molcraft 0.1.0rc10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
molcraft/chem.py ADDED
@@ -0,0 +1,714 @@
1
+ import warnings
2
+ import collections
3
+ import numpy as np
4
+
5
+ from rdkit import Chem
6
+ from rdkit.Chem import AllChem
7
+ from rdkit.Chem import Lipinski
8
+ from rdkit.Chem import rdDistGeom
9
+ from rdkit.Chem import rdDepictor
10
+ from rdkit.Chem import rdMolAlign
11
+ from rdkit.Chem import rdMolTransforms
12
+ from rdkit.Chem import rdPartialCharges
13
+ from rdkit.Chem import rdMolDescriptors
14
+ from rdkit.Chem import rdForceFieldHelpers
15
+ from rdkit.Chem import rdFingerprintGenerator
16
+
17
+
18
+ RDKitMol = Chem.Mol
19
+
20
+
21
+ class Mol(RDKitMol):
22
+
23
+ @classmethod
24
+ def from_encoding(cls, encoding: str, explicit_hs: bool = False, **kwargs) -> 'Mol':
25
+ rdkit_mol = get_mol(encoding, **kwargs)
26
+ if explicit_hs:
27
+ rdkit_mol = Chem.AddHs(rdkit_mol)
28
+ rdkit_mol.__class__ = cls
29
+ setattr(rdkit_mol, '_encoding', encoding)
30
+ return rdkit_mol
31
+
32
+ @classmethod
33
+ def cast(cls, obj: RDKitMol) -> 'Mol':
34
+ obj.__class__ = cls
35
+ return obj
36
+
37
+ @property
38
+ def canonical_smiles(self) -> str:
39
+ return Chem.MolToSmiles(self, canonical=True)
40
+
41
+ @property
42
+ def encoding(self):
43
+ return getattr(self, '_encoding', None)
44
+
45
+ @property
46
+ def bonds(self) -> list['Bond']:
47
+ return get_bonds(self)
48
+
49
+ @property
50
+ def atoms(self) -> list['Atom']:
51
+ return get_atoms(self)
52
+
53
+ @property
54
+ def num_conformers(self) -> int:
55
+ return int(self.GetNumConformers())
56
+
57
+ @property
58
+ def num_atoms(self) -> int:
59
+ return int(self.GetNumAtoms())
60
+
61
+ @property
62
+ def num_bonds(self) -> int:
63
+ return int(self.GetNumBonds())
64
+
65
+ def get_atom(
66
+ self,
67
+ atom: int | Chem.Atom
68
+ ) -> 'Atom':
69
+ if isinstance(atom, Chem.Atom):
70
+ atom = atom.GetIdx()
71
+ return Atom.cast(self.GetAtomWithIdx(int(atom)))
72
+
73
+ def get_shortest_path_between_atoms(
74
+ self,
75
+ atom_i: int | Chem.Atom,
76
+ atom_j: int | Chem.Atom
77
+ ) -> tuple[int]:
78
+ if isinstance(atom_i, Chem.Atom):
79
+ atom_i = atom_i.GetIdx()
80
+ if isinstance(atom_j, Chem.Atom):
81
+ atom_j = atom_j.GetIdx()
82
+ return Chem.rdmolops.GetShortestPath(
83
+ self, int(atom_i), int(atom_j)
84
+ )
85
+
86
+ def get_bond_between_atoms(
87
+ self,
88
+ atom_i: int | Chem.Atom,
89
+ atom_j: int | Chem.Atom,
90
+ ) -> 'Bond':
91
+ if isinstance(atom_i, Chem.Atom):
92
+ atom_i = atom_i.GetIdx()
93
+ if isinstance(atom_j, Chem.Atom):
94
+ atom_j = atom_j.GetIdx()
95
+ return Bond.cast(self.GetBondBetweenAtoms(int(atom_i), int(atom_j)))
96
+
97
+ def adjacency(
98
+ self,
99
+ fill: str = 'upper',
100
+ sparse: bool = True,
101
+ self_loops: bool = False,
102
+ dtype: str= 'int32',
103
+ ) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
104
+ return get_adjacency_matrix(
105
+ self, fill=fill, sparse=sparse, self_loops=self_loops, dtype=dtype
106
+ )
107
+
108
+ def get_conformer(self, index: int = 0) -> 'Conformer':
109
+ if self.num_conformers == 0:
110
+ warnings.warn(f'{self} has no conformer. Returning None.')
111
+ return None
112
+ return Conformer.cast(self.GetConformer(index))
113
+
114
+ def get_conformers(self) -> list['Conformer']:
115
+ if self.num_conformers == 0:
116
+ warnings.warn(f'{self} has no conformers. Returning an empty list.')
117
+ return []
118
+ return [Conformer.cast(x) for x in self.GetConformers()]
119
+
120
+ def __len__(self) -> int:
121
+ return int(self.GetNumAtoms())
122
+
123
+ def _repr_png_(self) -> None:
124
+ return None
125
+
126
+ def __repr__(self) -> str:
127
+ encoding = self.encoding or self.canonical_smiles
128
+ return f'<{self.__class__.__name__} {encoding} at {hex(id(self))}>'
129
+
130
+
131
+ class Conformer(Chem.Conformer):
132
+
133
+ @classmethod
134
+ def cast(cls, obj: Chem.Conformer) -> 'Conformer':
135
+ obj.__class__ = cls
136
+ return obj
137
+
138
+ @property
139
+ def index(self) -> int:
140
+ return self.GetId()
141
+
142
+ @property
143
+ def coordinates(self) -> np.ndarray:
144
+ return self.GetPositions()
145
+
146
+ @property
147
+ def distances(self) -> np.ndarray:
148
+ return Chem.rdmolops.Get3DDistanceMatrix(self.GetOwningMol())
149
+
150
+ @property
151
+ def centroid(self) -> np.ndarray:
152
+ return np.asarray(rdMolTransforms.ComputeCentroid(self))
153
+
154
+ def adjacency(
155
+ self,
156
+ fill: str = 'full',
157
+ radius: float = None,
158
+ sparse: bool = True,
159
+ self_loops: bool = False,
160
+ dtype: str = 'int32'
161
+ ) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
162
+ radius = radius or np.inf
163
+ distances = self.distances
164
+ if not self_loops:
165
+ np.fill_diagonal(distances, np.inf)
166
+ within_radius = distances < radius
167
+ if fill == 'lower':
168
+ within_radius = np.tril(within_radius, k=-1)
169
+ elif fill == 'upper':
170
+ within_radius = np.triu(within_radius, k=1)
171
+ if sparse:
172
+ edge_source, edge_target = np.where(within_radius)
173
+ return edge_source.astype(dtype), edge_target.astype(dtype)
174
+ return within_radius.astype(dtype)
175
+
176
+
177
+ class Atom(Chem.Atom):
178
+
179
+ @classmethod
180
+ def cast(cls, obj: Chem.Atom) -> 'Atom':
181
+ obj.__class__ = cls
182
+ return obj
183
+
184
+ @property
185
+ def index(self) -> int:
186
+ return int(self.GetIdx())
187
+
188
+ @property
189
+ def neighbors(self) -> list['Atom']:
190
+ return [Atom.cast(neighbor) for neighbor in self.GetNeighbors()]
191
+
192
+ @property
193
+ def symbol(self) -> str:
194
+ return self.GetSymbol()
195
+
196
+ @property
197
+ def label(self):
198
+ if self.HasProp('molAtomMapNumber'):
199
+ return int(self.GetProp('molAtomMapNumber'))
200
+ return None
201
+
202
+ @label.setter
203
+ def label(self, value: int) -> None:
204
+ self.SetProp('molAtomMapNumber', str(value))
205
+
206
+ def __repr__(self) -> str:
207
+ return f'<Atom {self.GetSymbol()} at {hex(id(self))}>'
208
+
209
+
210
+ class Bond(Chem.Bond):
211
+
212
+ @classmethod
213
+ def cast(cls, obj: Chem.Bond) -> 'Bond':
214
+ obj.__class__ = cls
215
+ return obj
216
+
217
+ @property
218
+ def index(self) -> int:
219
+ return int(self.GetIdx())
220
+
221
+ def __repr__(self) -> str:
222
+ return f'<Bond {self.GetBondType().name} at {hex(id(self))}>'
223
+
224
+
225
+ def get_mol(
226
+ encoding: str,
227
+ strict: bool = True,
228
+ assign_stereo_chemistry: bool = True,
229
+ ) -> RDKitMol:
230
+ if not isinstance(encoding, str):
231
+ raise ValueError(
232
+ f'Input ({encoding}) is not a SMILES or InChI string.'
233
+ )
234
+ if encoding.startswith('InChI'):
235
+ mol = Chem.MolFromInchi(encoding, sanitize=False)
236
+ else:
237
+ mol = Chem.MolFromSmiles(encoding, sanitize=False)
238
+ if mol is not None:
239
+ mol = sanitize_mol(mol, strict, assign_stereo_chemistry)
240
+ if mol is not None:
241
+ return mol
242
+ raise ValueError(f'Could not obtain `chem.Mol` from {encoding}.')
243
+
244
+ def get_adjacency_matrix(
245
+ mol: RDKitMol,
246
+ fill: str = 'full',
247
+ sparse: bool = False,
248
+ self_loops: bool = False,
249
+ dtype: str = "int32",
250
+ ) -> tuple[np.ndarray, np.ndarray]:
251
+ adjacency: np.ndarray = Chem.GetAdjacencyMatrix(mol)
252
+ if fill == 'lower':
253
+ adjacency = np.tril(adjacency, k=-1)
254
+ elif fill == 'upper':
255
+ adjacency = np.triu(adjacency, k=1)
256
+ if self_loops:
257
+ adjacency += np.eye(adjacency.shape[0], dtype=adjacency.dtype)
258
+ if not sparse:
259
+ return adjacency.astype(dtype)
260
+ edge_source, edge_target = np.where(adjacency)
261
+ return edge_source.astype(dtype), edge_target.astype(dtype)
262
+
263
+ def sanitize_mol(
264
+ mol: RDKitMol,
265
+ strict: bool = True,
266
+ assign_stereo_chemistry: bool = True,
267
+ ) -> Mol:
268
+ mol = Mol(mol)
269
+ flag = Chem.SanitizeMol(mol, catchErrors=True)
270
+ if flag != Chem.SanitizeFlags.SANITIZE_NONE:
271
+ if strict:
272
+ raise ValueError(f'Could not sanitize {mol}.')
273
+ warnings.warn(
274
+ f'Could not sanitize {mol}. Proceeding with partial sanitization.'
275
+ )
276
+ # Sanitize mol, excluding the steps causing the error previously
277
+ Chem.SanitizeMol(mol, sanitizeOps=Chem.SanitizeFlags.SANITIZE_ALL^flag)
278
+ if assign_stereo_chemistry:
279
+ Chem.AssignStereochemistry(
280
+ mol, cleanIt=True, force=True, flagPossibleStereoCenters=True)
281
+ return mol
282
+
283
+ def get_atoms(mol: Mol) -> list[Atom]:
284
+ return [
285
+ Atom.cast(mol.GetAtomWithIdx(i))
286
+ for i in range(mol.GetNumAtoms())
287
+ ]
288
+
289
+ def get_bonds(mol: Mol) -> list[Bond]:
290
+ return [
291
+ Bond.cast(mol.GetBondWithIdx(int(i)))
292
+ for i in range(mol.GetNumBonds())
293
+ ]
294
+
295
+ def add_hs(mol: Mol) -> Mol:
296
+ rdkit_mol = Chem.AddHs(mol)
297
+ rdkit_mol.__class__ = mol.__class__
298
+ return rdkit_mol
299
+
300
+ def remove_hs(mol: Mol) -> Mol:
301
+ rdkit_mol = Chem.RemoveHs(mol)
302
+ rdkit_mol.__class__ = mol.__class__
303
+ return rdkit_mol
304
+
305
+ def get_distances(
306
+ mol: Mol,
307
+ fill: str = 'full',
308
+ use_bond_order: bool = False,
309
+ use_atom_weights: bool = False
310
+ ) -> np.ndarray:
311
+ dist_matrix = Chem.rdmolops.GetDistanceMatrix(
312
+ mol, useBO=use_bond_order, useAtomWts=use_atom_weights
313
+ )
314
+ # For disconnected nodes, a value of 1e8 is assigned to dist_matrix
315
+ # Here we convert this large value to -1.
316
+ # TODO: Add argument for filling disconnected node pairs.
317
+ dist_matrix = np.where(
318
+ dist_matrix >= 1e6, -1, dist_matrix
319
+ )
320
+ if fill == 'lower':
321
+ return np.tril(dist_matrix, k=-1)
322
+ elif fill == 'upper':
323
+ return np.triu(dist_matrix, k=1)
324
+ return dist_matrix
325
+
326
+ def get_shortest_paths(
327
+ mol: Mol,
328
+ radius: int,
329
+ self_loops: bool = False,
330
+ ) -> list[list[int]]:
331
+ paths = []
332
+ for atom in mol.atoms:
333
+ queue = collections.deque([(atom, [atom.index])])
334
+ visited = set([atom.index])
335
+ while queue:
336
+ current_atom, path = queue.popleft()
337
+ if len(path) > (radius + 1):
338
+ continue
339
+ if len(path) > 1 or self_loops:
340
+ paths.append(path)
341
+ for neighbor in current_atom.neighbors:
342
+ if neighbor.index in visited:
343
+ continue
344
+ visited.add(neighbor.index)
345
+ queue.append((neighbor, path + [neighbor.index]))
346
+ return paths
347
+
348
+ def get_periodic_table():
349
+ return Chem.GetPeriodicTable()
350
+
351
+ def partial_charges(mol: 'Mol') -> list[float]:
352
+ rdPartialCharges.ComputeGasteigerCharges(mol)
353
+ return [atom.GetDoubleProp("_GasteigerCharge") for atom in mol.atoms]
354
+
355
+ def logp_contributions(mol: 'Mol') -> list[float]:
356
+ return [i[0] for i in rdMolDescriptors._CalcCrippenContribs(mol)]
357
+
358
+ def molar_refractivity_contributions(mol: 'Mol') -> list[float]:
359
+ return [i[1] for i in rdMolDescriptors._CalcCrippenContribs(mol)]
360
+
361
+ def total_polar_surface_area_contributions(mol: 'Mol') -> list[float]:
362
+ return list(rdMolDescriptors._CalcTPSAContribs(mol))
363
+
364
+ def accessible_surface_area_contributions(mol: 'Mol') -> list[float]:
365
+ return list(rdMolDescriptors._CalcLabuteASAContribs(mol)[0])
366
+
367
+ def hydrogen_acceptors(mol: 'Mol') -> list[bool]:
368
+ h_acceptors = [i[0] for i in Lipinski._HAcceptors(mol)]
369
+ return [atom.index in h_acceptors for atom in mol.atoms]
370
+
371
+ def hydrogen_donors(mol: 'Mol') -> list[bool]:
372
+ h_donors = [i[0] for i in Lipinski._HDonors(mol)]
373
+ return [atom.index in h_donors for atom in mol.atoms]
374
+
375
+ def hetero_atoms(mol: 'Mol') -> list[bool]:
376
+ hetero_atoms = [i[0] for i in Lipinski._Heteroatoms(mol)]
377
+ return [atom.index in hetero_atoms for atom in mol.atoms]
378
+
379
+ def rotatable_bonds(mol: 'Mol') -> list[bool]:
380
+ rotatable_bonds = [set(x) for x in Lipinski._RotatableBonds(mol)]
381
+ def is_rotatable(bond):
382
+ atom_indices = {bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()}
383
+ return atom_indices in rotatable_bonds
384
+ return [is_rotatable(bond) for bond in mol.bonds]
385
+
386
+ def conformer_deviations(mol: Mol, fill: str = 'full') -> np.array:
387
+ """Root mean squared deviation (RMSD) matrix"""
388
+ num_confs = mol.num_conformers
389
+ deviations = rdMolAlign.GetAllConformerBestRMS(mol)
390
+ matrix = np.zeros((num_confs, num_confs))
391
+ k = 0
392
+ for i in range(num_confs):
393
+ for j in range(i+1, num_confs):
394
+ deviation = deviations[k]
395
+ if fill == 'upper':
396
+ matrix[i, j] = deviation
397
+ elif fill == 'lower':
398
+ matrix[j, i] = deviation
399
+ else:
400
+ matrix[i, j] = deviation
401
+ matrix[j, i] = deviation
402
+ k += 1
403
+ return matrix
404
+
405
+ def conformer_energies(
406
+ mol: Mol,
407
+ method: str = 'UFF',
408
+ ) -> list[float]:
409
+ if method == 'UFF':
410
+ energies = _calc_uff_energies(mol)
411
+ else:
412
+ if method == 'MMFF':
413
+ method += '94'
414
+ variant = method
415
+ energies = _calc_mmff_energies(mol, variant)
416
+ return energies
417
+
418
+ def embed_conformers(
419
+ mol: Mol,
420
+ num_conformers: int,
421
+ method: str = 'ETKDGv3',
422
+ timeout: int | None = None,
423
+ random_seed: int | None = None,
424
+ **kwargs
425
+ ) -> Mol:
426
+ available_embedding_methods = {
427
+ 'ETDG': rdDistGeom.ETDG(),
428
+ 'ETKDG': rdDistGeom.ETKDG(),
429
+ 'ETKDGv2': rdDistGeom.ETKDGv2(),
430
+ 'ETKDGv3': rdDistGeom.ETKDGv3(),
431
+ 'srETKDGv3': rdDistGeom.srETKDGv3(),
432
+ 'KDG': rdDistGeom.KDG()
433
+ }
434
+ mol = Mol(mol)
435
+ embedding_method = available_embedding_methods.get(method)
436
+ if embedding_method is None:
437
+ warnings.warn(
438
+ f'{method} is not available. Proceeding with ETKDGv3.'
439
+ )
440
+ embedding_method = available_embedding_methods['ETKDGv3']
441
+
442
+ for key, value in kwargs.items():
443
+ setattr(embedding_method, key, value)
444
+
445
+ if not timeout:
446
+ timeout = 0 # No timeout
447
+
448
+ if not random_seed:
449
+ random_seed = -1 # No random seed
450
+
451
+ embedding_method.randomSeed = random_seed
452
+ embedding_method.timeout = timeout
453
+
454
+ success = rdDistGeom.EmbedMultipleConfs(
455
+ mol, numConfs=num_conformers, params=embedding_method
456
+ )
457
+ num_successes = len(success)
458
+ if num_successes < num_conformers:
459
+ warnings.warn(
460
+ f'Could only embed {num_successes} out of {num_conformers} conformer(s) for '
461
+ f'{mol} using the specified method ({method}) and parameters. Attempting to '
462
+ f'embed the remaining {num_conformers-num_successes} using fallback methods.',
463
+ )
464
+ max_iters = 20 * mol.num_atoms # Doubling the number of iterations
465
+ for fallback_method in [method, 'ETDG', 'KDG']:
466
+ fallback_embedding_method = available_embedding_methods[fallback_method]
467
+ fallback_embedding_method.useRandomCoords = True
468
+ fallback_embedding_method.maxIterations = int(max_iters)
469
+ fallback_embedding_method.clearConfs = False
470
+ fallback_embedding_method.timeout = int(timeout)
471
+ fallback_embedding_method.randomSeed = int(random_seed)
472
+ success = rdDistGeom.EmbedMultipleConfs(
473
+ mol, numConfs=(num_conformers - num_successes), params=fallback_embedding_method
474
+ )
475
+ num_successes += len(success)
476
+ if num_successes == num_conformers:
477
+ break
478
+ else:
479
+ raise RuntimeError(
480
+ f'Could not embed {num_conformers} conformer(s) for {mol}. '
481
+ )
482
+ return mol
483
+
484
+ def optimize_conformers(
485
+ mol: Mol,
486
+ method: str = 'UFF',
487
+ max_iter: int = 200,
488
+ num_threads: bool = 1,
489
+ ignore_interfragment_interactions: bool = True,
490
+ vdw_threshold: float = 10.0,
491
+ ) -> Mol:
492
+ if mol.num_conformers == 0:
493
+ warnings.warn(
494
+ f'{mol} has no conformers to optimize. Proceeding without it.'
495
+ )
496
+ return Mol(mol)
497
+ available_force_field_methods = ['MMFF', 'MMFF94', 'MMFF94s', 'UFF']
498
+ if method not in available_force_field_methods:
499
+ warnings.warn(
500
+ f'{method} is not available. Proceeding with universal force field (UFF).'
501
+ )
502
+ method = 'UFF'
503
+ mol_optimized = Mol(mol)
504
+ try:
505
+ if method.startswith('MMFF'):
506
+ variant = method
507
+ if variant == 'MMFF':
508
+ variant += '94'
509
+ _, _ = _mmff_optimize_conformers(
510
+ mol_optimized,
511
+ num_threads=num_threads,
512
+ max_iter=max_iter,
513
+ variant=variant,
514
+ ignore_interfragment_interactions=ignore_interfragment_interactions,
515
+ )
516
+ else:
517
+ _, _ = _uff_optimize_conformers(
518
+ mol_optimized,
519
+ num_threads=num_threads,
520
+ max_iter=max_iter,
521
+ vdw_threshold=vdw_threshold,
522
+ ignore_interfragment_interactions=ignore_interfragment_interactions,
523
+ )
524
+ except RuntimeError as e:
525
+ warnings.warn(
526
+ f'Unsuccessful {method} force field minimization for {mol}. Proceeding without it.',
527
+ )
528
+ return Mol(mol)
529
+ return mol_optimized
530
+
531
+ def prune_conformers(
532
+ mol: Mol,
533
+ keep: int = 1,
534
+ threshold: float = 0.0,
535
+ energy_force_field: str = 'UFF',
536
+ ) -> Mol:
537
+ if mol.num_conformers == 0:
538
+ warnings.warn(
539
+ f'{mol} has no conformers to prune. Proceeding without it.'
540
+ )
541
+ return RDKitMol(mol)
542
+
543
+ threshold = threshold or 0.0
544
+ deviations = conformer_deviations(mol)
545
+ energies = conformer_energies(mol, method=energy_force_field)
546
+ sorted_indices = np.argsort(energies)
547
+
548
+ selected = [int(sorted_indices[0])]
549
+
550
+ for target in sorted_indices[1:]:
551
+ if len(selected) >= keep:
552
+ break
553
+ if np.all(deviations[target, selected] >= threshold):
554
+ selected.append(int(target))
555
+
556
+ mol_copy = Mol(mol)
557
+ mol_copy.RemoveAllConformers()
558
+ for cid in selected:
559
+ conformer = mol.get_conformer(cid)
560
+ mol_copy.AddConformer(conformer, assignId=True)
561
+
562
+ return mol_copy
563
+
564
+ def _uff_optimize_conformers(
565
+ mol: Mol,
566
+ num_threads: int = 1,
567
+ max_iter: int = 200,
568
+ vdw_threshold: float = 10.0,
569
+ ignore_interfragment_interactions: bool = True,
570
+ **kwargs,
571
+ ) -> tuple[list[float], list[bool]]:
572
+ """Universal Force Field Minimization.
573
+ """
574
+ results = rdForceFieldHelpers.UFFOptimizeMoleculeConfs(
575
+ mol,
576
+ numThreads=num_threads,
577
+ maxIters=max_iter,
578
+ vdwThresh=vdw_threshold,
579
+ ignoreInterfragInteractions=ignore_interfragment_interactions,
580
+ )
581
+ energies = [r[1] for r in results]
582
+ converged = [r[0] == 0 for r in results]
583
+ return energies, converged
584
+
585
+ def _mmff_optimize_conformers(
586
+ mol: Mol,
587
+ num_threads: int = 1,
588
+ max_iter: int = 200,
589
+ variant: str = 'MMFF94',
590
+ ignore_interfragment_interactions: bool = True,
591
+ **kwargs,
592
+ ) -> tuple[list[float], list[bool]]:
593
+ """Merck Molecular Force Field Minimization.
594
+ """
595
+ if not rdForceFieldHelpers.MMFFHasAllMoleculeParams(mol):
596
+ raise ValueError("Cannot minimize molecule using MMFF.")
597
+ rdForceFieldHelpers.MMFFSanitizeMolecule(mol)
598
+ results = rdForceFieldHelpers.MMFFOptimizeMoleculeConfs(
599
+ mol,
600
+ num_threads=num_threads,
601
+ maxIters=max_iter,
602
+ mmffVariant=variant,
603
+ ignoreInterfragInteractions=ignore_interfragment_interactions,
604
+ )
605
+ energies = [r[1] for r in results]
606
+ converged = [r[0] == 0 for r in results]
607
+ return energies, converged
608
+
609
+ def _calc_uff_energies(
610
+ mol: Mol,
611
+ ) -> list[float]:
612
+ energies = []
613
+ for i in range(mol.num_conformers):
614
+ try:
615
+ force_field = rdForceFieldHelpers.UFFGetMoleculeForceField(mol, confId=i)
616
+ energies.append(force_field.CalcEnergy())
617
+ except Exception:
618
+ energies.append(float('nan'))
619
+ return energies
620
+
621
+ def _calc_mmff_energies(
622
+ mol: Mol,
623
+ variant: str = 'MMFF94',
624
+ ) -> list[float]:
625
+ energies = []
626
+ if not rdForceFieldHelpers.MMFFHasAllMoleculeParams(mol):
627
+ raise ValueError("Cannot compute MMFF energies for this molecule.")
628
+ props = rdForceFieldHelpers.MMFFGetMoleculeProperties(mol, mmffVariant=variant)
629
+ for i in range(mol.num_conformers):
630
+ try:
631
+ force_field = rdForceFieldHelpers.MMFFGetMoleculeForceField(mol, props, confId=i)
632
+ energies.append(force_field.CalcEnergy())
633
+ except Exception:
634
+ energies.append(float('nan'))
635
+ return energies
636
+
637
+ def unpack_conformers(mol: Mol) -> list[Mol]:
638
+ mols = []
639
+ for conf in mol.get_conformers():
640
+ new_mol = RDKitMol(mol)
641
+ new_mol.RemoveAllConformers()
642
+ new_mol.AddConformer(conf, assignId=True)
643
+ new_mol.__class__ = mol.__class__
644
+ mols.append(new_mol)
645
+ return mols
646
+
647
+ _fingerprint_types = {
648
+ 'rdkit': rdFingerprintGenerator.GetRDKitFPGenerator,
649
+ 'morgan': rdFingerprintGenerator.GetMorganGenerator,
650
+ 'topological_torsion': rdFingerprintGenerator.GetTopologicalTorsionGenerator,
651
+ 'atom_pair': rdFingerprintGenerator.GetAtomPairGenerator,
652
+ }
653
+
654
+ def _get_fingerprint(
655
+ mol: Mol,
656
+ fp_type: str = 'morgan',
657
+ binary: bool = True,
658
+ dtype: str = 'float32',
659
+ **kwargs,
660
+ ) -> np.ndarray:
661
+ fingerprint: rdFingerprintGenerator.FingerprintGenerator64 = (
662
+ _fingerprint_types[fp_type](**kwargs)
663
+ )
664
+ if not isinstance(mol, Mol):
665
+ mol = Mol.from_encoding(mol)
666
+ if binary:
667
+ fp: np.ndarray = fingerprint.GetFingerprintAsNumPy(mol)
668
+ else:
669
+ fp: np.ndarray = fingerprint.GetCountFingerprintAsNumPy(mol)
670
+ return fp.astype(dtype)
671
+
672
+ def _rdkit_fingerprint(
673
+ mol: RDKitMol,
674
+ size: int = 2048,
675
+ *,
676
+ min_path: int = 1,
677
+ max_path: int = 7,
678
+ binary: bool = True,
679
+ dtype: str = 'float32',
680
+ ) -> np.ndarray:
681
+ fp_param = {'fpSize': size, 'minPath': min_path, 'maxPath': max_path}
682
+ return _get_fingerprint(mol, 'rdkit', binary, dtype, **fp_param)
683
+
684
+ def _morgan_fingerprint(
685
+ mol: RDKitMol,
686
+ size: int = 2048,
687
+ *,
688
+ radius: int = 3,
689
+ binary: bool = True,
690
+ dtype: str = 'float32',
691
+ ) -> np.ndarray:
692
+ fp_param = {'radius': radius, 'fpSize': size}
693
+ return _get_fingerprint(mol, 'morgan', binary, dtype, **fp_param)
694
+
695
+ def _topological_torsion_fingerprint(
696
+ mol: RDKitMol,
697
+ size: int = 2048,
698
+ *,
699
+ binary: bool = True,
700
+ dtype: str = 'float32',
701
+ ) -> np.ndarray:
702
+ fp_param = {'fpSize': size}
703
+ return _get_fingerprint(mol, 'topological_torsion', binary, dtype, **fp_param)
704
+
705
+ def _atom_pair_fingerprint(
706
+ mol: RDKitMol,
707
+ size: int = 2048,
708
+ *,
709
+ binary: bool = True,
710
+ dtype: str = 'float32',
711
+ ) -> np.ndarray:
712
+ fp_param = {'fpSize': size}
713
+ return _get_fingerprint(mol, 'atom_pair', binary, dtype, **fp_param)
714
+