rdworks 0.25.8__py3-none-any.whl → 0.36.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rdworks/display.py CHANGED
@@ -1,135 +1,296 @@
1
- import io
2
- import os
3
1
  import numpy as np
4
- from typing import Optional, List, Tuple
5
2
 
6
3
  from PIL import Image, ImageChops
4
+ from io import BytesIO
7
5
 
8
- from rdkit import Chem, Geometry
6
+ from rdkit import Chem
9
7
  from rdkit.Chem import AllChem, Draw, rdDepictor, rdMolTransforms
10
8
  from rdkit.Chem.Draw import rdMolDraw2D
9
+ from rdkit.Chem.Draw import MolsMatrixToGridImage # new in RDKit 2023.09.1
10
+
11
+ # SVG optimization
12
+ from scour.scour import scourString
11
13
 
12
14
 
13
15
  # https://greglandrum.github.io/rdkit-blog/posts/2023-05-26-drawing-options-explained.html
14
16
 
15
17
 
16
- def twod_depictor(rdmol:Chem.Mol, index:bool=False, coordgen:bool=False) -> Chem.Mol:
17
- """Sets up for 2D depiction.
18
+ def trim_png(img:Image.Image) -> Image.Image:
19
+ """Removes white margin around molecular drawing.
18
20
 
19
21
  Args:
20
- rdmol (Chem.Mol): input molecule.
21
- index (bool, optional): whether to show atom index. Defaults to False.
22
- coordgen (bool, optional): whether to set rdDepictor.SetPreferCoordGen. Defaults to False.
22
+ img (Image.Image): input PIL Image object.
23
+
24
+ Returns:
25
+ Image.Image: output PIL Image object.
26
+ """
27
+ bg = Image.new(img.mode, img.size, img.getpixel((0,0)))
28
+ diff = ImageChops.difference(img,bg)
29
+ diff = ImageChops.add(diff, diff, 2.0, -100)
30
+ bbox = diff.getbbox()
31
+
32
+ if bbox:
33
+ return img.crop(bbox)
34
+
35
+ return img
36
+
37
+
38
+ def get_highlight_bonds(rdmol: Chem.Mol, atom_indices: list[int]) -> list[int] | None:
39
+ """Get bond indices for bonds between atom indices.
40
+
41
+ Args:
42
+ rdmol (Chem.Mol): rdkit Chem.Mol object.
43
+ atom_indices (list[int]): atom indices.
23
44
 
24
45
  Returns:
