molcraft 0.1.0rc9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of molcraft might be problematic. Click here for more details.
- molcraft/__init__.py +18 -0
- molcraft/callbacks.py +100 -0
- molcraft/chem.py +714 -0
- molcraft/datasets.py +132 -0
- molcraft/descriptors.py +149 -0
- molcraft/features.py +379 -0
- molcraft/featurizers.py +624 -0
- molcraft/layers.py +1910 -0
- molcraft/losses.py +37 -0
- molcraft/models.py +623 -0
- molcraft/ops.py +195 -0
- molcraft/records.py +187 -0
- molcraft/tensors.py +561 -0
- molcraft/trainers.py +212 -0
- molcraft-0.1.0rc9.dist-info/METADATA +118 -0
- molcraft-0.1.0rc9.dist-info/RECORD +19 -0
- molcraft-0.1.0rc9.dist-info/WHEEL +5 -0
- molcraft-0.1.0rc9.dist-info/licenses/LICENSE +21 -0
- molcraft-0.1.0rc9.dist-info/top_level.txt +1 -0
molcraft/chem.py
ADDED
|
@@ -0,0 +1,714 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
import collections
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from rdkit import Chem
|
|
6
|
+
from rdkit.Chem import AllChem
|
|
7
|
+
from rdkit.Chem import Lipinski
|
|
8
|
+
from rdkit.Chem import rdDistGeom
|
|
9
|
+
from rdkit.Chem import rdDepictor
|
|
10
|
+
from rdkit.Chem import rdMolAlign
|
|
11
|
+
from rdkit.Chem import rdMolTransforms
|
|
12
|
+
from rdkit.Chem import rdPartialCharges
|
|
13
|
+
from rdkit.Chem import rdMolDescriptors
|
|
14
|
+
from rdkit.Chem import rdForceFieldHelpers
|
|
15
|
+
from rdkit.Chem import rdFingerprintGenerator
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
RDKitMol = Chem.Mol
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class Mol(RDKitMol):
|
|
22
|
+
|
|
23
|
+
@classmethod
|
|
24
|
+
def from_encoding(cls, encoding: str, explicit_hs: bool = False, **kwargs) -> 'Mol':
|
|
25
|
+
rdkit_mol = get_mol(encoding, **kwargs)
|
|
26
|
+
if explicit_hs:
|
|
27
|
+
rdkit_mol = Chem.AddHs(rdkit_mol)
|
|
28
|
+
rdkit_mol.__class__ = cls
|
|
29
|
+
setattr(rdkit_mol, '_encoding', encoding)
|
|
30
|
+
return rdkit_mol
|
|
31
|
+
|
|
32
|
+
@classmethod
|
|
33
|
+
def cast(cls, obj: RDKitMol) -> 'Mol':
|
|
34
|
+
obj.__class__ = cls
|
|
35
|
+
return obj
|
|
36
|
+
|
|
37
|
+
@property
|
|
38
|
+
def canonical_smiles(self) -> str:
|
|
39
|
+
return Chem.MolToSmiles(self, canonical=True)
|
|
40
|
+
|
|
41
|
+
@property
|
|
42
|
+
def encoding(self):
|
|
43
|
+
return getattr(self, '_encoding', None)
|
|
44
|
+
|
|
45
|
+
@property
|
|
46
|
+
def bonds(self) -> list['Bond']:
|
|
47
|
+
return get_bonds(self)
|
|
48
|
+
|
|
49
|
+
@property
|
|
50
|
+
def atoms(self) -> list['Atom']:
|
|
51
|
+
return get_atoms(self)
|
|
52
|
+
|
|
53
|
+
@property
|
|
54
|
+
def num_conformers(self) -> int:
|
|
55
|
+
return int(self.GetNumConformers())
|
|
56
|
+
|
|
57
|
+
@property
|
|
58
|
+
def num_atoms(self) -> int:
|
|
59
|
+
return int(self.GetNumAtoms())
|
|
60
|
+
|
|
61
|
+
@property
|
|
62
|
+
def num_bonds(self) -> int:
|
|
63
|
+
return int(self.GetNumBonds())
|
|
64
|
+
|
|
65
|
+
def get_atom(
|
|
66
|
+
self,
|
|
67
|
+
atom: int | Chem.Atom
|
|
68
|
+
) -> 'Atom':
|
|
69
|
+
if isinstance(atom, Chem.Atom):
|
|
70
|
+
atom = atom.GetIdx()
|
|
71
|
+
return Atom.cast(self.GetAtomWithIdx(int(atom)))
|
|
72
|
+
|
|
73
|
+
def get_shortest_path_between_atoms(
|
|
74
|
+
self,
|
|
75
|
+
atom_i: int | Chem.Atom,
|
|
76
|
+
atom_j: int | Chem.Atom
|
|
77
|
+
) -> tuple[int]:
|
|
78
|
+
if isinstance(atom_i, Chem.Atom):
|
|
79
|
+
atom_i = atom_i.GetIdx()
|
|
80
|
+
if isinstance(atom_j, Chem.Atom):
|
|
81
|
+
atom_j = atom_j.GetIdx()
|
|
82
|
+
return Chem.rdmolops.GetShortestPath(
|
|
83
|
+
self, int(atom_i), int(atom_j)
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
def get_bond_between_atoms(
|
|
87
|
+
self,
|
|
88
|
+
atom_i: int | Chem.Atom,
|
|
89
|
+
atom_j: int | Chem.Atom,
|
|
90
|
+
) -> 'Bond':
|
|
91
|
+
if isinstance(atom_i, Chem.Atom):
|
|
92
|
+
atom_i = atom_i.GetIdx()
|
|
93
|
+
if isinstance(atom_j, Chem.Atom):
|
|
94
|
+
atom_j = atom_j.GetIdx()
|
|
95
|
+
return Bond.cast(self.GetBondBetweenAtoms(int(atom_i), int(atom_j)))
|
|
96
|
+
|
|
97
|
+
def adjacency(
|
|
98
|
+
self,
|
|
99
|
+
fill: str = 'upper',
|
|
100
|
+
sparse: bool = True,
|
|
101
|
+
self_loops: bool = False,
|
|
102
|
+
dtype: str= 'int32',
|
|
103
|
+
) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
|
|
104
|
+
return get_adjacency_matrix(
|
|
105
|
+
self, fill=fill, sparse=sparse, self_loops=self_loops, dtype=dtype
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
def get_conformer(self, index: int = 0) -> 'Conformer':
|
|
109
|
+
if self.num_conformers == 0:
|
|
110
|
+
warnings.warn(f'{self} has no conformer. Returning None.')
|
|
111
|
+
return None
|
|
112
|
+
return Conformer.cast(self.GetConformer(index))
|
|
113
|
+
|
|
114
|
+
def get_conformers(self) -> list['Conformer']:
|
|
115
|
+
if self.num_conformers == 0:
|
|
116
|
+
warnings.warn(f'{self} has no conformers. Returning an empty list.')
|
|
117
|
+
return []
|
|
118
|
+
return [Conformer.cast(x) for x in self.GetConformers()]
|
|
119
|
+
|
|
120
|
+
def __len__(self) -> int:
|
|
121
|
+
return int(self.GetNumAtoms())
|
|
122
|
+
|
|
123
|
+
def _repr_png_(self) -> None:
|
|
124
|
+
return None
|
|
125
|
+
|
|
126
|
+
def __repr__(self) -> str:
|
|
127
|
+
encoding = self.encoding or self.canonical_smiles
|
|
128
|
+
return f'<{self.__class__.__name__} {encoding} at {hex(id(self))}>'
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class Conformer(Chem.Conformer):
|
|
132
|
+
|
|
133
|
+
@classmethod
|
|
134
|
+
def cast(cls, obj: Chem.Conformer) -> 'Conformer':
|
|
135
|
+
obj.__class__ = cls
|
|
136
|
+
return obj
|
|
137
|
+
|
|
138
|
+
@property
|
|
139
|
+
def index(self) -> int:
|
|
140
|
+
return self.GetId()
|
|
141
|
+
|
|
142
|
+
@property
|
|
143
|
+
def coordinates(self) -> np.ndarray:
|
|
144
|
+
return self.GetPositions()
|
|
145
|
+
|
|
146
|
+
@property
|
|
147
|
+
def distances(self) -> np.ndarray:
|
|
148
|
+
return Chem.rdmolops.Get3DDistanceMatrix(self.GetOwningMol())
|
|
149
|
+
|
|
150
|
+
@property
|
|
151
|
+
def centroid(self) -> np.ndarray:
|
|
152
|
+
return np.asarray(rdMolTransforms.ComputeCentroid(self))
|
|
153
|
+
|
|
154
|
+
def adjacency(
|
|
155
|
+
self,
|
|
156
|
+
fill: str = 'full',
|
|
157
|
+
radius: float = None,
|
|
158
|
+
sparse: bool = True,
|
|
159
|
+
self_loops: bool = False,
|
|
160
|
+
dtype: str = 'int32'
|
|
161
|
+
) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
|
|
162
|
+
radius = radius or np.inf
|
|
163
|
+
distances = self.distances
|
|
164
|
+
if not self_loops:
|
|
165
|
+
np.fill_diagonal(distances, np.inf)
|
|
166
|
+
within_radius = distances < radius
|
|
167
|
+
if fill == 'lower':
|
|
168
|
+
within_radius = np.tril(within_radius, k=-1)
|
|
169
|
+
elif fill == 'upper':
|
|
170
|
+
within_radius = np.triu(within_radius, k=1)
|
|
171
|
+
if sparse:
|
|
172
|
+
edge_source, edge_target = np.where(within_radius)
|
|
173
|
+
return edge_source.astype(dtype), edge_target.astype(dtype)
|
|
174
|
+
return within_radius.astype(dtype)
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
class Atom(Chem.Atom):
|
|
178
|
+
|
|
179
|
+
@classmethod
|
|
180
|
+
def cast(cls, obj: Chem.Atom) -> 'Atom':
|
|
181
|
+
obj.__class__ = cls
|
|
182
|
+
return obj
|
|
183
|
+
|
|
184
|
+
@property
|
|
185
|
+
def index(self) -> int:
|
|
186
|
+
return int(self.GetIdx())
|
|
187
|
+
|
|
188
|
+
@property
|
|
189
|
+
def neighbors(self) -> list['Atom']:
|
|
190
|
+
return [Atom.cast(neighbor) for neighbor in self.GetNeighbors()]
|
|
191
|
+
|
|
192
|
+
@property
|
|
193
|
+
def symbol(self) -> str:
|
|
194
|
+
return self.GetSymbol()
|
|
195
|
+
|
|
196
|
+
@property
|
|
197
|
+
def label(self):
|
|
198
|
+
if self.HasProp('molAtomMapNumber'):
|
|
199
|
+
return int(self.GetProp('molAtomMapNumber'))
|
|
200
|
+
return None
|
|
201
|
+
|
|
202
|
+
@label.setter
|
|
203
|
+
def label(self, value: int) -> None:
|
|
204
|
+
self.SetProp('molAtomMapNumber', str(value))
|
|
205
|
+
|
|
206
|
+
def __repr__(self) -> str:
|
|
207
|
+
return f'<Atom {self.GetSymbol()} at {hex(id(self))}>'
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
class Bond(Chem.Bond):
|
|
211
|
+
|
|
212
|
+
@classmethod
|
|
213
|
+
def cast(cls, obj: Chem.Bond) -> 'Bond':
|
|
214
|
+
obj.__class__ = cls
|
|
215
|
+
return obj
|
|
216
|
+
|
|
217
|
+
@property
|
|
218
|
+
def index(self) -> int:
|
|
219
|
+
return int(self.GetIdx())
|
|
220
|
+
|
|
221
|
+
def __repr__(self) -> str:
|
|
222
|
+
return f'<Bond {self.GetBondType().name} at {hex(id(self))}>'
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def get_mol(
|
|
226
|
+
encoding: str,
|
|
227
|
+
strict: bool = True,
|
|
228
|
+
assign_stereo_chemistry: bool = True,
|
|
229
|
+
) -> RDKitMol:
|
|
230
|
+
if not isinstance(encoding, str):
|
|
231
|
+
raise ValueError(
|
|
232
|
+
f'Input ({encoding}) is not a SMILES or InChI string.'
|
|
233
|
+
)
|
|
234
|
+
if encoding.startswith('InChI'):
|
|
235
|
+
mol = Chem.MolFromInchi(encoding, sanitize=False)
|
|
236
|
+
else:
|
|
237
|
+
mol = Chem.MolFromSmiles(encoding, sanitize=False)
|
|
238
|
+
if mol is not None:
|
|
239
|
+
mol = sanitize_mol(mol, strict, assign_stereo_chemistry)
|
|
240
|
+
if mol is not None:
|
|
241
|
+
return mol
|
|
242
|
+
raise ValueError(f'Could not obtain `chem.Mol` from {encoding}.')
|
|
243
|
+
|
|
244
|
+
def get_adjacency_matrix(
|
|
245
|
+
mol: RDKitMol,
|
|
246
|
+
fill: str = 'full',
|
|
247
|
+
sparse: bool = False,
|
|
248
|
+
self_loops: bool = False,
|
|
249
|
+
dtype: str = "int32",
|
|
250
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
251
|
+
adjacency: np.ndarray = Chem.GetAdjacencyMatrix(mol)
|
|
252
|
+
if fill == 'lower':
|
|
253
|
+
adjacency = np.tril(adjacency, k=-1)
|
|
254
|
+
elif fill == 'upper':
|
|
255
|
+
adjacency = np.triu(adjacency, k=1)
|
|
256
|
+
if self_loops:
|
|
257
|
+
adjacency += np.eye(adjacency.shape[0], dtype=adjacency.dtype)
|
|
258
|
+
if not sparse:
|
|
259
|
+
return adjacency.astype(dtype)
|
|
260
|
+
edge_source, edge_target = np.where(adjacency)
|
|
261
|
+
return edge_source.astype(dtype), edge_target.astype(dtype)
|
|
262
|
+
|
|
263
|
+
def sanitize_mol(
|
|
264
|
+
mol: RDKitMol,
|
|
265
|
+
strict: bool = True,
|
|
266
|
+
assign_stereo_chemistry: bool = True,
|
|
267
|
+
) -> Mol:
|
|
268
|
+
mol = Mol(mol)
|
|
269
|
+
flag = Chem.SanitizeMol(mol, catchErrors=True)
|
|
270
|
+
if flag != Chem.SanitizeFlags.SANITIZE_NONE:
|
|
271
|
+
if strict:
|
|
272
|
+
raise ValueError(f'Could not sanitize {mol}.')
|
|
273
|
+
warnings.warn(
|
|
274
|
+
f'Could not sanitize {mol}. Proceeding with partial sanitization.'
|
|
275
|
+
)
|
|
276
|
+
# Sanitize mol, excluding the steps causing the error previously
|
|
277
|
+
Chem.SanitizeMol(mol, sanitizeOps=Chem.SanitizeFlags.SANITIZE_ALL^flag)
|
|
278
|
+
if assign_stereo_chemistry:
|
|
279
|
+
Chem.AssignStereochemistry(
|
|
280
|
+
mol, cleanIt=True, force=True, flagPossibleStereoCenters=True)
|
|
281
|
+
return mol
|
|
282
|
+
|
|
283
|
+
def get_atoms(mol: Mol) -> list[Atom]:
|
|
284
|
+
return [
|
|
285
|
+
Atom.cast(mol.GetAtomWithIdx(i))
|
|
286
|
+
for i in range(mol.GetNumAtoms())
|
|
287
|
+
]
|
|
288
|
+
|
|
289
|
+
def get_bonds(mol: Mol) -> list[Bond]:
|
|
290
|
+
return [
|
|
291
|
+
Bond.cast(mol.GetBondWithIdx(int(i)))
|
|
292
|
+
for i in range(mol.GetNumBonds())
|
|
293
|
+
]
|
|
294
|
+
|
|
295
|
+
def add_hs(mol: Mol) -> Mol:
|
|
296
|
+
rdkit_mol = Chem.AddHs(mol)
|
|
297
|
+
rdkit_mol.__class__ = mol.__class__
|
|
298
|
+
return rdkit_mol
|
|
299
|
+
|
|
300
|
+
def remove_hs(mol: Mol) -> Mol:
|
|
301
|
+
rdkit_mol = Chem.RemoveHs(mol)
|
|
302
|
+
rdkit_mol.__class__ = mol.__class__
|
|
303
|
+
return rdkit_mol
|
|
304
|
+
|
|
305
|
+
def get_distances(
|
|
306
|
+
mol: Mol,
|
|
307
|
+
fill: str = 'full',
|
|
308
|
+
use_bond_order: bool = False,
|
|
309
|
+
use_atom_weights: bool = False
|
|
310
|
+
) -> np.ndarray:
|
|
311
|
+
dist_matrix = Chem.rdmolops.GetDistanceMatrix(
|
|
312
|
+
mol, useBO=use_bond_order, useAtomWts=use_atom_weights
|
|
313
|
+
)
|
|
314
|
+
# For disconnected nodes, a value of 1e8 is assigned to dist_matrix
|
|
315
|
+
# Here we convert this large value to -1.
|
|
316
|
+
# TODO: Add argument for filling disconnected node pairs.
|
|
317
|
+
dist_matrix = np.where(
|
|
318
|
+
dist_matrix >= 1e6, -1, dist_matrix
|
|
319
|
+
)
|
|
320
|
+
if fill == 'lower':
|
|
321
|
+
return np.tril(dist_matrix, k=-1)
|
|
322
|
+
elif fill == 'upper':
|
|
323
|
+
return np.triu(dist_matrix, k=1)
|
|
324
|
+
return dist_matrix
|
|
325
|
+
|
|
326
|
+
def get_shortest_paths(
|
|
327
|
+
mol: Mol,
|
|
328
|
+
radius: int,
|
|
329
|
+
self_loops: bool = False,
|
|
330
|
+
) -> list[list[int]]:
|
|
331
|
+
paths = []
|
|
332
|
+
for atom in mol.atoms:
|
|
333
|
+
queue = collections.deque([(atom, [atom.index])])
|
|
334
|
+
visited = set([atom.index])
|
|
335
|
+
while queue:
|
|
336
|
+
current_atom, path = queue.popleft()
|
|
337
|
+
if len(path) > (radius + 1):
|
|
338
|
+
continue
|
|
339
|
+
if len(path) > 1 or self_loops:
|
|
340
|
+
paths.append(path)
|
|
341
|
+
for neighbor in current_atom.neighbors:
|
|
342
|
+
if neighbor.index in visited:
|
|
343
|
+
continue
|
|
344
|
+
visited.add(neighbor.index)
|
|
345
|
+
queue.append((neighbor, path + [neighbor.index]))
|
|
346
|
+
return paths
|
|
347
|
+
|
|
348
|
+
def get_periodic_table():
|
|
349
|
+
return Chem.GetPeriodicTable()
|
|
350
|
+
|
|
351
|
+
def partial_charges(mol: 'Mol') -> list[float]:
|
|
352
|
+
rdPartialCharges.ComputeGasteigerCharges(mol)
|
|
353
|
+
return [atom.GetDoubleProp("_GasteigerCharge") for atom in mol.atoms]
|
|
354
|
+
|
|
355
|
+
def logp_contributions(mol: 'Mol') -> list[float]:
|
|
356
|
+
return [i[0] for i in rdMolDescriptors._CalcCrippenContribs(mol)]
|
|
357
|
+
|
|
358
|
+
def molar_refractivity_contributions(mol: 'Mol') -> list[float]:
|
|
359
|
+
return [i[1] for i in rdMolDescriptors._CalcCrippenContribs(mol)]
|
|
360
|
+
|
|
361
|
+
def total_polar_surface_area_contributions(mol: 'Mol') -> list[float]:
|
|
362
|
+
return list(rdMolDescriptors._CalcTPSAContribs(mol))
|
|
363
|
+
|
|
364
|
+
def accessible_surface_area_contributions(mol: 'Mol') -> list[float]:
|
|
365
|
+
return list(rdMolDescriptors._CalcLabuteASAContribs(mol)[0])
|
|
366
|
+
|
|
367
|
+
def hydrogen_acceptors(mol: 'Mol') -> list[bool]:
|
|
368
|
+
h_acceptors = [i[0] for i in Lipinski._HAcceptors(mol)]
|
|
369
|
+
return [atom.index in h_acceptors for atom in mol.atoms]
|
|
370
|
+
|
|
371
|
+
def hydrogen_donors(mol: 'Mol') -> list[bool]:
|
|
372
|
+
h_donors = [i[0] for i in Lipinski._HDonors(mol)]
|
|
373
|
+
return [atom.index in h_donors for atom in mol.atoms]
|
|
374
|
+
|
|
375
|
+
def hetero_atoms(mol: 'Mol') -> list[bool]:
|
|
376
|
+
hetero_atoms = [i[0] for i in Lipinski._Heteroatoms(mol)]
|
|
377
|
+
return [atom.index in hetero_atoms for atom in mol.atoms]
|
|
378
|
+
|
|
379
|
+
def rotatable_bonds(mol: 'Mol') -> list[bool]:
|
|
380
|
+
rotatable_bonds = [set(x) for x in Lipinski._RotatableBonds(mol)]
|
|
381
|
+
def is_rotatable(bond):
|
|
382
|
+
atom_indices = {bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()}
|
|
383
|
+
return atom_indices in rotatable_bonds
|
|
384
|
+
return [is_rotatable(bond) for bond in mol.bonds]
|
|
385
|
+
|
|
386
|
+
def conformer_deviations(mol: Mol, fill: str = 'full') -> np.array:
|
|
387
|
+
"""Root mean squared deviation (RMSD) matrix"""
|
|
388
|
+
num_confs = mol.num_conformers
|
|
389
|
+
deviations = rdMolAlign.GetAllConformerBestRMS(mol)
|
|
390
|
+
matrix = np.zeros((num_confs, num_confs))
|
|
391
|
+
k = 0
|
|
392
|
+
for i in range(num_confs):
|
|
393
|
+
for j in range(i+1, num_confs):
|
|
394
|
+
deviation = deviations[k]
|
|
395
|
+
if fill == 'upper':
|
|
396
|
+
matrix[i, j] = deviation
|
|
397
|
+
elif fill == 'lower':
|
|
398
|
+
matrix[j, i] = deviation
|
|
399
|
+
else:
|
|
400
|
+
matrix[i, j] = deviation
|
|
401
|
+
matrix[j, i] = deviation
|
|
402
|
+
k += 1
|
|
403
|
+
return matrix
|
|
404
|
+
|
|
405
|
+
def conformer_energies(
|
|
406
|
+
mol: Mol,
|
|
407
|
+
method: str = 'UFF',
|
|
408
|
+
) -> list[float]:
|
|
409
|
+
if method == 'UFF':
|
|
410
|
+
energies = _calc_uff_energies(mol)
|
|
411
|
+
else:
|
|
412
|
+
if method == 'MMFF':
|
|
413
|
+
method += '94'
|
|
414
|
+
variant = method
|
|
415
|
+
energies = _calc_mmff_energies(mol, variant)
|
|
416
|
+
return energies
|
|
417
|
+
|
|
418
|
+
def embed_conformers(
|
|
419
|
+
mol: Mol,
|
|
420
|
+
num_conformers: int,
|
|
421
|
+
method: str = 'ETKDGv3',
|
|
422
|
+
timeout: int | None = None,
|
|
423
|
+
random_seed: int | None = None,
|
|
424
|
+
**kwargs
|
|
425
|
+
) -> Mol:
|
|
426
|
+
available_embedding_methods = {
|
|
427
|
+
'ETDG': rdDistGeom.ETDG(),
|
|
428
|
+
'ETKDG': rdDistGeom.ETKDG(),
|
|
429
|
+
'ETKDGv2': rdDistGeom.ETKDGv2(),
|
|
430
|
+
'ETKDGv3': rdDistGeom.ETKDGv3(),
|
|
431
|
+
'srETKDGv3': rdDistGeom.srETKDGv3(),
|
|
432
|
+
'KDG': rdDistGeom.KDG()
|
|
433
|
+
}
|
|
434
|
+
mol = Mol(mol)
|
|
435
|
+
embedding_method = available_embedding_methods.get(method)
|
|
436
|
+
if embedding_method is None:
|
|
437
|
+
warnings.warn(
|
|
438
|
+
f'{method} is not available. Proceeding with ETKDGv3.'
|
|
439
|
+
)
|
|
440
|
+
embedding_method = available_embedding_methods['ETKDGv3']
|
|
441
|
+
|
|
442
|
+
for key, value in kwargs.items():
|
|
443
|
+
setattr(embedding_method, key, value)
|
|
444
|
+
|
|
445
|
+
if not timeout:
|
|
446
|
+
timeout = 0 # No timeout
|
|
447
|
+
|
|
448
|
+
if not random_seed:
|
|
449
|
+
random_seed = -1 # No random seed
|
|
450
|
+
|
|
451
|
+
embedding_method.randomSeed = random_seed
|
|
452
|
+
embedding_method.timeout = timeout
|
|
453
|
+
|
|
454
|
+
success = rdDistGeom.EmbedMultipleConfs(
|
|
455
|
+
mol, numConfs=num_conformers, params=embedding_method
|
|
456
|
+
)
|
|
457
|
+
num_successes = len(success)
|
|
458
|
+
if num_successes < num_conformers:
|
|
459
|
+
warnings.warn(
|
|
460
|
+
f'Could only embed {num_successes} out of {num_conformers} conformer(s) for '
|
|
461
|
+
f'{mol} using the specified method ({method}) and parameters. Attempting to '
|
|
462
|
+
f'embed the remaining {num_conformers-num_successes} using fallback methods.',
|
|
463
|
+
)
|
|
464
|
+
max_iters = 20 * mol.num_atoms # Doubling the number of iterations
|
|
465
|
+
for fallback_method in [method, 'ETDG', 'KDG']:
|
|
466
|
+
fallback_embedding_method = available_embedding_methods[fallback_method]
|
|
467
|
+
fallback_embedding_method.useRandomCoords = True
|
|
468
|
+
fallback_embedding_method.maxIterations = int(max_iters)
|
|
469
|
+
fallback_embedding_method.clearConfs = False
|
|
470
|
+
fallback_embedding_method.timeout = int(timeout)
|
|
471
|
+
fallback_embedding_method.randomSeed = int(random_seed)
|
|
472
|
+
success = rdDistGeom.EmbedMultipleConfs(
|
|
473
|
+
mol, numConfs=(num_conformers - num_successes), params=fallback_embedding_method
|
|
474
|
+
)
|
|
475
|
+
num_successes += len(success)
|
|
476
|
+
if num_successes == num_conformers:
|
|
477
|
+
break
|
|
478
|
+
else:
|
|
479
|
+
raise RuntimeError(
|
|
480
|
+
f'Could not embed {num_conformers} conformer(s) for {mol}. '
|
|
481
|
+
)
|
|
482
|
+
return mol
|
|
483
|
+
|
|
484
|
+
def optimize_conformers(
|
|
485
|
+
mol: Mol,
|
|
486
|
+
method: str = 'UFF',
|
|
487
|
+
max_iter: int = 200,
|
|
488
|
+
num_threads: bool = 1,
|
|
489
|
+
ignore_interfragment_interactions: bool = True,
|
|
490
|
+
vdw_threshold: float = 10.0,
|
|
491
|
+
) -> Mol:
|
|
492
|
+
if mol.num_conformers == 0:
|
|
493
|
+
warnings.warn(
|
|
494
|
+
f'{mol} has no conformers to optimize. Proceeding without it.'
|
|
495
|
+
)
|
|
496
|
+
return Mol(mol)
|
|
497
|
+
available_force_field_methods = ['MMFF', 'MMFF94', 'MMFF94s', 'UFF']
|
|
498
|
+
if method not in available_force_field_methods:
|
|
499
|
+
warnings.warn(
|
|
500
|
+
f'{method} is not available. Proceeding with universal force field (UFF).'
|
|
501
|
+
)
|
|
502
|
+
method = 'UFF'
|
|
503
|
+
mol_optimized = Mol(mol)
|
|
504
|
+
try:
|
|
505
|
+
if method.startswith('MMFF'):
|
|
506
|
+
variant = method
|
|
507
|
+
if variant == 'MMFF':
|
|
508
|
+
variant += '94'
|
|
509
|
+
_, _ = _mmff_optimize_conformers(
|
|
510
|
+
mol_optimized,
|
|
511
|
+
num_threads=num_threads,
|
|
512
|
+
max_iter=max_iter,
|
|
513
|
+
variant=variant,
|
|
514
|
+
ignore_interfragment_interactions=ignore_interfragment_interactions,
|
|
515
|
+
)
|
|
516
|
+
else:
|
|
517
|
+
_, _ = _uff_optimize_conformers(
|
|
518
|
+
mol_optimized,
|
|
519
|
+
num_threads=num_threads,
|
|
520
|
+
max_iter=max_iter,
|
|
521
|
+
vdw_threshold=vdw_threshold,
|
|
522
|
+
ignore_interfragment_interactions=ignore_interfragment_interactions,
|
|
523
|
+
)
|
|
524
|
+
except RuntimeError as e:
|
|
525
|
+
warnings.warn(
|
|
526
|
+
f'Unsuccessful {method} force field minimization for {mol}. Proceeding without it.',
|
|
527
|
+
)
|
|
528
|
+
return Mol(mol)
|
|
529
|
+
return mol_optimized
|
|
530
|
+
|
|
531
|
+
def prune_conformers(
|
|
532
|
+
mol: Mol,
|
|
533
|
+
keep: int = 1,
|
|
534
|
+
threshold: float = 0.0,
|
|
535
|
+
energy_force_field: str = 'UFF',
|
|
536
|
+
) -> Mol:
|
|
537
|
+
if mol.num_conformers == 0:
|
|
538
|
+
warnings.warn(
|
|
539
|
+
f'{mol} has no conformers to prune. Proceeding without it.'
|
|
540
|
+
)
|
|
541
|
+
return RDKitMol(mol)
|
|
542
|
+
|
|
543
|
+
threshold = threshold or 0.0
|
|
544
|
+
deviations = conformer_deviations(mol)
|
|
545
|
+
energies = conformer_energies(mol, method=energy_force_field)
|
|
546
|
+
sorted_indices = np.argsort(energies)
|
|
547
|
+
|
|
548
|
+
selected = [int(sorted_indices[0])]
|
|
549
|
+
|
|
550
|
+
for target in sorted_indices[1:]:
|
|
551
|
+
if len(selected) >= keep:
|
|
552
|
+
break
|
|
553
|
+
if np.all(deviations[target, selected] >= threshold):
|
|
554
|
+
selected.append(int(target))
|
|
555
|
+
|
|
556
|
+
mol_copy = Mol(mol)
|
|
557
|
+
mol_copy.RemoveAllConformers()
|
|
558
|
+
for cid in selected:
|
|
559
|
+
conformer = mol.get_conformer(cid)
|
|
560
|
+
mol_copy.AddConformer(conformer, assignId=True)
|
|
561
|
+
|
|
562
|
+
return mol_copy
|
|
563
|
+
|
|
564
|
+
def _uff_optimize_conformers(
|
|
565
|
+
mol: Mol,
|
|
566
|
+
num_threads: int = 1,
|
|
567
|
+
max_iter: int = 200,
|
|
568
|
+
vdw_threshold: float = 10.0,
|
|
569
|
+
ignore_interfragment_interactions: bool = True,
|
|
570
|
+
**kwargs,
|
|
571
|
+
) -> tuple[list[float], list[bool]]:
|
|
572
|
+
"""Universal Force Field Minimization.
|
|
573
|
+
"""
|
|
574
|
+
results = rdForceFieldHelpers.UFFOptimizeMoleculeConfs(
|
|
575
|
+
mol,
|
|
576
|
+
numThreads=num_threads,
|
|
577
|
+
maxIters=max_iter,
|
|
578
|
+
vdwThresh=vdw_threshold,
|
|
579
|
+
ignoreInterfragInteractions=ignore_interfragment_interactions,
|
|
580
|
+
)
|
|
581
|
+
energies = [r[1] for r in results]
|
|
582
|
+
converged = [r[0] == 0 for r in results]
|
|
583
|
+
return energies, converged
|
|
584
|
+
|
|
585
|
+
def _mmff_optimize_conformers(
|
|
586
|
+
mol: Mol,
|
|
587
|
+
num_threads: int = 1,
|
|
588
|
+
max_iter: int = 200,
|
|
589
|
+
variant: str = 'MMFF94',
|
|
590
|
+
ignore_interfragment_interactions: bool = True,
|
|
591
|
+
**kwargs,
|
|
592
|
+
) -> tuple[list[float], list[bool]]:
|
|
593
|
+
"""Merck Molecular Force Field Minimization.
|
|
594
|
+
"""
|
|
595
|
+
if not rdForceFieldHelpers.MMFFHasAllMoleculeParams(mol):
|
|
596
|
+
raise ValueError("Cannot minimize molecule using MMFF.")
|
|
597
|
+
rdForceFieldHelpers.MMFFSanitizeMolecule(mol)
|
|
598
|
+
results = rdForceFieldHelpers.MMFFOptimizeMoleculeConfs(
|
|
599
|
+
mol,
|
|
600
|
+
num_threads=num_threads,
|
|
601
|
+
maxIters=max_iter,
|
|
602
|
+
mmffVariant=variant,
|
|
603
|
+
ignoreInterfragInteractions=ignore_interfragment_interactions,
|
|
604
|
+
)
|
|
605
|
+
energies = [r[1] for r in results]
|
|
606
|
+
converged = [r[0] == 0 for r in results]
|
|
607
|
+
return energies, converged
|
|
608
|
+
|
|
609
|
+
def _calc_uff_energies(
|
|
610
|
+
mol: Mol,
|
|
611
|
+
) -> list[float]:
|
|
612
|
+
energies = []
|
|
613
|
+
for i in range(mol.num_conformers):
|
|
614
|
+
try:
|
|
615
|
+
force_field = rdForceFieldHelpers.UFFGetMoleculeForceField(mol, confId=i)
|
|
616
|
+
energies.append(force_field.CalcEnergy())
|
|
617
|
+
except Exception:
|
|
618
|
+
energies.append(float('nan'))
|
|
619
|
+
return energies
|
|
620
|
+
|
|
621
|
+
def _calc_mmff_energies(
|
|
622
|
+
mol: Mol,
|
|
623
|
+
variant: str = 'MMFF94',
|
|
624
|
+
) -> list[float]:
|
|
625
|
+
energies = []
|
|
626
|
+
if not rdForceFieldHelpers.MMFFHasAllMoleculeParams(mol):
|
|
627
|
+
raise ValueError("Cannot compute MMFF energies for this molecule.")
|
|
628
|
+
props = rdForceFieldHelpers.MMFFGetMoleculeProperties(mol, mmffVariant=variant)
|
|
629
|
+
for i in range(mol.num_conformers):
|
|
630
|
+
try:
|
|
631
|
+
force_field = rdForceFieldHelpers.MMFFGetMoleculeForceField(mol, props, confId=i)
|
|
632
|
+
energies.append(force_field.CalcEnergy())
|
|
633
|
+
except Exception:
|
|
634
|
+
energies.append(float('nan'))
|
|
635
|
+
return energies
|
|
636
|
+
|
|
637
|
+
def unpack_conformers(mol: Mol) -> list[Mol]:
|
|
638
|
+
mols = []
|
|
639
|
+
for conf in mol.get_conformers():
|
|
640
|
+
new_mol = RDKitMol(mol)
|
|
641
|
+
new_mol.RemoveAllConformers()
|
|
642
|
+
new_mol.AddConformer(conf, assignId=True)
|
|
643
|
+
new_mol.__class__ = mol.__class__
|
|
644
|
+
mols.append(new_mol)
|
|
645
|
+
return mols
|
|
646
|
+
|
|
647
|
+
_fingerprint_types = {
|
|
648
|
+
'rdkit': rdFingerprintGenerator.GetRDKitFPGenerator,
|
|
649
|
+
'morgan': rdFingerprintGenerator.GetMorganGenerator,
|
|
650
|
+
'topological_torsion': rdFingerprintGenerator.GetTopologicalTorsionGenerator,
|
|
651
|
+
'atom_pair': rdFingerprintGenerator.GetAtomPairGenerator,
|
|
652
|
+
}
|
|
653
|
+
|
|
654
|
+
def _get_fingerprint(
|
|
655
|
+
mol: Mol,
|
|
656
|
+
fp_type: str = 'morgan',
|
|
657
|
+
binary: bool = True,
|
|
658
|
+
dtype: str = 'float32',
|
|
659
|
+
**kwargs,
|
|
660
|
+
) -> np.ndarray:
|
|
661
|
+
fingerprint: rdFingerprintGenerator.FingerprintGenerator64 = (
|
|
662
|
+
_fingerprint_types[fp_type](**kwargs)
|
|
663
|
+
)
|
|
664
|
+
if not isinstance(mol, Mol):
|
|
665
|
+
mol = Mol.from_encoding(mol)
|
|
666
|
+
if binary:
|
|
667
|
+
fp: np.ndarray = fingerprint.GetFingerprintAsNumPy(mol)
|
|
668
|
+
else:
|
|
669
|
+
fp: np.ndarray = fingerprint.GetCountFingerprintAsNumPy(mol)
|
|
670
|
+
return fp.astype(dtype)
|
|
671
|
+
|
|
672
|
+
def _rdkit_fingerprint(
|
|
673
|
+
mol: RDKitMol,
|
|
674
|
+
size: int = 2048,
|
|
675
|
+
*,
|
|
676
|
+
min_path: int = 1,
|
|
677
|
+
max_path: int = 7,
|
|
678
|
+
binary: bool = True,
|
|
679
|
+
dtype: str = 'float32',
|
|
680
|
+
) -> np.ndarray:
|
|
681
|
+
fp_param = {'fpSize': size, 'minPath': min_path, 'maxPath': max_path}
|
|
682
|
+
return _get_fingerprint(mol, 'rdkit', binary, dtype, **fp_param)
|
|
683
|
+
|
|
684
|
+
def _morgan_fingerprint(
|
|
685
|
+
mol: RDKitMol,
|
|
686
|
+
size: int = 2048,
|
|
687
|
+
*,
|
|
688
|
+
radius: int = 3,
|
|
689
|
+
binary: bool = True,
|
|
690
|
+
dtype: str = 'float32',
|
|
691
|
+
) -> np.ndarray:
|
|
692
|
+
fp_param = {'radius': radius, 'fpSize': size}
|
|
693
|
+
return _get_fingerprint(mol, 'morgan', binary, dtype, **fp_param)
|
|
694
|
+
|
|
695
|
+
def _topological_torsion_fingerprint(
|
|
696
|
+
mol: RDKitMol,
|
|
697
|
+
size: int = 2048,
|
|
698
|
+
*,
|
|
699
|
+
binary: bool = True,
|
|
700
|
+
dtype: str = 'float32',
|
|
701
|
+
) -> np.ndarray:
|
|
702
|
+
fp_param = {'fpSize': size}
|
|
703
|
+
return _get_fingerprint(mol, 'topological_torsion', binary, dtype, **fp_param)
|
|
704
|
+
|
|
705
|
+
def _atom_pair_fingerprint(
|
|
706
|
+
mol: RDKitMol,
|
|
707
|
+
size: int = 2048,
|
|
708
|
+
*,
|
|
709
|
+
binary: bool = True,
|
|
710
|
+
dtype: str = 'float32',
|
|
711
|
+
) -> np.ndarray:
|
|
712
|
+
fp_param = {'fpSize': size}
|
|
713
|
+
return _get_fingerprint(mol, 'atom_pair', binary, dtype, **fp_param)
|
|
714
|
+
|