cdxml-toolkit 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (91) hide show
  1. cdxml_toolkit/__init__.py +18 -0
  2. cdxml_toolkit/_jre/__init__.py +2 -0
  3. cdxml_toolkit/_jre/temurin-21-jre-win-x64.zip +0 -0
  4. cdxml_toolkit/analysis/__init__.py +35 -0
  5. cdxml_toolkit/analysis/deterministic/__init__.py +12 -0
  6. cdxml_toolkit/analysis/deterministic/discover_experiment_files.py +413 -0
  7. cdxml_toolkit/analysis/deterministic/lab_book_formatter.py +701 -0
  8. cdxml_toolkit/analysis/deterministic/lcms_file_categorizer.py +928 -0
  9. cdxml_toolkit/analysis/deterministic/lcms_identifier.py +598 -0
  10. cdxml_toolkit/analysis/deterministic/mass_resolver.py +654 -0
  11. cdxml_toolkit/analysis/deterministic/multi_lcms_analyzer.py +1412 -0
  12. cdxml_toolkit/analysis/deterministic/procedure_writer.py +446 -0
  13. cdxml_toolkit/analysis/extract_nmr.py +47 -0
  14. cdxml_toolkit/analysis/format_procedure_entry.py +479 -0
  15. cdxml_toolkit/analysis/lcms_analyzer.py +1299 -0
  16. cdxml_toolkit/analysis/parse_analysis_file.py +134 -0
  17. cdxml_toolkit/cdxml_builder.py +920 -0
  18. cdxml_toolkit/cdxml_utils.py +342 -0
  19. cdxml_toolkit/chemdraw/__init__.py +5 -0
  20. cdxml_toolkit/chemdraw/_chemscript_server.py +562 -0
  21. cdxml_toolkit/chemdraw/cdx_converter.py +527 -0
  22. cdxml_toolkit/chemdraw/cdxml_to_image.py +262 -0
  23. cdxml_toolkit/chemdraw/cdxml_to_image_rdkit.py +296 -0
  24. cdxml_toolkit/chemdraw/chemscript_bridge.py +901 -0
  25. cdxml_toolkit/constants.py +304 -0
  26. cdxml_toolkit/coord_normalizer.py +438 -0
  27. cdxml_toolkit/deterministic_pipeline/__init__.py +6 -0
  28. cdxml_toolkit/deterministic_pipeline/legacy/__init__.py +5 -0
  29. cdxml_toolkit/deterministic_pipeline/legacy/eln_cdx_cleanup.py +509 -0
  30. cdxml_toolkit/deterministic_pipeline/legacy/eln_enrichment.py +1394 -0
  31. cdxml_toolkit/deterministic_pipeline/legacy/scheme_aligner.py +428 -0
  32. cdxml_toolkit/deterministic_pipeline/legacy/scheme_polisher.py +1337 -0
  33. cdxml_toolkit/deterministic_pipeline/legacy/scheme_polisher_v2.py +1340 -0
  34. cdxml_toolkit/deterministic_pipeline/scheme_reader_audit.py +931 -0
  35. cdxml_toolkit/deterministic_pipeline/scheme_reader_verify.py +1160 -0
  36. cdxml_toolkit/image/__init__.py +15 -0
  37. cdxml_toolkit/image/reaction_from_image.py +2103 -0
  38. cdxml_toolkit/image/structure_from_image.py +1711 -0
  39. cdxml_toolkit/layout/__init__.py +5 -0
  40. cdxml_toolkit/layout/alignment.py +1642 -0
  41. cdxml_toolkit/layout/reaction_cleanup.py +1002 -0
  42. cdxml_toolkit/layout/scheme_merger.py +2260 -0
  43. cdxml_toolkit/mcp_server/__init__.py +0 -0
  44. cdxml_toolkit/mcp_server/__main__.py +5 -0
  45. cdxml_toolkit/mcp_server/server.py +1567 -0
  46. cdxml_toolkit/naming/__init__.py +6 -0
  47. cdxml_toolkit/naming/aligned_namer.py +2342 -0
  48. cdxml_toolkit/naming/mol_builder.py +3722 -0
  49. cdxml_toolkit/naming/name_decomposer.py +2843 -0
  50. cdxml_toolkit/naming/reactions_datamol.json +2414 -0
  51. cdxml_toolkit/office/__init__.py +5 -0
  52. cdxml_toolkit/office/doc_from_template.py +722 -0
  53. cdxml_toolkit/office/ole_embedder.py +808 -0
  54. cdxml_toolkit/office/ole_extractor.py +272 -0
  55. cdxml_toolkit/perception/__init__.py +10 -0
  56. cdxml_toolkit/perception/compound_search.py +229 -0
  57. cdxml_toolkit/perception/eln_csv_parser.py +240 -0
  58. cdxml_toolkit/perception/rdf_parser.py +664 -0
  59. cdxml_toolkit/perception/reactant_heuristic.py +1045 -0
  60. cdxml_toolkit/perception/reaction_parser.py +2150 -0
  61. cdxml_toolkit/perception/scheme_reader.py +2948 -0
  62. cdxml_toolkit/perception/scheme_refine.py +1404 -0
  63. cdxml_toolkit/perception/scheme_segmenter.py +619 -0
  64. cdxml_toolkit/perception/spatial_assignment.py +1013 -0
  65. cdxml_toolkit/rdkit_utils.py +605 -0
  66. cdxml_toolkit/render/__init__.py +17 -0
  67. cdxml_toolkit/render/auto_layout.py +229 -0
  68. cdxml_toolkit/render/compact_parser.py +632 -0
  69. cdxml_toolkit/render/parser.py +706 -0
  70. cdxml_toolkit/render/render_scheme.py +267 -0
  71. cdxml_toolkit/render/renderer.py +2387 -0
  72. cdxml_toolkit/render/schema.py +90 -0
  73. cdxml_toolkit/render/scheme_maker.py +1043 -0
  74. cdxml_toolkit/render/scheme_yaml_writer.py +1487 -0
  75. cdxml_toolkit/resolve/__init__.py +13 -0
  76. cdxml_toolkit/resolve/cas_resolver.py +430 -0
  77. cdxml_toolkit/resolve/chemscanner_abbreviations.json +28813 -0
  78. cdxml_toolkit/resolve/condensed_formula.py +493 -0
  79. cdxml_toolkit/resolve/jre_manager.py +195 -0
  80. cdxml_toolkit/resolve/reagent_abbreviations.json +1046 -0
  81. cdxml_toolkit/resolve/reagent_db.py +285 -0
  82. cdxml_toolkit/resolve/superatom_data.json +2856 -0
  83. cdxml_toolkit/resolve/superatom_table.py +146 -0
  84. cdxml_toolkit/text_formatting.py +298 -0
  85. cdxml_toolkit-0.5.0.dist-info/METADATA +318 -0
  86. cdxml_toolkit-0.5.0.dist-info/RECORD +91 -0
  87. cdxml_toolkit-0.5.0.dist-info/WHEEL +5 -0
  88. cdxml_toolkit-0.5.0.dist-info/entry_points.txt +17 -0
  89. cdxml_toolkit-0.5.0.dist-info/licenses/LICENSE +21 -0
  90. cdxml_toolkit-0.5.0.dist-info/licenses/NOTICE.md +37 -0
  91. cdxml_toolkit-0.5.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,2260 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ scheme_merger.py -- Merge multiple ELN-enriched reaction schemes.
