synkit 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. synkit/Chem/Fingerprint/__init__.py +0 -0
  2. synkit/Chem/Fingerprint/fp_calculator.py +122 -0
  3. synkit/Chem/Fingerprint/smiles_featurizer.py +185 -0
  4. synkit/Chem/Fingerprint/transformation_fp.py +79 -0
  5. synkit/Chem/Molecule/__init__.py +0 -0
  6. synkit/Chem/Molecule/standardize.py +137 -0
  7. synkit/Chem/Reaction/__init__.py +0 -0
  8. synkit/Chem/Reaction/balance_check.py +162 -0
  9. synkit/Chem/Reaction/cleanning.py +59 -0
  10. synkit/Chem/Reaction/deionize.py +289 -0
  11. synkit/Chem/Reaction/neutralize.py +256 -0
  12. synkit/Chem/Reaction/reagent.py +102 -0
  13. synkit/Chem/Reaction/standardize.py +157 -0
  14. synkit/Chem/Reaction/tautomerize.py +168 -0
  15. synkit/Graph/Cluster/__init__.py +0 -0
  16. synkit/Graph/Cluster/morphism.py +83 -0
  17. synkit/Graph/Feature/__init__.py +0 -0
  18. synkit/Graph/Feature/graph_descriptors.py +325 -0
  19. synkit/Graph/Feature/graph_fps.py +97 -0
  20. synkit/Graph/Feature/graph_signature.py +236 -0
  21. synkit/Graph/Feature/hash_fps.py +130 -0
  22. synkit/Graph/Feature/morgan_fps.py +87 -0
  23. synkit/Graph/Feature/path_fps.py +82 -0
  24. synkit/Graph/__init.py +0 -0
  25. synkit/IO/__init__.py +0 -0
  26. synkit/IO/chem_converter.py +231 -0
  27. synkit/IO/data_io.py +277 -0
  28. synkit/IO/data_process.py +49 -0
  29. synkit/IO/debug.py +78 -0
  30. synkit/IO/dg_to_gml.py +124 -0
  31. synkit/IO/gml_to_nx.py +119 -0
  32. synkit/IO/graph_to_mol.py +110 -0
  33. synkit/IO/mol_to_graph.py +282 -0
  34. synkit/IO/nx_to_gml.py +200 -0
  35. synkit/IO/parse_rule.py +172 -0
  36. synkit/IO/smiles_to_id.py +119 -0
  37. synkit/ITS/_misc.py +280 -0
  38. synkit/ITS/aam_validator.py +254 -0
  39. synkit/ITS/its_builder.py +94 -0
  40. synkit/ITS/its_construction.py +213 -0
  41. synkit/ITS/normalize_aam.py +183 -0
  42. synkit/ITS/partial_expand.py +170 -0
  43. synkit/Reactor/__init__.py +0 -0
  44. synkit/Reactor/core_engine.py +164 -0
  45. synkit/Reactor/inference.py +73 -0
  46. synkit/Reactor/multi_step.py +227 -0
  47. synkit/Reactor/multi_step_aam.py +82 -0
  48. synkit/Reactor/reagent.py +95 -0
  49. synkit/Reactor/rule_apply.py +81 -0
  50. synkit/Vis/__init__.py +0 -0
  51. synkit/Vis/chemical_graph_visualizer.py +378 -0
  52. synkit/Vis/chemical_reaction_visualizer.py +133 -0
  53. synkit/Vis/chemical_space.py +83 -0
  54. synkit/Vis/embedding.py +92 -0
  55. synkit/Vis/graph_visualizer.py +286 -0
  56. synkit/Vis/pdf_writer.py +143 -0
  57. synkit/Vis/rsmi_to_fig.py +169 -0
  58. synkit/__init__.py +0 -0
  59. synkit/_misc.py +181 -0
  60. synkit-0.0.1.dist-info/METADATA +148 -0
  61. synkit-0.0.1.dist-info/RECORD +63 -0
  62. synkit-0.0.1.dist-info/WHEEL +4 -0
  63. synkit-0.0.1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,286 @@