25
- Chem.Mol: a copy of rdkit.Chem.Mol object.
46
+ list[int]: bond indices.
26
47
  """
27
- if coordgen:
28
- rdDepictor.SetPreferCoordGen(True)
48
+ bond_indices = []
49
+ for bond in rdmol.GetBonds():
50
+ if bond.GetBeginAtomIdx() in atom_indices and bond.GetEndAtomIdx() in atom_indices:
51
+ bond_indices.append(bond.GetIdx())
52
+
53
+ if bond_indices:
54
+ return bond_indices
29
55
  else:
30
- rdDepictor.SetPreferCoordGen(False)
56
+ return None
57
+
58
+
59
+ def render_2D_mol(rdmol:Chem.Mol,
60
+ moldrawer:rdMolDraw2D,
61
+ redraw: bool = False,
62
+ coordgen: bool = False,
63
+ legend: str = '',
64
+ atom_index: bool = False,
65
+ highlight_atoms: list[int] | None = None,
66
+ highlight_bonds: list[int] | None = None,
67
+ ) -> str:
31
68
 
32
69
  rdmol_2d = Chem.Mol(rdmol)
33
- rdDepictor.Compute2DCoords(rdmol_2d)
34
- rdDepictor.StraightenDepiction(rdmol_2d)
35
70
 
36
- for atom in rdmol_2d.GetAtoms():
37
- for key in atom.GetPropsAsDict():
38
- atom.ClearProp(key)
39
-
40
- if index: # index hides polar hydrogens
41
- for atom in rdmol_2d.GetAtoms():
42
- atom.SetProp("atomLabel", str(atom.GetIdx()))
43
- # atom.SetProp("atomNote", str(atom.GetIdx()))
44
- # atom.SetProp("molAtomMapNumber", str(atom.GetIdx()))
71
+ if redraw or rdmol_2d.GetNumConformers() == 0:
72
+ rdDepictor.SetPreferCoordGen(coordgen)
73
+ rdmol_2d = Chem.RemoveHs(rdmol_2d)
74
+ rdDepictor.Compute2DCoords(rdmol_2d)
75
+
76
+ rdDepictor.StraightenDepiction(rdmol_2d)
77
+
78
+ if (highlight_bonds is None) and (highlight_atoms is not None):
79
+ # highlight bonds between the highlighted atoms
80
+ highlight_bonds = get_highlight_bonds(rdmol_2d, highlight_atoms)
81
+
82
+ draw_options = moldrawer.drawOptions()
83
+
84
+ draw_options.addAtomIndices = atom_index
85
+ # draw_options.setHighlightColour((0,.9,.9,.8)) # Cyan highlight
86
+ # draw_options.addBondIndices = True
87
+ # draw_options.noAtomLabels = True
88
+ draw_options.atomLabelDeuteriumTritium = True # D, T
89
+ # draw_options.explicitMethyl = True
90
+ draw_options.singleColourWedgeBonds = True
91
+ draw_options.addStereoAnnotation = True
92
+ # draw_options.fillHighlights = False
93
+ # draw_options.highlightRadius = .4
94
+ # draw_options.highlightBondWidthMultiplier = 12
95
+ # draw_options.variableAtomRadius = 0.2
96
+ # draw_options.variableBondWidthMultiplier = 40
97
+ # draw_options.setVariableAttachmentColour((.5,.5,1))
98
+ # draw_options.baseFontSize = 1.0 # default is 0.6
99
+ # draw_options.annotationFontScale = 1
100
+ # draw_options.rotate = 30 # rotation angle in degrees
101
+ # draw_options.padding = 0.2 # default is 0.05
102
+
103
+ # for atom in rdmol_2d.GetAtoms():
104
+ # for key in atom.GetPropsAsDict():
105
+ # atom.ClearProp(key)
106
+ # if index: # index hides polar hydrogens
107
+ # for atom in rdmol_2d.GetAtoms():
108
+ # atom.SetProp("atomLabel", str(atom.GetIdx()))
109
+ # # # atom.SetProp("atomNote", str(atom.GetIdx()))
110
+ # # # atom.SetProp("molAtomMapNumber", str(atom.GetIdx()))
111
+
112
+ moldrawer.DrawMolecule(rdmol_2d,
113
+ legend=legend,
114
+ highlightAtoms=highlight_atoms,
115
+ highlightBonds=highlight_bonds)
116
+ moldrawer.FinishDrawing()
45
117
 
46
- return rdmol_2d
118
+ return moldrawer.GetDrawingText()
47
119
 
48
120
 
49
- def svg(rdmol:Chem.Mol,
50
- width:int=300,
51
- height:int=300,
52
- legend:str='',
53
- index:bool=False,
54
- highlight:list[int] | None = None,
55
- coordgen:bool = False) -> str:
56
- """Returns string SVG output of a molecule.
121
+ def render_svg(rdmol: Chem.Mol,
122
+ width: int = 300,
123
+ height: int = 300,
124
+ legend: str = '',
125
+ atom_index: bool = False,
126
+ highlight_atoms: list[int] | None = None,
127
+ highlight_bonds: list[int] | None = None,
128
+ redraw: bool = False,
129
+ coordgen: bool = False,
130
+ optimize: bool = True) -> str:
131
+ """Draw 2D molecule in SVG format.
57
132
 
58
133
  Examples:
134
+ For Jupyternotebook, wrap the output with SVG:
135
+
59
136
  >>> from IPython.display import SVG
60
137
  >>> SVG(libr[0].to_svg())
61
138
 
62
139
  Args:
63
- rdmol (Chem.Mol): input molecule.
64
- width (int): width. Defaults to 300.
65
- height (int): height. Defaults to 300.
66
- legend (str): title of molecule. Defaults to ''.
67
- index (bool): whether to show atom indexes. Defaults to False.
68
- highlight (list[int]): list of atom indices to highlight. Defaults to None.
69
- coordgen (bool): whether to use rdDepictor.SetPreferCoordGen. Defaults to False.
140
+ rdmol (Chem.Mol): rdkit Chem.Mol object.
141
+ width (int, optional): width. Defaults to 300.
142
+ height (int, optional): height. Defaults to 300.
143
+ legend (str, optional): legend. Defaults to ''.
144
+ atom_index (bool, optional): whether to show atom index. Defaults to False.
145
+ highlight_atoms (list[int] | None, optional): atom(s) to highlight. Defaults to None.
146
+ highlight_bonds (list[int] | None, optional): bond(s) to highlight. Defaults to None.
147
+ redraw (bool, optional): whether to redraw. Defaults to False.
148
+ coordgen (bool, optional): whether to use coordgen. Defaults to False.
149
+ optimize (bool, optional): whether to optimize SVG string. Defaults to True.
70
150
 
71
151
  Returns:
72
- str: SVG text
152
+ str: SVG string
73
153
  """