4
+
5
+ Auto-detects relationships between input schemes:
6
+ Parallel: Same reaction at different scales -> one scheme, stacked run arrows.
7
+ Sequential: Step 1 product = step 2 starting material -> multi-step linear scheme.
8
+ Unrelated: No chemical relationship -> placed adjacent (side by side).
9
+
10
+ Input: ELN-enriched CDXMLs from scheme_polisher_v2.py (via run_pipeline.py).
11
+
12
+ Usage:
13
+ python scheme_merger.py s1.cdxml s2.cdxml s3.cdxml s4.cdxml # auto-detect
14
+ python scheme_merger.py --mode parallel s1.cdxml s2.cdxml # explicit mode
15
+ python scheme_merger.py --mode sequential s1.cdxml s2.cdxml
16
+ python scheme_merger.py s1.cdxml s2.cdxml --no-equiv
17
+ python scheme_merger.py s1.cdxml s2.cdxml --equiv-range
18
+ python scheme_merger.py s1.cdxml s2.cdxml --ref-cdxml ref.cdxml
19
+ python scheme_merger.py s1.cdxml s2.cdxml --no-adjacent # error on unrelated
20
+ """
21
+
22
+ import argparse
23
+ import copy
24
+ import os
25
+ import re
26
+ import subprocess
27
+ import sys
28
+ import tempfile
29
+ import xml.etree.ElementTree as ET
30
+ from dataclasses import dataclass, field
31
+ from typing import Dict, List, Optional, Tuple
32
+
33
+ from ..cdxml_utils import (
34
+ fragment_bbox,
35
+ fragment_bbox_with_label_extension,
36
+ fragment_bottom_has_hanging_label,
37
+ parse_cdxml,
38
+ write_cdxml,
39
+ recompute_text_bbox,
40
+ )
41
+ from ..rdkit_utils import frag_to_smiles, frag_to_mw
42
+ from ..constants import (
43
+ ACS_BOND_LENGTH,
44
+ LAYOUT_ABOVE_GAP,
45
+ LAYOUT_BELOW_GAP,
46
+ LAYOUT_HANGING_LABEL_GAP,
47
+ LAYOUT_FRAG_GAP_BONDS,
48
+ LAYOUT_INTER_GAP_BONDS,
49
+ MW_MATCH_TOLERANCE,
50
+ )
51
+
52
+
53
+ # ============================================================================
54
+ # Data structures
55
+ # ============================================================================
56
+
57
+ @dataclass
58
+ class RunArrowData:
59
+ """Data extracted from one run arrow (mass/yield for one run)."""
60
+ sm_mass_text: str # e.g. "50.0 mg"
61
+ yield_text: str # e.g. "62.6 mg, 77 %"
62
+ source_file: str = "" # filename stem for identification
63
+
64
+
65
+ @dataclass
66
+ class EquivInfo:
67
+ """Equiv value for one reagent in one scheme."""
68
+ reagent_name: str
69
+ equiv_value: float
70
+
71
+
72
+ @dataclass
73
+ class ParsedScheme:
74
+ """A parsed ELN-enriched CDXML scheme with all metadata extracted."""
75
+ path: str
76
+ tree: ET.ElementTree
77
+ root: ET.Element
78
+ page: ET.Element
79
+ scheme_el: ET.Element
80
+ step: ET.Element
81
+
82
+ # Element maps
83
+ id_map: Dict[str, ET.Element] = field(default_factory=dict)
84
+
85
+ # Fragments and their roles (from step metadata)
86
+ reactant_ids: List[str] = field(default_factory=list)
87
+ product_ids: List[str] = field(default_factory=list)
88
+ above_arrow_ids: List[str] = field(default_factory=list)
89
+ below_arrow_ids: List[str] = field(default_factory=list)
90
+ arrow_ids: List[str] = field(default_factory=list)
91
+
92
+ # Classified elements
93
+ fragments: Dict[str, ET.Element] = field(default_factory=dict)
94
+ fragment_smiles: Dict[str, str] = field(default_factory=dict)
95
+
96
+ # Arrows
97
+ main_arrow: Optional[ET.Element] = None
98
+ main_arrow_id: str = ""
99
+ main_graphic: Optional[ET.Element] = None
100
+ run_arrows: List[ET.Element] = field(default_factory=list)
101
+ run_graphics: List[ET.Element] = field(default_factory=list)
102
+
103
+ # Arrow geometry
104
+ arrow_tail_x: float = 0.0
105
+ arrow_head_x: float = 0.0
106
+ arrow_y: float = 0.0
107
+
108
+ # Run arrow data
109
+ run_arrow_data: List[RunArrowData] = field(default_factory=list)
110
+
111
+ # Text elements associated with run arrows (to remove during merge)
112
+ run_arrow_text_ids: List[str] = field(default_factory=list)
113
+
114
+ # Equiv values parsed from conditions text
115
+ equiv_values: List[EquivInfo] = field(default_factory=list)
116
+
117
+ def get_reactant_smiles_set(self) -> set:
118
+ """Set of canonical SMILES for reactant fragments."""
119
+ result = set()
120
+ for rid in self.reactant_ids:
121
+ s = self.fragment_smiles.get(rid, "")
122
+ if s:
123
+ result.add(s)
124
+ return result
125
+
126
+ def get_product_smiles_set(self) -> set:
127
+ """Set of canonical SMILES for product fragments."""
128
+ result = set()
129
+ for pid in self.product_ids:
130
+ s = self.fragment_smiles.get(pid, "")
131
+ if s:
132
+ result.add(s)
133
+ return result
134
+
135
+
136
+ # ============================================================================
137
+ # Parsing helpers
138
+ # ============================================================================
139
+
140
+ def _get_text_content(t_elem: ET.Element) -> str:
141
+ """Extract concatenated text from all <s> children of a <t> element."""
142
+ parts = []
143
+ for s in t_elem.iter("s"):
144
+ if s.text:
145
+ parts.append(s.text)
146
+ return "".join(parts).strip()
147
+
148
+
149
+ def _get_max_id(root: ET.Element) -> int:
150
+ """Find the maximum id attribute value in the entire document."""
151
+ max_id = 0
152
+ for el in root.iter():
153
+ eid = el.get("id", "")
154
+ if eid:
155
+ try:
156
+ max_id = max(max_id, int(eid))
157
+ except ValueError:
158
+ pass
159
+ return max_id
160
+
161
+
162
+ def _get_max_z(root: ET.Element) -> int:
163
+ """Find the maximum Z attribute value in the entire document."""
164
+ max_z = 0
165
+ for el in root.iter():
166
+ z = el.get("Z", "")
167
+ if z:
168
+ try:
169
+ max_z = max(max_z, int(z))
170
+ except ValueError:
171
+ pass
172
+ return max_z
173
+
174
+
175
+ def _get_arrow_coords(arrow_el: ET.Element) -> Tuple[float, float, float]:
176
+ """Extract (tail_x, head_x, y) from an <arrow> or <graphic> element."""
177
+ if arrow_el.tag == "arrow":
178
+ head3d = arrow_el.get("Head3D", "")
179
+ tail3d = arrow_el.get("Tail3D", "")
180
+ if head3d and tail3d:
181
+ hp = head3d.split()
182
+ tp = tail3d.split()
183
+ head_x = float(hp[0])
184
+ tail_x = float(tp[0])
185
+ y = float(hp[1])
186
+ else:
187
+ bb = arrow_el.get("BoundingBox", "").split()
188
+ tail_x = float(bb[0])
189
+ head_x = float(bb[2])
190
+ y = (float(bb[1]) + float(bb[3])) / 2.0
191
+ else: # graphic
192
+ bb = arrow_el.get("BoundingBox", "").split()
193
+ head_x = float(bb[0])
194
+ tail_x = float(bb[2])
195
+ y = float(bb[1])
196
+
197
+ if tail_x > head_x:
198
+ tail_x, head_x = head_x, tail_x
199
+ return tail_x, head_x, y
200
+
201
+
202
+ def _fragment_centroid_y(page: ET.Element, frag_ids: List[str],
203
+ id_map: Dict[str, ET.Element]) -> float:
204
+ """Average y-centroid of the specified fragments."""
205
+ ys = []
206
+ for fid in frag_ids:
207
+ el = id_map.get(fid)
208
+ if el is not None and el.tag == "fragment":
209
+ bb = fragment_bbox(el)
210
+ if bb:
211
+ ys.append((bb[1] + bb[3]) / 2.0)
212
+ return sum(ys) / len(ys) if ys else 0.0
213
+
214
+
215
+ def _build_page_id_map(page: ET.Element) -> Dict[str, ET.Element]:
216
+ """Build id->element map for direct children of page."""
217
+ id_map = {}
218
+ for el in page:
219
+ eid = el.get("id")
220
+ if eid:
221
+ id_map[eid] = el
222
+ return id_map
223
+
224
+
225
+ def parse_scheme(path: str, log=None) -> ParsedScheme:
226
+ """Parse an ELN-enriched CDXML scheme file."""
227
+ if log is None:
228
+ log = lambda msg: None
229
+
230
+ tree = parse_cdxml(path)
231
+ root = tree.getroot()
232
+ page = root.find(".//page")
233
+ if page is None:
234
+ raise ValueError(f"No <page> found in {path}")
235
+
236
+ scheme_el = page.find(".//scheme")
237
+ if scheme_el is None:
238
+ raise ValueError(f"No <scheme> found in {path}")
239
+
240
+ step = scheme_el.find("step")
241
+ if step is None:
242
+ raise ValueError(f"No <step> found in {path}")
243
+
244
+ ps = ParsedScheme(
245
+ path=path, tree=tree, root=root, page=page,
246
+ scheme_el=scheme_el, step=step,
247
+ )
248
+
249
+ # Build ID map (direct children of page)
250
+ ps.id_map = _build_page_id_map(page)
251
+
252
+ # Parse step metadata
253
+ ps.reactant_ids = step.get("ReactionStepReactants", "").split()
254
+ ps.product_ids = step.get("ReactionStepProducts", "").split()
255
+ ps.above_arrow_ids = step.get("ReactionStepObjectsAboveArrow", "").split()
256
+ ps.below_arrow_ids = step.get("ReactionStepObjectsBelowArrow", "").split()
257
+ ps.arrow_ids = step.get("ReactionStepArrows", "").split()
258
+
259
+ # Collect all fragment IDs and compute SMILES
260
+ for el in page:
261
+ if el.tag == "fragment":
262
+ fid = el.get("id", "")
263
+ ps.fragments[fid] = el
264
+ try:
265
+ smiles = frag_to_smiles(el)
266
+ ps.fragment_smiles[fid] = smiles
267
+ except Exception:
268
+ ps.fragment_smiles[fid] = ""
269
+
270
+ # Identify all arrows on the page
271
+ all_arrows = []
272
+ all_graphics = {} # arrow_id -> graphic element (via SupersededBy)
273
+ for el in page:
274
+ if el.tag == "arrow":
275
+ all_arrows.append(el)
276
+ elif el.tag == "graphic" and el.get("SupersededBy"):
277
+ all_graphics[el.get("SupersededBy")] = el
278
+
279
+ # Distinguish main arrow from run arrows by proximity to fragment centroids
280
+ all_frag_ids = list(ps.reactant_ids) + list(ps.product_ids)
281
+ frag_cy = _fragment_centroid_y(page, all_frag_ids, ps.id_map)
282
+
283
+ if all_arrows:
284
+ # Sort arrows by distance to fragment centroid y
285
+ arrows_with_dist = []
286
+ for a in all_arrows:
287
+ _, _, ay = _get_arrow_coords(a)
288
+ arrows_with_dist.append((abs(ay - frag_cy), ay, a))
289
+ arrows_with_dist.sort(key=lambda x: x[0])
290
+
291
+ ps.main_arrow = arrows_with_dist[0][2]
292
+ ps.main_arrow_id = ps.main_arrow.get("id", "")
293
+ ps.main_graphic = all_graphics.get(ps.main_arrow_id)
294
+
295
+ # Get main arrow coordinates
296
+ ps.arrow_tail_x, ps.arrow_head_x, ps.arrow_y = _get_arrow_coords(ps.main_arrow)
297
+
298
+ # Run arrows: everything below the main arrow
299
+ for _, ay, a in arrows_with_dist[1:]:
300
+ if ay > ps.arrow_y + 5.0:
301
+ ps.run_arrows.append(a)
302
+ aid = a.get("id", "")
303
+ if aid in all_graphics:
304
+ ps.run_graphics.append(all_graphics[aid])
305
+
306
+ # Sort run arrows by y position
307
+ ps.run_arrows.sort(
308
+ key=lambda a: float(a.get("Head3D", "0 0 0").split()[1])
309
+ )
310
+
311
+ # Extract run arrow data (text near each run arrow)
312
+ for ra in ps.run_arrows:
313
+ ra_tail_x, ra_head_x, ra_y = _get_arrow_coords(ra)
314
+ sm_text = ""
315
+ yield_text = ""
316
+ for el in page:
317
+ if el.tag != "t":
318
+ continue
319
+ p = el.get("p", "")
320
+ if not p:
321
+ continue
322
+ parts = p.split()
323
+ tx, ty = float(parts[0]), float(parts[1])
324
+ if abs(ty - (ra_y + 2.25)) < 6.0:
325
+ text = _get_text_content(el)
326
+ # SM text is right-justified, near arrow tail
327
+ if tx < (ra_tail_x + ra_head_x) / 2.0:
328
+ sm_text = text
329
+ ps.run_arrow_text_ids.append(el.get("id", ""))
330
+ else:
331
+ yield_text = text
332
+ ps.run_arrow_text_ids.append(el.get("id", ""))
333
+
334
+ ps.run_arrow_data.append(RunArrowData(
335
+ sm_mass_text=sm_text,
336
+ yield_text=yield_text,
337
+ source_file=os.path.splitext(os.path.basename(path))[0],
338
+ ))
339
+
340
+ # Extract equiv values from conditions text
341
+ _extract_equiv_values(ps)
342
+
343
+ log(f"Parsed {os.path.basename(path)}: "
344
+ f"{len(ps.fragments)} fragments, "
345
+ f"{len(ps.run_arrows)} run arrow(s), "
346
+ f"main arrow at y={ps.arrow_y:.1f}")
347
+ return ps
348
+
349
+
350
+ def _extract_equiv_values(ps: ParsedScheme):
351
+ """Parse equiv values from conditions text and equiv labels."""
352
+ # Pattern: "ReagentName (X eq.)" or standalone "(X eq.)"
353
+ equiv_re = re.compile(r'\((\d+\.?\d*)\s*eq\.\)')
354
+
355
+ for el in ps.page:
356
+ if el.tag != "t":
357
+ continue
358
+ text = _get_text_content(el)
359
+ if not text:
360
+ continue
361
+
362
+ # Multi-line conditions text: "PPh3 (1.6 eq.)\nDEAD (1.55 eq.)\nTHF"
363
+ for line in text.split("\n"):
364
+ line = line.strip()
365
+ m = equiv_re.search(line)
366
+ if m:
367
+ equiv_val = float(m.group(1))
368
+ # Reagent name is everything before the parentheses
369
+ name = line[:m.start()].strip()
370
+ if not name:
371
+ name = line # standalone equiv label
372
+ ps.equiv_values.append(EquivInfo(
373
+ reagent_name=name,
374
+ equiv_value=equiv_val,
375
+ ))
376
+
377
+
378
+ # ============================================================================
379
+ # Element creation helpers (patterns from eln_enrichment.py)
380
+ # ============================================================================
381
+
382
+ def _create_text_element(elem_id: int, z_order: int,
383
+ x: float, y: float,
384
+ text: str, justify: str = "Left") -> ET.Element:
385
+ """Create a standalone <t> element with plain text content."""
386
+ t = ET.Element("t")
387
+ t.set("id", str(elem_id))
388
+ t.set("p", f"{x:.2f} {y:.2f}")
389
+ t.set("Z", str(z_order))
390
+ t.set("Warning",
391
+ "Chemical Interpretation is not possible for this label")
392
+ t.set("LineHeight", "auto")
393
+
394
+ if justify == "Center":
395
+ t.set("CaptionJustification", "Center")
396
+ t.set("Justification", "Center")
397
+ elif justify == "Right":
398
+ t.set("CaptionJustification", "Right")
399
+ t.set("Justification", "Right")
400
+
401
+ s = ET.SubElement(t, "s")
402
+ s.set("font", "3")
403
+ s.set("size", "10")
404
+ s.set("color", "0")
405
+ s.text = text
406
+
407
+ # BoundingBox
408
+ char_w = 5.8
409
+ line_h = 12.0
410
+ w = len(text) * char_w
411
+ if justify == "Center":
412
+ x1, x2 = x - w / 2.0, x + w / 2.0
413
+ elif justify == "Right":
414
+ x1, x2 = x - w, x
415
+ else:
416
+ x1, x2 = x, x + w
417
+ y1 = y - line_h + 3.0
418
+ y2 = y + 3.0
419
+ t.set("BoundingBox", f"{x1:.2f} {y1:.2f} {x2:.2f} {y2:.2f}")
420
+ return t
421
+
422
+
423
+ def _create_arrow(elem_id: int, z_order: int,
424
+ tail_x: float, head_x: float,
425
+ y: float) -> ET.Element:
426
+ """Create an <arrow> element."""
427
+ arrow = ET.Element("arrow")
428
+ arrow.set("id", str(elem_id))
429
+ bb_top = y - 1.64
430
+ bb_bot = y + 1.52
431
+ arrow.set("BoundingBox",
432
+ f"{tail_x:.2f} {bb_top:.2f} {head_x:.2f} {bb_bot:.2f}")
433
+ arrow.set("Z", str(z_order))
434
+ arrow.set("FillType", "None")
435
+ arrow.set("ArrowheadHead", "Full")
436
+ arrow.set("ArrowheadType", "Solid")
437
+ arrow.set("HeadSize", "1000")
438
+ arrow.set("ArrowheadCenterSize", "875")
439
+ arrow.set("ArrowheadWidth", "250")
440
+ arrow.set("Head3D", f"{head_x:.2f} {y:.2f} 0")
441
+ arrow.set("Tail3D", f"{tail_x:.2f} {y:.2f} 0")
442
+ # Center3D / axis ends (cosmetic)
443
+ cx_3d = (tail_x + head_x) / 2.0 + 290.0
444
+ cy_3d = y + 129.0
445
+ half_len = (head_x - tail_x) / 2.0
446
+ arrow.set("Center3D", f"{cx_3d:.2f} {cy_3d:.2f} 0")
447
+ arrow.set("MajorAxisEnd3D",
448
+ f"{cx_3d + half_len:.2f} {cy_3d:.2f} 0")
449
+ arrow.set("MinorAxisEnd3D",
450
+ f"{cx_3d:.2f} {cy_3d + half_len:.2f} 0")
451
+ return arrow
452
+
453
+
454
+ def _create_graphic(elem_id: int, z_order: int,
455
+ superseded_by: int,
456
+ tail_x: float, head_x: float,
457
+ y: float) -> ET.Element:
458
+ """Create a <graphic> element (old-style arrow ref)."""
459
+ g = ET.Element("graphic")
460
+ g.set("id", str(elem_id))
461
+ g.set("SupersededBy", str(superseded_by))
462
+ g.set("BoundingBox",
463
+ f"{head_x:.2f} {y:.2f} {tail_x:.2f} {y:.2f}")
464
+ g.set("Z", str(z_order))
465
+ g.set("GraphicType", "Line")
466
+ g.set("ArrowType", "FullHead")
467
+ g.set("HeadSize", "1000")
468
+ return g
469
+
470
+
471
+ # ============================================================================
472
+ # Geometry helpers
473
+ # ============================================================================
474
+
475
+ def _fragment_main_component_bbox(frag: ET.Element) -> Optional[Tuple[float, float, float, float]]:
476
+ """Get atom-only bbox of the largest connected component in a fragment.
477
+
478
+ For salt products (e.g. amine + HCl), the counterion atoms may be far
479
+ from the main structure, inflating the bbox. This returns the bbox of
480
+ only the largest connected component by atom count.
481
+
482
+ Falls back to regular fragment_bbox if there's only one component.
483
+ """
484
+ # Collect atoms with positions
485
+ atoms = {} # id -> (x, y)
486
+ for n in frag:
487
+ if n.tag != "n":
488
+ continue
489
+ nid = n.get("id", "")
490
+ p = n.get("p", "")
491
+ if nid and p:
492
+ parts = p.split()
493
+ if len(parts) >= 2:
494
+ atoms[nid] = (float(parts[0]), float(parts[1]))
495
+
496
+ if len(atoms) < 2:
497
+ return fragment_bbox(frag)
498
+
499
+ # Build adjacency from bonds
500
+ adj: Dict[str, List[str]] = {aid: [] for aid in atoms}
501
+ for b in frag:
502
+ if b.tag != "b":
503
+ continue
504
+ b_id = b.get("B", "")
505
+ e_id = b.get("E", "")
506
+ if b_id in adj and e_id in adj:
507
+ adj[b_id].append(e_id)
508
+ adj[e_id].append(b_id)
509
+
510
+ # Find connected components via BFS
511
+ visited = set()
512
+ components = []
513
+ for start in atoms:
514
+ if start in visited:
515
+ continue
516
+ comp = []
517
+ queue = [start]
518
+ visited.add(start)
519
+ while queue:
520
+ node = queue.pop(0)
521
+ comp.append(node)
522
+ for nb in adj.get(node, []):
523
+ if nb not in visited:
524
+ visited.add(nb)
525
+ queue.append(nb)
526
+ components.append(comp)
527
+
528
+ if len(components) <= 1:
529
+ return fragment_bbox(frag)
530
+
531
+ # Use the largest component
532
+ largest = max(components, key=len)
533
+ xs = [atoms[aid][0] for aid in largest]
534
+ ys = [atoms[aid][1] for aid in largest]
535
+ return (min(xs), min(ys), max(xs), max(ys))
536
+
537
+
538
+ def _get_element_bbox(el: ET.Element) -> Optional[Tuple[float, float, float, float]]:
539
+ """Get bounding box for any element."""
540
+ if el.tag == "fragment":
541
+ return fragment_bbox_with_label_extension(el)
542
+ elif el.tag == "t":
543
+ bb = el.get("BoundingBox", "")
544
+ if bb:
545
+ vals = [float(v) for v in bb.split()]
546
+ if len(vals) >= 4:
547
+ return (vals[0], vals[1], vals[2], vals[3])
548
+ p = el.get("p", "")
549
+ if p:
550
+ parts = [float(v) for v in p.split()]
551
+ text = _get_text_content(el)
552
+ w = len(text) * 5.8
553
+ return (parts[0] - w / 2, parts[1] - 12.0,
554
+ parts[0] + w / 2, parts[1])
555
+ elif el.tag in ("arrow", "graphic"):
556
+ bb = el.get("BoundingBox", "")
557
+ if bb:
558
+ vals = [float(v) for v in bb.split()]
559
+ if len(vals) >= 4:
560
+ return (vals[0], vals[1], vals[2], vals[3])
561
+ return None
562
+
563
+
564
+ def _content_bottom(page: ET.Element, exclude_ids: set = None) -> float:
565
+ """Find the bottom y coordinate of all visible content on the page."""
566
+ if exclude_ids is None:
567
+ exclude_ids = set()
568
+ bottom = 0.0
569
+ for el in page:
570
+ eid = el.get("id", "")
571
+ if eid in exclude_ids:
572
+ continue
573
+ if el.tag in ("scheme",):
574
+ continue
575
+ bb = _get_element_bbox(el)
576
+ if bb and bb[3] > bottom:
577
+ bottom = bb[3]
578
+ return bottom
579
+
580
+
581
+ def _shift_element(el: ET.Element, dx: float, dy: float):
582
+ """Translate an element (fragment or text) by (dx, dy)."""
583
+ if el.tag == "fragment":
584
+ for n in el.iter("n"):
585
+ p = n.get("p")
586
+ if p:
587
+ parts = p.split()
588
+ if len(parts) >= 2:
589
+ nx = float(parts[0]) + dx
590
+ ny = float(parts[1]) + dy
591
+ n.set("p", f"{nx:.2f} {ny:.2f}")
592
+ for t in el.iter("t"):
593
+ _shift_text_el(t, dx, dy)
594
+ _shift_bbox_attr(el, dx, dy)
595
+ for inner in el.iter("fragment"):
596
+ if inner is not el:
597
+ _shift_bbox_attr(inner, dx, dy)
598
+ elif el.tag == "t":
599
+ _shift_text_el(el, dx, dy)
600
+
601
+
602
+ def _shift_text_el(t: ET.Element, dx: float, dy: float):
603
+ """Shift a <t> element's p and BoundingBox."""
604
+ p = t.get("p")
605
+ if p:
606
+ parts = p.split()
607
+ if len(parts) >= 2:
608
+ t.set("p", f"{float(parts[0]) + dx:.2f} {float(parts[1]) + dy:.2f}")
609
+ _shift_bbox_attr(t, dx, dy)
610
+
611
+
612
+ def _shift_bbox_attr(el: ET.Element, dx: float, dy: float):
613
+ """Shift BoundingBox attribute by (dx, dy)."""
614
+ bb = el.get("BoundingBox")
615
+ if bb:
616
+ vals = [float(v) for v in bb.split()]
617
+ if len(vals) >= 4:
618
+ vals[0] += dx
619
+ vals[1] += dy
620
+ vals[2] += dx
621
+ vals[3] += dy
622
+ el.set("BoundingBox", " ".join(f"{v:.2f}" for v in vals))
623
+
624
+
625
+ def _move_element_to(el: ET.Element, target_cx: float, target_cy: float):
626
+ """Move element so its center is at (target_cx, target_cy)."""
627
+ bb = _get_element_bbox(el)
628
+ if bb is None:
629
+ return
630
+ cx = (bb[0] + bb[2]) / 2.0
631
+ cy = (bb[1] + bb[3]) / 2.0
632
+ _shift_element(el, target_cx - cx, target_cy - cy)
633
+
634
+
635
+ def _update_document_bbox(root: ET.Element, page: ET.Element):
636
+ """Update root CDXML BoundingBox to encompass all content."""
637
+ min_x = min_y = float('inf')
638
+ max_x = max_y = float('-inf')
639
+ for el in page:
640
+ bb = _get_element_bbox(el)
641
+ if bb is None:
642
+ continue
643
+ min_x = min(min_x, bb[0])
644
+ min_y = min(min_y, bb[1])
645
+ max_x = max(max_x, bb[2])
646
+ max_y = max(max_y, bb[3])
647
+ if min_x < float('inf'):
648
+ root.set("BoundingBox",
649
+ f"{min_x:.2f} {min_y:.2f} {max_x:.2f} {max_y:.2f}")
650
+
651
+
652
+ # ============================================================================
653
+ # Fragment comparison
654
+ # ============================================================================
655
+
656
+ def _smiles_match(s1: str, s2: str) -> bool:
657
+ """Check if two SMILES represent the same molecule."""
658
+ if not s1 or not s2:
659
+ return False
660
+ return s1 == s2
661
+
662
+
663
+ def _fragments_same_molecule(frag1: ET.Element, frag2: ET.Element,
664
+ smiles1: str = "", smiles2: str = "") -> bool:
665
+ """Check if two fragments represent the same molecule.
666
+
667
+ Primary: canonical SMILES comparison.
668
+ Fallback: MW comparison within MW_MATCH_TOLERANCE (for fragments
669
+ with abbreviation groups that produce '*' in SMILES).
670
+ """
671
+ # Try SMILES first
672
+ if not smiles1:
673
+ try:
674
+ smiles1 = frag_to_smiles(frag1)
675
+ except Exception:
676
+ smiles1 = ""
677
+ if not smiles2:
678
+ try:
679
+ smiles2 = frag_to_smiles(frag2)
680
+ except Exception:
681
+ smiles2 = ""
682
+
683
+ if smiles1 and smiles2 and "*" not in smiles1 and "*" not in smiles2:
684
+ return _smiles_match(smiles1, smiles2)
685
+
686
+ # Fallback: MW comparison
687
+ try:
688
+ mw1 = frag_to_mw(frag1)
689
+ mw2 = frag_to_mw(frag2)
690
+ if mw1 is not None and mw2 is not None:
691
+ return abs(mw1 - mw2) < MW_MATCH_TOLERANCE
692
+ except Exception:
693
+ pass
694
+
695
+ return False
696
+
697
+
698
+ def _products_match(ps_a: ParsedScheme, ps_b: ParsedScheme) -> bool:
699
+ """Check if two schemes have the same set of product molecules."""
700
+ prod_a = ps_a.get_product_smiles_set()
701
+ prod_b = ps_b.get_product_smiles_set()
702
+
703
+ # Fast path: identical SMILES sets (no abbreviation groups)
704
+ if prod_a and prod_b and prod_a == prod_b:
705
+ return True
706
+
707
+ # Fragment-level comparison with MW fallback
708
+ frags_a = [(pid, ps_a.fragments.get(pid)) for pid in ps_a.product_ids
709
+ if ps_a.fragments.get(pid) is not None]
710
+ frags_b = [(pid, ps_b.fragments.get(pid)) for pid in ps_b.product_ids
711
+ if ps_b.fragments.get(pid) is not None]
712
+
713
+ if len(frags_a) != len(frags_b):
714
+ return False
715
+
716
+ # Try to match each product in A to one in B
717
+ used = set()
718
+ for aid, afrag in frags_a:
719
+ asmiles = ps_a.fragment_smiles.get(aid, "")
720
+ matched = False
721
+ for j, (bid, bfrag) in enumerate(frags_b):
722
+ if j in used:
723
+ continue
724
+ bsmiles = ps_b.fragment_smiles.get(bid, "")
725
+ if _fragments_same_molecule(afrag, bfrag, asmiles, bsmiles):
726
+ used.add(j)
727
+ matched = True
728
+ break
729
+ if not matched:
730
+ return False
731
+ return True
732
+
733
+
734
+ def _any_reactant_matches(ps_a: ParsedScheme, ps_b: ParsedScheme) -> bool:
735
+ """Check if at least one reactant fragment matches between two schemes."""
736
+ for aid in ps_a.reactant_ids:
737
+ afrag = ps_a.fragments.get(aid)
738
+ if afrag is None:
739
+ continue
740
+ asmiles = ps_a.fragment_smiles.get(aid, "")
741
+ for bid in ps_b.reactant_ids:
742
+ bfrag = ps_b.fragments.get(bid)
743
+ if bfrag is None:
744
+ continue
745
+ bsmiles = ps_b.fragment_smiles.get(bid, "")
746
+ if _fragments_same_molecule(afrag, bfrag, asmiles, bsmiles):
747
+ return True
748
+ return False
749
+
750
+
751
+ def _product_matches_reactant(ps_a: ParsedScheme, ps_b: ParsedScheme) -> bool:
752
+ """Check if any product of A matches any reactant of B."""
753
+ for pid in ps_a.product_ids:
754
+ pfrag = ps_a.fragments.get(pid)
755
+ if pfrag is None:
756
+ continue
757
+ psmiles = ps_a.fragment_smiles.get(pid, "")
758
+ for rid in ps_b.reactant_ids:
759
+ rfrag = ps_b.fragments.get(rid)
760
+ if rfrag is None:
761
+ continue
762
+ rsmiles = ps_b.fragment_smiles.get(rid, "")
763
+ if _fragments_same_molecule(pfrag, rfrag, psmiles, rsmiles):
764
+ return True
765
+ return False
766
+
767
+
768
+ # ============================================================================
769
+ # Auto-detection: classify pairs and plan merges
770
+ # ============================================================================
771
+
772
+ def classify_pair(ps_a: ParsedScheme, ps_b: ParsedScheme) -> str:
773
+ """Classify the relationship between two parsed schemes.
774
+
775
+ Returns:
776
+ "parallel" - Same reaction (same products, shared reactant)
777
+ "sequential_ab" - A's product is B's starting material
778
+ "sequential_ba" - B's product is A's starting material
779
+ "unrelated" - No chemical relationship
780
+ """
781
+ # Check parallel first: same products AND at least one shared reactant
782
+ if _products_match(ps_a, ps_b) and _any_reactant_matches(ps_a, ps_b):
783
+ return "parallel"
784
+
785
+ # Check sequential: product of one matches reactant of the other
786
+ if _product_matches_reactant(ps_a, ps_b):
787
+ return "sequential_ab"
788
+ if _product_matches_reactant(ps_b, ps_a):
789
+ return "sequential_ba"
790
+
791
+ return "unrelated"
792
+
793
+
794
+ @dataclass
795
+ class MergePlan:
796
+ """Result of auto-detection: how to merge N schemes."""
797
+ parallel_groups: List[List[int]] = field(default_factory=list)
798
+ """Groups of scheme indices that are the same reaction."""
799
+ sequential_chain: List[int] = field(default_factory=list)
800
+ """Indices into parallel_groups, in reaction order."""
801
+ unrelated_groups: List[int] = field(default_factory=list)
802
+ """Indices into parallel_groups with no sequential link."""
803
+
804
+ def describe(self) -> str:
805
+ """Human-readable summary of the merge plan."""
806
+ parts = []
807
+ if self.sequential_chain:
808
+ chain_desc = []
809
+ for gi in self.sequential_chain:
810
+ grp = self.parallel_groups[gi]
811
+ if len(grp) > 1:
812
+ chain_desc.append(f"[{'+'.join(str(g) for g in grp)}]")
813
+ else:
814
+ chain_desc.append(str(grp[0]))
815
+ parts.append(f"Sequential chain: {' -> '.join(chain_desc)}")
816
+ elif len(self.parallel_groups) == 1 and len(self.parallel_groups[0]) > 1:
817
+ grp = self.parallel_groups[0]
818
+ parts.append(f"Parallel merge: {'+'.join(str(g) for g in grp)}")
819
+ if self.unrelated_groups:
820
+ for gi in self.unrelated_groups:
821
+ grp = self.parallel_groups[gi]
822
+ parts.append(f"Adjacent (unrelated): "
823
+ f"{'+'.join(str(g) for g in grp)}")
824
+ return "; ".join(parts) if parts else "Single scheme"
825
+
826
+
827
+ def auto_detect(schemes: List[ParsedScheme], log=None) -> MergePlan:
828
+ """Analyze N schemes and determine merge strategy.
829
+
830
+ Algorithm:
831
+ 1. Classify all pairs (parallel, sequential, unrelated).
832
+ 2. Union-Find to cluster parallel schemes.
833
+ 3. Build DAG of sequential links between clusters.
834
+ 4. Topological sort for reaction order.
835
+ 5. Remaining clusters are unrelated.
836
+ """
837
+ if log is None:
838
+ log = lambda msg: None
839
+
840
+ n = len(schemes)
841
+ if n == 1:
842
+ return MergePlan(parallel_groups=[[0]], sequential_chain=[0])
843
+
844
+ # --- Step 1: Classify all pairs ---
845
+ classifications = {}
846
+ for i in range(n):
847
+ for j in range(i + 1, n):
848
+ c = classify_pair(schemes[i], schemes[j])
849
+ classifications[(i, j)] = c
850
+ log(f" {os.path.basename(schemes[i].path)} vs "
851
+ f"{os.path.basename(schemes[j].path)}: {c}")
852
+
853
+ # --- Step 2: Union-Find for parallel clusters ---
854
+ parent = list(range(n))
855
+
856
+ def find(x):
857
+ while parent[x] != x:
858
+ parent[x] = parent[parent[x]]
859
+ x = parent[x]
860
+ return x
861
+
862
+ def union(x, y):
863
+ px, py = find(x), find(y)
864
+ if px != py:
865
+ parent[px] = py
866
+
867
+ for (i, j), c in classifications.items():
868
+ if c == "parallel":
869
+ union(i, j)
870
+
871
+ # Build groups
872
+ groups_map = {}
873
+ for i in range(n):
874
+ root = find(i)
875
+ groups_map.setdefault(root, []).append(i)
876
+ groups = list(groups_map.values())
877
+
878
+ # Index each scheme to its group
879
+ scheme_to_group = {}
880
+ for gi, grp in enumerate(groups):
881
+ for si in grp:
882
+ scheme_to_group[si] = gi
883
+
884
+ # --- Step 3: Build DAG of sequential links between groups ---
885
+ # For each sequential pair, determine which group feeds into which
886
+ seq_edges = set() # (from_group, to_group)
887
+ for (i, j), c in classifications.items():
888
+ gi, gj = scheme_to_group[i], scheme_to_group[j]
889
+ if gi == gj:
890
+ continue # same parallel group
891
+ if c == "sequential_ab":
892
+ seq_edges.add((gi, gj))
893
+ elif c == "sequential_ba":
894
+ seq_edges.add((gj, gi))
895
+
896
+ # --- Step 4: Topological sort for sequential chain ---
897
+ if seq_edges:
898
+ # Build adjacency and in-degree
899
+ ng = len(groups)
900
+ adj = {i: [] for i in range(ng)}
901
+ in_deg = {i: 0 for i in range(ng)}
902
+ for (a, b) in seq_edges:
903
+ adj[a].append(b)
904
+ in_deg[b] += 1
905
+
906
+ # Kahn's algorithm
907
+ queue = [i for i in range(ng) if in_deg[i] == 0]
908
+ topo_order = []
909
+ while queue:
910
+ node = queue.pop(0)
911
+ topo_order.append(node)
912
+ for nb in adj[node]:
913
+ in_deg[nb] -= 1
914
+ if in_deg[nb] == 0:
915
+ queue.append(nb)
916
+
917
+ if len(topo_order) != ng:
918
+ log("WARNING: Cycle detected in sequential links — "
919
+ "falling back to input order")
920
+ topo_order = list(range(ng))
921
+
922
+ # Split into connected (in the sequential DAG) and unrelated
923
+ connected = set()
924
+ for a, b in seq_edges:
925
+ connected.add(a)
926
+ connected.add(b)
927
+
928
+ chain = [gi for gi in topo_order if gi in connected]
929
+ unrelated = [gi for gi in topo_order if gi not in connected]
930
+ else:
931
+ # No sequential links — everything is either one parallel group
932
+ # or multiple unrelated groups
933
+ chain = []
934
+ unrelated = list(range(len(groups)))
935
+ # If there's only one group, it's parallel (not unrelated)
936
+ if len(groups) == 1:
937
+ chain = [0]
938
+ unrelated = []
939
+
940
+ plan = MergePlan(
941
+ parallel_groups=groups,
942
+ sequential_chain=chain,
943
+ unrelated_groups=unrelated,
944
+ )
945
+ log(f" Merge plan: {plan.describe()}")
946
+ return plan
947
+
948
+
949
+ def execute_merge_plan(schemes: List[ParsedScheme], plan: MergePlan, *,
950
+ equiv_mode: str = "default",
951
+ ref_cdxml: str = None,
952
+ allow_adjacent: bool = True,
953
+ log=None) -> ET.ElementTree:
954
+ """Execute a merge plan: parallel within groups, sequential between.
955
+
956
+ Args:
957
+ schemes: All parsed schemes.
958
+ plan: Auto-detected merge plan.
959
+ equiv_mode: Equiv handling for parallel merges.
960
+ ref_cdxml: Reference CDXML for sequential alignment.
961
+ allow_adjacent: If False, raise ValueError for unrelated reactions.
962
+ log: Logging callback.
963
+ """
964
+ if log is None:
965
+ log = lambda msg: None
966
+
967
+ # Check for unrelated groups
968
+ if plan.unrelated_groups and not plan.sequential_chain:
969
+ # All groups are unrelated (no sequential chain)
970
+ if not allow_adjacent and len(plan.parallel_groups) > 1:
971
+ raise ValueError(
972
+ "Input schemes have no chemical relationship (not parallel, "
973
+ "not sequential). Use --adjacent to place them side by side."
974
+ )
975
+
976
+ if plan.unrelated_groups and plan.sequential_chain and not allow_adjacent:
977
+ raise ValueError(
978
+ "Some input schemes are unrelated to the sequential chain. "
979
+ "Use --adjacent to place them side by side."
980
+ )
981
+
982
+ # --- Phase 1: Parallel merge within each group ---
983
+ group_trees = [] # one tree per group (parallel-merged or single)
984
+ group_schemes = [] # re-parsed schemes for sequential merge
985
+
986
+ for gi, grp in enumerate(plan.parallel_groups):
987
+ if len(grp) == 1:
988
+ # Single scheme — use as-is
989
+ s = schemes[grp[0]]
990
+ group_trees.append(copy.deepcopy(s.tree))
991
+ log(f"Group {gi}: single scheme "
992
+ f"({os.path.basename(s.path)})")
993
+ else:
994
+ # Parallel merge
995
+ grp_schemes = [schemes[si] for si in grp]
996
+ names = [os.path.basename(schemes[si].path) for si in grp]
997
+ log(f"Group {gi}: parallel merge of {', '.join(names)}")
998
+ merged = parallel_merge(grp_schemes, equiv_mode=equiv_mode,
999
+ log=log)
1000
+ group_trees.append(merged)
1001
+
1002
+ # --- Phase 2: Sequential merge across groups in chain order ---
1003
+ if len(plan.sequential_chain) > 1:
1004
+ # Re-parse the parallel-merged trees for sequential merge
1005
+ chain_schemes = []
1006
+ for gi in plan.sequential_chain:
1007
+ tree = group_trees[gi]
1008
+ # Write to temp file and re-parse
1009
+ fd, tmp_path = tempfile.mkstemp(suffix=".cdxml")
1010
+ try:
1011
+ write_cdxml(tree, tmp_path)
1012
+ os.close(fd)
1013
+ ps = parse_scheme(tmp_path, log=log)
1014
+ chain_schemes.append(ps)
1015
+ except Exception as e:
1016
+ log(f"WARNING: Failed to re-parse group {gi}: {e}")
1017
+ os.close(fd)
1018
+ finally:
1019
+ try:
1020
+ os.unlink(tmp_path)
1021
+ except OSError:
1022
+ pass
1023
+
1024
+ if len(chain_schemes) >= 2:
1025
+ log(f"Sequential merge: {len(chain_schemes)} steps")
1026
+ result_tree = sequential_merge(
1027
+ chain_schemes, ref_cdxml=ref_cdxml, log=log)
1028
+ else:
1029
+ result_tree = group_trees[plan.sequential_chain[0]]
1030
+
1031
+ elif len(plan.sequential_chain) == 1:
1032
+ result_tree = group_trees[plan.sequential_chain[0]]
1033
+ elif plan.unrelated_groups:
1034
+ # All groups are unrelated — use the first one as starting point
1035
+ result_tree = group_trees[plan.unrelated_groups[0]]
1036
+ else:
1037
+ result_tree = group_trees[0]
1038
+
1039
+ # --- Phase 3: Adjacent placement for unrelated groups ---
1040
+ if plan.unrelated_groups and plan.sequential_chain:
1041
+ # Have both a sequential chain result and unrelated groups
1042
+ trees_to_place = [result_tree]
1043
+ for gi in plan.unrelated_groups:
1044
+ trees_to_place.append(group_trees[gi])
1045
+ if len(trees_to_place) > 1:
1046
+ result_tree = adjacent_place(trees_to_place, log=log)
1047
+ elif not plan.sequential_chain and len(plan.unrelated_groups) > 1:
1048
+ # All groups are unrelated — place adjacent
1049
+ trees_to_place = [group_trees[gi] for gi in plan.unrelated_groups]
1050
+ result_tree = adjacent_place(trees_to_place, log=log)
1051
+
1052
+ return result_tree
1053
+
1054
+
1055
+ def adjacent_place(trees: List[ET.ElementTree], *,
1056
+ log=None) -> ET.ElementTree:
1057
+ """Place multiple independent schemes side by side on one page.
1058
+
1059
+ Each tree keeps its own fragments, arrows, run arrows, and scheme/step
1060
+ structure. They are arranged horizontally with a generous gap.
1061
+
1062
+ Args:
1063
+ trees: List of CDXML ElementTrees to place adjacent.
1064
+ log: Logging callback.
1065
+
1066
+ Returns:
1067
+ Combined ElementTree with all schemes side by side.
1068
+ """
1069
+ if log is None:
1070
+ log = lambda msg: None
1071
+
1072
+ if len(trees) == 1:
1073
+ return trees[0]
1074
+
1075
+ log(f"Adjacent placement: {len(trees)} schemes side by side")
1076
+
1077
+ # Use first tree as base
1078
+ result_tree = copy.deepcopy(trees[0])
1079
+ result_root = result_tree.getroot()
1080
+ result_page = result_root.find(".//page")
1081
+
1082
+ # Find the right edge of existing content
1083
+ _, _, right_edge, _ = _page_bbox(result_page)
1084
+
1085
+ gap = ACS_BOND_LENGTH * 3.0 # generous horizontal gap
1086
+
1087
+ for tree_idx, extra_tree in enumerate(trees[1:], 2):
1088
+ extra_root = extra_tree.getroot()
1089
+ extra_page = extra_root.find(".//page")
1090
+ if extra_page is None:
1091
+ continue
1092
+
1093
+ # Get bbox of content to add
1094
+ ex_left, ex_top, ex_right, ex_bottom = _page_bbox(extra_page)
1095
+ if ex_left >= float('inf'):
1096
+ continue
1097
+
1098
+ # Compute horizontal shift to place after existing content
1099
+ dx = (right_edge + gap) - ex_left
1100
+
1101
+ # Remap IDs to avoid conflicts with existing content
1102
+ next_id = _get_max_id(result_root) + 1
1103
+ next_z = _get_max_z(result_root) + 1
1104
+ old_to_new = {}
1105
+
1106
+ # Copy elements from extra_page to result_page
1107
+ for el in list(extra_page):
1108
+ el_copy = copy.deepcopy(el)
1109
+ # Remap IDs
1110
+ _remap_element_ids(el_copy, old_to_new, next_id, next_z)
1111
+ next_id = max(next_id,
1112
+ _get_max_id_in_element(el_copy) + 1)
1113
+ next_z = max(next_z,
1114
+ _get_max_z_in_element(el_copy) + 1)
1115
+ # Shift horizontally
1116
+ _shift_element(el_copy, dx, 0)
1117
+ result_page.append(el_copy)
1118
+
1119
+ # Update right edge for next placement
1120
+ right_edge = right_edge + gap + (ex_right - ex_left)
1121
+ log(f" Placed scheme {tree_idx} at x offset {dx:.1f}")
1122
+
1123
+ _update_document_bbox(result_root, result_page)
1124
+ return result_tree
1125
+
1126
+
1127
+ def _page_bbox(page: ET.Element) -> Tuple[float, float, float, float]:
1128
+ """Get bounding box of all content on a page."""
1129
+ min_x = min_y = float('inf')
1130
+ max_x = max_y = float('-inf')
1131
+ for el in page:
1132
+ bb = _get_element_bbox(el)
1133
+ if bb is None:
1134
+ continue
1135
+ min_x = min(min_x, bb[0])
1136
+ min_y = min(min_y, bb[1])
1137
+ max_x = max(max_x, bb[2])
1138
+ max_y = max(max_y, bb[3])
1139
+ return min_x, min_y, max_x, max_y
1140
+
1141
+
1142
+ # ============================================================================
1143
+ # Parallel merge
1144
+ # ============================================================================
1145
+
1146
+ RUN_ARROW_SPACING = 16.0 # vertical gap between stacked run arrows (pts)
1147
+ RUN_ARROW_GAP = 20.0 # gap from content bottom to first run arrow (pts)
1148
+
1149
+ def parallel_merge(schemes: List[ParsedScheme], *,
1150
+ equiv_mode: str = "default",
1151
+ strict: bool = True,
1152
+ log=None) -> ET.ElementTree:
1153
+ """Merge schemes for the same reaction into one with stacked run arrows.
1154
+
1155
+ Args:
1156
+ schemes: Parsed schemes (same reaction, different scales).
1157
+ equiv_mode: "default" (keep scheme 1), "no-equiv", "equiv-range".
1158
+ strict: If True, reject when products don't match (default).
1159
+ log: Logging callback.
1160
+
1161
+ Returns:
1162
+ Merged ElementTree.
1163
+
1164
+ Raises:
1165
+ ValueError: If strict=True and schemes don't represent the same reaction.
1166
+ """
1167
+ if log is None:
1168
+ log = lambda msg: None
1169
+
1170
+ base = schemes[0]
1171
+ log(f"Parallel merge: {len(schemes)} schemes, "
1172
+ f"base = {os.path.basename(base.path)}")
1173
+
1174
+ # Validate: all schemes must have the same products
1175
+ for i, s in enumerate(schemes[1:], 2):
1176
+ if not _products_match(base, s):
1177
+ msg = (f"Scheme {i} ({os.path.basename(s.path)}) has different "
1178
+ f"products from scheme 1 ({os.path.basename(base.path)}) "
1179
+ f"— these are not the same reaction")
1180
+ if strict:
1181
+ raise ValueError(msg)
1182
+ log(f"WARNING: {msg}")
1183
+
1184
+ if not _any_reactant_matches(base, s):
1185
+ msg = (f"Scheme {i} ({os.path.basename(s.path)}) has no shared "
1186
+ f"reactants with scheme 1 — reagent drawn vs text?")
1187
+ log(f"WARNING: {msg}")
1188
+
1189
+ # Deep copy base scheme
1190
+ merged_tree = copy.deepcopy(base.tree)
1191
+ merged_root = merged_tree.getroot()
1192
+ merged_page = merged_root.find(".//page")
1193
+
1194
+ # Identify and remove existing run arrows + their text from the copy
1195
+ run_arrow_ids = set()
1196
+ run_graphic_ids = set()
1197
+ run_text_ids = set()
1198
+
1199
+ # Re-parse the copy to find its arrows
1200
+ copy_id_map = _build_page_id_map(merged_page)
1201
+ copy_arrows = [el for el in merged_page if el.tag == "arrow"]
1202
+ copy_graphics = {}
1203
+ for el in merged_page:
1204
+ if el.tag == "graphic" and el.get("SupersededBy"):
1205
+ copy_graphics[el.get("SupersededBy")] = el
1206
+
1207
+ # Find the main arrow (closest to fragment centroids)
1208
+ all_frag_ids = list(base.reactant_ids) + list(base.product_ids)
1209
+ frag_cy = _fragment_centroid_y(merged_page, all_frag_ids, copy_id_map)
1210
+
1211
+ main_arrow_el = None
1212
+ main_arrow_y = None
1213
+ for a in copy_arrows:
1214
+ _, _, ay = _get_arrow_coords(a)
1215
+ if main_arrow_el is None or abs(ay - frag_cy) < abs(main_arrow_y - frag_cy):
1216
+ main_arrow_el = a
1217
+ main_arrow_y = ay
1218
+
1219
+ # Everything below main arrow = run arrows to remove
1220
+ elements_to_remove = []
1221
+ for a in copy_arrows:
1222
+ if a is main_arrow_el:
1223
+ continue
1224
+ _, _, ay = _get_arrow_coords(a)
1225
+ if ay > main_arrow_y + 5.0:
1226
+ run_arrow_ids.add(a.get("id", ""))
1227
+ elements_to_remove.append(a)
1228
+ # Also remove corresponding graphic
1229
+ aid = a.get("id", "")
1230
+ if aid in copy_graphics:
1231
+ elements_to_remove.append(copy_graphics[aid])
1232
+ run_graphic_ids.add(copy_graphics[aid].get("id", ""))
1233
+
1234
+ # Find run arrow text (near removed arrows)
1235
+ for a in list(run_arrow_ids):
1236
+ a_el = copy_id_map.get(a)
1237
+ if a_el is None:
1238
+ continue
1239
+ _, _, ra_y = _get_arrow_coords(a_el)
1240
+ for el in merged_page:
1241
+ if el.tag != "t":
1242
+ continue
1243
+ p = el.get("p", "")
1244
+ if not p:
1245
+ continue
1246
+ ty = float(p.split()[1])
1247
+ if abs(ty - (ra_y + 2.25)) < 6.0:
1248
+ elements_to_remove.append(el)
1249
+ run_text_ids.add(el.get("id", ""))
1250
+
1251
+ # Remove elements
1252
+ for el in elements_to_remove:
1253
+ try:
1254
+ merged_page.remove(el)
1255
+ except ValueError:
1256
+ pass # already removed
1257
+
1258
+ # Also remove graphic for main arrow's superseding graphic if step
1259
+ # references a run arrow (clean up stale step metadata)
1260
+ # Update step to reference the main arrow's graphic
1261
+ merged_scheme = merged_page.find(".//scheme")
1262
+ merged_step = merged_scheme.find("step") if merged_scheme is not None else None
1263
+
1264
+ # Handle equivalents
1265
+ if equiv_mode == "no-equiv":
1266
+ _remove_equiv_labels(merged_page, log)
1267
+ elif equiv_mode == "equiv-range":
1268
+ _apply_equiv_range(merged_page, schemes, log)
1269
+
1270
+ # Collect all run arrow data from all schemes
1271
+ all_run_data = []
1272
+ for s in schemes:
1273
+ if s.run_arrow_data:
1274
+ all_run_data.extend(s.run_arrow_data)
1275
+ else:
1276
+ # Scheme has no run arrow (no ELN enrichment) — skip
1277
+ log(f" {os.path.basename(s.path)}: no run arrow data")
1278
+
1279
+ if not all_run_data:
1280
+ log("WARNING: No run arrow data found in any scheme")
1281
+ _update_document_bbox(merged_root, merged_page)
1282
+ return merged_tree
1283
+
1284
+ # Get main arrow coordinates from the merged copy
1285
+ main_tail_x, main_head_x, _ = _get_arrow_coords(main_arrow_el)
1286
+
1287
+ # Find content bottom (excluding removed elements)
1288
+ exclude_ids = run_arrow_ids | run_graphic_ids | run_text_ids
1289
+ bottom = _content_bottom(merged_page, exclude_ids)
1290
+
1291
+ # Create stacked run arrows
1292
+ next_id = _get_max_id(merged_root) + 1
1293
+ next_z = _get_max_z(merged_root) + 1
1294
+ run_arrow_y = bottom + RUN_ARROW_GAP
1295
+
1296
+ last_arrow_graphic_id = None
1297
+
1298
+ for i, rad in enumerate(all_run_data):
1299
+ if i > 0:
1300
+ run_arrow_y += RUN_ARROW_SPACING
1301
+
1302
+ # Create graphic (old-style ref)
1303
+ graphic_id = next_id
1304
+ next_id += 1
1305
+ arrow_id = next_id
1306
+ next_id += 1
1307
+
1308
+ graphic = _create_graphic(
1309
+ graphic_id, next_z, arrow_id,
1310
+ main_tail_x, main_head_x, run_arrow_y,
1311
+ )
1312
+ next_z += 1
1313
+ merged_page.append(graphic)
1314
+
1315
+ arrow = _create_arrow(
1316
+ arrow_id, next_z,
1317
+ main_tail_x, main_head_x, run_arrow_y,
1318
+ )
1319
+ next_z += 1
1320
+ merged_page.append(arrow)
1321
+
1322
+ last_arrow_graphic_id = graphic_id
1323
+
1324
+ # SM mass text (left of arrow, right-justified)
1325
+ text_y = run_arrow_y + 2.25
1326
+ if rad.sm_mass_text:
1327
+ sm_label = _create_text_element(
1328
+ next_id, next_z,
1329
+ main_tail_x - 4.0, text_y,
1330
+ rad.sm_mass_text, justify="Right",
1331
+ )
1332
+ next_id += 1
1333
+ next_z += 1
1334
+ merged_page.append(sm_label)
1335
+
1336
+ # Yield text (right of arrow, left-justified)
1337
+ if rad.yield_text:
1338
+ yield_label = _create_text_element(
1339
+ next_id, next_z,
1340
+ main_head_x + 4.0, text_y,
1341
+ rad.yield_text, justify="Left",
1342
+ )
1343
+ next_id += 1
1344
+ next_z += 1
1345
+ merged_page.append(yield_label)
1346
+
1347
+ log(f" Run arrow {i+1}: '{rad.sm_mass_text}' -> '{rad.yield_text}' "
1348
+ f"at y={run_arrow_y:.1f}")
1349
+
1350
+ # Update step to reference the bottom-most graphic
1351
+ # (matching reference file pattern)
1352
+ if merged_step is not None and last_arrow_graphic_id is not None:
1353
+ merged_step.set("ReactionStepArrows", str(last_arrow_graphic_id))
1354
+
1355
+ _update_document_bbox(merged_root, merged_page)
1356
+ return merged_tree
1357
+
1358
+
1359
+ def _remove_equiv_labels(page: ET.Element, log):
1360
+ """Remove all standalone equiv labels like '(1.6 eq.)'."""
1361
+ equiv_re = re.compile(r'^\s*\(\d+\.?\d*\s*eq\.\)\s*$')
1362
+ to_remove = []
1363
+ for el in page:
1364
+ if el.tag != "t":
1365
+ continue
1366
+ text = _get_text_content(el)
1367
+ if equiv_re.match(text):
1368
+ to_remove.append(el)
1369
+ log(f" Removing equiv label: '{text}'")
1370
+
1371
+ for el in to_remove:
1372
+ page.remove(el)
1373
+
1374
+ # Also remove equiv suffixes from conditions text
1375
+ for el in page:
1376
+ if el.tag != "t":
1377
+ continue
1378
+ _strip_equiv_from_conditions(el)
1379
+
1380
+
1381
+ def _strip_equiv_from_conditions(t_elem: ET.Element):
1382
+ """Remove ' (X eq.)' suffixes from conditions text <s> elements."""
1383
+ equiv_suffix = re.compile(r'\s*\(\d+\.?\d*\s*eq\.\)')
1384
+ for s in t_elem.iter("s"):
1385
+ if s.text:
1386
+ new_text = equiv_suffix.sub("", s.text)
1387
+ if new_text != s.text:
1388
+ s.text = new_text
1389
+
1390
+
1391
+ def _apply_equiv_range(page: ET.Element, schemes: List[ParsedScheme], log):
1392
+ """Replace equiv values with ranges when they differ across schemes."""
1393
+ # Collect equiv values per reagent name
1394
+ reagent_equivs: Dict[str, List[float]] = {}
1395
+ for s in schemes:
1396
+ for ei in s.equiv_values:
1397
+ name = ei.reagent_name.lower()
1398
+ if name not in reagent_equivs:
1399
+ reagent_equivs[name] = []
1400
+ reagent_equivs[name].append(ei.equiv_value)
1401
+
1402
+ if not reagent_equivs:
1403
+ return
1404
+
1405
+ # Build replacement map: old equiv text -> new equiv text
1406
+ equiv_re = re.compile(r'\((\d+\.?\d*)\s*eq\.\)')
1407
+
1408
+ for el in page:
1409
+ if el.tag != "t":
1410
+ continue
1411
+ text = _get_text_content(el)
1412
+ if not text:
1413
+ continue
1414
+
1415
+ # Check if this is a standalone equiv label
1416
+ standalone_m = re.match(r'^\s*\((\d+\.?\d*)\s*eq\.\)\s*$', text)
1417
+ if standalone_m:
1418
+ val = float(standalone_m.group(1))
1419
+ # Find which reagent this belongs to by matching value
1420
+ for name, vals in reagent_equivs.items():
1421
+ if val in vals:
1422
+ min_v, max_v = min(vals), max(vals)
1423
+ if min_v != max_v:
1424
+ new_text = f"({min_v:.2g} - {max_v:.2g} eq.)"
1425
+ log(f" Equiv range: '{text}' -> '{new_text}'")
1426
+ for s in el.iter("s"):
1427
+ s.text = new_text
1428
+ recompute_text_bbox(el)
1429
+ break
1430
+ continue
1431
+
1432
+ # Multi-line conditions: replace inline equiv values
1433
+ for s in el.iter("s"):
1434
+ if s.text:
1435
+ def _replace_equiv(m):
1436
+ val = float(m.group(1))
1437
+ for name, vals in reagent_equivs.items():
1438
+ if val in vals:
1439
+ min_v, max_v = min(vals), max(vals)
1440
+ if min_v != max_v:
1441
+ return f"({min_v:.2g} - {max_v:.2g} eq.)"
1442
+ return m.group(0)
1443
+ s.text = equiv_re.sub(_replace_equiv, s.text)
1444
+
1445
+
1446
+ # ============================================================================
1447
+ # Sequential merge
1448
+ # ============================================================================
1449
+
1450
+ def sequential_merge(schemes: List[ParsedScheme], *,
1451
+ ref_cdxml: str = None,
1452
+ log=None) -> ET.ElementTree:
1453
+ """Merge schemes where step N product = step N+1 starting material.
1454
+
1455
+ Creates a multi-step linear scheme.
1456
+
1457
+ Args:
1458
+ schemes: Parsed schemes in reaction order.
1459
+ ref_cdxml: Reference CDXML for final product alignment.
1460
+ log: Logging callback.
1461
+
1462
+ Returns:
1463
+ Merged ElementTree.
1464
+ """
1465
+ if log is None:
1466
+ log = lambda msg: None
1467
+
1468
+ log(f"Sequential merge: {len(schemes)} steps")
1469
+
1470
+ # Validate sequential linkage
1471
+ links = [] # (product_frag_id_in_step_i, reactant_frag_id_in_step_i+1)
1472
+ for i in range(len(schemes) - 1):
1473
+ s_cur = schemes[i]
1474
+ s_next = schemes[i + 1]
1475
+ found = False
1476
+ for pid in s_cur.product_ids:
1477
+ pfrag = s_cur.fragments.get(pid)
1478
+ psmiles = s_cur.fragment_smiles.get(pid, "")
1479
+ if pfrag is None:
1480
+ continue
1481
+ for rid in s_next.reactant_ids:
1482
+ rfrag = s_next.fragments.get(rid)
1483
+ rsmiles = s_next.fragment_smiles.get(rid, "")
1484
+ if rfrag is None:
1485
+ continue
1486
+ if _fragments_same_molecule(pfrag, rfrag, psmiles, rsmiles):
1487
+ links.append((pid, rid))
1488
+ log(f" Step {i+1} product {pid} matches "
1489
+ f"step {i+2} reactant {rid}")
1490
+ found = True
1491
+ break
1492
+ if found:
1493
+ break
1494
+ if not found:
1495
+ log(f"WARNING: No product/reactant match between "
1496
+ f"step {i+1} and step {i+2}")
1497
+ links.append(("", ""))
1498
+
1499
+ # --- Alignment cascade (backwards from final product) ---
1500
+ _alignment_cascade(schemes, links, ref_cdxml, log)
1501
+
1502
+ # --- Assemble the merged CDXML ---
1503
+ # Use first scheme's root as template for document settings
1504
+ base_root = schemes[0].root
1505
+ merged_root = ET.Element("CDXML")
1506
+ for attr_name in base_root.keys():
1507
+ merged_root.set(attr_name, base_root.get(attr_name))
1508
+
1509
+ # Copy colortable and fonttable
1510
+ for child_tag in ("colortable", "fonttable"):
1511
+ src = base_root.find(child_tag)
1512
+ if src is not None:
1513
+ merged_root.append(copy.deepcopy(src))
1514
+
1515
+ # Create page
1516
+ base_page = schemes[0].page
1517
+ merged_page = ET.SubElement(merged_root, "page")
1518
+ for attr_name in base_page.keys():
1519
+ merged_page.set(attr_name, base_page.get(attr_name))
1520
+
1521
+ # --- ID remapping and element collection ---
1522
+ next_id = 1
1523
+ next_z = 1
1524
+
1525
+ # For each step, collect the elements we need and remap their IDs
1526
+ step_data = [] # list of dicts per step
1527
+
1528
+ for step_idx, s in enumerate(schemes):
1529
+ sd = {
1530
+ "elements": [], # (element, role) tuples to add to page
1531
+ "reactant_ids": [], # remapped IDs
1532
+ "product_ids": [], # remapped IDs
1533
+ "above_ids": [], # remapped IDs
1534
+ "below_ids": [], # remapped IDs
1535
+ "arrow_id": "", # remapped main arrow ID
1536
+ "graphic_id": "", # remapped main graphic ID
1537
+ "other_ids": [], # remapped IDs of elements with no step role
1538
+ "skip_reactant_ids": set(), # IDs to skip (shared intermediate)
1539
+ "orig_arrow_cx": 0.0, # original arrow center for delta computation
1540
+ }
1541
+
1542
+ # Record original arrow center for computing layout deltas
1543
+ if s.main_arrow is not None:
1544
+ sd["orig_arrow_cx"] = (s.arrow_tail_x + s.arrow_head_x) / 2.0
1545
+
1546
+ # If this is not the first step, the shared intermediate (which is
1547
+ # step N-1's product) is already in the merged page. Skip the
1548
+ # duplicate reactant.
1549
+ if step_idx > 0 and links[step_idx - 1][1]:
1550
+ sd["skip_reactant_ids"].add(links[step_idx - 1][1])
1551
+
1552
+ # Deep copy all page elements and remap IDs
1553
+ old_to_new = {}
1554
+
1555
+ # Collect elements from this scheme's page
1556
+ for el in s.page:
1557
+ if el.tag in ("scheme",):
1558
+ continue # will rebuild scheme/step
1559
+
1560
+ eid = el.get("id", "")
1561
+
1562
+ # Skip elements belonging to run arrows (we'll rebuild them)
1563
+ if eid in {a.get("id", "") for a in s.run_arrows}:
1564
+ continue
1565
+ if eid in {g.get("id", "") for g in s.run_graphics}:
1566
+ continue
1567
+ if eid in set(s.run_arrow_text_ids):
1568
+ continue
1569
+
1570
+ # Skip the duplicate reactant
1571
+ if eid in sd["skip_reactant_ids"]:
1572
+ continue
1573
+
1574
+ # Skip standalone equiv labels — these are redundant in multi-step
1575
+ # schemes because the conditions text already contains equiv info
1576
+ # (e.g. "PPh₃ (1.6 eq.)"). ELN enrichment adds these as separate
1577
+ # text elements next to fragments, but they cause overlap when
1578
+ # the layout rearranges positions.
1579
+ if el.tag == "t":
1580
+ text = _get_text_content(el)
1581
+ if re.match(r'^\s*\(\d+\.?\d*\s*eq\.\)\s*$', text):
1582
+ log(f" Skipping standalone equiv label: '{text}'")
1583
+ continue
1584
+
1585
+ el_copy = copy.deepcopy(el)
1586
+
1587
+ # Remap all IDs in this element subtree
1588
+ _remap_element_ids(el_copy, old_to_new, next_id, next_z)
1589
+ next_id = max(next_id, _get_max_id_in_element(el_copy) + 1)
1590
+ next_z = max(next_z, _get_max_z_in_element(el_copy) + 1)
1591
+
1592
+ # Determine role
1593
+ role = "other"
1594
+ if eid in s.reactant_ids:
1595
+ role = "reactant"
1596
+ elif eid in s.product_ids:
1597
+ role = "product"
1598
+ elif eid in s.above_arrow_ids:
1599
+ role = "above"
1600
+ elif eid in s.below_arrow_ids:
1601
+ role = "below"
1602
+ elif el.tag == "arrow" and eid == s.main_arrow_id:
1603
+ role = "main_arrow"
1604
+ elif el.tag == "graphic" and s.main_arrow is not None and \
1605
+ el.get("SupersededBy") == s.main_arrow_id:
1606
+ role = "main_graphic"
1607
+
1608
+ sd["elements"].append((el_copy, role, eid))
1609
+
1610
+ # FIX 1: After all elements collected, fix cross-references with
1611
+ # the complete old_to_new mapping. When _remap_element_ids processes
1612
+ # elements individually, a graphic's SupersededBy might reference an
1613
+ # arrow that hasn't been assigned a new ID yet. Now old_to_new is
1614
+ # complete for this step, so we fix any stale references.
1615
+ for el_copy, role, old_id in sd["elements"]:
1616
+ for node in el_copy.iter():
1617
+ for attr in ("SupersededBy",):
1618
+ ref = node.get(attr)
1619
+ if ref and ref in old_to_new:
1620
+ node.set(attr, old_to_new[ref])
1621
+
1622
+ # Build remapped ID lists
1623
+ for el_copy, role, old_id in sd["elements"]:
1624
+ new_id = el_copy.get("id", "")
1625
+ if role == "reactant":
1626
+ sd["reactant_ids"].append(new_id)
1627
+ elif role == "product":
1628
+ sd["product_ids"].append(new_id)
1629
+ elif role == "above":
1630
+ sd["above_ids"].append(new_id)
1631
+ elif role == "below":
1632
+ sd["below_ids"].append(new_id)
1633
+ elif role == "main_arrow":
1634
+ sd["arrow_id"] = new_id
1635
+ elif role == "main_graphic":
1636
+ sd["graphic_id"] = new_id
1637
+ elif role == "other":
1638
+ sd["other_ids"].append(new_id)
1639
+
1640
+ # If shared intermediate was skipped, use the previous step's product ID
1641
+ if step_idx > 0 and links[step_idx - 1][1]:
1642
+ prev_sd = step_data[step_idx - 1]
1643
+ # The intermediate is the product of the previous step
1644
+ if prev_sd["product_ids"]:
1645
+ intermediate_id = prev_sd["product_ids"][0]
1646
+ sd["reactant_ids"].insert(0, intermediate_id)
1647
+
1648
+ step_data.append(sd)
1649
+
1650
+ # --- Horizontal layout ---
1651
+ bond_len = ACS_BOND_LENGTH
1652
+ frag_gap = bond_len * LAYOUT_FRAG_GAP_BONDS
1653
+ inter_gap = bond_len * LAYOUT_INTER_GAP_BONDS
1654
+
1655
+ # Add all elements to the merged page first
1656
+ all_elements = {} # id -> element
1657
+ for sd in step_data:
1658
+ for el_copy, role, old_id in sd["elements"]:
1659
+ merged_page.append(el_copy)
1660
+ new_id = el_copy.get("id", "")
1661
+ all_elements[new_id] = el_copy
1662
+
1663
+ # Now lay out step by step, left to right
1664
+ cursor_x = 100.0 # starting x position
1665
+ arrow_y = None # will be set from first step's reactant centroids
1666
+
1667
+ # Compute a common arrow_y from all reactant fragments
1668
+ all_reactant_bbs = []
1669
+ for sd in step_data:
1670
+ for rid in sd["reactant_ids"]:
1671
+ el = all_elements.get(rid)
1672
+ if el is not None and el.tag == "fragment":
1673
+ bb = fragment_bbox(el)
1674
+ if bb:
1675
+ all_reactant_bbs.append(bb)
1676
+ if all_reactant_bbs:
1677
+ arrow_y = sum((bb[1] + bb[3]) / 2.0 for bb in all_reactant_bbs) / len(all_reactant_bbs)
1678
+ else:
1679
+ arrow_y = 200.0 # fallback
1680
+
1681
+ placed_ids = set() # track already-placed fragment IDs (for shared intermediates)
1682
+
1683
+ for step_idx, sd in enumerate(step_data):
1684
+ # Place reactants (skip shared intermediate if already placed)
1685
+ for rid in sd["reactant_ids"]:
1686
+ if rid in placed_ids:
1687
+ # Already placed as previous step's product — account for its width
1688
+ el = all_elements.get(rid)
1689
+ if el is not None:
1690
+ bb = _get_element_bbox(el)
1691
+ if bb:
1692
+ cursor_x = bb[2] + frag_gap # right edge + gap
1693
+ continue
1694
+
1695
+ el = all_elements.get(rid)
1696
+ if el is None or el.tag != "fragment":
1697
+ continue
1698
+ # Use atom-only bbox for centering (avoids salt/counterion inflation)
1699
+ bb_atoms = fragment_bbox(el)
1700
+ bb_full = _get_element_bbox(el)
1701
+ bb = bb_atoms if bb_atoms else bb_full
1702
+ if bb is None:
1703
+ continue
1704
+ w = bb[2] - bb[0]
1705
+ cy = (bb[1] + bb[3]) / 2.0
1706
+ dx = (cursor_x + w / 2.0) - (bb[0] + bb[2]) / 2.0
1707
+ dy = arrow_y - cy
1708
+ _shift_element(el, dx, dy)
1709
+ placed_ids.add(rid)
1710
+ full_w = (bb_full[2] - bb_full[0]) if bb_full else w
1711
+ cursor_x += max(w, full_w) + inter_gap
1712
+
1713
+ # Replace last inter_gap with frag_gap before arrow
1714
+ if sd["reactant_ids"]:
1715
+ cursor_x = cursor_x - inter_gap + frag_gap
1716
+
1717
+ # Compute arrow length from above/below content
1718
+ above_els = [all_elements.get(i) for i in sd["above_ids"]
1719
+ if all_elements.get(i) is not None]
1720
+ below_els = [all_elements.get(i) for i in sd["below_ids"]
1721
+ if all_elements.get(i) is not None]
1722
+ min_arrow = bond_len * 5.0
1723
+ max_w = 0.0
1724
+ for el in above_els + below_els:
1725
+ if el is not None:
1726
+ bb = _get_element_bbox(el)
1727
+ if bb:
1728
+ w = bb[2] - bb[0]
1729
+ if w > max_w:
1730
+ max_w = w
1731
+ arrow_len = max(min_arrow, max_w + 10.0)
1732
+
1733
+ # Place arrow
1734
+ tail_x = cursor_x
1735
+ head_x = cursor_x + arrow_len
1736
+ arrow_cx = (tail_x + head_x) / 2.0
1737
+
1738
+ arrow_el = all_elements.get(sd["arrow_id"])
1739
+ if arrow_el is not None:
1740
+ arrow_el.set("Tail3D", f"{tail_x:.2f} {arrow_y:.2f} 0")
1741
+ arrow_el.set("Head3D", f"{head_x:.2f} {arrow_y:.2f} 0")
1742
+ bb_top = arrow_y - 1.64
1743
+ bb_bot = arrow_y + 1.52
1744
+ arrow_el.set("BoundingBox",
1745
+ f"{tail_x:.2f} {bb_top:.2f} {head_x:.2f} {bb_bot:.2f}")
1746
+ # Center3D etc.
1747
+ cx_3d = arrow_cx + 280.0
1748
+ cy_3d = arrow_y + 130.0
1749
+ half_len = arrow_len / 2.0
1750
+ arrow_el.set("Center3D", f"{cx_3d:.2f} {cy_3d:.2f} 0")
1751
+ arrow_el.set("MajorAxisEnd3D",
1752
+ f"{cx_3d + half_len:.2f} {cy_3d:.2f} 0")
1753
+ arrow_el.set("MinorAxisEnd3D",
1754
+ f"{cx_3d:.2f} {cy_3d + half_len:.2f} 0")
1755
+
1756
+ # Update graphic
1757
+ graphic_el = all_elements.get(sd["graphic_id"])
1758
+ if graphic_el is not None:
1759
+ graphic_el.set("BoundingBox",
1760
+ f"{head_x:.2f} {arrow_y:.2f} "
1761
+ f"{tail_x:.2f} {arrow_y:.2f}")
1762
+
1763
+ # Stack above/below arrow objects
1764
+ for el in above_els:
1765
+ if el is None:
1766
+ continue
1767
+ if el.tag == "t":
1768
+ # Text goes below arrow
1769
+ pass # handled in below stacking
1770
+ else:
1771
+ bb = _get_element_bbox(el)
1772
+ if bb is None:
1773
+ continue
1774
+ h = bb[3] - bb[1]
1775
+ if el.tag == "fragment" and fragment_bottom_has_hanging_label(el):
1776
+ gap = LAYOUT_HANGING_LABEL_GAP
1777
+ else:
1778
+ gap = LAYOUT_ABOVE_GAP
1779
+ target_bottom = arrow_y - gap
1780
+ target_cy = target_bottom - h / 2.0
1781
+ _move_element_to(el, arrow_cx, target_cy)
1782
+
1783
+ # Below arrow: collect all text (from above + below lists)
1784
+ above_texts = [all_elements.get(i) for i in sd["above_ids"]
1785
+ if all_elements.get(i) is not None
1786
+ and all_elements[i].tag == "t"]
1787
+ below_texts = [all_elements.get(i) for i in sd["below_ids"]
1788
+ if all_elements.get(i) is not None
1789
+ and all_elements[i].tag == "t"]
1790
+ all_below_text = above_texts + below_texts
1791
+
1792
+ BASELINE_OFFSET = 10.0
1793
+ TEXT_ELEMENT_GAP = 2.0 # gap between consecutive text elements
1794
+ y_cursor = arrow_y + LAYOUT_BELOW_GAP + BASELINE_OFFSET
1795
+ for el in all_below_text:
1796
+ if el is None:
1797
+ continue
1798
+ el.set("p", f"{arrow_cx:.2f} {y_cursor:.2f}")
1799
+ el.set("CaptionJustification", "Center")
1800
+ el.set("Justification", "Center")
1801
+ recompute_text_bbox(el)
1802
+ # Advance cursor past this element's full height
1803
+ bb_str = el.get("BoundingBox", "")
1804
+ if bb_str:
1805
+ bb_vals = [float(v) for v in bb_str.split()]
1806
+ if len(bb_vals) >= 4:
1807
+ y_cursor = bb_vals[3] + TEXT_ELEMENT_GAP
1808
+ else:
1809
+ y_cursor += 13.0
1810
+ else:
1811
+ y_cursor += 13.0
1812
+
1813
+ cursor_x = head_x + frag_gap
1814
+
1815
+ # Place products — use main-component bbox for centering to handle
1816
+ # salt products (e.g. amine + HCl where counterion is far from main)
1817
+ for pid in sd["product_ids"]:
1818
+ if pid in placed_ids:
1819
+ continue
1820
+ el = all_elements.get(pid)
1821
+ if el is None or el.tag != "fragment":
1822
+ continue
1823
+ # Use main-component bbox for centering (ignores distant counterions)
1824
+ bb_main = _fragment_main_component_bbox(el)
1825
+ bb_full = _get_element_bbox(el)
1826
+ bb = bb_main if bb_main else bb_full
1827
+ if bb is None:
1828
+ continue
1829
+ w = bb[2] - bb[0]
1830
+ cy = (bb[1] + bb[3]) / 2.0
1831
+ # Move using main-component center
1832
+ dx = (cursor_x + w / 2.0) - (bb[0] + bb[2]) / 2.0
1833
+ dy = arrow_y - cy
1834
+ _shift_element(el, dx, dy)
1835
+ placed_ids.add(pid)
1836
+ # Use full bbox width for cursor advance (accounts for labels)
1837
+ full_w = (bb_full[2] - bb_full[0]) if bb_full else w
1838
+ cursor_x += max(w, full_w) + inter_gap
1839
+
1840
+ # After products, replace inter_gap with frag_gap for next step
1841
+ cursor_x = cursor_x - inter_gap + frag_gap
1842
+
1843
+ # FIX 2: Shift "other" elements (standalone equiv labels etc.)
1844
+ # by the same horizontal delta as the step's arrow moved.
1845
+ new_arrow_cx = arrow_cx # arrow_cx was set above during arrow placement
1846
+ orig_cx = sd["orig_arrow_cx"]
1847
+ if orig_cx and sd["other_ids"]:
1848
+ delta_x = new_arrow_cx - orig_cx
1849
+ for oid in sd["other_ids"]:
1850
+ oel = all_elements.get(oid)
1851
+ if oel is not None:
1852
+ _shift_element(oel, delta_x, 0.0)
1853
+
1854
+ # --- Create multi-step <scheme> ---
1855
+ scheme_el = ET.SubElement(merged_page, "scheme")
1856
+ scheme_el.set("id", str(next_id))
1857
+ next_id += 1
1858
+
1859
+ for sd in step_data:
1860
+ step_el = ET.SubElement(scheme_el, "step")
1861
+ step_el.set("id", str(next_id))
1862
+ next_id += 1
1863
+ step_el.set("ReactionStepReactants",
1864
+ " ".join(sd["reactant_ids"]))
1865
+ step_el.set("ReactionStepProducts",
1866
+ " ".join(sd["product_ids"]))
1867
+ if sd["graphic_id"]:
1868
+ step_el.set("ReactionStepArrows", sd["graphic_id"])
1869
+ elif sd["arrow_id"]:
1870
+ step_el.set("ReactionStepArrows", sd["arrow_id"])
1871
+ step_el.set("ReactionStepObjectsAboveArrow",
1872
+ " ".join(sd["above_ids"]))
1873
+ step_el.set("ReactionStepObjectsBelowArrow",
1874
+ " ".join(sd["below_ids"]))
1875
+
1876
+ # --- Add run arrows per step ---
1877
+ for step_idx, (sd, s) in enumerate(zip(step_data, schemes)):
1878
+ if not s.run_arrow_data:
1879
+ continue
1880
+
1881
+ arrow_el = all_elements.get(sd["arrow_id"])
1882
+ if arrow_el is None:
1883
+ continue
1884
+ step_tail_x, step_head_x, step_y = _get_arrow_coords(arrow_el)
1885
+
1886
+ # Find content bottom for this step's column
1887
+ bottom = step_y
1888
+ for eid in (sd["above_ids"] + sd["below_ids"] +
1889
+ sd["reactant_ids"] + sd["product_ids"]):
1890
+ el = all_elements.get(eid)
1891
+ if el is not None:
1892
+ bb = _get_element_bbox(el)
1893
+ if bb and bb[3] > bottom:
1894
+ bottom = bb[3]
1895
+
1896
+ run_y = bottom + RUN_ARROW_GAP
1897
+ for i, rad in enumerate(s.run_arrow_data):
1898
+ if i > 0:
1899
+ run_y += RUN_ARROW_SPACING
1900
+
1901
+ g_id = next_id
1902
+ next_id += 1
1903
+ a_id = next_id
1904
+ next_id += 1
1905
+
1906
+ graphic = _create_graphic(
1907
+ g_id, next_z, a_id,
1908
+ step_tail_x, step_head_x, run_y,
1909
+ )
1910
+ next_z += 1
1911
+ merged_page.append(graphic)
1912
+
1913
+ arrow = _create_arrow(
1914
+ a_id, next_z,
1915
+ step_tail_x, step_head_x, run_y,
1916
+ )
1917
+ next_z += 1
1918
+ merged_page.append(arrow)
1919
+
1920
+ text_y = run_y + 2.25
1921
+ if rad.sm_mass_text:
1922
+ merged_page.append(_create_text_element(
1923
+ next_id, next_z,
1924
+ step_tail_x - 4.0, text_y,
1925
+ rad.sm_mass_text, justify="Right",
1926
+ ))
1927
+ next_id += 1
1928
+ next_z += 1
1929
+
1930
+ if rad.yield_text:
1931
+ merged_page.append(_create_text_element(
1932
+ next_id, next_z,
1933
+ step_head_x + 4.0, text_y,
1934
+ rad.yield_text, justify="Left",
1935
+ ))
1936
+ next_id += 1
1937
+ next_z += 1
1938
+
1939
+ # Increase page width if needed
1940
+ merged_page.set("WidthPages", "4")
1941
+
1942
+ _update_document_bbox(merged_root, merged_page)
1943
+
1944
+ merged_tree = ET.ElementTree(merged_root)
1945
+ return merged_tree
1946
+
1947
+
1948
+ def _remap_element_ids(el: ET.Element, old_to_new: Dict[str, str],
1949
+ start_id: int, start_z: int):
1950
+ """Remap all id/Z attributes and cross-references in an element subtree.
1951
+
1952
+ Populates old_to_new mapping as a side effect.
1953
+ """
1954
+ # Phase 1: assign new IDs
1955
+ counter = [start_id, start_z]
1956
+ for node in el.iter():
1957
+ old_id = node.get("id")
1958
+ if old_id and old_id not in old_to_new:
1959
+ new_id = str(counter[0])
1960
+ old_to_new[old_id] = new_id
1961
+ counter[0] += 1
1962
+
1963
+ # Phase 2: apply mapping
1964
+ for node in el.iter():
1965
+ # Remap id
1966
+ old_id = node.get("id")
1967
+ if old_id and old_id in old_to_new:
1968
+ node.set("id", old_to_new[old_id])
1969
+
1970
+ # Remap Z (just increment to avoid conflicts)
1971
+ z = node.get("Z")
1972
+ if z:
1973
+ node.set("Z", str(counter[1]))
1974
+ counter[1] += 1
1975
+
1976
+ # Remap bond references
1977
+ for attr in ("B", "E"):
1978
+ ref = node.get(attr)
1979
+ if ref and ref in old_to_new:
1980
+ node.set(attr, old_to_new[ref])
1981
+
1982
+ # Remap SupersededBy
1983
+ sup = node.get("SupersededBy")
1984
+ if sup and sup in old_to_new:
1985
+ node.set("SupersededBy", old_to_new[sup])
1986
+
1987
+ # Remap BondCircularOrdering
1988
+ bco = node.get("BondCircularOrdering")
1989
+ if bco:
1990
+ parts = bco.split()
1991
+ new_parts = [old_to_new.get(p, p) for p in parts]
1992
+ node.set("BondCircularOrdering", " ".join(new_parts))
1993
+
1994
+
1995
+ def _get_max_id_in_element(el: ET.Element) -> int:
1996
+ """Get max id in an element subtree."""
1997
+ max_id = 0
1998
+ for node in el.iter():
1999
+ eid = node.get("id", "")
2000
+ if eid:
2001
+ try:
2002
+ max_id = max(max_id, int(eid))
2003
+ except ValueError:
2004
+ pass
2005
+ return max_id
2006
+
2007
+
2008
+ def _get_max_z_in_element(el: ET.Element) -> int:
2009
+ """Get max Z in an element subtree."""
2010
+ max_z = 0
2011
+ for node in el.iter():
2012
+ z = node.get("Z", "")
2013
+ if z:
2014
+ try:
2015
+ max_z = max(max_z, int(z))
2016
+ except ValueError:
2017
+ pass
2018
+ return max_z
2019
+
2020
+
2021
+ # ============================================================================
2022
+ # Alignment cascade for sequential merge
2023
+ # ============================================================================
2024
+
2025
+ def _alignment_cascade(schemes: List[ParsedScheme],
2026
+ links: List[Tuple[str, str]],
2027
+ ref_cdxml: str, log):
2028
+ """Align all structures backwards from the final product.
2029
+
2030
+ Modifies scheme trees in-place.
2031
+ """
2032
+ try:
2033
+ from .alignment import (
2034
+ align_product_to_reference,
2035
+ rdkit_align_to_product,
2036
+ )
2037
+ has_alignment = True
2038
+ except ImportError:
2039
+ has_alignment = False
2040
+ log("WARNING: alignment module not available — skipping alignment")
2041
+ return
2042
+
2043
+ # Process backwards from last step
2044
+ for step_idx in range(len(schemes) - 1, -1, -1):
2045
+ s = schemes[step_idx]
2046
+ root = s.root
2047
+
2048
+ if step_idx == len(schemes) - 1 and ref_cdxml:
2049
+ # Last step: align product to external reference
2050
+ log(f" Step {step_idx+1}: aligning product to reference {ref_cdxml}")
2051
+ try:
2052
+ align_product_to_reference(root, ref_cdxml, verbose=False)
2053
+ except Exception as e:
2054
+ log(f" WARNING: align_product_to_reference failed: {e}")
2055
+
2056
+ elif step_idx < len(schemes) - 1:
2057
+ # Earlier step: align this step's product to the already-aligned
2058
+ # version of the same molecule from the next step's reactant side.
2059
+ next_s = schemes[step_idx + 1]
2060
+ link_product_id, link_reactant_id = links[step_idx]
2061
+
2062
+ if link_reactant_id and link_reactant_id in next_s.fragments:
2063
+ # Write the aligned reactant from next step to a temp CDXML
2064
+ aligned_frag = next_s.fragments[link_reactant_id]
2065
+ try:
2066
+ ref_path = _write_temp_fragment_cdxml(aligned_frag, s.root)
2067
+ log(f" Step {step_idx+1}: aligning product to "
2068
+ f"step {step_idx+2}'s aligned reactant")
2069
+ align_product_to_reference(root, ref_path, verbose=False)
2070
+ os.unlink(ref_path)
2071
+ except Exception as e:
2072
+ log(f" WARNING: cross-step alignment failed: {e}")
2073
+
2074
+ # Align all structures within this step to the product
2075
+ log(f" Step {step_idx+1}: aligning reactants/reagents to product")
2076
+ try:
2077
+ rdkit_align_to_product(root, verbose=False)
2078
+ except Exception as e:
2079
+ log(f" WARNING: within-step alignment failed: {e}")
2080
+
2081
+
2082
+ def _write_temp_fragment_cdxml(frag: ET.Element,
2083
+ source_root: ET.Element) -> str:
2084
+ """Write a single fragment to a temporary CDXML file for use as alignment ref."""
2085
+ from ..constants import CDXML_MINIMAL_HEADER, CDXML_FOOTER
2086
+
2087
+ frag_copy = copy.deepcopy(frag)
2088
+ # Wrap in minimal CDXML
2089
+ content = CDXML_MINIMAL_HEADER
2090
+ content += ET.tostring(frag_copy, encoding="unicode")
2091
+ content += CDXML_FOOTER
2092
+
2093
+ fd, path = tempfile.mkstemp(suffix=".cdxml")
2094
+ with os.fdopen(fd, "w", encoding="utf-8") as f:
2095
+ f.write(content)
2096
+ return path
2097
+
2098
+
2099
+ # ============================================================================
2100
+ # CLI
2101
+ # ============================================================================
2102
+
2103
+ def main():
2104
+ parser = argparse.ArgumentParser(
2105
+ description="Merge ELN-enriched reaction schemes (auto-detects mode).",
2106
+ formatter_class=argparse.RawDescriptionHelpFormatter,
2107
+ epilog=__doc__,
2108
+ )
2109
+
2110
+ parser.add_argument("inputs", nargs="+", metavar="CDXML",
2111
+ help="Input CDXML files (2 or more)")
2112
+ parser.add_argument("-o", "--output", default=None,
2113
+ help="Output CDXML path "
2114
+ "(default: auto-generated from input names)")
2115
+
2116
+ # Mode: auto-detect by default, explicit override available
2117
+ parser.add_argument("--mode", choices=["auto", "parallel", "sequential"],
2118
+ default="auto",
2119
+ help="Merge mode (default: auto-detect)")
2120
+ # Backward-compat aliases (deprecated)
2121
+ parser.add_argument("--parallel", action="store_true",
2122
+ help=argparse.SUPPRESS) # deprecated
2123
+ parser.add_argument("--sequential", action="store_true",
2124
+ help=argparse.SUPPRESS) # deprecated
2125
+
2126
+ # Parallel options
2127
+ parser.add_argument("--no-equiv", action="store_true",
2128
+ help="Remove all equivalents labels")
2129
+ parser.add_argument("--equiv-range", action="store_true",
2130
+ help="Show equiv range when values differ")
2131
+
2132
+ # Sequential options
2133
+ parser.add_argument("--ref-cdxml", default=None,
2134
+ help="Reference CDXML for final product alignment")
2135
+
2136
+ # Unrelated handling
2137
+ parser.add_argument("--adjacent", action="store_true", default=True,
2138
+ help="Place unrelated reactions side by side (default)")
2139
+ parser.add_argument("--no-adjacent", action="store_true",
2140
+ help="Error if any reactions are unrelated")
2141
+
2142
+ # Common
2143
+ parser.add_argument("-v", "--verbose", action="store_true",
2144
+ help="Print progress to stderr")
2145
+ parser.add_argument("--render", action="store_true",
2146
+ help="Render output to PNG via cdxml_to_image.py")
2147
+
2148
+ args = parser.parse_args()
2149
+
2150
+ # Handle deprecated flags
2151
+ mode = args.mode
2152
+ if args.parallel:
2153
+ print("WARNING: --parallel is deprecated, use --mode parallel",
2154
+ file=sys.stderr)
2155
+ mode = "parallel"
2156
+ elif args.sequential:
2157
+ print("WARNING: --sequential is deprecated, use --mode sequential",
2158
+ file=sys.stderr)
2159
+ mode = "sequential"
2160
+
2161
+ if len(args.inputs) < 2:
2162
+ parser.error("Need at least 2 input files")
2163
+
2164
+ for path in args.inputs:
2165
+ if not os.path.isfile(path):
2166
+ parser.error(f"File not found: {path}")
2167
+
2168
+ log = (lambda msg: print(msg, file=sys.stderr)) if args.verbose else (lambda msg: None)
2169
+
2170
+ # Parse all input schemes
2171
+ schemes = []
2172
+ for path in args.inputs:
2173
+ try:
2174
+ ps = parse_scheme(path, log=log)
2175
+ schemes.append(ps)
2176
+ except Exception as e:
2177
+ print(f"ERROR: Failed to parse {path}: {e}", file=sys.stderr)
2178
+ sys.exit(1)
2179
+
2180
+ # Determine equiv mode
2181
+ equiv_mode = "default"
2182
+ if args.no_equiv:
2183
+ equiv_mode = "no-equiv"
2184
+ elif args.equiv_range:
2185
+ equiv_mode = "equiv-range"
2186
+
2187
+ allow_adjacent = not args.no_adjacent
2188
+
2189
+ # Merge
2190
+ try:
2191
+ if mode == "auto":
2192
+ plan = auto_detect(schemes, log=log)
2193
+ log(f"Detected: {plan.describe()}")
2194
+ merged_tree = execute_merge_plan(
2195
+ schemes, plan,
2196
+ equiv_mode=equiv_mode,
2197
+ ref_cdxml=args.ref_cdxml,
2198
+ allow_adjacent=allow_adjacent,
2199
+ log=log,
2200
+ )
2201
+ elif mode == "parallel":
2202
+ merged_tree = parallel_merge(
2203
+ schemes, equiv_mode=equiv_mode, strict=True, log=log)
2204
+ elif mode == "sequential":
2205
+ merged_tree = sequential_merge(
2206
+ schemes, ref_cdxml=args.ref_cdxml, log=log)
2207
+ except ValueError as e:
2208
+ print(f"ERROR: {e}", file=sys.stderr)
2209
+ sys.exit(1)
2210
+
2211
+ # Determine output path
2212
+ if args.output:
2213
+ out_path = args.output
2214
+ else:
2215
+ # Auto-generate from input names
2216
+ stems = []
2217
+ for p in args.inputs:
2218
+ stem = os.path.splitext(os.path.basename(p))[0]
2219
+ # Strip common suffixes like "-scheme"
2220
+ stem = re.sub(r'-scheme$', '', stem)
2221
+ stems.append(stem)
2222
+ # Find common prefix
2223
+ prefix = os.path.commonprefix(stems)
2224
+ if prefix and prefix[-1] == '-':
2225
+ prefix = prefix[:-1]
2226
+ if prefix:
2227
+ # Use prefix + unique suffixes
2228
+ suffixes = []
2229
+ for s in stems:
2230
+ suffix = s[len(prefix):].lstrip('-')
2231
+ if suffix:
2232
+ suffixes.append(suffix)
2233
+ if suffixes:
2234
+ out_name = f"{prefix}-{'+'.join(suffixes)}-merged.cdxml"
2235
+ else:
2236
+ out_name = f"{prefix}-merged.cdxml"
2237
+ else:
2238
+ out_name = f"{stems[0]}-merged.cdxml"
2239
+ out_dir = os.path.dirname(args.inputs[0])
2240
+ out_path = os.path.join(out_dir, out_name)
2241
+
2242
+ # Write output
2243
+ write_cdxml(merged_tree, out_path)
2244
+ log(f"Written: {out_path}")
2245
+ print(out_path)
2246
+
2247
+ # Optional render
2248
+ if args.render:
2249
+ script_dir = os.path.dirname(os.path.abspath(__file__))
2250
+ render_script = os.path.join(script_dir, "cdxml_to_image.py")
2251
+ if os.path.isfile(render_script):
2252
+ subprocess.run(
2253
+ [sys.executable, render_script, out_path],
2254
+ capture_output=True, text=True,
2255
+ )
2256
+ log(f"Rendered: {os.path.splitext(out_path)[0]}.png")
2257
+
2258
+
2259
+ if __name__ == "__main__":
2260
+ main()