cdxml-toolkit 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (91) hide show
  1. cdxml_toolkit/__init__.py +18 -0
  2. cdxml_toolkit/_jre/__init__.py +2 -0
  3. cdxml_toolkit/_jre/temurin-21-jre-win-x64.zip +0 -0
  4. cdxml_toolkit/analysis/__init__.py +35 -0
  5. cdxml_toolkit/analysis/deterministic/__init__.py +12 -0
  6. cdxml_toolkit/analysis/deterministic/discover_experiment_files.py +413 -0
  7. cdxml_toolkit/analysis/deterministic/lab_book_formatter.py +701 -0
  8. cdxml_toolkit/analysis/deterministic/lcms_file_categorizer.py +928 -0
  9. cdxml_toolkit/analysis/deterministic/lcms_identifier.py +598 -0
  10. cdxml_toolkit/analysis/deterministic/mass_resolver.py +654 -0
  11. cdxml_toolkit/analysis/deterministic/multi_lcms_analyzer.py +1412 -0
  12. cdxml_toolkit/analysis/deterministic/procedure_writer.py +446 -0
  13. cdxml_toolkit/analysis/extract_nmr.py +47 -0
  14. cdxml_toolkit/analysis/format_procedure_entry.py +479 -0
  15. cdxml_toolkit/analysis/lcms_analyzer.py +1299 -0
  16. cdxml_toolkit/analysis/parse_analysis_file.py +134 -0
  17. cdxml_toolkit/cdxml_builder.py +920 -0
  18. cdxml_toolkit/cdxml_utils.py +342 -0
  19. cdxml_toolkit/chemdraw/__init__.py +5 -0
  20. cdxml_toolkit/chemdraw/_chemscript_server.py +562 -0
  21. cdxml_toolkit/chemdraw/cdx_converter.py +527 -0
  22. cdxml_toolkit/chemdraw/cdxml_to_image.py +262 -0
  23. cdxml_toolkit/chemdraw/cdxml_to_image_rdkit.py +296 -0
  24. cdxml_toolkit/chemdraw/chemscript_bridge.py +901 -0
  25. cdxml_toolkit/constants.py +304 -0
  26. cdxml_toolkit/coord_normalizer.py +438 -0
  27. cdxml_toolkit/deterministic_pipeline/__init__.py +6 -0
  28. cdxml_toolkit/deterministic_pipeline/legacy/__init__.py +5 -0
  29. cdxml_toolkit/deterministic_pipeline/legacy/eln_cdx_cleanup.py +509 -0
  30. cdxml_toolkit/deterministic_pipeline/legacy/eln_enrichment.py +1394 -0
  31. cdxml_toolkit/deterministic_pipeline/legacy/scheme_aligner.py +428 -0
  32. cdxml_toolkit/deterministic_pipeline/legacy/scheme_polisher.py +1337 -0
  33. cdxml_toolkit/deterministic_pipeline/legacy/scheme_polisher_v2.py +1340 -0
  34. cdxml_toolkit/deterministic_pipeline/scheme_reader_audit.py +931 -0
  35. cdxml_toolkit/deterministic_pipeline/scheme_reader_verify.py +1160 -0
  36. cdxml_toolkit/image/__init__.py +15 -0
  37. cdxml_toolkit/image/reaction_from_image.py +2103 -0
  38. cdxml_toolkit/image/structure_from_image.py +1711 -0
  39. cdxml_toolkit/layout/__init__.py +5 -0
  40. cdxml_toolkit/layout/alignment.py +1642 -0
  41. cdxml_toolkit/layout/reaction_cleanup.py +1002 -0
  42. cdxml_toolkit/layout/scheme_merger.py +2260 -0
  43. cdxml_toolkit/mcp_server/__init__.py +0 -0
  44. cdxml_toolkit/mcp_server/__main__.py +5 -0
  45. cdxml_toolkit/mcp_server/server.py +1567 -0
  46. cdxml_toolkit/naming/__init__.py +6 -0
  47. cdxml_toolkit/naming/aligned_namer.py +2342 -0
  48. cdxml_toolkit/naming/mol_builder.py +3722 -0
  49. cdxml_toolkit/naming/name_decomposer.py +2843 -0
  50. cdxml_toolkit/naming/reactions_datamol.json +2414 -0
  51. cdxml_toolkit/office/__init__.py +5 -0
  52. cdxml_toolkit/office/doc_from_template.py +722 -0
  53. cdxml_toolkit/office/ole_embedder.py +808 -0
  54. cdxml_toolkit/office/ole_extractor.py +272 -0
  55. cdxml_toolkit/perception/__init__.py +10 -0
  56. cdxml_toolkit/perception/compound_search.py +229 -0
  57. cdxml_toolkit/perception/eln_csv_parser.py +240 -0
  58. cdxml_toolkit/perception/rdf_parser.py +664 -0
  59. cdxml_toolkit/perception/reactant_heuristic.py +1045 -0
  60. cdxml_toolkit/perception/reaction_parser.py +2150 -0
  61. cdxml_toolkit/perception/scheme_reader.py +2948 -0
  62. cdxml_toolkit/perception/scheme_refine.py +1404 -0
  63. cdxml_toolkit/perception/scheme_segmenter.py +619 -0
  64. cdxml_toolkit/perception/spatial_assignment.py +1013 -0
  65. cdxml_toolkit/rdkit_utils.py +605 -0
  66. cdxml_toolkit/render/__init__.py +17 -0
  67. cdxml_toolkit/render/auto_layout.py +229 -0
  68. cdxml_toolkit/render/compact_parser.py +632 -0
  69. cdxml_toolkit/render/parser.py +706 -0
  70. cdxml_toolkit/render/render_scheme.py +267 -0
  71. cdxml_toolkit/render/renderer.py +2387 -0
  72. cdxml_toolkit/render/schema.py +90 -0
  73. cdxml_toolkit/render/scheme_maker.py +1043 -0
  74. cdxml_toolkit/render/scheme_yaml_writer.py +1487 -0
  75. cdxml_toolkit/resolve/__init__.py +13 -0
  76. cdxml_toolkit/resolve/cas_resolver.py +430 -0
  77. cdxml_toolkit/resolve/chemscanner_abbreviations.json +28813 -0
  78. cdxml_toolkit/resolve/condensed_formula.py +493 -0
  79. cdxml_toolkit/resolve/jre_manager.py +195 -0
  80. cdxml_toolkit/resolve/reagent_abbreviations.json +1046 -0
  81. cdxml_toolkit/resolve/reagent_db.py +285 -0
  82. cdxml_toolkit/resolve/superatom_data.json +2856 -0
  83. cdxml_toolkit/resolve/superatom_table.py +146 -0
  84. cdxml_toolkit/text_formatting.py +298 -0
  85. cdxml_toolkit-0.5.0.dist-info/METADATA +318 -0
  86. cdxml_toolkit-0.5.0.dist-info/RECORD +91 -0
  87. cdxml_toolkit-0.5.0.dist-info/WHEEL +5 -0
  88. cdxml_toolkit-0.5.0.dist-info/entry_points.txt +17 -0
  89. cdxml_toolkit-0.5.0.dist-info/licenses/LICENSE +21 -0
  90. cdxml_toolkit-0.5.0.dist-info/licenses/NOTICE.md +37 -0
  91. cdxml_toolkit-0.5.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1642 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ alignment.py -- Shared alignment functions for reaction scheme polishing.