1
+ """
2
+ This module comprises several functions adapted from the work of Klaus Weinbauer.
3
+ The original code can be found at his GitHub repository: https://github.com/klausweinbauer/FGUtils.
4
+ Adaptations were made to enhance functionality and integrate with other system components.
5
+ """
6
+
7
+ import networkx as nx
8
+ from rdkit import Chem
9
+ from rdkit.Chem import rdDepictor
10
+
11
+ import matplotlib.pyplot as plt
12
+ from typing import Dict, Optional
13
+
14
+ from synkit.IO.graph_to_mol import GraphToMol
15
+
16
+
17
+ class GraphVisualizer:
18
+ def __init__(
19
+ self,
20
+ node_attributes: Dict[str, str] = {
21
+ "element": "element",
22
+ "charge": "charge",
23
+ "atom_map": "atom_map",
24
+ },
25
+ edge_attributes: Dict[str, str] = {"order": "order"},
26
+ ):
27
+ self.node_attributes = node_attributes
28
+ self.edge_attributes = edge_attributes
29
+
30
+ def _get_its_as_mol(self, its: nx.Graph) -> Optional[Chem.Mol]:
31
+ """
32
+ Convert a graph representation of an intermediate transition state into an RDKit molecule.
33
+
34
+ Parameters:
35
+ - its (nx.Graph): The graph to convert.
36
+
37
+ Returns:
38
+ - Chem.Mol or None: The RDKit molecule if conversion is successful, None otherwise.
39
+ """
40
+ _its = its.copy()
41
+ for n in _its.nodes():
42
+ _its.nodes[n]["atom_map"] = n #
43
+ for u, v in _its.edges():
44
+ _its[u][v]["order"] = 1
45
+ return GraphToMol(self.node_attributes, self.edge_attributes).graph_to_mol(
46
+ _its, False, False
47
+ )
48
+
49
+ def plot_its(
50
+ self,
51
+ its: nx.Graph,
52
+ ax: plt.Axes,
53
+ use_mol_coords: bool = True,
54
+ title: Optional[str] = None,
55
+ node_color: str = "#FFFFFF",
56
+ node_size: int = 500,
57
+ edge_color: str = "#000000",
58
+ edge_weight: float = 2.0,
59
+ show_atom_map: bool = False,
60
+ use_edge_color: bool = False,
61
+ symbol_key: str = "element",
62
+ bond_key: str = "order",
63
+ aam_key: str = "atom_map",
64
+ standard_order_key: str = "standard_order",
65
+ font_size: int = 12,
66
+ rule: bool = False, # New option to remove edges with specific colors
67
+ ):
68
+ """
69
+ Plot an intermediate transition state (ITS) graph on a given Matplotlib axes with various customizations.
70
+
71
+ Parameters:
72
+ - its (nx.Graph): The graph representing the intermediate transition state.
73
+ - ax (plt.Axes): The matplotlib axes to draw the graph on.
74
+ - use_mol_coords (bool): Use molecular coordinates for node positions if True, else use a spring layout.
75
+ - title (Optional[str]): Title for the graph. If None, no title is set.
76
+ - node_color (str): Color code for the graph nodes.
77
+ - node_size (int): Size of the graph nodes.
78
+ - edge_color (str): Default color code for the graph edges if not using conditional coloring.
79
+ - edge_weight (float): Thickness of the graph edges.
80
+ - show_atom_map (bool): If True, displays atom mapping numbers alongside symbols.
81
+ - use_edge_color (bool): If True, colors edges based on their 'standard_order' attribute.
82
+ - symbol_key (str): Key to access the symbol attribute in the node's data.
83
+ - bond_key (str): Key to access the bond type attribute in the edge's data.
84
+ - aam_key (str): Key to access the atom mapping number in the node's data.
85
+ - standard_order_key (str): Key to determine the edge color conditionally.
86
+ - font_size (int): Font size for labels and edge labels.
87
+ - rule (bool): If True, removes edges with a specific color before plotting.
88
+
89
+ Returns:
90
+ - None
91
+ """
92
+ bond_char = {None: "∅", 0: "∅", 1: "—", 2: "=", 3: "≡", 1.5: ":"}
93
+
94
+ positions = self._calculate_positions(its, use_mol_coords)
95
+
96
+ ax.axis("equal")
97
+ ax.axis("off")
98
+ if title:
99
+ ax.set_title(title)
100
+
101
+ # Conditional edge coloring based on 'standard_order'
102
+ if use_edge_color:
103
+ edge_colors = [
104
+ (
105
+ "red"
106
+ if data.get(standard_order_key, 0) > 0
107
+ else "green" if data.get(standard_order_key, 0) < 0 else "black"
108
+ )
109
+ for _, _, data in its.edges(data=True)
110
+ ]
111
+ else:
112
+ edge_colors = edge_color
113
+
114
+ # If rule=True, remove edges with specific colors (red/green/black)
115
+ if rule:
116
+ # Get the edges that have the colors red, green, or black
117
+ edges_to_remove = [
118
+ edge
119
+ for edge, color in zip(its.edges(), edge_colors)
120
+ if color in ["red", "green", "black"]
121
+ ]
122
+ its.remove_edges_from(edges_to_remove)
123
+
124
+ # Recalculate edge_colors after removal of edges
125
+ if use_edge_color:
126
+ edge_colors = [
127
+ (
128
+ "red"
129
+ if data.get(standard_order_key, 0) > 0
130
+ else "green" if data.get(standard_order_key, 0) < 0 else "black"
131
+ )
132
+ for _, _, data in its.edges(data=True)
133
+ ]
134
+ else:
135
+ edge_colors = edge_color
136
+
137
+ # Plot the remaining graph
138
+ nx.draw_networkx_edges(
139
+ its, positions, edge_color=edge_colors, width=edge_weight, ax=ax
140
+ )
141
+ nx.draw_networkx_nodes(
142
+ its, positions, node_color=node_color, node_size=node_size, ax=ax
143
+ )
144
+
145
+ # Adjust labels to optionally show atom mapping numbers
146
+ labels = {
147
+ n: (
148
+ f"{d[symbol_key]} ({d.get(aam_key, '')})"
149
+ if show_atom_map
150
+ else f"{d[symbol_key]}"
151
+ )
152
+ for n, d in its.nodes(data=True)
153
+ }
154
+ edge_labels = self._determine_edge_labels(its, bond_char, bond_key)
155
+
156
+ nx.draw_networkx_labels(
157
+ its, positions, labels=labels, font_size=font_size, ax=ax
158
+ )
159
+ nx.draw_networkx_edge_labels(
160
+ its, positions, edge_labels=edge_labels, font_size=font_size, ax=ax
161
+ )
162
+
163
+ def _calculate_positions(self, its: nx.Graph, use_mol_coords: bool) -> dict:
164
+ if use_mol_coords:
165
+ mol = self._get_its_as_mol(its)
166
+ positions = {}
167
+ rdDepictor.Compute2DCoords(mol)
168
+ for i, atom in enumerate(mol.GetAtoms()):
169
+ aam = atom.GetAtomMapNum()
170
+ apos = mol.GetConformer().GetAtomPosition(i)
171
+ positions[aam] = [apos.x, apos.y]
172
+ else:
173
+ positions = nx.spring_layout(its)
174
+ return positions
175
+
176
+ def _determine_edge_labels(
177
+ self, its: nx.Graph, bond_char: dict, bond_key: str
178
+ ) -> dict:
179
+ edge_labels = {}
180
+ for u, v, data in its.edges(data=True):
181
+ bond_codes = data.get(bond_key, (0, 0))
182
+ bc1, bc2 = bond_char.get(bond_codes[0], "∅"), bond_char.get(
183
+ bond_codes[1], "∅"
184
+ )
185
+ if bc1 != bc2:
186
+ edge_labels[(u, v)] = f"({bc1},{bc2})"
187
+ return edge_labels
188
+
189
+ def plot_as_mol(
190
+ self,
191
+ g: nx.Graph,
192
+ ax: plt.Axes,
193
+ use_mol_coords: bool = True,
194
+ node_color: str = "#FFFFFF",
195
+ node_size: int = 500,
196
+ edge_color: str = "#000000",
197
+ edge_width: float = 2.0,
198
+ label_color: str = "#000000",
199
+ font_size: int = 12,
200
+ show_atom_map: bool = False,
201
+ bond_char: Dict[Optional[int], str] = None,
202
+ symbol_key: str = "element",
203
+ bond_key: str = "order",
204
+ aam_key: str = "atom_map",
205
+ ) -> None:
206
+ """
207
+ Plots a molecular graph on a given Matplotlib axes using either molecular coordinates
208
+ or a networkx layout.
209
+
210
+ Parameters:
211
+ - g (nx.Graph): The molecular graph to be plotted.
212
+ - ax (plt.Axes): Matplotlib axes where the graph will be plotted.
213
+ - use_mol_coords (bool, optional): Use molecular coordinates if True, else use networkx layout.
214
+ - node_color (str, optional): Color code for the nodes.
215
+ - node_size (int, optional): Size of the nodes.
216
+ - edge_color (str, optional): Color code for the edges.
217
+ - label_color (str, optional): Color for node labels.
218
+ - font_size (int, optional): Font size for labels.
219
+ - bond_char (Dict[Optional[int], str], optional): Dictionary mapping bond types to characters.
220
+ - symbol_key (str, optional): Node attribute key for element symbols.
221
+ - bond_key (str, optional): Edge attribute key for bond types.
222
+
223
+ Returns:
224
+ - None
225
+ """
226
+
227
+ # Set default bond characters if not provided
228
+ if bond_char is None:
229
+ bond_char = {None: "∅", 1: "—", 2: "=", 3: "≡", 1.5: ":"}
230
+
231
+ # Determine positions based on use_mol_coords flag
232
+ if use_mol_coords:
233
+ mol = GraphToMol(self.node_attributes, self.edge_attributes).graph_to_mol(
234
+ g, False
235
+ ) # This function needs to be defined or imported
236
+ positions = {}
237
+ rdDepictor.Compute2DCoords(mol)
238
+ for atom in mol.GetAtoms():
239
+ aidx = atom.GetIdx()
240
+ atom_map = atom.GetAtomMapNum()
241
+ apos = mol.GetConformer().GetAtomPosition(aidx)
242
+ positions[atom_map] = [apos.x, apos.y]
243
+ else:
244
+ positions = nx.spring_layout(g) # Optionally provide a layout configuration
245
+
246
+ ax.axis("equal")
247
+ ax.axis("off")
248
+
249
+ # Drawing elements on the plot
250
+ nx.draw_networkx_edges(
251
+ g, positions, edge_color=edge_color, width=edge_width, ax=ax
252
+ )
253
+ nx.draw_networkx_nodes(
254
+ g, positions, node_color=node_color, node_size=node_size, ax=ax
255
+ )
256
+
257
+ # Preparing labels
258
+ labels = {}
259
+ for n, d in g.nodes(data=True):
260
+ charge = d.get("charge", 0)
261
+ if charge == 0:
262
+ charge = ""
263
+ elif charge > 0:
264
+ charge = f"{charge}+" if charge > 1 else "+"
265
+ else:
266
+ charge = f"{-charge}-" if charge < -1 else "-"
267
+ label = f"{d.get(symbol_key, '')}{charge}"
268
+ if show_atom_map:
269
+ label += f" ({d.get(aam_key, '')})"
270
+ labels[n] = label
271
+ edge_labels = {
272
+ (u, v): bond_char.get(d[bond_key], "∅") for u, v, d in g.edges(data=True)
273
+ }
274
+
275
+ # Drawing labels
276
+ nx.draw_networkx_labels(
277
+ g,
278
+ positions,
279
+ labels=labels,
280
+ font_color=label_color,
281
+ font_size=font_size,
282
+ ax=ax,
283
+ )
284
+ nx.draw_networkx_edge_labels(
285
+ g, positions, edge_labels=edge_labels, font_color=label_color, ax=ax
286
+ )
@@ -0,0 +1,143 @@
1
+ """
2
+ This module comprises several functions adapted from the work of Klaus Weinbauer.
3
+ The original code can be found at his GitHub repository: https://github.com/klausweinbauer/FGUtils.
4
+ Adaptations were made to enhance functionality and integrate with other system components.
5
+ """
6
+
7
+ import matplotlib.pyplot as plt
8
+ from matplotlib.backends.backend_pdf import PdfPages
9
+ from typing import List, Callable, Union, Tuple, Optional
10
+ import tqdm
11
+
12
+
13
+ class PdfWriter:
14
+ """
15
+ A utility class to create PDF reports with plots from a list of figures or dynamically generated plots.
16
+
17
+ Parameters:
18
+ - file (str): The file name of the output PDF.
19
+ - plot_fn (Optional[Callable], optional): Function to create a plot for a single data entry or row.
20
+ Expected interface: `plot_fn(data_entry, axis, **kwargs)`. Default is None.
21
+ - plot_per_row (bool, optional): If True, calls `plot_fn` for an entire row instead of individual subplots.
22
+ Default is False.
23
+ - max_pages (int, optional): Maximum number of pages to create. Default is 999.
24
+ - rows (int, optional): Number of plot rows per page. Default is 7.
25
+ - cols (int, optional): Number of plot columns per page. Default is 2.
26
+ - pagesize (Tuple[float, float], optional): Size of a single page (in inches). Default is (21, 29.7).
27
+ - width_ratios (Optional[List[float]], optional): Column width ratios. Default is None.
28
+ - show_progress (bool, optional): If True, displays a progress bar using `tqdm`. Default is True.
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ file: str,
34
+ plot_fn: Optional[Callable] = None,
35
+ plot_per_row: bool = False,
36
+ max_pages: int = 999,
37
+ rows: int = 7,
38
+ cols: int = 2,
39
+ pagesize: Tuple[float, float] = (21, 29.7),
40
+ width_ratios: Optional[List[float]] = None,
41
+ show_progress: bool = True,
42
+ ):
43
+ self.pdf_pages = PdfPages(file)
44
+ self.plot_fn = plot_fn
45
+ self.plot_per_row = plot_per_row
46
+ self.max_pages = max_pages
47
+ self.rows = rows
48
+ self.cols = cols
49
+ self.pagesize = pagesize
50
+ self.width_ratios = width_ratios
51
+ self.show_progress = show_progress
52
+
53
+ def plot(self, data: Union[List[plt.Figure], List], **kwargs):
54
+ """
55
+ Generate plots from data or save pre-generated figures to the PDF.
56
+
57
+ Parameters:
58
+ - data (Union[List[matplotlib.figure.Figure], List]): Input data or list of figures.
59
+ If a list of figures, they are saved directly. Otherwise, the `plot_fn` is called for each data entry.
60
+ - **kwargs: Additional keyword arguments passed to `plot_fn`.
61
+
62
+ Returns:
63
+ - None
64
+ """
65
+ # Case 1: Pre-generated figures
66
+ if all(isinstance(item, plt.Figure) for item in data):
67
+ for fig in tqdm.tqdm(
68
+ data, disable=not self.show_progress, desc="Saving Figures"
69
+ ):
70
+ self.save_figure(fig)
71
+ return
72
+
73
+ # Case 2: Generate plots dynamically using `plot_fn`
74
+ if self.plot_fn is None:
75
+ raise ValueError(
76
+ "plot_fn must be provided when input is not a list of figures."
77
+ )
78
+
79
+ if not isinstance(data, list):
80
+ raise ValueError(
81
+ "Data must be a list or a list of matplotlib.figure.Figure."
82
+ )
83
+
84
+ plots_per_page = self.rows if self.plot_per_row else self.rows * self.cols
85
+ max_plots = self.max_pages * plots_per_page
86
+ step = max(len(data) / max_plots, 1)
87
+ pages = int((len(data) / step + plots_per_page - 1) // plots_per_page)
88
+
89
+ for p in tqdm.tqdm(
90
+ range(pages), disable=not self.show_progress, desc="Generating Pages"
91
+ ):
92
+ fig, ax = plt.subplots(
93
+ self.rows,
94
+ self.cols,
95
+ figsize=self.pagesize,
96
+ squeeze=False,
97
+ gridspec_kw=(
98
+ {"width_ratios": self.width_ratios} if self.width_ratios else None
99
+ ),
100
+ )
101
+ done = False
102
+ for r in range(self.rows):
103
+ if self.plot_per_row:
104
+ _idx = int((p * self.rows + r) * step)
105
+ if _idx >= len(data):
106
+ done = True
107
+ break
108
+ self.plot_fn(data[_idx], ax[r, :], index=_idx, **kwargs)
109
+ else:
110
+ for c in range(self.cols):
111
+ _idx = int((p * plots_per_page + r * self.cols + c) * step)
112
+ if _idx >= len(data):
113
+ done = True
114
+ break
115
+ self.plot_fn(data[_idx], ax[r, c], index=_idx, **kwargs)
116
+ plt.tight_layout()
117
+ self.pdf_pages.savefig(fig, bbox_inches="tight", pad_inches=1)
118
+ plt.close(fig)
119
+ if done:
120
+ break
121
+
122
+ def save_figure(self, figure: plt.Figure):
123
+ """
124
+ Save a pre-generated matplotlib figure directly to the PDF.
125
+
126
+ Parameters:
127
+ - figure (matplotlib.figure.Figure): The figure to save.
128
+
129
+ Returns:
130
+ - None
131
+ """
132
+ if not isinstance(figure, plt.Figure):
133
+ raise ValueError("Input must be a matplotlib.figure.Figure.")
134
+ self.pdf_pages.savefig(figure, bbox_inches="tight", pad_inches=1)
135
+
136
+ def close(self):
137
+ """
138
+ Close the PDF file, ensuring all pages are written.
139
+
140
+ Returns:
141
+ - None
142
+ """
143
+ self.pdf_pages.close()
@@ -0,0 +1,169 @@
1
+ import networkx as nx
2
+ import matplotlib.pyplot as plt
3
+ from typing import Union, Tuple
4
+
5
+ from synkit.Vis.graph_visualizer import GraphVisualizer
6
+ from synkit.IO.chem_converter import rsmi_to_graph
7
+ from synkit.ITS.its_construction import ITSConstruction
8
+ from synkit.IO.gml_to_nx import GMLToNX
9
+
10
+ vis_graph = GraphVisualizer()
11
+
12
+
13
+ def three_graph_vis(
14
+ input: Union[str, Tuple[nx.Graph, nx.Graph, nx.Graph]],
15
+ sanitize: bool = False,
16
+ figsize: Tuple[int, int] = (18, 5),
17
+ orientation: str = "horizontal",
18
+ show_titles: bool = True,
19
+ show_atom_map: bool = False,
20
+ titles: Tuple[str, str, str] = (
21
+ "Reactants",
22
+ "Imaginary Transition State",
23
+ "Products",
24
+ ),
25
+ add_gridbox: bool = False,
26
+ rule: bool = False,
27
+ ) -> plt.Figure:
28
+ """
29
+ Visualize three related graphs (reactants, imaginary transition state, and products)
30
+ side by side or vertically in a single figure.
31
+
32
+ Parameters:
33
+ - input (Union[str, Tuple[nx.Graph, nx.Graph, nx.Graph]]): Either
34
+ a reaction SMILES stringor a tuple of three NetworkX graphs
35
+ (reactants, products, ITS).
36
+ - sanitize (bool, optional): If True, sanitizes the input molecule.
37
+ Default is False.
38
+ - figsize (Tuple[int, int], optional): The size of the Matplotlib figure.
39
+ Default is (18, 5).
40
+ - orientation (str, optional): Layout of the subplots; 'horizontal' or 'vertical'.
41
+ Default is 'horizontal'.
42
+ - show_titles (bool, optional): If True, adds titles to each subplot.
43
+ Default is True.
44
+ - titles (Tuple[str, str, str], optional): Custom titles for each subplot.
45
+ Default is ('Reactants', 'Imaginary Transition State', 'Products').
46
+ - add_gridbox (bool, optional): If True, adds a gridbox cover for each subplot
47
+ (rectangular frame). Default is False.
48
+
49
+ Returns:
50
+ - plt.Figure: The Matplotlib figure containing the three subplots.
51
+ """
52
+ try:
53
+ # Parse input to determine graphs
54
+ if isinstance(input, str):
55
+ r, p = rsmi_to_graph(input, light_weight=True, sanitize=sanitize)
56
+ its = ITSConstruction().ITSGraph(r, p)
57
+ elif isinstance(input, tuple) and len(input) == 3:
58
+ r, p, its = input
59
+ else:
60
+ raise ValueError(
61
+ "Input must be a reaction SMILES string or a tuple of three graphs (r, p, its)."
62
+ )
63
+
64
+ # Set up subplots
65
+ if orientation == "horizontal":
66
+ fig, ax = plt.subplots(1, 3, figsize=figsize)
67
+ elif orientation == "vertical":
68
+ fig, ax = plt.subplots(3, 1, figsize=figsize)
69
+ else:
70
+ raise ValueError("Orientation must be 'horizontal' or 'vertical'.")
71
+
72
+ # Plot the graphs
73
+ vis_graph.plot_as_mol(
74
+ r,
75
+ ax[0],
76
+ show_atom_map=show_atom_map,
77
+ font_size=12,
78
+ node_size=800,
79
+ edge_width=2.0,
80
+ )
81
+ if show_titles:
82
+ ax[0].set_title(titles[0])
83
+
84
+ vis_graph.plot_its(
85
+ its, ax[1], use_edge_color=True, show_atom_map=show_atom_map, rule=rule
86
+ )
87
+ if show_titles:
88
+ ax[1].set_title(titles[1])
89
+
90
+ vis_graph.plot_as_mol(
91
+ p,
92
+ ax[2],
93
+ show_atom_map=show_atom_map,
94
+ font_size=12,
95
+ node_size=800,
96
+ edge_width=2.0,
97
+ )
98
+ if show_titles:
99
+ ax[2].set_title(titles[2])
100
+
101
+ # Add gridbox frame around each subplot if requested
102
+ if add_gridbox:
103
+ for a in ax:
104
+ # Make sure the grid is on top of the plot
105
+ a.set_axisbelow(False)
106
+
107
+ # Add a rectangular frame (gridbox) with thicker borders
108
+ for spine in a.spines.values():
109
+ spine.set_visible(True)
110
+ spine.set_linewidth(2)
111
+ spine.set_color("black")
112
+
113
+ # Make gridlines lighter and under the plot elements
114
+ a.grid(
115
+ True,
116
+ which="both",
117
+ axis="both",
118
+ linestyle="--",
119
+ color="gray",
120
+ alpha=0.5,
121
+ )
122
+
123
+ return fig
124
+
125
+ except Exception as e:
126
+ raise RuntimeError(f"An error occurred during visualization: {str(e)}")
127
+
128
+
129
+ def rule_visualize(gml, rule=True, titles=None):
130
+ """
131
+ Visualizes a reaction network from GML data with optional edge filtering (rule)
132
+ and custom titles.
133
+
134
+ Parameters:
135
+ - gml (str): GML format string representing the reaction data.
136
+ - rule (bool): If True, applies the rule to filter edges
137
+ (e.g., removing edges based on color).
138
+ - titles (list, optional): List of titles for the subplots.
139
+ Defaults to ['L', 'K', 'R'].
140
+
141
+ Returns:
142
+ - plt.Figure: Matplotlib figure containing the visualized reaction network.
143
+ """
144
+ try:
145
+ # Transform GML to NetworkX graphs
146
+ r, p, its = GMLToNX(gml).transform()
147
+
148
+ # If no titles are provided, default to ['L', 'K', 'R']
149
+ if titles is None:
150
+ titles = ["L", "K", "R"]
151
+
152
+ # Ensure titles match the number of graphs (3)
153
+ if len(titles) != 3:
154
+ raise ValueError(
155
+ "The titles list must contain exactly three titles for the three graphs."
156
+ )
157
+
158
+ # Call the `three_graph_vis` function with the transformed graphs and rule filtering
159
+ return three_graph_vis(
160
+ (r, p, its),
161
+ add_gridbox=True, # Add the gridbox around the plot
162
+ titles=titles, # Pass the titles for the subplots
163
+ rule=rule, # Apply the rule filtering based on the value of `rule`
164
+ )
165
+
166
+ except Exception as e:
167
+ raise RuntimeError(
168
+ f"An error occurred during the visualization process: {str(e)}"
169
+ )
synkit/__init__.py ADDED
File without changes