molcraft 0.1.0a22__py3-none-any.whl → 0.1.0a23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of molcraft might be problematic. Click here for more details.
- molcraft/__init__.py +2 -2
- molcraft/callbacks.py +1 -1
- molcraft/chem.py +33 -26
- molcraft/datasets.py +1 -0
- molcraft/descriptors.py +37 -5
- molcraft/diffusion.py +241 -0
- molcraft/features.py +4 -5
- molcraft/featurizers.py +1 -1
- molcraft/layers.py +20 -21
- molcraft/losses.py +1 -0
- molcraft/models.py +4 -1
- molcraft/ops.py +1 -0
- molcraft/records.py +3 -3
- molcraft/tensors.py +1 -0
- {molcraft-0.1.0a22.dist-info → molcraft-0.1.0a23.dist-info}/METADATA +2 -3
- molcraft-0.1.0a23.dist-info/RECORD +22 -0
- molcraft-0.1.0a22.dist-info/RECORD +0 -21
- {molcraft-0.1.0a22.dist-info → molcraft-0.1.0a23.dist-info}/WHEEL +0 -0
- {molcraft-0.1.0a22.dist-info → molcraft-0.1.0a23.dist-info}/licenses/LICENSE +0 -0
- {molcraft-0.1.0a22.dist-info → molcraft-0.1.0a23.dist-info}/top_level.txt +0 -0
molcraft/__init__.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
__version__ = '0.1.
|
|
1
|
+
__version__ = '0.1.0a23'
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
4
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
|
@@ -14,4 +14,4 @@ from molcraft import records
|
|
|
14
14
|
from molcraft import tensors
|
|
15
15
|
from molcraft import callbacks
|
|
16
16
|
from molcraft import datasets
|
|
17
|
-
from molcraft import losses
|
|
17
|
+
from molcraft import losses
|
molcraft/callbacks.py
CHANGED
molcraft/chem.py
CHANGED
|
@@ -3,6 +3,7 @@ import collections
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
|
|
5
5
|
from rdkit import Chem
|
|
6
|
+
from rdkit.Chem import AllChem
|
|
6
7
|
from rdkit.Chem import Lipinski
|
|
7
8
|
from rdkit.Chem import rdDistGeom
|
|
8
9
|
from rdkit.Chem import rdDepictor
|
|
@@ -31,10 +32,8 @@ class Mol(Chem.Mol):
|
|
|
31
32
|
|
|
32
33
|
@property
|
|
33
34
|
def encoding(self):
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
return None
|
|
37
|
-
|
|
35
|
+
return getattr(self, '_encoding', None)
|
|
36
|
+
|
|
38
37
|
@property
|
|
39
38
|
def bonds(self) -> list['Bond']:
|
|
40
39
|
if not hasattr(self, '_bonds'):
|
|
@@ -67,7 +66,7 @@ class Mol(Chem.Mol):
|
|
|
67
66
|
atom = atom.GetIdx()
|
|
68
67
|
return Atom.cast(self.GetAtomWithIdx(int(atom)))
|
|
69
68
|
|
|
70
|
-
def
|
|
69
|
+
def get_shortest_path_between_atoms(
|
|
71
70
|
self,
|
|
72
71
|
atom_i: int | Chem.Atom,
|
|
73
72
|
atom_j: int | Chem.Atom
|
|
@@ -107,13 +106,13 @@ class Mol(Chem.Mol):
|
|
|
107
106
|
|
|
108
107
|
def get_conformer(self, index: int = 0) -> 'Conformer':
|
|
109
108
|
if self.num_conformers == 0:
|
|
110
|
-
warnings.warn('
|
|
109
|
+
warnings.warn(f'{self} has no conformer. Returning None.')
|
|
111
110
|
return None
|
|
112
111
|
return Conformer.cast(self.GetConformer(index))
|
|
113
112
|
|
|
114
113
|
def get_conformers(self) -> list['Conformer']:
|
|
115
114
|
if self.num_conformers == 0:
|
|
116
|
-
warnings.warn('
|
|
115
|
+
warnings.warn(f'{self} has no conformers. Returning an empty list.')
|
|
117
116
|
return []
|
|
118
117
|
return [Conformer.cast(x) for x in self.GetConformers()]
|
|
119
118
|
|
|
@@ -124,7 +123,8 @@ class Mol(Chem.Mol):
|
|
|
124
123
|
return None
|
|
125
124
|
|
|
126
125
|
def __repr__(self) -> str:
|
|
127
|
-
|
|
126
|
+
encoding = self.encoding or self.canonical_smiles
|
|
127
|
+
return f'<{self.__class__.__name__} {encoding} at {hex(id(self))}>'
|
|
128
128
|
|
|
129
129
|
|
|
130
130
|
class Conformer(Chem.Conformer):
|
|
@@ -251,7 +251,10 @@ def sanitize_mol(
|
|
|
251
251
|
flag = Chem.SanitizeMol(mol, catchErrors=True)
|
|
252
252
|
if flag != Chem.SanitizeFlags.SANITIZE_NONE:
|
|
253
253
|
if strict:
|
|
254
|
-
|
|
254
|
+
raise ValueError(f'Could not sanitize {mol}.')
|
|
255
|
+
warnings.warn(
|
|
256
|
+
f'Could not sanitize {mol}. Proceeding with partial sanitization.'
|
|
257
|
+
)
|
|
255
258
|
# Sanitize mol, excluding the steps causing the error previously
|
|
256
259
|
Chem.SanitizeMol(mol, sanitizeOps=Chem.SanitizeFlags.SANITIZE_ALL^flag)
|
|
257
260
|
if assign_stereo_chemistry:
|
|
@@ -411,13 +414,12 @@ def embed_conformers(
|
|
|
411
414
|
'KDG': rdDistGeom.KDG()
|
|
412
415
|
}
|
|
413
416
|
mol = Mol(mol)
|
|
414
|
-
encoding = mol.encoding or mol.canonical_smiles
|
|
415
417
|
embedding_method = available_embedding_methods.get(method)
|
|
416
418
|
if embedding_method is None:
|
|
417
|
-
|
|
418
|
-
f'
|
|
419
|
-
'`ETDG`, `ETKDG`, `ETKDGv2`, `ETKDGv3`, `srETKDGv3` or `KDG`.'
|
|
419
|
+
warnings.warn(
|
|
420
|
+
f'{method} is not available. Proceeding with ETKDGv3.'
|
|
420
421
|
)
|
|
422
|
+
embedding_method = available_embedding_methods['ETKDGv3']
|
|
421
423
|
|
|
422
424
|
for key, value in kwargs.items():
|
|
423
425
|
setattr(embedding_method, key, value)
|
|
@@ -438,8 +440,8 @@ def embed_conformers(
|
|
|
438
440
|
if num_successes < num_conformers:
|
|
439
441
|
warnings.warn(
|
|
440
442
|
f'Could only embed {num_successes} out of {num_conformers} conformer(s) for '
|
|
441
|
-
f'{
|
|
442
|
-
f'
|
|
443
|
+
f'{mol} using the specified method ({method}) and parameters. Attempting to '
|
|
444
|
+
f'embed the remaining {num_conformers-num_successes} using fallback methods.',
|
|
443
445
|
)
|
|
444
446
|
max_iters = 20 * mol.num_atoms # Doubling the number of iterations
|
|
445
447
|
for fallback_method in [method, 'ETDG', 'KDG']:
|
|
@@ -457,10 +459,13 @@ def embed_conformers(
|
|
|
457
459
|
break
|
|
458
460
|
else:
|
|
459
461
|
raise RuntimeError(
|
|
460
|
-
f'Could not embed {num_conformers} conformer(s) for {
|
|
462
|
+
f'Could not embed {num_conformers} conformer(s) for {mol}. '
|
|
461
463
|
)
|
|
462
464
|
return mol
|
|
463
465
|
|
|
466
|
+
|
|
467
|
+
import warnings
|
|
468
|
+
|
|
464
469
|
def optimize_conformers(
|
|
465
470
|
mol: Mol,
|
|
466
471
|
method: str = 'UFF',
|
|
@@ -469,14 +474,17 @@ def optimize_conformers(
|
|
|
469
474
|
ignore_interfragment_interactions: bool = True,
|
|
470
475
|
vdw_threshold: float = 10.0,
|
|
471
476
|
) -> Mol:
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
477
|
+
if mol.num_conformers == 0:
|
|
478
|
+
warnings.warn(
|
|
479
|
+
f'{mol} has no conformers to optimize. Proceeding without it.'
|
|
480
|
+
)
|
|
481
|
+
return Mol(mol)
|
|
482
|
+
available_force_field_methods = ['MMFF', 'MMFF94', 'MMFF94s', 'UFF']
|
|
475
483
|
if method not in available_force_field_methods:
|
|
476
|
-
|
|
477
|
-
f'
|
|
478
|
-
'`UFF`, `MMFF`, `MMFF94` or `MMFF94s`.'
|
|
484
|
+
warnings.warn(
|
|
485
|
+
f'{method} is not available. Proceeding with universal force field (UFF).'
|
|
479
486
|
)
|
|
487
|
+
method = 'UFF'
|
|
480
488
|
mol_optimized = Mol(mol)
|
|
481
489
|
try:
|
|
482
490
|
if method.startswith('MMFF'):
|
|
@@ -500,7 +508,7 @@ def optimize_conformers(
|
|
|
500
508
|
)
|
|
501
509
|
except RuntimeError as e:
|
|
502
510
|
warnings.warn(
|
|
503
|
-
f'{method} force field minimization
|
|
511
|
+
f'Unsuccessful {method} force field minimization for {mol}. Proceeding without it.',
|
|
504
512
|
)
|
|
505
513
|
return Mol(mol)
|
|
506
514
|
return mol_optimized
|
|
@@ -513,10 +521,9 @@ def prune_conformers(
|
|
|
513
521
|
) -> Mol:
|
|
514
522
|
if mol.num_conformers == 0:
|
|
515
523
|
warnings.warn(
|
|
516
|
-
'
|
|
517
|
-
'and optionally followed by `minimize()` to perform force field minimization.',
|
|
524
|
+
f'{mol} has no conformers to prune. Proceeding without it.'
|
|
518
525
|
)
|
|
519
|
-
return mol
|
|
526
|
+
return Chem.Mol(mol)
|
|
520
527
|
|
|
521
528
|
threshold = threshold or 0.0
|
|
522
529
|
deviations = conformer_deviations(mol)
|
molcraft/datasets.py
CHANGED
molcraft/descriptors.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
|
1
|
+
import warnings
|
|
1
2
|
import keras
|
|
2
3
|
import numpy as np
|
|
4
|
+
|
|
3
5
|
from rdkit.Chem import rdMolDescriptors
|
|
4
6
|
|
|
5
7
|
from molcraft import chem
|
|
@@ -12,9 +14,7 @@ class Descriptor(features.Feature):
|
|
|
12
14
|
def __call__(self, mol: chem.Mol) -> np.ndarray:
|
|
13
15
|
if not isinstance(mol, chem.Mol):
|
|
14
16
|
raise ValueError(
|
|
15
|
-
f'Input to {self.name}
|
|
16
|
-
'implements two properties that should be iterated over '
|
|
17
|
-
'to compute features: `atoms` and `bonds`.'
|
|
17
|
+
f'Input to {self.name} must be a `chem.Mol` object.'
|
|
18
18
|
)
|
|
19
19
|
descriptor = self.call(mol)
|
|
20
20
|
func = (
|
|
@@ -30,6 +30,23 @@ class Descriptor(features.Feature):
|
|
|
30
30
|
return np.concatenate(descriptors)
|
|
31
31
|
|
|
32
32
|
|
|
33
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
34
|
+
class Descriptor3D(Descriptor):
|
|
35
|
+
|
|
36
|
+
def __call__(self, mol: chem.Mol) -> np.ndarray:
|
|
37
|
+
if not isinstance(mol, chem.Mol):
|
|
38
|
+
raise ValueError(
|
|
39
|
+
f'Input to {self.name} must be a `chem.Mol` object.'
|
|
40
|
+
)
|
|
41
|
+
if mol.num_conformers == 0:
|
|
42
|
+
raise ValueError(
|
|
43
|
+
f'The inputted `chem.Mol` to {self.name} must embed a conformer. '
|
|
44
|
+
f'It is recommended that {self.name} is used as a molecule feature '
|
|
45
|
+
'for `MolGraphFeaturizer3D`, which by default embeds a conformer.'
|
|
46
|
+
)
|
|
47
|
+
return super().__call__(mol)
|
|
48
|
+
|
|
49
|
+
|
|
33
50
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
34
51
|
class MolWeight(Descriptor):
|
|
35
52
|
def call(self, mol: chem.Mol) -> np.ndarray:
|
|
@@ -77,7 +94,7 @@ class NumHydrogenDonors(Descriptor):
|
|
|
77
94
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
78
95
|
class NumHydrogenAcceptors(Descriptor):
|
|
79
96
|
def call(self, mol: chem.Mol) -> np.ndarray:
|
|
80
|
-
return rdMolDescriptors.CalcNumHBA(mol)
|
|
97
|
+
return rdMolDescriptors.CalcNumHBA(mol)
|
|
81
98
|
|
|
82
99
|
|
|
83
100
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
@@ -89,7 +106,7 @@ class NumRotatableBonds(Descriptor):
|
|
|
89
106
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
90
107
|
class NumRings(Descriptor):
|
|
91
108
|
def call(self, mol: chem.Mol) -> np.ndarray:
|
|
92
|
-
return rdMolDescriptors.CalcNumRings(mol)
|
|
109
|
+
return rdMolDescriptors.CalcNumRings(mol)
|
|
93
110
|
|
|
94
111
|
|
|
95
112
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
@@ -105,3 +122,18 @@ class AtomCount(Descriptor):
|
|
|
105
122
|
if atom.GetSymbol() == self.atom_type:
|
|
106
123
|
count += 1
|
|
107
124
|
return count
|
|
125
|
+
|
|
126
|
+
def get_config(self) -> dict:
|
|
127
|
+
config = super().get_config()
|
|
128
|
+
config['atom_type'] = self.atom_type
|
|
129
|
+
return config
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
133
|
+
class ForceFieldEnergy(Descriptor3D):
|
|
134
|
+
"""Universal Force Field (UFF) Energy."""
|
|
135
|
+
def call(self, mol: chem.Mol) -> np.ndarray:
|
|
136
|
+
mol_copy = chem.Mol(mol)
|
|
137
|
+
mol_copy = chem.add_hs(mol_copy)
|
|
138
|
+
return chem.conformer_energies(mol_copy, method="UFF")
|
|
139
|
+
|
molcraft/diffusion.py
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
import keras
|
|
3
|
+
import tensorflow as tf
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
from molcraft import ops
|
|
7
|
+
from molcraft import tensors
|
|
8
|
+
from molcraft import layers
|
|
9
|
+
from molcraft import models
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
# smiles = pd.read_csv('../../data/rt/RIKEN.csv')['smiles'].values
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
# graph = featurizers.MolGraphFeaturizer3D(super_node=False)(smiles)
|
|
16
|
+
# graph.node['coordinate']
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
# encoder = molcraft.models.GraphModel.from_layers(
|
|
20
|
+
# [
|
|
21
|
+
# diffusion.CoordinateNoise(),
|
|
22
|
+
# molcraft.layers.NodeEmbedding(128),
|
|
23
|
+
# molcraft.layers.EdgeEmbedding(128),
|
|
24
|
+
# molcraft.layers.AddContext('position'),
|
|
25
|
+
# molcraft.layers.MPConv(128),
|
|
26
|
+
# molcraft.layers.AddContext('position'),
|
|
27
|
+
# molcraft.layers.MPConv(128),
|
|
28
|
+
# molcraft.layers.AddContext('position'),
|
|
29
|
+
# ]
|
|
30
|
+
# )
|
|
31
|
+
|
|
32
|
+
# decoder = keras.Sequential([
|
|
33
|
+
# keras.layers.Dense(128, activation='relu'),
|
|
34
|
+
# keras.layers.Dense(3),
|
|
35
|
+
# ])
|
|
36
|
+
|
|
37
|
+
# model = diffusion.CoordinateNoisePredictor(encoder, decoder)
|
|
38
|
+
|
|
39
|
+
# model(graph)
|
|
40
|
+
|
|
41
|
+
# model.save('/tmp/model.keras')
|
|
42
|
+
# model = molcraft.models.load_model('/tmp/model.keras')
|
|
43
|
+
|
|
44
|
+
# model.compile(keras.optimizers.Adam(1e-3), 'mse')
|
|
45
|
+
# model.fit(graph, epochs=100)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
# from rdkit.Geometry import Point3D
|
|
49
|
+
|
|
50
|
+
# def energy(smiles, coordinate):
|
|
51
|
+
# m = chem.Mol.from_encoding(smiles)
|
|
52
|
+
# m = chem.embed_conformers(m, 1)
|
|
53
|
+
# conf = m.GetConformer()
|
|
54
|
+
# for i in range(m.GetNumAtoms()):
|
|
55
|
+
# x, y, z = coordinate[i]
|
|
56
|
+
# conf.SetAtomPosition(i, Point3D(float(x), float(y), float(z)))
|
|
57
|
+
# return m, chem.conformer_energies(m)[0]
|
|
58
|
+
|
|
59
|
+
# def denoise(
|
|
60
|
+
# graph: tensors.GraphTensor,
|
|
61
|
+
# model,
|
|
62
|
+
# ):
|
|
63
|
+
|
|
64
|
+
# print("----")
|
|
65
|
+
# print(energy(smiles[0], graph[0].node['coordinate'])[-1])
|
|
66
|
+
# print("----")
|
|
67
|
+
|
|
68
|
+
# beta = keras.ops.linspace(1e-4, 1e-2, 100)
|
|
69
|
+
# alpha = 1 - beta
|
|
70
|
+
# alpha_bar = keras.ops.cumprod(alpha)
|
|
71
|
+
# sigma = keras.ops.sqrt(beta[1:] * (1.0 - alpha_bar[:-1]) / (1.0 - alpha_bar[1:]))
|
|
72
|
+
|
|
73
|
+
# graph = graph.update(
|
|
74
|
+
# {
|
|
75
|
+
# 'context': {
|
|
76
|
+
# 'position': keras.ops.ones_like(graph.context['size']) * 99
|
|
77
|
+
# },
|
|
78
|
+
# 'node': {
|
|
79
|
+
# 'coordinate': keras.random.normal(graph.node['coordinate'].shape)
|
|
80
|
+
# }
|
|
81
|
+
# }
|
|
82
|
+
# )
|
|
83
|
+
|
|
84
|
+
# for t in reversed(range(100)):
|
|
85
|
+
# alpha_t = alpha[t]
|
|
86
|
+
# alpha_bar_t = alpha_bar[t]
|
|
87
|
+
|
|
88
|
+
# a = 1 / keras.ops.sqrt(alpha_t)
|
|
89
|
+
|
|
90
|
+
# b = (1 - alpha_t) / keras.ops.sqrt(1 - alpha_bar_t)
|
|
91
|
+
|
|
92
|
+
# if t > 0:
|
|
93
|
+
# z = keras.random.normal(()) * sigma[t-1]
|
|
94
|
+
# else:
|
|
95
|
+
# z = 0.0
|
|
96
|
+
|
|
97
|
+
# graph = graph.update({
|
|
98
|
+
# 'node': {
|
|
99
|
+
# 'coordinate': (
|
|
100
|
+
# a * (graph.node['coordinate'] - b * model(graph)) + z
|
|
101
|
+
# )
|
|
102
|
+
# }
|
|
103
|
+
# })
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
# print(energy(smiles[0], graph[0].node['coordinate'])[-1])
|
|
107
|
+
|
|
108
|
+
# return graph
|
|
109
|
+
|
|
110
|
+
# graph_updated = denoise(graph[:1], model)x
|
|
111
|
+
# mol, e = energy(smiles[0], graph_updated[0].node['coordinate'])
|
|
112
|
+
# print(e)
|
|
113
|
+
# Chem.Mol(mol)
|
|
114
|
+
|
|
115
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
116
|
+
class CoordinateNoisePredictor(models.GraphModel):
|
|
117
|
+
|
|
118
|
+
def __init__(self, encoder, decoder, *args, **kwargs):
|
|
119
|
+
super().__init__(*args, **kwargs)
|
|
120
|
+
self.encoder = encoder
|
|
121
|
+
self.decoder = decoder
|
|
122
|
+
|
|
123
|
+
def propagate(self, tensor):
|
|
124
|
+
return self.decoder(self.encoder(tensor).node['feature'])
|
|
125
|
+
|
|
126
|
+
def train_step(self, tensor: tensors.GraphTensor) -> dict[str, float]:
|
|
127
|
+
with tf.GradientTape() as tape:
|
|
128
|
+
tensor = self.encoder(tensor)
|
|
129
|
+
feature = tensor.node['feature']
|
|
130
|
+
noise_true = tensor.node['label']
|
|
131
|
+
noise_pred = self.decoder(feature)
|
|
132
|
+
loss = self.compute_loss(tensor, noise_true, noise_pred)
|
|
133
|
+
loss = self.optimizer.scale_loss(loss)
|
|
134
|
+
trainable_weights = self.trainable_weights
|
|
135
|
+
gradients = tape.gradient(loss, trainable_weights)
|
|
136
|
+
self.optimizer.apply_gradients(zip(gradients, trainable_weights))
|
|
137
|
+
return self.compute_metrics(tensor, noise_true, noise_pred)
|
|
138
|
+
|
|
139
|
+
def test_step(self, tensor: tensors.GraphTensor) -> dict[str, float]:
|
|
140
|
+
tensor = self.encoder(tensor)
|
|
141
|
+
feature = tensor.node['feature']
|
|
142
|
+
noise_true = tensor.node['label']
|
|
143
|
+
noise_pred = self.decoder(feature)
|
|
144
|
+
return self.compute_metrics(tensor, noise_true, noise_pred)
|
|
145
|
+
|
|
146
|
+
def get_config(self) -> dict:
|
|
147
|
+
config = super().get_config()
|
|
148
|
+
config['encoder'] = keras.saving.serialize_keras_object(self.encoder)
|
|
149
|
+
config['decoder'] = keras.saving.serialize_keras_object(self.decoder)
|
|
150
|
+
return config
|
|
151
|
+
|
|
152
|
+
@classmethod
|
|
153
|
+
def from_config(cls, config: dict):
|
|
154
|
+
config['encoder'] = keras.saving.deserialize_keras_object(config['encoder'])
|
|
155
|
+
config['decoder'] = keras.saving.deserialize_keras_object(config['decoder'])
|
|
156
|
+
return super().from_config(config)
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
160
|
+
class CoordinateNoise(layers.GraphLayer):
|
|
161
|
+
|
|
162
|
+
def __init__(
|
|
163
|
+
self,
|
|
164
|
+
beta: tuple[float, float] = (1e-4, 1e-2),
|
|
165
|
+
position_dim: int = 128,
|
|
166
|
+
max_timesteps: int = 100,
|
|
167
|
+
**kwargs
|
|
168
|
+
) -> None:
|
|
169
|
+
super().__init__(**kwargs)
|
|
170
|
+
self._beta = beta
|
|
171
|
+
self._max_timesteps = max_timesteps
|
|
172
|
+
beta = keras.ops.linspace(*self._beta, self._max_timesteps)
|
|
173
|
+
alpha = 1 - beta
|
|
174
|
+
alpha_cumprod = keras.ops.cumprod(alpha)
|
|
175
|
+
alpha_cumprod = keras.ops.expand_dims(alpha_cumprod, -1)
|
|
176
|
+
self._alpha_cumprod = alpha_cumprod
|
|
177
|
+
self._timestep_embedding = TimestepEmbedding(dim=position_dim)
|
|
178
|
+
|
|
179
|
+
def propagate(self, graph: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
180
|
+
if 'position' in graph.context:
|
|
181
|
+
return graph.update({'context': {'position': self._timestep_embedding(graph.context['position'])}})
|
|
182
|
+
|
|
183
|
+
timestep = keras.random.randint(
|
|
184
|
+
shape=(graph.num_subgraphs,), minval=0, maxval=self._max_timesteps
|
|
185
|
+
)
|
|
186
|
+
alpha_cumprod = ops.gather(
|
|
187
|
+
ops.gather(self._alpha_cumprod, timestep), graph.graph_indicator
|
|
188
|
+
)
|
|
189
|
+
epsilon = keras.random.normal(
|
|
190
|
+
shape=keras.ops.shape(graph.node['coordinate']), mean=0, stddev=1
|
|
191
|
+
)
|
|
192
|
+
noisy_coordinate = (
|
|
193
|
+
keras.ops.sqrt(alpha_cumprod) * graph.node['coordinate'] +
|
|
194
|
+
keras.ops.sqrt(1 - alpha_cumprod) * epsilon
|
|
195
|
+
)
|
|
196
|
+
timestep = self._timestep_embedding(timestep)
|
|
197
|
+
return graph.update(
|
|
198
|
+
{
|
|
199
|
+
'context': {
|
|
200
|
+
'position': timestep,
|
|
201
|
+
},
|
|
202
|
+
'node': {
|
|
203
|
+
'coordinate': noisy_coordinate,
|
|
204
|
+
'label': epsilon
|
|
205
|
+
},
|
|
206
|
+
}
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
def get_config(self) -> dict:
|
|
210
|
+
config = super().get_config()
|
|
211
|
+
config['beta'] = self._beta
|
|
212
|
+
config['max_timesteps'] = self._max_timesteps
|
|
213
|
+
return config
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
class TimestepEmbedding(keras.layers.Layer):
|
|
217
|
+
|
|
218
|
+
def __init__(self, dim: int, max_wavelength: int = 10000, **kwargs) -> None:
|
|
219
|
+
super().__init__(**kwargs)
|
|
220
|
+
self._dim = dim
|
|
221
|
+
self._max_wavelength = max_wavelength
|
|
222
|
+
|
|
223
|
+
def call(self, inputs: tf.Tensor) -> tf.Tensor:
|
|
224
|
+
timestep = keras.ops.cast(inputs, 'float32')
|
|
225
|
+
embedding = keras.ops.log(self._max_wavelength) / (self._dim // 2 - 1)
|
|
226
|
+
embedding = keras.ops.exp(
|
|
227
|
+
-embedding * keras.ops.arange(self._dim // 2, dtype='float32')
|
|
228
|
+
)
|
|
229
|
+
embedding = timestep[:, None] * embedding[None, :]
|
|
230
|
+
embedding = keras.ops.concatenate(
|
|
231
|
+
[keras.ops.sin(embedding), keras.ops.cos(embedding)], axis=-1
|
|
232
|
+
)
|
|
233
|
+
return embedding
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
def get_config(self) -> dict:
|
|
237
|
+
config = super().get_config()
|
|
238
|
+
config['dim'] = self._dim
|
|
239
|
+
config['max_wavelength'] = self._max_wavelength
|
|
240
|
+
return config
|
|
241
|
+
|
molcraft/features.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
|
+
import warnings
|
|
1
2
|
import abc
|
|
2
3
|
import math
|
|
3
4
|
import keras
|
|
4
|
-
import warnings
|
|
5
5
|
import numpy as np
|
|
6
6
|
|
|
7
7
|
from molcraft import chem
|
|
@@ -41,14 +41,14 @@ class Feature(abc.ABC):
|
|
|
41
41
|
|
|
42
42
|
def __call__(self, mol: chem.Mol) -> np.ndarray:
|
|
43
43
|
if not isinstance(mol, chem.Mol):
|
|
44
|
-
raise TypeError(f'Input to {self.name} must be a `chem.Mol`
|
|
44
|
+
raise TypeError(f'Input to {self.name} must be a `chem.Mol` object.')
|
|
45
45
|
features = self.call(mol)
|
|
46
46
|
if len(features) != mol.num_atoms and len(features) != mol.num_bonds:
|
|
47
47
|
raise ValueError(
|
|
48
48
|
f'The number of features computed by {self.name} does not '
|
|
49
49
|
'match the number of atoms or bonds of the `chem.Mol` object. '
|
|
50
|
-
'Make sure to iterate over `atoms` or `bonds` of `chem.Mol` '
|
|
51
|
-
'when computing features.'
|
|
50
|
+
'Make sure to iterate over `atoms` or `bonds` of the `chem.Mol` '
|
|
51
|
+
'object when computing features.'
|
|
52
52
|
)
|
|
53
53
|
if len(features) == 0:
|
|
54
54
|
# Edge case: no atoms or bonds in the molecule.
|
|
@@ -109,7 +109,6 @@ class Feature(abc.ABC):
|
|
|
109
109
|
warnings.warn(
|
|
110
110
|
f'Found value of {self.name} to be non-finite. '
|
|
111
111
|
f'Value received: {value}. Converting it to a value of 0.',
|
|
112
|
-
stacklevel=2
|
|
113
112
|
)
|
|
114
113
|
value = 0.0
|
|
115
114
|
return np.asarray([value], dtype=self.dtype)
|
molcraft/featurizers.py
CHANGED
molcraft/layers.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
|
+
import warnings
|
|
1
2
|
import keras
|
|
2
3
|
import tensorflow as tf
|
|
3
|
-
import warnings
|
|
4
4
|
import functools
|
|
5
5
|
from keras.src.models import functional
|
|
6
6
|
|
|
@@ -350,11 +350,8 @@ class GraphConv(GraphLayer):
|
|
|
350
350
|
)
|
|
351
351
|
if self._project_residual:
|
|
352
352
|
warnings.warn(
|
|
353
|
-
'
|
|
354
|
-
'
|
|
355
|
-
'Automatically applying a projection layer to residual to '
|
|
356
|
-
'match input and output. ',
|
|
357
|
-
stacklevel=2,
|
|
353
|
+
'Found incompatible dim between input and output. Applying '
|
|
354
|
+
'a projection layer to residual to match input and output dim.',
|
|
358
355
|
)
|
|
359
356
|
self._residual_dense = self.get_dense(
|
|
360
357
|
self.units, name='residual_dense'
|
|
@@ -613,10 +610,8 @@ class GIConv(GraphConv):
|
|
|
613
610
|
if not self._update_edge_feature:
|
|
614
611
|
if (edge_feature_dim != node_feature_dim):
|
|
615
612
|
warnings.warn(
|
|
616
|
-
'Found edge feature dim to be incompatible
|
|
617
|
-
'
|
|
618
|
-
'the dim of node features.',
|
|
619
|
-
stacklevel=2,
|
|
613
|
+
'Found edge and node feature dim to be incompatible. Applying a '
|
|
614
|
+
'projection layer to edge features to match the dim of the node features.',
|
|
620
615
|
)
|
|
621
616
|
self._update_edge_feature = True
|
|
622
617
|
|
|
@@ -870,10 +865,10 @@ class MPConv(GraphConv):
|
|
|
870
865
|
self._project_previous_node_feature = node_feature_dim != self.units
|
|
871
866
|
if self._project_previous_node_feature:
|
|
872
867
|
warnings.warn(
|
|
873
|
-
'
|
|
874
|
-
'
|
|
875
|
-
'
|
|
876
|
-
|
|
868
|
+
'Inputted node feature dim does not match updated node feature dim, '
|
|
869
|
+
'which is required for the GRU update. Applying a projection layer to '
|
|
870
|
+
'the inputted node features prior to the GRU update, to match dim '
|
|
871
|
+
'of the updated node feature dim.'
|
|
877
872
|
)
|
|
878
873
|
self._previous_node_dense = self.get_dense(self.units)
|
|
879
874
|
|
|
@@ -1497,6 +1492,7 @@ class AddContext(GraphLayer):
|
|
|
1497
1492
|
|
|
1498
1493
|
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
1499
1494
|
feature_dim = spec.node['feature'].shape[-1]
|
|
1495
|
+
self._has_super_node = 'super' in spec.node
|
|
1500
1496
|
if self._intermediate_dim is None:
|
|
1501
1497
|
self._intermediate_dim = feature_dim * 2
|
|
1502
1498
|
self._intermediate_dense = self.get_dense(
|
|
@@ -1515,9 +1511,14 @@ class AddContext(GraphLayer):
|
|
|
1515
1511
|
context = self._intermediate_dense(context)
|
|
1516
1512
|
context = self._intermediate_norm(context)
|
|
1517
1513
|
context = self._final_dense(context)
|
|
1518
|
-
|
|
1519
|
-
|
|
1520
|
-
|
|
1514
|
+
if self._has_super_node:
|
|
1515
|
+
node_feature = ops.scatter_add(
|
|
1516
|
+
tensor.node['feature'], tensor.node['super'], context
|
|
1517
|
+
)
|
|
1518
|
+
else:
|
|
1519
|
+
node_feature = (
|
|
1520
|
+
tensor.node['feature'] + ops.gather(context, tensor.graph_indicator)
|
|
1521
|
+
)
|
|
1521
1522
|
data = {'node': {'feature': node_feature}}
|
|
1522
1523
|
if self._drop:
|
|
1523
1524
|
data['context'] = {self._field: None}
|
|
@@ -1561,8 +1562,7 @@ class GraphNetwork(GraphLayer):
|
|
|
1561
1562
|
if self._update_node_feature:
|
|
1562
1563
|
warnings.warn(
|
|
1563
1564
|
'Node feature dim does not match `units` of the first layer. '
|
|
1564
|
-
'
|
|
1565
|
-
stacklevel=2
|
|
1565
|
+
'Applying a projection layer to node features to match `units`.',
|
|
1566
1566
|
)
|
|
1567
1567
|
self._node_dense = self.get_dense(units)
|
|
1568
1568
|
self._has_edge_feature = 'feature' in spec.edge
|
|
@@ -1572,8 +1572,7 @@ class GraphNetwork(GraphLayer):
|
|
|
1572
1572
|
if self._update_edge_feature:
|
|
1573
1573
|
warnings.warn(
|
|
1574
1574
|
'Edge feature dim does not match `units` of the first layer. '
|
|
1575
|
-
'
|
|
1576
|
-
stacklevel=2
|
|
1575
|
+
'Applying projection layer to edge features to match `units`.'
|
|
1577
1576
|
)
|
|
1578
1577
|
self._edge_dense = self.get_dense(units)
|
|
1579
1578
|
|
molcraft/losses.py
CHANGED
molcraft/models.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import warnings
|
|
1
2
|
import typing
|
|
2
3
|
import keras
|
|
3
4
|
import numpy as np
|
|
@@ -111,7 +112,7 @@ class GraphModel(layers.GraphLayer, keras.models.Model):
|
|
|
111
112
|
def __new__(cls, *args, **kwargs):
|
|
112
113
|
if _functional_init_arguments(args, kwargs) and cls == GraphModel:
|
|
113
114
|
return FunctionalGraphModel(*args, **kwargs)
|
|
114
|
-
return
|
|
115
|
+
return super().__new__(cls)
|
|
115
116
|
|
|
116
117
|
def __init__(self, *args, **kwargs):
|
|
117
118
|
self._model_layers = kwargs.pop('model_layers', None)
|
|
@@ -137,6 +138,8 @@ class GraphModel(layers.GraphLayer, keras.models.Model):
|
|
|
137
138
|
"""
|
|
138
139
|
if not tensors.is_graph(graph_layers[0]):
|
|
139
140
|
return cls(model_layers=graph_layers)
|
|
141
|
+
elif cls != GraphModel:
|
|
142
|
+
return cls(model_layers=graph_layers[1:])
|
|
140
143
|
inputs: dict = graph_layers.pop(0)
|
|
141
144
|
x = inputs
|
|
142
145
|
for layer in graph_layers:
|
molcraft/ops.py
CHANGED
molcraft/records.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
|
1
|
+
import warnings
|
|
1
2
|
import os
|
|
2
3
|
import math
|
|
3
4
|
import glob
|
|
4
5
|
import time
|
|
5
6
|
import typing
|
|
6
|
-
import warnings
|
|
7
7
|
import tensorflow as tf
|
|
8
8
|
import numpy as np
|
|
9
9
|
import pandas as pd
|
|
@@ -164,8 +164,8 @@ def _write_tfrecord(
|
|
|
164
164
|
_write_example(tensor)
|
|
165
165
|
except Exception as e:
|
|
166
166
|
warnings.warn(
|
|
167
|
-
f
|
|
168
|
-
f
|
|
167
|
+
f'Could not write record for index {i + start_index}, proceeding without it.'
|
|
168
|
+
f'Exception raised:\n{e}'
|
|
169
169
|
)
|
|
170
170
|
|
|
171
171
|
def _serialize_example(
|
molcraft/tensors.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: molcraft
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.0a23
|
|
4
4
|
Summary: Graph Neural Networks for Molecular Machine Learning
|
|
5
5
|
Author-email: Alexander Kensert <alexander.kensert@gmail.com>
|
|
6
6
|
License: MIT License
|
|
@@ -35,7 +35,6 @@ Requires-Python: >=3.10
|
|
|
35
35
|
Description-Content-Type: text/markdown
|
|
36
36
|
License-File: LICENSE
|
|
37
37
|
Requires-Dist: tensorflow>=2.16
|
|
38
|
-
Requires-Dist: tensorflow-text>=2.16
|
|
39
38
|
Requires-Dist: rdkit>=2023.9.5
|
|
40
39
|
Requires-Dist: pandas>=1.0.3
|
|
41
40
|
Requires-Dist: ipython>=8.12.0
|
|
@@ -43,7 +42,7 @@ Provides-Extra: gpu
|
|
|
43
42
|
Requires-Dist: tensorflow[and-cuda]>=2.16; extra == "gpu"
|
|
44
43
|
Dynamic: license-file
|
|
45
44
|
|
|
46
|
-
<img src="https://github.com/akensert/molcraft/blob/main/docs/_static/molcraft-logo.png" alt="molcraft-logo"
|
|
45
|
+
<img src="https://github.com/akensert/molcraft/blob/main/docs/_static/molcraft-logo.png" alt="molcraft-logo" width="90%">
|
|
47
46
|
|
|
48
47
|
**Deep Learning on Molecules**: A Minimalistic GNN package for Molecular ML.
|
|
49
48
|
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
molcraft/__init__.py,sha256=QAJQSS_jOBzLxGRL7ciskMY_kn2ARLCg7FVTWeF-D_I,432
|
|
2
|
+
molcraft/callbacks.py,sha256=B4gGWjVW_1ORrt38jfk1ZFI9c0rOpN5sgjGWVqs3Ess,3571
|
|
3
|
+
molcraft/chem.py,sha256=0Zni91J4fQJW16R6g3jlOX9Vm8FH0Z5NOx0s7_X-xQw,22232
|
|
4
|
+
molcraft/datasets.py,sha256=1rHccqra5chIBwo2pz9vduyv0i07uY3CABzmAqWiFBU,4161
|
|
5
|
+
molcraft/descriptors.py,sha256=uqMPeIKqfkHC04FgztxS1FsfC3zsFJhvniZO70D22l0,4553
|
|
6
|
+
molcraft/diffusion.py,sha256=HR1kp2MuCWyUtGoGXvEA6kXTdWMGD2w5EZEoKLI1ilM,7902
|
|
7
|
+
molcraft/features.py,sha256=q-wuRP9YjPu_v5czipsh00VEXEjgFaeuLk6dbgyD_VM,13505
|
|
8
|
+
molcraft/featurizers.py,sha256=nGdV9G-aO43-vgKPNFfEOESW2hVvIvixHu3EHjIRrgU,18097
|
|
9
|
+
molcraft/layers.py,sha256=ba0WdQC2IUNsLy9pV0mIm5BBO3MCvR2lYWWuq1-8M4M,64522
|
|
10
|
+
molcraft/losses.py,sha256=piu4XYAgjnK7k9LqA4Vkh-SooYZ31sWwRfG1cacCwyA,1081
|
|
11
|
+
molcraft/models.py,sha256=-at-yFWj8mIkGchVY39m9-HtTnKqAUDQrF6wDrQXNuQ,22040
|
|
12
|
+
molcraft/ops.py,sha256=Qf9l1oOg20HNi9L9nLgf_c_5v09GtXDCc-1fkEQqn54,6194
|
|
13
|
+
molcraft/records.py,sha256=dAaq5tr3B1jXNgd3tkvVxWgUW2Qa9gtXWE3TToeiKmQ,6283
|
|
14
|
+
molcraft/tensors.py,sha256=JILwU9l6kUsrtJJ9YSzmT90_G5kbZZC6LAShvsnZOrk,22493
|
|
15
|
+
molcraft/applications/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
16
|
+
molcraft/applications/chromatography.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
17
|
+
molcraft/applications/proteomics.py,sha256=BL3EtW-q-0j79pLYO7npC67mA2ApRhH-XI4rOaP8_wc,8407
|
|
18
|
+
molcraft-0.1.0a23.dist-info/licenses/LICENSE,sha256=sbVeqlrtZ0V63uYhZGL5dCxUm8rBAOqe2avyA1zIQNk,1074
|
|
19
|
+
molcraft-0.1.0a23.dist-info/METADATA,sha256=RomzSM8GmqbILgyZMmzR4GjrBLWDPHGZcedb1VDmqxs,3892
|
|
20
|
+
molcraft-0.1.0a23.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
21
|
+
molcraft-0.1.0a23.dist-info/top_level.txt,sha256=dENV6MfOceshM6MQCgJlcN1ojZkiCL9B4F7XyUge3QM,9
|
|
22
|
+
molcraft-0.1.0a23.dist-info/RECORD,,
|
|
@@ -1,21 +0,0 @@
|
|
|
1
|
-
molcraft/__init__.py,sha256=O88EmicQAD8oz9oFMXk_IzFChQEbbU-BCs3IE-c9Dkk,431
|
|
2
|
-
molcraft/callbacks.py,sha256=x5HnkZhqcFRrW6xdApt_jZ4X08A-0fxcnFKfdmRKa0c,3571
|
|
3
|
-
molcraft/chem.py,sha256=ynrEpWZL2D370p7CqH2kE1KhBByq7IiuQbUNoKQt96I,22028
|
|
4
|
-
molcraft/datasets.py,sha256=Nd2lw5USUZE52vvAiNr-q-n03Y3--NlZlK0NzqHgp-E,4145
|
|
5
|
-
molcraft/descriptors.py,sha256=Cl3KnBPsTST7XLgRLktkX5LwY9MV0P_lUlrt8iPV5no,3508
|
|
6
|
-
molcraft/features.py,sha256=s0WeV8eZcDEypPgC1m37f4s9QkvWIlVgn-L43Cdsa14,13525
|
|
7
|
-
molcraft/featurizers.py,sha256=1yBz5-JA7IhNm0dGivvVm1nJ5QGck8VQXtwHPWFbTuQ,18091
|
|
8
|
-
molcraft/layers.py,sha256=H7XZru4XGJA6gbRO9V1BsGqh1mIrMdhzNCKS5o6oNok,64544
|
|
9
|
-
molcraft/losses.py,sha256=qnS2yC5g-O3n_zVea9MR6TNiFraW2yqRgePOisoUP4A,1065
|
|
10
|
-
molcraft/models.py,sha256=2Pc1htT9fCukGd8ZxrvE0rzEHsPBm0pluHw4FZXaUE4,21963
|
|
11
|
-
molcraft/ops.py,sha256=bQbdFDt9waxVCzF5-dkTB6vlpj9eoSt8I4Qg7ZGXbsU,6178
|
|
12
|
-
molcraft/records.py,sha256=sopYElKWC3A9QE5I8_957v3faLb2Wt5WILHZv_FLLds,6283
|
|
13
|
-
molcraft/tensors.py,sha256=vk-W8zZu-re1g18YevDEEoVQRxT4AdIiMdI-4EvtJI4,22477
|
|
14
|
-
molcraft/applications/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
15
|
-
molcraft/applications/chromatography.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
16
|
-
molcraft/applications/proteomics.py,sha256=BL3EtW-q-0j79pLYO7npC67mA2ApRhH-XI4rOaP8_wc,8407
|
|
17
|
-
molcraft-0.1.0a22.dist-info/licenses/LICENSE,sha256=sbVeqlrtZ0V63uYhZGL5dCxUm8rBAOqe2avyA1zIQNk,1074
|
|
18
|
-
molcraft-0.1.0a22.dist-info/METADATA,sha256=1OHx3-Q94fFEi21l0p3bnMjU-Q0EHaZLm4PU1A6QbkU,3930
|
|
19
|
-
molcraft-0.1.0a22.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
20
|
-
molcraft-0.1.0a22.dist-info/top_level.txt,sha256=dENV6MfOceshM6MQCgJlcN1ojZkiCL9B4F7XyUge3QM,9
|
|
21
|
-
molcraft-0.1.0a22.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|