retrocast 0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- retrocast/__init__.py +56 -0
- retrocast/adapters/__init__.py +169 -0
- retrocast/adapters/aizynth_adapter.py +88 -0
- retrocast/adapters/askcos_adapter.py +256 -0
- retrocast/adapters/base_adapter.py +43 -0
- retrocast/adapters/common.py +193 -0
- retrocast/adapters/dms_adapter.py +144 -0
- retrocast/adapters/dreamretro_adapter.py +112 -0
- retrocast/adapters/factory.py +37 -0
- retrocast/adapters/multistepttl_adapter.py +125 -0
- retrocast/adapters/paroutes_adapter.py +263 -0
- retrocast/adapters/retrochimera_adapter.py +186 -0
- retrocast/adapters/retrostar_adapter.py +115 -0
- retrocast/adapters/synllama_adapter.py +163 -0
- retrocast/adapters/synplanner_adapter.py +148 -0
- retrocast/adapters/syntheseus_adapter.py +89 -0
- retrocast/api.py +79 -0
- retrocast/chem.py +196 -0
- retrocast/cli/__init__.py +0 -0
- retrocast/cli/adhoc.py +338 -0
- retrocast/cli/handlers.py +462 -0
- retrocast/cli/main.py +189 -0
- retrocast/curation/__init__.py +15 -0
- retrocast/curation/filtering.py +205 -0
- retrocast/curation/generators.py +296 -0
- retrocast/curation/sampling.py +142 -0
- retrocast/exceptions.py +52 -0
- retrocast/io/__init__.py +39 -0
- retrocast/io/blob.py +43 -0
- retrocast/io/data.py +312 -0
- retrocast/io/provenance.py +233 -0
- retrocast/metrics/__init__.py +0 -0
- retrocast/metrics/bootstrap.py +200 -0
- retrocast/metrics/diversity.py +0 -0
- retrocast/metrics/ranking.py +109 -0
- retrocast/metrics/similarity.py +28 -0
- retrocast/metrics/solvability.py +24 -0
- retrocast/models/__init__.py +0 -0
- retrocast/models/benchmark.py +274 -0
- retrocast/models/chem.py +372 -0
- retrocast/models/evaluation.py +54 -0
- retrocast/models/provenance.py +69 -0
- retrocast/models/stats.py +59 -0
- retrocast/resources/__init__.py +0 -0
- retrocast/resources/default_config.yaml +54 -0
- retrocast/typing.py +11 -0
- retrocast/utils/__init__.py +0 -0
- retrocast/utils/logging.py +35 -0
- retrocast/utils/serializers.py +211 -0
- retrocast/visualization/adapters.py +289 -0
- retrocast/visualization/model_performance.py +464 -0
- retrocast/visualization/plots.py +386 -0
- retrocast/visualization/report.py +258 -0
- retrocast/visualization/routes.py +244 -0
- retrocast/visualization/theme.py +127 -0
- retrocast/workflow/__init__.py +0 -0
- retrocast/workflow/analyze.py +58 -0
- retrocast/workflow/ingest.py +122 -0
- retrocast/workflow/score.py +107 -0
- retrocast/workflow/verify.py +192 -0
- retrocast-0.dist-info/METADATA +221 -0
- retrocast-0.dist-info/RECORD +66 -0
- retrocast-0.dist-info/WHEEL +5 -0
- retrocast-0.dist-info/entry_points.txt +2 -0
- retrocast-0.dist-info/licenses/LICENSE +21 -0
- retrocast-0.dist-info/top_level.txt +1 -0
retrocast/__init__.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
"""
|
|
2
|
+
retrocast: A unified toolkit for retrosynthesis benchmark analysis.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from importlib.metadata import PackageNotFoundError, version
|
|
6
|
+
|
|
7
|
+
from packaging.version import Version
|
|
8
|
+
|
|
9
|
+
from retrocast.adapters import ADAPTER_MAP, adapt_routes, adapt_single_route, get_adapter
|
|
10
|
+
from retrocast.curation.filtering import deduplicate_routes
|
|
11
|
+
from retrocast.curation.sampling import sample_k_by_length, sample_random_k, sample_top_k
|
|
12
|
+
from retrocast.models.chem import Molecule, ReactionStep, Route, TargetInput
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _normalize_version_with_patch(version_str: str) -> str:
|
|
16
|
+
"""Ensure version always has explicit major.minor.micro format (e.g., 0.3.0.dev16 not 0.3.dev16)."""
|
|
17
|
+
v = Version(version_str)
|
|
18
|
+
# Reconstruct with explicit patch version
|
|
19
|
+
base = f"{v.major}.{v.minor}.{v.micro}"
|
|
20
|
+
|
|
21
|
+
# Add pre-release, post-release, dev, local parts if present
|
|
22
|
+
parts = [base]
|
|
23
|
+
if v.pre:
|
|
24
|
+
parts.append(f"{v.pre[0]}{v.pre[1]}")
|
|
25
|
+
if v.post is not None:
|
|
26
|
+
parts.append(f".post{v.post}")
|
|
27
|
+
if v.dev is not None:
|
|
28
|
+
parts.append(f".dev{v.dev}")
|
|
29
|
+
if v.local:
|
|
30
|
+
parts.append(f"+{v.local}")
|
|
31
|
+
|
|
32
|
+
return "".join(parts)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
try:
|
|
36
|
+
__version__ = _normalize_version_with_patch(version("retrocast"))
|
|
37
|
+
except PackageNotFoundError:
|
|
38
|
+
# Package not installed (running from source without editable install)
|
|
39
|
+
__version__ = "0.0.0.dev0+unknown"
|
|
40
|
+
__all__ = [
|
|
41
|
+
# Core schemas
|
|
42
|
+
"Route",
|
|
43
|
+
"Molecule",
|
|
44
|
+
"ReactionStep",
|
|
45
|
+
"TargetInput",
|
|
46
|
+
# Adapter functions
|
|
47
|
+
"adapt_single_route",
|
|
48
|
+
"adapt_routes",
|
|
49
|
+
"get_adapter",
|
|
50
|
+
"ADAPTER_MAP",
|
|
51
|
+
# Route processing utilities
|
|
52
|
+
"deduplicate_routes",
|
|
53
|
+
"sample_top_k",
|
|
54
|
+
"sample_random_k",
|
|
55
|
+
"sample_k_by_length",
|
|
56
|
+
]
|
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from retrocast.adapters.aizynth_adapter import AizynthAdapter
|
|
4
|
+
from retrocast.adapters.askcos_adapter import AskcosAdapter
|
|
5
|
+
from retrocast.adapters.base_adapter import BaseAdapter
|
|
6
|
+
from retrocast.adapters.dms_adapter import DMSAdapter
|
|
7
|
+
from retrocast.adapters.dreamretro_adapter import DreamRetroAdapter
|
|
8
|
+
from retrocast.adapters.multistepttl_adapter import TtlRetroAdapter
|
|
9
|
+
from retrocast.adapters.paroutes_adapter import PaRoutesAdapter
|
|
10
|
+
from retrocast.adapters.retrochimera_adapter import RetrochimeraAdapter
|
|
11
|
+
from retrocast.adapters.retrostar_adapter import RetroStarAdapter
|
|
12
|
+
from retrocast.adapters.synllama_adapter import SynLlaMaAdapter
|
|
13
|
+
from retrocast.adapters.synplanner_adapter import SynPlannerAdapter
|
|
14
|
+
from retrocast.adapters.syntheseus_adapter import SyntheseusAdapter
|
|
15
|
+
from retrocast.exceptions import RetroCastException
|
|
16
|
+
from retrocast.models.chem import Route, TargetIdentity
|
|
17
|
+
|
|
18
|
+
ADAPTER_MAP: dict[str, BaseAdapter] = {
|
|
19
|
+
"aizynth": AizynthAdapter(),
|
|
20
|
+
"askcos": AskcosAdapter(),
|
|
21
|
+
"dms": DMSAdapter(),
|
|
22
|
+
"dreamretro": DreamRetroAdapter(),
|
|
23
|
+
"multistepttl": TtlRetroAdapter(),
|
|
24
|
+
"paroutes": PaRoutesAdapter(),
|
|
25
|
+
"retrochimera": RetrochimeraAdapter(),
|
|
26
|
+
"retrostar": RetroStarAdapter(),
|
|
27
|
+
"synplanner": SynPlannerAdapter(),
|
|
28
|
+
"syntheseus": SyntheseusAdapter(),
|
|
29
|
+
"synllama": SynLlaMaAdapter(),
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
# Adapters that expect target-centric data format (dict with metadata + nested routes)
|
|
33
|
+
# vs route-centric format (list of route objects)
|
|
34
|
+
TARGET_CENTRIC_ADAPTERS = {"askcos", "retrochimera", "paroutes"}
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def get_adapter(adapter_name: str) -> BaseAdapter:
|
|
38
|
+
"""
|
|
39
|
+
Retrieves an adapter instance from the `ADAPTER_MAP`.
|
|
40
|
+
"""
|
|
41
|
+
adapter = ADAPTER_MAP.get(adapter_name)
|
|
42
|
+
if adapter is None:
|
|
43
|
+
raise RetroCastException(
|
|
44
|
+
f"unknown adapter '{adapter_name}'. Check `retrocast.adapters.ADAPTER_MAP` for available adapters."
|
|
45
|
+
)
|
|
46
|
+
return adapter
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def adapt_single_route(
|
|
50
|
+
raw_route: Any,
|
|
51
|
+
target: TargetIdentity,
|
|
52
|
+
adapter_name: str,
|
|
53
|
+
) -> Route | None:
|
|
54
|
+
"""
|
|
55
|
+
Adapt a single raw route to the unified Route format.
|
|
56
|
+
|
|
57
|
+
This is a convenience function for users who want to adapt individual routes
|
|
58
|
+
programmatically without the full batch processing pipeline. It intelligently
|
|
59
|
+
handles both route-centric and target-centric adapter formats.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
raw_route: A single route or target data in the model's native format.
|
|
63
|
+
- For route-centric adapters (DMS, AiZynth, SynPlanner): Pass a single
|
|
64
|
+
route object/dict from the model's output list.
|
|
65
|
+
- For target-centric adapters (RetroChimera, ASKCOS): Pass the complete
|
|
66
|
+
target data dict (containing target metadata and nested routes).
|
|
67
|
+
target: Target molecule information (id and canonical SMILES).
|
|
68
|
+
adapter_name: Name of the adapter to use (e.g., "dms", "aizynth", "retrostar").
|
|
69
|
+
See ADAPTER_MAP.keys() for available adapters.
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
Route object if successful, None if adaptation failed.
|
|
73
|
+
|
|
74
|
+
Examples:
|
|
75
|
+
Route-centric adapter (DMS):
|
|
76
|
+
>>> from retrocast.adapters import adapt_single_route
|
|
77
|
+
>>> from retrocast.models.chem import TargetIdentity
|
|
78
|
+
>>>
|
|
79
|
+
>>> target = TargetIdentity(id="aspirin", smiles="CC(=O)Oc1ccccc1C(=O)O")
|
|
80
|
+
>>> raw_dms_route = {"smiles": "CC(=O)Oc1ccccc1C(=O)O", "children": [...]}
|
|
81
|
+
>>>
|
|
82
|
+
>>> route = adapt_single_route(raw_dms_route, target, "dms")
|
|
83
|
+
>>> if route:
|
|
84
|
+
... print(f"Route depth: {route.length}")
|
|
85
|
+
... print(f"Starting materials: {len(route.leaves)}")
|
|
86
|
+
|
|
87
|
+
Target-centric adapter (RetroChimera):
|
|
88
|
+
>>> target = TargetIdentity(id="mol1", smiles="CCO")
|
|
89
|
+
>>> retrochimera_data = {
|
|
90
|
+
... "smiles": "CCO",
|
|
91
|
+
... "result": {"outputs": [{"routes": [...]}]}
|
|
92
|
+
... }
|
|
93
|
+
>>> route = adapt_single_route(retrochimera_data, target, "retrochimera")
|
|
94
|
+
"""
|
|
95
|
+
adapter = get_adapter(adapter_name)
|
|
96
|
+
|
|
97
|
+
# Determine if this adapter expects target-centric or route-centric format
|
|
98
|
+
if adapter_name in TARGET_CENTRIC_ADAPTERS:
|
|
99
|
+
# Target-centric adapters (RetroChimera, ASKCOS) expect a dict directly
|
|
100
|
+
raw_data = raw_route
|
|
101
|
+
else:
|
|
102
|
+
# Route-centric adapters (DMS, AiZynth, etc.) expect a list of routes
|
|
103
|
+
raw_data = [raw_route] if not isinstance(raw_route, list) else raw_route
|
|
104
|
+
|
|
105
|
+
# Get first successful route from the generator
|
|
106
|
+
for route in adapter.cast(raw_data, target):
|
|
107
|
+
return route
|
|
108
|
+
|
|
109
|
+
return None
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def adapt_routes(
|
|
113
|
+
raw_routes: Any,
|
|
114
|
+
target: TargetIdentity,
|
|
115
|
+
adapter_name: str,
|
|
116
|
+
max_routes: int | None = None,
|
|
117
|
+
) -> list[Route]:
|
|
118
|
+
"""
|
|
119
|
+
Adapt multiple raw routes to the unified Route format.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
raw_routes: Routes in the model's native format (typically a list or dict).
|
|
123
|
+
target: Target molecule information (id and canonical SMILES).
|
|
124
|
+
adapter_name: Name of the adapter to use.
|
|
125
|
+
max_routes: Maximum number of routes to return (None for all successful routes).
|
|
126
|
+
|
|
127
|
+
Returns:
|
|
128
|
+
List of successfully adapted Route objects.
|
|
129
|
+
|
|
130
|
+
Example:
|
|
131
|
+
>>> from retrocast.adapters import adapt_routes
|
|
132
|
+
>>> from retrocast.models.chem import TargetIdentity
|
|
133
|
+
>>>
|
|
134
|
+
>>> target = TargetIdentity(id="ibuprofen", smiles="CC(C)Cc1ccc(cc1)C(C)C(=O)O")
|
|
135
|
+
>>> raw_routes = [route1, route2, route3, ...] # Your model's output
|
|
136
|
+
>>>
|
|
137
|
+
>>> routes = adapt_routes(raw_routes, target, "aizynth", max_routes=10)
|
|
138
|
+
>>> print(f"Adapted {len(routes)} routes successfully")
|
|
139
|
+
"""
|
|
140
|
+
adapter = get_adapter(adapter_name)
|
|
141
|
+
routes = []
|
|
142
|
+
|
|
143
|
+
for i, route in enumerate(adapter.cast(raw_routes, target)):
|
|
144
|
+
routes.append(route)
|
|
145
|
+
if max_routes and i + 1 >= max_routes:
|
|
146
|
+
break
|
|
147
|
+
|
|
148
|
+
return routes
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
__all__ = [
|
|
152
|
+
"adapt_single_route",
|
|
153
|
+
"adapt_routes",
|
|
154
|
+
"get_adapter",
|
|
155
|
+
"ADAPTER_MAP",
|
|
156
|
+
"TARGET_CENTRIC_ADAPTERS",
|
|
157
|
+
"BaseAdapter",
|
|
158
|
+
"AizynthAdapter",
|
|
159
|
+
"AskcosAdapter",
|
|
160
|
+
"DMSAdapter",
|
|
161
|
+
"DreamRetroAdapter",
|
|
162
|
+
"TtlRetroAdapter",
|
|
163
|
+
"PaRoutesAdapter",
|
|
164
|
+
"RetrochimeraAdapter",
|
|
165
|
+
"RetroStarAdapter",
|
|
166
|
+
"SynPlannerAdapter",
|
|
167
|
+
"SyntheseusAdapter",
|
|
168
|
+
"SynLlaMaAdapter",
|
|
169
|
+
]
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from collections.abc import Generator
|
|
5
|
+
from typing import Annotated, Any, Literal
|
|
6
|
+
|
|
7
|
+
from pydantic import BaseModel, Field, RootModel, ValidationError
|
|
8
|
+
|
|
9
|
+
from retrocast.adapters.base_adapter import BaseAdapter
|
|
10
|
+
from retrocast.adapters.common import build_molecule_from_bipartite_node
|
|
11
|
+
from retrocast.exceptions import AdapterLogicError, RetroCastException
|
|
12
|
+
from retrocast.models.chem import Route, TargetIdentity
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
# --- pydantic models for input validation ---
|
|
17
|
+
# these models validate the raw aizynthfinder output format before any transformation.
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class AizynthBaseNode(BaseModel):
|
|
21
|
+
"""a base model for shared fields between node types."""
|
|
22
|
+
|
|
23
|
+
smiles: str
|
|
24
|
+
children: list[AizynthNode] = Field(default_factory=list)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class AizynthMoleculeInput(AizynthBaseNode):
|
|
28
|
+
"""represents a 'mol' node in the raw aizynth tree."""
|
|
29
|
+
|
|
30
|
+
type: Literal["mol"]
|
|
31
|
+
in_stock: bool = False
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class AizynthReactionInput(AizynthBaseNode):
|
|
35
|
+
"""represents a 'reaction' node in the raw aizynth tree."""
|
|
36
|
+
|
|
37
|
+
type: Literal["reaction"]
|
|
38
|
+
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
# a discriminated union to handle the bipartite graph structure.
|
|
42
|
+
AizynthNode = Annotated[AizynthMoleculeInput | AizynthReactionInput, Field(discriminator="type")]
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class AizynthRouteList(RootModel[list[AizynthMoleculeInput]]):
|
|
46
|
+
"""the top-level object for a single target is a list of potential routes."""
|
|
47
|
+
|
|
48
|
+
pass
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class AizynthAdapter(BaseAdapter):
|
|
52
|
+
"""adapter for converting aizynthfinder-style outputs to the benchmarktree schema."""
|
|
53
|
+
|
|
54
|
+
def cast(self, raw_target_data: Any, target: TargetIdentity) -> Generator[Route, None, None]:
|
|
55
|
+
"""
|
|
56
|
+
validates raw aizynthfinder data, transforms it, and yields route objects.
|
|
57
|
+
"""
|
|
58
|
+
try:
|
|
59
|
+
validated_routes = AizynthRouteList.model_validate(raw_target_data)
|
|
60
|
+
except ValidationError as e:
|
|
61
|
+
logger.warning(f" - raw data for target '{target.id}' failed aizynth schema validation. error: {e}")
|
|
62
|
+
return
|
|
63
|
+
|
|
64
|
+
for rank, aizynth_tree_root in enumerate(validated_routes.root, start=1):
|
|
65
|
+
try:
|
|
66
|
+
route = self._transform(aizynth_tree_root, target, rank)
|
|
67
|
+
yield route
|
|
68
|
+
except RetroCastException as e:
|
|
69
|
+
logger.warning(f" - route for '{target.id}' failed transformation: {e}")
|
|
70
|
+
continue
|
|
71
|
+
|
|
72
|
+
def _transform(self, aizynth_root: AizynthMoleculeInput, target: TargetIdentity, rank: int) -> Route:
|
|
73
|
+
"""
|
|
74
|
+
orchestrates the transformation of a single aizynthfinder output tree.
|
|
75
|
+
raises RetroCastException on failure.
|
|
76
|
+
"""
|
|
77
|
+
# use the common recursive builder with new schema
|
|
78
|
+
target_molecule = build_molecule_from_bipartite_node(raw_mol_node=aizynth_root)
|
|
79
|
+
|
|
80
|
+
if target_molecule.smiles != target.smiles:
|
|
81
|
+
msg = (
|
|
82
|
+
f"mismatched smiles for target {target.id}. "
|
|
83
|
+
f"expected canonical: {target.smiles}, but adapter produced: {target_molecule.smiles}"
|
|
84
|
+
)
|
|
85
|
+
logger.error(msg)
|
|
86
|
+
raise AdapterLogicError(msg)
|
|
87
|
+
|
|
88
|
+
return Route(target=target_molecule, rank=rank, metadata={})
|
|
@@ -0,0 +1,256 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from collections import defaultdict
|
|
5
|
+
from collections.abc import Generator
|
|
6
|
+
from typing import Annotated, Any, Literal
|
|
7
|
+
|
|
8
|
+
from pydantic import BaseModel, Field, ValidationError
|
|
9
|
+
|
|
10
|
+
from retrocast.adapters.base_adapter import BaseAdapter
|
|
11
|
+
from retrocast.chem import canonicalize_smiles, get_inchi_key
|
|
12
|
+
from retrocast.exceptions import AdapterLogicError, RetroCastException
|
|
13
|
+
from retrocast.models.chem import Molecule, ReactionStep, Route, TargetIdentity
|
|
14
|
+
from retrocast.typing import ReactionSmilesStr, SmilesStr
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
# --- pydantic models for input validation ---
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class AskcosBaseNode(BaseModel):
|
|
22
|
+
smiles: str
|
|
23
|
+
id: str
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class AskcosChemicalNode(AskcosBaseNode):
|
|
27
|
+
type: Literal["chemical"]
|
|
28
|
+
terminal: bool
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class AskcosTemplateSource(BaseModel):
|
|
32
|
+
"""Nested structure for template information."""
|
|
33
|
+
|
|
34
|
+
reaction_smarts: str | None = None
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class AskcosModelMetadata(BaseModel):
|
|
38
|
+
"""Model metadata containing template information."""
|
|
39
|
+
|
|
40
|
+
source: dict[str, Any] = Field(default_factory=dict)
|
|
41
|
+
|
|
42
|
+
def get_template(self) -> str | None:
|
|
43
|
+
"""Extract reaction_smarts from nested template structure."""
|
|
44
|
+
template_dict = self.source.get("template", {})
|
|
45
|
+
return template_dict.get("reaction_smarts") if isinstance(template_dict, dict) else None
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class AskcosReactionProperties(BaseModel):
|
|
49
|
+
mapped_smiles: str | None = None
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class AskcosReactionNode(AskcosBaseNode):
|
|
53
|
+
type: Literal["reaction"]
|
|
54
|
+
reaction_properties: AskcosReactionProperties | None = None
|
|
55
|
+
model_metadata: list[AskcosModelMetadata] = Field(default_factory=list)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
AskcosNode = Annotated[AskcosChemicalNode | AskcosReactionNode, Field(discriminator="type")]
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class AskcosPathwayEdge(BaseModel):
|
|
62
|
+
source: str
|
|
63
|
+
target: str
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class AskcosUDS(BaseModel):
|
|
67
|
+
node_dict: dict[str, AskcosNode]
|
|
68
|
+
uuid2smiles: dict[str, str]
|
|
69
|
+
pathways: list[list[AskcosPathwayEdge]]
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class AskcosResults(BaseModel):
|
|
73
|
+
uds: AskcosUDS
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class AskcosOutput(BaseModel):
|
|
77
|
+
results: AskcosResults
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class AskcosAdapter(BaseAdapter):
|
|
81
|
+
"""adapter for converting askcos outputs to the benchmarktree schema."""
|
|
82
|
+
|
|
83
|
+
def __init__(self, use_full_graph: bool = False):
|
|
84
|
+
"""
|
|
85
|
+
initializes the adapter.
|
|
86
|
+
|
|
87
|
+
args:
|
|
88
|
+
use_full_graph: if true, attempts to extract all possible routes
|
|
89
|
+
from the full search graph instead of using the pre-computed
|
|
90
|
+
pathways. defaults to false.
|
|
91
|
+
"""
|
|
92
|
+
self.use_full_graph = use_full_graph
|
|
93
|
+
|
|
94
|
+
def cast(self, raw_target_data: Any, target: TargetIdentity) -> Generator[Route, None, None]:
|
|
95
|
+
"""validates raw askcos data, transforms its pathways, and yields route objects."""
|
|
96
|
+
if self.use_full_graph:
|
|
97
|
+
raise NotImplementedError("extracting routes from the full askcos search graph is not yet implemented.")
|
|
98
|
+
|
|
99
|
+
try:
|
|
100
|
+
validated_output = AskcosOutput.model_validate(raw_target_data)
|
|
101
|
+
except ValidationError as e:
|
|
102
|
+
logger.warning(f" - raw data for target '{target.id}' failed askcos schema validation. error: {e}")
|
|
103
|
+
return
|
|
104
|
+
|
|
105
|
+
uds = validated_output.results.uds
|
|
106
|
+
|
|
107
|
+
# Extract metadata from stats if available
|
|
108
|
+
stats = raw_target_data.get("results", {}).get("stats", {})
|
|
109
|
+
metadata = {
|
|
110
|
+
"total_iterations": stats.get("total_iterations"),
|
|
111
|
+
"total_chemicals": stats.get("total_chemicals"),
|
|
112
|
+
"total_reactions": stats.get("total_reactions"),
|
|
113
|
+
"total_templates": stats.get("total_templates"),
|
|
114
|
+
"total_paths": stats.get("total_paths"),
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
for i, pathway_edges in enumerate(uds.pathways):
|
|
118
|
+
try:
|
|
119
|
+
route = self._transform_pathway(
|
|
120
|
+
pathway_edges=pathway_edges,
|
|
121
|
+
uuid2smiles=uds.uuid2smiles,
|
|
122
|
+
node_dict=uds.node_dict,
|
|
123
|
+
target_input=target,
|
|
124
|
+
rank=i + 1,
|
|
125
|
+
metadata=metadata,
|
|
126
|
+
)
|
|
127
|
+
yield route
|
|
128
|
+
except RetroCastException as e:
|
|
129
|
+
logger.warning(f" - pathway {i} for target '{target.id}' failed transformation: {e}")
|
|
130
|
+
continue
|
|
131
|
+
|
|
132
|
+
def _transform_pathway(
|
|
133
|
+
self,
|
|
134
|
+
pathway_edges: list[AskcosPathwayEdge],
|
|
135
|
+
uuid2smiles: dict[str, str],
|
|
136
|
+
node_dict: dict[str, AskcosNode],
|
|
137
|
+
target_input: TargetIdentity,
|
|
138
|
+
rank: int,
|
|
139
|
+
metadata: dict[str, Any],
|
|
140
|
+
) -> Route:
|
|
141
|
+
"""transforms a single askcos pathway (represented by its edges) into a route."""
|
|
142
|
+
adj_list = defaultdict(list)
|
|
143
|
+
for edge in pathway_edges:
|
|
144
|
+
adj_list[edge.source].append(edge.target)
|
|
145
|
+
|
|
146
|
+
root_uuid = "00000000-0000-0000-0000-000000000000"
|
|
147
|
+
if root_uuid not in uuid2smiles:
|
|
148
|
+
raise AdapterLogicError("root uuid not found in pathway data.")
|
|
149
|
+
|
|
150
|
+
target_molecule = self._build_molecule(
|
|
151
|
+
chem_uuid=root_uuid,
|
|
152
|
+
path_prefix="retrocast-mol-root",
|
|
153
|
+
adj_list=adj_list,
|
|
154
|
+
uuid2smiles=uuid2smiles,
|
|
155
|
+
node_dict=node_dict,
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
if target_molecule.smiles != target_input.smiles:
|
|
159
|
+
msg = (
|
|
160
|
+
f"mismatched smiles for target {target_input.id}. "
|
|
161
|
+
f"expected canonical: {target_input.smiles}, but adapter produced: {target_molecule.smiles}"
|
|
162
|
+
)
|
|
163
|
+
raise AdapterLogicError(msg)
|
|
164
|
+
|
|
165
|
+
return Route(target=target_molecule, rank=rank, metadata=metadata)
|
|
166
|
+
|
|
167
|
+
def _build_molecule(
|
|
168
|
+
self,
|
|
169
|
+
chem_uuid: str,
|
|
170
|
+
path_prefix: str,
|
|
171
|
+
adj_list: dict[str, list[str]],
|
|
172
|
+
uuid2smiles: dict[str, str],
|
|
173
|
+
node_dict: dict[str, AskcosNode],
|
|
174
|
+
) -> Molecule:
|
|
175
|
+
"""recursively builds a canonical molecule from a chemical uuid."""
|
|
176
|
+
raw_smiles = uuid2smiles.get(chem_uuid)
|
|
177
|
+
if not raw_smiles:
|
|
178
|
+
raise AdapterLogicError(f"uuid '{chem_uuid}' not found in uuid2smiles map.")
|
|
179
|
+
|
|
180
|
+
node_data = node_dict.get(raw_smiles)
|
|
181
|
+
if not node_data or not isinstance(node_data, AskcosChemicalNode):
|
|
182
|
+
raise AdapterLogicError(f"node data for smiles '{raw_smiles}' not found or not a chemical node.")
|
|
183
|
+
|
|
184
|
+
canon_smiles = canonicalize_smiles(node_data.smiles)
|
|
185
|
+
is_leaf = node_data.terminal
|
|
186
|
+
synthesis_step = None
|
|
187
|
+
|
|
188
|
+
if not is_leaf and chem_uuid in adj_list:
|
|
189
|
+
child_reaction_uuids = adj_list[chem_uuid]
|
|
190
|
+
if len(child_reaction_uuids) > 1:
|
|
191
|
+
logger.warning(f"molecule {canon_smiles} has multiple child reactions in pathway; using first one.")
|
|
192
|
+
|
|
193
|
+
rxn_uuid = child_reaction_uuids[0]
|
|
194
|
+
synthesis_step = self._build_reaction_step(
|
|
195
|
+
rxn_uuid=rxn_uuid,
|
|
196
|
+
product_smiles=canon_smiles,
|
|
197
|
+
path_prefix=path_prefix,
|
|
198
|
+
adj_list=adj_list,
|
|
199
|
+
uuid2smiles=uuid2smiles,
|
|
200
|
+
node_dict=node_dict,
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
return Molecule(
|
|
204
|
+
smiles=canon_smiles,
|
|
205
|
+
inchikey=get_inchi_key(canon_smiles),
|
|
206
|
+
synthesis_step=synthesis_step,
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
def _build_reaction_step(
|
|
210
|
+
self,
|
|
211
|
+
rxn_uuid: str,
|
|
212
|
+
product_smiles: SmilesStr,
|
|
213
|
+
path_prefix: str,
|
|
214
|
+
adj_list: dict[str, list[str]],
|
|
215
|
+
uuid2smiles: dict[str, str],
|
|
216
|
+
node_dict: dict[str, AskcosNode],
|
|
217
|
+
) -> ReactionStep:
|
|
218
|
+
"""builds a canonical reaction step from a reaction uuid."""
|
|
219
|
+
raw_smiles = uuid2smiles.get(rxn_uuid)
|
|
220
|
+
if not raw_smiles:
|
|
221
|
+
raise AdapterLogicError(f"uuid '{rxn_uuid}' not found in uuid2smiles map.")
|
|
222
|
+
|
|
223
|
+
node_data = node_dict.get(raw_smiles)
|
|
224
|
+
if not node_data or not isinstance(node_data, AskcosReactionNode):
|
|
225
|
+
raise AdapterLogicError(f"node data for reaction '{raw_smiles}' not found or not a reaction node.")
|
|
226
|
+
|
|
227
|
+
reactants: list[Molecule] = []
|
|
228
|
+
reactant_smiles_list: list[SmilesStr] = []
|
|
229
|
+
|
|
230
|
+
reactant_uuids = adj_list.get(rxn_uuid, [])
|
|
231
|
+
for i, reactant_uuid in enumerate(reactant_uuids):
|
|
232
|
+
reactant_molecule = self._build_molecule(
|
|
233
|
+
chem_uuid=reactant_uuid,
|
|
234
|
+
path_prefix=f"{path_prefix}-{i}",
|
|
235
|
+
adj_list=adj_list,
|
|
236
|
+
uuid2smiles=uuid2smiles,
|
|
237
|
+
node_dict=node_dict,
|
|
238
|
+
)
|
|
239
|
+
reactants.append(reactant_molecule)
|
|
240
|
+
reactant_smiles_list.append(reactant_molecule.smiles)
|
|
241
|
+
|
|
242
|
+
# Extract mapped_smiles from reaction_properties if available
|
|
243
|
+
mapped_smiles = None
|
|
244
|
+
if node_data.reaction_properties and node_data.reaction_properties.mapped_smiles:
|
|
245
|
+
mapped_smiles = ReactionSmilesStr(node_data.reaction_properties.mapped_smiles)
|
|
246
|
+
|
|
247
|
+
# Extract template from model_metadata if available
|
|
248
|
+
template = None
|
|
249
|
+
if node_data.model_metadata and len(node_data.model_metadata) > 0:
|
|
250
|
+
template = node_data.model_metadata[0].get_template()
|
|
251
|
+
|
|
252
|
+
return ReactionStep(
|
|
253
|
+
reactants=reactants,
|
|
254
|
+
mapped_smiles=mapped_smiles,
|
|
255
|
+
template=template,
|
|
256
|
+
)
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from collections.abc import Generator
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from retrocast.models.chem import Route, TargetIdentity
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class BaseAdapter(ABC):
|
|
9
|
+
"""
|
|
10
|
+
Abstract base class for all model output adapters.
|
|
11
|
+
|
|
12
|
+
An adapter's role is to transform a model's raw output format into the
|
|
13
|
+
canonical `Route` schema.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
@abstractmethod
|
|
17
|
+
def cast(self, raw_target_data: Any, target: TargetIdentity) -> Generator[Route, None, None]:
|
|
18
|
+
"""
|
|
19
|
+
Validates, transforms, and yields Routes from raw model data.
|
|
20
|
+
|
|
21
|
+
This is the primary method for an adapter. It encapsulates all model-specific
|
|
22
|
+
logic. It should be a generator that yields successful routes and handles its
|
|
23
|
+
own exceptions internally by logging and continuing.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
raw_target_data: The raw data blob from a file for a single target.
|
|
27
|
+
This blob can follow one of two common patterns:
|
|
28
|
+
|
|
29
|
+
1. **Route-Centric**: The data is a list of route objects, where the
|
|
30
|
+
root of each route object contains the target SMILES (e.g.,
|
|
31
|
+
AiZynthFinder, DMS). `raw_target_data` is typically a `list`.
|
|
32
|
+
|
|
33
|
+
2. **Target-Centric**: The data is a single JSON object that contains
|
|
34
|
+
metadata (like a top-level `smiles` key) and a nested list of
|
|
35
|
+
routes (e.g., RetroChimera). `raw_target_data` is typically a `dict`.
|
|
36
|
+
|
|
37
|
+
The adapter is responsible for handling the specific structure of its model.
|
|
38
|
+
target: The identity of the target molecule (id and canonical SMILES).
|
|
39
|
+
|
|
40
|
+
Yields:
|
|
41
|
+
Successfully transformed Route objects.
|
|
42
|
+
"""
|
|
43
|
+
raise NotImplementedError
|