synkit 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. synkit/Chem/Fingerprint/__init__.py +0 -0
  2. synkit/Chem/Fingerprint/fp_calculator.py +122 -0
  3. synkit/Chem/Fingerprint/smiles_featurizer.py +185 -0
  4. synkit/Chem/Fingerprint/transformation_fp.py +79 -0
  5. synkit/Chem/Molecule/__init__.py +0 -0
  6. synkit/Chem/Molecule/standardize.py +137 -0
  7. synkit/Chem/Reaction/__init__.py +0 -0
  8. synkit/Chem/Reaction/balance_check.py +162 -0
  9. synkit/Chem/Reaction/cleanning.py +59 -0
  10. synkit/Chem/Reaction/deionize.py +289 -0
  11. synkit/Chem/Reaction/neutralize.py +256 -0
  12. synkit/Chem/Reaction/reagent.py +102 -0
  13. synkit/Chem/Reaction/standardize.py +157 -0
  14. synkit/Chem/Reaction/tautomerize.py +168 -0
  15. synkit/Graph/Cluster/__init__.py +0 -0
  16. synkit/Graph/Cluster/morphism.py +83 -0
  17. synkit/Graph/Feature/__init__.py +0 -0
  18. synkit/Graph/Feature/graph_descriptors.py +325 -0
  19. synkit/Graph/Feature/graph_fps.py +97 -0
  20. synkit/Graph/Feature/graph_signature.py +236 -0
  21. synkit/Graph/Feature/hash_fps.py +130 -0
  22. synkit/Graph/Feature/morgan_fps.py +87 -0
  23. synkit/Graph/Feature/path_fps.py +82 -0
  24. synkit/Graph/__init.py +0 -0
  25. synkit/IO/__init__.py +0 -0
  26. synkit/IO/chem_converter.py +231 -0
  27. synkit/IO/data_io.py +277 -0
  28. synkit/IO/data_process.py +49 -0
  29. synkit/IO/debug.py +78 -0
  30. synkit/IO/dg_to_gml.py +124 -0
  31. synkit/IO/gml_to_nx.py +119 -0
  32. synkit/IO/graph_to_mol.py +110 -0
  33. synkit/IO/mol_to_graph.py +282 -0
  34. synkit/IO/nx_to_gml.py +200 -0
  35. synkit/IO/parse_rule.py +172 -0
  36. synkit/IO/smiles_to_id.py +119 -0
  37. synkit/ITS/_misc.py +280 -0
  38. synkit/ITS/aam_validator.py +254 -0
  39. synkit/ITS/its_builder.py +94 -0
  40. synkit/ITS/its_construction.py +213 -0
  41. synkit/ITS/normalize_aam.py +183 -0
  42. synkit/ITS/partial_expand.py +170 -0
  43. synkit/Reactor/__init__.py +0 -0
  44. synkit/Reactor/core_engine.py +164 -0
  45. synkit/Reactor/inference.py +73 -0
  46. synkit/Reactor/multi_step.py +227 -0
  47. synkit/Reactor/multi_step_aam.py +82 -0
  48. synkit/Reactor/reagent.py +95 -0
  49. synkit/Reactor/rule_apply.py +81 -0
  50. synkit/Vis/__init__.py +0 -0
  51. synkit/Vis/chemical_graph_visualizer.py +378 -0
  52. synkit/Vis/chemical_reaction_visualizer.py +133 -0
  53. synkit/Vis/chemical_space.py +83 -0
  54. synkit/Vis/embedding.py +92 -0
  55. synkit/Vis/graph_visualizer.py +286 -0
  56. synkit/Vis/pdf_writer.py +143 -0
  57. synkit/Vis/rsmi_to_fig.py +169 -0
  58. synkit/__init__.py +0 -0
  59. synkit/_misc.py +181 -0
  60. synkit-0.0.1.dist-info/METADATA +148 -0
  61. synkit-0.0.1.dist-info/RECORD +63 -0
  62. synkit-0.0.1.dist-info/WHEEL +4 -0
  63. synkit-0.0.1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,183 @@