74
- d2d_svg = rdMolDraw2D.MolDraw2DSVG(width, height)
75
- rdmol_2d = twod_depictor(rdmol, index, coordgen)
76
- if highlight:
77
- d2d_svg.DrawMolecule(rdmol_2d, legend=legend, highlightAtoms=highlight)
78
- else:
79
- d2d_svg.DrawMolecule(rdmol_2d, legend=legend)
80
- #rdMolDraw2D.PrepareAndDrawMolecule(d2d_svg, rdmol_2d, highlightAtoms=highlight, legend=legend)
81
- d2d_svg.FinishDrawing()
82
- return d2d_svg.GetDrawingText()
154
+
155
+ svg_string = render_2D_mol(rdmol,
156
+ moldrawer = rdMolDraw2D.MolDraw2DSVG(width, height),
157
+ redraw = redraw,
158
+ coordgen = coordgen,
159
+ legend = legend,
160
+ atom_index = atom_index,
161
+ highlight_atoms = highlight_atoms,
162
+ highlight_bonds = highlight_bonds,
163
+ )
164
+
165
+ if optimize:
166
+ scour_options = {
167
+ 'strip_comments': True,
168
+ 'strip_ids': True,
169
+ 'shorten_ids': True,
170
+ 'compact_paths': True,
171
+ 'indent_type': 'none',
172
+ }
173
+ svg_string = scourString(svg_string, options=scour_options)
174
+
175
+ return svg_string
83
176
 
84
177
 
85
- def png(rdmol:Chem.Mol, width:int=300, height:int=300, legend:str='',
86
- index:bool=False, highlight:Optional[List[int]]=None, coordgen:bool=False) -> Image.Image:
87
- """Returns a trimmed PIL Image object of a molecule.
178
+ def render_png(rdmol: Chem.Mol,
179
+ width: int = 300,
180
+ height: int = 300,
181
+ legend: str = '',
182
+ atom_index: bool = False,
183
+ highlight_atoms: list[int] | None = None,
184
+ highlight_bonds: list[int] | None = None,
185
+ redraw: bool = False,
186
+ coordgen: bool = False,
187
+ trim: bool = True) -> Image.Image:
188
+ """Draw 2D molecule in PNG format.
88
189
 
89
190
  Args:
90
- rdmol (Chem.Mol): input molecule.
91
- width (int): width. Defaults to 300.
92
- height (int): height. Defaults to 300.
93
- legend (str): title of molecule. Defaults to ''.
94
- index (bool): whether to show atom indexes. Defaults to False.
95
- highlight (list): list of atom indices to highlight. Defaults to None.
96
- coordgen (bool): whether to use rdDepictor.SetPreferCoordGen. Defaults to False.
191
+ rdmol (Chem.Mol): rdkit Chem.Mol object.
192
+ width (int, optional): width. Defaults to 300.
193
+ height (int, optional): height. Defaults to 300.
194
+ legend (str, optional): legend. Defaults to ''.
195
+ atom_index (bool, optional): whether to show atom index. Defaults to False.
196
+ highlight_atoms (list[int] | None, optional): atom(s) to highlight. Defaults to None.
197
+ highlight_bonds (list[int] | None, optional): bond(s) to highlight. Defaults to None.
198
+ redraw (bool, optional): whether to redraw. Defaults to False.
199
+ coordgen (bool, optional): whether to use coordgen. Defaults to False.
97
200
 
98
201
  Returns:
99
202
  Image.Image: output PIL Image object.
