molcraft 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of molcraft might be problematic. Click here for more details.
- molcraft/__init__.py +16 -0
- molcraft/callbacks.py +21 -0
- molcraft/chem.py +600 -0
- molcraft/conformers.py +155 -0
- molcraft/descriptors.py +90 -0
- molcraft/experimental/__init__.py +1 -0
- molcraft/experimental/peptides.py +303 -0
- molcraft/features.py +387 -0
- molcraft/featurizers.py +693 -0
- molcraft/layers.py +1224 -0
- molcraft/models.py +441 -0
- molcraft/ops.py +129 -0
- molcraft/records.py +169 -0
- molcraft/tensors.py +527 -0
- molcraft-0.1.0a1.dist-info/METADATA +58 -0
- molcraft-0.1.0a1.dist-info/RECORD +19 -0
- molcraft-0.1.0a1.dist-info/WHEEL +5 -0
- molcraft-0.1.0a1.dist-info/licenses/LICENSE +21 -0
- molcraft-0.1.0a1.dist-info/top_level.txt +1 -0
molcraft/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
__version__ = '0.1.0a1'
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
|
5
|
+
|
|
6
|
+
from molcraft import chem
|
|
7
|
+
from molcraft import features
|
|
8
|
+
from molcraft import descriptors
|
|
9
|
+
from molcraft import conformers
|
|
10
|
+
from molcraft import featurizers
|
|
11
|
+
from molcraft import layers
|
|
12
|
+
from molcraft import models
|
|
13
|
+
from molcraft import ops
|
|
14
|
+
from molcraft import records
|
|
15
|
+
from molcraft import tensors
|
|
16
|
+
from molcraft import callbacks
|
molcraft/callbacks.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
import keras
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class TensorBoard(keras.callbacks.TensorBoard):
|
|
5
|
+
|
|
6
|
+
def _log_weights(self, epoch):
|
|
7
|
+
with self._train_writer.as_default():
|
|
8
|
+
for layer in self.model.layers:
|
|
9
|
+
for weight in layer.weights:
|
|
10
|
+
# Use weight.path istead of weight.name to distinguish
|
|
11
|
+
# weights of different layers.
|
|
12
|
+
histogram_weight_name = weight.path + "/histogram"
|
|
13
|
+
self.summary.histogram(
|
|
14
|
+
histogram_weight_name, weight, step=epoch
|
|
15
|
+
)
|
|
16
|
+
if self.write_images:
|
|
17
|
+
image_weight_name = weight.path + "/image"
|
|
18
|
+
self._log_weight_as_image(
|
|
19
|
+
weight, image_weight_name, epoch
|
|
20
|
+
)
|
|
21
|
+
self._train_writer.flush()
|
molcraft/chem.py
ADDED
|
@@ -0,0 +1,600 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
import collections
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from rdkit import Chem
|
|
6
|
+
from rdkit.Chem import Lipinski
|
|
7
|
+
from rdkit.Chem import rdDistGeom
|
|
8
|
+
from rdkit.Chem import rdDepictor
|
|
9
|
+
from rdkit.Chem import rdMolAlign
|
|
10
|
+
from rdkit.Chem import rdMolTransforms
|
|
11
|
+
from rdkit.Chem import rdPartialCharges
|
|
12
|
+
from rdkit.Chem import rdMolDescriptors
|
|
13
|
+
from rdkit.Chem import rdForceFieldHelpers
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class Mol(Chem.Mol):
|
|
17
|
+
|
|
18
|
+
@classmethod
|
|
19
|
+
def from_encoding(cls, encoding: str, explicit_hs: bool = False, **kwargs) -> 'Mol':
|
|
20
|
+
rdkit_mol = get_mol(encoding, **kwargs)
|
|
21
|
+
if not rdkit_mol:
|
|
22
|
+
return None
|
|
23
|
+
if explicit_hs:
|
|
24
|
+
rdkit_mol = Chem.AddHs(rdkit_mol)
|
|
25
|
+
rdkit_mol.__class__ = cls
|
|
26
|
+
return rdkit_mol
|
|
27
|
+
|
|
28
|
+
@property
|
|
29
|
+
def canonical_smiles(self) -> str:
|
|
30
|
+
return Chem.MolToSmiles(self, canonical=True)
|
|
31
|
+
|
|
32
|
+
@property
|
|
33
|
+
def bonds(self) -> list['Bond']:
|
|
34
|
+
if not hasattr(self, '_bonds'):
|
|
35
|
+
self._bonds = get_bonds(self)
|
|
36
|
+
return self._bonds
|
|
37
|
+
|
|
38
|
+
@property
|
|
39
|
+
def atoms(self) -> list['Atom']:
|
|
40
|
+
if not hasattr(self, '_atoms'):
|
|
41
|
+
self._atoms = get_atoms(self)
|
|
42
|
+
return self._atoms
|
|
43
|
+
|
|
44
|
+
@property
|
|
45
|
+
def num_conformers(self) -> int:
|
|
46
|
+
return int(self.GetNumConformers())
|
|
47
|
+
|
|
48
|
+
@property
|
|
49
|
+
def num_atoms(self) -> int:
|
|
50
|
+
return int(self.GetNumAtoms())
|
|
51
|
+
|
|
52
|
+
@property
|
|
53
|
+
def num_bonds(self) -> int:
|
|
54
|
+
return int(self.GetNumBonds())
|
|
55
|
+
|
|
56
|
+
def get_atom(
|
|
57
|
+
self,
|
|
58
|
+
atom: int | Chem.Atom
|
|
59
|
+
) -> 'Atom':
|
|
60
|
+
if isinstance(atom, Chem.Atom):
|
|
61
|
+
atom = atom.GetIdx()
|
|
62
|
+
return Atom.cast(self.GetAtomWithIdx(int(atom)))
|
|
63
|
+
|
|
64
|
+
def get_path_between_atoms(
|
|
65
|
+
self,
|
|
66
|
+
atom_i: int | Chem.Atom,
|
|
67
|
+
atom_j: int | Chem.Atom
|
|
68
|
+
) -> tuple[int]:
|
|
69
|
+
if isinstance(atom_i, Chem.Atom):
|
|
70
|
+
atom_i = atom_i.GetIdx()
|
|
71
|
+
if isinstance(atom_j, Chem.Atom):
|
|
72
|
+
atom_j = atom_j.GetIdx()
|
|
73
|
+
return Chem.rdmolops.GetShortestPath(
|
|
74
|
+
self, int(atom_i), int(atom_j)
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
def get_bond_between_atoms(
|
|
78
|
+
self,
|
|
79
|
+
atom_i: int | Chem.Atom,
|
|
80
|
+
atom_j: int | Chem.Atom,
|
|
81
|
+
) -> 'Bond':
|
|
82
|
+
if isinstance(atom_i, Chem.Atom):
|
|
83
|
+
atom_i = atom_i.GetIdx()
|
|
84
|
+
if isinstance(atom_j, Chem.Atom):
|
|
85
|
+
atom_j = atom_j.GetIdx()
|
|
86
|
+
return Bond.cast(self.GetBondBetweenAtoms(int(atom_i), int(atom_j)))
|
|
87
|
+
|
|
88
|
+
def adjacency(
|
|
89
|
+
self,
|
|
90
|
+
fill: str = 'upper',
|
|
91
|
+
sparse: bool = True,
|
|
92
|
+
self_loops: bool = False,
|
|
93
|
+
dtype: str= 'int32',
|
|
94
|
+
cache: bool = True
|
|
95
|
+
) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
|
|
96
|
+
if not hasattr(self, '_adjacency') or not cache:
|
|
97
|
+
self._adjacency = get_adjacency_matrix(
|
|
98
|
+
self, fill=fill, sparse=sparse, self_loops=self_loops, dtype=dtype
|
|
99
|
+
)
|
|
100
|
+
return self._adjacency
|
|
101
|
+
|
|
102
|
+
def get_conformer(self, index: int = 0) -> 'Conformer':
|
|
103
|
+
if self.num_conformers == 0:
|
|
104
|
+
warn(
|
|
105
|
+
'Molecule has no conformer. To embed conformer(s), invoke the `embed` method, '
|
|
106
|
+
'and optionally followed by `minimize()` to perform force field minimization.'
|
|
107
|
+
)
|
|
108
|
+
return None
|
|
109
|
+
return Conformer.cast(self.GetConformer(index))
|
|
110
|
+
|
|
111
|
+
def get_conformers(self) -> list['Conformer']:
|
|
112
|
+
if self.num_conformers == 0:
|
|
113
|
+
warn(
|
|
114
|
+
'Molecule has no conformers. To embed conformers, invoke the `embed` method, '
|
|
115
|
+
'and optionally followed by `minimize()` to perform force field minimization.'
|
|
116
|
+
)
|
|
117
|
+
return []
|
|
118
|
+
return [Conformer.cast(x) for x in self.GetConformers()]
|
|
119
|
+
|
|
120
|
+
def __len__(self) -> int:
|
|
121
|
+
return int(self.GetNumAtoms())
|
|
122
|
+
|
|
123
|
+
def _repr_png_(self) -> None:
|
|
124
|
+
return None
|
|
125
|
+
|
|
126
|
+
def __repr__(self) -> str:
|
|
127
|
+
return f'<{self.__class__.__name__} {self.canonical_smiles} at {hex(id(self))}>'
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
class Conformer(Chem.Conformer):
|
|
131
|
+
|
|
132
|
+
@classmethod
|
|
133
|
+
def cast(cls, obj: Chem.Conformer) -> 'Conformer':
|
|
134
|
+
obj.__class__ = cls
|
|
135
|
+
return obj
|
|
136
|
+
|
|
137
|
+
@property
|
|
138
|
+
def index(self) -> int:
|
|
139
|
+
return self.GetId()
|
|
140
|
+
|
|
141
|
+
@property
|
|
142
|
+
def coordinates(self) -> np.ndarray:
|
|
143
|
+
return self.GetPositions()
|
|
144
|
+
|
|
145
|
+
@property
|
|
146
|
+
def distances(self) -> np.ndarray:
|
|
147
|
+
return Chem.rdmolops.Get3DDistanceMatrix(self.GetOwningMol())
|
|
148
|
+
|
|
149
|
+
@property
|
|
150
|
+
def centroid(self) -> np.ndarray:
|
|
151
|
+
return np.asarray(rdMolTransforms.ComputeCentroid(self))
|
|
152
|
+
|
|
153
|
+
def adjacency(
|
|
154
|
+
self,
|
|
155
|
+
fill: str = 'full',
|
|
156
|
+
radius: float = None,
|
|
157
|
+
sparse: bool = True,
|
|
158
|
+
self_loops: bool = False,
|
|
159
|
+
dtype: str = 'int32'
|
|
160
|
+
) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
|
|
161
|
+
radius = radius or np.inf
|
|
162
|
+
distances = self.distances
|
|
163
|
+
if not self_loops:
|
|
164
|
+
np.fill_diagonal(distances, np.inf)
|
|
165
|
+
within_radius = distances < radius
|
|
166
|
+
if fill == 'lower':
|
|
167
|
+
within_radius = np.tril(within_radius, k=-1)
|
|
168
|
+
elif fill == 'upper':
|
|
169
|
+
within_radius = np.triu(within_radius, k=1)
|
|
170
|
+
if sparse:
|
|
171
|
+
edge_source, edge_target = np.where(within_radius)
|
|
172
|
+
return edge_source.astype(dtype), edge_target.astype(dtype)
|
|
173
|
+
return within_radius.astype(dtype)
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
class Atom(Chem.Atom):
|
|
177
|
+
|
|
178
|
+
@classmethod
|
|
179
|
+
def cast(cls, obj: Chem.Atom) -> 'Atom':
|
|
180
|
+
obj.__class__ = cls
|
|
181
|
+
return obj
|
|
182
|
+
|
|
183
|
+
@property
|
|
184
|
+
def index(self) -> int:
|
|
185
|
+
return int(self.GetIdx())
|
|
186
|
+
|
|
187
|
+
@property
|
|
188
|
+
def neighbors(self) -> list['Atom']:
|
|
189
|
+
return [Atom.cast(neighbor) for neighbor in self.GetNeighbors()]
|
|
190
|
+
|
|
191
|
+
def __repr__(self) -> str:
|
|
192
|
+
return f'<Atom {self.GetSymbol()} at {hex(id(self))}>'
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
class Bond(Chem.Bond):
|
|
196
|
+
|
|
197
|
+
@classmethod
|
|
198
|
+
def cast(cls, obj: Chem.Bond) -> 'Bond':
|
|
199
|
+
obj.__class__ = cls
|
|
200
|
+
return obj
|
|
201
|
+
|
|
202
|
+
@property
|
|
203
|
+
def index(self) -> int:
|
|
204
|
+
return int(self.GetIdx())
|
|
205
|
+
|
|
206
|
+
def __repr__(self) -> str:
|
|
207
|
+
return f'<Bond {self.GetBondType().name} at {hex(id(self))}>'
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def get_mol(
|
|
211
|
+
encoding: str,
|
|
212
|
+
strict: bool = True,
|
|
213
|
+
assign_stereo_chemistry: bool = True,
|
|
214
|
+
) -> Chem.Mol:
|
|
215
|
+
if isinstance(encoding, Chem.Mol):
|
|
216
|
+
return encoding
|
|
217
|
+
if encoding.startswith('InChI'):
|
|
218
|
+
mol = Chem.MolFromInchi(encoding, sanitize=False)
|
|
219
|
+
else:
|
|
220
|
+
mol = Chem.MolFromSmiles(encoding, sanitize=False)
|
|
221
|
+
if mol is not None:
|
|
222
|
+
return sanitize_mol(mol, strict, assign_stereo_chemistry)
|
|
223
|
+
raise ValueError(
|
|
224
|
+
f"{encoding} is invalid; "
|
|
225
|
+
f"make sure {encoding} is a valid SMILES or InChI string."
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
def get_adjacency_matrix(
|
|
229
|
+
mol: Chem.Mol,
|
|
230
|
+
fill: str = 'full',
|
|
231
|
+
sparse: bool = False,
|
|
232
|
+
self_loops: bool = False,
|
|
233
|
+
dtype: str = "int32",
|
|
234
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
235
|
+
adjacency: np.ndarray = Chem.GetAdjacencyMatrix(mol)
|
|
236
|
+
if fill == 'lower':
|
|
237
|
+
adjacency = np.tril(adjacency, k=-1)
|
|
238
|
+
elif fill == 'upper':
|
|
239
|
+
adjacency = np.triu(adjacency, k=1)
|
|
240
|
+
if self_loops:
|
|
241
|
+
adjacency += np.eye(adjacency.shape[0], dtype=adjacency.dtype)
|
|
242
|
+
if not sparse:
|
|
243
|
+
return adjacency.astype(dtype)
|
|
244
|
+
edge_source, edge_target = np.where(adjacency)
|
|
245
|
+
return edge_source.astype(dtype), edge_target.astype(dtype)
|
|
246
|
+
|
|
247
|
+
def sanitize_mol(
|
|
248
|
+
mol: Chem.Mol,
|
|
249
|
+
strict: bool = True,
|
|
250
|
+
assign_stereo_chemistry: bool = True,
|
|
251
|
+
) -> Chem.Mol:
|
|
252
|
+
flag = Chem.SanitizeMol(mol, catchErrors=True)
|
|
253
|
+
if flag != Chem.SanitizeFlags.SANITIZE_NONE:
|
|
254
|
+
if strict:
|
|
255
|
+
return None
|
|
256
|
+
# Sanitize mol, excluding the steps causing the error previously
|
|
257
|
+
Chem.SanitizeMol(mol, sanitizeOps=Chem.SanitizeFlags.SANITIZE_ALL^flag)
|
|
258
|
+
if assign_stereo_chemistry:
|
|
259
|
+
Chem.AssignStereochemistry(
|
|
260
|
+
mol, cleanIt=True, force=True, flagPossibleStereoCenters=True)
|
|
261
|
+
return mol
|
|
262
|
+
|
|
263
|
+
def get_atoms(mol: Mol) -> list[Atom]:
|
|
264
|
+
return [
|
|
265
|
+
Atom.cast(mol.GetAtomWithIdx(i))
|
|
266
|
+
for i in range(mol.GetNumAtoms())
|
|
267
|
+
]
|
|
268
|
+
|
|
269
|
+
def get_bonds(mol: Mol) -> list[Bond]:
|
|
270
|
+
return [
|
|
271
|
+
Bond.cast(mol.GetBondWithIdx(int(i)))
|
|
272
|
+
for i in range(mol.GetNumBonds())
|
|
273
|
+
]
|
|
274
|
+
|
|
275
|
+
def add_hs(mol: Mol) -> Mol:
|
|
276
|
+
rdkit_mol = Chem.AddHs(mol)
|
|
277
|
+
rdkit_mol.__class__ = mol.__class__
|
|
278
|
+
return rdkit_mol
|
|
279
|
+
|
|
280
|
+
def remove_hs(mol: Mol) -> Mol:
|
|
281
|
+
rdkit_mol = Chem.RemoveHs(mol)
|
|
282
|
+
rdkit_mol.__class__ = mol.__class__
|
|
283
|
+
return rdkit_mol
|
|
284
|
+
|
|
285
|
+
def get_distances(
|
|
286
|
+
mol: Mol,
|
|
287
|
+
fill: str = 'full',
|
|
288
|
+
use_bond_order: bool = False,
|
|
289
|
+
use_atom_weights: bool = False
|
|
290
|
+
) -> np.ndarray:
|
|
291
|
+
dist_matrix = Chem.rdmolops.GetDistanceMatrix(
|
|
292
|
+
mol, useBO=use_bond_order, useAtomWts=use_atom_weights
|
|
293
|
+
)
|
|
294
|
+
# For disconnected nodes, a value of 1e8 is assigned to dist_matrix
|
|
295
|
+
# Here we convert this large value to -1.
|
|
296
|
+
# TODO: Add argument for filling disconnected node pairs.
|
|
297
|
+
dist_matrix = np.where(
|
|
298
|
+
dist_matrix >= 1e6, -1, dist_matrix
|
|
299
|
+
)
|
|
300
|
+
if fill == 'lower':
|
|
301
|
+
return np.tril(dist_matrix, k=-1)
|
|
302
|
+
elif fill == 'upper':
|
|
303
|
+
return np.triu(dist_matrix, k=1)
|
|
304
|
+
return dist_matrix
|
|
305
|
+
|
|
306
|
+
def get_shortest_paths(
|
|
307
|
+
mol: Mol,
|
|
308
|
+
radius: int,
|
|
309
|
+
self_loops: bool = False,
|
|
310
|
+
) -> list[list[int]]:
|
|
311
|
+
paths = []
|
|
312
|
+
for atom in mol.atoms:
|
|
313
|
+
queue = collections.deque([(atom, [atom.index])])
|
|
314
|
+
visited = set([atom.index])
|
|
315
|
+
while queue:
|
|
316
|
+
current_atom, path = queue.popleft()
|
|
317
|
+
if len(path) > (radius + 1):
|
|
318
|
+
continue
|
|
319
|
+
if len(path) > 1 or self_loops:
|
|
320
|
+
paths.append(path)
|
|
321
|
+
for neighbor in current_atom.neighbors:
|
|
322
|
+
if neighbor.index in visited:
|
|
323
|
+
continue
|
|
324
|
+
visited.add(neighbor.index)
|
|
325
|
+
queue.append((neighbor, path + [neighbor.index]))
|
|
326
|
+
return paths
|
|
327
|
+
|
|
328
|
+
def get_periodic_table():
|
|
329
|
+
return Chem.GetPeriodicTable()
|
|
330
|
+
|
|
331
|
+
def gasteiger_charges(mol: 'Mol') -> list[float]:
|
|
332
|
+
rdPartialCharges.ComputeGasteigerCharges(mol)
|
|
333
|
+
return [atom.GetDoubleProp("_GasteigerCharge") for atom in mol.atoms]
|
|
334
|
+
|
|
335
|
+
def logp_contributions(mol: 'Mol') -> list[float]:
|
|
336
|
+
return [i[0] for i in rdMolDescriptors._CalcCrippenContribs(mol)]
|
|
337
|
+
|
|
338
|
+
def molar_refractivity_contribution(mol: 'Mol') -> list[float]:
|
|
339
|
+
return [i[1] for i in rdMolDescriptors._CalcCrippenContribs(mol)]
|
|
340
|
+
|
|
341
|
+
def tpsa_contribution(mol: 'Mol') -> list[float]:
|
|
342
|
+
return list(rdMolDescriptors._CalcTPSAContribs(mol))
|
|
343
|
+
|
|
344
|
+
def asa_contribution(mol: 'Mol') -> list[float]:
|
|
345
|
+
return list(rdMolDescriptors._CalcLabuteASAContribs(mol)[0])
|
|
346
|
+
|
|
347
|
+
def hydrogen_acceptors(mol: 'Mol') -> list[bool]:
|
|
348
|
+
h_acceptors = [i[0] for i in Lipinski._HAcceptors(mol)]
|
|
349
|
+
return [atom.index in h_acceptors for atom in mol.atoms]
|
|
350
|
+
|
|
351
|
+
def hydrogen_donors(mol: 'Mol') -> list[bool]:
|
|
352
|
+
h_donors = [i[0] for i in Lipinski._HDonors(mol)]
|
|
353
|
+
return [atom.index in h_donors for atom in mol.atoms]
|
|
354
|
+
|
|
355
|
+
def hetero_atoms(mol: 'Mol') -> list[bool]:
|
|
356
|
+
hetero_atoms = [i[0] for i in Lipinski._Heteroatoms(mol)]
|
|
357
|
+
return [atom.index in hetero_atoms for atom in mol.atoms]
|
|
358
|
+
|
|
359
|
+
def rotatable_bonds(mol: 'Mol') -> list[bool]:
|
|
360
|
+
rotatable_bonds = [set(x) for x in Lipinski._RotatableBonds(mol)]
|
|
361
|
+
def is_rotatable(bond):
|
|
362
|
+
atom_indices = {bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()}
|
|
363
|
+
return atom_indices in rotatable_bonds
|
|
364
|
+
return [is_rotatable(bond) for bond in mol.bonds]
|
|
365
|
+
|
|
366
|
+
def conformer_deviations(mol: Mol, fill: str = 'full') -> np.array:
|
|
367
|
+
"""Root mean squared deviation (RMSD) matrix"""
|
|
368
|
+
num_confs = mol.num_conformers
|
|
369
|
+
deviations = rdMolAlign.GetAllConformerBestRMS(mol)
|
|
370
|
+
matrix = np.zeros((num_confs, num_confs))
|
|
371
|
+
k = 0
|
|
372
|
+
for i in range(num_confs):
|
|
373
|
+
for j in range(i+1, num_confs):
|
|
374
|
+
deviation = deviations[k]
|
|
375
|
+
if fill == 'upper':
|
|
376
|
+
matrix[i, j] = deviation
|
|
377
|
+
elif fill == 'lower':
|
|
378
|
+
matrix[j, i] = deviation
|
|
379
|
+
else:
|
|
380
|
+
matrix[i, j] = deviation
|
|
381
|
+
matrix[j, i] = deviation
|
|
382
|
+
k += 1
|
|
383
|
+
return matrix
|
|
384
|
+
|
|
385
|
+
def conformer_energies(
|
|
386
|
+
mol: Mol,
|
|
387
|
+
method: str = 'UFF',
|
|
388
|
+
) -> list[float]:
|
|
389
|
+
if method == 'UFF':
|
|
390
|
+
energies = _calc_uff_energies(mol)
|
|
391
|
+
else:
|
|
392
|
+
if method == 'MMFF':
|
|
393
|
+
method += '94'
|
|
394
|
+
variant = method
|
|
395
|
+
energies = _calc_mmff_energies(mol, variant)
|
|
396
|
+
return energies
|
|
397
|
+
|
|
398
|
+
def embed_conformers(
|
|
399
|
+
mol: Mol,
|
|
400
|
+
num_conformers: int,
|
|
401
|
+
method: str = 'ETKDGv3',
|
|
402
|
+
force: bool = True,
|
|
403
|
+
**kwargs
|
|
404
|
+
) -> None:
|
|
405
|
+
available_embedding_methods = {
|
|
406
|
+
'ETDG': rdDistGeom.ETDG(),
|
|
407
|
+
'ETKDG': rdDistGeom.ETKDG(),
|
|
408
|
+
'ETKDGv2': rdDistGeom.ETKDGv2(),
|
|
409
|
+
'ETKDGv3': rdDistGeom.ETKDGv3(),
|
|
410
|
+
'srETKDGv3': rdDistGeom.srETKDGv3(),
|
|
411
|
+
'KDG': rdDistGeom.KDG()
|
|
412
|
+
}
|
|
413
|
+
default_embedding_method = 'ETKDGv3'
|
|
414
|
+
mol = Mol(mol)
|
|
415
|
+
params = available_embedding_methods.get(method)
|
|
416
|
+
if params is None:
|
|
417
|
+
warn(
|
|
418
|
+
f"Could not find `method` {method}. "
|
|
419
|
+
f"Automatically setting method to {default_embedding_method}."
|
|
420
|
+
)
|
|
421
|
+
params = available_embedding_methods[default_embedding_method]
|
|
422
|
+
for key, value in kwargs.items():
|
|
423
|
+
setattr(params, key, value)
|
|
424
|
+
|
|
425
|
+
success = rdDistGeom.EmbedMultipleConfs(mol, numConfs=num_conformers, params=params)
|
|
426
|
+
if not len(success):
|
|
427
|
+
warning = 'Could not embed conformer(s).'
|
|
428
|
+
if not force:
|
|
429
|
+
warn(warning)
|
|
430
|
+
else:
|
|
431
|
+
solution = ' Embedding a conformer (in 3D space) using (x, y) coordinates.'
|
|
432
|
+
warn(warning + solution)
|
|
433
|
+
rdDepictor.Compute2DCoords(mol)
|
|
434
|
+
return mol
|
|
435
|
+
|
|
436
|
+
def optimize_conformers(
|
|
437
|
+
mol: Mol,
|
|
438
|
+
method: str = 'UFF',
|
|
439
|
+
max_iter: int = 200,
|
|
440
|
+
num_threads: bool = 1,
|
|
441
|
+
ignore_interfragment_interactions: bool = True,
|
|
442
|
+
vdw_threshold: float = 10.0,
|
|
443
|
+
):
|
|
444
|
+
available_force_field_methods = [
|
|
445
|
+
'MMFF', 'MMFF94', 'MMFF94s', 'UFF'
|
|
446
|
+
]
|
|
447
|
+
mol = Mol(mol)
|
|
448
|
+
try:
|
|
449
|
+
if method.startswith('MMFF'):
|
|
450
|
+
variant = method
|
|
451
|
+
if variant == 'MMFF':
|
|
452
|
+
variant += '94'
|
|
453
|
+
_, _ = _mmff_optimize_conformers(
|
|
454
|
+
mol,
|
|
455
|
+
num_threads=num_threads,
|
|
456
|
+
max_iter=max_iter,
|
|
457
|
+
variant=variant,
|
|
458
|
+
ignore_interfragment_interactions=ignore_interfragment_interactions,
|
|
459
|
+
)
|
|
460
|
+
else:
|
|
461
|
+
_, _ = _uff_optimize_conformers(
|
|
462
|
+
mol,
|
|
463
|
+
num_threads=num_threads,
|
|
464
|
+
max_iter=max_iter,
|
|
465
|
+
vdw_threshold=vdw_threshold,
|
|
466
|
+
ignore_interfragment_interactions=ignore_interfragment_interactions,
|
|
467
|
+
)
|
|
468
|
+
except RuntimeError as e:
|
|
469
|
+
warn(
|
|
470
|
+
f'{method} force field minimization raised {e}. '
|
|
471
|
+
'\nProceeding without force field minimization...'
|
|
472
|
+
)
|
|
473
|
+
return mol
|
|
474
|
+
|
|
475
|
+
def prune_conformers(
|
|
476
|
+
mol: Mol,
|
|
477
|
+
keep: int = 1,
|
|
478
|
+
threshold: float = 0.0,
|
|
479
|
+
energy_force_field: str = 'UFF',
|
|
480
|
+
):
|
|
481
|
+
if mol.num_conformers == 0:
|
|
482
|
+
warn(
|
|
483
|
+
'Molecule has no conformers. To embed conformers, invoke the `embed` method, '
|
|
484
|
+
'and optionally followed by `minimize()` to perform force field minimization.'
|
|
485
|
+
)
|
|
486
|
+
return mol
|
|
487
|
+
|
|
488
|
+
threshold = threshold or 0.0
|
|
489
|
+
deviations = conformer_deviations(mol)
|
|
490
|
+
energies = conformer_energies(mol, method=energy_force_field)
|
|
491
|
+
sorted_indices = np.argsort(energies)
|
|
492
|
+
|
|
493
|
+
selected = [int(sorted_indices[0])]
|
|
494
|
+
|
|
495
|
+
for target in sorted_indices[1:]:
|
|
496
|
+
if len(selected) >= keep:
|
|
497
|
+
break
|
|
498
|
+
if np.all(deviations[target, selected] >= threshold):
|
|
499
|
+
selected.append(int(target))
|
|
500
|
+
|
|
501
|
+
mol_copy = Mol(mol)
|
|
502
|
+
mol_copy.RemoveAllConformers()
|
|
503
|
+
for cid in selected:
|
|
504
|
+
conformer = mol.get_conformer(cid)
|
|
505
|
+
mol_copy.AddConformer(conformer, assignId=True)
|
|
506
|
+
|
|
507
|
+
return mol_copy
|
|
508
|
+
|
|
509
|
+
def _uff_optimize_conformers(
|
|
510
|
+
mol: Mol,
|
|
511
|
+
num_threads: int = 1,
|
|
512
|
+
max_iter: int = 200,
|
|
513
|
+
vdw_threshold: float = 10.0,
|
|
514
|
+
ignore_interfragment_interactions: bool = True,
|
|
515
|
+
**kwargs,
|
|
516
|
+
) -> Mol:
|
|
517
|
+
"""Universal Force Field Minimization.
|
|
518
|
+
"""
|
|
519
|
+
results = rdForceFieldHelpers.UFFOptimizeMoleculeConfs(
|
|
520
|
+
mol,
|
|
521
|
+
numThreads=num_threads,
|
|
522
|
+
maxIters=max_iter,
|
|
523
|
+
vdwThresh=vdw_threshold,
|
|
524
|
+
ignoreInterfragInteractions=ignore_interfragment_interactions,
|
|
525
|
+
)
|
|
526
|
+
energies = [r[1] for r in results]
|
|
527
|
+
converged = [r[0] == 0 for r in results]
|
|
528
|
+
return energies, converged
|
|
529
|
+
|
|
530
|
+
def _mmff_optimize_conformers(
|
|
531
|
+
mol: Mol,
|
|
532
|
+
num_threads: int = 1,
|
|
533
|
+
max_iter: int = 200,
|
|
534
|
+
variant: str = 'MMFF94',
|
|
535
|
+
ignore_interfragment_interactions: bool = True,
|
|
536
|
+
**kwargs,
|
|
537
|
+
) -> Mol:
|
|
538
|
+
"""Merck Molecular Force Field Minimization.
|
|
539
|
+
"""
|
|
540
|
+
if not rdForceFieldHelpers.MMFFHasAllMoleculeParams(mol):
|
|
541
|
+
raise ValueError("Cannot minimize molecule using MMFF.")
|
|
542
|
+
rdForceFieldHelpers.MMFFSanitizeMolecule(mol)
|
|
543
|
+
results = rdForceFieldHelpers.MMFFOptimizeMoleculeConfs(
|
|
544
|
+
mol,
|
|
545
|
+
num_threads=num_threads,
|
|
546
|
+
maxIters=max_iter,
|
|
547
|
+
mmffVariant=variant,
|
|
548
|
+
ignoreInterfragInteractions=ignore_interfragment_interactions,
|
|
549
|
+
)
|
|
550
|
+
energies = [r[1] for r in results]
|
|
551
|
+
converged = [r[0] == 0 for r in results]
|
|
552
|
+
return energies, converged
|
|
553
|
+
|
|
554
|
+
def _calc_uff_energies(
|
|
555
|
+
mol: Mol,
|
|
556
|
+
) -> list[float]:
|
|
557
|
+
energies = []
|
|
558
|
+
for i in range(mol.num_conformers):
|
|
559
|
+
try:
|
|
560
|
+
force_field = rdForceFieldHelpers.UFFGetMoleculeForceField(mol, confId=i)
|
|
561
|
+
energies.append(force_field.CalcEnergy())
|
|
562
|
+
except Exception:
|
|
563
|
+
energies.append(float('nan'))
|
|
564
|
+
return energies
|
|
565
|
+
|
|
566
|
+
def _calc_mmff_energies(
|
|
567
|
+
mol: Mol,
|
|
568
|
+
variant: str = 'MMFF94',
|
|
569
|
+
) -> list[float]:
|
|
570
|
+
energies = []
|
|
571
|
+
if not rdForceFieldHelpers.MMFFHasAllMoleculeParams(mol):
|
|
572
|
+
raise ValueError("Cannot compute MMFF energies for this molecule.")
|
|
573
|
+
props = rdForceFieldHelpers.MMFFGetMoleculeProperties(mol, mmffVariant=variant)
|
|
574
|
+
for i in range(mol.num_conformers):
|
|
575
|
+
try:
|
|
576
|
+
force_field = rdForceFieldHelpers.MMFFGetMoleculeForceField(mol, props, confId=i)
|
|
577
|
+
energies.append(force_field.CalcEnergy())
|
|
578
|
+
except Exception:
|
|
579
|
+
energies.append(float('nan'))
|
|
580
|
+
return energies
|
|
581
|
+
|
|
582
|
+
|
|
583
|
+
def _split_mol_by_confs(mol: Mol) -> list[Mol]:
|
|
584
|
+
mols = []
|
|
585
|
+
for conf in mol.get_conformers():
|
|
586
|
+
new_mol = Chem.Mol(mol)
|
|
587
|
+
new_mol.RemoveAllConformers()
|
|
588
|
+
new_mol.AddConformer(conf, assignId=True)
|
|
589
|
+
new_mol.__class__ = mol.__class__
|
|
590
|
+
mols.append(new_mol)
|
|
591
|
+
return mols
|
|
592
|
+
|
|
593
|
+
def warn(message: str) -> None:
|
|
594
|
+
warnings.warn(
|
|
595
|
+
message=message,
|
|
596
|
+
category=UserWarning,
|
|
597
|
+
stacklevel=1,
|
|
598
|
+
)
|
|
599
|
+
|
|
600
|
+
|