1
+ import re
2
+ import networkx as nx
3
+ from rdkit import Chem
4
+ from typing import List
5
+
6
+ from synkit.IO.chem_converter import rsmi_to_graph
7
+ from synkit.IO.graph_to_mol import GraphToMol
8
+ from synkit.ITS.its_construction import ITSConstruction
9
+ from synkit.ITS._misc import its_decompose, get_rc
10
+
11
+
12
+ class NormalizeAAM:
13
+ """
14
+ Provides functionalities to normalize atom mappings in SMILES representations,
15
+ extract and process reaction centers from ITS graphs, and convert between
16
+ graph representations and molecular models.
17
+ """
18
+
19
+ def __init__(self) -> None:
20
+ """
21
+ Initializes the NormalizeAAM class.
22
+ """
23
+ pass
24
+
25
+ @staticmethod
26
+ def increment(match: re.Match) -> str:
27
+ """
28
+ Helper function to increment a matched atom mapping number by 1.
29
+
30
+ Parameters:
31
+ match (re.Match): A regex match object containing the atom mapping number.
32
+
33
+ Returns:
34
+ str: The incremented atom mapping number as a string.
35
+ """
36
+ return str(int(match.group()) + 1)
37
+
38
+ @staticmethod
39
+ def fix_atom_mapping(smiles: str) -> str:
40
+ """
41
+ Increments each atom mapping number in a SMILES string by 1.
42
+
43
+ Parameters:
44
+ smiles (str): The SMILES string with atom mapping numbers.
45
+
46
+ Returns:
47
+ str: The SMILES string with updated atom mapping numbers.
48
+ """
49
+ pattern = re.compile(r"(?<=:)\d+")
50
+ return pattern.sub(NormalizeAAM.increment, smiles)
51
+
52
+ @staticmethod
53
+ def fix_aam_rsmi(rsmi: str) -> str:
54
+ """
55
+ Adjusts atom mapping numbers in both reactant and product parts of a reaction SMILES (RSMI).
56
+
57
+ Parameters:
58
+ rsmi (str): The reaction SMILES string.
59
+
60
+ Returns:
61
+ str: The RSMI with updated atom mappings for both reactants and products.
62
+ """
63
+ r, p = rsmi.split(">>")
64
+ return f"{NormalizeAAM.fix_atom_mapping(r)}>>{NormalizeAAM.fix_atom_mapping(p)}"
65
+
66
+ @staticmethod
67
+ def fix_rsmi_kekulize(rsmi: str) -> str:
68
+ """
69
+ Filters the reactants and products of a reaction SMILES string.
70
+
71
+ Parameters:
72
+ - rsmi (str): A string representing the reaction SMILES in the form of "reactants >> products".
73
+
74
+ Returns:
75
+ - str: A filtered reaction SMILES string where invalid reactants/products are removed.
76
+ """
77
+ # Split the reaction into reactants and products
78
+ reactants, products = rsmi.split(">>")
79
+
80
+ # Filter valid reactants and products
81
+ filtered_reactants = NormalizeAAM.fix_kekulize(reactants)
82
+ filtered_products = NormalizeAAM.fix_kekulize(products)
83
+
84
+ # Return the filtered reaction SMILES
85
+ return f"{filtered_reactants}>>{filtered_products}"
86
+
87
+ @staticmethod
88
+ def fix_kekulize(smiles: str) -> str:
89
+ """
90
+ Filters and returns valid SMILES strings from a string of SMILES, joined by '.'.
91
+
92
+ This function processes a string of SMILES separated by periods (e.g., "CCO.CC=O"),
93
+ filters out invalid SMILES, and returns a string of valid SMILES joined by periods.
94
+
95
+ Parameters:
96
+ - smiles (str): A string containing SMILES strings separated by periods ('.').
97
+
98
+ Returns:
99
+ - str: A string of valid SMILES, joined by periods ('.').
100
+ """
101
+ smiles_list = smiles.split(".") # Split SMILES by period
102
+ valid_smiles = [] # List to store valid SMILES strings
103
+
104
+ for smile in smiles_list:
105
+ mol = Chem.MolFromSmiles(smile, sanitize=False)
106
+ if mol: # Check if molecule is valid
107
+ valid_smiles.append(
108
+ Chem.MolToSmiles(
109
+ mol, canonical=True, kekuleSmiles=True, allHsExplicit=True
110
+ )
111
+ )
112
+ return ".".join(valid_smiles) # Return valid SMILES joined by '.'
113
+
114
+ @staticmethod
115
+ def extract_subgraph(graph: nx.Graph, indices: List[int]) -> nx.Graph:
116
+ """
117
+ Extracts a subgraph from a given graph based on a list of node indices.
118
+
119
+ Parameters:
120
+ graph (nx.Graph): The original graph from which to extract the subgraph.
121
+ indices (List[int]): A list of node indices that define the subgraph.
122
+
123
+ Returns:
124
+ nx.Graph: The extracted subgraph.
125
+ """
126
+ return graph.subgraph(indices).copy()
127
+
128
+ def reset_indices_and_atom_map(
129
+ self, subgraph: nx.Graph, aam_key: str = "atom_map"
130
+ ) -> nx.Graph:
131
+ """
132
+ Resets the node indices and the atom_map of the subgraph to be continuous from 1 onwards.
133
+
134
+ Parameters:
135
+ subgraph (nx.Graph): The subgraph with possibly non-continuous indices.
136
+ aam_key (str): The attribute key for atom mapping. Defaults to 'atom_map'.
137
+
138
+ Returns:
139
+ nx.Graph: A new subgraph with continuous indices and adjusted atom_map.
140
+ """
141
+ new_graph = nx.Graph()
142
+ node_id_mapping = {
143
+ old_id: new_id for new_id, old_id in enumerate(subgraph.nodes(), 1)
144
+ }
145
+ for old_id, new_id in node_id_mapping.items():
146
+ node_data = subgraph.nodes[old_id].copy()
147
+ node_data[aam_key] = new_id
148
+ new_graph.add_node(new_id, **node_data)
149
+ for u, v, data in subgraph.edges(data=True):
150
+ new_graph.add_edge(node_id_mapping[u], node_id_mapping[v], **data)
151
+ return new_graph
152
+
153
+ def fit(self, rsmi: str, fix_aam_indice: bool = True) -> str:
154
+ """
155
+ Processes a reaction SMILES (RSMI) to adjust atom mappings, extract reaction centers,
156
+ decompose into separate reactant and product graphs, and generate the corresponding SMILES.
157
+
158
+ Parameters:
159
+ rsmi (str): The reaction SMILES string to be processed.
160
+ fix_aam_indice (bool): Whether to fix the atom mapping numbers. Defaults to True.
161
+
162
+ Returns:
163
+ str: The resulting reaction SMILES string with updated atom mappings.
164
+ """
165
+ rsmi = self.fix_rsmi_kekulize(rsmi)
166
+ if fix_aam_indice:
167
+ rsmi = self.fix_aam_rsmi(rsmi)
168
+ r_graph, p_graph = rsmi_to_graph(rsmi, light_weight=True, sanitize=False)
169
+ its = ITSConstruction().ITSGraph(r_graph, p_graph)
170
+ rc = get_rc(its)
171
+ keep_indice = [
172
+ indice
173
+ for indice, data in its.nodes(data=True)
174
+ if indice not in rc.nodes() and data["element"] != "H"
175
+ ]
176
+ keep_indice.extend(rc.nodes())
177
+ subgraph = self.extract_subgraph(its, keep_indice)
178
+ subgraph = self.reset_indices_and_atom_map(subgraph)
179
+ r_graph, p_graph = its_decompose(subgraph)
180
+ r_mol, p_mol = GraphToMol().graph_to_mol(
181
+ r_graph, sanitize=False
182
+ ), GraphToMol().graph_to_mol(p_graph, sanitize=False)
183
+ return f"{Chem.MolToSmiles(r_mol)}>>{Chem.MolToSmiles(p_mol)}"
@@ -0,0 +1,170 @@
1
+ import networkx as nx
2
+ from synutility.SynIO.Format.nx_to_gml import NXToGML
3
+ from synutility.SynIO.Format.chemical_conversion import (
4
+ rsmi_to_graph,
5
+ graph_to_rsmi,
6
+ smiles_to_graph,
7
+ )
8
+
9
+ from synutility.SynAAM.misc import its_decompose, get_rc
10
+ from synutility.SynAAM.its_construction import ITSConstruction
11
+ from synutility.SynAAM.its_builder import ITSBuilder
12
+ from synutility.SynChem.Reaction.standardize import Standardize
13
+ from synutility.SynAAM.inference import aam_infer
14
+
15
+ std = Standardize()
16
+
17
+
18
+ class PartialExpand:
19
+ """
20
+ A class for partially expanding reaction SMILES (RSMI) by applying transformation
21
+ rules based on the reaction center (RC) graph.
22
+
23
+ This class provides methods for expanding a given RSMI by identifying the
24
+ reaction center (RC), applying transformation rules, and standardizing atom mappings
25
+ to generate a full AAM RSMI.
26
+
27
+ Methods:
28
+ - expand(rsmi: str) -> str:
29
+ Expands a reaction SMILES string by identifying the reaction center (RC),
30
+ applying transformation rules, and standardizing atom mappings.
31
+
32
+ - graph_expand(partial_its: nx.Graph, rsmi: str) -> str:
33
+ Expands a reaction SMILES string using an Imaginary Transition State
34
+ (ITS) graph and applies the transformation rule based on the reaction center (RC).
35
+ """
36
+
37
+ def __init__(self) -> None:
38
+ """
39
+ Initializes the PartialExpand class.
40
+
41
+ This constructor currently does not initialize any instance-specific attributes.
42
+ """
43
+ pass
44
+
45
+ @staticmethod
46
+ def graph_expand(partial_its: nx.Graph, rsmi: str) -> str:
47
+ """
48
+ Expands a reaction SMILES string by applying transformation rules using an
49
+ ITS graph based on the reaction center (RC) graph.
50
+
51
+ This method extracts the reaction center (RC) from the ITS graph, decomposes it
52
+ into reactant and product graphs, generates a GML rule for transformation,
53
+ and applies the rule to the RSMI string.
54
+
55
+ Parameters:
56
+ - partial_its (nx.Graph): The Intermediate Transition State (ITS) graph.
57
+ - rsmi (str): The input reaction SMILES string to be expanded.
58
+
59
+ Returns:
60
+ - str: The transformed reaction SMILES string after applying the
61
+ transformation rules.
62
+ """
63
+ # Extract the reaction center (RC) graph from the ITS graph
64
+ rc = get_rc(partial_its)
65
+
66
+ # Decompose the RC into reactant and product graphs
67
+ r_graph, p_graph = its_decompose(rc)
68
+
69
+ # Transform the graph into a GML rule
70
+ rule = NXToGML().transform((r_graph, p_graph, rc))
71
+
72
+ # Apply the transformation rule to the RSMI
73
+ transformed_rsmi = aam_infer(rsmi, rule)[0]
74
+
75
+ return transformed_rsmi
76
+
77
+ @staticmethod
78
+ def expand_aam_with_transform(rsmi: str) -> str:
79
+ """
80
+ Expands a reaction SMILES string by identifying the reaction center (RC),
81
+ applying transformation rules, and standardizing the atom mappings.
82
+
83
+ This method constructs the Intermediate Transition State (ITS) graph from the
84
+ input RSMI, applies the reaction transformation rules using `graph_expand`,
85
+ and returns the transformed reaction SMILES string.
86
+
87
+ Parameters:
88
+ - rsmi (str): The input reaction SMILES string to be expanded.
89
+
90
+ Returns:
91
+ - str: The transformed reaction SMILES string after applying the
92
+ transformation rules.
93
+
94
+ Raises:
95
+ - Exception: If an error occurs during the expansion process, the original RSMI
96
+ is returned.
97
+ """
98
+ try:
99
+ # Convert RSMI to reactant and product graphs
100
+ r_graph, p_graph = rsmi_to_graph(rsmi, light_weight=True, sanitize=False)
101
+
102
+ # Construct the ITS graph from the reactant and product graphs
103
+ its = ITSConstruction().ITSGraph(r_graph, p_graph)
104
+
105
+ # Standardize smiles
106
+ rsmi = std.fit(rsmi)
107
+ # Apply graph expansion and return the result
108
+ return PartialExpand.graph_expand(its, rsmi)
109
+
110
+ except Exception as e:
111
+ # Log the error and return the original RSMI if something goes wrong
112
+ print(f"An error occurred during RSMI expansion: {e}")
113
+ return None
114
+
115
+ @staticmethod
116
+ def expand_aam_with_its(rsmi: str, use_G: bool = True, light_weight=True) -> str:
117
+ """
118
+ Expands a partial reaction SMILES string to a full reaction SMILES by reconstructing
119
+ intermediate transition states (ITS) and decomposing them back into reactants and products.
120
+
121
+ Parameters:
122
+ - rsmi (str): The reaction SMILES string that potentially contains a partial mapping of atoms.
123
+ - use_G (bool, optional): A flag to determine which part of the reaction SMILES to expand.
124
+ If True, uses the reactants' part for expansion; if False, uses the products' part.
125
+
126
+ Returns:
127
+ - str: The expanded reaction SMILES string with a complete mapping of all atoms involved
128
+ in the reaction.
129
+
130
+ Note:
131
+ - This function assumes that the input reaction SMILES is formatted correctly and split
132
+ into reactants and products separated by '>>'.
133
+ - The function relies on graph transformation methods to construct the ITS graph, decompose it,
134
+ and finally convert the resulting graph back into a SMILES string.
135
+ """
136
+ # Split the reaction SMILES based on the use_G flag
137
+ smi = rsmi.split(">>")[0] if use_G else rsmi.split(">>")[1]
138
+
139
+ # Convert reaction SMILES to graph representation of reactants and products
140
+ r, p = rsmi_to_graph(rsmi)
141
+
142
+ # Construct the Intermediate Transition State (ITS) graph from reactants and products
143
+ rc = ITSConstruction().ITSGraph(r, p)
144
+ # rc = get_rc(rc)
145
+
146
+ # Convert a SMILES string to graph; parameters are indicative and function should exist
147
+ G = smiles_to_graph(
148
+ smi,
149
+ light_weight=light_weight,
150
+ sanitize=True,
151
+ drop_non_aam=False,
152
+ use_index_as_atom_map=False,
153
+ )
154
+
155
+ # Rebuild the ITS graph from the generated graph and the reconstructed ITS
156
+ its = ITSBuilder().ITSGraph(G, rc)
157
+
158
+ # Decompose the ITS graph back into modified reactants and products
159
+ r, p = its_decompose(its)
160
+
161
+ # Convert the modified reactants and products back into a reaction SMILES string
162
+ return graph_to_rsmi(r, p, its, True, False, True)
163
+
164
+
165
+ if __name__ == "__main__":
166
+ rsmi = "[CH3][CH:1]=[CH2:2].[H:3][H:4]>>[CH3][CH:1]([H:3])[CH2:2][H:4]"
167
+ rsmi = "CC[CH2:3][Cl:1].[NH2:2][H:4]>>CC[CH2:3][NH2:2].[Cl:1][H:4]"
168
+ print(PartialExpand.expand(rsmi))
169
+ # self.rsmi = "BrCc1ccc(Br)cc1.COCCO>>Br.COCCOCc1ccc(Br)cc1"
170
+ # self.gml = smart_to_gml("[Br:1][CH3:2].[OH:3][H:4]>>[Br:1][H:4].[CH3:2][OH:3]")
File without changes
@@ -0,0 +1,164 @@
1
+ from rdkit import Chem
2
+ from pathlib import Path
3
+ from typing import List, Union
4
+ from collections import Counter
5
+ from synkit.IO.data_io import load_gml_as_text
6
+
7
+ import torch
8
+ from mod import *
9
+
10
+
11
+ class CoreEngine:
12
+ """
13
+ The MØDModeling class encapsulates functionalities for reaction modeling using the MØD
14
+ toolkit. It provides methods for forward and backward prediction based on templates
15
+ library.
16
+ """
17
+
18
+ @staticmethod
19
+ def generate_reaction_smiles(
20
+ temp_results: List[str], base_smiles: str, is_forward: bool = True
21
+ ) -> List[str]:
22
+ """
23
+ Constructs reaction SMILES strings from intermediate results using a base SMILES
24
+ string, indicating whether the process is a forward or backward reaction. This
25
+ function iterates over a list of intermediate SMILES strings, combines them with
26
+ the base SMILES, and formats them into complete reaction SMILES strings.
27
+
28
+ Parameters:
29
+ - temp_results (List[str]): Intermediate SMILES strings resulting from partial
30
+ reactions or combinations.
31
+ - base_smiles (str): The SMILES string representing the starting point of the
32
+ reaction, either as reactants or products, depending on the reaction direction.
33
+ - is_forward (bool, optional): Flag to determine the direction of the reaction;
34
+ 'True' for forward reactions where 'base_smiles' are reactants, and 'False' for
35
+ backward reactions where 'base_smiles' are products. Defaults to True.
36
+
37
+ Returns:
38
+ - List[str]: A list of complete reaction SMILES strings, formatted according to
39
+ the specified reaction direction.
40
+ """
41
+ results = []
42
+ for comb in temp_results:
43
+ joined_smiles = ".".join(comb)
44
+ reaction_smiles = (
45
+ f"{base_smiles}>>{joined_smiles}"
46
+ if is_forward
47
+ else f"{joined_smiles}>>{base_smiles}"
48
+ )
49
+ results.append(reaction_smiles)
50
+ return results
51
+
52
+ @staticmethod
53
+ def perform_reaction(
54
+ rule_file_path: Union[str, str],
55
+ initial_smiles: List[str],
56
+ prediction_type: str = "forward",
57
+ print_results: bool = False,
58
+ verbosity: int = 0,
59
+ ) -> List[str]:
60
+ """
61
+ Applies a specified reaction rule, loaded from a GML file, to a set of initial
62
+ molecules represented by SMILES strings. The reaction can be simulated in forward
63
+ or backward direction and repeated multiple times.
64
+
65
+ Parameters:
66
+ - rule_file_path (str): Path to the GML file containing the reaction rule.
67
+ - initial_smiles (List[str]): Initial molecules represented as SMILES strings.
68
+ - type (str, optional): Direction of the reaction ('forward' for forward,
69
+ 'backward' for backward). Defaults to 'forward'.
70
+ - print_results (bool): Print results in latex or not. Defaults to False.
71
+
72
+ Returns:
73
+ - List[str]: SMILES strings of the resulting molecules or reactions.
74
+ """
75
+
76
+ # Determine the rule inversion based on reaction type
77
+ invert_rule = prediction_type == "backward"
78
+ # Convert SMILES strings to molecule objects, avoiding duplicate conversions
79
+ initial_molecules = [smiles(smile, add=False) for smile in (initial_smiles)]
80
+
81
+ def deduplicateGraphs(initial):
82
+ res = []
83
+ for cand in initial:
84
+ for a in res:
85
+ if cand.isomorphism(a) != 0:
86
+ res.append(a) # the one we had already
87
+ break
88
+ else:
89
+ # didn't find any isomorphic, use the new one
90
+ res.append(cand)
91
+ return res
92
+
93
+ initial_molecules = deduplicateGraphs(initial_molecules)
94
+
95
+ initial_molecules = sorted(
96
+ initial_molecules, key=lambda molecule: molecule.numVertices, reverse=False
97
+ )
98
+ # Load the reaction rule from the GML file
99
+ rule_path = Path(rule_file_path)
100
+
101
+ try:
102
+ if rule_path.is_file():
103
+ gml_content = load_gml_as_text(rule_file_path)
104
+ else:
105
+ gml_content = rule_file_path
106
+ except Exception as e:
107
+ # print(f"An error occurred while loading the GML file: {e}")
108
+ gml_content = rule_file_path
109
+ reaction_rule = ruleGMLString(gml_content, invert=invert_rule, add=False)
110
+ # Initialize the derivation graph and execute the strategy
111
+ dg = DG(graphDatabase=initial_molecules)
112
+ config.dg.doRuleIsomorphismDuringBinding = False
113
+ dg.build().apply(initial_molecules, reaction_rule, verbosity=verbosity)
114
+ if print_results:
115
+ dg.print()
116
+
117
+ temp_results = []
118
+ for e in dg.edges:
119
+ productSmiles = [v.graph.smiles for v in e.targets]
120
+ temp_results.append(productSmiles)
121
+ # print(productSmiles)
122
+
123
+ if len(temp_results) == 0:
124
+ # print(1)
125
+ dg = DG(graphDatabase=initial_molecules)
126
+ # dg.build().execute(strategy, verbosity=8)
127
+ config.dg.doRuleIsomorphismDuringBinding = False
128
+ dg.build().apply(
129
+ initial_molecules, reaction_rule, verbosity=verbosity, onlyProper=False
130
+ )
131
+ temp_results, small_educt = [], []
132
+ for edge in dg.edges:
133
+ temp_results.append([vertex.graph.smiles for vertex in edge.targets])
134
+ small_educt.append([vertex.graph.smiles for vertex in edge.sources])
135
+
136
+ for key, solution in enumerate(temp_results):
137
+ educt = small_educt[key]
138
+ small_educt_counts = Counter(
139
+ Chem.CanonSmiles(smile) for smile in educt if smile is not None
140
+ )
141
+ reagent_counts = Counter([Chem.CanonSmiles(s) for s in initial_smiles])
142
+ reagent_counts.subtract(small_educt_counts)
143
+ reagent = [
144
+ smile
145
+ for smile, count in reagent_counts.items()
146
+ for _ in range(count)
147
+ if count > 0
148
+ ]
149
+ solution.extend(reagent)
150
+
151
+ reaction_processing_map = {
152
+ "forward": lambda smiles: CoreEngine.generate_reaction_smiles(
153
+ temp_results, ".".join(initial_smiles), is_forward=True
154
+ ),
155
+ "backward": lambda smiles: CoreEngine.generate_reaction_smiles(
156
+ temp_results, ".".join(initial_smiles), is_forward=False
157
+ ),
158
+ }
159
+
160
+ # Use the reaction type to select the appropriate processing function and apply it
161
+ if prediction_type in reaction_processing_map:
162
+ return reaction_processing_map[prediction_type](initial_smiles)
163
+ else:
164
+ return ""
@@ -0,0 +1,73 @@
1
+ import torch
2
+ from typing import List, Any
3
+ from synkit.IO.dg_to_gml import DGToGML
4
+ from synkit.ITS.normalize_aam import NormalizeAAM
5
+ from synkit.Chem.Reaction.standardize import Standardize
6
+ from synkit.Reactor.rule_apply import rule_apply
7
+
8
+ std = Standardize()
9
+
10
+
11
+ def aam_infer(rsmi: str, gml: Any) -> List[str]:
12
+ """
13
+ Infers a set of normalized SMILES from a reaction SMILES string and a graph model (GML).
14
+
15
+ This function takes a reaction SMILES string (rsmi) and a graph model (gml), applies the
16
+ reaction transformation using the graph model, normalizes and standardizes the resulting
17
+ SMILES, and returns a list of SMILES that match the original reaction's structure after
18
+ normalization and standardization.
19
+
20
+ Steps:
21
+ 1. The reactants in the reaction SMILES string are separated.
22
+ 2. The transformation is applied to the reactants using the provided graph model (gml).
23
+ 3. The resulting SMILES are transformed to a canonical form.
24
+ 4. The resulting SMILES are normalized and standardized.
25
+ 5. The function returns the normalized SMILES that match the original reaction SMILES.
26
+
27
+ Parameters:
28
+ - rsmi (str): The reaction SMILES string in the form "reactants >> products".
29
+ - gml (Any): A graph model or data structure used for applying the reaction transformation.
30
+
31
+ Returns:
32
+ - List[str]: A list of valid, normalized, and standardized SMILES strings that match the original reaction SMILES.
33
+ """
34
+ # Split the input reaction SMILES into reactants and products
35
+ smiles = rsmi.split(">>")[0].split(".")
36
+
37
+ # Apply the reaction transformation based on the graph model (GML)
38
+ dg = rule_apply(smiles, gml)
39
+
40
+ # Get the transformed reaction SMILES from the graph
41
+ transformed_rsmi = list(DGToGML.getReactionSmiles(dg).values())
42
+ transformed_rsmi = [value[0] for value in transformed_rsmi]
43
+
44
+ # Normalize the transformed SMILES
45
+ normalized_rsmi = []
46
+ for value in transformed_rsmi:
47
+ try:
48
+ value = NormalizeAAM().fit(value)
49
+ normalized_rsmi.append(value)
50
+ except Exception as e:
51
+ print(e)
52
+ continue
53
+
54
+ # Standardize the normalized SMILES
55
+ curated_smiles = []
56
+ for value in normalized_rsmi:
57
+ try:
58
+ curated_smiles.append(std.fit(value))
59
+ except Exception as e:
60
+ print(e)
61
+ curated_smiles.append(None)
62
+ continue
63
+
64
+ # Standardize the original SMILES for comparison
65
+ org_smiles = std.fit(rsmi)
66
+
67
+ # Filter out the SMILES that match the original reaction SMILES
68
+ final = []
69
+ for key, value in enumerate(curated_smiles):
70
+ if value == org_smiles:
71
+ final.append(normalized_rsmi[key])
72
+
73
+ return final