4
+
5
+ Provides two independent product-alignment strategies:
6
+
7
+ * **Kabsch** -- rigid-body 2D rotation computed from matched atom
8
+ coordinates (requires ChemScript for MOL export + RDKit for MCS).
9
+ * **RDKit MCS** -- per-bond re-depiction via
10
+ ``GenerateDepictionMatching2DStructure`` (RDKit only, no ChemScript).
11
+
12
+ Both are consumed by ``scheme_polisher.py`` and ``scheme_polisher_v2.py``.
13
+ All geometry primitives (centroid, rotation, coordinate helpers) live here
14
+ so there is exactly one copy of each.
15
+
16
+ Layers
17
+ ------
18
+ 1. Geometry primitives (stdlib only -- no RDKit / ChemScript)
19
+ 2. Kabsch alignment (ChemScript + RDKit, lazy imports)
20
+ 3. RDKit MCS alignment (RDKit only, lazy imports)
21
+ """
22
+
23
+ import argparse
24
+ import copy
25
+ import math
26
+ import os
27
+ import sys
28
+ import tempfile
29
+ import xml.etree.ElementTree as ET
30
+ from typing import Dict, List, Optional, Set, Tuple
31
+
32
+ from ..constants import ACS_BOND_LENGTH, CDXML_MINIMAL_HEADER
33
+ from ..cdxml_utils import write_cdxml
34
+
35
+
36
+ # ============================================================================
37
+ # LAYER 1 -- Geometry primitives (stdlib only)
38
+ # ============================================================================
39
+
40
+ # ---------------------------------------------------------------------------
41
+ # Fragment <-> CDXML wrapping
42
+ # ---------------------------------------------------------------------------
43
+
44
+ def sp_fragment_to_cdxml(frag: ET.Element) -> str:
45
+ """Wrap a single <fragment> element in a minimal CDXML document.
46
+
47
+ Used to pass individual fragments to ChemScript for MOL block export.
48
+ """
49
+ frag_xml = ET.tostring(frag, encoding="unicode")
50
+ lines = [
51
+ CDXML_MINIMAL_HEADER,
52
+ '<page id="1">',
53
+ frag_xml,
54
+ '</page>',
55
+ '</CDXML>',
56
+ ]
57
+ return "\n".join(lines)
58
+
59
+
60
+ def filtered_atom_nodes(frag: ET.Element) -> List[ET.Element]:
61
+ """Return only real atom <n> nodes from a fragment, filtering out
62
+ ExternalConnectionPoint, Fragment, and Unspecified pseudo-nodes."""
63
+ return [n for n in frag.iter("n")
64
+ if n.get("NodeType") not in
65
+ ("ExternalConnectionPoint", "Fragment", "Unspecified")]
66
+
67
+
68
+ # ---------------------------------------------------------------------------
69
+ # Centroid / position helpers
70
+ # ---------------------------------------------------------------------------
71
+
72
+ def fragment_centroid(frag: ET.Element) -> Tuple[float, float]:
73
+ """Compute centroid from direct-child node positions."""
74
+ xs, ys = [], []
75
+ for n in frag.findall("n"):
76
+ p = n.get("p")
77
+ if p:
78
+ parts = p.split()
79
+ if len(parts) >= 2:
80
+ xs.append(float(parts[0]))
81
+ ys.append(float(parts[1]))
82
+ if not xs:
83
+ return 0.0, 0.0
84
+ return sum(xs) / len(xs), sum(ys) / len(ys)
85
+
86
+
87
+ def get_visible_carbon_positions(frag: ET.Element) -> List[Tuple[float, float]]:
88
+ """Extract positions of visible carbon atoms for Kabsch alignment.
89
+
90
+ Only uses carbon atoms (no Element attribute = carbon by CDXML convention)
91
+ that are direct children of the fragment and are regular atoms (no
92
+ NodeType attribute). This excludes:
93
+ - Heteroatoms (N, O, S, halogens -- they have Element="7", "8", etc.)
94
+ - Abbreviation group nodes (NodeType="Fragment")
95
+ - External connection points (NodeType="ExternalConnectionPoint")
96
+ - Atoms inside abbreviation inner fragments
97
+
98
+ Carbon backbone positions are the most geometrically stable reference
99
+ points for orientation matching, unaffected by label rendering or
100
+ abbreviation expansion.
101
+ """
102
+ positions = []
103
+ for n in frag.findall("n"): # direct children only
104
+ if n.get("NodeType"):
105
+ continue
106
+ if n.get("Element"):
107
+ continue
108
+ p = n.get("p")
109
+ if p:
110
+ parts = p.split()
111
+ if len(parts) >= 2:
112
+ positions.append((float(parts[0]), float(parts[1])))
113
+ return positions
114
+
115
+
116
+ # ---------------------------------------------------------------------------
117
+ # Kabsch 2D rotation
118
+ # ---------------------------------------------------------------------------
119
+
120
+ def compute_rigid_rotation_2d(
121
+ old_pts: List[Tuple[float, float]],
122
+ new_pts: List[Tuple[float, float]],
123
+ ) -> Tuple[float, float]:
124
+ """Compute the optimal 2D rotation from matched point pairs (Kabsch).
125
+
126
+ Returns (cos_a, sin_a). Only the rotation component is computed;
127
+ translation is discarded because we want to keep fragments in place.
128
+ """
129
+ n = len(old_pts)
130
+ if n < 2:
131
+ return (1.0, 0.0)
132
+
133
+ ocx = sum(p[0] for p in old_pts) / n
134
+ ocy = sum(p[1] for p in old_pts) / n
135
+ ncx = sum(p[0] for p in new_pts) / n
136
+ ncy = sum(p[1] for p in new_pts) / n
137
+
138
+ s_xx = s_yy = s_xy = s_yx = 0.0
139
+ for (ox, oy), (nx, ny) in zip(old_pts, new_pts):
140
+ dx_o, dy_o = ox - ocx, oy - ocy
141
+ dx_n, dy_n = nx - ncx, ny - ncy
142
+ s_xx += dx_o * dx_n
143
+ s_yy += dy_o * dy_n
144
+ s_xy += dx_o * dy_n
145
+ s_yx += dy_o * dx_n
146
+
147
+ angle = math.atan2(s_xy - s_yx, s_xx + s_yy)
148
+ return (math.cos(angle), math.sin(angle))
149
+
150
+
151
+ def match_and_compute_rotation(
152
+ src_positions: List[Tuple[float, float]],
153
+ tgt_positions: List[Tuple[float, float]],
154
+ ) -> Tuple[float, float, float]:
155
+ """Match atoms by normalized nearest-neighbor and compute Kabsch rotation.
156
+
157
+ Finds the 1:1 correspondence between two sets of positions for the
158
+ same molecule (e.g. before/after ChemScript cleanup), then computes
159
+ the optimal rotation from src -> tgt orientation.
160
+
161
+ Returns (cos_a, sin_a, angle_degrees).
162
+ """
163
+ n = len(src_positions)
164
+ if n < 3 or len(tgt_positions) != n:
165
+ return (1.0, 0.0, 0.0)
166
+
167
+ # Center and normalize both sets so nearest-neighbor works
168
+ # across different coordinate scales
169
+ def _center_norm(pts):
170
+ cx = sum(p[0] for p in pts) / len(pts)
171
+ cy = sum(p[1] for p in pts) / len(pts)
172
+ centered = [(x - cx, y - cy) for x, y in pts]
173
+ scale = max(max(abs(x), abs(y)) for x, y in centered) or 1.0
174
+ return [(x / scale, y / scale) for x, y in centered]
175
+
176
+ src_n = _center_norm(src_positions)
177
+ tgt_n = _center_norm(tgt_positions)
178
+
179
+ # Greedy nearest-neighbor matching
180
+ used = set()
181
+ matched_src = []
182
+ matched_tgt = []
183
+ for si, (sx, sy) in enumerate(src_n):
184
+ best_ti = -1
185
+ best_d2 = float("inf")
186
+ for ti, (tx, ty) in enumerate(tgt_n):
187
+ if ti in used:
188
+ continue
189
+ d2 = (sx - tx) ** 2 + (sy - ty) ** 2
190
+ if d2 < best_d2:
191
+ best_d2 = d2
192
+ best_ti = ti
193
+ if best_ti >= 0:
194
+ matched_src.append(src_positions[si])
195
+ matched_tgt.append(tgt_positions[best_ti])
196
+ used.add(best_ti)
197
+
198
+ if len(matched_src) < 3:
199
+ return (1.0, 0.0, 0.0)
200
+
201
+ # Compute Kabsch rotation from src -> tgt
202
+ cos_a, sin_a = compute_rigid_rotation_2d(matched_src, matched_tgt)
203
+ angle_deg = math.degrees(math.atan2(sin_a, cos_a))
204
+ return (cos_a, sin_a, angle_deg)
205
+
206
+
207
+ # ---------------------------------------------------------------------------
208
+ # In-place coordinate rotation
209
+ # ---------------------------------------------------------------------------
210
+
211
+ def rotate_fragment_in_place(
212
+ frag: ET.Element,
213
+ cos_a: float, sin_a: float,
214
+ cx: float, cy: float,
215
+ ) -> None:
216
+ """Rotate all coordinates in a fragment around (cx, cy).
217
+
218
+ Updates all descendant <n> positions, <t> label positions and
219
+ BoundingBoxes, fragment-level BoundingBox, and inner-fragment
220
+ BoundingBoxes (abbreviation groups) -- all in-place.
221
+ """
222
+ def rot(x: float, y: float) -> Tuple[float, float]:
223
+ dx, dy = x - cx, y - cy
224
+ return (cos_a * dx - sin_a * dy + cx,
225
+ sin_a * dx + cos_a * dy + cy)
226
+
227
+ def rotate_bb(bb_str: str) -> str:
228
+ vals = [float(v) for v in bb_str.split()]
229
+ if len(vals) < 4:
230
+ return bb_str
231
+ corners = [(vals[0], vals[1]), (vals[2], vals[1]),
232
+ (vals[0], vals[3]), (vals[2], vals[3])]
233
+ rotated = [rot(x, y) for x, y in corners]
234
+ return (f"{min(r[0] for r in rotated):.2f} "
235
+ f"{min(r[1] for r in rotated):.2f} "
236
+ f"{max(r[0] for r in rotated):.2f} "
237
+ f"{max(r[1] for r in rotated):.2f}")
238
+
239
+ # Rotate all node positions (all descendants)
240
+ for n in frag.iter("n"):
241
+ p = n.get("p")
242
+ if p:
243
+ parts = p.split()
244
+ if len(parts) >= 2:
245
+ nx, ny = rot(float(parts[0]), float(parts[1]))
246
+ n.set("p", f"{nx:.2f} {ny:.2f}")
247
+
248
+ # Rotate text labels
249
+ for t in frag.iter("t"):
250
+ p = t.get("p")
251
+ if p:
252
+ parts = p.split()
253
+ if len(parts) >= 2:
254
+ nx, ny = rot(float(parts[0]), float(parts[1]))
255
+ t.set("p", f"{nx:.2f} {ny:.2f}")
256
+ bb = t.get("BoundingBox")
257
+ if bb:
258
+ t.set("BoundingBox", rotate_bb(bb))
259
+
260
+ # Fragment-level BoundingBox
261
+ fb = frag.get("BoundingBox")
262
+ if fb:
263
+ frag.set("BoundingBox", rotate_bb(fb))
264
+
265
+ # Inner fragment BoundingBoxes (abbreviation groups)
266
+ for inner in frag.iter("fragment"):
267
+ if inner is not frag:
268
+ ib = inner.get("BoundingBox")
269
+ if ib:
270
+ inner.set("BoundingBox", rotate_bb(ib))
271
+
272
+
273
+ # ---------------------------------------------------------------------------
274
+ # Abbreviation dummy copy
275
+ # ---------------------------------------------------------------------------
276
+
277
+ def make_abbrev_dummy_copy(frag: ET.Element) -> ET.Element:
278
+ """Create a deep copy of a fragment with abbreviation nodes replaced by
279
+ dummy atoms (Iodine, Element=53).
280
+
281
+ This ensures that ChemScript MOL export and ``filtered_atom_nodes``
282
+ see exactly the same atoms -- abbreviation inner fragments are stripped
283
+ so their child atoms don't pollute the MOL-to-CDXML mapping. The
284
+ dummy atom preserves the abbreviation node's position.
285
+ """
286
+ work = copy.deepcopy(frag)
287
+ for n in work.findall("n"):
288
+ if n.get("NodeType") != "Fragment":
289
+ continue
290
+ # Strip inner fragment children + label
291
+ for child in list(n):
292
+ n.remove(child)
293
+ for attr in ("NodeType", "LabelDisplay", "NeedsClean",
294
+ "AS", "Warning"):
295
+ if attr in n.attrib:
296
+ del n.attrib[attr]
297
+ n.set("Element", "53") # Iodine dummy
298
+ n.set("NumHydrogens", "0")
299
+ return work
300
+
301
+
302
+ # ---------------------------------------------------------------------------
303
+ # Coordinate translation
304
+ # ---------------------------------------------------------------------------
305
+
306
+ def translate_subtree(elem: ET.Element, dx: float, dy: float) -> None:
307
+ """Recursively shift all p and BoundingBox attributes by (dx, dy)."""
308
+ p = elem.get("p")
309
+ if p:
310
+ parts = p.split()
311
+ if len(parts) >= 2:
312
+ elem.set("p",
313
+ f"{float(parts[0])+dx:.2f} {float(parts[1])+dy:.2f}")
314
+
315
+ bb = elem.get("BoundingBox")
316
+ if bb:
317
+ parts = bb.split()
318
+ if len(parts) == 4:
319
+ elem.set("BoundingBox",
320
+ f"{float(parts[0])+dx:.2f} {float(parts[1])+dy:.2f} "
321
+ f"{float(parts[2])+dx:.2f} {float(parts[3])+dy:.2f}")
322
+
323
+ for child in elem:
324
+ translate_subtree(child, dx, dy)
325
+
326
+
327
+ # ============================================================================
328
+ # LAYER 2 -- Kabsch product alignment (ChemScript + RDKit, lazy imports)
329
+ # ============================================================================
330
+
331
+ def kabsch_align_fragment_to_product(
332
+ reagent_frag: ET.Element,
333
+ product_frag: ET.Element,
334
+ cs_bridge,
335
+ verbose: bool = False,
336
+ ) -> bool:
337
+ """Align a reagent fragment's orientation to match its substructure in
338
+ the product using rigid-body Kabsch rotation.
339
+
340
+ Strategy:
341
+ 1. Create work copies with abbreviations replaced by dummy atoms
342
+ (Iodine) so that ChemScript MOL export and the CDXML atom list
343
+ have identical atom counts and consistent coordinates.
344
+ 2. Export both work copies to MOL blocks (via ChemScript) for RDKit.
345
+ 3. Find the atom correspondence via substructure match or MCS.
346
+ 4. Use the matched atom indices to pair up CDXML coordinates
347
+ (both already in the same y-down coordinate system).
348
+ 5. Compute the optimal rigid rotation via Kabsch.
349
+ 6. Apply the rotation in-place to the **original** reagent fragment
350
+ (including abbreviation inner fragments).
351
+
352
+ Returns True on success, False on failure.
353
+ """
354
+ def log(msg: str):
355
+ if verbose:
356
+ print(f"[alignment] {msg}", file=sys.stderr)
357
+
358
+ try:
359
+ from rdkit import Chem
360
+ from rdkit.Chem import rdFMCS
361
+ except ImportError:
362
+ log(" RDKit not available, skipping Kabsch alignment")
363
+ return False
364
+
365
+ label = f"frag {reagent_frag.get('id', '?')}"
366
+
367
+ # --- Create work copies with abbreviations replaced by dummies ---
368
+ reagent_work = make_abbrev_dummy_copy(reagent_frag)
369
+ product_work = make_abbrev_dummy_copy(product_frag)
370
+
371
+ # --- Get MOL blocks from ChemScript ---
372
+ reagent_cdxml = sp_fragment_to_cdxml(reagent_work)
373
+ product_cdxml = sp_fragment_to_cdxml(product_work)
374
+
375
+ r_tmp = p_tmp = None
376
+ try:
377
+ with tempfile.NamedTemporaryFile(
378
+ suffix=".cdxml", mode="w", delete=False, encoding="utf-8"
379
+ ) as f:
380
+ f.write(reagent_cdxml)
381
+ r_tmp = f.name
382
+ with tempfile.NamedTemporaryFile(
383
+ suffix=".cdxml", mode="w", delete=False, encoding="utf-8"
384
+ ) as f:
385
+ f.write(product_cdxml)
386
+ p_tmp = f.name
387
+
388
+ try:
389
+ r_mol_block = cs_bridge.write_data(r_tmp, "chemical/x-mdl-molfile")
390
+ p_mol_block = cs_bridge.write_data(p_tmp, "chemical/x-mdl-molfile")
391
+ except Exception as exc:
392
+ log(f" {label}: ChemScript MOL export failed: {exc}")
393
+ return False
394
+ finally:
395
+ for tmp in (r_tmp, p_tmp):
396
+ if tmp:
397
+ try:
398
+ os.unlink(tmp)
399
+ except OSError:
400
+ pass
401
+
402
+ # --- Parse in RDKit ---
403
+ reagent_mol = Chem.MolFromMolBlock(r_mol_block, sanitize=False)
404
+ if reagent_mol:
405
+ try:
406
+ Chem.SanitizeMol(reagent_mol)
407
+ except Exception:
408
+ Chem.SanitizeMol(
409
+ reagent_mol,
410
+ Chem.SanitizeFlags.SANITIZE_ALL
411
+ ^ Chem.SanitizeFlags.SANITIZE_KEKULIZE,
412
+ )
413
+ product_mol = Chem.MolFromMolBlock(p_mol_block, sanitize=False)
414
+ if product_mol:
415
+ try:
416
+ Chem.SanitizeMol(
417
+ product_mol,
418
+ Chem.SanitizeFlags.SANITIZE_ALL
419
+ ^ Chem.SanitizeFlags.SANITIZE_KEKULIZE,
420
+ )
421
+ except Exception:
422
+ pass
423
+
424
+ if reagent_mol is None or product_mol is None:
425
+ log(f" {label}: RDKit couldn't parse MOL blocks")
426
+ return False
427
+
428
+ # --- Build MOL-to-CDXML index mapping ---
429
+ # ChemScript may reorder atoms when exporting to MOL block, so
430
+ # RDKit atom indices may not correspond to CDXML <n> iteration order.
431
+ # We build the mapping by matching MOL block coordinates (y-up) to
432
+ # CDXML node coordinates (y-down, in points) via nearest-neighbor.
433
+ # Use work copies (abbreviations replaced with dummies) so that the
434
+ # CDXML atom list matches the MOL block atom list exactly.
435
+ r_real = filtered_atom_nodes(reagent_work)
436
+ p_real = filtered_atom_nodes(product_work)
437
+
438
+ def _build_mol_to_cdxml_map(mol, cdxml_nodes):
439
+ """Map MOL block atom index -> CDXML filtered-node index by
440
+ matching coordinates. MOL coords are in Angstroms (y-up),
441
+ CDXML coords are in points (y-down). We normalise by
442
+ centering both sets and matching by relative position."""
443
+ n = mol.GetNumAtoms()
444
+ if n != len(cdxml_nodes):
445
+ return None
446
+
447
+ # MOL positions (y-up)
448
+ conf = mol.GetConformer()
449
+ mol_pts = []
450
+ for i in range(n):
451
+ pos = conf.GetAtomPosition(i)
452
+ mol_pts.append((pos.x, -pos.y)) # flip y to y-down
453
+
454
+ # CDXML positions (already y-down)
455
+ cdxml_pts = []
456
+ for node in cdxml_nodes:
457
+ p = node.get("p", "")
458
+ if p:
459
+ parts = p.split()
460
+ cdxml_pts.append((float(parts[0]), float(parts[1])))
461
+ else:
462
+ cdxml_pts.append((0.0, 0.0))
463
+
464
+ # Center and normalise both to unit scale so nearest-neighbour
465
+ # works across the Angstrom (MOL) / point (CDXML) scale gap.
466
+ def _center_and_normalise(pts):
467
+ cx = sum(p[0] for p in pts) / n
468
+ cy = sum(p[1] for p in pts) / n
469
+ centred = [(x - cx, y - cy) for x, y in pts]
470
+ scale = max(max(abs(x), abs(y)) for x, y in centred) or 1.0
471
+ return [(x / scale, y / scale) for x, y in centred]
472
+
473
+ mol_n = _center_and_normalise(mol_pts)
474
+ cdxml_n = _center_and_normalise(cdxml_pts)
475
+
476
+ # Greedy nearest-neighbour matching
477
+ used = set()
478
+ mapping = {} # mol_idx -> cdxml_idx
479
+ for mi, (mx, my) in enumerate(mol_n):
480
+ best_ci = -1
481
+ best_d2 = float("inf")
482
+ for ci, (cx, cy) in enumerate(cdxml_n):
483
+ if ci in used:
484
+ continue
485
+ d2 = (mx - cx) ** 2 + (my - cy) ** 2
486
+ if d2 < best_d2:
487
+ best_d2 = d2
488
+ best_ci = ci
489
+ if best_ci >= 0:
490
+ mapping[mi] = best_ci
491
+ used.add(best_ci)
492
+ return mapping if len(mapping) == n else None
493
+
494
+ r_mol_to_cdxml = _build_mol_to_cdxml_map(reagent_mol, r_real)
495
+ p_mol_to_cdxml = _build_mol_to_cdxml_map(product_mol, p_real)
496
+
497
+ if r_mol_to_cdxml is None or p_mol_to_cdxml is None:
498
+ log(f" {label}: couldn't build MOL-to-CDXML atom mapping")
499
+ return False
500
+
501
+ # --- Find atom correspondence via RDKit ---
502
+ r_match = None # reagent MOL indices in the match
503
+ p_match = None # product MOL indices in the match
504
+
505
+ if product_mol.HasSubstructMatch(reagent_mol):
506
+ p_match_tuple = product_mol.GetSubstructMatch(reagent_mol)
507
+ r_match = tuple(range(reagent_mol.GetNumAtoms()))
508
+ p_match = p_match_tuple
509
+ log(f" {label}: full substructure match ({len(r_match)} atoms)")
510
+ else:
511
+ log(f" {label}: no full substructure match, trying MCS...")
512
+ mcs = rdFMCS.FindMCS(
513
+ [reagent_mol, product_mol],
514
+ threshold=1.0,
515
+ ringMatchesRingOnly=True,
516
+ completeRingsOnly=True,
517
+ timeout=5,
518
+ )
519
+ if mcs.canceled or mcs.numAtoms < 3:
520
+ log(f" {label}: MCS too small ({mcs.numAtoms} atoms), skipping")
521
+ return False
522
+
523
+ mcs_mol = Chem.MolFromSmarts(mcs.smartsString)
524
+ if mcs_mol is None:
525
+ log(f" {label}: couldn't parse MCS SMARTS, skipping")
526
+ return False
527
+
528
+ r_match = reagent_mol.GetSubstructMatch(mcs_mol)
529
+ p_match = product_mol.GetSubstructMatch(mcs_mol)
530
+ if not r_match or not p_match:
531
+ log(f" {label}: MCS match failed, skipping")
532
+ return False
533
+ log(f" {label}: using MCS ({mcs.numAtoms} atoms) for alignment")
534
+
535
+ # --- Build matched coordinate pairs from CDXML (y-down) ---
536
+ def _node_pos(nodes, idx):
537
+ p = nodes[idx].get("p", "")
538
+ if p:
539
+ parts = p.split()
540
+ if len(parts) >= 2:
541
+ return (float(parts[0]), float(parts[1]))
542
+ return None
543
+
544
+ reagent_pts = []
545
+ product_pts = []
546
+ for r_mol_idx, p_mol_idx in zip(r_match, p_match):
547
+ # Translate MOL indices to CDXML indices
548
+ r_cdxml_idx = r_mol_to_cdxml.get(r_mol_idx)
549
+ p_cdxml_idx = p_mol_to_cdxml.get(p_mol_idx)
550
+ if r_cdxml_idx is None or p_cdxml_idx is None:
551
+ continue
552
+ rp = _node_pos(r_real, r_cdxml_idx)
553
+ pp = _node_pos(p_real, p_cdxml_idx)
554
+ if rp and pp:
555
+ # Weight heteroatoms (non-carbon) 3x so they dominate the
556
+ # rotation for symmetric rings like morpholine, where the
557
+ # 4 near-equivalent carbons can outvote 2 heteroatoms and
558
+ # produce a wrong Kabsch solution.
559
+ is_hetero = r_real[r_cdxml_idx].get("Element", "") not in ("", "C")
560
+ copies = 3 if is_hetero else 1
561
+ for _ in range(copies):
562
+ reagent_pts.append(rp)
563
+ product_pts.append(pp)
564
+
565
+ if len(reagent_pts) < 3:
566
+ log(f" {label}: too few matched points ({len(reagent_pts)}), skipping")
567
+ return False
568
+
569
+ # --- Compute rotation via Kabsch ---
570
+ cos_a, sin_a = compute_rigid_rotation_2d(reagent_pts, product_pts)
571
+ angle_deg = math.degrees(math.atan2(sin_a, cos_a))
572
+
573
+ if abs(angle_deg) < 5.0:
574
+ log(f" {label}: rotation {angle_deg:.1f} deg < 5 deg, already aligned")
575
+ return False
576
+
577
+ # --- Apply rotation around reagent centroid ---
578
+ all_r_pts = []
579
+ for n in r_real:
580
+ rp = n.get("p", "")
581
+ if rp:
582
+ parts = rp.split()
583
+ if len(parts) >= 2:
584
+ all_r_pts.append((float(parts[0]), float(parts[1])))
585
+ cx = sum(p[0] for p in all_r_pts) / len(all_r_pts)
586
+ cy = sum(p[1] for p in all_r_pts) / len(all_r_pts)
587
+
588
+ rotate_fragment_in_place(reagent_frag, cos_a, sin_a, cx, cy)
589
+ log(f" {label}: rotated {angle_deg:.1f} deg around "
590
+ f"({cx:.1f}, {cy:.1f})")
591
+ return True
592
+
593
+
594
+ def kabsch_align_to_product(
595
+ root: ET.Element,
596
+ cs_bridge=None,
597
+ verbose: bool = False,
598
+ frag_ids: Optional[Set[str]] = None,
599
+ ) -> List[str]:
600
+ """Align fragments to product orientation using Kabsch rigid rotation.
601
+
602
+ Reads ``<scheme><step>`` metadata to identify the product. For each
603
+ eligible non-product fragment, computes a rigid 2D rotation via MCS
604
+ + Kabsch and applies it in-place.
605
+
606
+ Parameters
607
+ ----------
608
+ root : ET.Element
609
+ Parsed CDXML root element (modified in-place).
610
+ cs_bridge : ChemScriptBridge or None
611
+ If supplied, reuses an already-open bridge (avoids spinning up a
612
+ second subprocess). If *None*, creates and closes its own.
613
+ verbose : bool
614
+ Print progress to stderr.
615
+ frag_ids : set of str or None
616
+ Restrict alignment to these fragment IDs. If *None*, all
617
+ non-product fragments in the step are eligible.
618
+
619
+ Returns
620
+ -------
621
+ list of str
622
+ IDs of fragments that were actually rotated.
623
+ """
624
+ def log(msg: str):
625
+ if verbose:
626
+ print(f"[alignment] {msg}", file=sys.stderr)
627
+
628
+ page = root.find("page")
629
+ if page is None:
630
+ return []
631
+
632
+ # Build fragment lookup
633
+ id_to_el: Dict[str, ET.Element] = {}
634
+ for el in page:
635
+ eid = el.get("id", "")
636
+ if eid:
637
+ id_to_el[eid] = el
638
+
639
+ # Parse <scheme><step> metadata
640
+ steps = root.findall(".//step")
641
+ if not steps:
642
+ log("No reaction steps found, skipping Kabsch alignment")
643
+ return []
644
+
645
+ # Use first step
646
+ step = steps[0]
647
+ product_ids = set(step.get("ReactionStepProducts", "").split())
648
+ reactant_ids = set(step.get("ReactionStepReactants", "").split())
649
+ above_ids = set(step.get("ReactionStepObjectsAboveArrow", "").split())
650
+ below_ids = set(step.get("ReactionStepObjectsBelowArrow", "").split())
651
+
652
+ # Find the product fragment
653
+ product_frag = None
654
+ for pid in product_ids:
655
+ el = id_to_el.get(pid)
656
+ if el is not None and el.tag == "fragment":
657
+ product_frag = el
658
+ break
659
+
660
+ if product_frag is None:
661
+ log("No product fragment found, skipping Kabsch alignment")
662
+ return []
663
+
664
+ # Determine which fragments to align
665
+ if frag_ids is not None:
666
+ eligible = frag_ids - product_ids
667
+ else:
668
+ # All non-product fragments in the step
669
+ all_step_ids = (reactant_ids | above_ids | below_ids) - product_ids
670
+ eligible = {fid for fid in all_step_ids
671
+ if fid in id_to_el and id_to_el[fid].tag == "fragment"}
672
+
673
+ if not eligible:
674
+ log("No eligible fragments to align")
675
+ return []
676
+
677
+ log(f"Kabsch aligning {len(eligible)} fragment(s) to product...")
678
+
679
+ # Ensure ChemScript bridge
680
+ owns_bridge = cs_bridge is None
681
+ if owns_bridge:
682
+ try:
683
+ from ..chemdraw.chemscript_bridge import ChemScriptBridge
684
+ cs_bridge = ChemScriptBridge()
685
+ except Exception as exc:
686
+ log(f"WARNING: ChemScript unavailable ({exc}), "
687
+ f"skipping Kabsch alignment")
688
+ return []
689
+
690
+ aligned = []
691
+ try:
692
+ for fid in sorted(eligible):
693
+ frag_el = id_to_el.get(fid)
694
+ if frag_el is None:
695
+ continue
696
+ try:
697
+ success = kabsch_align_fragment_to_product(
698
+ frag_el, product_frag, cs_bridge, verbose)
699
+ if success:
700
+ aligned.append(fid)
701
+ except Exception as exc:
702
+ log(f" Fragment {fid}: alignment error: {exc}")
703
+ finally:
704
+ if owns_bridge and cs_bridge is not None:
705
+ try:
706
+ cs_bridge.close()
707
+ except Exception:
708
+ pass
709
+
710
+ log(f"Kabsch aligned {len(aligned)} fragment(s)")
711
+ return aligned
712
+
713
+
714
+ # ============================================================================
715
+ # LAYER 3 -- RDKit MCS alignment (RDKit only, lazy imports)
716
+ # ============================================================================
717
+
718
+ _HAS_RDKIT: Optional[bool] = None # lazy detection
719
+
720
+
721
+ def _check_rdkit() -> bool:
722
+ """Check if RDKit is available. Caches result."""
723
+ global _HAS_RDKIT
724
+ if _HAS_RDKIT is None:
725
+ try:
726
+ from rdkit import Chem # noqa: F401
727
+ _HAS_RDKIT = True
728
+ except ImportError:
729
+ _HAS_RDKIT = False
730
+ return _HAS_RDKIT
731
+
732
+
733
+ def _frag_to_mol(frag_elem: ET.Element):
734
+ """Convert a CDXML <fragment> to an RDKit Mol with atom metadata.
735
+
736
+ Returns (mol, atoms_data) where atoms_data is a list of dicts with
737
+ keys: id, idx, x, y, elem, num_h, is_abbrev, xml.
738
+
739
+ Abbreviation groups (NodeType="Fragment") become dummy atoms (element 0)
740
+ so they participate in connectivity but not MCS element matching.
741
+
742
+ Returns (None, None) if conversion fails.
743
+ """
744
+ from rdkit import Chem
745
+
746
+ atoms: List[dict] = []
747
+ id_map: Dict[int, int] = {}
748
+
749
+ for n in frag_elem.findall("n"):
750
+ nid = int(n.get("id"))
751
+ if n.get("NodeType") == "ExternalConnectionPoint":
752
+ continue
753
+
754
+ px, py = [float(v) for v in n.get("p", "0 0").split()]
755
+ elem = int(n.get("Element", "6"))
756
+ num_h_attr = n.get("NumHydrogens")
757
+ num_h = int(num_h_attr) if num_h_attr is not None else None
758
+ is_abbrev = n.get("NodeType") == "Fragment"
759
+
760
+ idx = len(atoms)
761
+ id_map[nid] = idx
762
+ atoms.append({
763
+ "id": nid, "idx": idx,
764
+ "x": px, "y": py,
765
+ "elem": elem, "num_h": num_h,
766
+ "is_abbrev": is_abbrev,
767
+ "xml": n,
768
+ })
769
+
770
+ bonds = []
771
+ for b in frag_elem.findall("b"):
772
+ bi, ei = int(b.get("B")), int(b.get("E"))
773
+ if bi in id_map and ei in id_map:
774
+ bonds.append((id_map[bi], id_map[ei], int(b.get("Order", "1"))))
775
+
776
+ em = Chem.RWMol()
777
+ for a in atoms:
778
+ ra = Chem.Atom(0 if a["is_abbrev"] else a["elem"])
779
+ if a["num_h"] is not None:
780
+ ra.SetNoImplicit(True)
781
+ ra.SetNumExplicitHs(a["num_h"])
782
+ em.AddAtom(ra)
783
+
784
+ BT = {1: Chem.BondType.SINGLE, 2: Chem.BondType.DOUBLE,
785
+ 3: Chem.BondType.TRIPLE}
786
+ for bi, ei, order in bonds:
787
+ em.AddBond(bi, ei, BT.get(order, Chem.BondType.SINGLE))
788
+
789
+ mol = em.GetMol()
790
+ try:
791
+ Chem.SanitizeMol(mol)
792
+ except Exception:
793
+ try:
794
+ Chem.SanitizeMol(
795
+ mol,
796
+ Chem.SanitizeFlags.SANITIZE_ALL
797
+ ^ Chem.SanitizeFlags.SANITIZE_PROPERTIES,
798
+ )
799
+ except Exception:
800
+ pass
801
+
802
+ return mol, atoms
803
+
804
+
805
+ _rdk_bl_cache: Optional[float] = None
806
+
807
+
808
+ def _rdkit_default_bond_length() -> float:
809
+ """RDKit's default 2D depiction bond length (cached)."""
810
+ global _rdk_bl_cache
811
+ if _rdk_bl_cache is None:
812
+ from rdkit import Chem
813
+ from rdkit.Chem import AllChem
814
+ m = Chem.MolFromSmiles("CC")
815
+ AllChem.Compute2DCoords(m)
816
+ c = m.GetConformer()
817
+ p0, p1 = c.GetAtomPosition(0), c.GetAtomPosition(1)
818
+ _rdk_bl_cache = math.sqrt(
819
+ (p1.x - p0.x) ** 2 + (p1.y - p0.y) ** 2)
820
+ return _rdk_bl_cache
821
+
822
+
823
+ def _avg_bond_length_from_atoms(atoms_data: List[dict], mol) -> float:
824
+ """Average bond length computed from CDXML atom coordinates."""
825
+ total, count = 0.0, 0
826
+ for bond in mol.GetBonds():
827
+ i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
828
+ dx = atoms_data[i]["x"] - atoms_data[j]["x"]
829
+ dy = atoms_data[i]["y"] - atoms_data[j]["y"]
830
+ total += math.sqrt(dx * dx + dy * dy)
831
+ count += 1
832
+ return total / count if count else ACS_BOND_LENGTH
833
+
834
+
835
+ def _set_cdxml_conformer(mol, atoms_data: List[dict], scale: float = 1.0):
836
+ """Set conformer from CDXML coordinates (y-flipped, scaled to RDKit space).
837
+
838
+ CDXML y-axis points down; RDKit y-axis points up. The scale factor
839
+ converts from CDXML points (~14.40 pt bond length) to RDKit units
840
+ (~1.5 unit bond length).
841
+ """
842
+ from rdkit import Chem
843
+ from rdkit.Geometry import Point3D
844
+
845
+ conf = Chem.Conformer(mol.GetNumAtoms())
846
+ for a in atoms_data:
847
+ conf.SetAtomPosition(
848
+ a["idx"], Point3D(a["x"] * scale, -a["y"] * scale, 0.0))
849
+ mol.RemoveAllConformers()
850
+ mol.AddConformer(conf, assignId=True)
851
+
852
+
853
+ def _find_mcs_match(ref_mol, target_mol, timeout: int = 30):
854
+ """Find MCS between ref and target molecules.
855
+
856
+ Returns atom_map as list of (ref_idx, tgt_idx) tuples,
857
+ or (None, 0) if MCS has fewer than 3 atoms.
858
+ """
859
+ from rdkit import Chem
860
+ from rdkit.Chem import rdFMCS
861
+
862
+ mcs = rdFMCS.FindMCS(
863
+ [ref_mol, target_mol],
864
+ timeout=timeout,
865
+ atomCompare=rdFMCS.AtomCompare.CompareElements,
866
+ bondCompare=rdFMCS.BondCompare.CompareOrder,
867
+ ringMatchesRingOnly=True,
868
+ completeRingsOnly=True,
869
+ )
870
+
871
+ if mcs.numAtoms < 3:
872
+ return None, 0
873
+
874
+ core = Chem.MolFromSmarts(mcs.smartsString)
875
+ if core is None:
876
+ return None, 0
877
+
878
+ ref_match = ref_mol.GetSubstructMatch(core)
879
+ target_match = target_mol.GetSubstructMatch(core)
880
+ if not ref_match or not target_match:
881
+ return None, 0
882
+
883
+ return list(zip(ref_match, target_match)), mcs.numAtoms
884
+
885
+
886
+ def _rdkit_align_and_write(
887
+ frag_elem: ET.Element,
888
+ ref_mol,
889
+ tgt_mol,
890
+ tgt_atoms: List[dict],
891
+ atom_map: list,
892
+ scale: float,
893
+ original_center: Tuple[float, float],
894
+ ) -> float:
895
+ """Align target mol to reference using GenerateDepictionMatching2DStructure,
896
+ then write aligned coordinates back to the CDXML fragment.
897
+
898
+ Parameters
899
+ ----------
900
+ frag_elem : ET.Element
901
+ The CDXML <fragment> element to update (modified in-place).
902
+ ref_mol : RDKit Mol
903
+ Reference molecule (product) with conformer set at RDKit scale.
904
+ tgt_mol : RDKit Mol
905
+ Target molecule (reactant/reagent) -- conformer will be generated.
906
+ tgt_atoms : list of dict
907
+ Atom metadata from _frag_to_mol (includes 'xml' references).
908
+ atom_map : list of (ref_idx, tgt_idx) tuples
909
+ MCS atom correspondence.
910
+ scale : float
911
+ CDXML-to-RDKit scale factor (rdk_bl / cdxml_bl).
912
+ original_center : (float, float)
913
+ Original fragment centroid in CDXML space (for center preservation).
914
+
915
+ Returns
916
+ -------
917
+ float
918
+ RMSD of MCS atoms after alignment (should be ~0).
919
+ """
920
+ from rdkit.Chem import rdDepictor
921
+
922
+ # Align: generates new 2D depiction for tgt_mol with MCS atoms
923
+ # constrained to match ref_mol's positions
924
+ rdDepictor.GenerateDepictionMatching2DStructure(
925
+ tgt_mol, ref_mol, atom_map)
926
+
927
+ # Compute RMSD for validation
928
+ rc = ref_mol.GetConformer()
929
+ tc = tgt_mol.GetConformer()
930
+ ss = sum(
931
+ (rc.GetAtomPosition(ri).x - tc.GetAtomPosition(ti).x) ** 2 +
932
+ (rc.GetAtomPosition(ri).y - tc.GetAtomPosition(ti).y) ** 2
933
+ for ri, ti in atom_map)
934
+ rmsd = math.sqrt(ss / len(atom_map)) if atom_map else 0.0
935
+
936
+ # Convert aligned RDKit coords back to CDXML space
937
+ conf = tgt_mol.GetConformer()
938
+ inv = 1.0 / scale
939
+
940
+ aligned = []
941
+ for a in tgt_atoms:
942
+ pos = conf.GetAtomPosition(a["idx"])
943
+ aligned.append((pos.x * inv, -pos.y * inv)) # scale back + flip y
944
+
945
+ # Translate to preserve original fragment center
946
+ acx = sum(p[0] for p in aligned) / len(aligned)
947
+ acy = sum(p[1] for p in aligned) / len(aligned)
948
+ gdx = original_center[0] - acx
949
+ gdy = original_center[1] - acy
950
+
951
+ for i, a in enumerate(tgt_atoms):
952
+ new_x = aligned[i][0] + gdx
953
+ new_y = aligned[i][1] + gdy
954
+ adx = new_x - a["x"]
955
+ ady = new_y - a["y"]
956
+
957
+ node = a["xml"]
958
+ node.set("p", f"{new_x:.2f} {new_y:.2f}")
959
+
960
+ # Shift all child elements (labels, inner fragments) by atom offset
961
+ for child in node:
962
+ translate_subtree(child, adx, ady)
963
+
964
+ # Recompute fragment BoundingBox from atom positions
965
+ xs, ys = [], []
966
+ for n in frag_elem.findall("n"):
967
+ if n.get("NodeType") == "ExternalConnectionPoint":
968
+ continue
969
+ p = n.get("p")
970
+ if p:
971
+ parts = p.split()
972
+ xs.append(float(parts[0]))
973
+ ys.append(float(parts[1]))
974
+ if xs and ys:
975
+ margin = 15.0
976
+ frag_elem.set(
977
+ "BoundingBox",
978
+ f"{min(xs)-margin:.2f} {min(ys)-margin:.2f} "
979
+ f"{max(xs)+margin:.2f} {max(ys)+margin:.2f}")
980
+
981
+ return rmsd
982
+
983
+
984
+ def align_product_to_reference(
985
+ root: ET.Element,
986
+ ref_cdxml_path: str,
987
+ verbose: bool = False,
988
+ timeout: int = 30,
989
+ ) -> bool:
990
+ """Align the product fragment to the best-matching structure in a
991
+ reference CDXML file.
992
+
993
+ The reference file should contain one or more "known good" structures
994
+ drawn with the desired orientation (e.g. from a group meeting slide).
995
+ The product is matched to whichever reference structure has the largest
996
+ MCS overlap, then aligned via GenerateDepictionMatching2DStructure.
997
+
998
+ Call this BEFORE rdkit_align_to_product() so that reactant alignment
999
+ uses the correctly-oriented product as its reference.
1000
+
1001
+ Parameters
1002
+ ----------
1003
+ root : ET.Element
1004
+ Parsed CDXML root element (modified in-place).
1005
+ ref_cdxml_path : str
1006
+ Path to reference CDXML with known-good structure(s).
1007
+ verbose : bool
1008
+ Print progress to stderr.
1009
+ timeout : int
1010
+ MCS timeout in seconds per comparison (default 30).
1011
+
1012
+ Returns
1013
+ -------
1014
+ bool
1015
+ True if the product was successfully aligned.
1016
+ """
1017
+ if not _check_rdkit():
1018
+ if verbose:
1019
+ print("[alignment] RDKit not available, skipping reference alignment",
1020
+ file=sys.stderr)
1021
+ return False
1022
+
1023
+ def log(msg):
1024
+ if verbose:
1025
+ print(f"[alignment] {msg}", file=sys.stderr)
1026
+
1027
+ page = root.find("page")
1028
+ if page is None:
1029
+ return False
1030
+
1031
+ # Build fragment lookup
1032
+ fragments = {}
1033
+ for f in page.findall("fragment"):
1034
+ fid = f.get("id")
1035
+ if fid:
1036
+ fragments[fid] = f
1037
+
1038
+ # Find the product from reaction step metadata
1039
+ steps = root.findall(".//step")
1040
+ if not steps:
1041
+ log("No reaction steps found")
1042
+ return False
1043
+
1044
+ step = steps[0]
1045
+ product_ids = step.get("ReactionStepProducts", "").split()
1046
+ prod_frag = None
1047
+ prod_id = None
1048
+ for pid in product_ids:
1049
+ if pid in fragments:
1050
+ prod_frag = fragments[pid]
1051
+ prod_id = pid
1052
+ break
1053
+
1054
+ if prod_frag is None:
1055
+ log("No product fragment found")
1056
+ return False
1057
+
1058
+ # Convert product to RDKit mol
1059
+ prod_result = _frag_to_mol(prod_frag)
1060
+ if prod_result is None or prod_result[0] is None:
1061
+ log("Product fragment conversion failed")
1062
+ return False
1063
+ prod_mol, prod_atoms = prod_result
1064
+ if prod_mol.GetNumAtoms() < 3:
1065
+ log("Product too small for reference alignment")
1066
+ return False
1067
+
1068
+ # Parse reference CDXML file (sanitize control chars — Findmolecule
1069
+ # embeds binary "Molecule ID" values with \x01, \x12 etc. that are
1070
+ # illegal in XML)
1071
+ import re as _re
1072
+ with open(ref_cdxml_path, "r", encoding="utf-8", errors="replace") as _f:
1073
+ raw = _f.read()
1074
+ raw = _re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f]', '', raw)
1075
+ ref_root = ET.fromstring(raw)
1076
+ ref_page = ref_root.find(".//page")
1077
+ if ref_page is None:
1078
+ log("No page in reference CDXML")
1079
+ return False
1080
+
1081
+ # Extract reference fragments as RDKit mols
1082
+ ref_entries = []
1083
+ for rf in ref_page.findall("fragment"):
1084
+ rf_result = _frag_to_mol(rf)
1085
+ if rf_result is None or rf_result[0] is None:
1086
+ continue
1087
+ rf_mol, rf_atoms = rf_result
1088
+ if rf_mol.GetNumAtoms() < 3:
1089
+ continue
1090
+ ref_entries.append((rf, rf_mol, rf_atoms))
1091
+
1092
+ if not ref_entries:
1093
+ log("No usable fragments in reference CDXML")
1094
+ return False
1095
+
1096
+ log(f"Reference file: {len(ref_entries)} fragment(s)")
1097
+
1098
+ # Find best-matching reference by MCS size
1099
+ best_ref = None
1100
+ best_map = None
1101
+ best_n_mcs = 0
1102
+
1103
+ for rf_elem, rf_mol, rf_atoms in ref_entries:
1104
+ atom_map, n_mcs = _find_mcs_match(rf_mol, prod_mol, timeout)
1105
+ rf_id = rf_elem.get("id", "?")
1106
+ if atom_map is not None and n_mcs > best_n_mcs:
1107
+ best_n_mcs = n_mcs
1108
+ best_map = atom_map
1109
+ best_ref = (rf_elem, rf_mol, rf_atoms)
1110
+ log(f" Ref fragment {rf_id}: MCS = {n_mcs} atoms (new best)")
1111
+ else:
1112
+ log(f" Ref fragment {rf_id}: MCS = {n_mcs} atoms")
1113
+
1114
+ if best_ref is None:
1115
+ log("No reference with MCS >= 3 atoms")
1116
+ return False
1117
+
1118
+ rf_elem, rf_mol, rf_atoms = best_ref
1119
+
1120
+ # Set reference conformer at RDKit scale (preserving its drawn orientation)
1121
+ rf_bl = _avg_bond_length_from_atoms(rf_atoms, rf_mol)
1122
+ rdk_bl = _rdkit_default_bond_length()
1123
+ rf_scale = rdk_bl / rf_bl
1124
+ _set_cdxml_conformer(rf_mol, rf_atoms, rf_scale)
1125
+
1126
+ # Product centroid for center preservation
1127
+ n_prod = len(prod_atoms)
1128
+ original_center = (
1129
+ sum(a["x"] for a in prod_atoms) / n_prod,
1130
+ sum(a["y"] for a in prod_atoms) / n_prod,
1131
+ )
1132
+
1133
+ # Scale for writeback: after alignment, coords are at RDKit scale
1134
+ # (bond lengths ~ rdk_bl). Convert to ACS standard (14.40 pt).
1135
+ write_scale = rdk_bl / ACS_BOND_LENGTH
1136
+
1137
+ # Align product to reference and write back
1138
+ rmsd = _rdkit_align_and_write(
1139
+ prod_frag, rf_mol, prod_mol, prod_atoms,
1140
+ best_map, write_scale, original_center)
1141
+
1142
+ rf_id = rf_elem.get("id", "?")
1143
+ log(f"Product {prod_id} aligned to reference fragment {rf_id} "
1144
+ f"(MCS {best_n_mcs} atoms, RMSD {rmsd:.4f})")
1145
+ return True
1146
+
1147
+
1148
+ def rdkit_align_to_product(
1149
+ root: ET.Element,
1150
+ verbose: bool = False,
1151
+ timeout: int = 30,
1152
+ ) -> int:
1153
+ """Align all non-product fragments to the product's orientation.
1154
+
1155
+ Uses RDKit MCS + GenerateDepictionMatching2DStructure to align each
1156
+ reactant/reagent fragment to match the product's drawn orientation.
1157
+ Reads <scheme>/<step> metadata to identify roles.
1158
+
1159
+ Modifies CDXML coordinates in-place.
1160
+
1161
+ Returns the number of fragments successfully aligned.
1162
+ """
1163
+ if not _check_rdkit():
1164
+ if verbose:
1165
+ print("[alignment] RDKit not available, skipping RDKit alignment",
1166
+ file=sys.stderr)
1167
+ return 0
1168
+
1169
+ page = root.find("page")
1170
+ if page is None:
1171
+ return 0
1172
+
1173
+ # Build fragment lookup (id string -> element)
1174
+ fragments: Dict[str, ET.Element] = {}
1175
+ for f in page.findall("fragment"):
1176
+ fid = f.get("id")
1177
+ if fid:
1178
+ fragments[fid] = f
1179
+
1180
+ # Parse reaction steps
1181
+ steps = []
1182
+ for s in root.findall(".//step"):
1183
+ def _ids(attr_name):
1184
+ return [x for x in s.get(attr_name, "").split() if x]
1185
+ steps.append({
1186
+ "reactants": _ids("ReactionStepReactants"),
1187
+ "products": _ids("ReactionStepProducts"),
1188
+ "above": _ids("ReactionStepObjectsAboveArrow"),
1189
+ "below": _ids("ReactionStepObjectsBelowArrow"),
1190
+ })
1191
+
1192
+ if not steps:
1193
+ if verbose:
1194
+ print("[alignment] No reaction steps found, skipping RDKit alignment",
1195
+ file=sys.stderr)
1196
+ return 0
1197
+
1198
+ aligned_count = 0
1199
+
1200
+ for si, step in enumerate(steps):
1201
+ if not step["products"]:
1202
+ continue
1203
+
1204
+ # Product = reference for this step
1205
+ prod_id = step["products"][0]
1206
+ if prod_id not in fragments:
1207
+ continue
1208
+
1209
+ prod_mol, prod_atoms = _frag_to_mol(fragments[prod_id])
1210
+ if prod_mol is None or prod_mol.GetNumAtoms() < 2:
1211
+ continue
1212
+
1213
+ # Compute scale: CDXML bond length -> RDKit bond length
1214
+ cdxml_bl = _avg_bond_length_from_atoms(prod_atoms, prod_mol)
1215
+ rdk_bl = _rdkit_default_bond_length()
1216
+ scale = rdk_bl / cdxml_bl
1217
+
1218
+ # Set product conformer at RDKit scale (the reference orientation)
1219
+ _set_cdxml_conformer(prod_mol, prod_atoms, scale)
1220
+
1221
+ if verbose:
1222
+ print(f"[alignment] Step {si+1}: product = fragment {prod_id} "
1223
+ f"({prod_mol.GetNumAtoms()} atoms, "
1224
+ f"bond length {cdxml_bl:.1f} pts)",
1225
+ file=sys.stderr)
1226
+
1227
+ # Collect all other fragment IDs in this step
1228
+ other_ids = []
1229
+ for fid in (step["reactants"] + step["above"] + step["below"]):
1230
+ if fid in fragments and fid != prod_id and fid not in other_ids:
1231
+ other_ids.append(fid)
1232
+
1233
+ for fid in other_ids:
1234
+ frag_elem = fragments[fid]
1235
+ frag_mol, frag_atoms = _frag_to_mol(frag_elem)
1236
+ if frag_mol is None or frag_mol.GetNumAtoms() < 2:
1237
+ continue
1238
+
1239
+ # Original centroid for center preservation
1240
+ n_atoms = len(frag_atoms)
1241
+ original_center = (
1242
+ sum(a["x"] for a in frag_atoms) / n_atoms,
1243
+ sum(a["y"] for a in frag_atoms) / n_atoms,
1244
+ )
1245
+
1246
+ # Find MCS with product
1247
+ atom_map, n_mcs = _find_mcs_match(
1248
+ prod_mol, frag_mol, timeout)
1249
+ if atom_map is None:
1250
+ if verbose:
1251
+ print(f"[alignment] Fragment {fid}: MCS < 3 atoms, "
1252
+ f"skipping RDKit alignment",
1253
+ file=sys.stderr)
1254
+ continue
1255
+
1256
+ # Align and write back
1257
+ rmsd = _rdkit_align_and_write(
1258
+ frag_elem, prod_mol, frag_mol, frag_atoms,
1259
+ atom_map, scale, original_center)
1260
+
1261
+ aligned_count += 1
1262
+ if verbose:
1263
+ print(f"[alignment] Fragment {fid}: aligned to product "
1264
+ f"(MCS {n_mcs} atoms, RMSD {rmsd:.4f})",
1265
+ file=sys.stderr)
1266
+
1267
+ return aligned_count
1268
+
1269
+
1270
+ # ============================================================================
1271
+ # LAYER 4 -- RXNMapper-based alignment (RDKit + rxn-experiments subprocess)
1272
+ # ============================================================================
1273
+
1274
+ def rxnmapper_align_to_product(
1275
+ root: ET.Element,
1276
+ verbose: bool = False,
1277
+ timeout: int = 120,
1278
+ ) -> int:
1279
+ """Align non-product fragments using RXNMapper atom maps.
1280
+
1281
+ Like rdkit_align_to_product(), but uses transformer-based atom mapping
1282
+ instead of RDKit MCS to find atom correspondence. RXNMapper understands
1283
+ reaction chemistry, so the atom correspondence reflects actual bond
1284
+ formation/breaking rather than purely structural overlap.
1285
+
1286
+ Falls back to MCS alignment for fragments where RXNMapper fails
1287
+ (e.g. if the fragment isn't in the mapped output).
1288
+
1289
+ Modifies CDXML coordinates in-place.
1290
+ Returns the number of fragments successfully aligned.
1291
+ """
1292
+ if not _check_rdkit():
1293
+ if verbose:
1294
+ print("[alignment] RDKit not available, skipping RXNMapper alignment",
1295
+ file=sys.stderr)
1296
+ return 0
1297
+
1298
+ from rdkit import Chem
1299
+
1300
+ def log(msg):
1301
+ if verbose:
1302
+ print(f"[alignment] {msg}", file=sys.stderr)
1303
+
1304
+ page = root.find("page")
1305
+ if page is None:
1306
+ return 0
1307
+
1308
+ # Build fragment lookup
1309
+ fragments: Dict[str, ET.Element] = {}
1310
+ for f in page.findall("fragment"):
1311
+ fid = f.get("id")
1312
+ if fid:
1313
+ fragments[fid] = f
1314
+
1315
+ # Parse reaction steps
1316
+ steps = []
1317
+ for s in root.findall(".//step"):
1318
+ def _ids(attr_name):
1319
+ return [x for x in s.get(attr_name, "").split() if x]
1320
+ steps.append({
1321
+ "reactants": _ids("ReactionStepReactants"),
1322
+ "products": _ids("ReactionStepProducts"),
1323
+ "above": _ids("ReactionStepObjectsAboveArrow"),
1324
+ "below": _ids("ReactionStepObjectsBelowArrow"),
1325
+ })
1326
+
1327
+ if not steps:
1328
+ log("No reaction steps found")
1329
+ return 0
1330
+
1331
+ aligned_count = 0
1332
+
1333
+ for si, step in enumerate(steps):
1334
+ if not step["products"]:
1335
+ continue
1336
+
1337
+ prod_id = step["products"][0]
1338
+ if prod_id not in fragments:
1339
+ continue
1340
+
1341
+ prod_result = _frag_to_mol(fragments[prod_id])
1342
+ if prod_result is None:
1343
+ continue
1344
+ prod_mol, prod_atoms = prod_result
1345
+ if prod_mol is None or prod_mol.GetNumAtoms() < 2:
1346
+ continue
1347
+
1348
+ # Compute scale
1349
+ cdxml_bl = _avg_bond_length_from_atoms(prod_atoms, prod_mol)
1350
+ rdk_bl = _rdkit_default_bond_length()
1351
+ scale = rdk_bl / cdxml_bl
1352
+
1353
+ # Set product conformer at RDKit scale
1354
+ _set_cdxml_conformer(prod_mol, prod_atoms, scale)
1355
+
1356
+ # Get product SMILES
1357
+ prod_smi = Chem.MolToSmiles(prod_mol)
1358
+
1359
+ log(f"Step {si+1}: product = fragment {prod_id} "
1360
+ f"({prod_mol.GetNumAtoms()} atoms, SMILES={prod_smi[:50]})")
1361
+
1362
+ # Collect all other fragment IDs in this step
1363
+ other_ids = []
1364
+ for fid in (step["reactants"] + step["above"] + step["below"]):
1365
+ if fid in fragments and fid != prod_id and fid not in other_ids:
1366
+ other_ids.append(fid)
1367
+
1368
+ # Convert all fragments to mols + SMILES
1369
+ frag_data = {} # fid -> (mol, atoms, smiles)
1370
+ for fid in other_ids:
1371
+ frag_result = _frag_to_mol(fragments[fid])
1372
+ if frag_result is None:
1373
+ continue
1374
+ frag_mol, frag_atoms = frag_result
1375
+ if frag_mol is None or frag_mol.GetNumAtoms() < 2:
1376
+ continue
1377
+ frag_smi = Chem.MolToSmiles(frag_mol)
1378
+ if not frag_smi or "*" in frag_smi:
1379
+ # Skip fragments with dummy atoms (abbreviation groups)
1380
+ log(f" Fragment {fid}: SMILES has wildcards, skipping RXNMapper")
1381
+ continue
1382
+ frag_data[fid] = (frag_mol, frag_atoms, frag_smi)
1383
+
1384
+ if not frag_data:
1385
+ continue
1386
+
1387
+ # Build reaction SMILES: all_fragments >> product
1388
+ reactant_side = ".".join(d[2] for d in frag_data.values())
1389
+ rxn_smi = f"{reactant_side}>>{prod_smi}"
1390
+
1391
+ log(f" RXN SMILES: {rxn_smi[:100]}...")
1392
+
1393
+ # Call RXNMapper
1394
+ try:
1395
+ from experiments.atom_mapping.rxn_atom_mapper import map_reaction
1396
+ map_result = map_reaction(rxn_smi, timeout=timeout)
1397
+ except ImportError:
1398
+ log(" rxn_atom_mapper not importable, falling back to MCS")
1399
+ return rdkit_align_to_product(root, verbose=verbose)
1400
+ except Exception as exc:
1401
+ log(f" RXNMapper error: {exc}, falling back to MCS")
1402
+ return rdkit_align_to_product(root, verbose=verbose)
1403
+
1404
+ if map_result is None:
1405
+ log(" RXNMapper returned no results, falling back to MCS")
1406
+ return rdkit_align_to_product(root, verbose=verbose)
1407
+
1408
+ mapped_rxn = map_result["mapped_rxn"]
1409
+ confidence = map_result.get("confidence", 0)
1410
+ log(f" RXNMapper confidence: {confidence:.4f}")
1411
+
1412
+ # Parse mapped SMILES
1413
+ mapped_r_str, mapped_p_str = mapped_rxn.split(">>")
1414
+ mapped_reactants = mapped_r_str.split(".")
1415
+ mapped_products = mapped_p_str.split(".")
1416
+
1417
+ # Parse mapped product
1418
+ mapped_prod_mol = Chem.MolFromSmiles(mapped_products[0])
1419
+ if mapped_prod_mol is None:
1420
+ log(" Could not parse mapped product SMILES")
1421
+ continue
1422
+
1423
+ # Build map_number -> mapped_prod_atom_idx lookup
1424
+ prod_mapnum_to_mapped_idx = {}
1425
+ for atom in mapped_prod_mol.GetAtoms():
1426
+ mn = atom.GetAtomMapNum()
1427
+ if mn > 0:
1428
+ prod_mapnum_to_mapped_idx[mn] = atom.GetIdx()
1429
+
1430
+ # Bridge: mapped product atoms -> CDXML product atoms via substructure match
1431
+ mapped_prod_clean = Chem.RWMol(mapped_prod_mol)
1432
+ for atom in mapped_prod_clean.GetAtoms():
1433
+ atom.SetAtomMapNum(0)
1434
+ try:
1435
+ Chem.SanitizeMol(mapped_prod_clean)
1436
+ except Exception:
1437
+ pass
1438
+
1439
+ prod_match = prod_mol.GetSubstructMatch(mapped_prod_clean)
1440
+ if not prod_match:
1441
+ # Try reverse match
1442
+ prod_match_rev = Chem.Mol(mapped_prod_clean).GetSubstructMatch(prod_mol)
1443
+ if prod_match_rev:
1444
+ # Invert: mapped_idx -> prod_mol_idx
1445
+ prod_match = [0] * mapped_prod_clean.GetNumAtoms()
1446
+ for prod_idx, mapped_idx in enumerate(prod_match_rev):
1447
+ if mapped_idx < len(prod_match):
1448
+ prod_match[mapped_idx] = prod_idx
1449
+ prod_match = tuple(prod_match)
1450
+ else:
1451
+ log(" Product substructure match failed, falling back to MCS")
1452
+ return rdkit_align_to_product(root, verbose=verbose)
1453
+
1454
+ # For each fragment, find its mapped correspondence and align
1455
+ for fid, (frag_mol, frag_atoms, frag_smi) in frag_data.items():
1456
+ frag_canon = Chem.MolToSmiles(
1457
+ Chem.MolFromSmiles(frag_smi))
1458
+
1459
+ # Find this fragment in the mapped reactants
1460
+ # (RXNMapper reorders reactants!)
1461
+ mapped_frag_smi = None
1462
+ for mr_smi in mapped_reactants:
1463
+ mr_mol = Chem.MolFromSmiles(mr_smi)
1464
+ if mr_mol is None:
1465
+ continue
1466
+ mr_clean = Chem.RWMol(mr_mol)
1467
+ for atom in mr_clean.GetAtoms():
1468
+ atom.SetAtomMapNum(0)
1469
+ mr_canon = Chem.MolToSmiles(mr_clean)
1470
+ if mr_canon == frag_canon:
1471
+ mapped_frag_smi = mr_smi
1472
+ break
1473
+
1474
+ if mapped_frag_smi is None:
1475
+ log(f" Fragment {fid}: not found in mapped output, "
1476
+ "falling back to MCS")
1477
+ # Per-fragment MCS fallback
1478
+ atom_map, n_mcs = _find_mcs_match(prod_mol, frag_mol, 30)
1479
+ if atom_map is None:
1480
+ log(f" Fragment {fid}: MCS also < 3 atoms, skipping")
1481
+ continue
1482
+ n_atoms = len(frag_atoms)
1483
+ original_center = (
1484
+ sum(a["x"] for a in frag_atoms) / n_atoms,
1485
+ sum(a["y"] for a in frag_atoms) / n_atoms,
1486
+ )
1487
+ _rdkit_align_and_write(
1488
+ fragments[fid], prod_mol, frag_mol, frag_atoms,
1489
+ atom_map, scale, original_center)
1490
+ aligned_count += 1
1491
+ continue
1492
+
1493
+ # Parse mapped fragment
1494
+ mapped_frag_mol = Chem.MolFromSmiles(mapped_frag_smi)
1495
+ if mapped_frag_mol is None:
1496
+ continue
1497
+
1498
+ frag_mapnum_to_mapped_idx = {}
1499
+ for atom in mapped_frag_mol.GetAtoms():
1500
+ mn = atom.GetAtomMapNum()
1501
+ if mn > 0:
1502
+ frag_mapnum_to_mapped_idx[mn] = atom.GetIdx()
1503
+
1504
+ # Bridge: mapped fragment atoms -> CDXML fragment atoms
1505
+ mapped_frag_clean = Chem.RWMol(mapped_frag_mol)
1506
+ for atom in mapped_frag_clean.GetAtoms():
1507
+ atom.SetAtomMapNum(0)
1508
+ try:
1509
+ Chem.SanitizeMol(mapped_frag_clean)
1510
+ except Exception:
1511
+ pass
1512
+
1513
+ frag_match = frag_mol.GetSubstructMatch(mapped_frag_clean)
1514
+ if not frag_match:
1515
+ frag_match_rev = Chem.Mol(mapped_frag_clean).GetSubstructMatch(
1516
+ frag_mol)
1517
+ if frag_match_rev:
1518
+ frag_match = [0] * mapped_frag_clean.GetNumAtoms()
1519
+ for fi, mi in enumerate(frag_match_rev):
1520
+ if mi < len(frag_match):
1521
+ frag_match[mi] = fi
1522
+ frag_match = tuple(frag_match)
1523
+ else:
1524
+ log(f" Fragment {fid}: substruct match failed, "
1525
+ "falling back to MCS")
1526
+ atom_map, n_mcs = _find_mcs_match(prod_mol, frag_mol, 30)
1527
+ if atom_map is not None:
1528
+ n_atoms = len(frag_atoms)
1529
+ original_center = (
1530
+ sum(a["x"] for a in frag_atoms) / n_atoms,
1531
+ sum(a["y"] for a in frag_atoms) / n_atoms,
1532
+ )
1533
+ _rdkit_align_and_write(
1534
+ fragments[fid], prod_mol, frag_mol, frag_atoms,
1535
+ atom_map, scale, original_center)
1536
+ aligned_count += 1
1537
+ continue
1538
+
1539
+ # Build atom_map: (prod_mol_idx, frag_mol_idx)
1540
+ atom_map = []
1541
+ shared_maps = (set(frag_mapnum_to_mapped_idx.keys()) &
1542
+ set(prod_mapnum_to_mapped_idx.keys()))
1543
+
1544
+ for mn in shared_maps:
1545
+ mapped_frag_idx = frag_mapnum_to_mapped_idx[mn]
1546
+ mapped_prod_idx = prod_mapnum_to_mapped_idx[mn]
1547
+
1548
+ if (mapped_frag_idx < len(frag_match) and
1549
+ mapped_prod_idx < len(prod_match)):
1550
+ cdxml_frag_idx = frag_match[mapped_frag_idx]
1551
+ cdxml_prod_idx = prod_match[mapped_prod_idx]
1552
+ atom_map.append((cdxml_prod_idx, cdxml_frag_idx))
1553
+
1554
+ if len(atom_map) < 3:
1555
+ log(f" Fragment {fid}: only {len(atom_map)} shared maps, "
1556
+ "falling back to MCS")
1557
+ mcs_map, n_mcs = _find_mcs_match(prod_mol, frag_mol, 30)
1558
+ if mcs_map is not None:
1559
+ n_atoms = len(frag_atoms)
1560
+ original_center = (
1561
+ sum(a["x"] for a in frag_atoms) / n_atoms,
1562
+ sum(a["y"] for a in frag_atoms) / n_atoms,
1563
+ )
1564
+ _rdkit_align_and_write(
1565
+ fragments[fid], prod_mol, frag_mol, frag_atoms,
1566
+ mcs_map, scale, original_center)
1567
+ aligned_count += 1
1568
+ continue
1569
+
1570
+ # Compute original centroid
1571
+ n_atoms = len(frag_atoms)
1572
+ original_center = (
1573
+ sum(a["x"] for a in frag_atoms) / n_atoms,
1574
+ sum(a["y"] for a in frag_atoms) / n_atoms,
1575
+ )
1576
+
1577
+ # Align!
1578
+ rmsd = _rdkit_align_and_write(
1579
+ fragments[fid], prod_mol, frag_mol, frag_atoms,
1580
+ atom_map, scale, original_center)
1581
+
1582
+ aligned_count += 1
1583
+ log(f" Fragment {fid}: aligned to product "
1584
+ f"(RXNMapper {len(atom_map)} atom maps, RMSD {rmsd:.4f})")
1585
+
1586
+ return aligned_count
1587
+
1588
+
1589
+ # ---------------------------------------------------------------------------
1590
+ # CLI
1591
+ # ---------------------------------------------------------------------------
1592
+
1593
+ def main(argv: Optional[List[str]] = None) -> int:
1594
+ """Align reaction scheme fragments to the product orientation."""
1595
+ parser = argparse.ArgumentParser(
1596
+ description="Align CDXML reaction scheme fragments to match product orientation.",
1597
+ )
1598
+ parser.add_argument("input", help="Input CDXML file with a reaction scheme")
1599
+ parser.add_argument("-o", "--output", help="Output CDXML path (default: input-aligned.cdxml)")
1600
+ parser.add_argument("--ref", help="Reference CDXML for product orientation (optional)")
1601
+ parser.add_argument("--method", choices=["rdkit", "kabsch"], default="rdkit",
1602
+ help="Alignment method (default: rdkit)")
1603
+ parser.add_argument("--timeout", type=int, default=30,
1604
+ help="MCS timeout in seconds (default: 30)")
1605
+ parser.add_argument("-v", "--verbose", action="store_true")
1606
+
1607
+ args = parser.parse_args(argv)
1608
+
1609
+ if not os.path.isfile(args.input):
1610
+ print(f"Error: file not found: {args.input}", file=sys.stderr)
1611
+ return 1
1612
+
1613
+ tree = ET.parse(args.input)
1614
+ root = tree.getroot()
1615
+
1616
+ # Optional: align product to reference first
1617
+ if args.ref:
1618
+ if not os.path.isfile(args.ref):
1619
+ print(f"Error: reference file not found: {args.ref}", file=sys.stderr)
1620
+ return 1
1621
+ align_product_to_reference(root, args.ref, verbose=args.verbose,
1622
+ timeout=args.timeout)
1623
+
1624
+ # Align fragments to product
1625
+ if args.method == "rdkit":
1626
+ count = rdkit_align_to_product(root, verbose=args.verbose,
1627
+ timeout=args.timeout)
1628
+ else:
1629
+ aligned = kabsch_align_to_product(root, verbose=args.verbose)
1630
+ count = len(aligned)
1631
+
1632
+ base, ext = os.path.splitext(args.input)
1633
+ out_path = args.output or f"{base}-aligned{ext}"
1634
+
1635
+ write_cdxml(tree, out_path)
1636
+
1637
+ print(f"Aligned {count} fragment(s) -> {out_path}")
1638
+ return 0
1639
+
1640
+
1641
+ if __name__ == "__main__":
1642
+ sys.exit(main())