molcraft 0.1.0a3__tar.gz → 0.1.0a4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of molcraft might be problematic. Click here for more details.
- {molcraft-0.1.0a3 → molcraft-0.1.0a4}/PKG-INFO +6 -6
- {molcraft-0.1.0a3 → molcraft-0.1.0a4}/README.md +5 -5
- {molcraft-0.1.0a3 → molcraft-0.1.0a4}/molcraft/__init__.py +1 -1
- {molcraft-0.1.0a3 → molcraft-0.1.0a4}/molcraft/callbacks.py +12 -0
- {molcraft-0.1.0a3 → molcraft-0.1.0a4}/molcraft/descriptors.py +24 -23
- {molcraft-0.1.0a3 → molcraft-0.1.0a4}/molcraft/experimental/peptides.py +96 -79
- {molcraft-0.1.0a3 → molcraft-0.1.0a4}/molcraft/featurizers.py +59 -37
- {molcraft-0.1.0a3 → molcraft-0.1.0a4}/molcraft/layers.py +692 -565
- {molcraft-0.1.0a3 → molcraft-0.1.0a4}/molcraft/models.py +14 -1
- {molcraft-0.1.0a3 → molcraft-0.1.0a4}/molcraft/ops.py +14 -3
- {molcraft-0.1.0a3 → molcraft-0.1.0a4}/molcraft/tensors.py +3 -3
- {molcraft-0.1.0a3 → molcraft-0.1.0a4}/molcraft.egg-info/PKG-INFO +6 -6
- molcraft-0.1.0a4/tests/test_featurizers.py +197 -0
- {molcraft-0.1.0a3 → molcraft-0.1.0a4}/tests/test_layers.py +47 -28
- {molcraft-0.1.0a3 → molcraft-0.1.0a4}/tests/test_models.py +41 -1
- molcraft-0.1.0a3/tests/test_featurizers.py +0 -111
- {molcraft-0.1.0a3 → molcraft-0.1.0a4}/LICENSE +0 -0
- {molcraft-0.1.0a3 → molcraft-0.1.0a4}/molcraft/chem.py +0 -0
- {molcraft-0.1.0a3 → molcraft-0.1.0a4}/molcraft/conformers.py +0 -0
- {molcraft-0.1.0a3 → molcraft-0.1.0a4}/molcraft/datasets.py +0 -0
- {molcraft-0.1.0a3 → molcraft-0.1.0a4}/molcraft/experimental/__init__.py +0 -0
- {molcraft-0.1.0a3 → molcraft-0.1.0a4}/molcraft/features.py +0 -0
- {molcraft-0.1.0a3 → molcraft-0.1.0a4}/molcraft/records.py +0 -0
- {molcraft-0.1.0a3 → molcraft-0.1.0a4}/molcraft.egg-info/SOURCES.txt +0 -0
- {molcraft-0.1.0a3 → molcraft-0.1.0a4}/molcraft.egg-info/dependency_links.txt +0 -0
- {molcraft-0.1.0a3 → molcraft-0.1.0a4}/molcraft.egg-info/requires.txt +0 -0
- {molcraft-0.1.0a3 → molcraft-0.1.0a4}/molcraft.egg-info/top_level.txt +0 -0
- {molcraft-0.1.0a3 → molcraft-0.1.0a4}/pyproject.toml +0 -0
- {molcraft-0.1.0a3 → molcraft-0.1.0a4}/setup.cfg +0 -0
- {molcraft-0.1.0a3 → molcraft-0.1.0a4}/tests/test_chem.py +0 -0
- {molcraft-0.1.0a3 → molcraft-0.1.0a4}/tests/test_tensors.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: molcraft
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.0a4
|
|
4
4
|
Summary: Graph Neural Networks for Molecular Machine Learning
|
|
5
5
|
Author-email: Alexander Kensert <alexander.kensert@gmail.com>
|
|
6
6
|
License: MIT License
|
|
@@ -51,11 +51,11 @@ Dynamic: license-file
|
|
|
51
51
|
|
|
52
52
|
## Highlights
|
|
53
53
|
- Compatible with **Keras 3**
|
|
54
|
-
-
|
|
55
|
-
-
|
|
56
|
-
-
|
|
57
|
-
-
|
|
58
|
-
-
|
|
54
|
+
- Customizable and serializable **featurizers**
|
|
55
|
+
- Customizable and serializable **layers** and **models**
|
|
56
|
+
- Customizable **GraphTensor**
|
|
57
|
+
- Fast and efficient featurization of molecular graphs
|
|
58
|
+
- Efficient and easy-to-use input pipelines using TF **records**
|
|
59
59
|
|
|
60
60
|
## Examples
|
|
61
61
|
|
|
@@ -7,11 +7,11 @@
|
|
|
7
7
|
|
|
8
8
|
## Highlights
|
|
9
9
|
- Compatible with **Keras 3**
|
|
10
|
-
-
|
|
11
|
-
-
|
|
12
|
-
-
|
|
13
|
-
-
|
|
14
|
-
-
|
|
10
|
+
- Customizable and serializable **featurizers**
|
|
11
|
+
- Customizable and serializable **layers** and **models**
|
|
12
|
+
- Customizable **GraphTensor**
|
|
13
|
+
- Fast and efficient featurization of molecular graphs
|
|
14
|
+
- Efficient and easy-to-use input pipelines using TF **records**
|
|
15
15
|
|
|
16
16
|
## Examples
|
|
17
17
|
|
|
@@ -19,3 +19,15 @@ class TensorBoard(keras.callbacks.TensorBoard):
|
|
|
19
19
|
weight, image_weight_name, epoch
|
|
20
20
|
)
|
|
21
21
|
self._train_writer.flush()
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class LearningRateDecay(keras.callbacks.LearningRateScheduler):
|
|
25
|
+
|
|
26
|
+
def __init__(self, rate: float, delay: int = 0, **kwargs):
|
|
27
|
+
|
|
28
|
+
def lr_schedule(epoch: int, lr: float):
|
|
29
|
+
if epoch < delay:
|
|
30
|
+
return float(lr)
|
|
31
|
+
return float(lr * keras.ops.exp(-rate))
|
|
32
|
+
|
|
33
|
+
super().__init__(schedule=lr_schedule, **kwargs)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import keras
|
|
2
2
|
import numpy as np
|
|
3
|
-
from rdkit.Chem import
|
|
3
|
+
from rdkit.Chem import rdMolDescriptors
|
|
4
4
|
|
|
5
5
|
from molcraft import chem
|
|
6
6
|
from molcraft import features
|
|
@@ -8,9 +8,6 @@ from molcraft import features
|
|
|
8
8
|
|
|
9
9
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
10
10
|
class Descriptor(features.Feature):
|
|
11
|
-
def __init__(self, scale: float | None = None, **kwargs):
|
|
12
|
-
super().__init__(**kwargs)
|
|
13
|
-
self.scale = scale
|
|
14
11
|
|
|
15
12
|
def __call__(self, mol: chem.Mol) -> np.ndarray:
|
|
16
13
|
if not isinstance(mol, chem.Mol):
|
|
@@ -24,67 +21,71 @@ class Descriptor(features.Feature):
|
|
|
24
21
|
self._featurize_categorical if self.vocab else
|
|
25
22
|
self._featurize_floating
|
|
26
23
|
)
|
|
27
|
-
scale_value = self.scale and not self.vocab
|
|
28
24
|
if not isinstance(descriptor, (tuple, list, np.ndarray)):
|
|
29
25
|
descriptor = [descriptor]
|
|
30
26
|
|
|
31
27
|
descriptors = []
|
|
32
28
|
for value in descriptor:
|
|
33
|
-
if scale_value:
|
|
34
|
-
value /= self.scale
|
|
35
29
|
descriptors.append(func(value))
|
|
36
30
|
return np.concatenate(descriptors)
|
|
37
31
|
|
|
38
|
-
def get_config(self):
|
|
39
|
-
config = super().get_config()
|
|
40
|
-
config['scale'] = self.scale
|
|
41
|
-
return config
|
|
42
|
-
|
|
43
32
|
|
|
44
33
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
45
34
|
class MolWeight(Descriptor):
|
|
46
35
|
def call(self, mol: chem.Mol) -> np.ndarray:
|
|
47
|
-
return
|
|
36
|
+
return rdMolDescriptors.CalcExactMolWt(mol)
|
|
48
37
|
|
|
49
38
|
|
|
50
39
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
51
|
-
class
|
|
40
|
+
class TPSA(Descriptor):
|
|
52
41
|
def call(self, mol: chem.Mol) -> np.ndarray:
|
|
53
|
-
return
|
|
42
|
+
return rdMolDescriptors.CalcTPSA(mol)
|
|
43
|
+
|
|
54
44
|
|
|
45
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
46
|
+
class CrippenLogP(Descriptor):
|
|
47
|
+
def call(self, mol: chem.Mol) -> np.ndarray:
|
|
48
|
+
return rdMolDescriptors.CalcCrippenDescriptors(mol)[0]
|
|
49
|
+
|
|
55
50
|
|
|
56
51
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
57
|
-
class
|
|
52
|
+
class CrippenMolarRefractivity(Descriptor):
|
|
58
53
|
def call(self, mol: chem.Mol) -> np.ndarray:
|
|
59
|
-
return
|
|
54
|
+
return rdMolDescriptors.CalcCrippenDescriptors(mol)[1]
|
|
60
55
|
|
|
61
56
|
|
|
62
57
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
63
58
|
class NumHeavyAtoms(Descriptor):
|
|
64
59
|
def call(self, mol: chem.Mol) -> np.ndarray:
|
|
65
|
-
return
|
|
60
|
+
return rdMolDescriptors.CalcNumHeavyAtoms(mol)
|
|
66
61
|
|
|
67
62
|
|
|
63
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
64
|
+
class NumHeteroAtoms(Descriptor):
|
|
65
|
+
def call(self, mol: chem.Mol) -> np.ndarray:
|
|
66
|
+
return rdMolDescriptors.CalcNumHeteroatoms(mol)
|
|
67
|
+
|
|
68
|
+
|
|
68
69
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
69
70
|
class NumHydrogenDonors(Descriptor):
|
|
70
71
|
def call(self, mol: chem.Mol) -> np.ndarray:
|
|
71
|
-
return
|
|
72
|
+
return rdMolDescriptors.CalcNumHBD(mol)
|
|
72
73
|
|
|
73
74
|
|
|
74
75
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
75
76
|
class NumHydrogenAcceptors(Descriptor):
|
|
76
77
|
def call(self, mol: chem.Mol) -> np.ndarray:
|
|
77
|
-
return
|
|
78
|
-
|
|
78
|
+
return rdMolDescriptors.CalcNumHBA(mol)
|
|
79
|
+
|
|
79
80
|
|
|
80
81
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
81
82
|
class NumRotatableBonds(Descriptor):
|
|
82
83
|
def call(self, mol: chem.Mol) -> np.ndarray:
|
|
83
|
-
return
|
|
84
|
+
return rdMolDescriptors.CalcNumRotatableBonds(mol)
|
|
84
85
|
|
|
85
86
|
|
|
86
87
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
87
88
|
class NumRings(Descriptor):
|
|
88
89
|
def call(self, mol: chem.Mol) -> np.ndarray:
|
|
89
|
-
return
|
|
90
|
+
return rdMolDescriptors.CalcNumRings(mol)
|
|
90
91
|
|
|
@@ -2,6 +2,7 @@ import re
|
|
|
2
2
|
import keras
|
|
3
3
|
import numpy as np
|
|
4
4
|
import tensorflow as tf
|
|
5
|
+
import tensorflow_text as tf_text
|
|
5
6
|
from rdkit import Chem
|
|
6
7
|
|
|
7
8
|
from molcraft import ops
|
|
@@ -10,12 +11,14 @@ from molcraft import features
|
|
|
10
11
|
from molcraft import featurizers
|
|
11
12
|
from molcraft import tensors
|
|
12
13
|
from molcraft import descriptors
|
|
14
|
+
from molcraft import layers
|
|
13
15
|
|
|
14
16
|
|
|
15
17
|
def Graph(
|
|
16
18
|
inputs,
|
|
17
19
|
atom_features: list[features.Feature] | str | None = 'auto',
|
|
18
20
|
bond_features: list[features.Feature] | str | None = 'auto',
|
|
21
|
+
molecule_features: list[descriptors.Descriptor] | str | None = 'auto',
|
|
19
22
|
super_atom: bool = True,
|
|
20
23
|
radius: int | float | None = None,
|
|
21
24
|
self_loops: bool = False,
|
|
@@ -25,7 +28,7 @@ def Graph(
|
|
|
25
28
|
featurizer = featurizers.MolGraphFeaturizer(
|
|
26
29
|
atom_features=atom_features,
|
|
27
30
|
bond_features=bond_features,
|
|
28
|
-
molecule_features=
|
|
31
|
+
molecule_features=molecule_features,
|
|
29
32
|
super_atom=super_atom,
|
|
30
33
|
radius=radius,
|
|
31
34
|
self_loops=self_loops,
|
|
@@ -33,41 +36,55 @@ def Graph(
|
|
|
33
36
|
**kwargs,
|
|
34
37
|
)
|
|
35
38
|
|
|
36
|
-
|
|
37
|
-
residues[
|
|
39
|
+
tensor_list: list[tensors.GraphTensor] = [
|
|
40
|
+
featurizer(residues[tag]).update({'context': {'tag': tag}}) for tag in inputs
|
|
38
41
|
]
|
|
39
|
-
tensor_list = [featurizer(x) for x in inputs]
|
|
40
42
|
return tf.stack(tensor_list, axis=0)
|
|
43
|
+
|
|
41
44
|
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
lookup = GraphLookupLayer()
|
|
45
|
+
def Lookup(graph: tensors.GraphTensor) -> 'LookupLayer':
|
|
46
|
+
lookup = LookupLayer()
|
|
45
47
|
lookup._build(graph)
|
|
46
48
|
return lookup
|
|
47
49
|
|
|
48
50
|
|
|
49
51
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
50
|
-
class
|
|
52
|
+
class SequenceSplit(keras.layers.Layer):
|
|
53
|
+
|
|
54
|
+
_pattern = "|".join([
|
|
55
|
+
r'(\[[A-Za-z0-9]+\]-[A-Z]\[[A-Za-z0-9]+\])', # N-term mod + mod
|
|
56
|
+
r'([A-Z]\[[A-Za-z0-9]+\]-\[[A-Za-z0-9]+\])', # C-term mod + mod
|
|
57
|
+
r'([A-Z]-\[[A-Za-z0-9]+\])', # C-term mod
|
|
58
|
+
r'(\[[A-Za-z0-9]+\]-[A-Z])', # N-term mod
|
|
59
|
+
r'([A-Z]\[[A-Za-z0-9]+\])', # Mod
|
|
60
|
+
r'([A-Z])', # No mod
|
|
61
|
+
])
|
|
62
|
+
|
|
63
|
+
def call(self, inputs):
|
|
64
|
+
inputs = tf_text.regex_split(inputs, self._pattern, self._pattern)
|
|
65
|
+
inputs = keras.ops.concatenate([
|
|
66
|
+
tf.strings.join([inputs[:, :-1], '-[X]']),
|
|
67
|
+
inputs[:, -1:]
|
|
68
|
+
], axis=1)
|
|
69
|
+
return inputs.to_tensor()
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
|
|
51
73
|
|
|
52
|
-
|
|
53
|
-
|
|
74
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
75
|
+
class LookupLayer(keras.layers.Layer):
|
|
76
|
+
|
|
77
|
+
def __init__(self, **kwargs):
|
|
78
|
+
super().__init__(**kwargs)
|
|
79
|
+
self._sequence_splitter = SequenceSplit()
|
|
80
|
+
|
|
81
|
+
def call(self, sequence: tf.Tensor) -> tensors.GraphTensor:
|
|
82
|
+
sequence = self._sequence_splitter(sequence)
|
|
83
|
+
indices = self._tag_to_index.lookup(sequence)
|
|
84
|
+
indices = tf.sort(tf.unique(tf.reshape(indices, [-1]))[0])[1:]
|
|
54
85
|
graph = self.graph[indices]
|
|
55
|
-
sizes = graph.context['size']
|
|
56
|
-
max_index = keras.ops.max(indices)
|
|
57
|
-
sizes = tf.tensor_scatter_nd_update(
|
|
58
|
-
tensor=tf.zeros([max_index + 1], dtype=indices.dtype),
|
|
59
|
-
indices=indices[:, None],
|
|
60
|
-
updates=sizes
|
|
61
|
-
)
|
|
62
|
-
graph = graph.update(
|
|
63
|
-
{
|
|
64
|
-
'context': {
|
|
65
|
-
'size': sizes
|
|
66
|
-
}
|
|
67
|
-
},
|
|
68
|
-
)
|
|
69
86
|
return tensors.to_dict(graph)
|
|
70
|
-
|
|
87
|
+
|
|
71
88
|
def _build(self, x):
|
|
72
89
|
|
|
73
90
|
if isinstance(x, tensors.GraphTensor):
|
|
@@ -98,7 +115,17 @@ class GraphLookupLayer(keras.layers.Layer):
|
|
|
98
115
|
keras.ops.convert_to_tensor, self._graph
|
|
99
116
|
)
|
|
100
117
|
self._graph_tensor = tensors.from_dict(graph)
|
|
101
|
-
|
|
118
|
+
|
|
119
|
+
tags = self._graph_tensor.context['tag']
|
|
120
|
+
|
|
121
|
+
self._tag_to_index = tf.lookup.StaticHashTable(
|
|
122
|
+
tf.lookup.KeyValueTensorInitializer(
|
|
123
|
+
keys=tags,
|
|
124
|
+
values=range(len(tags)),
|
|
125
|
+
),
|
|
126
|
+
default_value=-1,
|
|
127
|
+
)
|
|
128
|
+
|
|
102
129
|
def get_config(self):
|
|
103
130
|
config = super().get_config()
|
|
104
131
|
spec = keras.saving.serialize_keras_object(self._spec)
|
|
@@ -106,7 +133,7 @@ class GraphLookupLayer(keras.layers.Layer):
|
|
|
106
133
|
return config
|
|
107
134
|
|
|
108
135
|
@classmethod
|
|
109
|
-
def from_config(cls, config: dict) -> '
|
|
136
|
+
def from_config(cls, config: dict) -> 'LookupLayer':
|
|
110
137
|
spec = config.pop('spec')
|
|
111
138
|
spec = keras.saving.deserialize_keras_object(spec)
|
|
112
139
|
layer = cls(**config)
|
|
@@ -130,66 +157,48 @@ class Gather(keras.layers.Layer):
|
|
|
130
157
|
super().__init__(**kwargs)
|
|
131
158
|
self.padding = padding
|
|
132
159
|
self.mask_value = mask_value
|
|
160
|
+
self._readout_layer = layers.Readout(mode='mean')
|
|
133
161
|
self.supports_masking = True
|
|
162
|
+
self._sequence_splitter = SequenceSplit()
|
|
134
163
|
|
|
135
164
|
def get_config(self):
|
|
136
165
|
config = super().get_config()
|
|
137
166
|
config['mask_value'] = self.mask_value
|
|
138
167
|
config['padding'] = self.padding
|
|
139
168
|
return config
|
|
140
|
-
|
|
141
|
-
def call(self, inputs
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
169
|
+
|
|
170
|
+
def call(self, inputs) -> tf.Tensor:
|
|
171
|
+
|
|
172
|
+
graph, sequence = inputs
|
|
173
|
+
|
|
174
|
+
tag = graph['context']['tag']
|
|
175
|
+
data = self._readout_layer(graph)
|
|
176
|
+
|
|
177
|
+
table = tf.lookup.experimental.MutableHashTable(
|
|
178
|
+
key_dtype=tf.string,
|
|
179
|
+
value_dtype=tf.int32,
|
|
180
|
+
default_value=-1
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
table.insert(tag, tf.range(tf.shape(tag)[0]))
|
|
184
|
+
sequence = self._sequence_splitter(sequence)
|
|
185
|
+
sequence = table.lookup(sequence)
|
|
186
|
+
|
|
187
|
+
readout = ops.gather(data, keras.ops.where(sequence == -1, 0, sequence))
|
|
188
|
+
readout = keras.ops.where(sequence[..., None] == -1, 0.0, readout)
|
|
189
|
+
return readout
|
|
190
|
+
|
|
155
191
|
def compute_mask(
|
|
156
192
|
self,
|
|
157
|
-
inputs:
|
|
193
|
+
inputs: tensors.GraphTensor,
|
|
158
194
|
mask: bool | None = None
|
|
159
195
|
) -> tf.Tensor | None:
|
|
160
196
|
# if self.mask_value is None:
|
|
161
197
|
# return None
|
|
162
|
-
_,
|
|
163
|
-
|
|
164
|
-
|
|
198
|
+
_, sequence = inputs
|
|
199
|
+
sequence = self._sequence_splitter(sequence)
|
|
200
|
+
return keras.ops.not_equal(sequence, '')
|
|
165
201
|
|
|
166
|
-
@keras.saving.register_keras_serializable(package='molcraft')
|
|
167
|
-
class AminoAcidType(descriptors.Descriptor):
|
|
168
|
-
|
|
169
|
-
def __init__(self, vocab=None, **kwargs):
|
|
170
|
-
vocab = [
|
|
171
|
-
"A", "C", "D", "E", "F", "G", "H", "I", "K", "L",
|
|
172
|
-
"M", "N", "P", "Q", "R", "S", "T", "V", "W", "Y",
|
|
173
|
-
]
|
|
174
|
-
super().__init__(vocab=vocab, **kwargs)
|
|
175
|
-
|
|
176
|
-
def call(self, mol: chem.Mol) -> list[str]:
|
|
177
|
-
residue = residues_reverse.get(mol.canonical_smiles)
|
|
178
|
-
if not residue:
|
|
179
|
-
raise KeyError(f'Could not find {mol.canonical_smiles} in `residues_reverse`.')
|
|
180
|
-
mol = chem.remove_hs(mol)
|
|
181
|
-
return _extract_residue_type(residues_reverse[mol.canonical_smiles])
|
|
182
|
-
|
|
183
|
-
def sequence_split(sequence: str):
|
|
184
|
-
patterns = [
|
|
185
|
-
r'(\[[A-Za-z0-9]+\]-[A-Z]\[[A-Za-z0-9]+\])', # N-term mod + mod
|
|
186
|
-
r'([A-Z]\[[A-Za-z0-9]+\]-\[[A-Za-z0-9]+\])', # C-term mod + mod
|
|
187
|
-
r'([A-Z]-\[[A-Za-z0-9]+\])', # C-term mod
|
|
188
|
-
r'(\[[A-Za-z0-9]+\]-[A-Z])', # N-term mod
|
|
189
|
-
r'([A-Z]\[[A-Za-z0-9]+\])', # Mod
|
|
190
|
-
r'([A-Z])', # No mod
|
|
191
|
-
]
|
|
192
|
-
return [match.group(0) for match in re.finditer("|".join(patterns), sequence)]
|
|
193
202
|
|
|
194
203
|
residues = {
|
|
195
204
|
"A": "N[C@@H](C)C(=O)O",
|
|
@@ -252,13 +261,21 @@ residues = {
|
|
|
252
261
|
}
|
|
253
262
|
|
|
254
263
|
residues_reverse = {}
|
|
255
|
-
def register_peptide_residues(
|
|
256
|
-
for residue, smiles in
|
|
257
|
-
|
|
264
|
+
def register_peptide_residues(residues_: dict[str, str], canonicalize=True):
|
|
265
|
+
for residue, smiles in residues_.items():
|
|
266
|
+
if canonicalize:
|
|
267
|
+
smiles = Chem.MolToSmiles(Chem.MolFromSmiles(smiles))
|
|
268
|
+
residues[residue] = smiles
|
|
258
269
|
residues_reverse[residues[residue]] = residue
|
|
259
270
|
|
|
260
|
-
register_peptide_residues(residues)
|
|
271
|
+
register_peptide_residues(residues, canonicalize=False)
|
|
261
272
|
|
|
262
273
|
def _extract_residue_type(residue_tag: str) -> str:
|
|
263
|
-
pattern = r"(?<!\[)[A-Z](?![
|
|
264
|
-
return [match.group(0) for match in re.finditer(pattern, residue_tag)][0]
|
|
274
|
+
pattern = r"(?<!\[)[A-Z](?![^\[]*\])"
|
|
275
|
+
return [match.group(0) for match in re.finditer(pattern, residue_tag)][0]
|
|
276
|
+
|
|
277
|
+
special_residues = {}
|
|
278
|
+
for key, value in residues.items():
|
|
279
|
+
special_residues[key + '-[X]'] = value.rstrip('O')
|
|
280
|
+
|
|
281
|
+
register_peptide_residues(special_residues, canonicalize=False)
|
|
@@ -186,9 +186,11 @@ class MolGraphFeaturizer(Featurizer):
|
|
|
186
186
|
if default_molecule_features:
|
|
187
187
|
molecule_features = [
|
|
188
188
|
descriptors.MolWeight(),
|
|
189
|
-
descriptors.
|
|
190
|
-
descriptors.
|
|
189
|
+
descriptors.TPSA(),
|
|
190
|
+
descriptors.CrippenLogP(),
|
|
191
|
+
descriptors.CrippenMolarRefractivity(),
|
|
191
192
|
descriptors.NumHeavyAtoms(),
|
|
193
|
+
descriptors.NumHeteroAtoms(),
|
|
192
194
|
descriptors.NumHydrogenDonors(),
|
|
193
195
|
descriptors.NumHydrogenAcceptors(),
|
|
194
196
|
descriptors.NumRotatableBonds(),
|
|
@@ -219,7 +221,7 @@ class MolGraphFeaturizer(Featurizer):
|
|
|
219
221
|
|
|
220
222
|
atom_feature = self.atom_features(mol)
|
|
221
223
|
bond_feature = self.bond_features(mol)
|
|
222
|
-
|
|
224
|
+
molecule_feature = self.molecule_feature(mol)
|
|
223
225
|
molecule_size = self.num_atoms(mol)
|
|
224
226
|
|
|
225
227
|
if isinstance(context, dict):
|
|
@@ -241,8 +243,14 @@ class MolGraphFeaturizer(Featurizer):
|
|
|
241
243
|
else:
|
|
242
244
|
context = {'size': molecule_size}
|
|
243
245
|
|
|
244
|
-
if
|
|
245
|
-
|
|
246
|
+
if molecule_feature is not None:
|
|
247
|
+
if 'feature' in context:
|
|
248
|
+
warn(
|
|
249
|
+
'Found both inputted and computed context feature. '
|
|
250
|
+
'Overwriting inputted context feature with computed '
|
|
251
|
+
'context feature (based on `molecule_features`).'
|
|
252
|
+
)
|
|
253
|
+
context['feature'] = molecule_feature
|
|
246
254
|
|
|
247
255
|
node = {}
|
|
248
256
|
node['feature'] = atom_feature
|
|
@@ -288,7 +296,7 @@ class MolGraphFeaturizer(Featurizer):
|
|
|
288
296
|
edge['feature'] = self._expand_bond_features(
|
|
289
297
|
mol, paths, bond_feature,
|
|
290
298
|
)
|
|
291
|
-
edge['length'] = np.eye(self.radius + 1)[edge['length']]
|
|
299
|
+
edge['length'] = np.eye(self.radius + 1, dtype=self.feature_dtype)[edge['length']]
|
|
292
300
|
|
|
293
301
|
if self.super_atom:
|
|
294
302
|
node, edge = self._add_super_atom(node, edge)
|
|
@@ -315,13 +323,13 @@ class MolGraphFeaturizer(Featurizer):
|
|
|
315
323
|
)
|
|
316
324
|
return bond_feature.astype(self.feature_dtype)
|
|
317
325
|
|
|
318
|
-
def
|
|
326
|
+
def molecule_feature(self, mol: chem.Mol) -> np.ndarray:
|
|
319
327
|
if self._molecule_features is None:
|
|
320
328
|
return None
|
|
321
|
-
|
|
329
|
+
molecule_feature: np.ndarray = np.concatenate(
|
|
322
330
|
[f(mol) for f in self._molecule_features], axis=-1
|
|
323
331
|
)
|
|
324
|
-
return
|
|
332
|
+
return molecule_feature.astype(self.feature_dtype)
|
|
325
333
|
|
|
326
334
|
def num_atoms(self, mol: chem.Mol) -> np.ndarray:
|
|
327
335
|
return np.asarray(mol.num_atoms, dtype=self.index_dtype)
|
|
@@ -361,9 +369,7 @@ class MolGraphFeaturizer(Featurizer):
|
|
|
361
369
|
) -> tuple[dict[str, np.ndarray]]:
|
|
362
370
|
num_super_nodes = 1
|
|
363
371
|
num_nodes = node['feature'].shape[0]
|
|
364
|
-
node = _add_super_nodes(
|
|
365
|
-
node, num_super_nodes, self.feature_dtype
|
|
366
|
-
)
|
|
372
|
+
node = _add_super_nodes(node, num_super_nodes)
|
|
367
373
|
edge = _add_super_edges(
|
|
368
374
|
edge, num_nodes, num_super_nodes, self.feature_dtype, self.index_dtype
|
|
369
375
|
)
|
|
@@ -544,7 +550,7 @@ class MolGraphFeaturizer3D(MolGraphFeaturizer):
|
|
|
544
550
|
'of the `Featurizer` or input a 3D representation of the molecule. '
|
|
545
551
|
)
|
|
546
552
|
|
|
547
|
-
|
|
553
|
+
molecule_feature = self.molecule_feature(mol)
|
|
548
554
|
molecule_size = self.num_atoms(mol) + int(self.super_atom)
|
|
549
555
|
|
|
550
556
|
if isinstance(context, dict):
|
|
@@ -566,8 +572,14 @@ class MolGraphFeaturizer3D(MolGraphFeaturizer):
|
|
|
566
572
|
else:
|
|
567
573
|
context = {'size': molecule_size}
|
|
568
574
|
|
|
569
|
-
if
|
|
570
|
-
|
|
575
|
+
if molecule_feature is not None:
|
|
576
|
+
if 'feature' in context:
|
|
577
|
+
warn(
|
|
578
|
+
'Found both inputted and computed context feature. '
|
|
579
|
+
'Overwriting inputted context feature with computed '
|
|
580
|
+
'context feature (based on `molecule_features`).'
|
|
581
|
+
)
|
|
582
|
+
context['feature'] = molecule_feature
|
|
571
583
|
|
|
572
584
|
node = {}
|
|
573
585
|
node['feature'] = self.atom_features(mol)
|
|
@@ -587,7 +599,7 @@ class MolGraphFeaturizer3D(MolGraphFeaturizer):
|
|
|
587
599
|
radius=self.radius,
|
|
588
600
|
sparse=False,
|
|
589
601
|
self_loops=self.self_loops,
|
|
590
|
-
dtype=
|
|
602
|
+
dtype=bool
|
|
591
603
|
)
|
|
592
604
|
edge_conformer['source'], edge_conformer['target'] = np.where(adjacency_matrix)
|
|
593
605
|
edge_conformer['source'] = edge_conformer['source'].astype(self.index_dtype)
|
|
@@ -604,7 +616,7 @@ class MolGraphFeaturizer3D(MolGraphFeaturizer):
|
|
|
604
616
|
)
|
|
605
617
|
node_conformer['coordinate'] = np.concatenate(
|
|
606
618
|
[node_conformer['coordinate'], conformer.centroid[None]], axis=0
|
|
607
|
-
)
|
|
619
|
+
).astype(self.feature_dtype)
|
|
608
620
|
tensor_list.append(
|
|
609
621
|
tensors.GraphTensor(context, node_conformer, edge_conformer)
|
|
610
622
|
)
|
|
@@ -670,12 +682,15 @@ def load_featurizer(
|
|
|
670
682
|
def _add_super_nodes(
|
|
671
683
|
node: dict[str, np.ndarray],
|
|
672
684
|
num_super_nodes: int = 1,
|
|
673
|
-
feature_dtype: str = 'float32',
|
|
674
685
|
) -> dict[str, np.ndarray]:
|
|
675
686
|
node = copy.deepcopy(node)
|
|
676
|
-
node['super'] = np.array(
|
|
687
|
+
node['super'] = np.array(
|
|
688
|
+
[False] * len(node['feature']) + [True] * num_super_nodes,
|
|
689
|
+
dtype=bool
|
|
690
|
+
)
|
|
677
691
|
super_node_feature = np.zeros(
|
|
678
|
-
[num_super_nodes, node['feature'].shape[-1]],
|
|
692
|
+
[num_super_nodes, node['feature'].shape[-1]],
|
|
693
|
+
dtype=node['feature'].dtype
|
|
679
694
|
)
|
|
680
695
|
node['feature'] = np.concatenate([node['feature'], super_node_feature])
|
|
681
696
|
return node
|
|
@@ -695,31 +710,38 @@ def _add_super_edges(
|
|
|
695
710
|
np.tile(np.arange(num_nodes), [num_super_nodes])
|
|
696
711
|
)
|
|
697
712
|
edge['source'] = np.concatenate(
|
|
698
|
-
[
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
super_node_indices,
|
|
702
|
-
]
|
|
703
|
-
)
|
|
704
|
-
edge['source'] = edge['source'].astype(index_dtype)
|
|
713
|
+
[edge['source'], node_indices, super_node_indices]
|
|
714
|
+
).astype(index_dtype)
|
|
715
|
+
|
|
705
716
|
edge['target'] = np.concatenate(
|
|
706
|
-
[
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
node_indices
|
|
710
|
-
]
|
|
711
|
-
)
|
|
712
|
-
edge['target'] = edge['target'].astype(index_dtype)
|
|
717
|
+
[edge['target'], super_node_indices, node_indices]
|
|
718
|
+
).astype(index_dtype)
|
|
719
|
+
|
|
713
720
|
if 'feature' in edge:
|
|
714
|
-
|
|
715
|
-
|
|
721
|
+
num_edges = int(edge['feature'].shape[0])
|
|
722
|
+
num_super_edges = int(num_super_nodes * num_nodes * 2)
|
|
723
|
+
edge['super'] = np.asarray(
|
|
724
|
+
([False] * num_edges + [True] * num_super_edges),
|
|
725
|
+
dtype=bool
|
|
726
|
+
)
|
|
727
|
+
edge['feature'] = np.concatenate(
|
|
728
|
+
[
|
|
729
|
+
edge['feature'],
|
|
730
|
+
np.zeros(
|
|
731
|
+
shape=(num_super_edges, edge['feature'].shape[-1]),
|
|
732
|
+
dtype=edge['feature'].dtype
|
|
733
|
+
)
|
|
734
|
+
]
|
|
735
|
+
)
|
|
736
|
+
|
|
716
737
|
if 'length' in edge:
|
|
717
738
|
edge['length'] = np.pad(edge['length'], [(0, 0), (1, 0)])
|
|
718
|
-
zero_array = np.zeros(
|
|
739
|
+
zero_array = np.zeros([num_nodes * num_super_nodes * 2], dtype='int32')
|
|
719
740
|
edge_length_dim = edge['length'].shape[1]
|
|
720
741
|
virtual_edge_length = np.eye(edge_length_dim)[zero_array]
|
|
721
742
|
edge['length'] = np.concatenate([edge['length'], virtual_edge_length])
|
|
722
743
|
edge['length'] = edge['length'].astype(feature_dtype)
|
|
744
|
+
|
|
723
745
|
return edge
|
|
724
746
|
|
|
725
747
|
|