synkit 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. synkit/Chem/Fingerprint/__init__.py +0 -0
  2. synkit/Chem/Fingerprint/fp_calculator.py +122 -0
  3. synkit/Chem/Fingerprint/smiles_featurizer.py +185 -0
  4. synkit/Chem/Fingerprint/transformation_fp.py +79 -0
  5. synkit/Chem/Molecule/__init__.py +0 -0
  6. synkit/Chem/Molecule/standardize.py +137 -0
  7. synkit/Chem/Reaction/__init__.py +0 -0
  8. synkit/Chem/Reaction/balance_check.py +162 -0
  9. synkit/Chem/Reaction/cleanning.py +59 -0
  10. synkit/Chem/Reaction/deionize.py +289 -0
  11. synkit/Chem/Reaction/neutralize.py +256 -0
  12. synkit/Chem/Reaction/reagent.py +102 -0
  13. synkit/Chem/Reaction/standardize.py +157 -0
  14. synkit/Chem/Reaction/tautomerize.py +168 -0
  15. synkit/Graph/Cluster/__init__.py +0 -0
  16. synkit/Graph/Cluster/morphism.py +83 -0
  17. synkit/Graph/Feature/__init__.py +0 -0
  18. synkit/Graph/Feature/graph_descriptors.py +325 -0
  19. synkit/Graph/Feature/graph_fps.py +97 -0
  20. synkit/Graph/Feature/graph_signature.py +236 -0
  21. synkit/Graph/Feature/hash_fps.py +130 -0
  22. synkit/Graph/Feature/morgan_fps.py +87 -0
  23. synkit/Graph/Feature/path_fps.py +82 -0
  24. synkit/Graph/__init.py +0 -0
  25. synkit/IO/__init__.py +0 -0
  26. synkit/IO/chem_converter.py +231 -0
  27. synkit/IO/data_io.py +277 -0
  28. synkit/IO/data_process.py +49 -0
  29. synkit/IO/debug.py +78 -0
  30. synkit/IO/dg_to_gml.py +124 -0
  31. synkit/IO/gml_to_nx.py +119 -0
  32. synkit/IO/graph_to_mol.py +110 -0
  33. synkit/IO/mol_to_graph.py +282 -0
  34. synkit/IO/nx_to_gml.py +200 -0
  35. synkit/IO/parse_rule.py +172 -0
  36. synkit/IO/smiles_to_id.py +119 -0
  37. synkit/ITS/_misc.py +280 -0
  38. synkit/ITS/aam_validator.py +254 -0
  39. synkit/ITS/its_builder.py +94 -0
  40. synkit/ITS/its_construction.py +213 -0
  41. synkit/ITS/normalize_aam.py +183 -0
  42. synkit/ITS/partial_expand.py +170 -0
  43. synkit/Reactor/__init__.py +0 -0
  44. synkit/Reactor/core_engine.py +164 -0
  45. synkit/Reactor/inference.py +73 -0
  46. synkit/Reactor/multi_step.py +227 -0
  47. synkit/Reactor/multi_step_aam.py +82 -0
  48. synkit/Reactor/reagent.py +95 -0
  49. synkit/Reactor/rule_apply.py +81 -0
  50. synkit/Vis/__init__.py +0 -0
  51. synkit/Vis/chemical_graph_visualizer.py +378 -0
  52. synkit/Vis/chemical_reaction_visualizer.py +133 -0
  53. synkit/Vis/chemical_space.py +83 -0
  54. synkit/Vis/embedding.py +92 -0
  55. synkit/Vis/graph_visualizer.py +286 -0
  56. synkit/Vis/pdf_writer.py +143 -0
  57. synkit/Vis/rsmi_to_fig.py +169 -0
  58. synkit/__init__.py +0 -0
  59. synkit/_misc.py +181 -0
  60. synkit-0.0.1.dist-info/METADATA +148 -0
  61. synkit-0.0.1.dist-info/RECORD +63 -0
  62. synkit-0.0.1.dist-info/WHEEL +4 -0
  63. synkit-0.0.1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,378 @@
