molcraft 0.1.0a2__tar.gz → 0.1.0a4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of molcraft might be problematic. Click here for more details.
- {molcraft-0.1.0a2 → molcraft-0.1.0a4}/PKG-INFO +6 -6
- {molcraft-0.1.0a2 → molcraft-0.1.0a4}/README.md +5 -5
- {molcraft-0.1.0a2 → molcraft-0.1.0a4}/molcraft/__init__.py +2 -1
- {molcraft-0.1.0a2 → molcraft-0.1.0a4}/molcraft/callbacks.py +12 -0
- {molcraft-0.1.0a2 → molcraft-0.1.0a4}/molcraft/descriptors.py +24 -23
- {molcraft-0.1.0a2 → molcraft-0.1.0a4}/molcraft/experimental/peptides.py +96 -79
- {molcraft-0.1.0a2 → molcraft-0.1.0a4}/molcraft/features.py +5 -3
- {molcraft-0.1.0a2 → molcraft-0.1.0a4}/molcraft/featurizers.py +61 -38
- {molcraft-0.1.0a2 → molcraft-0.1.0a4}/molcraft/layers.py +1004 -425
- {molcraft-0.1.0a2 → molcraft-0.1.0a4}/molcraft/models.py +47 -3
- {molcraft-0.1.0a2 → molcraft-0.1.0a4}/molcraft/ops.py +14 -3
- {molcraft-0.1.0a2 → molcraft-0.1.0a4}/molcraft/tensors.py +3 -3
- {molcraft-0.1.0a2 → molcraft-0.1.0a4}/molcraft.egg-info/PKG-INFO +6 -6
- molcraft-0.1.0a4/tests/test_featurizers.py +197 -0
- molcraft-0.1.0a4/tests/test_layers.py +287 -0
- {molcraft-0.1.0a2 → molcraft-0.1.0a4}/tests/test_models.py +41 -1
- molcraft-0.1.0a2/tests/test_featurizers.py +0 -111
- molcraft-0.1.0a2/tests/test_layers.py +0 -143
- {molcraft-0.1.0a2 → molcraft-0.1.0a4}/LICENSE +0 -0
- {molcraft-0.1.0a2 → molcraft-0.1.0a4}/molcraft/chem.py +0 -0
- {molcraft-0.1.0a2 → molcraft-0.1.0a4}/molcraft/conformers.py +0 -0
- {molcraft-0.1.0a2 → molcraft-0.1.0a4}/molcraft/datasets.py +0 -0
- {molcraft-0.1.0a2 → molcraft-0.1.0a4}/molcraft/experimental/__init__.py +0 -0
- {molcraft-0.1.0a2 → molcraft-0.1.0a4}/molcraft/records.py +0 -0
- {molcraft-0.1.0a2 → molcraft-0.1.0a4}/molcraft.egg-info/SOURCES.txt +0 -0
- {molcraft-0.1.0a2 → molcraft-0.1.0a4}/molcraft.egg-info/dependency_links.txt +0 -0
- {molcraft-0.1.0a2 → molcraft-0.1.0a4}/molcraft.egg-info/requires.txt +0 -0
- {molcraft-0.1.0a2 → molcraft-0.1.0a4}/molcraft.egg-info/top_level.txt +0 -0
- {molcraft-0.1.0a2 → molcraft-0.1.0a4}/pyproject.toml +0 -0
- {molcraft-0.1.0a2 → molcraft-0.1.0a4}/setup.cfg +0 -0
- {molcraft-0.1.0a2 → molcraft-0.1.0a4}/tests/test_chem.py +0 -0
- {molcraft-0.1.0a2 → molcraft-0.1.0a4}/tests/test_tensors.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: molcraft
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.0a4
|
|
4
4
|
Summary: Graph Neural Networks for Molecular Machine Learning
|
|
5
5
|
Author-email: Alexander Kensert <alexander.kensert@gmail.com>
|
|
6
6
|
License: MIT License
|
|
@@ -51,11 +51,11 @@ Dynamic: license-file
|
|
|
51
51
|
|
|
52
52
|
## Highlights
|
|
53
53
|
- Compatible with **Keras 3**
|
|
54
|
-
-
|
|
55
|
-
-
|
|
56
|
-
-
|
|
57
|
-
-
|
|
58
|
-
-
|
|
54
|
+
- Customizable and serializable **featurizers**
|
|
55
|
+
- Customizable and serializable **layers** and **models**
|
|
56
|
+
- Customizable **GraphTensor**
|
|
57
|
+
- Fast and efficient featurization of molecular graphs
|
|
58
|
+
- Efficient and easy-to-use input pipelines using TF **records**
|
|
59
59
|
|
|
60
60
|
## Examples
|
|
61
61
|
|
|
@@ -7,11 +7,11 @@
|
|
|
7
7
|
|
|
8
8
|
## Highlights
|
|
9
9
|
- Compatible with **Keras 3**
|
|
10
|
-
-
|
|
11
|
-
-
|
|
12
|
-
-
|
|
13
|
-
-
|
|
14
|
-
-
|
|
10
|
+
- Customizable and serializable **featurizers**
|
|
11
|
+
- Customizable and serializable **layers** and **models**
|
|
12
|
+
- Customizable **GraphTensor**
|
|
13
|
+
- Fast and efficient featurization of molecular graphs
|
|
14
|
+
- Efficient and easy-to-use input pipelines using TF **records**
|
|
15
15
|
|
|
16
16
|
## Examples
|
|
17
17
|
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
__version__ = '0.1.
|
|
1
|
+
__version__ = '0.1.0a4'
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
4
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
|
@@ -14,3 +14,4 @@ from molcraft import ops
|
|
|
14
14
|
from molcraft import records
|
|
15
15
|
from molcraft import tensors
|
|
16
16
|
from molcraft import callbacks
|
|
17
|
+
from molcraft import datasets
|
|
@@ -19,3 +19,15 @@ class TensorBoard(keras.callbacks.TensorBoard):
|
|
|
19
19
|
weight, image_weight_name, epoch
|
|
20
20
|
)
|
|
21
21
|
self._train_writer.flush()
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class LearningRateDecay(keras.callbacks.LearningRateScheduler):
|
|
25
|
+
|
|
26
|
+
def __init__(self, rate: float, delay: int = 0, **kwargs):
|
|
27
|
+
|
|
28
|
+
def lr_schedule(epoch: int, lr: float):
|
|
29
|
+
if epoch < delay:
|
|
30
|
+
return float(lr)
|
|
31
|
+
return float(lr * keras.ops.exp(-rate))
|
|
32
|
+
|
|
33
|
+
super().__init__(schedule=lr_schedule, **kwargs)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import keras
|
|
2
2
|
import numpy as np
|
|
3
|
-
from rdkit.Chem import
|
|
3
|
+
from rdkit.Chem import rdMolDescriptors
|
|
4
4
|
|
|
5
5
|
from molcraft import chem
|
|
6
6
|
from molcraft import features
|
|
@@ -8,9 +8,6 @@ from molcraft import features
|
|
|
8
8
|
|
|
9
9
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
10
10
|
class Descriptor(features.Feature):
|
|
11
|
-
def __init__(self, scale: float | None = None, **kwargs):
|
|
12
|
-
super().__init__(**kwargs)
|
|
13
|
-
self.scale = scale
|
|
14
11
|
|
|
15
12
|
def __call__(self, mol: chem.Mol) -> np.ndarray:
|
|
16
13
|
if not isinstance(mol, chem.Mol):
|
|
@@ -24,67 +21,71 @@ class Descriptor(features.Feature):
|
|
|
24
21
|
self._featurize_categorical if self.vocab else
|
|
25
22
|
self._featurize_floating
|
|
26
23
|
)
|
|
27
|
-
scale_value = self.scale and not self.vocab
|
|
28
24
|
if not isinstance(descriptor, (tuple, list, np.ndarray)):
|
|
29
25
|
descriptor = [descriptor]
|
|
30
26
|
|
|
31
27
|
descriptors = []
|
|
32
28
|
for value in descriptor:
|
|
33
|
-
if scale_value:
|
|
34
|
-
value /= self.scale
|
|
35
29
|
descriptors.append(func(value))
|
|
36
30
|
return np.concatenate(descriptors)
|
|
37
31
|
|
|
38
|
-
def get_config(self):
|
|
39
|
-
config = super().get_config()
|
|
40
|
-
config['scale'] = self.scale
|
|
41
|
-
return config
|
|
42
|
-
|
|
43
32
|
|
|
44
33
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
45
34
|
class MolWeight(Descriptor):
|
|
46
35
|
def call(self, mol: chem.Mol) -> np.ndarray:
|
|
47
|
-
return
|
|
36
|
+
return rdMolDescriptors.CalcExactMolWt(mol)
|
|
48
37
|
|
|
49
38
|
|
|
50
39
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
51
|
-
class
|
|
40
|
+
class TPSA(Descriptor):
|
|
52
41
|
def call(self, mol: chem.Mol) -> np.ndarray:
|
|
53
|
-
return
|
|
42
|
+
return rdMolDescriptors.CalcTPSA(mol)
|
|
43
|
+
|
|
54
44
|
|
|
45
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
46
|
+
class CrippenLogP(Descriptor):
|
|
47
|
+
def call(self, mol: chem.Mol) -> np.ndarray:
|
|
48
|
+
return rdMolDescriptors.CalcCrippenDescriptors(mol)[0]
|
|
49
|
+
|
|
55
50
|
|
|
56
51
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
57
|
-
class
|
|
52
|
+
class CrippenMolarRefractivity(Descriptor):
|
|
58
53
|
def call(self, mol: chem.Mol) -> np.ndarray:
|
|
59
|
-
return
|
|
54
|
+
return rdMolDescriptors.CalcCrippenDescriptors(mol)[1]
|
|
60
55
|
|
|
61
56
|
|
|
62
57
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
63
58
|
class NumHeavyAtoms(Descriptor):
|
|
64
59
|
def call(self, mol: chem.Mol) -> np.ndarray:
|
|
65
|
-
return
|
|
60
|
+
return rdMolDescriptors.CalcNumHeavyAtoms(mol)
|
|
66
61
|
|
|
67
62
|
|
|
63
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
64
|
+
class NumHeteroAtoms(Descriptor):
|
|
65
|
+
def call(self, mol: chem.Mol) -> np.ndarray:
|
|
66
|
+
return rdMolDescriptors.CalcNumHeteroatoms(mol)
|
|
67
|
+
|
|
68
|
+
|
|
68
69
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
69
70
|
class NumHydrogenDonors(Descriptor):
|
|
70
71
|
def call(self, mol: chem.Mol) -> np.ndarray:
|
|
71
|
-
return
|
|
72
|
+
return rdMolDescriptors.CalcNumHBD(mol)
|
|
72
73
|
|
|
73
74
|
|
|
74
75
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
75
76
|
class NumHydrogenAcceptors(Descriptor):
|
|
76
77
|
def call(self, mol: chem.Mol) -> np.ndarray:
|
|
77
|
-
return
|
|
78
|
-
|
|
78
|
+
return rdMolDescriptors.CalcNumHBA(mol)
|
|
79
|
+
|
|
79
80
|
|
|
80
81
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
81
82
|
class NumRotatableBonds(Descriptor):
|
|
82
83
|
def call(self, mol: chem.Mol) -> np.ndarray:
|
|
83
|
-
return
|
|
84
|
+
return rdMolDescriptors.CalcNumRotatableBonds(mol)
|
|
84
85
|
|
|
85
86
|
|
|
86
87
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
87
88
|
class NumRings(Descriptor):
|
|
88
89
|
def call(self, mol: chem.Mol) -> np.ndarray:
|
|
89
|
-
return
|
|
90
|
+
return rdMolDescriptors.CalcNumRings(mol)
|
|
90
91
|
|
|
@@ -2,6 +2,7 @@ import re
|
|
|
2
2
|
import keras
|
|
3
3
|
import numpy as np
|
|
4
4
|
import tensorflow as tf
|
|
5
|
+
import tensorflow_text as tf_text
|
|
5
6
|
from rdkit import Chem
|
|
6
7
|
|
|
7
8
|
from molcraft import ops
|
|
@@ -10,12 +11,14 @@ from molcraft import features
|
|
|
10
11
|
from molcraft import featurizers
|
|
11
12
|
from molcraft import tensors
|
|
12
13
|
from molcraft import descriptors
|
|
14
|
+
from molcraft import layers
|
|
13
15
|
|
|
14
16
|
|
|
15
17
|
def Graph(
|
|
16
18
|
inputs,
|
|
17
19
|
atom_features: list[features.Feature] | str | None = 'auto',
|
|
18
20
|
bond_features: list[features.Feature] | str | None = 'auto',
|
|
21
|
+
molecule_features: list[descriptors.Descriptor] | str | None = 'auto',
|
|
19
22
|
super_atom: bool = True,
|
|
20
23
|
radius: int | float | None = None,
|
|
21
24
|
self_loops: bool = False,
|
|
@@ -25,7 +28,7 @@ def Graph(
|
|
|
25
28
|
featurizer = featurizers.MolGraphFeaturizer(
|
|
26
29
|
atom_features=atom_features,
|
|
27
30
|
bond_features=bond_features,
|
|
28
|
-
molecule_features=
|
|
31
|
+
molecule_features=molecule_features,
|
|
29
32
|
super_atom=super_atom,
|
|
30
33
|
radius=radius,
|
|
31
34
|
self_loops=self_loops,
|
|
@@ -33,41 +36,55 @@ def Graph(
|
|
|
33
36
|
**kwargs,
|
|
34
37
|
)
|
|
35
38
|
|
|
36
|
-
|
|
37
|
-
residues[
|
|
39
|
+
tensor_list: list[tensors.GraphTensor] = [
|
|
40
|
+
featurizer(residues[tag]).update({'context': {'tag': tag}}) for tag in inputs
|
|
38
41
|
]
|
|
39
|
-
tensor_list = [featurizer(x) for x in inputs]
|
|
40
42
|
return tf.stack(tensor_list, axis=0)
|
|
43
|
+
|
|
41
44
|
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
lookup = GraphLookupLayer()
|
|
45
|
+
def Lookup(graph: tensors.GraphTensor) -> 'LookupLayer':
|
|
46
|
+
lookup = LookupLayer()
|
|
45
47
|
lookup._build(graph)
|
|
46
48
|
return lookup
|
|
47
49
|
|
|
48
50
|
|
|
49
51
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
50
|
-
class
|
|
52
|
+
class SequenceSplit(keras.layers.Layer):
|
|
53
|
+
|
|
54
|
+
_pattern = "|".join([
|
|
55
|
+
r'(\[[A-Za-z0-9]+\]-[A-Z]\[[A-Za-z0-9]+\])', # N-term mod + mod
|
|
56
|
+
r'([A-Z]\[[A-Za-z0-9]+\]-\[[A-Za-z0-9]+\])', # C-term mod + mod
|
|
57
|
+
r'([A-Z]-\[[A-Za-z0-9]+\])', # C-term mod
|
|
58
|
+
r'(\[[A-Za-z0-9]+\]-[A-Z])', # N-term mod
|
|
59
|
+
r'([A-Z]\[[A-Za-z0-9]+\])', # Mod
|
|
60
|
+
r'([A-Z])', # No mod
|
|
61
|
+
])
|
|
62
|
+
|
|
63
|
+
def call(self, inputs):
|
|
64
|
+
inputs = tf_text.regex_split(inputs, self._pattern, self._pattern)
|
|
65
|
+
inputs = keras.ops.concatenate([
|
|
66
|
+
tf.strings.join([inputs[:, :-1], '-[X]']),
|
|
67
|
+
inputs[:, -1:]
|
|
68
|
+
], axis=1)
|
|
69
|
+
return inputs.to_tensor()
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
|
|
51
73
|
|
|
52
|
-
|
|
53
|
-
|
|
74
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
75
|
+
class LookupLayer(keras.layers.Layer):
|
|
76
|
+
|
|
77
|
+
def __init__(self, **kwargs):
|
|
78
|
+
super().__init__(**kwargs)
|
|
79
|
+
self._sequence_splitter = SequenceSplit()
|
|
80
|
+
|
|
81
|
+
def call(self, sequence: tf.Tensor) -> tensors.GraphTensor:
|
|
82
|
+
sequence = self._sequence_splitter(sequence)
|
|
83
|
+
indices = self._tag_to_index.lookup(sequence)
|
|
84
|
+
indices = tf.sort(tf.unique(tf.reshape(indices, [-1]))[0])[1:]
|
|
54
85
|
graph = self.graph[indices]
|
|
55
|
-
sizes = graph.context['size']
|
|
56
|
-
max_index = keras.ops.max(indices)
|
|
57
|
-
sizes = tf.tensor_scatter_nd_update(
|
|
58
|
-
tensor=tf.zeros([max_index + 1], dtype=indices.dtype),
|
|
59
|
-
indices=indices[:, None],
|
|
60
|
-
updates=sizes
|
|
61
|
-
)
|
|
62
|
-
graph = graph.update(
|
|
63
|
-
{
|
|
64
|
-
'context': {
|
|
65
|
-
'size': sizes
|
|
66
|
-
}
|
|
67
|
-
},
|
|
68
|
-
)
|
|
69
86
|
return tensors.to_dict(graph)
|
|
70
|
-
|
|
87
|
+
|
|
71
88
|
def _build(self, x):
|
|
72
89
|
|
|
73
90
|
if isinstance(x, tensors.GraphTensor):
|
|
@@ -98,7 +115,17 @@ class GraphLookupLayer(keras.layers.Layer):
|
|
|
98
115
|
keras.ops.convert_to_tensor, self._graph
|
|
99
116
|
)
|
|
100
117
|
self._graph_tensor = tensors.from_dict(graph)
|
|
101
|
-
|
|
118
|
+
|
|
119
|
+
tags = self._graph_tensor.context['tag']
|
|
120
|
+
|
|
121
|
+
self._tag_to_index = tf.lookup.StaticHashTable(
|
|
122
|
+
tf.lookup.KeyValueTensorInitializer(
|
|
123
|
+
keys=tags,
|
|
124
|
+
values=range(len(tags)),
|
|
125
|
+
),
|
|
126
|
+
default_value=-1,
|
|
127
|
+
)
|
|
128
|
+
|
|
102
129
|
def get_config(self):
|
|
103
130
|
config = super().get_config()
|
|
104
131
|
spec = keras.saving.serialize_keras_object(self._spec)
|
|
@@ -106,7 +133,7 @@ class GraphLookupLayer(keras.layers.Layer):
|
|
|
106
133
|
return config
|
|
107
134
|
|
|
108
135
|
@classmethod
|
|
109
|
-
def from_config(cls, config: dict) -> '
|
|
136
|
+
def from_config(cls, config: dict) -> 'LookupLayer':
|
|
110
137
|
spec = config.pop('spec')
|
|
111
138
|
spec = keras.saving.deserialize_keras_object(spec)
|
|
112
139
|
layer = cls(**config)
|
|
@@ -130,66 +157,48 @@ class Gather(keras.layers.Layer):
|
|
|
130
157
|
super().__init__(**kwargs)
|
|
131
158
|
self.padding = padding
|
|
132
159
|
self.mask_value = mask_value
|
|
160
|
+
self._readout_layer = layers.Readout(mode='mean')
|
|
133
161
|
self.supports_masking = True
|
|
162
|
+
self._sequence_splitter = SequenceSplit()
|
|
134
163
|
|
|
135
164
|
def get_config(self):
|
|
136
165
|
config = super().get_config()
|
|
137
166
|
config['mask_value'] = self.mask_value
|
|
138
167
|
config['padding'] = self.padding
|
|
139
168
|
return config
|
|
140
|
-
|
|
141
|
-
def call(self, inputs
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
169
|
+
|
|
170
|
+
def call(self, inputs) -> tf.Tensor:
|
|
171
|
+
|
|
172
|
+
graph, sequence = inputs
|
|
173
|
+
|
|
174
|
+
tag = graph['context']['tag']
|
|
175
|
+
data = self._readout_layer(graph)
|
|
176
|
+
|
|
177
|
+
table = tf.lookup.experimental.MutableHashTable(
|
|
178
|
+
key_dtype=tf.string,
|
|
179
|
+
value_dtype=tf.int32,
|
|
180
|
+
default_value=-1
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
table.insert(tag, tf.range(tf.shape(tag)[0]))
|
|
184
|
+
sequence = self._sequence_splitter(sequence)
|
|
185
|
+
sequence = table.lookup(sequence)
|
|
186
|
+
|
|
187
|
+
readout = ops.gather(data, keras.ops.where(sequence == -1, 0, sequence))
|
|
188
|
+
readout = keras.ops.where(sequence[..., None] == -1, 0.0, readout)
|
|
189
|
+
return readout
|
|
190
|
+
|
|
155
191
|
def compute_mask(
|
|
156
192
|
self,
|
|
157
|
-
inputs:
|
|
193
|
+
inputs: tensors.GraphTensor,
|
|
158
194
|
mask: bool | None = None
|
|
159
195
|
) -> tf.Tensor | None:
|
|
160
196
|
# if self.mask_value is None:
|
|
161
197
|
# return None
|
|
162
|
-
_,
|
|
163
|
-
|
|
164
|
-
|
|
198
|
+
_, sequence = inputs
|
|
199
|
+
sequence = self._sequence_splitter(sequence)
|
|
200
|
+
return keras.ops.not_equal(sequence, '')
|
|
165
201
|
|
|
166
|
-
@keras.saving.register_keras_serializable(package='molcraft')
|
|
167
|
-
class AminoAcidType(descriptors.Descriptor):
|
|
168
|
-
|
|
169
|
-
def __init__(self, vocab=None, **kwargs):
|
|
170
|
-
vocab = [
|
|
171
|
-
"A", "C", "D", "E", "F", "G", "H", "I", "K", "L",
|
|
172
|
-
"M", "N", "P", "Q", "R", "S", "T", "V", "W", "Y",
|
|
173
|
-
]
|
|
174
|
-
super().__init__(vocab=vocab, **kwargs)
|
|
175
|
-
|
|
176
|
-
def call(self, mol: chem.Mol) -> list[str]:
|
|
177
|
-
residue = residues_reverse.get(mol.canonical_smiles)
|
|
178
|
-
if not residue:
|
|
179
|
-
raise KeyError(f'Could not find {mol.canonical_smiles} in `residues_reverse`.')
|
|
180
|
-
mol = chem.remove_hs(mol)
|
|
181
|
-
return _extract_residue_type(residues_reverse[mol.canonical_smiles])
|
|
182
|
-
|
|
183
|
-
def sequence_split(sequence: str):
|
|
184
|
-
patterns = [
|
|
185
|
-
r'(\[[A-Za-z0-9]+\]-[A-Z]\[[A-Za-z0-9]+\])', # N-term mod + mod
|
|
186
|
-
r'([A-Z]\[[A-Za-z0-9]+\]-\[[A-Za-z0-9]+\])', # C-term mod + mod
|
|
187
|
-
r'([A-Z]-\[[A-Za-z0-9]+\])', # C-term mod
|
|
188
|
-
r'(\[[A-Za-z0-9]+\]-[A-Z])', # N-term mod
|
|
189
|
-
r'([A-Z]\[[A-Za-z0-9]+\])', # Mod
|
|
190
|
-
r'([A-Z])', # No mod
|
|
191
|
-
]
|
|
192
|
-
return [match.group(0) for match in re.finditer("|".join(patterns), sequence)]
|
|
193
202
|
|
|
194
203
|
residues = {
|
|
195
204
|
"A": "N[C@@H](C)C(=O)O",
|
|
@@ -252,13 +261,21 @@ residues = {
|
|
|
252
261
|
}
|
|
253
262
|
|
|
254
263
|
residues_reverse = {}
|
|
255
|
-
def register_peptide_residues(
|
|
256
|
-
for residue, smiles in
|
|
257
|
-
|
|
264
|
+
def register_peptide_residues(residues_: dict[str, str], canonicalize=True):
|
|
265
|
+
for residue, smiles in residues_.items():
|
|
266
|
+
if canonicalize:
|
|
267
|
+
smiles = Chem.MolToSmiles(Chem.MolFromSmiles(smiles))
|
|
268
|
+
residues[residue] = smiles
|
|
258
269
|
residues_reverse[residues[residue]] = residue
|
|
259
270
|
|
|
260
|
-
register_peptide_residues(residues)
|
|
271
|
+
register_peptide_residues(residues, canonicalize=False)
|
|
261
272
|
|
|
262
273
|
def _extract_residue_type(residue_tag: str) -> str:
|
|
263
|
-
pattern = r"(?<!\[)[A-Z](?![
|
|
264
|
-
return [match.group(0) for match in re.finditer(pattern, residue_tag)][0]
|
|
274
|
+
pattern = r"(?<!\[)[A-Z](?![^\[]*\])"
|
|
275
|
+
return [match.group(0) for match in re.finditer(pattern, residue_tag)][0]
|
|
276
|
+
|
|
277
|
+
special_residues = {}
|
|
278
|
+
for key, value in residues.items():
|
|
279
|
+
special_residues[key + '-[X]'] = value.rstrip('O')
|
|
280
|
+
|
|
281
|
+
register_peptide_residues(special_residues, canonicalize=False)
|
|
@@ -155,9 +155,11 @@ class Distance(EdgeFeature):
|
|
|
155
155
|
encode_oov: bool = True,
|
|
156
156
|
**kwargs,
|
|
157
157
|
) -> None:
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
158
|
+
vocab = kwargs.pop('vocab', None)
|
|
159
|
+
if not vocab:
|
|
160
|
+
if max_distance is None:
|
|
161
|
+
max_distance = 20
|
|
162
|
+
vocab = list(range(max_distance + 1))
|
|
161
163
|
super().__init__(
|
|
162
164
|
vocab=vocab,
|
|
163
165
|
allow_oov=allow_oov,
|
|
@@ -186,9 +186,11 @@ class MolGraphFeaturizer(Featurizer):
|
|
|
186
186
|
if default_molecule_features:
|
|
187
187
|
molecule_features = [
|
|
188
188
|
descriptors.MolWeight(),
|
|
189
|
-
descriptors.
|
|
190
|
-
descriptors.
|
|
189
|
+
descriptors.TPSA(),
|
|
190
|
+
descriptors.CrippenLogP(),
|
|
191
|
+
descriptors.CrippenMolarRefractivity(),
|
|
191
192
|
descriptors.NumHeavyAtoms(),
|
|
193
|
+
descriptors.NumHeteroAtoms(),
|
|
192
194
|
descriptors.NumHydrogenDonors(),
|
|
193
195
|
descriptors.NumHydrogenAcceptors(),
|
|
194
196
|
descriptors.NumRotatableBonds(),
|
|
@@ -219,7 +221,7 @@ class MolGraphFeaturizer(Featurizer):
|
|
|
219
221
|
|
|
220
222
|
atom_feature = self.atom_features(mol)
|
|
221
223
|
bond_feature = self.bond_features(mol)
|
|
222
|
-
|
|
224
|
+
molecule_feature = self.molecule_feature(mol)
|
|
223
225
|
molecule_size = self.num_atoms(mol)
|
|
224
226
|
|
|
225
227
|
if isinstance(context, dict):
|
|
@@ -241,8 +243,14 @@ class MolGraphFeaturizer(Featurizer):
|
|
|
241
243
|
else:
|
|
242
244
|
context = {'size': molecule_size}
|
|
243
245
|
|
|
244
|
-
if
|
|
245
|
-
|
|
246
|
+
if molecule_feature is not None:
|
|
247
|
+
if 'feature' in context:
|
|
248
|
+
warn(
|
|
249
|
+
'Found both inputted and computed context feature. '
|
|
250
|
+
'Overwriting inputted context feature with computed '
|
|
251
|
+
'context feature (based on `molecule_features`).'
|
|
252
|
+
)
|
|
253
|
+
context['feature'] = molecule_feature
|
|
246
254
|
|
|
247
255
|
node = {}
|
|
248
256
|
node['feature'] = atom_feature
|
|
@@ -288,7 +296,7 @@ class MolGraphFeaturizer(Featurizer):
|
|
|
288
296
|
edge['feature'] = self._expand_bond_features(
|
|
289
297
|
mol, paths, bond_feature,
|
|
290
298
|
)
|
|
291
|
-
edge['length'] = np.eye(self.radius + 1)[edge['length']]
|
|
299
|
+
edge['length'] = np.eye(self.radius + 1, dtype=self.feature_dtype)[edge['length']]
|
|
292
300
|
|
|
293
301
|
if self.super_atom:
|
|
294
302
|
node, edge = self._add_super_atom(node, edge)
|
|
@@ -315,13 +323,13 @@ class MolGraphFeaturizer(Featurizer):
|
|
|
315
323
|
)
|
|
316
324
|
return bond_feature.astype(self.feature_dtype)
|
|
317
325
|
|
|
318
|
-
def
|
|
326
|
+
def molecule_feature(self, mol: chem.Mol) -> np.ndarray:
|
|
319
327
|
if self._molecule_features is None:
|
|
320
328
|
return None
|
|
321
|
-
|
|
329
|
+
molecule_feature: np.ndarray = np.concatenate(
|
|
322
330
|
[f(mol) for f in self._molecule_features], axis=-1
|
|
323
331
|
)
|
|
324
|
-
return
|
|
332
|
+
return molecule_feature.astype(self.feature_dtype)
|
|
325
333
|
|
|
326
334
|
def num_atoms(self, mol: chem.Mol) -> np.ndarray:
|
|
327
335
|
return np.asarray(mol.num_atoms, dtype=self.index_dtype)
|
|
@@ -361,9 +369,7 @@ class MolGraphFeaturizer(Featurizer):
|
|
|
361
369
|
) -> tuple[dict[str, np.ndarray]]:
|
|
362
370
|
num_super_nodes = 1
|
|
363
371
|
num_nodes = node['feature'].shape[0]
|
|
364
|
-
node = _add_super_nodes(
|
|
365
|
-
node, num_super_nodes, self.feature_dtype
|
|
366
|
-
)
|
|
372
|
+
node = _add_super_nodes(node, num_super_nodes)
|
|
367
373
|
edge = _add_super_edges(
|
|
368
374
|
edge, num_nodes, num_super_nodes, self.feature_dtype, self.index_dtype
|
|
369
375
|
)
|
|
@@ -402,6 +408,7 @@ class MolGraphFeaturizer(Featurizer):
|
|
|
402
408
|
return cls(**config)
|
|
403
409
|
|
|
404
410
|
|
|
411
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
405
412
|
class MolGraphFeaturizer3D(MolGraphFeaturizer):
|
|
406
413
|
|
|
407
414
|
"""Molecular 3d-graph featurizer.
|
|
@@ -543,7 +550,7 @@ class MolGraphFeaturizer3D(MolGraphFeaturizer):
|
|
|
543
550
|
'of the `Featurizer` or input a 3D representation of the molecule. '
|
|
544
551
|
)
|
|
545
552
|
|
|
546
|
-
|
|
553
|
+
molecule_feature = self.molecule_feature(mol)
|
|
547
554
|
molecule_size = self.num_atoms(mol) + int(self.super_atom)
|
|
548
555
|
|
|
549
556
|
if isinstance(context, dict):
|
|
@@ -565,8 +572,14 @@ class MolGraphFeaturizer3D(MolGraphFeaturizer):
|
|
|
565
572
|
else:
|
|
566
573
|
context = {'size': molecule_size}
|
|
567
574
|
|
|
568
|
-
if
|
|
569
|
-
|
|
575
|
+
if molecule_feature is not None:
|
|
576
|
+
if 'feature' in context:
|
|
577
|
+
warn(
|
|
578
|
+
'Found both inputted and computed context feature. '
|
|
579
|
+
'Overwriting inputted context feature with computed '
|
|
580
|
+
'context feature (based on `molecule_features`).'
|
|
581
|
+
)
|
|
582
|
+
context['feature'] = molecule_feature
|
|
570
583
|
|
|
571
584
|
node = {}
|
|
572
585
|
node['feature'] = self.atom_features(mol)
|
|
@@ -586,7 +599,7 @@ class MolGraphFeaturizer3D(MolGraphFeaturizer):
|
|
|
586
599
|
radius=self.radius,
|
|
587
600
|
sparse=False,
|
|
588
601
|
self_loops=self.self_loops,
|
|
589
|
-
dtype=
|
|
602
|
+
dtype=bool
|
|
590
603
|
)
|
|
591
604
|
edge_conformer['source'], edge_conformer['target'] = np.where(adjacency_matrix)
|
|
592
605
|
edge_conformer['source'] = edge_conformer['source'].astype(self.index_dtype)
|
|
@@ -603,7 +616,7 @@ class MolGraphFeaturizer3D(MolGraphFeaturizer):
|
|
|
603
616
|
)
|
|
604
617
|
node_conformer['coordinate'] = np.concatenate(
|
|
605
618
|
[node_conformer['coordinate'], conformer.centroid[None]], axis=0
|
|
606
|
-
)
|
|
619
|
+
).astype(self.feature_dtype)
|
|
607
620
|
tensor_list.append(
|
|
608
621
|
tensors.GraphTensor(context, node_conformer, edge_conformer)
|
|
609
622
|
)
|
|
@@ -627,7 +640,7 @@ class MolGraphFeaturizer3D(MolGraphFeaturizer):
|
|
|
627
640
|
config['conformer_generator'] = keras.saving.deserialize_keras_object(
|
|
628
641
|
config['conformer_generator']
|
|
629
642
|
)
|
|
630
|
-
return super().from_config(
|
|
643
|
+
return super().from_config(config)
|
|
631
644
|
|
|
632
645
|
|
|
633
646
|
def save_featurizer(
|
|
@@ -669,12 +682,15 @@ def load_featurizer(
|
|
|
669
682
|
def _add_super_nodes(
|
|
670
683
|
node: dict[str, np.ndarray],
|
|
671
684
|
num_super_nodes: int = 1,
|
|
672
|
-
feature_dtype: str = 'float32',
|
|
673
685
|
) -> dict[str, np.ndarray]:
|
|
674
686
|
node = copy.deepcopy(node)
|
|
675
|
-
node['super'] = np.array(
|
|
687
|
+
node['super'] = np.array(
|
|
688
|
+
[False] * len(node['feature']) + [True] * num_super_nodes,
|
|
689
|
+
dtype=bool
|
|
690
|
+
)
|
|
676
691
|
super_node_feature = np.zeros(
|
|
677
|
-
[num_super_nodes, node['feature'].shape[-1]],
|
|
692
|
+
[num_super_nodes, node['feature'].shape[-1]],
|
|
693
|
+
dtype=node['feature'].dtype
|
|
678
694
|
)
|
|
679
695
|
node['feature'] = np.concatenate([node['feature'], super_node_feature])
|
|
680
696
|
return node
|
|
@@ -694,31 +710,38 @@ def _add_super_edges(
|
|
|
694
710
|
np.tile(np.arange(num_nodes), [num_super_nodes])
|
|
695
711
|
)
|
|
696
712
|
edge['source'] = np.concatenate(
|
|
697
|
-
[
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
super_node_indices,
|
|
701
|
-
]
|
|
702
|
-
)
|
|
703
|
-
edge['source'] = edge['source'].astype(index_dtype)
|
|
713
|
+
[edge['source'], node_indices, super_node_indices]
|
|
714
|
+
).astype(index_dtype)
|
|
715
|
+
|
|
704
716
|
edge['target'] = np.concatenate(
|
|
705
|
-
[
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
node_indices
|
|
709
|
-
]
|
|
710
|
-
)
|
|
711
|
-
edge['target'] = edge['target'].astype(index_dtype)
|
|
717
|
+
[edge['target'], super_node_indices, node_indices]
|
|
718
|
+
).astype(index_dtype)
|
|
719
|
+
|
|
712
720
|
if 'feature' in edge:
|
|
713
|
-
|
|
714
|
-
|
|
721
|
+
num_edges = int(edge['feature'].shape[0])
|
|
722
|
+
num_super_edges = int(num_super_nodes * num_nodes * 2)
|
|
723
|
+
edge['super'] = np.asarray(
|
|
724
|
+
([False] * num_edges + [True] * num_super_edges),
|
|
725
|
+
dtype=bool
|
|
726
|
+
)
|
|
727
|
+
edge['feature'] = np.concatenate(
|
|
728
|
+
[
|
|
729
|
+
edge['feature'],
|
|
730
|
+
np.zeros(
|
|
731
|
+
shape=(num_super_edges, edge['feature'].shape[-1]),
|
|
732
|
+
dtype=edge['feature'].dtype
|
|
733
|
+
)
|
|
734
|
+
]
|
|
735
|
+
)
|
|
736
|
+
|
|
715
737
|
if 'length' in edge:
|
|
716
738
|
edge['length'] = np.pad(edge['length'], [(0, 0), (1, 0)])
|
|
717
|
-
zero_array = np.zeros(
|
|
739
|
+
zero_array = np.zeros([num_nodes * num_super_nodes * 2], dtype='int32')
|
|
718
740
|
edge_length_dim = edge['length'].shape[1]
|
|
719
741
|
virtual_edge_length = np.eye(edge_length_dim)[zero_array]
|
|
720
742
|
edge['length'] = np.concatenate([edge['length'], virtual_edge_length])
|
|
721
743
|
edge['length'] = edge['length'].astype(feature_dtype)
|
|
744
|
+
|
|
722
745
|
return edge
|
|
723
746
|
|
|
724
747
|
|