cdxml-toolkit 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (91) hide show
  1. cdxml_toolkit/__init__.py +18 -0
  2. cdxml_toolkit/_jre/__init__.py +2 -0
  3. cdxml_toolkit/_jre/temurin-21-jre-win-x64.zip +0 -0
  4. cdxml_toolkit/analysis/__init__.py +35 -0
  5. cdxml_toolkit/analysis/deterministic/__init__.py +12 -0
  6. cdxml_toolkit/analysis/deterministic/discover_experiment_files.py +413 -0
  7. cdxml_toolkit/analysis/deterministic/lab_book_formatter.py +701 -0
  8. cdxml_toolkit/analysis/deterministic/lcms_file_categorizer.py +928 -0
  9. cdxml_toolkit/analysis/deterministic/lcms_identifier.py +598 -0
  10. cdxml_toolkit/analysis/deterministic/mass_resolver.py +654 -0
  11. cdxml_toolkit/analysis/deterministic/multi_lcms_analyzer.py +1412 -0
  12. cdxml_toolkit/analysis/deterministic/procedure_writer.py +446 -0
  13. cdxml_toolkit/analysis/extract_nmr.py +47 -0
  14. cdxml_toolkit/analysis/format_procedure_entry.py +479 -0
  15. cdxml_toolkit/analysis/lcms_analyzer.py +1299 -0
  16. cdxml_toolkit/analysis/parse_analysis_file.py +134 -0
  17. cdxml_toolkit/cdxml_builder.py +920 -0
  18. cdxml_toolkit/cdxml_utils.py +342 -0
  19. cdxml_toolkit/chemdraw/__init__.py +5 -0
  20. cdxml_toolkit/chemdraw/_chemscript_server.py +562 -0
  21. cdxml_toolkit/chemdraw/cdx_converter.py +527 -0
  22. cdxml_toolkit/chemdraw/cdxml_to_image.py +262 -0
  23. cdxml_toolkit/chemdraw/cdxml_to_image_rdkit.py +296 -0
  24. cdxml_toolkit/chemdraw/chemscript_bridge.py +901 -0
  25. cdxml_toolkit/constants.py +304 -0
  26. cdxml_toolkit/coord_normalizer.py +438 -0
  27. cdxml_toolkit/deterministic_pipeline/__init__.py +6 -0
  28. cdxml_toolkit/deterministic_pipeline/legacy/__init__.py +5 -0
  29. cdxml_toolkit/deterministic_pipeline/legacy/eln_cdx_cleanup.py +509 -0
  30. cdxml_toolkit/deterministic_pipeline/legacy/eln_enrichment.py +1394 -0
  31. cdxml_toolkit/deterministic_pipeline/legacy/scheme_aligner.py +428 -0
  32. cdxml_toolkit/deterministic_pipeline/legacy/scheme_polisher.py +1337 -0
  33. cdxml_toolkit/deterministic_pipeline/legacy/scheme_polisher_v2.py +1340 -0
  34. cdxml_toolkit/deterministic_pipeline/scheme_reader_audit.py +931 -0
  35. cdxml_toolkit/deterministic_pipeline/scheme_reader_verify.py +1160 -0
  36. cdxml_toolkit/image/__init__.py +15 -0
  37. cdxml_toolkit/image/reaction_from_image.py +2103 -0
  38. cdxml_toolkit/image/structure_from_image.py +1711 -0
  39. cdxml_toolkit/layout/__init__.py +5 -0
  40. cdxml_toolkit/layout/alignment.py +1642 -0
  41. cdxml_toolkit/layout/reaction_cleanup.py +1002 -0
  42. cdxml_toolkit/layout/scheme_merger.py +2260 -0
  43. cdxml_toolkit/mcp_server/__init__.py +0 -0
  44. cdxml_toolkit/mcp_server/__main__.py +5 -0
  45. cdxml_toolkit/mcp_server/server.py +1567 -0
  46. cdxml_toolkit/naming/__init__.py +6 -0
  47. cdxml_toolkit/naming/aligned_namer.py +2342 -0
  48. cdxml_toolkit/naming/mol_builder.py +3722 -0
  49. cdxml_toolkit/naming/name_decomposer.py +2843 -0
  50. cdxml_toolkit/naming/reactions_datamol.json +2414 -0
  51. cdxml_toolkit/office/__init__.py +5 -0
  52. cdxml_toolkit/office/doc_from_template.py +722 -0
  53. cdxml_toolkit/office/ole_embedder.py +808 -0
  54. cdxml_toolkit/office/ole_extractor.py +272 -0
  55. cdxml_toolkit/perception/__init__.py +10 -0
  56. cdxml_toolkit/perception/compound_search.py +229 -0
  57. cdxml_toolkit/perception/eln_csv_parser.py +240 -0
  58. cdxml_toolkit/perception/rdf_parser.py +664 -0
  59. cdxml_toolkit/perception/reactant_heuristic.py +1045 -0
  60. cdxml_toolkit/perception/reaction_parser.py +2150 -0
  61. cdxml_toolkit/perception/scheme_reader.py +2948 -0
  62. cdxml_toolkit/perception/scheme_refine.py +1404 -0
  63. cdxml_toolkit/perception/scheme_segmenter.py +619 -0
  64. cdxml_toolkit/perception/spatial_assignment.py +1013 -0
  65. cdxml_toolkit/rdkit_utils.py +605 -0
  66. cdxml_toolkit/render/__init__.py +17 -0
  67. cdxml_toolkit/render/auto_layout.py +229 -0
  68. cdxml_toolkit/render/compact_parser.py +632 -0
  69. cdxml_toolkit/render/parser.py +706 -0
  70. cdxml_toolkit/render/render_scheme.py +267 -0
  71. cdxml_toolkit/render/renderer.py +2387 -0
  72. cdxml_toolkit/render/schema.py +90 -0
  73. cdxml_toolkit/render/scheme_maker.py +1043 -0
  74. cdxml_toolkit/render/scheme_yaml_writer.py +1487 -0
  75. cdxml_toolkit/resolve/__init__.py +13 -0
  76. cdxml_toolkit/resolve/cas_resolver.py +430 -0
  77. cdxml_toolkit/resolve/chemscanner_abbreviations.json +28813 -0
  78. cdxml_toolkit/resolve/condensed_formula.py +493 -0
  79. cdxml_toolkit/resolve/jre_manager.py +195 -0
  80. cdxml_toolkit/resolve/reagent_abbreviations.json +1046 -0
  81. cdxml_toolkit/resolve/reagent_db.py +285 -0
  82. cdxml_toolkit/resolve/superatom_data.json +2856 -0
  83. cdxml_toolkit/resolve/superatom_table.py +146 -0
  84. cdxml_toolkit/text_formatting.py +298 -0
  85. cdxml_toolkit-0.5.0.dist-info/METADATA +318 -0
  86. cdxml_toolkit-0.5.0.dist-info/RECORD +91 -0
  87. cdxml_toolkit-0.5.0.dist-info/WHEEL +5 -0
  88. cdxml_toolkit-0.5.0.dist-info/entry_points.txt +17 -0
  89. cdxml_toolkit-0.5.0.dist-info/licenses/LICENSE +21 -0
  90. cdxml_toolkit-0.5.0.dist-info/licenses/NOTICE.md +37 -0
  91. cdxml_toolkit-0.5.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1013 @@