100
203
  """
101
- rdmol_2d = twod_depictor(rdmol, index, coordgen)
102
- img = Draw.MolToImage(rdmol_2d,
103
- size=(width,height),
104
- highlightAtoms=highlight,
105
- kekulize=True,
106
- wedgeBonds=True,
107
- fitImage=False,
108
- )
109
- # highlightAtoms: list of atoms to highlight (default [])
110
- # highlightBonds: list of bonds to highlight (default [])
111
- # highlightColor: RGB color as tuple (default [1, 0, 0])
112
204
 
113
- return trim_png(img)
205
+ png_string = render_2D_mol(rdmol,
206
+ moldrawer = rdMolDraw2D.MolDraw2DCairo(width, height),
207
+ redraw = redraw,
208
+ coordgen = coordgen,
209
+ legend = legend,
210
+ atom_index = atom_index,
211
+ highlight_atoms = highlight_atoms,
212
+ highlight_bonds = highlight_bonds,
213
+ )
114
214
 
215
+ img = Image.open(BytesIO(png_string))
216
+
217
+ if trim:
218
+ img = trim_png(img)
219
+
220
+ return img
115
221
 
116
- def trim_png(img:Image.Image) -> Image.Image:
117
- """Removes white margin around molecular drawing.
222
+
223
+ def render_matrix_grid(rdmol: list[Chem.Mol],
224
+ legend: list[str] | None,
225
+ highlight_atoms: list[list[int]] | None = None,
226
+ highlight_bonds: list[list[int]] | None = None,
227
+ mols_per_row: int = 5,
228
+ width: int = 200,
229
+ height: int = 200,
230
+ atom_index: bool = False,
231
+ redraw: bool = False,
232
+ coordgen: bool = False,
233
+ svg: bool = True,
234
+ ) -> str | Image.Image:
235
+ """Rendering a grid image from a list of molecules.
118
236
 
119
237
  Args:
120
- img (Image.Image): input PIL Image object.
238
+ rdmol (list[Chem.Mol]): list of rdkit Chem.Mol objects.
239
+ legend (list[str]): list of legends
240
+ highlight_atoms (list[list[int]] | None, optional): list of atom(s) to highlight. Defaults to None.
241
+ highlight_bonds (list[list[int]] | None, optional): list of bond(s) to highlight. Defaults to None.
242
+ mols_per_row (int, optional): molecules per row. Defaults to 5.
243
+ width (int, optional): width. Defaults to 200.
244
+ height (int, optional): height. Defaults to 200.
245
+ atom_index (bool, optional): whether to show atom index. Defaults to False.
246
+ redraw (bool, optional): whether to redraw 2D. Defaults to False.
247
+ coordgen (bool, optional): whether to use coordgen to depict. Defaults to False.
121
248
 
122
249
  Returns:
123
- Image.Image: output PIL Image object.
250
+ str | Image.Image: SVG string or PIL Image object.
251
+
252
+ Reference:
253
+ https://greglandrum.github.io/rdkit-blog/posts/2023-10-25-molsmatrixtogridimage.html
124
254
  """
125
- bg = Image.new(img.mode, img.size, img.getpixel((0,0)))
126
- diff = ImageChops.difference(img,bg)
127
- diff = ImageChops.add(diff, diff, 2.0, -100)
128
- bbox = diff.getbbox()
129
- if bbox:
130
- return img.crop(bbox)
131
- return img
132
255
 
256
+ n = len(rdmol)
257
+
258
+ if isinstance(legend, list):
259
+ assert len(legend) == n, "number of legends and molecules must be the same"
260
+ elif legend is None:
261
+ legend = ['',] * n
262
+
263
+ if isinstance(highlight_atoms, list):
264
+ assert len(highlight_atoms) == n, "number of highlights and molecules must be the same"
265
+ elif highlight_atoms is None:
266
+ highlight_atoms = [ (), ] * n
267
+
268
+ if isinstance(highlight_bonds, list):
269
+ assert len(highlight_bonds) == n, "number of highlights and molecules must be the same"
270
+ elif highlight_bonds is None:
271
+ highlight_bonds = [ (), ] * n
272
+
273
+ rdmol_matrix = []
274
+ legend_matrix = []
275
+ highlight_atoms_matrix = []
276
+ highlight_bonds_matrix = []
277
+
278
+ for i in range(0, n, mols_per_row):
279
+ rdmol_matrix.append(rdmol[i:(i + mols_per_row)])
280
+ legend_matrix.append(legend[i:(i + mols_per_row)])
281
+ highlight_atoms_matrix.append(highlight_atoms[i:(i + mols_per_row)])
282
+ highlight_bonds_matrix.append(highlight_bonds[i:(i + mols_per_row)])
283
+
284
+ return MolsMatrixToGridImage(
285
+ molsMatrix = rdmol_matrix,
286
+ subImgSize = (width, height),
287
+ legendsMatrix = legend_matrix,
288
+ highlightAtomListsMatrix = highlight_atoms_matrix,
289
+ highlightBondListsMatrix = highlight_bonds_matrix,
290
+ useSVG = svg,
291
+ returnPNG = False # whether to return PNG data (True) or a PIL object (False)
292
+ )
293
+
133
294
 
134
295
  def rescale(rdmol:Chem.Mol, factor:float=1.5) -> Chem.Mol:
135
296
  """Returns a copy of `rdmol` by a `factor`.