cdxml-toolkit 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cdxml_toolkit/__init__.py +18 -0
- cdxml_toolkit/_jre/__init__.py +2 -0
- cdxml_toolkit/_jre/temurin-21-jre-win-x64.zip +0 -0
- cdxml_toolkit/analysis/__init__.py +35 -0
- cdxml_toolkit/analysis/deterministic/__init__.py +12 -0
- cdxml_toolkit/analysis/deterministic/discover_experiment_files.py +413 -0
- cdxml_toolkit/analysis/deterministic/lab_book_formatter.py +701 -0
- cdxml_toolkit/analysis/deterministic/lcms_file_categorizer.py +928 -0
- cdxml_toolkit/analysis/deterministic/lcms_identifier.py +598 -0
- cdxml_toolkit/analysis/deterministic/mass_resolver.py +654 -0
- cdxml_toolkit/analysis/deterministic/multi_lcms_analyzer.py +1412 -0
- cdxml_toolkit/analysis/deterministic/procedure_writer.py +446 -0
- cdxml_toolkit/analysis/extract_nmr.py +47 -0
- cdxml_toolkit/analysis/format_procedure_entry.py +479 -0
- cdxml_toolkit/analysis/lcms_analyzer.py +1299 -0
- cdxml_toolkit/analysis/parse_analysis_file.py +134 -0
- cdxml_toolkit/cdxml_builder.py +920 -0
- cdxml_toolkit/cdxml_utils.py +342 -0
- cdxml_toolkit/chemdraw/__init__.py +5 -0
- cdxml_toolkit/chemdraw/_chemscript_server.py +562 -0
- cdxml_toolkit/chemdraw/cdx_converter.py +527 -0
- cdxml_toolkit/chemdraw/cdxml_to_image.py +262 -0
- cdxml_toolkit/chemdraw/cdxml_to_image_rdkit.py +296 -0
- cdxml_toolkit/chemdraw/chemscript_bridge.py +901 -0
- cdxml_toolkit/constants.py +304 -0
- cdxml_toolkit/coord_normalizer.py +438 -0
- cdxml_toolkit/deterministic_pipeline/__init__.py +6 -0
- cdxml_toolkit/deterministic_pipeline/legacy/__init__.py +5 -0
- cdxml_toolkit/deterministic_pipeline/legacy/eln_cdx_cleanup.py +509 -0
- cdxml_toolkit/deterministic_pipeline/legacy/eln_enrichment.py +1394 -0
- cdxml_toolkit/deterministic_pipeline/legacy/scheme_aligner.py +428 -0
- cdxml_toolkit/deterministic_pipeline/legacy/scheme_polisher.py +1337 -0
- cdxml_toolkit/deterministic_pipeline/legacy/scheme_polisher_v2.py +1340 -0
- cdxml_toolkit/deterministic_pipeline/scheme_reader_audit.py +931 -0
- cdxml_toolkit/deterministic_pipeline/scheme_reader_verify.py +1160 -0
- cdxml_toolkit/image/__init__.py +15 -0
- cdxml_toolkit/image/reaction_from_image.py +2103 -0
- cdxml_toolkit/image/structure_from_image.py +1711 -0
- cdxml_toolkit/layout/__init__.py +5 -0
- cdxml_toolkit/layout/alignment.py +1642 -0
- cdxml_toolkit/layout/reaction_cleanup.py +1002 -0
- cdxml_toolkit/layout/scheme_merger.py +2260 -0
- cdxml_toolkit/mcp_server/__init__.py +0 -0
- cdxml_toolkit/mcp_server/__main__.py +5 -0
- cdxml_toolkit/mcp_server/server.py +1567 -0
- cdxml_toolkit/naming/__init__.py +6 -0
- cdxml_toolkit/naming/aligned_namer.py +2342 -0
- cdxml_toolkit/naming/mol_builder.py +3722 -0
- cdxml_toolkit/naming/name_decomposer.py +2843 -0
- cdxml_toolkit/naming/reactions_datamol.json +2414 -0
- cdxml_toolkit/office/__init__.py +5 -0
- cdxml_toolkit/office/doc_from_template.py +722 -0
- cdxml_toolkit/office/ole_embedder.py +808 -0
- cdxml_toolkit/office/ole_extractor.py +272 -0
- cdxml_toolkit/perception/__init__.py +10 -0
- cdxml_toolkit/perception/compound_search.py +229 -0
- cdxml_toolkit/perception/eln_csv_parser.py +240 -0
- cdxml_toolkit/perception/rdf_parser.py +664 -0
- cdxml_toolkit/perception/reactant_heuristic.py +1045 -0
- cdxml_toolkit/perception/reaction_parser.py +2150 -0
- cdxml_toolkit/perception/scheme_reader.py +2948 -0
- cdxml_toolkit/perception/scheme_refine.py +1404 -0
- cdxml_toolkit/perception/scheme_segmenter.py +619 -0
- cdxml_toolkit/perception/spatial_assignment.py +1013 -0
- cdxml_toolkit/rdkit_utils.py +605 -0
- cdxml_toolkit/render/__init__.py +17 -0
- cdxml_toolkit/render/auto_layout.py +229 -0
- cdxml_toolkit/render/compact_parser.py +632 -0
- cdxml_toolkit/render/parser.py +706 -0
- cdxml_toolkit/render/render_scheme.py +267 -0
- cdxml_toolkit/render/renderer.py +2387 -0
- cdxml_toolkit/render/schema.py +90 -0
- cdxml_toolkit/render/scheme_maker.py +1043 -0
- cdxml_toolkit/render/scheme_yaml_writer.py +1487 -0
- cdxml_toolkit/resolve/__init__.py +13 -0
- cdxml_toolkit/resolve/cas_resolver.py +430 -0
- cdxml_toolkit/resolve/chemscanner_abbreviations.json +28813 -0
- cdxml_toolkit/resolve/condensed_formula.py +493 -0
- cdxml_toolkit/resolve/jre_manager.py +195 -0
- cdxml_toolkit/resolve/reagent_abbreviations.json +1046 -0
- cdxml_toolkit/resolve/reagent_db.py +285 -0
- cdxml_toolkit/resolve/superatom_data.json +2856 -0
- cdxml_toolkit/resolve/superatom_table.py +146 -0
- cdxml_toolkit/text_formatting.py +298 -0
- cdxml_toolkit-0.5.0.dist-info/METADATA +318 -0
- cdxml_toolkit-0.5.0.dist-info/RECORD +91 -0
- cdxml_toolkit-0.5.0.dist-info/WHEEL +5 -0
- cdxml_toolkit-0.5.0.dist-info/entry_points.txt +17 -0
- cdxml_toolkit-0.5.0.dist-info/licenses/LICENSE +21 -0
- cdxml_toolkit-0.5.0.dist-info/licenses/NOTICE.md +37 -0
- cdxml_toolkit-0.5.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1013 @@
|
|
|
1
|
+
"""
|
|
2
|
+
spatial_assignment.py — Geometry-first spatial assignment of scheme elements to arrows.
|
|
3
|
+
|
|
4
|
+
Replaces the naive x-band assignment in scheme_reader._parse_from_geometry()
|
|
5
|
+
with a rotation-invariant, distance-based approach that handles arbitrary arrow
|
|
6
|
+
orientations, multi-row layouts, branching, and cycles.
|
|
7
|
+
|
|
8
|
+
Algorithmic influences:
|
|
9
|
+
- ReactionDataExtractor (Cambridge, 2021/2023): arrow-centric equidistant scan
|
|
10
|
+
- RxnIM (HKUST, 2025): layout pattern taxonomy (single/multi-line/branch/cycle)
|
|
11
|
+
- CDXML advantage: exact coordinates from XML, no detection/OCR needed
|
|
12
|
+
|
|
13
|
+
Key design decisions:
|
|
14
|
+
- Arrow-relative projection: every point is transformed into (parallel, perp)
|
|
15
|
+
coordinates relative to each arrow. This makes role assignment
|
|
16
|
+
rotation-invariant.
|
|
17
|
+
- Distance-based assignment: fragments go to the nearest arrow by a combined
|
|
18
|
+
distance score, not by hard x-coordinate bands.
|
|
19
|
+
- Layout classifier: detects the scheme pattern first, then delegates to a
|
|
20
|
+
pattern-specific strategy that handles edge cases for that layout type.
|
|
21
|
+
- Confidence scoring: every assignment carries a 0-1 confidence based on
|
|
22
|
+
the ratio of nearest to second-nearest arrow distance.
|
|
23
|
+
|
|
24
|
+
API:
|
|
25
|
+
from cdxml_toolkit.perception.spatial_assignment import (
|
|
26
|
+
build_arrow_vectors, classify_layout, assign_elements,
|
|
27
|
+
)
|
|
28
|
+
arrows = build_arrow_vectors(page_element)
|
|
29
|
+
layout = classify_layout(arrows)
|
|
30
|
+
steps, results = assign_elements(arrows, page_element, layout)
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
from __future__ import annotations
|
|
34
|
+
|
|
35
|
+
import math
|
|
36
|
+
from dataclasses import dataclass, field
|
|
37
|
+
from enum import Enum
|
|
38
|
+
from typing import Dict, List, Optional, Set, Tuple
|
|
39
|
+
from xml.etree import ElementTree as ET
|
|
40
|
+
|
|
41
|
+
from ..constants import ACS_BOND_LENGTH
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
# ---------------------------------------------------------------------------
|
|
45
|
+
# Data structures
|
|
46
|
+
# ---------------------------------------------------------------------------
|
|
47
|
+
|
|
48
|
+
@dataclass
|
|
49
|
+
class ArrowVector:
|
|
50
|
+
"""Fully characterised arrow with direction, type, and spatial metadata."""
|
|
51
|
+
element_id: str
|
|
52
|
+
element: ET.Element
|
|
53
|
+
tail: Tuple[float, float]
|
|
54
|
+
head: Tuple[float, float]
|
|
55
|
+
midpoint: Tuple[float, float]
|
|
56
|
+
direction: Tuple[float, float] # unit vector tail -> head
|
|
57
|
+
normal: Tuple[float, float] # perpendicular — points to "above" side
|
|
58
|
+
length: float
|
|
59
|
+
angle_deg: float # 0=right, 90=down, 180=left, 270=up
|
|
60
|
+
arrow_type: str # "solid", "dashed", "failed", "equilibrium"
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class LayoutPattern(Enum):
|
|
64
|
+
SINGLE_LINE = "single_line"
|
|
65
|
+
MULTI_LINE = "multi_line"
|
|
66
|
+
BRANCH = "branch"
|
|
67
|
+
CYCLE = "cycle"
|
|
68
|
+
SERPENTINE = "serpentine"
|
|
69
|
+
MIXED = "mixed"
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
@dataclass
|
|
73
|
+
class FragmentInfo:
|
|
74
|
+
"""Spatial metadata for a CDXML fragment."""
|
|
75
|
+
element_id: str
|
|
76
|
+
element: ET.Element
|
|
77
|
+
centroid: Tuple[float, float]
|
|
78
|
+
bbox: Tuple[float, float, float, float]
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@dataclass
|
|
82
|
+
class TextInfo:
|
|
83
|
+
"""Spatial metadata for a CDXML text element."""
|
|
84
|
+
element_id: str
|
|
85
|
+
element: ET.Element
|
|
86
|
+
position: Tuple[float, float]
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
@dataclass
|
|
90
|
+
class AssignmentResult:
|
|
91
|
+
"""Single element-to-arrow assignment with confidence."""
|
|
92
|
+
element_id: str
|
|
93
|
+
arrow_id: str
|
|
94
|
+
role: str # "reactant", "product", "above", "below"
|
|
95
|
+
confidence: float # 0.0 – 1.0
|
|
96
|
+
distance: float # perpendicular distance to arrow axis
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
@dataclass
|
|
100
|
+
class RawStep:
|
|
101
|
+
"""One reaction step derived from spatial assignment."""
|
|
102
|
+
arrow_id: str
|
|
103
|
+
arrow_element: Optional[ET.Element] = None
|
|
104
|
+
reactant_ids: List[str] = field(default_factory=list)
|
|
105
|
+
product_ids: List[str] = field(default_factory=list)
|
|
106
|
+
above_arrow_ids: List[str] = field(default_factory=list)
|
|
107
|
+
below_arrow_ids: List[str] = field(default_factory=list)
|
|
108
|
+
confidence: float = 1.0
|
|
109
|
+
layout_row: int = 0 # row index for multi-line layouts
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
# ---------------------------------------------------------------------------
|
|
113
|
+
# Geometry primitives
|
|
114
|
+
# ---------------------------------------------------------------------------
|
|
115
|
+
|
|
116
|
+
def _unit_vector(dx: float, dy: float) -> Tuple[float, float]:
|
|
117
|
+
"""Normalise (dx, dy) to a unit vector. Returns (0, 0) for zero-length."""
|
|
118
|
+
mag = math.hypot(dx, dy)
|
|
119
|
+
if mag < 1e-9:
|
|
120
|
+
return (0.0, 0.0)
|
|
121
|
+
return (dx / mag, dy / mag)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def project_onto_arrow(
|
|
125
|
+
point: Tuple[float, float],
|
|
126
|
+
tail: Tuple[float, float],
|
|
127
|
+
head: Tuple[float, float],
|
|
128
|
+
) -> Tuple[float, float]:
|
|
129
|
+
"""Project *point* into the arrow-relative coordinate system.
|
|
130
|
+
|
|
131
|
+
Returns ``(parallel, perpendicular)`` where:
|
|
132
|
+
- *parallel*: signed distance along the arrow axis from the tail.
|
|
133
|
+
Negative = behind the tail, > arrow length = past the head.
|
|
134
|
+
- *perpendicular*: signed distance from the arrow axis.
|
|
135
|
+
Negative = "above" side, positive = "below" side.
|
|
136
|
+
|
|
137
|
+
Sign convention for perpendicular (CDXML y-axis points downward):
|
|
138
|
+
- For a horizontal L-to-R arrow: y < arrow → above → perp < 0
|
|
139
|
+
- For a vertical downward arrow: x > arrow → right side → perp < 0
|
|
140
|
+
(right side is "above" when the arrow points down)
|
|
141
|
+
"""
|
|
142
|
+
dx = head[0] - tail[0]
|
|
143
|
+
dy = head[1] - tail[1]
|
|
144
|
+
length = math.hypot(dx, dy)
|
|
145
|
+
if length < 1e-9:
|
|
146
|
+
# Degenerate arrow — return distance from tail
|
|
147
|
+
dist = math.hypot(point[0] - tail[0], point[1] - tail[1])
|
|
148
|
+
return (0.0, dist)
|
|
149
|
+
|
|
150
|
+
# Direction unit vector
|
|
151
|
+
ux, uy = dx / length, dy / length
|
|
152
|
+
# Normal: rotate direction 90° counter-clockwise in math coords,
|
|
153
|
+
# which is clockwise on screen (y-down).
|
|
154
|
+
# For a rightward arrow (ux=1, uy=0) this gives (0, 1) i.e. downward.
|
|
155
|
+
# Convention: perpendicular < 0 = "above" side, > 0 = "below" side.
|
|
156
|
+
nx, ny = -uy, ux
|
|
157
|
+
|
|
158
|
+
# Vector from tail to point
|
|
159
|
+
vx = point[0] - tail[0]
|
|
160
|
+
vy = point[1] - tail[1]
|
|
161
|
+
|
|
162
|
+
parallel = vx * ux + vy * uy
|
|
163
|
+
perpendicular = vx * nx + vy * ny
|
|
164
|
+
|
|
165
|
+
return (parallel, perpendicular)
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def point_to_segment_distance(
|
|
169
|
+
point: Tuple[float, float],
|
|
170
|
+
seg_start: Tuple[float, float],
|
|
171
|
+
seg_end: Tuple[float, float],
|
|
172
|
+
) -> float:
|
|
173
|
+
"""Shortest distance from *point* to the line segment [seg_start, seg_end]."""
|
|
174
|
+
sx, sy = seg_start
|
|
175
|
+
ex, ey = seg_end
|
|
176
|
+
dx, dy = ex - sx, ey - sy
|
|
177
|
+
len_sq = dx * dx + dy * dy
|
|
178
|
+
|
|
179
|
+
if len_sq < 1e-18:
|
|
180
|
+
return math.hypot(point[0] - sx, point[1] - sy)
|
|
181
|
+
|
|
182
|
+
# Parameter t for projection onto the infinite line
|
|
183
|
+
t = ((point[0] - sx) * dx + (point[1] - sy) * dy) / len_sq
|
|
184
|
+
t = max(0.0, min(1.0, t))
|
|
185
|
+
|
|
186
|
+
proj_x = sx + t * dx
|
|
187
|
+
proj_y = sy + t * dy
|
|
188
|
+
return math.hypot(point[0] - proj_x, point[1] - proj_y)
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
# ---------------------------------------------------------------------------
|
|
192
|
+
# Arrow vector construction
|
|
193
|
+
# ---------------------------------------------------------------------------
|
|
194
|
+
|
|
195
|
+
def _classify_arrow_type(arrow: ET.Element) -> str:
|
|
196
|
+
"""Classify arrow type from CDXML attributes."""
|
|
197
|
+
if arrow.get("NoGo") == "Cross":
|
|
198
|
+
return "failed"
|
|
199
|
+
line_type = (arrow.get("LineType") or "").lower()
|
|
200
|
+
if line_type in ("dash", "dashed", "dot"):
|
|
201
|
+
return "dashed"
|
|
202
|
+
if (arrow.get("ArrowheadType") or "").lower() == "dashed":
|
|
203
|
+
return "dashed"
|
|
204
|
+
# Check for equilibrium arrows (double-headed)
|
|
205
|
+
arrow_type_attr = (arrow.get("ArrowType") or "").lower()
|
|
206
|
+
if "equilibrium" in arrow_type_attr:
|
|
207
|
+
return "equilibrium"
|
|
208
|
+
return "solid"
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def build_arrow_vector(arrow: ET.Element) -> ArrowVector:
|
|
212
|
+
"""Build an :class:`ArrowVector` from a CDXML ``<arrow>`` or ``<graphic>`` element."""
|
|
213
|
+
from ..cdxml_utils import arrow_endpoints
|
|
214
|
+
|
|
215
|
+
tx, ty, hx, hy = arrow_endpoints(arrow)
|
|
216
|
+
|
|
217
|
+
dx = hx - tx
|
|
218
|
+
dy = hy - ty
|
|
219
|
+
length = math.hypot(dx, dy)
|
|
220
|
+
|
|
221
|
+
direction = _unit_vector(dx, dy)
|
|
222
|
+
# Normal: perpendicular < 0 = "above" side, > 0 = "below" side.
|
|
223
|
+
# For a rightward arrow (1, 0) this gives (0, 1) pointing downward.
|
|
224
|
+
normal = (-direction[1], direction[0])
|
|
225
|
+
|
|
226
|
+
angle_rad = math.atan2(dy, dx)
|
|
227
|
+
angle_deg = math.degrees(angle_rad) % 360
|
|
228
|
+
|
|
229
|
+
return ArrowVector(
|
|
230
|
+
element_id=arrow.get("id", ""),
|
|
231
|
+
element=arrow,
|
|
232
|
+
tail=(tx, ty),
|
|
233
|
+
head=(hx, hy),
|
|
234
|
+
midpoint=((tx + hx) / 2, (ty + hy) / 2),
|
|
235
|
+
direction=direction,
|
|
236
|
+
normal=normal,
|
|
237
|
+
length=length,
|
|
238
|
+
angle_deg=angle_deg,
|
|
239
|
+
arrow_type=_classify_arrow_type(arrow),
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
def build_arrow_vectors(page: ET.Element) -> List[ArrowVector]:
|
|
244
|
+
"""Find all arrows on the page and build ArrowVector objects.
|
|
245
|
+
|
|
246
|
+
Searches for ``<arrow>`` elements and ``<graphic>`` elements with
|
|
247
|
+
``GraphicType="Line"`` and an ``ArrowType`` attribute.
|
|
248
|
+
"""
|
|
249
|
+
seen: Set[str] = set()
|
|
250
|
+
arrows: List[ArrowVector] = []
|
|
251
|
+
|
|
252
|
+
for el in page:
|
|
253
|
+
if el.tag == "arrow":
|
|
254
|
+
eid = el.get("id", "")
|
|
255
|
+
if eid not in seen:
|
|
256
|
+
arrows.append(build_arrow_vector(el))
|
|
257
|
+
seen.add(eid)
|
|
258
|
+
|
|
259
|
+
for el in page:
|
|
260
|
+
if el.tag == "graphic":
|
|
261
|
+
if el.get("GraphicType") == "Line" and el.get("ArrowType"):
|
|
262
|
+
eid = el.get("id", "")
|
|
263
|
+
if eid not in seen:
|
|
264
|
+
arrows.append(build_arrow_vector(el))
|
|
265
|
+
seen.add(eid)
|
|
266
|
+
|
|
267
|
+
return arrows
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
# ---------------------------------------------------------------------------
|
|
271
|
+
# Fragment and text collection
|
|
272
|
+
# ---------------------------------------------------------------------------
|
|
273
|
+
|
|
274
|
+
def collect_fragments(page: ET.Element) -> List[FragmentInfo]:
|
|
275
|
+
"""Collect all fragments on the page with spatial metadata."""
|
|
276
|
+
from ..cdxml_utils import fragment_bbox, fragment_centroid
|
|
277
|
+
|
|
278
|
+
frags: List[FragmentInfo] = []
|
|
279
|
+
for el in page:
|
|
280
|
+
if el.tag == "fragment":
|
|
281
|
+
centroid = fragment_centroid(el)
|
|
282
|
+
bbox = fragment_bbox(el)
|
|
283
|
+
if centroid is None or bbox is None:
|
|
284
|
+
continue
|
|
285
|
+
frags.append(FragmentInfo(
|
|
286
|
+
element_id=el.get("id", ""),
|
|
287
|
+
element=el,
|
|
288
|
+
centroid=centroid,
|
|
289
|
+
bbox=bbox,
|
|
290
|
+
))
|
|
291
|
+
return frags
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
def collect_texts(page: ET.Element) -> List[TextInfo]:
|
|
295
|
+
"""Collect all free text elements on the page with positions."""
|
|
296
|
+
texts: List[TextInfo] = []
|
|
297
|
+
for el in page:
|
|
298
|
+
if el.tag == "t":
|
|
299
|
+
p = el.get("p")
|
|
300
|
+
if p:
|
|
301
|
+
parts = p.split()
|
|
302
|
+
if len(parts) >= 2:
|
|
303
|
+
pos = (float(parts[0]), float(parts[1]))
|
|
304
|
+
texts.append(TextInfo(
|
|
305
|
+
element_id=el.get("id", ""),
|
|
306
|
+
element=el,
|
|
307
|
+
position=pos,
|
|
308
|
+
))
|
|
309
|
+
continue
|
|
310
|
+
# Fallback: BoundingBox center
|
|
311
|
+
bb = el.get("BoundingBox", "")
|
|
312
|
+
if bb:
|
|
313
|
+
vals = [float(v) for v in bb.split()]
|
|
314
|
+
if len(vals) >= 4:
|
|
315
|
+
pos = ((vals[0] + vals[2]) / 2, (vals[1] + vals[3]) / 2)
|
|
316
|
+
texts.append(TextInfo(
|
|
317
|
+
element_id=el.get("id", ""),
|
|
318
|
+
element=el,
|
|
319
|
+
position=pos,
|
|
320
|
+
))
|
|
321
|
+
return texts
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
# ---------------------------------------------------------------------------
|
|
325
|
+
# Layout classification
|
|
326
|
+
# ---------------------------------------------------------------------------
|
|
327
|
+
|
|
328
|
+
_HORIZONTAL_ANGLE_TOLERANCE = 30.0 # degrees from horizontal (0 or 180)
|
|
329
|
+
_VERTICAL_ANGLE_TOLERANCE = 30.0 # degrees from vertical (90 or 270)
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
def _is_horizontal(arrow: ArrowVector) -> bool:
|
|
333
|
+
"""True if arrow is within tolerance of horizontal (L->R or R->L)."""
|
|
334
|
+
a = arrow.angle_deg
|
|
335
|
+
return (a < _HORIZONTAL_ANGLE_TOLERANCE
|
|
336
|
+
or a > 360 - _HORIZONTAL_ANGLE_TOLERANCE
|
|
337
|
+
or abs(a - 180) < _HORIZONTAL_ANGLE_TOLERANCE)
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
def _is_vertical(arrow: ArrowVector) -> bool:
|
|
341
|
+
"""True if arrow is within tolerance of vertical (down or up)."""
|
|
342
|
+
a = arrow.angle_deg
|
|
343
|
+
return (abs(a - 90) < _VERTICAL_ANGLE_TOLERANCE
|
|
344
|
+
or abs(a - 270) < _VERTICAL_ANGLE_TOLERANCE)
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
def cluster_arrows_into_rows(
|
|
348
|
+
arrows: List[ArrowVector],
|
|
349
|
+
gap_threshold: Optional[float] = None,
|
|
350
|
+
) -> List[List[ArrowVector]]:
|
|
351
|
+
"""Cluster arrows into horizontal rows by y-coordinate.
|
|
352
|
+
|
|
353
|
+
Uses single-linkage clustering with a gap threshold derived from the
|
|
354
|
+
median arrow length (default 1.5x). Returns rows sorted top-to-bottom
|
|
355
|
+
(increasing y), with arrows within each row sorted left-to-right.
|
|
356
|
+
"""
|
|
357
|
+
if not arrows:
|
|
358
|
+
return []
|
|
359
|
+
|
|
360
|
+
if gap_threshold is None:
|
|
361
|
+
lengths = sorted(a.length for a in arrows if a.length > 0)
|
|
362
|
+
median_len = lengths[len(lengths) // 2] if lengths else ACS_BOND_LENGTH * 3
|
|
363
|
+
# Use half the median arrow length as the row clustering threshold.
|
|
364
|
+
# Arrows within this vertical distance belong to the same row.
|
|
365
|
+
# For ACS-style schemes (arrow ~43pt), this gives ~21pt tolerance,
|
|
366
|
+
# which is enough for slight vertical jitter but not enough to merge
|
|
367
|
+
# genuinely separate rows.
|
|
368
|
+
gap_threshold = 0.5 * median_len
|
|
369
|
+
|
|
370
|
+
# Sort by midpoint y
|
|
371
|
+
sorted_arrows = sorted(arrows, key=lambda a: a.midpoint[1])
|
|
372
|
+
|
|
373
|
+
rows: List[List[ArrowVector]] = [[sorted_arrows[0]]]
|
|
374
|
+
for arrow in sorted_arrows[1:]:
|
|
375
|
+
# Check if this arrow belongs to the current row
|
|
376
|
+
row_y_center = sum(a.midpoint[1] for a in rows[-1]) / len(rows[-1])
|
|
377
|
+
if abs(arrow.midpoint[1] - row_y_center) <= gap_threshold:
|
|
378
|
+
rows[-1].append(arrow)
|
|
379
|
+
else:
|
|
380
|
+
rows.append([arrow])
|
|
381
|
+
|
|
382
|
+
# Sort arrows within each row by midpoint x
|
|
383
|
+
for row in rows:
|
|
384
|
+
row.sort(key=lambda a: a.midpoint[0])
|
|
385
|
+
|
|
386
|
+
return rows
|
|
387
|
+
|
|
388
|
+
|
|
389
|
+
def _arrows_form_cycle(arrows: List[ArrowVector],
|
|
390
|
+
proximity_threshold: Optional[float] = None) -> bool:
|
|
391
|
+
"""Check if arrows form a closed cycle (head of each -> tail of next, closing loop).
|
|
392
|
+
|
|
393
|
+
Uses proximity matching: arrow i's head must be near arrow j's tail for
|
|
394
|
+
some permutation that forms a cycle.
|
|
395
|
+
"""
|
|
396
|
+
if len(arrows) < 2:
|
|
397
|
+
return False
|
|
398
|
+
|
|
399
|
+
if proximity_threshold is None:
|
|
400
|
+
avg_length = sum(a.length for a in arrows) / len(arrows)
|
|
401
|
+
proximity_threshold = avg_length * 0.8
|
|
402
|
+
|
|
403
|
+
# Build a directed graph: arrow i -> arrow j if head_i is near tail_j
|
|
404
|
+
n = len(arrows)
|
|
405
|
+
adj: Dict[int, List[int]] = {i: [] for i in range(n)}
|
|
406
|
+
for i in range(n):
|
|
407
|
+
for j in range(n):
|
|
408
|
+
if i == j:
|
|
409
|
+
continue
|
|
410
|
+
dist = math.hypot(
|
|
411
|
+
arrows[i].head[0] - arrows[j].tail[0],
|
|
412
|
+
arrows[i].head[1] - arrows[j].tail[1],
|
|
413
|
+
)
|
|
414
|
+
if dist < proximity_threshold:
|
|
415
|
+
adj[i].append(j)
|
|
416
|
+
|
|
417
|
+
# DFS cycle detection from each node
|
|
418
|
+
WHITE, GRAY, BLACK = 0, 1, 2
|
|
419
|
+
color = [WHITE] * n
|
|
420
|
+
|
|
421
|
+
def dfs(u: int) -> bool:
|
|
422
|
+
color[u] = GRAY
|
|
423
|
+
for v in adj[u]:
|
|
424
|
+
if color[v] == GRAY:
|
|
425
|
+
return True # back edge -> cycle
|
|
426
|
+
if color[v] == WHITE and dfs(v):
|
|
427
|
+
return True
|
|
428
|
+
color[u] = BLACK
|
|
429
|
+
return False
|
|
430
|
+
|
|
431
|
+
for i in range(n):
|
|
432
|
+
if color[i] == WHITE and dfs(i):
|
|
433
|
+
return True
|
|
434
|
+
return False
|
|
435
|
+
|
|
436
|
+
|
|
437
|
+
def _arrows_share_endpoint(
|
|
438
|
+
arrows: List[ArrowVector],
|
|
439
|
+
proximity_threshold: Optional[float] = None,
|
|
440
|
+
) -> bool:
|
|
441
|
+
"""Check if any arrows share a tail or head region (branch indicator)."""
|
|
442
|
+
if len(arrows) < 2:
|
|
443
|
+
return False
|
|
444
|
+
|
|
445
|
+
if proximity_threshold is None:
|
|
446
|
+
avg_length = sum(a.length for a in arrows) / len(arrows)
|
|
447
|
+
proximity_threshold = avg_length * 0.5
|
|
448
|
+
|
|
449
|
+
# Check for shared tails (divergent) or shared heads (convergent)
|
|
450
|
+
for i in range(len(arrows)):
|
|
451
|
+
for j in range(i + 1, len(arrows)):
|
|
452
|
+
# Shared tail = divergent
|
|
453
|
+
tail_dist = math.hypot(
|
|
454
|
+
arrows[i].tail[0] - arrows[j].tail[0],
|
|
455
|
+
arrows[i].tail[1] - arrows[j].tail[1],
|
|
456
|
+
)
|
|
457
|
+
if tail_dist < proximity_threshold:
|
|
458
|
+
return True
|
|
459
|
+
# Shared head = convergent
|
|
460
|
+
head_dist = math.hypot(
|
|
461
|
+
arrows[i].head[0] - arrows[j].head[0],
|
|
462
|
+
arrows[i].head[1] - arrows[j].head[1],
|
|
463
|
+
)
|
|
464
|
+
if head_dist < proximity_threshold:
|
|
465
|
+
return True
|
|
466
|
+
return False
|
|
467
|
+
|
|
468
|
+
|
|
469
|
+
def classify_layout(arrows: List[ArrowVector]) -> LayoutPattern:
|
|
470
|
+
"""Classify the scheme layout pattern from arrow geometry.
|
|
471
|
+
|
|
472
|
+
Taxonomy (from RxnIM):
|
|
473
|
+
- SINGLE_LINE: all horizontal, single row
|
|
474
|
+
- MULTI_LINE: all horizontal, multiple rows
|
|
475
|
+
- SERPENTINE: horizontal arrows with vertical connectors between rows
|
|
476
|
+
- BRANCH: arrows share tail or head endpoints (divergent/convergent)
|
|
477
|
+
- CYCLE: arrows form a closed polygon
|
|
478
|
+
- MIXED: fallback
|
|
479
|
+
"""
|
|
480
|
+
if not arrows:
|
|
481
|
+
return LayoutPattern.SINGLE_LINE
|
|
482
|
+
if len(arrows) == 1:
|
|
483
|
+
return LayoutPattern.SINGLE_LINE
|
|
484
|
+
|
|
485
|
+
horizontal = [a for a in arrows if _is_horizontal(a)]
|
|
486
|
+
vertical = [a for a in arrows if _is_vertical(a)]
|
|
487
|
+
|
|
488
|
+
# Check for cycle first (arrows closing a loop)
|
|
489
|
+
if _arrows_form_cycle(arrows):
|
|
490
|
+
return LayoutPattern.CYCLE
|
|
491
|
+
|
|
492
|
+
# Check for branching (shared endpoints)
|
|
493
|
+
if _arrows_share_endpoint(arrows):
|
|
494
|
+
return LayoutPattern.BRANCH
|
|
495
|
+
|
|
496
|
+
# Check for serpentine (horizontal + vertical connectors)
|
|
497
|
+
if horizontal and vertical and len(horizontal) >= 2:
|
|
498
|
+
rows = cluster_arrows_into_rows(horizontal)
|
|
499
|
+
if len(rows) >= 2:
|
|
500
|
+
return LayoutPattern.SERPENTINE
|
|
501
|
+
|
|
502
|
+
# All horizontal — check row count
|
|
503
|
+
if len(horizontal) == len(arrows):
|
|
504
|
+
rows = cluster_arrows_into_rows(arrows)
|
|
505
|
+
if len(rows) == 1:
|
|
506
|
+
return LayoutPattern.SINGLE_LINE
|
|
507
|
+
return LayoutPattern.MULTI_LINE
|
|
508
|
+
|
|
509
|
+
# Mostly horizontal with some non-horizontal — still multi-line if clustered
|
|
510
|
+
if len(horizontal) >= len(arrows) * 0.7:
|
|
511
|
+
rows = cluster_arrows_into_rows(arrows)
|
|
512
|
+
if len(rows) == 1:
|
|
513
|
+
return LayoutPattern.SINGLE_LINE
|
|
514
|
+
return LayoutPattern.MULTI_LINE
|
|
515
|
+
|
|
516
|
+
return LayoutPattern.MIXED
|
|
517
|
+
|
|
518
|
+
|
|
519
|
+
# ---------------------------------------------------------------------------
|
|
520
|
+
# Distance scoring
|
|
521
|
+
# ---------------------------------------------------------------------------
|
|
522
|
+
|
|
523
|
+
# Penalty factor for parallel overshoot (fragment is past arrow tip).
|
|
524
|
+
# Higher values make fragments prefer arrows whose span they fall within.
|
|
525
|
+
_PARALLEL_OVERSHOOT_PENALTY = 0.5
|
|
526
|
+
|
|
527
|
+
|
|
528
|
+
def _distance_score(
|
|
529
|
+
point: Tuple[float, float],
|
|
530
|
+
arrow: ArrowVector,
|
|
531
|
+
) -> float:
|
|
532
|
+
"""Combined distance score from a point to an arrow.
|
|
533
|
+
|
|
534
|
+
Score = |perpendicular| + penalty * max(0, overshoot)
|
|
535
|
+
|
|
536
|
+
where overshoot is how far past either arrow tip the point projects.
|
|
537
|
+
Lower score = stronger association.
|
|
538
|
+
"""
|
|
539
|
+
parallel, perp = project_onto_arrow(point, arrow.tail, arrow.head)
|
|
540
|
+
|
|
541
|
+
overshoot = 0.0
|
|
542
|
+
if parallel < 0:
|
|
543
|
+
overshoot = -parallel
|
|
544
|
+
elif parallel > arrow.length:
|
|
545
|
+
overshoot = parallel - arrow.length
|
|
546
|
+
|
|
547
|
+
return abs(perp) + _PARALLEL_OVERSHOOT_PENALTY * overshoot
|
|
548
|
+
|
|
549
|
+
|
|
550
|
+
def _compute_confidence(dist_nearest: float, dist_second: float) -> float:
|
|
551
|
+
"""Confidence from ratio of nearest to second-nearest distances.
|
|
552
|
+
|
|
553
|
+
Returns 1.0 when nearest is far from alternatives, ~0.5 when equidistant.
|
|
554
|
+
"""
|
|
555
|
+
if dist_second <= 0:
|
|
556
|
+
return 1.0
|
|
557
|
+
total = dist_nearest + dist_second
|
|
558
|
+
if total < 1e-9:
|
|
559
|
+
return 0.5
|
|
560
|
+
return 1.0 - (dist_nearest / total)
|
|
561
|
+
|
|
562
|
+
|
|
563
|
+
# ---------------------------------------------------------------------------
|
|
564
|
+
# Role assignment from projection
|
|
565
|
+
# ---------------------------------------------------------------------------
|
|
566
|
+
|
|
567
|
+
def _role_from_projection(
|
|
568
|
+
parallel: float,
|
|
569
|
+
perpendicular: float,
|
|
570
|
+
arrow_length: float,
|
|
571
|
+
) -> str:
|
|
572
|
+
"""Determine role from arrow-relative coordinates.
|
|
573
|
+
|
|
574
|
+
- parallel < 0 → reactant (behind tail)
|
|
575
|
+
- parallel > arrow_length → product (past head)
|
|
576
|
+
- 0 ≤ parallel ≤ arrow_length:
|
|
577
|
+
perpendicular < 0 → above (reagent/condition)
|
|
578
|
+
perpendicular ≥ 0 → below (condition/yield)
|
|
579
|
+
"""
|
|
580
|
+
if parallel < 0:
|
|
581
|
+
return "reactant"
|
|
582
|
+
if parallel > arrow_length:
|
|
583
|
+
return "product"
|
|
584
|
+
if perpendicular < 0:
|
|
585
|
+
return "above"
|
|
586
|
+
return "below"
|
|
587
|
+
|
|
588
|
+
|
|
589
|
+
# ---------------------------------------------------------------------------
|
|
590
|
+
# Core assignment engine
|
|
591
|
+
# ---------------------------------------------------------------------------
|
|
592
|
+
|
|
593
|
+
def _assign_to_nearest_arrow(
|
|
594
|
+
point: Tuple[float, float],
|
|
595
|
+
arrows: List[ArrowVector],
|
|
596
|
+
element_id: str,
|
|
597
|
+
) -> AssignmentResult:
|
|
598
|
+
"""Assign a point to the nearest arrow with role and confidence."""
|
|
599
|
+
if len(arrows) == 1:
|
|
600
|
+
arrow = arrows[0]
|
|
601
|
+
par, perp = project_onto_arrow(point, arrow.tail, arrow.head)
|
|
602
|
+
role = _role_from_projection(par, perp, arrow.length)
|
|
603
|
+
return AssignmentResult(
|
|
604
|
+
element_id=element_id,
|
|
605
|
+
arrow_id=arrow.element_id,
|
|
606
|
+
role=role,
|
|
607
|
+
confidence=1.0,
|
|
608
|
+
distance=abs(perp),
|
|
609
|
+
)
|
|
610
|
+
|
|
611
|
+
# Score against all arrows
|
|
612
|
+
scores = []
|
|
613
|
+
for arrow in arrows:
|
|
614
|
+
score = _distance_score(point, arrow)
|
|
615
|
+
scores.append((score, arrow))
|
|
616
|
+
scores.sort(key=lambda s: s[0])
|
|
617
|
+
|
|
618
|
+
best_score, best_arrow = scores[0]
|
|
619
|
+
second_score = scores[1][0] if len(scores) > 1 else best_score * 10
|
|
620
|
+
|
|
621
|
+
par, perp = project_onto_arrow(point, best_arrow.tail, best_arrow.head)
|
|
622
|
+
role = _role_from_projection(par, perp, best_arrow.length)
|
|
623
|
+
confidence = _compute_confidence(best_score, second_score)
|
|
624
|
+
|
|
625
|
+
return AssignmentResult(
|
|
626
|
+
element_id=element_id,
|
|
627
|
+
arrow_id=best_arrow.element_id,
|
|
628
|
+
role=role,
|
|
629
|
+
confidence=confidence,
|
|
630
|
+
distance=abs(perp),
|
|
631
|
+
)
|
|
632
|
+
|
|
633
|
+
|
|
634
|
+
def _assign_text_to_arrow(
|
|
635
|
+
text: TextInfo,
|
|
636
|
+
arrows: List[ArrowVector],
|
|
637
|
+
margin: float = 30.0,
|
|
638
|
+
max_perp: float = ACS_BOND_LENGTH * 2.5,
|
|
639
|
+
) -> Optional[AssignmentResult]:
|
|
640
|
+
"""Assign a text element to an arrow.
|
|
641
|
+
|
|
642
|
+
Text is only assigned if it falls within the arrow's parallel span
|
|
643
|
+
(extended by *margin* on each side) AND within *max_perp* perpendicular
|
|
644
|
+
distance. This prevents distant text from being mis-assigned.
|
|
645
|
+
"""
|
|
646
|
+
best: Optional[AssignmentResult] = None
|
|
647
|
+
best_score = float("inf")
|
|
648
|
+
|
|
649
|
+
for arrow in arrows:
|
|
650
|
+
par, perp = project_onto_arrow(text.position, arrow.tail, arrow.head)
|
|
651
|
+
|
|
652
|
+
# Check parallel range (within arrow span + margin)
|
|
653
|
+
if par < -margin or par > arrow.length + margin:
|
|
654
|
+
continue
|
|
655
|
+
# Check perpendicular range
|
|
656
|
+
if abs(perp) > max_perp:
|
|
657
|
+
continue
|
|
658
|
+
|
|
659
|
+
score = _distance_score(text.position, arrow)
|
|
660
|
+
if score < best_score:
|
|
661
|
+
best_score = score
|
|
662
|
+
role = "above" if perp < 0 else "below"
|
|
663
|
+
best = AssignmentResult(
|
|
664
|
+
element_id=text.element_id,
|
|
665
|
+
arrow_id=arrow.element_id,
|
|
666
|
+
role=role,
|
|
667
|
+
confidence=1.0, # refined below
|
|
668
|
+
distance=abs(perp),
|
|
669
|
+
)
|
|
670
|
+
|
|
671
|
+
# Compute confidence if we found a match
|
|
672
|
+
if best is not None and len(arrows) > 1:
|
|
673
|
+
scores_all = []
|
|
674
|
+
for arrow in arrows:
|
|
675
|
+
par, perp = project_onto_arrow(text.position, arrow.tail, arrow.head)
|
|
676
|
+
if -margin <= par <= arrow.length + margin and abs(perp) <= max_perp:
|
|
677
|
+
scores_all.append(_distance_score(text.position, arrow))
|
|
678
|
+
if len(scores_all) >= 2:
|
|
679
|
+
scores_all.sort()
|
|
680
|
+
best.confidence = _compute_confidence(scores_all[0], scores_all[1])
|
|
681
|
+
|
|
682
|
+
return best
|
|
683
|
+
|
|
684
|
+
|
|
685
|
+
# ---------------------------------------------------------------------------
|
|
686
|
+
# Layout-specific strategies
|
|
687
|
+
# ---------------------------------------------------------------------------
|
|
688
|
+
|
|
689
|
+
def _assign_single_line(
|
|
690
|
+
arrows: List[ArrowVector],
|
|
691
|
+
fragments: List[FragmentInfo],
|
|
692
|
+
texts: List[TextInfo],
|
|
693
|
+
) -> Tuple[List[RawStep], List[AssignmentResult]]:
|
|
694
|
+
"""Assignment for single-row layouts (horizontal or any orientation)."""
|
|
695
|
+
steps: List[RawStep] = []
|
|
696
|
+
results: List[AssignmentResult] = []
|
|
697
|
+
|
|
698
|
+
# Create a step per arrow
|
|
699
|
+
for arrow in arrows:
|
|
700
|
+
steps.append(RawStep(
|
|
701
|
+
arrow_id=arrow.element_id,
|
|
702
|
+
arrow_element=arrow.element,
|
|
703
|
+
))
|
|
704
|
+
|
|
705
|
+
# Assign fragments
|
|
706
|
+
for frag in fragments:
|
|
707
|
+
result = _assign_to_nearest_arrow(frag.centroid, arrows, frag.element_id)
|
|
708
|
+
results.append(result)
|
|
709
|
+
# Find the corresponding step
|
|
710
|
+
for step in steps:
|
|
711
|
+
if step.arrow_id == result.arrow_id:
|
|
712
|
+
if result.role == "reactant":
|
|
713
|
+
step.reactant_ids.append(frag.element_id)
|
|
714
|
+
elif result.role == "product":
|
|
715
|
+
step.product_ids.append(frag.element_id)
|
|
716
|
+
elif result.role == "above":
|
|
717
|
+
step.above_arrow_ids.append(frag.element_id)
|
|
718
|
+
elif result.role == "below":
|
|
719
|
+
step.below_arrow_ids.append(frag.element_id)
|
|
720
|
+
break
|
|
721
|
+
|
|
722
|
+
# Assign texts
|
|
723
|
+
for text in texts:
|
|
724
|
+
result = _assign_text_to_arrow(text, arrows)
|
|
725
|
+
if result is not None:
|
|
726
|
+
results.append(result)
|
|
727
|
+
for step in steps:
|
|
728
|
+
if step.arrow_id == result.arrow_id:
|
|
729
|
+
if result.role == "above":
|
|
730
|
+
step.above_arrow_ids.append(text.element_id)
|
|
731
|
+
else:
|
|
732
|
+
step.below_arrow_ids.append(text.element_id)
|
|
733
|
+
break
|
|
734
|
+
|
|
735
|
+
# Compute step-level confidence
|
|
736
|
+
for step in steps:
|
|
737
|
+
step_results = [r for r in results if r.arrow_id == step.arrow_id]
|
|
738
|
+
if step_results:
|
|
739
|
+
step.confidence = sum(r.confidence for r in step_results) / len(step_results)
|
|
740
|
+
|
|
741
|
+
return steps, results
|
|
742
|
+
|
|
743
|
+
|
|
744
|
+
def _assign_multi_line(
|
|
745
|
+
arrows: List[ArrowVector],
|
|
746
|
+
fragments: List[FragmentInfo],
|
|
747
|
+
texts: List[TextInfo],
|
|
748
|
+
) -> Tuple[List[RawStep], List[AssignmentResult]]:
|
|
749
|
+
"""Assignment for multi-row horizontal layouts.
|
|
750
|
+
|
|
751
|
+
1. Cluster arrows into rows
|
|
752
|
+
2. For each row, assign fragments/texts using proximity
|
|
753
|
+
3. Link cross-row intermediates (last product row N -> first reactant row N+1)
|
|
754
|
+
"""
|
|
755
|
+
rows = cluster_arrows_into_rows(arrows)
|
|
756
|
+
|
|
757
|
+
all_steps: List[RawStep] = []
|
|
758
|
+
all_results: List[AssignmentResult] = []
|
|
759
|
+
|
|
760
|
+
for row_idx, row_arrows in enumerate(rows):
|
|
761
|
+
# Filter fragments/texts to those closest to this row
|
|
762
|
+
row_y_center = sum(a.midpoint[1] for a in row_arrows) / len(row_arrows)
|
|
763
|
+
row_y_half_span = max(
|
|
764
|
+
(a.length for a in row_arrows), default=ACS_BOND_LENGTH * 3
|
|
765
|
+
)
|
|
766
|
+
|
|
767
|
+
row_frags = [f for f in fragments
|
|
768
|
+
if abs(f.centroid[1] - row_y_center) < row_y_half_span]
|
|
769
|
+
row_texts = [t for t in texts
|
|
770
|
+
if abs(t.position[1] - row_y_center) < row_y_half_span]
|
|
771
|
+
|
|
772
|
+
row_steps, row_results = _assign_single_line(row_arrows, row_frags, row_texts)
|
|
773
|
+
|
|
774
|
+
for step in row_steps:
|
|
775
|
+
step.layout_row = row_idx
|
|
776
|
+
|
|
777
|
+
all_steps.extend(row_steps)
|
|
778
|
+
all_results.extend(row_results)
|
|
779
|
+
|
|
780
|
+
# Link cross-row intermediates
|
|
781
|
+
_link_cross_row_intermediates(all_steps, rows)
|
|
782
|
+
|
|
783
|
+
return all_steps, all_results
|
|
784
|
+
|
|
785
|
+
|
|
786
|
+
def _assign_serpentine(
|
|
787
|
+
arrows: List[ArrowVector],
|
|
788
|
+
fragments: List[FragmentInfo],
|
|
789
|
+
texts: List[TextInfo],
|
|
790
|
+
) -> Tuple[List[RawStep], List[AssignmentResult]]:
|
|
791
|
+
"""Assignment for serpentine layouts (horizontal arrows + vertical connectors).
|
|
792
|
+
|
|
793
|
+
Vertical arrows are treated as connectors between rows, not as independent
|
|
794
|
+
reaction steps with their own reactants/products.
|
|
795
|
+
"""
|
|
796
|
+
horizontal = [a for a in arrows if _is_horizontal(a)]
|
|
797
|
+
vertical = [a for a in arrows if _is_vertical(a)]
|
|
798
|
+
|
|
799
|
+
# Assign using horizontal arrows only
|
|
800
|
+
steps, results = _assign_multi_line(horizontal, fragments, texts)
|
|
801
|
+
|
|
802
|
+
# Vertical arrows become connector metadata (not separate steps)
|
|
803
|
+
# They link end-of-row to start-of-next-row
|
|
804
|
+
rows = cluster_arrows_into_rows(horizontal)
|
|
805
|
+
for vert in vertical:
|
|
806
|
+
# Find which row boundary this vertical arrow bridges
|
|
807
|
+
for row_idx in range(len(rows) - 1):
|
|
808
|
+
row_bottom = max(a.midpoint[1] for a in rows[row_idx])
|
|
809
|
+
next_row_top = min(a.midpoint[1] for a in rows[row_idx + 1])
|
|
810
|
+
if row_bottom <= vert.midpoint[1] <= next_row_top:
|
|
811
|
+
# This vertical arrow bridges row_idx and row_idx+1
|
|
812
|
+
# Ensure the last product of row_idx is linked to first
|
|
813
|
+
# reactant of row_idx+1
|
|
814
|
+
break
|
|
815
|
+
|
|
816
|
+
return steps, results
|
|
817
|
+
|
|
818
|
+
|
|
819
|
+
def _assign_branch(
|
|
820
|
+
arrows: List[ArrowVector],
|
|
821
|
+
fragments: List[FragmentInfo],
|
|
822
|
+
texts: List[TextInfo],
|
|
823
|
+
) -> Tuple[List[RawStep], List[AssignmentResult]]:
|
|
824
|
+
"""Assignment for divergent/convergent branching layouts.
|
|
825
|
+
|
|
826
|
+
Identifies shared fragments (near multiple arrow tails or heads) and
|
|
827
|
+
assigns them as shared reactants (divergent) or shared products (convergent).
|
|
828
|
+
"""
|
|
829
|
+
# Fall back to the general nearest-arrow assignment
|
|
830
|
+
# The shared-endpoint logic is handled by the fact that a fragment near
|
|
831
|
+
# a shared tail region will be closest to multiple arrows — we assign it
|
|
832
|
+
# to all arrows that share that endpoint
|
|
833
|
+
steps, results = _assign_single_line(arrows, fragments, texts)
|
|
834
|
+
|
|
835
|
+
# Post-process: detect shared endpoints and duplicate assignments
|
|
836
|
+
avg_length = sum(a.length for a in arrows) / len(arrows)
|
|
837
|
+
proximity = avg_length * 0.5
|
|
838
|
+
|
|
839
|
+
# Find arrows with shared tails (divergent)
|
|
840
|
+
for i in range(len(arrows)):
|
|
841
|
+
for j in range(i + 1, len(arrows)):
|
|
842
|
+
tail_dist = math.hypot(
|
|
843
|
+
arrows[i].tail[0] - arrows[j].tail[0],
|
|
844
|
+
arrows[i].tail[1] - arrows[j].tail[1],
|
|
845
|
+
)
|
|
846
|
+
if tail_dist < proximity:
|
|
847
|
+
# These arrows share a tail — ensure they share reactants
|
|
848
|
+
step_i = next((s for s in steps if s.arrow_id == arrows[i].element_id), None)
|
|
849
|
+
step_j = next((s for s in steps if s.arrow_id == arrows[j].element_id), None)
|
|
850
|
+
if step_i and step_j:
|
|
851
|
+
# Share reactants between the two steps
|
|
852
|
+
shared = set(step_i.reactant_ids) | set(step_j.reactant_ids)
|
|
853
|
+
step_i.reactant_ids = list(shared)
|
|
854
|
+
step_j.reactant_ids = list(shared)
|
|
855
|
+
|
|
856
|
+
# Find arrows with shared heads (convergent)
|
|
857
|
+
for i in range(len(arrows)):
|
|
858
|
+
for j in range(i + 1, len(arrows)):
|
|
859
|
+
head_dist = math.hypot(
|
|
860
|
+
arrows[i].head[0] - arrows[j].head[0],
|
|
861
|
+
arrows[i].head[1] - arrows[j].head[1],
|
|
862
|
+
)
|
|
863
|
+
if head_dist < proximity:
|
|
864
|
+
step_i = next((s for s in steps if s.arrow_id == arrows[i].element_id), None)
|
|
865
|
+
step_j = next((s for s in steps if s.arrow_id == arrows[j].element_id), None)
|
|
866
|
+
if step_i and step_j:
|
|
867
|
+
shared = set(step_i.product_ids) | set(step_j.product_ids)
|
|
868
|
+
step_i.product_ids = list(shared)
|
|
869
|
+
step_j.product_ids = list(shared)
|
|
870
|
+
|
|
871
|
+
return steps, results
|
|
872
|
+
|
|
873
|
+
|
|
874
|
+
def _assign_cycle(
|
|
875
|
+
arrows: List[ArrowVector],
|
|
876
|
+
fragments: List[FragmentInfo],
|
|
877
|
+
texts: List[TextInfo],
|
|
878
|
+
) -> Tuple[List[RawStep], List[AssignmentResult]]:
|
|
879
|
+
"""Assignment for cyclic reaction networks (catalytic cycles).
|
|
880
|
+
|
|
881
|
+
Orders arrows around the cycle and assigns fragments between consecutive
|
|
882
|
+
arrow endpoints.
|
|
883
|
+
"""
|
|
884
|
+
# Order arrows into a cycle by chaining head -> nearest tail
|
|
885
|
+
ordered = _order_arrows_cyclic(arrows)
|
|
886
|
+
|
|
887
|
+
steps, results = _assign_single_line(ordered, fragments, texts)
|
|
888
|
+
|
|
889
|
+
# In a cycle, the "product" of the last step should connect back to the
|
|
890
|
+
# "reactant" of the first step. We don't modify the assignment — the
|
|
891
|
+
# topology detector in scheme_reader will recognize this as a cycle.
|
|
892
|
+
|
|
893
|
+
return steps, results
|
|
894
|
+
|
|
895
|
+
|
|
896
|
+
def _order_arrows_cyclic(arrows: List[ArrowVector]) -> List[ArrowVector]:
|
|
897
|
+
"""Order arrows into a cycle by greedily chaining head_i -> nearest tail_j."""
|
|
898
|
+
if len(arrows) <= 1:
|
|
899
|
+
return list(arrows)
|
|
900
|
+
|
|
901
|
+
remaining = list(arrows)
|
|
902
|
+
ordered = [remaining.pop(0)]
|
|
903
|
+
|
|
904
|
+
while remaining:
|
|
905
|
+
last_head = ordered[-1].head
|
|
906
|
+
# Find the arrow whose tail is closest to last_head
|
|
907
|
+
best_idx = 0
|
|
908
|
+
best_dist = float("inf")
|
|
909
|
+
for idx, arrow in enumerate(remaining):
|
|
910
|
+
dist = math.hypot(
|
|
911
|
+
last_head[0] - arrow.tail[0],
|
|
912
|
+
last_head[1] - arrow.tail[1],
|
|
913
|
+
)
|
|
914
|
+
if dist < best_dist:
|
|
915
|
+
best_dist = dist
|
|
916
|
+
best_idx = idx
|
|
917
|
+
ordered.append(remaining.pop(best_idx))
|
|
918
|
+
|
|
919
|
+
return ordered
|
|
920
|
+
|
|
921
|
+
|
|
922
|
+
# ---------------------------------------------------------------------------
|
|
923
|
+
# Shared intermediate handling
|
|
924
|
+
# ---------------------------------------------------------------------------
|
|
925
|
+
|
|
926
|
+
def _link_shared_intermediates(steps: List[RawStep]) -> None:
|
|
927
|
+
"""Propagate products to next step's reactants when reactants are empty.
|
|
928
|
+
|
|
929
|
+
For sequential schemes, the product of step i is the reactant of step i+1.
|
|
930
|
+
When step i+1 has no reactants, copy step i's products.
|
|
931
|
+
"""
|
|
932
|
+
for i in range(len(steps) - 1):
|
|
933
|
+
if not steps[i + 1].reactant_ids and steps[i].product_ids:
|
|
934
|
+
steps[i + 1].reactant_ids = list(steps[i].product_ids)
|
|
935
|
+
|
|
936
|
+
|
|
937
|
+
def _link_cross_row_intermediates(
|
|
938
|
+
steps: List[RawStep],
|
|
939
|
+
rows: List[List[ArrowVector]],
|
|
940
|
+
) -> None:
|
|
941
|
+
"""Link the last step of row N to the first step of row N+1.
|
|
942
|
+
|
|
943
|
+
For multi-line layouts, the last product of row N wraps to become the
|
|
944
|
+
first reactant of row N+1.
|
|
945
|
+
"""
|
|
946
|
+
if len(rows) < 2:
|
|
947
|
+
_link_shared_intermediates(steps)
|
|
948
|
+
return
|
|
949
|
+
|
|
950
|
+
# Group steps by row
|
|
951
|
+
row_steps: Dict[int, List[RawStep]] = {}
|
|
952
|
+
for step in steps:
|
|
953
|
+
row_steps.setdefault(step.layout_row, []).append(step)
|
|
954
|
+
|
|
955
|
+
# Within each row, link shared intermediates
|
|
956
|
+
for row_idx in sorted(row_steps):
|
|
957
|
+
_link_shared_intermediates(row_steps[row_idx])
|
|
958
|
+
|
|
959
|
+
# Across rows: last step of row N -> first step of row N+1
|
|
960
|
+
sorted_rows = sorted(row_steps.keys())
|
|
961
|
+
for i in range(len(sorted_rows) - 1):
|
|
962
|
+
curr_row = row_steps[sorted_rows[i]]
|
|
963
|
+
next_row = row_steps[sorted_rows[i + 1]]
|
|
964
|
+
if curr_row and next_row:
|
|
965
|
+
last_step = curr_row[-1]
|
|
966
|
+
first_step = next_row[0]
|
|
967
|
+
if last_step.product_ids and not first_step.reactant_ids:
|
|
968
|
+
first_step.reactant_ids = list(last_step.product_ids)
|
|
969
|
+
|
|
970
|
+
|
|
971
|
+
# ---------------------------------------------------------------------------
|
|
972
|
+
# Public API
|
|
973
|
+
# ---------------------------------------------------------------------------
|
|
974
|
+
|
|
975
|
+
def assign_elements(
|
|
976
|
+
arrows: List[ArrowVector],
|
|
977
|
+
page: ET.Element,
|
|
978
|
+
layout: Optional[LayoutPattern] = None,
|
|
979
|
+
) -> Tuple[List[RawStep], List[AssignmentResult]]:
|
|
980
|
+
"""Assign all fragments and texts on the page to arrows.
|
|
981
|
+
|
|
982
|
+
This is the main entry point. If *layout* is None, it is auto-detected.
|
|
983
|
+
|
|
984
|
+
Returns:
|
|
985
|
+
- steps: list of :class:`RawStep`, one per arrow
|
|
986
|
+
- results: list of :class:`AssignmentResult`, one per assigned element
|
|
987
|
+
"""
|
|
988
|
+
if not arrows:
|
|
989
|
+
return [], []
|
|
990
|
+
|
|
991
|
+
if layout is None:
|
|
992
|
+
layout = classify_layout(arrows)
|
|
993
|
+
|
|
994
|
+
fragments = collect_fragments(page)
|
|
995
|
+
texts = collect_texts(page)
|
|
996
|
+
|
|
997
|
+
dispatch = {
|
|
998
|
+
LayoutPattern.SINGLE_LINE: _assign_single_line,
|
|
999
|
+
LayoutPattern.MULTI_LINE: _assign_multi_line,
|
|
1000
|
+
LayoutPattern.SERPENTINE: _assign_serpentine,
|
|
1001
|
+
LayoutPattern.BRANCH: _assign_branch,
|
|
1002
|
+
LayoutPattern.CYCLE: _assign_cycle,
|
|
1003
|
+
LayoutPattern.MIXED: _assign_single_line, # fallback
|
|
1004
|
+
}
|
|
1005
|
+
|
|
1006
|
+
strategy = dispatch.get(layout, _assign_single_line)
|
|
1007
|
+
steps, results = strategy(arrows, fragments, texts)
|
|
1008
|
+
|
|
1009
|
+
# Final pass: link shared intermediates for sequential schemes
|
|
1010
|
+
if layout in (LayoutPattern.SINGLE_LINE, LayoutPattern.MIXED):
|
|
1011
|
+
_link_shared_intermediates(steps)
|
|
1012
|
+
|
|
1013
|
+
return steps, results
|