1
+ import random
2
+ import networkx as nx
3
+ import matplotlib.pyplot as plt
4
+ import seaborn as sns
5
+ from typing import Optional, Dict
6
+ import logging
7
+
8
+ color_scheme = {
9
+ "H": "#FFFFFF", # White
10
+ "C": "#909090", # Gray
11
+ "N": "#3050F8", # Blue
12
+ "O": "#FF0D0D", # Red
13
+ "F": "#90E050", # Green
14
+ "Cl": "#1FF01F", # Green
15
+ "Br": "#A62929", # Dark Red/Brown
16
+ "I": "#940094", # Purple
17
+ "P": "#FF8000", # Orange
18
+ "S": "#FFFF30", # Yellow
19
+ "Na": "#E0E0E0", # Light Gray
20
+ "K": "#8F40D4", # Light Purple
21
+ "Ca": "#3DFF00", # Light Green
22
+ "Mg": "#8AFF00", # Light Green
23
+ "Fe": "#B7410E", # Rust Red
24
+ "Zn": "#7D80B0", # Light Blue
25
+ "Cu": "#C88033", # Copper Red/Orange
26
+ "Ag": "#C0C0C0", # Light Gray
27
+ "Au": "#FFD123", # Gold Yellow
28
+ "Hg": "#B8B8D0", # Silver Gray
29
+ "Pb": "#575961", # Dark Gray
30
+ "Al": "#BFA6A6", # Light Gray
31
+ "Si": "#F0C8A0", # Light Brown
32
+ "B": "#FFA1A1", # Pink
33
+ "As": "#BD80E3", # Light Gray
34
+ "Sb": "#9E63B5", # Dark Gray
35
+ "Se": "#FFA100", # Light Pink
36
+ "Te": "#D47A00", # Gray
37
+ "Cd": "#FFD98F", # Light Blue/Gray
38
+ "Ti": "#BFC2C7", # Light Gray
39
+ "V": "#A6A6AB", # Light Gray/Blue
40
+ "Cr": "#8A99C7", # Steel Gray
41
+ "Mn": "#9C7AC7", # Gray
42
+ "Co": "#FF7A00", # Light Pink
43
+ "Ni": "#4DFF4D", # Light Green
44
+ }
45
+
46
+
47
+ class ChemicalGraphVisualizer:
48
+ def __init__(
49
+ self,
50
+ seed: Optional[int] = None,
51
+ element_colors: Optional[Dict[str, str]] = color_scheme,
52
+ ):
53
+ """
54
+ Initialize the visualizer with optional seed and color scheme.
55
+
56
+ Parameters:
57
+ seed (int, optional): Seed for random number generator for reproducibility.
58
+ element_colors (dict, optional): Dictionary mapping elements to their color codes.
59
+ """
60
+ # Define a popular color scheme in chemistry if not provided
61
+ if element_colors is None:
62
+ self.element_colors = {
63
+ "H": "#FFFFFF", # White
64
+ "C": "#909090", # Gray
65
+ "N": "#3050F8", # Blue
66
+ "O": "#FF0D0D", # Red
67
+ "F": "#90E050", # Green
68
+ "Cl": "#1FF01F", # Green
69
+ # Additional elements can be added here
70
+ }
71
+ else:
72
+ self.element_colors = element_colors
73
+ self.seed = seed
74
+
75
+ def graph_vis(
76
+ self,
77
+ G: nx.Graph,
78
+ node_size: int = 100,
79
+ visualize_edge_weight: bool = False,
80
+ edge_font_size: int = 10,
81
+ show_node_labels: bool = False,
82
+ node_label_font_size: int = 12,
83
+ ax: Optional[plt.Axes] = None,
84
+ ) -> None:
85
+ """
86
+ Visualize a NetworkX graph with standard representation.
87
+
88
+ Parameters:
89
+ G (nx.Graph): The graph to visualize.
90
+ node_size (int): The size of the nodes.
91
+ visualize_edge_weight (bool): Whether to display edge weights.
92
+ edge_font_size (int): Font size for edge labels.
93
+ show_node_labels (bool): Whether to show labels on the nodes.
94
+ node_label_font_size (int): Font size for node labels.
95
+ """
96
+ # Set random seed for reproducibility
97
+ if self.seed is not None:
98
+ random.seed(self.seed)
99
+
100
+ # Get colors for each node
101
+ node_colors = [
102
+ self.element_colors.get(G.nodes[node]["element"], "#000000")
103
+ for node in G.nodes()
104
+ ]
105
+
106
+ # Draw the graph
107
+ pos = nx.spring_layout(G, seed=self.seed) # Use spring layout
108
+
109
+ if ax is None:
110
+ ax = plt.gca() # Get current axes if not provided
111
+
112
+ if show_node_labels:
113
+ node_labels = {node: G.nodes[node]["element"] for node in G.nodes()}
114
+ nx.draw(
115
+ G,
116
+ pos,
117
+ ax=ax,
118
+ with_labels=True,
119
+ labels=node_labels,
120
+ node_color=node_colors,
121
+ node_size=node_size,
122
+ font_size=node_label_font_size,
123
+ # font_weight="semi-bold",
124
+ )
125
+ else:
126
+ nx.draw(
127
+ G,
128
+ pos,
129
+ ax=ax,
130
+ with_labels=False,
131
+ node_color=node_colors,
132
+ node_size=node_size,
133
+ # font_weight="bold",
134
+ )
135
+
136
+ # Get edge labels if needed
137
+ if visualize_edge_weight:
138
+ edge_labels = {(u, v): G.edges[u, v]["order"] for u, v in G.edges()}
139
+ nx.draw_networkx_edge_labels(
140
+ G, pos, ax=ax, edge_labels=edge_labels, font_size=edge_font_size
141
+ )
142
+
143
+ def its_vis(
144
+ self,
145
+ G: nx.Graph,
146
+ node_size: int = 100,
147
+ show_node_labels: bool = False,
148
+ node_label_font_size: int = 12,
149
+ ax: Optional[plt.Axes] = None,
150
+ ) -> None:
151
+ """
152
+ Visualize a NetworkX graph with edge colors indicating bond changes.
153
+
154
+ Parameters:
155
+ G (nx.Graph): The graph to visualize.
156
+ node_size (int): The size of the nodes.
157
+ show_node_labels (bool): Whether to show labels on the nodes.
158
+ node_label_font_size (int): Font size for node labels.
159
+ """
160
+ # Set random seed for reproducibility
161
+ if self.seed is not None:
162
+ random.seed(self.seed)
163
+
164
+ # Draw the graph
165
+ pos = nx.spring_layout(G, seed=self.seed) # Use spring layout
166
+ # Get colors for each node
167
+ node_colors = [
168
+ self.element_colors.get(G.nodes[node]["element"], "#000000")
169
+ for node in G.nodes()
170
+ ]
171
+
172
+ # Determine edge colors based on 'order'
173
+ edge_colors = []
174
+ for u, v, data in G.edges(data=True):
175
+ order = data.get("standard_order", 0)
176
+ # order = tuple(0 if isinstance(x, str) else float(x) for x in order)
177
+ if order == 0:
178
+ edge_colors.append("black") # Normal bond
179
+ elif order < 0:
180
+ edge_colors.append("blue") # Increasing bond
181
+ else:
182
+ edge_colors.append("red") # Breaking bond
183
+
184
+ if ax is None:
185
+ ax = plt.gca() # Get current axes if not provided
186
+
187
+ if show_node_labels:
188
+ node_labels = {node: G.nodes[node]["element"] for node in G.nodes()}
189
+ nx.draw(
190
+ G,
191
+ pos,
192
+ ax=ax,
193
+ with_labels=True,
194
+ labels=node_labels,
195
+ node_color=node_colors,
196
+ node_size=node_size,
197
+ font_size=node_label_font_size,
198
+ # font_weight="bold",
199
+ edge_color=edge_colors,
200
+ )
201
+ else:
202
+ nx.draw(
203
+ G,
204
+ pos,
205
+ ax=ax,
206
+ with_labels=False,
207
+ node_color=node_colors,
208
+ node_size=node_size,
209
+ font_weight="bold",
210
+ edge_color=edge_colors,
211
+ )
212
+
213
+ def vis_three_graph(
214
+ self,
215
+ graph_tuple,
216
+ figsize=(15, 5),
217
+ left_graph_title="Reactants",
218
+ k_graph_title="ITS Graph",
219
+ right_graph_title="Products",
220
+ show_node_labels=True,
221
+ title_fontsize=24,
222
+ title_weight="bold",
223
+ save_path=None,
224
+ display_inline=False,
225
+ log=False,
226
+ ):
227
+ """
228
+ Visualize reactants, ITS graph, and products in one figure.
229
+
230
+ Parameters:
231
+ graph_tuple (tuple): Tuple of NetworkX graphs (reactants, ITS graph, products).
232
+ figsize (tuple): Figure size in inches (width, height).
233
+ left_graph_title (str): Title for the left subplot.
234
+ k_graph_title (str): Title for the middle subplot.
235
+ right_graph_title (str): Title for the right subplot.
236
+ show_node_labels (bool): If True, show node labels on the graphs.
237
+ title_fontsize (int): Font size for subplot titles.
238
+ title_weight (str): Font weight for subplot titles.
239
+ save_path (str, optional): Path to save the figure to file.
240
+ display_inline (bool): If True, display the figure inline in the notebook.
241
+ log (bool): If True, enable logging of function progress.
242
+ """
243
+ if log:
244
+ logging.basicConfig(level=logging.INFO)
245
+
246
+ try:
247
+ # Unpack the tuple
248
+ reactants_graph, products_graph, its_graph = graph_tuple
249
+
250
+ # Create a figure with subplots
251
+ fig, axs = plt.subplots(1, 3, figsize=figsize)
252
+
253
+ # Visualize each graph on its respective subplot
254
+ self.graph_vis(
255
+ reactants_graph, ax=axs[0], show_node_labels=show_node_labels
256
+ )
257
+ self.its_vis(its_graph, ax=axs[1], show_node_labels=show_node_labels)
258
+ self.graph_vis(products_graph, ax=axs[2], show_node_labels=show_node_labels)
259
+
260
+ # Set titles for subplots
261
+ axs[0].set_title(
262
+ left_graph_title, fontsize=title_fontsize, weight=title_weight
263
+ )
264
+ axs[1].set_title(
265
+ k_graph_title, fontsize=title_fontsize, weight=title_weight
266
+ )
267
+ axs[2].set_title(
268
+ right_graph_title, fontsize=title_fontsize, weight=title_weight
269
+ )
270
+
271
+ plt.tight_layout()
272
+
273
+ if save_path is not None:
274
+ plt.savefig(save_path, dpi=600)
275
+ if log:
276
+ logging.info(f"Figure saved to {save_path}")
277
+
278
+ if display_inline:
279
+ plt.show()
280
+ else:
281
+ plt.close(fig)
282
+
283
+ return fig
284
+ except Exception as e:
285
+ if log:
286
+ logging.error("Failed to visualize graphs: ", exc_info=True)
287
+ raise RuntimeError("Error in graph visualization: ") from e
288
+
289
+ def visualize_all(
290
+ self,
291
+ graph_tuple_row1,
292
+ graph_tuple_row2,
293
+ figsize=(15, 10),
294
+ titles_row1=("A. Reactant Graph", "B. ITS Graph", "C Products"),
295
+ titles_row2=("D. L Graph", "E. K Graph", "D. R Graph"),
296
+ show_node_labels=True,
297
+ show_grid=True,
298
+ grid_style="--",
299
+ title_fontsize=24,
300
+ title_weight="bold",
301
+ save_path=None,
302
+ display_inline=False,
303
+ log=False,
304
+ ):
305
+ """
306
+ Visualize two rows of graphs, each with three graphs, optionally displaying a
307
+ grid.
308
+
309
+ Parameters:
310
+ graph_tuple_row1 (tuple): Tuple of NetworkX graphs for the first row
311
+ (reactants, ITS graph, products).
312
+ graph_tuple_row2 (tuple): Tuple of NetworkX graphs for the second row (L, K,
313
+ R).
314
+ figsize (tuple): Figure size in inches (width, height).
315
+ titles_row1 (tuple): Titles for the first row subplots.
316
+ titles_row2 (tuple): Titles for the second row subplots.
317
+ show_node_labels (bool): If True, show node labels on the graphs.
318
+ show_grid (bool): If True, display grid lines on the plots.
319
+ grid_style (str): Style of the grid lines.
320
+ title_fontsize (int): Font size for subplot titles.
321
+ title_weight (str): Font weight for subplot titles.
322
+ save_path (str, optional): Path to save the figure to file.
323
+ display_inline (bool): If True, display the figure inline in the notebook.
324
+ log (bool): If True, enable logging of function progress.
325
+ """
326
+ if log:
327
+ logging.basicConfig(level=logging.INFO)
328
+
329
+ try:
330
+ sns.set_theme(style="darkgrid") # Set the Seaborn style
331
+ reactants_graph, products_graph, its_graph = graph_tuple_row1
332
+ l_graph, r_graph, k_graph = graph_tuple_row2
333
+
334
+ # Create a figure with subplots
335
+ fig, axs = plt.subplots(2, 3, figsize=figsize)
336
+
337
+ # Visualize each graph on its respective subplot (first row)
338
+ self.graph_vis(
339
+ reactants_graph, ax=axs[0, 0], show_node_labels=show_node_labels
340
+ )
341
+ self.its_vis(its_graph, ax=axs[0, 1], show_node_labels=show_node_labels)
342
+ self.graph_vis(
343
+ products_graph, ax=axs[0, 2], show_node_labels=show_node_labels
344
+ )
345
+
346
+ # Visualize each graph on its respective subplot (second row)
347
+ self.graph_vis(l_graph, ax=axs[1, 0], show_node_labels=show_node_labels)
348
+ self.its_vis(k_graph, ax=axs[1, 1], show_node_labels=show_node_labels)
349
+ self.graph_vis(r_graph, ax=axs[1, 2], show_node_labels=show_node_labels)
350
+
351
+ # Set titles and enable grid for subplots
352
+ for ax, title in zip(axs[0], titles_row1):
353
+ ax.set_title(title, fontsize=title_fontsize, weight=title_weight)
354
+ if show_grid:
355
+ ax.grid(True, linestyle=grid_style, which="both")
356
+
357
+ for ax, title in zip(axs[1], titles_row2):
358
+ ax.set_title(title, fontsize=title_fontsize, weight=title_weight)
359
+ if show_grid:
360
+ ax.grid(True, linestyle=grid_style, which="both")
361
+
362
+ plt.tight_layout()
363
+
364
+ if save_path is not None:
365
+ plt.savefig(save_path, dpi=600)
366
+ if log:
367
+ logging.info(f"Figure saved to {save_path}")
368
+
369
+ if display_inline:
370
+ plt.show()
371
+ else:
372
+ plt.close(fig)
373
+
374
+ return fig
375
+ except Exception as e:
376
+ if log:
377
+ logging.error("Failed to visualize graphs: ", exc_info=True)
378
+ raise RuntimeError("Error in graph visualization: ") from e
@@ -0,0 +1,133 @@
1
+ from typing import List, Dict
2
+ from IPython.display import display, HTML, SVG
3
+ from rdkit.Chem.Draw import rdMolDraw2D
4
+ from rdkit.Chem import rdChemReactions
5
+
6
+
7
+ class ChemicalReactionVisualizer:
8
+ @staticmethod
9
+ def create_html_table_with_svgs(
10
+ svg_list: List[str],
11
+ titles: List[str],
12
+ num_cols: int = 2,
13
+ orientation: str = "vertical",
14
+ title_size: int = 16,
15
+ ) -> HTML:
16
+ """
17
+ Creates an HTML table to display SVG images with titles in
18
+ a structured 'subplot-like' layout.
19
+
20
+ Parameters:
21
+ - svg_list (List[str]): List of SVG content strings.
22
+ - titles (List[str]): Corresponding titles for each SVG image.
23
+ - num_cols (int): Defines the number of columns for the
24
+ 'vertical' layout or rows for 'horizontal' layout.
25
+ - orientation (str): Layout orientation of images ('vertical' or 'horizontal').
26
+ - title_size (int): Font size of the titles displayed above each image.
27
+
28
+ Returns:
29
+ - HTML: HTML object to be displayed within an IPython notebook environment.
30
+ """
31
+ html = "<table>"
32
+ title_style = f"font-size:{title_size}px;" # CSS to control title size
33
+ if orientation == "vertical":
34
+ for i in range(0, len(svg_list), num_cols):
35
+ html += "<tr>"
36
+ for j in range(num_cols):
37
+ if i + j < len(svg_list):
38
+ html += f"<td style='border:1px solid black; padding:10px'><b style='{title_style}'>{titles[i+j]}</b><br>{svg_list[i+j]}</td>"
39
+ html += "</tr>"
40
+ else:
41
+ for j in range(num_cols):
42
+ html += "<tr>"
43
+ for i in range(j, len(svg_list), num_cols):
44
+ html += f"<td style='border:1px solid black; padding:10px'><b style='{title_style}'>{titles[i]}</b><br>{svg_list[i]}</td>"
45
+ html += "</tr>"
46
+ html += "</table>"
47
+ return HTML(html)
48
+
49
+ @staticmethod
50
+ def visualize_reaction(
51
+ reaction_smiles: str,
52
+ img_size: tuple = (600, 200),
53
+ highlight_by_reactant: bool = True,
54
+ mol_scale: float = 0.9,
55
+ bond_line_width: float = 2.0,
56
+ atom_label_font_size: int = 12,
57
+ padding: float = 0.01,
58
+ show_atom_map: bool = False,
59
+ ) -> SVG:
60
+ """
61
+ Visualizes a chemical reaction using RDKit and generates an SVG image for display.
62
+
63
+ Parameters:
64
+ - reaction_smiles (str): SMILES string representing the chemical reaction.
65
+ - img_size (tuple): Dimensions of the output image (width, height).
66
+ - highlight_by_reactant (bool): Whether to highlight reactants in the image.
67
+ - mol_scale (float): Scale factor for the size of the molecules in the image.
68
+ - bond_line_width (float): Line width for the bonds in the drawing.
69
+ - atom_label_font_size (int): Font size for the atom labels in the drawing.
70
+ - padding (float): Padding around the image in the SVG.
71
+ - show_atom_map (bool): Whether to display atom mapping numbers on the atoms.
72
+
73
+ Returns:
74
+ - SVG: An SVG object containing the rendered reaction.
75
+ """
76
+ reaction = rdChemReactions.ReactionFromSmarts(reaction_smiles, useSmiles=True)
77
+ rdChemReactions.PreprocessReaction(reaction)
78
+ if show_atom_map:
79
+ for mol in list(reaction.GetReactants()) + list(reaction.GetProducts()):
80
+ for atom in mol.GetAtoms():
81
+ if atom.HasProp("molAtomMapNumber"):
82
+ atom.SetProp("atomLabel", atom.GetProp("molAtomMapNumber"))
83
+
84
+ drawer = rdMolDraw2D.MolDraw2DSVG(img_size[0], img_size[1])
85
+ opts = drawer.drawOptions()
86
+ opts.scale = mol_scale
87
+ opts.bondLineWidth = bond_line_width
88
+ opts.atomLabelFontSize = atom_label_font_size
89
+ opts.padding = padding
90
+
91
+ drawer.DrawReaction(reaction, highlightByReactant=highlight_by_reactant)
92
+ drawer.FinishDrawing()
93
+ return SVG(drawer.GetDrawingText())
94
+
95
+ @staticmethod
96
+ def visualize_and_compare_reactions(
97
+ input_dict: Dict[str, str],
98
+ id_col: str = "R-id",
99
+ img_size: tuple = (1000, 600),
100
+ num_cols: int = 2,
101
+ orientation: str = "vertical",
102
+ show_atom_map: bool = False,
103
+ ) -> None:
104
+ """
105
+ Visualizes and compares multiple chemical reactions,
106
+ displaying them side by side in an HTML table.
107
+
108
+ Parameters:
109
+ - input_dict (Dict[str, str]): Dictionary with reaction
110
+ identifiers as keys and SMILES strings as values.
111
+ - id_col (str): A dictionary key to exclude
112
+ from visualization (typically metadata).
113
+ - img_size (tuple): size if image.
114
+ - num_cols (int): Number of columns in the display table.
115
+ - orientation (str): vertically or horizontally.
116
+ """
117
+ svg_list = []
118
+ titles = []
119
+ for key, reaction_str in input_dict.items():
120
+ if key != id_col:
121
+ svg = ChemicalReactionVisualizer.visualize_reaction(
122
+ reaction_str,
123
+ img_size=img_size,
124
+ highlight_by_reactant=True,
125
+ show_atom_map=show_atom_map,
126
+ )
127
+ svg_list.append(svg.data)
128
+ titles.append(key)
129
+ display(
130
+ ChemicalReactionVisualizer.create_html_table_with_svgs(
131
+ svg_list, titles, num_cols, orientation
132
+ )
133
+ )
@@ -0,0 +1,83 @@
1
+ import pandas as pd
2
+ import matplotlib.pyplot as plt
3
+ import matplotlib.patches as mpatches
4
+
5
+ plt.rc("text", usetex=True) # Enable LaTeX rendering
6
+ plt.rc("font", family="serif") # Optional: use serif font
7
+
8
+
9
+ def scatter_plot(
10
+ data_train,
11
+ data_test,
12
+ size_train=10,
13
+ size_test=10,
14
+ title=None,
15
+ ax=None,
16
+ xlabel="Coordinate 1",
17
+ ylabel="Coordinate 2",
18
+ ):
19
+ # Check if data is empty
20
+ if data_train.empty or data_test.empty:
21
+ raise ValueError("Input data frames cannot be empty.")
22
+
23
+ # Check for necessary columns
24
+ if data_train.columns.size < 3 or data_test.columns.size < 3:
25
+ raise ValueError("Data frames must have at least three columns.")
26
+
27
+ # Adding 'Type' column to differentiate between train and test data
28
+ data_train["Type"] = "Train"
29
+ data_test["Type"] = "Test"
30
+
31
+ # Combine the datasets
32
+ data_combined = pd.concat([data_train, data_test])
33
+
34
+ # If no axes object is passed, create one
35
+ if ax is None:
36
+ fig, ax = plt.subplots(figsize=(12, 8))
37
+
38
+ # Define a more distinct color palette
39
+ pastel_palette = {
40
+ "Train": "deepskyblue",
41
+ "Test": "magenta",
42
+ } # Using deepskyblue and magenta for better distinction
43
+
44
+ # Create scatter plots with specified sizes
45
+ for dtype, color in pastel_palette.items():
46
+ subset = data_combined[data_combined["Type"] == dtype]
47
+ ax.scatter(
48
+ subset[subset.columns[1]],
49
+ subset[subset.columns[2]],
50
+ color=color,
51
+ label=dtype,
52
+ s=size_train if dtype == "Train" else size_test,
53
+ alpha=0.1,
54
+ edgecolor="none",
55
+ )
56
+
57
+ # Set the title if provided
58
+ if title:
59
+ ax.set_title(rf"{title}", fontsize=24, fontweight="bold")
60
+
61
+ # Set labels
62
+ ax.set_xlabel(xlabel, fontsize=18)
63
+ ax.set_ylabel(ylabel, fontsize=18)
64
+
65
+ # Enhance grid and layout
66
+ ax.grid(True, which="both", linestyle="--", linewidth=0.5)
67
+ ax.set_axisbelow(True)
68
+
69
+ # Get legend handles and labels for external usage
70
+ handles, labels = ax.get_legend_handles_labels()
71
+
72
+ # Return the axes, handles, and labels for further customization outside the function
73
+ return ax, handles, labels
74
+
75
+
76
+ # Define a function that modifies the legend handles to full opacity for better visibility in the legend
77
+ def adjust_legend_handles(handles, colors):
78
+ new_handles = []
79
+ for handle, color in zip(handles, colors):
80
+ # Create a new handle with the same properties but with full alpha for the legend
81
+ new_handle = mpatches.Patch(color=color, label=handle.get_label())
82
+ new_handles.append(new_handle)
83
+ return new_handles
@@ -0,0 +1,92 @@
1
+ from typing import Any, Dict, Optional
2
+ import numpy as np
3
+ from sklearn.manifold import TSNE
4
+ from joblib import Memory
5
+
6
+
7
+ class Embedding:
8
+ def __init__(
9
+ self,
10
+ cache_dir: str = "./cachedir",
11
+ verbose: int = 0,
12
+ custom_tsne_params: Optional[Dict] = None,
13
+ ) -> None:
14
+ """
15
+ Initialize the Embedding class with options for caching directory, verbosity, and custom t-SNE parameters.
16
+
17
+ Parameters:
18
+ cache_dir (str): Directory where cached results are stored.
19
+ verbose (int): Verbosity level for the memory object.
20
+ custom_tsne_params (Dict, optional): Custom default parameters for t-SNE computations.
21
+ """
22
+ self.memory = Memory(cache_dir, verbose=verbose)
23
+ self.default_tsne_params = {
24
+ "n_components": 2,
25
+ "perplexity": 30,
26
+ "learning_rate": 200,
27
+ "max_iter": 1000,
28
+ "random_state": 42,
29
+ }
30
+ if custom_tsne_params:
31
+ self.default_tsne_params.update(custom_tsne_params)
32
+ self.tsne_params = self.default_tsne_params.copy()
33
+
34
+ def set_tsne_params(self, **params) -> None:
35
+ """
36
+ Sets parameters for t-SNE computations.
37
+
38
+ Parameters:
39
+ **params: Arbitrary number of parameters for t-SNE.
40
+ """
41
+ self.tsne_params.update(params)
42
+
43
+ def reset_tsne_params(self) -> None:
44
+ """
45
+ Resets t-SNE parameters to default values.
46
+ """
47
+ self.tsne_params = self.default_tsne_params.copy()
48
+
49
+ def _compute_tsne(self, X: np.ndarray) -> np.ndarray:
50
+ """
51
+ Direct computation of the t-SNE embedding with the current parameters.
52
+
53
+ Parameters:
54
+ X (np.ndarray): High-dimensional data points.
55
+
56
+ Returns:
57
+ np.ndarray: The 2-dimensional t-SNE embedding of the data.
58
+ """
59
+ tsne = TSNE(**self.tsne_params)
60
+ return tsne.fit_transform(X)
61
+
62
+ def compute_tsne(self, X: np.ndarray, cache: bool = True) -> np.ndarray:
63
+ """
64
+ Computes or retrieves the t-SNE embedding from cache.
65
+
66
+ Parameters:
67
+ X (np.ndarray): High-dimensional data points.
68
+ cache (bool): Determines whether to use caching for the computation.
69
+
70
+ Returns:
71
+ np.ndarray: The 2-dimensional t-SNE embedding of the data.
72
+ """
73
+ if cache:
74
+ return self.cache(X)
75
+ else:
76
+ return self._compute_tsne(X)
77
+
78
+ @property
79
+ def cache(self) -> Any:
80
+ """
81
+ Decorator for caching the compute_tsne function.
82
+
83
+ Returns:
84
+ Callable: Cached function.
85
+ """
86
+ return self.memory.cache(self._compute_tsne)
87
+
88
+ def clear_cache(self) -> None:
89
+ """
90
+ Clears the cache directory.
91
+ """
92
+ self.memory.clear()