synkit 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- synkit/Chem/Fingerprint/__init__.py +0 -0
- synkit/Chem/Fingerprint/fp_calculator.py +122 -0
- synkit/Chem/Fingerprint/smiles_featurizer.py +185 -0
- synkit/Chem/Fingerprint/transformation_fp.py +79 -0
- synkit/Chem/Molecule/__init__.py +0 -0
- synkit/Chem/Molecule/standardize.py +137 -0
- synkit/Chem/Reaction/__init__.py +0 -0
- synkit/Chem/Reaction/balance_check.py +162 -0
- synkit/Chem/Reaction/cleanning.py +59 -0
- synkit/Chem/Reaction/deionize.py +289 -0
- synkit/Chem/Reaction/neutralize.py +256 -0
- synkit/Chem/Reaction/reagent.py +102 -0
- synkit/Chem/Reaction/standardize.py +157 -0
- synkit/Chem/Reaction/tautomerize.py +168 -0
- synkit/Graph/Cluster/__init__.py +0 -0
- synkit/Graph/Cluster/morphism.py +83 -0
- synkit/Graph/Feature/__init__.py +0 -0
- synkit/Graph/Feature/graph_descriptors.py +325 -0
- synkit/Graph/Feature/graph_fps.py +97 -0
- synkit/Graph/Feature/graph_signature.py +236 -0
- synkit/Graph/Feature/hash_fps.py +130 -0
- synkit/Graph/Feature/morgan_fps.py +87 -0
- synkit/Graph/Feature/path_fps.py +82 -0
- synkit/Graph/__init.py +0 -0
- synkit/IO/__init__.py +0 -0
- synkit/IO/chem_converter.py +231 -0
- synkit/IO/data_io.py +277 -0
- synkit/IO/data_process.py +49 -0
- synkit/IO/debug.py +78 -0
- synkit/IO/dg_to_gml.py +124 -0
- synkit/IO/gml_to_nx.py +119 -0
- synkit/IO/graph_to_mol.py +110 -0
- synkit/IO/mol_to_graph.py +282 -0
- synkit/IO/nx_to_gml.py +200 -0
- synkit/IO/parse_rule.py +172 -0
- synkit/IO/smiles_to_id.py +119 -0
- synkit/ITS/_misc.py +280 -0
- synkit/ITS/aam_validator.py +254 -0
- synkit/ITS/its_builder.py +94 -0
- synkit/ITS/its_construction.py +213 -0
- synkit/ITS/normalize_aam.py +183 -0
- synkit/ITS/partial_expand.py +170 -0
- synkit/Reactor/__init__.py +0 -0
- synkit/Reactor/core_engine.py +164 -0
- synkit/Reactor/inference.py +73 -0
- synkit/Reactor/multi_step.py +227 -0
- synkit/Reactor/multi_step_aam.py +82 -0
- synkit/Reactor/reagent.py +95 -0
- synkit/Reactor/rule_apply.py +81 -0
- synkit/Vis/__init__.py +0 -0
- synkit/Vis/chemical_graph_visualizer.py +378 -0
- synkit/Vis/chemical_reaction_visualizer.py +133 -0
- synkit/Vis/chemical_space.py +83 -0
- synkit/Vis/embedding.py +92 -0
- synkit/Vis/graph_visualizer.py +286 -0
- synkit/Vis/pdf_writer.py +143 -0
- synkit/Vis/rsmi_to_fig.py +169 -0
- synkit/__init__.py +0 -0
- synkit/_misc.py +181 -0
- synkit-0.0.1.dist-info/METADATA +148 -0
- synkit-0.0.1.dist-info/RECORD +63 -0
- synkit-0.0.1.dist-info/WHEEL +4 -0
- synkit-0.0.1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,378 @@
|
|
|
1
|
+
import random
|
|
2
|
+
import networkx as nx
|
|
3
|
+
import matplotlib.pyplot as plt
|
|
4
|
+
import seaborn as sns
|
|
5
|
+
from typing import Optional, Dict
|
|
6
|
+
import logging
|
|
7
|
+
|
|
8
|
+
color_scheme = {
|
|
9
|
+
"H": "#FFFFFF", # White
|
|
10
|
+
"C": "#909090", # Gray
|
|
11
|
+
"N": "#3050F8", # Blue
|
|
12
|
+
"O": "#FF0D0D", # Red
|
|
13
|
+
"F": "#90E050", # Green
|
|
14
|
+
"Cl": "#1FF01F", # Green
|
|
15
|
+
"Br": "#A62929", # Dark Red/Brown
|
|
16
|
+
"I": "#940094", # Purple
|
|
17
|
+
"P": "#FF8000", # Orange
|
|
18
|
+
"S": "#FFFF30", # Yellow
|
|
19
|
+
"Na": "#E0E0E0", # Light Gray
|
|
20
|
+
"K": "#8F40D4", # Light Purple
|
|
21
|
+
"Ca": "#3DFF00", # Light Green
|
|
22
|
+
"Mg": "#8AFF00", # Light Green
|
|
23
|
+
"Fe": "#B7410E", # Rust Red
|
|
24
|
+
"Zn": "#7D80B0", # Light Blue
|
|
25
|
+
"Cu": "#C88033", # Copper Red/Orange
|
|
26
|
+
"Ag": "#C0C0C0", # Light Gray
|
|
27
|
+
"Au": "#FFD123", # Gold Yellow
|
|
28
|
+
"Hg": "#B8B8D0", # Silver Gray
|
|
29
|
+
"Pb": "#575961", # Dark Gray
|
|
30
|
+
"Al": "#BFA6A6", # Light Gray
|
|
31
|
+
"Si": "#F0C8A0", # Light Brown
|
|
32
|
+
"B": "#FFA1A1", # Pink
|
|
33
|
+
"As": "#BD80E3", # Light Gray
|
|
34
|
+
"Sb": "#9E63B5", # Dark Gray
|
|
35
|
+
"Se": "#FFA100", # Light Pink
|
|
36
|
+
"Te": "#D47A00", # Gray
|
|
37
|
+
"Cd": "#FFD98F", # Light Blue/Gray
|
|
38
|
+
"Ti": "#BFC2C7", # Light Gray
|
|
39
|
+
"V": "#A6A6AB", # Light Gray/Blue
|
|
40
|
+
"Cr": "#8A99C7", # Steel Gray
|
|
41
|
+
"Mn": "#9C7AC7", # Gray
|
|
42
|
+
"Co": "#FF7A00", # Light Pink
|
|
43
|
+
"Ni": "#4DFF4D", # Light Green
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class ChemicalGraphVisualizer:
|
|
48
|
+
def __init__(
|
|
49
|
+
self,
|
|
50
|
+
seed: Optional[int] = None,
|
|
51
|
+
element_colors: Optional[Dict[str, str]] = color_scheme,
|
|
52
|
+
):
|
|
53
|
+
"""
|
|
54
|
+
Initialize the visualizer with optional seed and color scheme.
|
|
55
|
+
|
|
56
|
+
Parameters:
|
|
57
|
+
seed (int, optional): Seed for random number generator for reproducibility.
|
|
58
|
+
element_colors (dict, optional): Dictionary mapping elements to their color codes.
|
|
59
|
+
"""
|
|
60
|
+
# Define a popular color scheme in chemistry if not provided
|
|
61
|
+
if element_colors is None:
|
|
62
|
+
self.element_colors = {
|
|
63
|
+
"H": "#FFFFFF", # White
|
|
64
|
+
"C": "#909090", # Gray
|
|
65
|
+
"N": "#3050F8", # Blue
|
|
66
|
+
"O": "#FF0D0D", # Red
|
|
67
|
+
"F": "#90E050", # Green
|
|
68
|
+
"Cl": "#1FF01F", # Green
|
|
69
|
+
# Additional elements can be added here
|
|
70
|
+
}
|
|
71
|
+
else:
|
|
72
|
+
self.element_colors = element_colors
|
|
73
|
+
self.seed = seed
|
|
74
|
+
|
|
75
|
+
def graph_vis(
|
|
76
|
+
self,
|
|
77
|
+
G: nx.Graph,
|
|
78
|
+
node_size: int = 100,
|
|
79
|
+
visualize_edge_weight: bool = False,
|
|
80
|
+
edge_font_size: int = 10,
|
|
81
|
+
show_node_labels: bool = False,
|
|
82
|
+
node_label_font_size: int = 12,
|
|
83
|
+
ax: Optional[plt.Axes] = None,
|
|
84
|
+
) -> None:
|
|
85
|
+
"""
|
|
86
|
+
Visualize a NetworkX graph with standard representation.
|
|
87
|
+
|
|
88
|
+
Parameters:
|
|
89
|
+
G (nx.Graph): The graph to visualize.
|
|
90
|
+
node_size (int): The size of the nodes.
|
|
91
|
+
visualize_edge_weight (bool): Whether to display edge weights.
|
|
92
|
+
edge_font_size (int): Font size for edge labels.
|
|
93
|
+
show_node_labels (bool): Whether to show labels on the nodes.
|
|
94
|
+
node_label_font_size (int): Font size for node labels.
|
|
95
|
+
"""
|
|
96
|
+
# Set random seed for reproducibility
|
|
97
|
+
if self.seed is not None:
|
|
98
|
+
random.seed(self.seed)
|
|
99
|
+
|
|
100
|
+
# Get colors for each node
|
|
101
|
+
node_colors = [
|
|
102
|
+
self.element_colors.get(G.nodes[node]["element"], "#000000")
|
|
103
|
+
for node in G.nodes()
|
|
104
|
+
]
|
|
105
|
+
|
|
106
|
+
# Draw the graph
|
|
107
|
+
pos = nx.spring_layout(G, seed=self.seed) # Use spring layout
|
|
108
|
+
|
|
109
|
+
if ax is None:
|
|
110
|
+
ax = plt.gca() # Get current axes if not provided
|
|
111
|
+
|
|
112
|
+
if show_node_labels:
|
|
113
|
+
node_labels = {node: G.nodes[node]["element"] for node in G.nodes()}
|
|
114
|
+
nx.draw(
|
|
115
|
+
G,
|
|
116
|
+
pos,
|
|
117
|
+
ax=ax,
|
|
118
|
+
with_labels=True,
|
|
119
|
+
labels=node_labels,
|
|
120
|
+
node_color=node_colors,
|
|
121
|
+
node_size=node_size,
|
|
122
|
+
font_size=node_label_font_size,
|
|
123
|
+
# font_weight="semi-bold",
|
|
124
|
+
)
|
|
125
|
+
else:
|
|
126
|
+
nx.draw(
|
|
127
|
+
G,
|
|
128
|
+
pos,
|
|
129
|
+
ax=ax,
|
|
130
|
+
with_labels=False,
|
|
131
|
+
node_color=node_colors,
|
|
132
|
+
node_size=node_size,
|
|
133
|
+
# font_weight="bold",
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
# Get edge labels if needed
|
|
137
|
+
if visualize_edge_weight:
|
|
138
|
+
edge_labels = {(u, v): G.edges[u, v]["order"] for u, v in G.edges()}
|
|
139
|
+
nx.draw_networkx_edge_labels(
|
|
140
|
+
G, pos, ax=ax, edge_labels=edge_labels, font_size=edge_font_size
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
def its_vis(
|
|
144
|
+
self,
|
|
145
|
+
G: nx.Graph,
|
|
146
|
+
node_size: int = 100,
|
|
147
|
+
show_node_labels: bool = False,
|
|
148
|
+
node_label_font_size: int = 12,
|
|
149
|
+
ax: Optional[plt.Axes] = None,
|
|
150
|
+
) -> None:
|
|
151
|
+
"""
|
|
152
|
+
Visualize a NetworkX graph with edge colors indicating bond changes.
|
|
153
|
+
|
|
154
|
+
Parameters:
|
|
155
|
+
G (nx.Graph): The graph to visualize.
|
|
156
|
+
node_size (int): The size of the nodes.
|
|
157
|
+
show_node_labels (bool): Whether to show labels on the nodes.
|
|
158
|
+
node_label_font_size (int): Font size for node labels.
|
|
159
|
+
"""
|
|
160
|
+
# Set random seed for reproducibility
|
|
161
|
+
if self.seed is not None:
|
|
162
|
+
random.seed(self.seed)
|
|
163
|
+
|
|
164
|
+
# Draw the graph
|
|
165
|
+
pos = nx.spring_layout(G, seed=self.seed) # Use spring layout
|
|
166
|
+
# Get colors for each node
|
|
167
|
+
node_colors = [
|
|
168
|
+
self.element_colors.get(G.nodes[node]["element"], "#000000")
|
|
169
|
+
for node in G.nodes()
|
|
170
|
+
]
|
|
171
|
+
|
|
172
|
+
# Determine edge colors based on 'order'
|
|
173
|
+
edge_colors = []
|
|
174
|
+
for u, v, data in G.edges(data=True):
|
|
175
|
+
order = data.get("standard_order", 0)
|
|
176
|
+
# order = tuple(0 if isinstance(x, str) else float(x) for x in order)
|
|
177
|
+
if order == 0:
|
|
178
|
+
edge_colors.append("black") # Normal bond
|
|
179
|
+
elif order < 0:
|
|
180
|
+
edge_colors.append("blue") # Increasing bond
|
|
181
|
+
else:
|
|
182
|
+
edge_colors.append("red") # Breaking bond
|
|
183
|
+
|
|
184
|
+
if ax is None:
|
|
185
|
+
ax = plt.gca() # Get current axes if not provided
|
|
186
|
+
|
|
187
|
+
if show_node_labels:
|
|
188
|
+
node_labels = {node: G.nodes[node]["element"] for node in G.nodes()}
|
|
189
|
+
nx.draw(
|
|
190
|
+
G,
|
|
191
|
+
pos,
|
|
192
|
+
ax=ax,
|
|
193
|
+
with_labels=True,
|
|
194
|
+
labels=node_labels,
|
|
195
|
+
node_color=node_colors,
|
|
196
|
+
node_size=node_size,
|
|
197
|
+
font_size=node_label_font_size,
|
|
198
|
+
# font_weight="bold",
|
|
199
|
+
edge_color=edge_colors,
|
|
200
|
+
)
|
|
201
|
+
else:
|
|
202
|
+
nx.draw(
|
|
203
|
+
G,
|
|
204
|
+
pos,
|
|
205
|
+
ax=ax,
|
|
206
|
+
with_labels=False,
|
|
207
|
+
node_color=node_colors,
|
|
208
|
+
node_size=node_size,
|
|
209
|
+
font_weight="bold",
|
|
210
|
+
edge_color=edge_colors,
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
def vis_three_graph(
|
|
214
|
+
self,
|
|
215
|
+
graph_tuple,
|
|
216
|
+
figsize=(15, 5),
|
|
217
|
+
left_graph_title="Reactants",
|
|
218
|
+
k_graph_title="ITS Graph",
|
|
219
|
+
right_graph_title="Products",
|
|
220
|
+
show_node_labels=True,
|
|
221
|
+
title_fontsize=24,
|
|
222
|
+
title_weight="bold",
|
|
223
|
+
save_path=None,
|
|
224
|
+
display_inline=False,
|
|
225
|
+
log=False,
|
|
226
|
+
):
|
|
227
|
+
"""
|
|
228
|
+
Visualize reactants, ITS graph, and products in one figure.
|
|
229
|
+
|
|
230
|
+
Parameters:
|
|
231
|
+
graph_tuple (tuple): Tuple of NetworkX graphs (reactants, ITS graph, products).
|
|
232
|
+
figsize (tuple): Figure size in inches (width, height).
|
|
233
|
+
left_graph_title (str): Title for the left subplot.
|
|
234
|
+
k_graph_title (str): Title for the middle subplot.
|
|
235
|
+
right_graph_title (str): Title for the right subplot.
|
|
236
|
+
show_node_labels (bool): If True, show node labels on the graphs.
|
|
237
|
+
title_fontsize (int): Font size for subplot titles.
|
|
238
|
+
title_weight (str): Font weight for subplot titles.
|
|
239
|
+
save_path (str, optional): Path to save the figure to file.
|
|
240
|
+
display_inline (bool): If True, display the figure inline in the notebook.
|
|
241
|
+
log (bool): If True, enable logging of function progress.
|
|
242
|
+
"""
|
|
243
|
+
if log:
|
|
244
|
+
logging.basicConfig(level=logging.INFO)
|
|
245
|
+
|
|
246
|
+
try:
|
|
247
|
+
# Unpack the tuple
|
|
248
|
+
reactants_graph, products_graph, its_graph = graph_tuple
|
|
249
|
+
|
|
250
|
+
# Create a figure with subplots
|
|
251
|
+
fig, axs = plt.subplots(1, 3, figsize=figsize)
|
|
252
|
+
|
|
253
|
+
# Visualize each graph on its respective subplot
|
|
254
|
+
self.graph_vis(
|
|
255
|
+
reactants_graph, ax=axs[0], show_node_labels=show_node_labels
|
|
256
|
+
)
|
|
257
|
+
self.its_vis(its_graph, ax=axs[1], show_node_labels=show_node_labels)
|
|
258
|
+
self.graph_vis(products_graph, ax=axs[2], show_node_labels=show_node_labels)
|
|
259
|
+
|
|
260
|
+
# Set titles for subplots
|
|
261
|
+
axs[0].set_title(
|
|
262
|
+
left_graph_title, fontsize=title_fontsize, weight=title_weight
|
|
263
|
+
)
|
|
264
|
+
axs[1].set_title(
|
|
265
|
+
k_graph_title, fontsize=title_fontsize, weight=title_weight
|
|
266
|
+
)
|
|
267
|
+
axs[2].set_title(
|
|
268
|
+
right_graph_title, fontsize=title_fontsize, weight=title_weight
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
plt.tight_layout()
|
|
272
|
+
|
|
273
|
+
if save_path is not None:
|
|
274
|
+
plt.savefig(save_path, dpi=600)
|
|
275
|
+
if log:
|
|
276
|
+
logging.info(f"Figure saved to {save_path}")
|
|
277
|
+
|
|
278
|
+
if display_inline:
|
|
279
|
+
plt.show()
|
|
280
|
+
else:
|
|
281
|
+
plt.close(fig)
|
|
282
|
+
|
|
283
|
+
return fig
|
|
284
|
+
except Exception as e:
|
|
285
|
+
if log:
|
|
286
|
+
logging.error("Failed to visualize graphs: ", exc_info=True)
|
|
287
|
+
raise RuntimeError("Error in graph visualization: ") from e
|
|
288
|
+
|
|
289
|
+
def visualize_all(
|
|
290
|
+
self,
|
|
291
|
+
graph_tuple_row1,
|
|
292
|
+
graph_tuple_row2,
|
|
293
|
+
figsize=(15, 10),
|
|
294
|
+
titles_row1=("A. Reactant Graph", "B. ITS Graph", "C Products"),
|
|
295
|
+
titles_row2=("D. L Graph", "E. K Graph", "D. R Graph"),
|
|
296
|
+
show_node_labels=True,
|
|
297
|
+
show_grid=True,
|
|
298
|
+
grid_style="--",
|
|
299
|
+
title_fontsize=24,
|
|
300
|
+
title_weight="bold",
|
|
301
|
+
save_path=None,
|
|
302
|
+
display_inline=False,
|
|
303
|
+
log=False,
|
|
304
|
+
):
|
|
305
|
+
"""
|
|
306
|
+
Visualize two rows of graphs, each with three graphs, optionally displaying a
|
|
307
|
+
grid.
|
|
308
|
+
|
|
309
|
+
Parameters:
|
|
310
|
+
graph_tuple_row1 (tuple): Tuple of NetworkX graphs for the first row
|
|
311
|
+
(reactants, ITS graph, products).
|
|
312
|
+
graph_tuple_row2 (tuple): Tuple of NetworkX graphs for the second row (L, K,
|
|
313
|
+
R).
|
|
314
|
+
figsize (tuple): Figure size in inches (width, height).
|
|
315
|
+
titles_row1 (tuple): Titles for the first row subplots.
|
|
316
|
+
titles_row2 (tuple): Titles for the second row subplots.
|
|
317
|
+
show_node_labels (bool): If True, show node labels on the graphs.
|
|
318
|
+
show_grid (bool): If True, display grid lines on the plots.
|
|
319
|
+
grid_style (str): Style of the grid lines.
|
|
320
|
+
title_fontsize (int): Font size for subplot titles.
|
|
321
|
+
title_weight (str): Font weight for subplot titles.
|
|
322
|
+
save_path (str, optional): Path to save the figure to file.
|
|
323
|
+
display_inline (bool): If True, display the figure inline in the notebook.
|
|
324
|
+
log (bool): If True, enable logging of function progress.
|
|
325
|
+
"""
|
|
326
|
+
if log:
|
|
327
|
+
logging.basicConfig(level=logging.INFO)
|
|
328
|
+
|
|
329
|
+
try:
|
|
330
|
+
sns.set_theme(style="darkgrid") # Set the Seaborn style
|
|
331
|
+
reactants_graph, products_graph, its_graph = graph_tuple_row1
|
|
332
|
+
l_graph, r_graph, k_graph = graph_tuple_row2
|
|
333
|
+
|
|
334
|
+
# Create a figure with subplots
|
|
335
|
+
fig, axs = plt.subplots(2, 3, figsize=figsize)
|
|
336
|
+
|
|
337
|
+
# Visualize each graph on its respective subplot (first row)
|
|
338
|
+
self.graph_vis(
|
|
339
|
+
reactants_graph, ax=axs[0, 0], show_node_labels=show_node_labels
|
|
340
|
+
)
|
|
341
|
+
self.its_vis(its_graph, ax=axs[0, 1], show_node_labels=show_node_labels)
|
|
342
|
+
self.graph_vis(
|
|
343
|
+
products_graph, ax=axs[0, 2], show_node_labels=show_node_labels
|
|
344
|
+
)
|
|
345
|
+
|
|
346
|
+
# Visualize each graph on its respective subplot (second row)
|
|
347
|
+
self.graph_vis(l_graph, ax=axs[1, 0], show_node_labels=show_node_labels)
|
|
348
|
+
self.its_vis(k_graph, ax=axs[1, 1], show_node_labels=show_node_labels)
|
|
349
|
+
self.graph_vis(r_graph, ax=axs[1, 2], show_node_labels=show_node_labels)
|
|
350
|
+
|
|
351
|
+
# Set titles and enable grid for subplots
|
|
352
|
+
for ax, title in zip(axs[0], titles_row1):
|
|
353
|
+
ax.set_title(title, fontsize=title_fontsize, weight=title_weight)
|
|
354
|
+
if show_grid:
|
|
355
|
+
ax.grid(True, linestyle=grid_style, which="both")
|
|
356
|
+
|
|
357
|
+
for ax, title in zip(axs[1], titles_row2):
|
|
358
|
+
ax.set_title(title, fontsize=title_fontsize, weight=title_weight)
|
|
359
|
+
if show_grid:
|
|
360
|
+
ax.grid(True, linestyle=grid_style, which="both")
|
|
361
|
+
|
|
362
|
+
plt.tight_layout()
|
|
363
|
+
|
|
364
|
+
if save_path is not None:
|
|
365
|
+
plt.savefig(save_path, dpi=600)
|
|
366
|
+
if log:
|
|
367
|
+
logging.info(f"Figure saved to {save_path}")
|
|
368
|
+
|
|
369
|
+
if display_inline:
|
|
370
|
+
plt.show()
|
|
371
|
+
else:
|
|
372
|
+
plt.close(fig)
|
|
373
|
+
|
|
374
|
+
return fig
|
|
375
|
+
except Exception as e:
|
|
376
|
+
if log:
|
|
377
|
+
logging.error("Failed to visualize graphs: ", exc_info=True)
|
|
378
|
+
raise RuntimeError("Error in graph visualization: ") from e
|
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
from typing import List, Dict
|
|
2
|
+
from IPython.display import display, HTML, SVG
|
|
3
|
+
from rdkit.Chem.Draw import rdMolDraw2D
|
|
4
|
+
from rdkit.Chem import rdChemReactions
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class ChemicalReactionVisualizer:
|
|
8
|
+
@staticmethod
|
|
9
|
+
def create_html_table_with_svgs(
|
|
10
|
+
svg_list: List[str],
|
|
11
|
+
titles: List[str],
|
|
12
|
+
num_cols: int = 2,
|
|
13
|
+
orientation: str = "vertical",
|
|
14
|
+
title_size: int = 16,
|
|
15
|
+
) -> HTML:
|
|
16
|
+
"""
|
|
17
|
+
Creates an HTML table to display SVG images with titles in
|
|
18
|
+
a structured 'subplot-like' layout.
|
|
19
|
+
|
|
20
|
+
Parameters:
|
|
21
|
+
- svg_list (List[str]): List of SVG content strings.
|
|
22
|
+
- titles (List[str]): Corresponding titles for each SVG image.
|
|
23
|
+
- num_cols (int): Defines the number of columns for the
|
|
24
|
+
'vertical' layout or rows for 'horizontal' layout.
|
|
25
|
+
- orientation (str): Layout orientation of images ('vertical' or 'horizontal').
|
|
26
|
+
- title_size (int): Font size of the titles displayed above each image.
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
- HTML: HTML object to be displayed within an IPython notebook environment.
|
|
30
|
+
"""
|
|
31
|
+
html = "<table>"
|
|
32
|
+
title_style = f"font-size:{title_size}px;" # CSS to control title size
|
|
33
|
+
if orientation == "vertical":
|
|
34
|
+
for i in range(0, len(svg_list), num_cols):
|
|
35
|
+
html += "<tr>"
|
|
36
|
+
for j in range(num_cols):
|
|
37
|
+
if i + j < len(svg_list):
|
|
38
|
+
html += f"<td style='border:1px solid black; padding:10px'><b style='{title_style}'>{titles[i+j]}</b><br>{svg_list[i+j]}</td>"
|
|
39
|
+
html += "</tr>"
|
|
40
|
+
else:
|
|
41
|
+
for j in range(num_cols):
|
|
42
|
+
html += "<tr>"
|
|
43
|
+
for i in range(j, len(svg_list), num_cols):
|
|
44
|
+
html += f"<td style='border:1px solid black; padding:10px'><b style='{title_style}'>{titles[i]}</b><br>{svg_list[i]}</td>"
|
|
45
|
+
html += "</tr>"
|
|
46
|
+
html += "</table>"
|
|
47
|
+
return HTML(html)
|
|
48
|
+
|
|
49
|
+
@staticmethod
|
|
50
|
+
def visualize_reaction(
|
|
51
|
+
reaction_smiles: str,
|
|
52
|
+
img_size: tuple = (600, 200),
|
|
53
|
+
highlight_by_reactant: bool = True,
|
|
54
|
+
mol_scale: float = 0.9,
|
|
55
|
+
bond_line_width: float = 2.0,
|
|
56
|
+
atom_label_font_size: int = 12,
|
|
57
|
+
padding: float = 0.01,
|
|
58
|
+
show_atom_map: bool = False,
|
|
59
|
+
) -> SVG:
|
|
60
|
+
"""
|
|
61
|
+
Visualizes a chemical reaction using RDKit and generates an SVG image for display.
|
|
62
|
+
|
|
63
|
+
Parameters:
|
|
64
|
+
- reaction_smiles (str): SMILES string representing the chemical reaction.
|
|
65
|
+
- img_size (tuple): Dimensions of the output image (width, height).
|
|
66
|
+
- highlight_by_reactant (bool): Whether to highlight reactants in the image.
|
|
67
|
+
- mol_scale (float): Scale factor for the size of the molecules in the image.
|
|
68
|
+
- bond_line_width (float): Line width for the bonds in the drawing.
|
|
69
|
+
- atom_label_font_size (int): Font size for the atom labels in the drawing.
|
|
70
|
+
- padding (float): Padding around the image in the SVG.
|
|
71
|
+
- show_atom_map (bool): Whether to display atom mapping numbers on the atoms.
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
- SVG: An SVG object containing the rendered reaction.
|
|
75
|
+
"""
|
|
76
|
+
reaction = rdChemReactions.ReactionFromSmarts(reaction_smiles, useSmiles=True)
|
|
77
|
+
rdChemReactions.PreprocessReaction(reaction)
|
|
78
|
+
if show_atom_map:
|
|
79
|
+
for mol in list(reaction.GetReactants()) + list(reaction.GetProducts()):
|
|
80
|
+
for atom in mol.GetAtoms():
|
|
81
|
+
if atom.HasProp("molAtomMapNumber"):
|
|
82
|
+
atom.SetProp("atomLabel", atom.GetProp("molAtomMapNumber"))
|
|
83
|
+
|
|
84
|
+
drawer = rdMolDraw2D.MolDraw2DSVG(img_size[0], img_size[1])
|
|
85
|
+
opts = drawer.drawOptions()
|
|
86
|
+
opts.scale = mol_scale
|
|
87
|
+
opts.bondLineWidth = bond_line_width
|
|
88
|
+
opts.atomLabelFontSize = atom_label_font_size
|
|
89
|
+
opts.padding = padding
|
|
90
|
+
|
|
91
|
+
drawer.DrawReaction(reaction, highlightByReactant=highlight_by_reactant)
|
|
92
|
+
drawer.FinishDrawing()
|
|
93
|
+
return SVG(drawer.GetDrawingText())
|
|
94
|
+
|
|
95
|
+
@staticmethod
|
|
96
|
+
def visualize_and_compare_reactions(
|
|
97
|
+
input_dict: Dict[str, str],
|
|
98
|
+
id_col: str = "R-id",
|
|
99
|
+
img_size: tuple = (1000, 600),
|
|
100
|
+
num_cols: int = 2,
|
|
101
|
+
orientation: str = "vertical",
|
|
102
|
+
show_atom_map: bool = False,
|
|
103
|
+
) -> None:
|
|
104
|
+
"""
|
|
105
|
+
Visualizes and compares multiple chemical reactions,
|
|
106
|
+
displaying them side by side in an HTML table.
|
|
107
|
+
|
|
108
|
+
Parameters:
|
|
109
|
+
- input_dict (Dict[str, str]): Dictionary with reaction
|
|
110
|
+
identifiers as keys and SMILES strings as values.
|
|
111
|
+
- id_col (str): A dictionary key to exclude
|
|
112
|
+
from visualization (typically metadata).
|
|
113
|
+
- img_size (tuple): size if image.
|
|
114
|
+
- num_cols (int): Number of columns in the display table.
|
|
115
|
+
- orientation (str): vertically or horizontally.
|
|
116
|
+
"""
|
|
117
|
+
svg_list = []
|
|
118
|
+
titles = []
|
|
119
|
+
for key, reaction_str in input_dict.items():
|
|
120
|
+
if key != id_col:
|
|
121
|
+
svg = ChemicalReactionVisualizer.visualize_reaction(
|
|
122
|
+
reaction_str,
|
|
123
|
+
img_size=img_size,
|
|
124
|
+
highlight_by_reactant=True,
|
|
125
|
+
show_atom_map=show_atom_map,
|
|
126
|
+
)
|
|
127
|
+
svg_list.append(svg.data)
|
|
128
|
+
titles.append(key)
|
|
129
|
+
display(
|
|
130
|
+
ChemicalReactionVisualizer.create_html_table_with_svgs(
|
|
131
|
+
svg_list, titles, num_cols, orientation
|
|
132
|
+
)
|
|
133
|
+
)
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
import matplotlib.pyplot as plt
|
|
3
|
+
import matplotlib.patches as mpatches
|
|
4
|
+
|
|
5
|
+
plt.rc("text", usetex=True) # Enable LaTeX rendering
|
|
6
|
+
plt.rc("font", family="serif") # Optional: use serif font
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def scatter_plot(
|
|
10
|
+
data_train,
|
|
11
|
+
data_test,
|
|
12
|
+
size_train=10,
|
|
13
|
+
size_test=10,
|
|
14
|
+
title=None,
|
|
15
|
+
ax=None,
|
|
16
|
+
xlabel="Coordinate 1",
|
|
17
|
+
ylabel="Coordinate 2",
|
|
18
|
+
):
|
|
19
|
+
# Check if data is empty
|
|
20
|
+
if data_train.empty or data_test.empty:
|
|
21
|
+
raise ValueError("Input data frames cannot be empty.")
|
|
22
|
+
|
|
23
|
+
# Check for necessary columns
|
|
24
|
+
if data_train.columns.size < 3 or data_test.columns.size < 3:
|
|
25
|
+
raise ValueError("Data frames must have at least three columns.")
|
|
26
|
+
|
|
27
|
+
# Adding 'Type' column to differentiate between train and test data
|
|
28
|
+
data_train["Type"] = "Train"
|
|
29
|
+
data_test["Type"] = "Test"
|
|
30
|
+
|
|
31
|
+
# Combine the datasets
|
|
32
|
+
data_combined = pd.concat([data_train, data_test])
|
|
33
|
+
|
|
34
|
+
# If no axes object is passed, create one
|
|
35
|
+
if ax is None:
|
|
36
|
+
fig, ax = plt.subplots(figsize=(12, 8))
|
|
37
|
+
|
|
38
|
+
# Define a more distinct color palette
|
|
39
|
+
pastel_palette = {
|
|
40
|
+
"Train": "deepskyblue",
|
|
41
|
+
"Test": "magenta",
|
|
42
|
+
} # Using deepskyblue and magenta for better distinction
|
|
43
|
+
|
|
44
|
+
# Create scatter plots with specified sizes
|
|
45
|
+
for dtype, color in pastel_palette.items():
|
|
46
|
+
subset = data_combined[data_combined["Type"] == dtype]
|
|
47
|
+
ax.scatter(
|
|
48
|
+
subset[subset.columns[1]],
|
|
49
|
+
subset[subset.columns[2]],
|
|
50
|
+
color=color,
|
|
51
|
+
label=dtype,
|
|
52
|
+
s=size_train if dtype == "Train" else size_test,
|
|
53
|
+
alpha=0.1,
|
|
54
|
+
edgecolor="none",
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
# Set the title if provided
|
|
58
|
+
if title:
|
|
59
|
+
ax.set_title(rf"{title}", fontsize=24, fontweight="bold")
|
|
60
|
+
|
|
61
|
+
# Set labels
|
|
62
|
+
ax.set_xlabel(xlabel, fontsize=18)
|
|
63
|
+
ax.set_ylabel(ylabel, fontsize=18)
|
|
64
|
+
|
|
65
|
+
# Enhance grid and layout
|
|
66
|
+
ax.grid(True, which="both", linestyle="--", linewidth=0.5)
|
|
67
|
+
ax.set_axisbelow(True)
|
|
68
|
+
|
|
69
|
+
# Get legend handles and labels for external usage
|
|
70
|
+
handles, labels = ax.get_legend_handles_labels()
|
|
71
|
+
|
|
72
|
+
# Return the axes, handles, and labels for further customization outside the function
|
|
73
|
+
return ax, handles, labels
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
# Define a function that modifies the legend handles to full opacity for better visibility in the legend
|
|
77
|
+
def adjust_legend_handles(handles, colors):
|
|
78
|
+
new_handles = []
|
|
79
|
+
for handle, color in zip(handles, colors):
|
|
80
|
+
# Create a new handle with the same properties but with full alpha for the legend
|
|
81
|
+
new_handle = mpatches.Patch(color=color, label=handle.get_label())
|
|
82
|
+
new_handles.append(new_handle)
|
|
83
|
+
return new_handles
|
synkit/Vis/embedding.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
from typing import Any, Dict, Optional
|
|
2
|
+
import numpy as np
|
|
3
|
+
from sklearn.manifold import TSNE
|
|
4
|
+
from joblib import Memory
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class Embedding:
|
|
8
|
+
def __init__(
|
|
9
|
+
self,
|
|
10
|
+
cache_dir: str = "./cachedir",
|
|
11
|
+
verbose: int = 0,
|
|
12
|
+
custom_tsne_params: Optional[Dict] = None,
|
|
13
|
+
) -> None:
|
|
14
|
+
"""
|
|
15
|
+
Initialize the Embedding class with options for caching directory, verbosity, and custom t-SNE parameters.
|
|
16
|
+
|
|
17
|
+
Parameters:
|
|
18
|
+
cache_dir (str): Directory where cached results are stored.
|
|
19
|
+
verbose (int): Verbosity level for the memory object.
|
|
20
|
+
custom_tsne_params (Dict, optional): Custom default parameters for t-SNE computations.
|
|
21
|
+
"""
|
|
22
|
+
self.memory = Memory(cache_dir, verbose=verbose)
|
|
23
|
+
self.default_tsne_params = {
|
|
24
|
+
"n_components": 2,
|
|
25
|
+
"perplexity": 30,
|
|
26
|
+
"learning_rate": 200,
|
|
27
|
+
"max_iter": 1000,
|
|
28
|
+
"random_state": 42,
|
|
29
|
+
}
|
|
30
|
+
if custom_tsne_params:
|
|
31
|
+
self.default_tsne_params.update(custom_tsne_params)
|
|
32
|
+
self.tsne_params = self.default_tsne_params.copy()
|
|
33
|
+
|
|
34
|
+
def set_tsne_params(self, **params) -> None:
|
|
35
|
+
"""
|
|
36
|
+
Sets parameters for t-SNE computations.
|
|
37
|
+
|
|
38
|
+
Parameters:
|
|
39
|
+
**params: Arbitrary number of parameters for t-SNE.
|
|
40
|
+
"""
|
|
41
|
+
self.tsne_params.update(params)
|
|
42
|
+
|
|
43
|
+
def reset_tsne_params(self) -> None:
|
|
44
|
+
"""
|
|
45
|
+
Resets t-SNE parameters to default values.
|
|
46
|
+
"""
|
|
47
|
+
self.tsne_params = self.default_tsne_params.copy()
|
|
48
|
+
|
|
49
|
+
def _compute_tsne(self, X: np.ndarray) -> np.ndarray:
|
|
50
|
+
"""
|
|
51
|
+
Direct computation of the t-SNE embedding with the current parameters.
|
|
52
|
+
|
|
53
|
+
Parameters:
|
|
54
|
+
X (np.ndarray): High-dimensional data points.
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
np.ndarray: The 2-dimensional t-SNE embedding of the data.
|
|
58
|
+
"""
|
|
59
|
+
tsne = TSNE(**self.tsne_params)
|
|
60
|
+
return tsne.fit_transform(X)
|
|
61
|
+
|
|
62
|
+
def compute_tsne(self, X: np.ndarray, cache: bool = True) -> np.ndarray:
|
|
63
|
+
"""
|
|
64
|
+
Computes or retrieves the t-SNE embedding from cache.
|
|
65
|
+
|
|
66
|
+
Parameters:
|
|
67
|
+
X (np.ndarray): High-dimensional data points.
|
|
68
|
+
cache (bool): Determines whether to use caching for the computation.
|
|
69
|
+
|
|
70
|
+
Returns:
|
|
71
|
+
np.ndarray: The 2-dimensional t-SNE embedding of the data.
|
|
72
|
+
"""
|
|
73
|
+
if cache:
|
|
74
|
+
return self.cache(X)
|
|
75
|
+
else:
|
|
76
|
+
return self._compute_tsne(X)
|
|
77
|
+
|
|
78
|
+
@property
|
|
79
|
+
def cache(self) -> Any:
|
|
80
|
+
"""
|
|
81
|
+
Decorator for caching the compute_tsne function.
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
Callable: Cached function.
|
|
85
|
+
"""
|
|
86
|
+
return self.memory.cache(self._compute_tsne)
|
|
87
|
+
|
|
88
|
+
def clear_cache(self) -> None:
|
|
89
|
+
"""
|
|
90
|
+
Clears the cache directory.
|
|
91
|
+
"""
|
|
92
|
+
self.memory.clear()
|