synkit 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- synkit/Chem/Fingerprint/__init__.py +0 -0
- synkit/Chem/Fingerprint/fp_calculator.py +122 -0
- synkit/Chem/Fingerprint/smiles_featurizer.py +185 -0
- synkit/Chem/Fingerprint/transformation_fp.py +79 -0
- synkit/Chem/Molecule/__init__.py +0 -0
- synkit/Chem/Molecule/standardize.py +137 -0
- synkit/Chem/Reaction/__init__.py +0 -0
- synkit/Chem/Reaction/balance_check.py +162 -0
- synkit/Chem/Reaction/cleanning.py +59 -0
- synkit/Chem/Reaction/deionize.py +289 -0
- synkit/Chem/Reaction/neutralize.py +256 -0
- synkit/Chem/Reaction/reagent.py +102 -0
- synkit/Chem/Reaction/standardize.py +157 -0
- synkit/Chem/Reaction/tautomerize.py +168 -0
- synkit/Graph/Cluster/__init__.py +0 -0
- synkit/Graph/Cluster/morphism.py +83 -0
- synkit/Graph/Feature/__init__.py +0 -0
- synkit/Graph/Feature/graph_descriptors.py +325 -0
- synkit/Graph/Feature/graph_fps.py +97 -0
- synkit/Graph/Feature/graph_signature.py +236 -0
- synkit/Graph/Feature/hash_fps.py +130 -0
- synkit/Graph/Feature/morgan_fps.py +87 -0
- synkit/Graph/Feature/path_fps.py +82 -0
- synkit/Graph/__init.py +0 -0
- synkit/IO/__init__.py +0 -0
- synkit/IO/chem_converter.py +231 -0
- synkit/IO/data_io.py +277 -0
- synkit/IO/data_process.py +49 -0
- synkit/IO/debug.py +78 -0
- synkit/IO/dg_to_gml.py +124 -0
- synkit/IO/gml_to_nx.py +119 -0
- synkit/IO/graph_to_mol.py +110 -0
- synkit/IO/mol_to_graph.py +282 -0
- synkit/IO/nx_to_gml.py +200 -0
- synkit/IO/parse_rule.py +172 -0
- synkit/IO/smiles_to_id.py +119 -0
- synkit/ITS/_misc.py +280 -0
- synkit/ITS/aam_validator.py +254 -0
- synkit/ITS/its_builder.py +94 -0
- synkit/ITS/its_construction.py +213 -0
- synkit/ITS/normalize_aam.py +183 -0
- synkit/ITS/partial_expand.py +170 -0
- synkit/Reactor/__init__.py +0 -0
- synkit/Reactor/core_engine.py +164 -0
- synkit/Reactor/inference.py +73 -0
- synkit/Reactor/multi_step.py +227 -0
- synkit/Reactor/multi_step_aam.py +82 -0
- synkit/Reactor/reagent.py +95 -0
- synkit/Reactor/rule_apply.py +81 -0
- synkit/Vis/__init__.py +0 -0
- synkit/Vis/chemical_graph_visualizer.py +378 -0
- synkit/Vis/chemical_reaction_visualizer.py +133 -0
- synkit/Vis/chemical_space.py +83 -0
- synkit/Vis/embedding.py +92 -0
- synkit/Vis/graph_visualizer.py +286 -0
- synkit/Vis/pdf_writer.py +143 -0
- synkit/Vis/rsmi_to_fig.py +169 -0
- synkit/__init__.py +0 -0
- synkit/_misc.py +181 -0
- synkit-0.0.1.dist-info/METADATA +148 -0
- synkit-0.0.1.dist-info/RECORD +63 -0
- synkit-0.0.1.dist-info/WHEEL +4 -0
- synkit-0.0.1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,286 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module comprises several functions adapted from the work of Klaus Weinbauer.
|
|
3
|
+
The original code can be found at his GitHub repository: https://github.com/klausweinbauer/FGUtils.
|
|
4
|
+
Adaptations were made to enhance functionality and integrate with other system components.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import networkx as nx
|
|
8
|
+
from rdkit import Chem
|
|
9
|
+
from rdkit.Chem import rdDepictor
|
|
10
|
+
|
|
11
|
+
import matplotlib.pyplot as plt
|
|
12
|
+
from typing import Dict, Optional
|
|
13
|
+
|
|
14
|
+
from synkit.IO.graph_to_mol import GraphToMol
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class GraphVisualizer:
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
node_attributes: Dict[str, str] = {
|
|
21
|
+
"element": "element",
|
|
22
|
+
"charge": "charge",
|
|
23
|
+
"atom_map": "atom_map",
|
|
24
|
+
},
|
|
25
|
+
edge_attributes: Dict[str, str] = {"order": "order"},
|
|
26
|
+
):
|
|
27
|
+
self.node_attributes = node_attributes
|
|
28
|
+
self.edge_attributes = edge_attributes
|
|
29
|
+
|
|
30
|
+
def _get_its_as_mol(self, its: nx.Graph) -> Optional[Chem.Mol]:
|
|
31
|
+
"""
|
|
32
|
+
Convert a graph representation of an intermediate transition state into an RDKit molecule.
|
|
33
|
+
|
|
34
|
+
Parameters:
|
|
35
|
+
- its (nx.Graph): The graph to convert.
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
- Chem.Mol or None: The RDKit molecule if conversion is successful, None otherwise.
|
|
39
|
+
"""
|
|
40
|
+
_its = its.copy()
|
|
41
|
+
for n in _its.nodes():
|
|
42
|
+
_its.nodes[n]["atom_map"] = n #
|
|
43
|
+
for u, v in _its.edges():
|
|
44
|
+
_its[u][v]["order"] = 1
|
|
45
|
+
return GraphToMol(self.node_attributes, self.edge_attributes).graph_to_mol(
|
|
46
|
+
_its, False, False
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
def plot_its(
|
|
50
|
+
self,
|
|
51
|
+
its: nx.Graph,
|
|
52
|
+
ax: plt.Axes,
|
|
53
|
+
use_mol_coords: bool = True,
|
|
54
|
+
title: Optional[str] = None,
|
|
55
|
+
node_color: str = "#FFFFFF",
|
|
56
|
+
node_size: int = 500,
|
|
57
|
+
edge_color: str = "#000000",
|
|
58
|
+
edge_weight: float = 2.0,
|
|
59
|
+
show_atom_map: bool = False,
|
|
60
|
+
use_edge_color: bool = False,
|
|
61
|
+
symbol_key: str = "element",
|
|
62
|
+
bond_key: str = "order",
|
|
63
|
+
aam_key: str = "atom_map",
|
|
64
|
+
standard_order_key: str = "standard_order",
|
|
65
|
+
font_size: int = 12,
|
|
66
|
+
rule: bool = False, # New option to remove edges with specific colors
|
|
67
|
+
):
|
|
68
|
+
"""
|
|
69
|
+
Plot an intermediate transition state (ITS) graph on a given Matplotlib axes with various customizations.
|
|
70
|
+
|
|
71
|
+
Parameters:
|
|
72
|
+
- its (nx.Graph): The graph representing the intermediate transition state.
|
|
73
|
+
- ax (plt.Axes): The matplotlib axes to draw the graph on.
|
|
74
|
+
- use_mol_coords (bool): Use molecular coordinates for node positions if True, else use a spring layout.
|
|
75
|
+
- title (Optional[str]): Title for the graph. If None, no title is set.
|
|
76
|
+
- node_color (str): Color code for the graph nodes.
|
|
77
|
+
- node_size (int): Size of the graph nodes.
|
|
78
|
+
- edge_color (str): Default color code for the graph edges if not using conditional coloring.
|
|
79
|
+
- edge_weight (float): Thickness of the graph edges.
|
|
80
|
+
- show_atom_map (bool): If True, displays atom mapping numbers alongside symbols.
|
|
81
|
+
- use_edge_color (bool): If True, colors edges based on their 'standard_order' attribute.
|
|
82
|
+
- symbol_key (str): Key to access the symbol attribute in the node's data.
|
|
83
|
+
- bond_key (str): Key to access the bond type attribute in the edge's data.
|
|
84
|
+
- aam_key (str): Key to access the atom mapping number in the node's data.
|
|
85
|
+
- standard_order_key (str): Key to determine the edge color conditionally.
|
|
86
|
+
- font_size (int): Font size for labels and edge labels.
|
|
87
|
+
- rule (bool): If True, removes edges with a specific color before plotting.
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
- None
|
|
91
|
+
"""
|
|
92
|
+
bond_char = {None: "∅", 0: "∅", 1: "—", 2: "=", 3: "≡", 1.5: ":"}
|
|
93
|
+
|
|
94
|
+
positions = self._calculate_positions(its, use_mol_coords)
|
|
95
|
+
|
|
96
|
+
ax.axis("equal")
|
|
97
|
+
ax.axis("off")
|
|
98
|
+
if title:
|
|
99
|
+
ax.set_title(title)
|
|
100
|
+
|
|
101
|
+
# Conditional edge coloring based on 'standard_order'
|
|
102
|
+
if use_edge_color:
|
|
103
|
+
edge_colors = [
|
|
104
|
+
(
|
|
105
|
+
"red"
|
|
106
|
+
if data.get(standard_order_key, 0) > 0
|
|
107
|
+
else "green" if data.get(standard_order_key, 0) < 0 else "black"
|
|
108
|
+
)
|
|
109
|
+
for _, _, data in its.edges(data=True)
|
|
110
|
+
]
|
|
111
|
+
else:
|
|
112
|
+
edge_colors = edge_color
|
|
113
|
+
|
|
114
|
+
# If rule=True, remove edges with specific colors (red/green/black)
|
|
115
|
+
if rule:
|
|
116
|
+
# Get the edges that have the colors red, green, or black
|
|
117
|
+
edges_to_remove = [
|
|
118
|
+
edge
|
|
119
|
+
for edge, color in zip(its.edges(), edge_colors)
|
|
120
|
+
if color in ["red", "green", "black"]
|
|
121
|
+
]
|
|
122
|
+
its.remove_edges_from(edges_to_remove)
|
|
123
|
+
|
|
124
|
+
# Recalculate edge_colors after removal of edges
|
|
125
|
+
if use_edge_color:
|
|
126
|
+
edge_colors = [
|
|
127
|
+
(
|
|
128
|
+
"red"
|
|
129
|
+
if data.get(standard_order_key, 0) > 0
|
|
130
|
+
else "green" if data.get(standard_order_key, 0) < 0 else "black"
|
|
131
|
+
)
|
|
132
|
+
for _, _, data in its.edges(data=True)
|
|
133
|
+
]
|
|
134
|
+
else:
|
|
135
|
+
edge_colors = edge_color
|
|
136
|
+
|
|
137
|
+
# Plot the remaining graph
|
|
138
|
+
nx.draw_networkx_edges(
|
|
139
|
+
its, positions, edge_color=edge_colors, width=edge_weight, ax=ax
|
|
140
|
+
)
|
|
141
|
+
nx.draw_networkx_nodes(
|
|
142
|
+
its, positions, node_color=node_color, node_size=node_size, ax=ax
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
# Adjust labels to optionally show atom mapping numbers
|
|
146
|
+
labels = {
|
|
147
|
+
n: (
|
|
148
|
+
f"{d[symbol_key]} ({d.get(aam_key, '')})"
|
|
149
|
+
if show_atom_map
|
|
150
|
+
else f"{d[symbol_key]}"
|
|
151
|
+
)
|
|
152
|
+
for n, d in its.nodes(data=True)
|
|
153
|
+
}
|
|
154
|
+
edge_labels = self._determine_edge_labels(its, bond_char, bond_key)
|
|
155
|
+
|
|
156
|
+
nx.draw_networkx_labels(
|
|
157
|
+
its, positions, labels=labels, font_size=font_size, ax=ax
|
|
158
|
+
)
|
|
159
|
+
nx.draw_networkx_edge_labels(
|
|
160
|
+
its, positions, edge_labels=edge_labels, font_size=font_size, ax=ax
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
def _calculate_positions(self, its: nx.Graph, use_mol_coords: bool) -> dict:
|
|
164
|
+
if use_mol_coords:
|
|
165
|
+
mol = self._get_its_as_mol(its)
|
|
166
|
+
positions = {}
|
|
167
|
+
rdDepictor.Compute2DCoords(mol)
|
|
168
|
+
for i, atom in enumerate(mol.GetAtoms()):
|
|
169
|
+
aam = atom.GetAtomMapNum()
|
|
170
|
+
apos = mol.GetConformer().GetAtomPosition(i)
|
|
171
|
+
positions[aam] = [apos.x, apos.y]
|
|
172
|
+
else:
|
|
173
|
+
positions = nx.spring_layout(its)
|
|
174
|
+
return positions
|
|
175
|
+
|
|
176
|
+
def _determine_edge_labels(
|
|
177
|
+
self, its: nx.Graph, bond_char: dict, bond_key: str
|
|
178
|
+
) -> dict:
|
|
179
|
+
edge_labels = {}
|
|
180
|
+
for u, v, data in its.edges(data=True):
|
|
181
|
+
bond_codes = data.get(bond_key, (0, 0))
|
|
182
|
+
bc1, bc2 = bond_char.get(bond_codes[0], "∅"), bond_char.get(
|
|
183
|
+
bond_codes[1], "∅"
|
|
184
|
+
)
|
|
185
|
+
if bc1 != bc2:
|
|
186
|
+
edge_labels[(u, v)] = f"({bc1},{bc2})"
|
|
187
|
+
return edge_labels
|
|
188
|
+
|
|
189
|
+
def plot_as_mol(
|
|
190
|
+
self,
|
|
191
|
+
g: nx.Graph,
|
|
192
|
+
ax: plt.Axes,
|
|
193
|
+
use_mol_coords: bool = True,
|
|
194
|
+
node_color: str = "#FFFFFF",
|
|
195
|
+
node_size: int = 500,
|
|
196
|
+
edge_color: str = "#000000",
|
|
197
|
+
edge_width: float = 2.0,
|
|
198
|
+
label_color: str = "#000000",
|
|
199
|
+
font_size: int = 12,
|
|
200
|
+
show_atom_map: bool = False,
|
|
201
|
+
bond_char: Dict[Optional[int], str] = None,
|
|
202
|
+
symbol_key: str = "element",
|
|
203
|
+
bond_key: str = "order",
|
|
204
|
+
aam_key: str = "atom_map",
|
|
205
|
+
) -> None:
|
|
206
|
+
"""
|
|
207
|
+
Plots a molecular graph on a given Matplotlib axes using either molecular coordinates
|
|
208
|
+
or a networkx layout.
|
|
209
|
+
|
|
210
|
+
Parameters:
|
|
211
|
+
- g (nx.Graph): The molecular graph to be plotted.
|
|
212
|
+
- ax (plt.Axes): Matplotlib axes where the graph will be plotted.
|
|
213
|
+
- use_mol_coords (bool, optional): Use molecular coordinates if True, else use networkx layout.
|
|
214
|
+
- node_color (str, optional): Color code for the nodes.
|
|
215
|
+
- node_size (int, optional): Size of the nodes.
|
|
216
|
+
- edge_color (str, optional): Color code for the edges.
|
|
217
|
+
- label_color (str, optional): Color for node labels.
|
|
218
|
+
- font_size (int, optional): Font size for labels.
|
|
219
|
+
- bond_char (Dict[Optional[int], str], optional): Dictionary mapping bond types to characters.
|
|
220
|
+
- symbol_key (str, optional): Node attribute key for element symbols.
|
|
221
|
+
- bond_key (str, optional): Edge attribute key for bond types.
|
|
222
|
+
|
|
223
|
+
Returns:
|
|
224
|
+
- None
|
|
225
|
+
"""
|
|
226
|
+
|
|
227
|
+
# Set default bond characters if not provided
|
|
228
|
+
if bond_char is None:
|
|
229
|
+
bond_char = {None: "∅", 1: "—", 2: "=", 3: "≡", 1.5: ":"}
|
|
230
|
+
|
|
231
|
+
# Determine positions based on use_mol_coords flag
|
|
232
|
+
if use_mol_coords:
|
|
233
|
+
mol = GraphToMol(self.node_attributes, self.edge_attributes).graph_to_mol(
|
|
234
|
+
g, False
|
|
235
|
+
) # This function needs to be defined or imported
|
|
236
|
+
positions = {}
|
|
237
|
+
rdDepictor.Compute2DCoords(mol)
|
|
238
|
+
for atom in mol.GetAtoms():
|
|
239
|
+
aidx = atom.GetIdx()
|
|
240
|
+
atom_map = atom.GetAtomMapNum()
|
|
241
|
+
apos = mol.GetConformer().GetAtomPosition(aidx)
|
|
242
|
+
positions[atom_map] = [apos.x, apos.y]
|
|
243
|
+
else:
|
|
244
|
+
positions = nx.spring_layout(g) # Optionally provide a layout configuration
|
|
245
|
+
|
|
246
|
+
ax.axis("equal")
|
|
247
|
+
ax.axis("off")
|
|
248
|
+
|
|
249
|
+
# Drawing elements on the plot
|
|
250
|
+
nx.draw_networkx_edges(
|
|
251
|
+
g, positions, edge_color=edge_color, width=edge_width, ax=ax
|
|
252
|
+
)
|
|
253
|
+
nx.draw_networkx_nodes(
|
|
254
|
+
g, positions, node_color=node_color, node_size=node_size, ax=ax
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
# Preparing labels
|
|
258
|
+
labels = {}
|
|
259
|
+
for n, d in g.nodes(data=True):
|
|
260
|
+
charge = d.get("charge", 0)
|
|
261
|
+
if charge == 0:
|
|
262
|
+
charge = ""
|
|
263
|
+
elif charge > 0:
|
|
264
|
+
charge = f"{charge}+" if charge > 1 else "+"
|
|
265
|
+
else:
|
|
266
|
+
charge = f"{-charge}-" if charge < -1 else "-"
|
|
267
|
+
label = f"{d.get(symbol_key, '')}{charge}"
|
|
268
|
+
if show_atom_map:
|
|
269
|
+
label += f" ({d.get(aam_key, '')})"
|
|
270
|
+
labels[n] = label
|
|
271
|
+
edge_labels = {
|
|
272
|
+
(u, v): bond_char.get(d[bond_key], "∅") for u, v, d in g.edges(data=True)
|
|
273
|
+
}
|
|
274
|
+
|
|
275
|
+
# Drawing labels
|
|
276
|
+
nx.draw_networkx_labels(
|
|
277
|
+
g,
|
|
278
|
+
positions,
|
|
279
|
+
labels=labels,
|
|
280
|
+
font_color=label_color,
|
|
281
|
+
font_size=font_size,
|
|
282
|
+
ax=ax,
|
|
283
|
+
)
|
|
284
|
+
nx.draw_networkx_edge_labels(
|
|
285
|
+
g, positions, edge_labels=edge_labels, font_color=label_color, ax=ax
|
|
286
|
+
)
|
synkit/Vis/pdf_writer.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module comprises several functions adapted from the work of Klaus Weinbauer.
|
|
3
|
+
The original code can be found at his GitHub repository: https://github.com/klausweinbauer/FGUtils.
|
|
4
|
+
Adaptations were made to enhance functionality and integrate with other system components.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import matplotlib.pyplot as plt
|
|
8
|
+
from matplotlib.backends.backend_pdf import PdfPages
|
|
9
|
+
from typing import List, Callable, Union, Tuple, Optional
|
|
10
|
+
import tqdm
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class PdfWriter:
|
|
14
|
+
"""
|
|
15
|
+
A utility class to create PDF reports with plots from a list of figures or dynamically generated plots.
|
|
16
|
+
|
|
17
|
+
Parameters:
|
|
18
|
+
- file (str): The file name of the output PDF.
|
|
19
|
+
- plot_fn (Optional[Callable], optional): Function to create a plot for a single data entry or row.
|
|
20
|
+
Expected interface: `plot_fn(data_entry, axis, **kwargs)`. Default is None.
|
|
21
|
+
- plot_per_row (bool, optional): If True, calls `plot_fn` for an entire row instead of individual subplots.
|
|
22
|
+
Default is False.
|
|
23
|
+
- max_pages (int, optional): Maximum number of pages to create. Default is 999.
|
|
24
|
+
- rows (int, optional): Number of plot rows per page. Default is 7.
|
|
25
|
+
- cols (int, optional): Number of plot columns per page. Default is 2.
|
|
26
|
+
- pagesize (Tuple[float, float], optional): Size of a single page (in inches). Default is (21, 29.7).
|
|
27
|
+
- width_ratios (Optional[List[float]], optional): Column width ratios. Default is None.
|
|
28
|
+
- show_progress (bool, optional): If True, displays a progress bar using `tqdm`. Default is True.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
file: str,
|
|
34
|
+
plot_fn: Optional[Callable] = None,
|
|
35
|
+
plot_per_row: bool = False,
|
|
36
|
+
max_pages: int = 999,
|
|
37
|
+
rows: int = 7,
|
|
38
|
+
cols: int = 2,
|
|
39
|
+
pagesize: Tuple[float, float] = (21, 29.7),
|
|
40
|
+
width_ratios: Optional[List[float]] = None,
|
|
41
|
+
show_progress: bool = True,
|
|
42
|
+
):
|
|
43
|
+
self.pdf_pages = PdfPages(file)
|
|
44
|
+
self.plot_fn = plot_fn
|
|
45
|
+
self.plot_per_row = plot_per_row
|
|
46
|
+
self.max_pages = max_pages
|
|
47
|
+
self.rows = rows
|
|
48
|
+
self.cols = cols
|
|
49
|
+
self.pagesize = pagesize
|
|
50
|
+
self.width_ratios = width_ratios
|
|
51
|
+
self.show_progress = show_progress
|
|
52
|
+
|
|
53
|
+
def plot(self, data: Union[List[plt.Figure], List], **kwargs):
|
|
54
|
+
"""
|
|
55
|
+
Generate plots from data or save pre-generated figures to the PDF.
|
|
56
|
+
|
|
57
|
+
Parameters:
|
|
58
|
+
- data (Union[List[matplotlib.figure.Figure], List]): Input data or list of figures.
|
|
59
|
+
If a list of figures, they are saved directly. Otherwise, the `plot_fn` is called for each data entry.
|
|
60
|
+
- **kwargs: Additional keyword arguments passed to `plot_fn`.
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
- None
|
|
64
|
+
"""
|
|
65
|
+
# Case 1: Pre-generated figures
|
|
66
|
+
if all(isinstance(item, plt.Figure) for item in data):
|
|
67
|
+
for fig in tqdm.tqdm(
|
|
68
|
+
data, disable=not self.show_progress, desc="Saving Figures"
|
|
69
|
+
):
|
|
70
|
+
self.save_figure(fig)
|
|
71
|
+
return
|
|
72
|
+
|
|
73
|
+
# Case 2: Generate plots dynamically using `plot_fn`
|
|
74
|
+
if self.plot_fn is None:
|
|
75
|
+
raise ValueError(
|
|
76
|
+
"plot_fn must be provided when input is not a list of figures."
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
if not isinstance(data, list):
|
|
80
|
+
raise ValueError(
|
|
81
|
+
"Data must be a list or a list of matplotlib.figure.Figure."
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
plots_per_page = self.rows if self.plot_per_row else self.rows * self.cols
|
|
85
|
+
max_plots = self.max_pages * plots_per_page
|
|
86
|
+
step = max(len(data) / max_plots, 1)
|
|
87
|
+
pages = int((len(data) / step + plots_per_page - 1) // plots_per_page)
|
|
88
|
+
|
|
89
|
+
for p in tqdm.tqdm(
|
|
90
|
+
range(pages), disable=not self.show_progress, desc="Generating Pages"
|
|
91
|
+
):
|
|
92
|
+
fig, ax = plt.subplots(
|
|
93
|
+
self.rows,
|
|
94
|
+
self.cols,
|
|
95
|
+
figsize=self.pagesize,
|
|
96
|
+
squeeze=False,
|
|
97
|
+
gridspec_kw=(
|
|
98
|
+
{"width_ratios": self.width_ratios} if self.width_ratios else None
|
|
99
|
+
),
|
|
100
|
+
)
|
|
101
|
+
done = False
|
|
102
|
+
for r in range(self.rows):
|
|
103
|
+
if self.plot_per_row:
|
|
104
|
+
_idx = int((p * self.rows + r) * step)
|
|
105
|
+
if _idx >= len(data):
|
|
106
|
+
done = True
|
|
107
|
+
break
|
|
108
|
+
self.plot_fn(data[_idx], ax[r, :], index=_idx, **kwargs)
|
|
109
|
+
else:
|
|
110
|
+
for c in range(self.cols):
|
|
111
|
+
_idx = int((p * plots_per_page + r * self.cols + c) * step)
|
|
112
|
+
if _idx >= len(data):
|
|
113
|
+
done = True
|
|
114
|
+
break
|
|
115
|
+
self.plot_fn(data[_idx], ax[r, c], index=_idx, **kwargs)
|
|
116
|
+
plt.tight_layout()
|
|
117
|
+
self.pdf_pages.savefig(fig, bbox_inches="tight", pad_inches=1)
|
|
118
|
+
plt.close(fig)
|
|
119
|
+
if done:
|
|
120
|
+
break
|
|
121
|
+
|
|
122
|
+
def save_figure(self, figure: plt.Figure):
|
|
123
|
+
"""
|
|
124
|
+
Save a pre-generated matplotlib figure directly to the PDF.
|
|
125
|
+
|
|
126
|
+
Parameters:
|
|
127
|
+
- figure (matplotlib.figure.Figure): The figure to save.
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
- None
|
|
131
|
+
"""
|
|
132
|
+
if not isinstance(figure, plt.Figure):
|
|
133
|
+
raise ValueError("Input must be a matplotlib.figure.Figure.")
|
|
134
|
+
self.pdf_pages.savefig(figure, bbox_inches="tight", pad_inches=1)
|
|
135
|
+
|
|
136
|
+
def close(self):
|
|
137
|
+
"""
|
|
138
|
+
Close the PDF file, ensuring all pages are written.
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
- None
|
|
142
|
+
"""
|
|
143
|
+
self.pdf_pages.close()
|
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
import networkx as nx
|
|
2
|
+
import matplotlib.pyplot as plt
|
|
3
|
+
from typing import Union, Tuple
|
|
4
|
+
|
|
5
|
+
from synkit.Vis.graph_visualizer import GraphVisualizer
|
|
6
|
+
from synkit.IO.chem_converter import rsmi_to_graph
|
|
7
|
+
from synkit.ITS.its_construction import ITSConstruction
|
|
8
|
+
from synkit.IO.gml_to_nx import GMLToNX
|
|
9
|
+
|
|
10
|
+
vis_graph = GraphVisualizer()
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def three_graph_vis(
|
|
14
|
+
input: Union[str, Tuple[nx.Graph, nx.Graph, nx.Graph]],
|
|
15
|
+
sanitize: bool = False,
|
|
16
|
+
figsize: Tuple[int, int] = (18, 5),
|
|
17
|
+
orientation: str = "horizontal",
|
|
18
|
+
show_titles: bool = True,
|
|
19
|
+
show_atom_map: bool = False,
|
|
20
|
+
titles: Tuple[str, str, str] = (
|
|
21
|
+
"Reactants",
|
|
22
|
+
"Imaginary Transition State",
|
|
23
|
+
"Products",
|
|
24
|
+
),
|
|
25
|
+
add_gridbox: bool = False,
|
|
26
|
+
rule: bool = False,
|
|
27
|
+
) -> plt.Figure:
|
|
28
|
+
"""
|
|
29
|
+
Visualize three related graphs (reactants, imaginary transition state, and products)
|
|
30
|
+
side by side or vertically in a single figure.
|
|
31
|
+
|
|
32
|
+
Parameters:
|
|
33
|
+
- input (Union[str, Tuple[nx.Graph, nx.Graph, nx.Graph]]): Either
|
|
34
|
+
a reaction SMILES stringor a tuple of three NetworkX graphs
|
|
35
|
+
(reactants, products, ITS).
|
|
36
|
+
- sanitize (bool, optional): If True, sanitizes the input molecule.
|
|
37
|
+
Default is False.
|
|
38
|
+
- figsize (Tuple[int, int], optional): The size of the Matplotlib figure.
|
|
39
|
+
Default is (18, 5).
|
|
40
|
+
- orientation (str, optional): Layout of the subplots; 'horizontal' or 'vertical'.
|
|
41
|
+
Default is 'horizontal'.
|
|
42
|
+
- show_titles (bool, optional): If True, adds titles to each subplot.
|
|
43
|
+
Default is True.
|
|
44
|
+
- titles (Tuple[str, str, str], optional): Custom titles for each subplot.
|
|
45
|
+
Default is ('Reactants', 'Imaginary Transition State', 'Products').
|
|
46
|
+
- add_gridbox (bool, optional): If True, adds a gridbox cover for each subplot
|
|
47
|
+
(rectangular frame). Default is False.
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
- plt.Figure: The Matplotlib figure containing the three subplots.
|
|
51
|
+
"""
|
|
52
|
+
try:
|
|
53
|
+
# Parse input to determine graphs
|
|
54
|
+
if isinstance(input, str):
|
|
55
|
+
r, p = rsmi_to_graph(input, light_weight=True, sanitize=sanitize)
|
|
56
|
+
its = ITSConstruction().ITSGraph(r, p)
|
|
57
|
+
elif isinstance(input, tuple) and len(input) == 3:
|
|
58
|
+
r, p, its = input
|
|
59
|
+
else:
|
|
60
|
+
raise ValueError(
|
|
61
|
+
"Input must be a reaction SMILES string or a tuple of three graphs (r, p, its)."
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
# Set up subplots
|
|
65
|
+
if orientation == "horizontal":
|
|
66
|
+
fig, ax = plt.subplots(1, 3, figsize=figsize)
|
|
67
|
+
elif orientation == "vertical":
|
|
68
|
+
fig, ax = plt.subplots(3, 1, figsize=figsize)
|
|
69
|
+
else:
|
|
70
|
+
raise ValueError("Orientation must be 'horizontal' or 'vertical'.")
|
|
71
|
+
|
|
72
|
+
# Plot the graphs
|
|
73
|
+
vis_graph.plot_as_mol(
|
|
74
|
+
r,
|
|
75
|
+
ax[0],
|
|
76
|
+
show_atom_map=show_atom_map,
|
|
77
|
+
font_size=12,
|
|
78
|
+
node_size=800,
|
|
79
|
+
edge_width=2.0,
|
|
80
|
+
)
|
|
81
|
+
if show_titles:
|
|
82
|
+
ax[0].set_title(titles[0])
|
|
83
|
+
|
|
84
|
+
vis_graph.plot_its(
|
|
85
|
+
its, ax[1], use_edge_color=True, show_atom_map=show_atom_map, rule=rule
|
|
86
|
+
)
|
|
87
|
+
if show_titles:
|
|
88
|
+
ax[1].set_title(titles[1])
|
|
89
|
+
|
|
90
|
+
vis_graph.plot_as_mol(
|
|
91
|
+
p,
|
|
92
|
+
ax[2],
|
|
93
|
+
show_atom_map=show_atom_map,
|
|
94
|
+
font_size=12,
|
|
95
|
+
node_size=800,
|
|
96
|
+
edge_width=2.0,
|
|
97
|
+
)
|
|
98
|
+
if show_titles:
|
|
99
|
+
ax[2].set_title(titles[2])
|
|
100
|
+
|
|
101
|
+
# Add gridbox frame around each subplot if requested
|
|
102
|
+
if add_gridbox:
|
|
103
|
+
for a in ax:
|
|
104
|
+
# Make sure the grid is on top of the plot
|
|
105
|
+
a.set_axisbelow(False)
|
|
106
|
+
|
|
107
|
+
# Add a rectangular frame (gridbox) with thicker borders
|
|
108
|
+
for spine in a.spines.values():
|
|
109
|
+
spine.set_visible(True)
|
|
110
|
+
spine.set_linewidth(2)
|
|
111
|
+
spine.set_color("black")
|
|
112
|
+
|
|
113
|
+
# Make gridlines lighter and under the plot elements
|
|
114
|
+
a.grid(
|
|
115
|
+
True,
|
|
116
|
+
which="both",
|
|
117
|
+
axis="both",
|
|
118
|
+
linestyle="--",
|
|
119
|
+
color="gray",
|
|
120
|
+
alpha=0.5,
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
return fig
|
|
124
|
+
|
|
125
|
+
except Exception as e:
|
|
126
|
+
raise RuntimeError(f"An error occurred during visualization: {str(e)}")
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def rule_visualize(gml, rule=True, titles=None):
|
|
130
|
+
"""
|
|
131
|
+
Visualizes a reaction network from GML data with optional edge filtering (rule)
|
|
132
|
+
and custom titles.
|
|
133
|
+
|
|
134
|
+
Parameters:
|
|
135
|
+
- gml (str): GML format string representing the reaction data.
|
|
136
|
+
- rule (bool): If True, applies the rule to filter edges
|
|
137
|
+
(e.g., removing edges based on color).
|
|
138
|
+
- titles (list, optional): List of titles for the subplots.
|
|
139
|
+
Defaults to ['L', 'K', 'R'].
|
|
140
|
+
|
|
141
|
+
Returns:
|
|
142
|
+
- plt.Figure: Matplotlib figure containing the visualized reaction network.
|
|
143
|
+
"""
|
|
144
|
+
try:
|
|
145
|
+
# Transform GML to NetworkX graphs
|
|
146
|
+
r, p, its = GMLToNX(gml).transform()
|
|
147
|
+
|
|
148
|
+
# If no titles are provided, default to ['L', 'K', 'R']
|
|
149
|
+
if titles is None:
|
|
150
|
+
titles = ["L", "K", "R"]
|
|
151
|
+
|
|
152
|
+
# Ensure titles match the number of graphs (3)
|
|
153
|
+
if len(titles) != 3:
|
|
154
|
+
raise ValueError(
|
|
155
|
+
"The titles list must contain exactly three titles for the three graphs."
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
# Call the `three_graph_vis` function with the transformed graphs and rule filtering
|
|
159
|
+
return three_graph_vis(
|
|
160
|
+
(r, p, its),
|
|
161
|
+
add_gridbox=True, # Add the gridbox around the plot
|
|
162
|
+
titles=titles, # Pass the titles for the subplots
|
|
163
|
+
rule=rule, # Apply the rule filtering based on the value of `rule`
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
except Exception as e:
|
|
167
|
+
raise RuntimeError(
|
|
168
|
+
f"An error occurred during the visualization process: {str(e)}"
|
|
169
|
+
)
|
synkit/__init__.py
ADDED
|
File without changes
|