1
+ """
2
+ spatial_assignment.py — Geometry-first spatial assignment of scheme elements to arrows.
3
+
4
+ Replaces the naive x-band assignment in scheme_reader._parse_from_geometry()
5
+ with a rotation-invariant, distance-based approach that handles arbitrary arrow
6
+ orientations, multi-row layouts, branching, and cycles.
7
+
8
+ Algorithmic influences:
9
+ - ReactionDataExtractor (Cambridge, 2021/2023): arrow-centric equidistant scan
10
+ - RxnIM (HKUST, 2025): layout pattern taxonomy (single/multi-line/branch/cycle)
11
+ - CDXML advantage: exact coordinates from XML, no detection/OCR needed
12
+
13
+ Key design decisions:
14
+ - Arrow-relative projection: every point is transformed into (parallel, perp)
15
+ coordinates relative to each arrow. This makes role assignment
16
+ rotation-invariant.
17
+ - Distance-based assignment: fragments go to the nearest arrow by a combined
18
+ distance score, not by hard x-coordinate bands.
19
+ - Layout classifier: detects the scheme pattern first, then delegates to a
20
+ pattern-specific strategy that handles edge cases for that layout type.
21
+ - Confidence scoring: every assignment carries a 0-1 confidence based on
22
+ the ratio of nearest to second-nearest arrow distance.
23
+
24
+ API:
25
+ from cdxml_toolkit.perception.spatial_assignment import (
26
+ build_arrow_vectors, classify_layout, assign_elements,
27
+ )
28
+ arrows = build_arrow_vectors(page_element)
29
+ layout = classify_layout(arrows)
30
+ steps, results = assign_elements(arrows, page_element, layout)
31
+ """
32
+
33
+ from __future__ import annotations
34
+
35
+ import math
36
+ from dataclasses import dataclass, field
37
+ from enum import Enum
38
+ from typing import Dict, List, Optional, Set, Tuple
39
+ from xml.etree import ElementTree as ET
40
+
41
+ from ..constants import ACS_BOND_LENGTH
42
+
43
+
44
+ # ---------------------------------------------------------------------------
45
+ # Data structures
46
+ # ---------------------------------------------------------------------------
47
+
48
+ @dataclass
49
+ class ArrowVector:
50
+ """Fully characterised arrow with direction, type, and spatial metadata."""
51
+ element_id: str
52
+ element: ET.Element
53
+ tail: Tuple[float, float]
54
+ head: Tuple[float, float]
55
+ midpoint: Tuple[float, float]
56
+ direction: Tuple[float, float] # unit vector tail -> head
57
+ normal: Tuple[float, float] # perpendicular — points to "above" side
58
+ length: float
59
+ angle_deg: float # 0=right, 90=down, 180=left, 270=up
60
+ arrow_type: str # "solid", "dashed", "failed", "equilibrium"
61
+
62
+
63
+ class LayoutPattern(Enum):
64
+ SINGLE_LINE = "single_line"
65
+ MULTI_LINE = "multi_line"
66
+ BRANCH = "branch"
67
+ CYCLE = "cycle"
68
+ SERPENTINE = "serpentine"
69
+ MIXED = "mixed"
70
+
71
+
72
+ @dataclass
73
+ class FragmentInfo:
74
+ """Spatial metadata for a CDXML fragment."""
75
+ element_id: str
76
+ element: ET.Element
77
+ centroid: Tuple[float, float]
78
+ bbox: Tuple[float, float, float, float]
79
+
80
+
81
+ @dataclass
82
+ class TextInfo:
83
+ """Spatial metadata for a CDXML text element."""
84
+ element_id: str
85
+ element: ET.Element
86
+ position: Tuple[float, float]
87
+
88
+
89
+ @dataclass
90
+ class AssignmentResult:
91
+ """Single element-to-arrow assignment with confidence."""
92
+ element_id: str
93
+ arrow_id: str
94
+ role: str # "reactant", "product", "above", "below"
95
+ confidence: float # 0.0 – 1.0
96
+ distance: float # perpendicular distance to arrow axis
97
+
98
+
99
+ @dataclass
100
+ class RawStep:
101
+ """One reaction step derived from spatial assignment."""
102
+ arrow_id: str
103
+ arrow_element: Optional[ET.Element] = None
104
+ reactant_ids: List[str] = field(default_factory=list)
105
+ product_ids: List[str] = field(default_factory=list)
106
+ above_arrow_ids: List[str] = field(default_factory=list)
107
+ below_arrow_ids: List[str] = field(default_factory=list)
108
+ confidence: float = 1.0
109
+ layout_row: int = 0 # row index for multi-line layouts
110
+
111
+
112
+ # ---------------------------------------------------------------------------
113
+ # Geometry primitives
114
+ # ---------------------------------------------------------------------------
115
+
116
+ def _unit_vector(dx: float, dy: float) -> Tuple[float, float]:
117
+ """Normalise (dx, dy) to a unit vector. Returns (0, 0) for zero-length."""
118
+ mag = math.hypot(dx, dy)
119
+ if mag < 1e-9:
120
+ return (0.0, 0.0)
121
+ return (dx / mag, dy / mag)
122
+
123
+
124
+ def project_onto_arrow(
125
+ point: Tuple[float, float],
126
+ tail: Tuple[float, float],
127
+ head: Tuple[float, float],
128
+ ) -> Tuple[float, float]:
129
+ """Project *point* into the arrow-relative coordinate system.
130
+
131
+ Returns ``(parallel, perpendicular)`` where:
132
+ - *parallel*: signed distance along the arrow axis from the tail.
133
+ Negative = behind the tail, > arrow length = past the head.
134
+ - *perpendicular*: signed distance from the arrow axis.
135
+ Negative = "above" side, positive = "below" side.
136
+
137
+ Sign convention for perpendicular (CDXML y-axis points downward):
138
+ - For a horizontal L-to-R arrow: y < arrow → above → perp < 0
139
+ - For a vertical downward arrow: x > arrow → right side → perp < 0
140
+ (right side is "above" when the arrow points down)
141
+ """
142
+ dx = head[0] - tail[0]
143
+ dy = head[1] - tail[1]
144
+ length = math.hypot(dx, dy)
145
+ if length < 1e-9:
146
+ # Degenerate arrow — return distance from tail
147
+ dist = math.hypot(point[0] - tail[0], point[1] - tail[1])
148
+ return (0.0, dist)
149
+
150
+ # Direction unit vector
151
+ ux, uy = dx / length, dy / length
152
+ # Normal: rotate direction 90° counter-clockwise in math coords,
153
+ # which is clockwise on screen (y-down).
154
+ # For a rightward arrow (ux=1, uy=0) this gives (0, 1) i.e. downward.
155
+ # Convention: perpendicular < 0 = "above" side, > 0 = "below" side.
156
+ nx, ny = -uy, ux
157
+
158
+ # Vector from tail to point
159
+ vx = point[0] - tail[0]
160
+ vy = point[1] - tail[1]
161
+
162
+ parallel = vx * ux + vy * uy
163
+ perpendicular = vx * nx + vy * ny
164
+
165
+ return (parallel, perpendicular)
166
+
167
+
168
+ def point_to_segment_distance(
169
+ point: Tuple[float, float],
170
+ seg_start: Tuple[float, float],
171
+ seg_end: Tuple[float, float],
172
+ ) -> float:
173
+ """Shortest distance from *point* to the line segment [seg_start, seg_end]."""
174
+ sx, sy = seg_start
175
+ ex, ey = seg_end
176
+ dx, dy = ex - sx, ey - sy
177
+ len_sq = dx * dx + dy * dy
178
+
179
+ if len_sq < 1e-18:
180
+ return math.hypot(point[0] - sx, point[1] - sy)
181
+
182
+ # Parameter t for projection onto the infinite line
183
+ t = ((point[0] - sx) * dx + (point[1] - sy) * dy) / len_sq
184
+ t = max(0.0, min(1.0, t))
185
+
186
+ proj_x = sx + t * dx
187
+ proj_y = sy + t * dy
188
+ return math.hypot(point[0] - proj_x, point[1] - proj_y)
189
+
190
+
191
+ # ---------------------------------------------------------------------------
192
+ # Arrow vector construction
193
+ # ---------------------------------------------------------------------------
194
+
195
+ def _classify_arrow_type(arrow: ET.Element) -> str:
196
+ """Classify arrow type from CDXML attributes."""
197
+ if arrow.get("NoGo") == "Cross":
198
+ return "failed"
199
+ line_type = (arrow.get("LineType") or "").lower()
200
+ if line_type in ("dash", "dashed", "dot"):
201
+ return "dashed"
202
+ if (arrow.get("ArrowheadType") or "").lower() == "dashed":
203
+ return "dashed"
204
+ # Check for equilibrium arrows (double-headed)
205
+ arrow_type_attr = (arrow.get("ArrowType") or "").lower()
206
+ if "equilibrium" in arrow_type_attr:
207
+ return "equilibrium"
208
+ return "solid"
209
+
210
+
211
+ def build_arrow_vector(arrow: ET.Element) -> ArrowVector:
212
+ """Build an :class:`ArrowVector` from a CDXML ``<arrow>`` or ``<graphic>`` element."""
213
+ from ..cdxml_utils import arrow_endpoints
214
+
215
+ tx, ty, hx, hy = arrow_endpoints(arrow)
216
+
217
+ dx = hx - tx
218
+ dy = hy - ty
219
+ length = math.hypot(dx, dy)
220
+
221
+ direction = _unit_vector(dx, dy)
222
+ # Normal: perpendicular < 0 = "above" side, > 0 = "below" side.
223
+ # For a rightward arrow (1, 0) this gives (0, 1) pointing downward.
224
+ normal = (-direction[1], direction[0])
225
+
226
+ angle_rad = math.atan2(dy, dx)
227
+ angle_deg = math.degrees(angle_rad) % 360
228
+
229
+ return ArrowVector(
230
+ element_id=arrow.get("id", ""),
231
+ element=arrow,
232
+ tail=(tx, ty),
233
+ head=(hx, hy),
234
+ midpoint=((tx + hx) / 2, (ty + hy) / 2),
235
+ direction=direction,
236
+ normal=normal,
237
+ length=length,
238
+ angle_deg=angle_deg,
239
+ arrow_type=_classify_arrow_type(arrow),
240
+ )
241
+
242
+
243
+ def build_arrow_vectors(page: ET.Element) -> List[ArrowVector]:
244
+ """Find all arrows on the page and build ArrowVector objects.
245
+
246
+ Searches for ``<arrow>`` elements and ``<graphic>`` elements with
247
+ ``GraphicType="Line"`` and an ``ArrowType`` attribute.
248
+ """
249
+ seen: Set[str] = set()
250
+ arrows: List[ArrowVector] = []
251
+
252
+ for el in page:
253
+ if el.tag == "arrow":
254
+ eid = el.get("id", "")
255
+ if eid not in seen:
256
+ arrows.append(build_arrow_vector(el))
257
+ seen.add(eid)
258
+
259
+ for el in page:
260
+ if el.tag == "graphic":
261
+ if el.get("GraphicType") == "Line" and el.get("ArrowType"):
262
+ eid = el.get("id", "")
263
+ if eid not in seen:
264
+ arrows.append(build_arrow_vector(el))
265
+ seen.add(eid)
266
+
267
+ return arrows
268
+
269
+
270
+ # ---------------------------------------------------------------------------
271
+ # Fragment and text collection
272
+ # ---------------------------------------------------------------------------
273
+
274
+ def collect_fragments(page: ET.Element) -> List[FragmentInfo]:
275
+ """Collect all fragments on the page with spatial metadata."""
276
+ from ..cdxml_utils import fragment_bbox, fragment_centroid
277
+
278
+ frags: List[FragmentInfo] = []
279
+ for el in page:
280
+ if el.tag == "fragment":
281
+ centroid = fragment_centroid(el)
282
+ bbox = fragment_bbox(el)
283
+ if centroid is None or bbox is None:
284
+ continue
285
+ frags.append(FragmentInfo(
286
+ element_id=el.get("id", ""),
287
+ element=el,
288
+ centroid=centroid,
289
+ bbox=bbox,
290
+ ))
291
+ return frags
292
+
293
+
294
+ def collect_texts(page: ET.Element) -> List[TextInfo]:
295
+ """Collect all free text elements on the page with positions."""
296
+ texts: List[TextInfo] = []
297
+ for el in page:
298
+ if el.tag == "t":
299
+ p = el.get("p")
300
+ if p:
301
+ parts = p.split()
302
+ if len(parts) >= 2:
303
+ pos = (float(parts[0]), float(parts[1]))
304
+ texts.append(TextInfo(
305
+ element_id=el.get("id", ""),
306
+ element=el,
307
+ position=pos,
308
+ ))
309
+ continue
310
+ # Fallback: BoundingBox center
311
+ bb = el.get("BoundingBox", "")
312
+ if bb:
313
+ vals = [float(v) for v in bb.split()]
314
+ if len(vals) >= 4:
315
+ pos = ((vals[0] + vals[2]) / 2, (vals[1] + vals[3]) / 2)
316
+ texts.append(TextInfo(
317
+ element_id=el.get("id", ""),
318
+ element=el,
319
+ position=pos,
320
+ ))
321
+ return texts
322
+
323
+
324
+ # ---------------------------------------------------------------------------
325
+ # Layout classification
326
+ # ---------------------------------------------------------------------------
327
+
328
+ _HORIZONTAL_ANGLE_TOLERANCE = 30.0 # degrees from horizontal (0 or 180)
329
+ _VERTICAL_ANGLE_TOLERANCE = 30.0 # degrees from vertical (90 or 270)
330
+
331
+
332
+ def _is_horizontal(arrow: ArrowVector) -> bool:
333
+ """True if arrow is within tolerance of horizontal (L->R or R->L)."""
334
+ a = arrow.angle_deg
335
+ return (a < _HORIZONTAL_ANGLE_TOLERANCE
336
+ or a > 360 - _HORIZONTAL_ANGLE_TOLERANCE
337
+ or abs(a - 180) < _HORIZONTAL_ANGLE_TOLERANCE)
338
+
339
+
340
+ def _is_vertical(arrow: ArrowVector) -> bool:
341
+ """True if arrow is within tolerance of vertical (down or up)."""
342
+ a = arrow.angle_deg
343
+ return (abs(a - 90) < _VERTICAL_ANGLE_TOLERANCE
344
+ or abs(a - 270) < _VERTICAL_ANGLE_TOLERANCE)
345
+
346
+
347
+ def cluster_arrows_into_rows(
348
+ arrows: List[ArrowVector],
349
+ gap_threshold: Optional[float] = None,
350
+ ) -> List[List[ArrowVector]]:
351
+ """Cluster arrows into horizontal rows by y-coordinate.
352
+
353
+ Uses single-linkage clustering with a gap threshold derived from the
354
+ median arrow length (default 1.5x). Returns rows sorted top-to-bottom
355
+ (increasing y), with arrows within each row sorted left-to-right.
356
+ """
357
+ if not arrows:
358
+ return []
359
+
360
+ if gap_threshold is None:
361
+ lengths = sorted(a.length for a in arrows if a.length > 0)
362
+ median_len = lengths[len(lengths) // 2] if lengths else ACS_BOND_LENGTH * 3
363
+ # Use half the median arrow length as the row clustering threshold.
364
+ # Arrows within this vertical distance belong to the same row.
365
+ # For ACS-style schemes (arrow ~43pt), this gives ~21pt tolerance,
366
+ # which is enough for slight vertical jitter but not enough to merge
367
+ # genuinely separate rows.
368
+ gap_threshold = 0.5 * median_len
369
+
370
+ # Sort by midpoint y
371
+ sorted_arrows = sorted(arrows, key=lambda a: a.midpoint[1])
372
+
373
+ rows: List[List[ArrowVector]] = [[sorted_arrows[0]]]
374
+ for arrow in sorted_arrows[1:]:
375
+ # Check if this arrow belongs to the current row
376
+ row_y_center = sum(a.midpoint[1] for a in rows[-1]) / len(rows[-1])
377
+ if abs(arrow.midpoint[1] - row_y_center) <= gap_threshold:
378
+ rows[-1].append(arrow)
379
+ else:
380
+ rows.append([arrow])
381
+
382
+ # Sort arrows within each row by midpoint x
383
+ for row in rows:
384
+ row.sort(key=lambda a: a.midpoint[0])
385
+
386
+ return rows
387
+
388
+
389
+ def _arrows_form_cycle(arrows: List[ArrowVector],
390
+ proximity_threshold: Optional[float] = None) -> bool:
391
+ """Check if arrows form a closed cycle (head of each -> tail of next, closing loop).
392
+
393
+ Uses proximity matching: arrow i's head must be near arrow j's tail for
394
+ some permutation that forms a cycle.
395
+ """
396
+ if len(arrows) < 2:
397
+ return False
398
+
399
+ if proximity_threshold is None:
400
+ avg_length = sum(a.length for a in arrows) / len(arrows)
401
+ proximity_threshold = avg_length * 0.8
402
+
403
+ # Build a directed graph: arrow i -> arrow j if head_i is near tail_j
404
+ n = len(arrows)
405
+ adj: Dict[int, List[int]] = {i: [] for i in range(n)}
406
+ for i in range(n):
407
+ for j in range(n):
408
+ if i == j:
409
+ continue
410
+ dist = math.hypot(
411
+ arrows[i].head[0] - arrows[j].tail[0],
412
+ arrows[i].head[1] - arrows[j].tail[1],
413
+ )
414
+ if dist < proximity_threshold:
415
+ adj[i].append(j)
416
+
417
+ # DFS cycle detection from each node
418
+ WHITE, GRAY, BLACK = 0, 1, 2
419
+ color = [WHITE] * n
420
+
421
+ def dfs(u: int) -> bool:
422
+ color[u] = GRAY
423
+ for v in adj[u]:
424
+ if color[v] == GRAY:
425
+ return True # back edge -> cycle
426
+ if color[v] == WHITE and dfs(v):
427
+ return True
428
+ color[u] = BLACK
429
+ return False
430
+
431
+ for i in range(n):
432
+ if color[i] == WHITE and dfs(i):
433
+ return True
434
+ return False
435
+
436
+
437
+ def _arrows_share_endpoint(
438
+ arrows: List[ArrowVector],
439
+ proximity_threshold: Optional[float] = None,
440
+ ) -> bool:
441
+ """Check if any arrows share a tail or head region (branch indicator)."""
442
+ if len(arrows) < 2:
443
+ return False
444
+
445
+ if proximity_threshold is None:
446
+ avg_length = sum(a.length for a in arrows) / len(arrows)
447
+ proximity_threshold = avg_length * 0.5
448
+
449
+ # Check for shared tails (divergent) or shared heads (convergent)
450
+ for i in range(len(arrows)):
451
+ for j in range(i + 1, len(arrows)):
452
+ # Shared tail = divergent
453
+ tail_dist = math.hypot(
454
+ arrows[i].tail[0] - arrows[j].tail[0],
455
+ arrows[i].tail[1] - arrows[j].tail[1],
456
+ )
457
+ if tail_dist < proximity_threshold:
458
+ return True
459
+ # Shared head = convergent
460
+ head_dist = math.hypot(
461
+ arrows[i].head[0] - arrows[j].head[0],
462
+ arrows[i].head[1] - arrows[j].head[1],
463
+ )
464
+ if head_dist < proximity_threshold:
465
+ return True
466
+ return False
467
+
468
+
469
+ def classify_layout(arrows: List[ArrowVector]) -> LayoutPattern:
470
+ """Classify the scheme layout pattern from arrow geometry.
471
+
472
+ Taxonomy (from RxnIM):
473
+ - SINGLE_LINE: all horizontal, single row
474
+ - MULTI_LINE: all horizontal, multiple rows
475
+ - SERPENTINE: horizontal arrows with vertical connectors between rows
476
+ - BRANCH: arrows share tail or head endpoints (divergent/convergent)
477
+ - CYCLE: arrows form a closed polygon
478
+ - MIXED: fallback
479
+ """
480
+ if not arrows:
481
+ return LayoutPattern.SINGLE_LINE
482
+ if len(arrows) == 1:
483
+ return LayoutPattern.SINGLE_LINE
484
+
485
+ horizontal = [a for a in arrows if _is_horizontal(a)]
486
+ vertical = [a for a in arrows if _is_vertical(a)]
487
+
488
+ # Check for cycle first (arrows closing a loop)
489
+ if _arrows_form_cycle(arrows):
490
+ return LayoutPattern.CYCLE
491
+
492
+ # Check for branching (shared endpoints)
493
+ if _arrows_share_endpoint(arrows):
494
+ return LayoutPattern.BRANCH
495
+
496
+ # Check for serpentine (horizontal + vertical connectors)
497
+ if horizontal and vertical and len(horizontal) >= 2:
498
+ rows = cluster_arrows_into_rows(horizontal)
499
+ if len(rows) >= 2:
500
+ return LayoutPattern.SERPENTINE
501
+
502
+ # All horizontal — check row count
503
+ if len(horizontal) == len(arrows):
504
+ rows = cluster_arrows_into_rows(arrows)
505
+ if len(rows) == 1:
506
+ return LayoutPattern.SINGLE_LINE
507
+ return LayoutPattern.MULTI_LINE
508
+
509
+ # Mostly horizontal with some non-horizontal — still multi-line if clustered
510
+ if len(horizontal) >= len(arrows) * 0.7:
511
+ rows = cluster_arrows_into_rows(arrows)
512
+ if len(rows) == 1:
513
+ return LayoutPattern.SINGLE_LINE
514
+ return LayoutPattern.MULTI_LINE
515
+
516
+ return LayoutPattern.MIXED
517
+
518
+
519
+ # ---------------------------------------------------------------------------
520
+ # Distance scoring
521
+ # ---------------------------------------------------------------------------
522
+
523
+ # Penalty factor for parallel overshoot (fragment is past arrow tip).
524
+ # Higher values make fragments prefer arrows whose span they fall within.
525
+ _PARALLEL_OVERSHOOT_PENALTY = 0.5
526
+
527
+
528
+ def _distance_score(
529
+ point: Tuple[float, float],
530
+ arrow: ArrowVector,
531
+ ) -> float:
532
+ """Combined distance score from a point to an arrow.
533
+
534
+ Score = |perpendicular| + penalty * max(0, overshoot)
535
+
536
+ where overshoot is how far past either arrow tip the point projects.
537
+ Lower score = stronger association.
538
+ """
539
+ parallel, perp = project_onto_arrow(point, arrow.tail, arrow.head)
540
+
541
+ overshoot = 0.0
542
+ if parallel < 0:
543
+ overshoot = -parallel
544
+ elif parallel > arrow.length:
545
+ overshoot = parallel - arrow.length
546
+
547
+ return abs(perp) + _PARALLEL_OVERSHOOT_PENALTY * overshoot
548
+
549
+
550
+ def _compute_confidence(dist_nearest: float, dist_second: float) -> float:
551
+ """Confidence from ratio of nearest to second-nearest distances.
552
+
553
+ Returns 1.0 when nearest is far from alternatives, ~0.5 when equidistant.
554
+ """
555
+ if dist_second <= 0:
556
+ return 1.0
557
+ total = dist_nearest + dist_second
558
+ if total < 1e-9:
559
+ return 0.5
560
+ return 1.0 - (dist_nearest / total)
561
+
562
+
563
+ # ---------------------------------------------------------------------------
564
+ # Role assignment from projection
565
+ # ---------------------------------------------------------------------------
566
+
567
+ def _role_from_projection(
568
+ parallel: float,
569
+ perpendicular: float,
570
+ arrow_length: float,
571
+ ) -> str:
572
+ """Determine role from arrow-relative coordinates.
573
+
574
+ - parallel < 0 → reactant (behind tail)
575
+ - parallel > arrow_length → product (past head)
576
+ - 0 ≤ parallel ≤ arrow_length:
577
+ perpendicular < 0 → above (reagent/condition)
578
+ perpendicular ≥ 0 → below (condition/yield)
579
+ """
580
+ if parallel < 0:
581
+ return "reactant"
582
+ if parallel > arrow_length:
583
+ return "product"
584
+ if perpendicular < 0:
585
+ return "above"
586
+ return "below"
587
+
588
+
589
+ # ---------------------------------------------------------------------------
590
+ # Core assignment engine
591
+ # ---------------------------------------------------------------------------
592
+
593
+ def _assign_to_nearest_arrow(
594
+ point: Tuple[float, float],
595
+ arrows: List[ArrowVector],
596
+ element_id: str,
597
+ ) -> AssignmentResult:
598
+ """Assign a point to the nearest arrow with role and confidence."""
599
+ if len(arrows) == 1:
600
+ arrow = arrows[0]
601
+ par, perp = project_onto_arrow(point, arrow.tail, arrow.head)
602
+ role = _role_from_projection(par, perp, arrow.length)
603
+ return AssignmentResult(
604
+ element_id=element_id,
605
+ arrow_id=arrow.element_id,
606
+ role=role,
607
+ confidence=1.0,
608
+ distance=abs(perp),
609
+ )
610
+
611
+ # Score against all arrows
612
+ scores = []
613
+ for arrow in arrows:
614
+ score = _distance_score(point, arrow)
615
+ scores.append((score, arrow))
616
+ scores.sort(key=lambda s: s[0])
617
+
618
+ best_score, best_arrow = scores[0]
619
+ second_score = scores[1][0] if len(scores) > 1 else best_score * 10
620
+
621
+ par, perp = project_onto_arrow(point, best_arrow.tail, best_arrow.head)
622
+ role = _role_from_projection(par, perp, best_arrow.length)
623
+ confidence = _compute_confidence(best_score, second_score)
624
+
625
+ return AssignmentResult(
626
+ element_id=element_id,
627
+ arrow_id=best_arrow.element_id,
628
+ role=role,
629
+ confidence=confidence,
630
+ distance=abs(perp),
631
+ )
632
+
633
+
634
+ def _assign_text_to_arrow(
635
+ text: TextInfo,
636
+ arrows: List[ArrowVector],
637
+ margin: float = 30.0,
638
+ max_perp: float = ACS_BOND_LENGTH * 2.5,
639
+ ) -> Optional[AssignmentResult]:
640
+ """Assign a text element to an arrow.
641
+
642
+ Text is only assigned if it falls within the arrow's parallel span
643
+ (extended by *margin* on each side) AND within *max_perp* perpendicular
644
+ distance. This prevents distant text from being mis-assigned.
645
+ """
646
+ best: Optional[AssignmentResult] = None
647
+ best_score = float("inf")
648
+
649
+ for arrow in arrows:
650
+ par, perp = project_onto_arrow(text.position, arrow.tail, arrow.head)
651
+
652
+ # Check parallel range (within arrow span + margin)
653
+ if par < -margin or par > arrow.length + margin:
654
+ continue
655
+ # Check perpendicular range
656
+ if abs(perp) > max_perp:
657
+ continue
658
+
659
+ score = _distance_score(text.position, arrow)
660
+ if score < best_score:
661
+ best_score = score
662
+ role = "above" if perp < 0 else "below"
663
+ best = AssignmentResult(
664
+ element_id=text.element_id,
665
+ arrow_id=arrow.element_id,
666
+ role=role,
667
+ confidence=1.0, # refined below
668
+ distance=abs(perp),
669
+ )
670
+
671
+ # Compute confidence if we found a match
672
+ if best is not None and len(arrows) > 1:
673
+ scores_all = []
674
+ for arrow in arrows:
675
+ par, perp = project_onto_arrow(text.position, arrow.tail, arrow.head)
676
+ if -margin <= par <= arrow.length + margin and abs(perp) <= max_perp:
677
+ scores_all.append(_distance_score(text.position, arrow))
678
+ if len(scores_all) >= 2:
679
+ scores_all.sort()
680
+ best.confidence = _compute_confidence(scores_all[0], scores_all[1])
681
+
682
+ return best
683
+
684
+
685
+ # ---------------------------------------------------------------------------
686
+ # Layout-specific strategies
687
+ # ---------------------------------------------------------------------------
688
+
689
+ def _assign_single_line(
690
+ arrows: List[ArrowVector],
691
+ fragments: List[FragmentInfo],
692
+ texts: List[TextInfo],
693
+ ) -> Tuple[List[RawStep], List[AssignmentResult]]:
694
+ """Assignment for single-row layouts (horizontal or any orientation)."""
695
+ steps: List[RawStep] = []
696
+ results: List[AssignmentResult] = []
697
+
698
+ # Create a step per arrow
699
+ for arrow in arrows:
700
+ steps.append(RawStep(
701
+ arrow_id=arrow.element_id,
702
+ arrow_element=arrow.element,
703
+ ))
704
+
705
+ # Assign fragments
706
+ for frag in fragments:
707
+ result = _assign_to_nearest_arrow(frag.centroid, arrows, frag.element_id)
708
+ results.append(result)
709
+ # Find the corresponding step
710
+ for step in steps:
711
+ if step.arrow_id == result.arrow_id:
712
+ if result.role == "reactant":
713
+ step.reactant_ids.append(frag.element_id)
714
+ elif result.role == "product":
715
+ step.product_ids.append(frag.element_id)
716
+ elif result.role == "above":
717
+ step.above_arrow_ids.append(frag.element_id)
718
+ elif result.role == "below":
719
+ step.below_arrow_ids.append(frag.element_id)
720
+ break
721
+
722
+ # Assign texts
723
+ for text in texts:
724
+ result = _assign_text_to_arrow(text, arrows)
725
+ if result is not None:
726
+ results.append(result)
727
+ for step in steps:
728
+ if step.arrow_id == result.arrow_id:
729
+ if result.role == "above":
730
+ step.above_arrow_ids.append(text.element_id)
731
+ else:
732
+ step.below_arrow_ids.append(text.element_id)
733
+ break
734
+
735
+ # Compute step-level confidence
736
+ for step in steps:
737
+ step_results = [r for r in results if r.arrow_id == step.arrow_id]
738
+ if step_results:
739
+ step.confidence = sum(r.confidence for r in step_results) / len(step_results)
740
+
741
+ return steps, results
742
+
743
+
744
+ def _assign_multi_line(
745
+ arrows: List[ArrowVector],
746
+ fragments: List[FragmentInfo],
747
+ texts: List[TextInfo],
748
+ ) -> Tuple[List[RawStep], List[AssignmentResult]]:
749
+ """Assignment for multi-row horizontal layouts.
750
+
751
+ 1. Cluster arrows into rows
752
+ 2. For each row, assign fragments/texts using proximity
753
+ 3. Link cross-row intermediates (last product row N -> first reactant row N+1)
754
+ """
755
+ rows = cluster_arrows_into_rows(arrows)
756
+
757
+ all_steps: List[RawStep] = []
758
+ all_results: List[AssignmentResult] = []
759
+
760
+ for row_idx, row_arrows in enumerate(rows):
761
+ # Filter fragments/texts to those closest to this row
762
+ row_y_center = sum(a.midpoint[1] for a in row_arrows) / len(row_arrows)
763
+ row_y_half_span = max(
764
+ (a.length for a in row_arrows), default=ACS_BOND_LENGTH * 3
765
+ )
766
+
767
+ row_frags = [f for f in fragments
768
+ if abs(f.centroid[1] - row_y_center) < row_y_half_span]
769
+ row_texts = [t for t in texts
770
+ if abs(t.position[1] - row_y_center) < row_y_half_span]
771
+
772
+ row_steps, row_results = _assign_single_line(row_arrows, row_frags, row_texts)
773
+
774
+ for step in row_steps:
775
+ step.layout_row = row_idx
776
+
777
+ all_steps.extend(row_steps)
778
+ all_results.extend(row_results)
779
+
780
+ # Link cross-row intermediates
781
+ _link_cross_row_intermediates(all_steps, rows)
782
+
783
+ return all_steps, all_results
784
+
785
+
786
+ def _assign_serpentine(
787
+ arrows: List[ArrowVector],
788
+ fragments: List[FragmentInfo],
789
+ texts: List[TextInfo],
790
+ ) -> Tuple[List[RawStep], List[AssignmentResult]]:
791
+ """Assignment for serpentine layouts (horizontal arrows + vertical connectors).
792
+
793
+ Vertical arrows are treated as connectors between rows, not as independent
794
+ reaction steps with their own reactants/products.
795
+ """
796
+ horizontal = [a for a in arrows if _is_horizontal(a)]
797
+ vertical = [a for a in arrows if _is_vertical(a)]
798
+
799
+ # Assign using horizontal arrows only
800
+ steps, results = _assign_multi_line(horizontal, fragments, texts)
801
+
802
+ # Vertical arrows become connector metadata (not separate steps)
803
+ # They link end-of-row to start-of-next-row
804
+ rows = cluster_arrows_into_rows(horizontal)
805
+ for vert in vertical:
806
+ # Find which row boundary this vertical arrow bridges
807
+ for row_idx in range(len(rows) - 1):
808
+ row_bottom = max(a.midpoint[1] for a in rows[row_idx])
809
+ next_row_top = min(a.midpoint[1] for a in rows[row_idx + 1])
810
+ if row_bottom <= vert.midpoint[1] <= next_row_top:
811
+ # This vertical arrow bridges row_idx and row_idx+1
812
+ # Ensure the last product of row_idx is linked to first
813
+ # reactant of row_idx+1
814
+ break
815
+
816
+ return steps, results
817
+
818
+
819
+ def _assign_branch(
820
+ arrows: List[ArrowVector],
821
+ fragments: List[FragmentInfo],
822
+ texts: List[TextInfo],
823
+ ) -> Tuple[List[RawStep], List[AssignmentResult]]:
824
+ """Assignment for divergent/convergent branching layouts.
825
+
826
+ Identifies shared fragments (near multiple arrow tails or heads) and
827
+ assigns them as shared reactants (divergent) or shared products (convergent).
828
+ """
829
+ # Fall back to the general nearest-arrow assignment
830
+ # The shared-endpoint logic is handled by the fact that a fragment near
831
+ # a shared tail region will be closest to multiple arrows — we assign it
832
+ # to all arrows that share that endpoint
833
+ steps, results = _assign_single_line(arrows, fragments, texts)
834
+
835
+ # Post-process: detect shared endpoints and duplicate assignments
836
+ avg_length = sum(a.length for a in arrows) / len(arrows)
837
+ proximity = avg_length * 0.5
838
+
839
+ # Find arrows with shared tails (divergent)
840
+ for i in range(len(arrows)):
841
+ for j in range(i + 1, len(arrows)):
842
+ tail_dist = math.hypot(
843
+ arrows[i].tail[0] - arrows[j].tail[0],
844
+ arrows[i].tail[1] - arrows[j].tail[1],
845
+ )
846
+ if tail_dist < proximity:
847
+ # These arrows share a tail — ensure they share reactants
848
+ step_i = next((s for s in steps if s.arrow_id == arrows[i].element_id), None)
849
+ step_j = next((s for s in steps if s.arrow_id == arrows[j].element_id), None)
850
+ if step_i and step_j:
851
+ # Share reactants between the two steps
852
+ shared = set(step_i.reactant_ids) | set(step_j.reactant_ids)
853
+ step_i.reactant_ids = list(shared)
854
+ step_j.reactant_ids = list(shared)
855
+
856
+ # Find arrows with shared heads (convergent)
857
+ for i in range(len(arrows)):
858
+ for j in range(i + 1, len(arrows)):
859
+ head_dist = math.hypot(
860
+ arrows[i].head[0] - arrows[j].head[0],
861
+ arrows[i].head[1] - arrows[j].head[1],
862
+ )
863
+ if head_dist < proximity:
864
+ step_i = next((s for s in steps if s.arrow_id == arrows[i].element_id), None)
865
+ step_j = next((s for s in steps if s.arrow_id == arrows[j].element_id), None)
866
+ if step_i and step_j:
867
+ shared = set(step_i.product_ids) | set(step_j.product_ids)
868
+ step_i.product_ids = list(shared)
869
+ step_j.product_ids = list(shared)
870
+
871
+ return steps, results
872
+
873
+
874
+ def _assign_cycle(
875
+ arrows: List[ArrowVector],
876
+ fragments: List[FragmentInfo],
877
+ texts: List[TextInfo],
878
+ ) -> Tuple[List[RawStep], List[AssignmentResult]]:
879
+ """Assignment for cyclic reaction networks (catalytic cycles).
880
+
881
+ Orders arrows around the cycle and assigns fragments between consecutive
882
+ arrow endpoints.
883
+ """
884
+ # Order arrows into a cycle by chaining head -> nearest tail
885
+ ordered = _order_arrows_cyclic(arrows)
886
+
887
+ steps, results = _assign_single_line(ordered, fragments, texts)
888
+
889
+ # In a cycle, the "product" of the last step should connect back to the
890
+ # "reactant" of the first step. We don't modify the assignment — the
891
+ # topology detector in scheme_reader will recognize this as a cycle.
892
+
893
+ return steps, results
894
+
895
+
896
+ def _order_arrows_cyclic(arrows: List[ArrowVector]) -> List[ArrowVector]:
897
+ """Order arrows into a cycle by greedily chaining head_i -> nearest tail_j."""
898
+ if len(arrows) <= 1:
899
+ return list(arrows)
900
+
901
+ remaining = list(arrows)
902
+ ordered = [remaining.pop(0)]
903
+
904
+ while remaining:
905
+ last_head = ordered[-1].head
906
+ # Find the arrow whose tail is closest to last_head
907
+ best_idx = 0
908
+ best_dist = float("inf")
909
+ for idx, arrow in enumerate(remaining):
910
+ dist = math.hypot(
911
+ last_head[0] - arrow.tail[0],
912
+ last_head[1] - arrow.tail[1],
913
+ )
914
+ if dist < best_dist:
915
+ best_dist = dist
916
+ best_idx = idx
917
+ ordered.append(remaining.pop(best_idx))
918
+
919
+ return ordered
920
+
921
+
922
+ # ---------------------------------------------------------------------------
923
+ # Shared intermediate handling
924
+ # ---------------------------------------------------------------------------
925
+
926
+ def _link_shared_intermediates(steps: List[RawStep]) -> None:
927
+ """Propagate products to next step's reactants when reactants are empty.
928
+
929
+ For sequential schemes, the product of step i is the reactant of step i+1.
930
+ When step i+1 has no reactants, copy step i's products.
931
+ """
932
+ for i in range(len(steps) - 1):
933
+ if not steps[i + 1].reactant_ids and steps[i].product_ids:
934
+ steps[i + 1].reactant_ids = list(steps[i].product_ids)
935
+
936
+
937
+ def _link_cross_row_intermediates(
938
+ steps: List[RawStep],
939
+ rows: List[List[ArrowVector]],
940
+ ) -> None:
941
+ """Link the last step of row N to the first step of row N+1.
942
+
943
+ For multi-line layouts, the last product of row N wraps to become the
944
+ first reactant of row N+1.
945
+ """
946
+ if len(rows) < 2:
947
+ _link_shared_intermediates(steps)
948
+ return
949
+
950
+ # Group steps by row
951
+ row_steps: Dict[int, List[RawStep]] = {}
952
+ for step in steps:
953
+ row_steps.setdefault(step.layout_row, []).append(step)
954
+
955
+ # Within each row, link shared intermediates
956
+ for row_idx in sorted(row_steps):
957
+ _link_shared_intermediates(row_steps[row_idx])
958
+
959
+ # Across rows: last step of row N -> first step of row N+1
960
+ sorted_rows = sorted(row_steps.keys())
961
+ for i in range(len(sorted_rows) - 1):
962
+ curr_row = row_steps[sorted_rows[i]]
963
+ next_row = row_steps[sorted_rows[i + 1]]
964
+ if curr_row and next_row:
965
+ last_step = curr_row[-1]
966
+ first_step = next_row[0]
967
+ if last_step.product_ids and not first_step.reactant_ids:
968
+ first_step.reactant_ids = list(last_step.product_ids)
969
+
970
+
971
+ # ---------------------------------------------------------------------------
972
+ # Public API
973
+ # ---------------------------------------------------------------------------
974
+
975
+ def assign_elements(
976
+ arrows: List[ArrowVector],
977
+ page: ET.Element,
978
+ layout: Optional[LayoutPattern] = None,
979
+ ) -> Tuple[List[RawStep], List[AssignmentResult]]:
980
+ """Assign all fragments and texts on the page to arrows.
981
+
982
+ This is the main entry point. If *layout* is None, it is auto-detected.
983
+
984
+ Returns:
985
+ - steps: list of :class:`RawStep`, one per arrow
986
+ - results: list of :class:`AssignmentResult`, one per assigned element
987
+ """
988
+ if not arrows:
989
+ return [], []
990
+
991
+ if layout is None:
992
+ layout = classify_layout(arrows)
993
+
994
+ fragments = collect_fragments(page)
995
+ texts = collect_texts(page)
996
+
997
+ dispatch = {
998
+ LayoutPattern.SINGLE_LINE: _assign_single_line,
999
+ LayoutPattern.MULTI_LINE: _assign_multi_line,
1000
+ LayoutPattern.SERPENTINE: _assign_serpentine,
1001
+ LayoutPattern.BRANCH: _assign_branch,
1002
+ LayoutPattern.CYCLE: _assign_cycle,
1003
+ LayoutPattern.MIXED: _assign_single_line, # fallback
1004
+ }
1005
+
1006
+ strategy = dispatch.get(layout, _assign_single_line)
1007
+ steps, results = strategy(arrows, fragments, texts)
1008
+
1009
+ # Final pass: link shared intermediates for sequential schemes
1010
+ if layout in (LayoutPattern.SINGLE_LINE, LayoutPattern.MIXED):
1011
+ _link_shared_intermediates(steps)
1012
+
1013
+ return steps, results