molcraft 0.1.0a15__py3-none-any.whl → 0.1.0a17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of molcraft might be problematic. Click here for more details.
- molcraft/__init__.py +1 -2
- molcraft/applications/chromatography.py +0 -0
- molcraft/applications/proteomics.py +194 -0
- molcraft/chem.py +17 -22
- molcraft/datasets.py +6 -6
- molcraft/descriptors.py +14 -0
- molcraft/features.py +50 -58
- molcraft/featurizers.py +257 -487
- molcraft/layers.py +50 -0
- molcraft/models.py +2 -0
- molcraft/records.py +24 -15
- {molcraft-0.1.0a15.dist-info → molcraft-0.1.0a17.dist-info}/METADATA +14 -12
- molcraft-0.1.0a17.dist-info/RECORD +21 -0
- molcraft/apps/peptides.py +0 -429
- molcraft/apps/qsrr.py +0 -47
- molcraft/conformers.py +0 -151
- molcraft-0.1.0a15.dist-info/RECORD +0 -22
- /molcraft/{apps → applications}/__init__.py +0 -0
- {molcraft-0.1.0a15.dist-info → molcraft-0.1.0a17.dist-info}/WHEEL +0 -0
- {molcraft-0.1.0a15.dist-info → molcraft-0.1.0a17.dist-info}/licenses/LICENSE +0 -0
- {molcraft-0.1.0a15.dist-info → molcraft-0.1.0a17.dist-info}/top_level.txt +0 -0
molcraft/layers.py
CHANGED
|
@@ -1430,6 +1430,56 @@ class EdgeEmbedding(GraphLayer):
|
|
|
1430
1430
|
return config
|
|
1431
1431
|
|
|
1432
1432
|
|
|
1433
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1434
|
+
class AddContext(GraphLayer):
|
|
1435
|
+
|
|
1436
|
+
"""Context adding layer.
|
|
1437
|
+
|
|
1438
|
+
Adds context to super nodes.
|
|
1439
|
+
"""
|
|
1440
|
+
|
|
1441
|
+
def __init__(
|
|
1442
|
+
self,
|
|
1443
|
+
field: str = 'feature',
|
|
1444
|
+
drop: bool = False,
|
|
1445
|
+
normalize: bool = False,
|
|
1446
|
+
**kwargs
|
|
1447
|
+
) -> None:
|
|
1448
|
+
super().__init__(**kwargs)
|
|
1449
|
+
self.field = field
|
|
1450
|
+
self.drop = drop
|
|
1451
|
+
self._normalize = normalize
|
|
1452
|
+
|
|
1453
|
+
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
1454
|
+
feature_dim = spec.node['feature'].shape[-1]
|
|
1455
|
+
self._context_dense = self.get_dense(feature_dim)
|
|
1456
|
+
if not self._normalize:
|
|
1457
|
+
self._norm = keras.layers.Identity()
|
|
1458
|
+
elif str(self._normalize).lower().startswith('layer'):
|
|
1459
|
+
self._norm = keras.layers.LayerNormalization()
|
|
1460
|
+
else:
|
|
1461
|
+
self._norm = keras.layers.BatchNormalization()
|
|
1462
|
+
|
|
1463
|
+
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1464
|
+
context = tensor.context[self.field]
|
|
1465
|
+
context = self._context_dense(context)
|
|
1466
|
+
context = self._norm(context)
|
|
1467
|
+
node_feature = ops.scatter_add(
|
|
1468
|
+
tensor.node['feature'], tensor.node['super'], context
|
|
1469
|
+
)
|
|
1470
|
+
data = {'node': {'feature': node_feature}}
|
|
1471
|
+
if self.drop:
|
|
1472
|
+
data['context'] = {self.field: None}
|
|
1473
|
+
return tensor.update(data)
|
|
1474
|
+
|
|
1475
|
+
def get_config(self) -> dict:
|
|
1476
|
+
config = super().get_config()
|
|
1477
|
+
config['field'] = self.field
|
|
1478
|
+
config['drop'] = self.drop
|
|
1479
|
+
config['normalize'] = self._normalize
|
|
1480
|
+
return config
|
|
1481
|
+
|
|
1482
|
+
|
|
1433
1483
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1434
1484
|
class GraphNetwork(GraphLayer):
|
|
1435
1485
|
|
molcraft/models.py
CHANGED
|
@@ -154,6 +154,7 @@ class GraphModel(layers.GraphLayer, keras.models.Model):
|
|
|
154
154
|
return graph
|
|
155
155
|
|
|
156
156
|
def get_config(self):
|
|
157
|
+
"""Obtain model config."""
|
|
157
158
|
config = super().get_config()
|
|
158
159
|
if hasattr(self, '_model_layers') and self._model_layers is not None:
|
|
159
160
|
config['model_layers'] = [
|
|
@@ -164,6 +165,7 @@ class GraphModel(layers.GraphLayer, keras.models.Model):
|
|
|
164
165
|
|
|
165
166
|
@classmethod
|
|
166
167
|
def from_config(cls, config: dict):
|
|
168
|
+
"""Obtain model from model config."""
|
|
167
169
|
if 'model_layers' in config:
|
|
168
170
|
config['model_layers'] = [
|
|
169
171
|
keras.saving.deserialize_keras_object(l)
|
molcraft/records.py
CHANGED
|
@@ -14,7 +14,7 @@ from molcraft import featurizers
|
|
|
14
14
|
|
|
15
15
|
def write(
|
|
16
16
|
inputs: list[str | tuple],
|
|
17
|
-
featurizer: featurizers.
|
|
17
|
+
featurizer: featurizers.GraphFeaturizer,
|
|
18
18
|
path: str,
|
|
19
19
|
overwrite: bool = True,
|
|
20
20
|
num_files: typing.Optional[int] = None,
|
|
@@ -23,10 +23,13 @@ def write(
|
|
|
23
23
|
device: str = '/cpu:0'
|
|
24
24
|
) -> None:
|
|
25
25
|
|
|
26
|
-
if os.path.isdir(path)
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
26
|
+
if os.path.isdir(path):
|
|
27
|
+
if not overwrite:
|
|
28
|
+
return
|
|
29
|
+
else:
|
|
30
|
+
_remove_files(path)
|
|
31
|
+
else:
|
|
32
|
+
os.makedirs(path)
|
|
30
33
|
|
|
31
34
|
with tf.device(device):
|
|
32
35
|
|
|
@@ -133,7 +136,7 @@ def load_spec(path: str) -> tensors.GraphTensor.Spec:
|
|
|
133
136
|
def _write_tfrecord(
|
|
134
137
|
inputs,
|
|
135
138
|
path: str,
|
|
136
|
-
featurizer: featurizers.
|
|
139
|
+
featurizer: featurizers.GraphFeaturizer,
|
|
137
140
|
) -> None:
|
|
138
141
|
|
|
139
142
|
def _write_example(tensor):
|
|
@@ -149,11 +152,7 @@ def _write_tfrecord(
|
|
|
149
152
|
x = tuple(x)
|
|
150
153
|
tensor = featurizer(x)
|
|
151
154
|
if tensor is not None:
|
|
152
|
-
|
|
153
|
-
_write_example(tensor)
|
|
154
|
-
else:
|
|
155
|
-
for t in tensor:
|
|
156
|
-
_write_example(t)
|
|
155
|
+
_write_example(tensor)
|
|
157
156
|
|
|
158
157
|
def _serialize_example(
|
|
159
158
|
feature: dict[str, tf.train.Feature]
|
|
@@ -168,8 +167,18 @@ def _parse_example(
|
|
|
168
167
|
) -> tf.Tensor:
|
|
169
168
|
out = tf.io.parse_single_example(
|
|
170
169
|
x, features={'feature': tf.io.RaggedFeature(tf.string)})['feature']
|
|
171
|
-
out = [
|
|
172
|
-
tf.
|
|
173
|
-
|
|
170
|
+
out = [
|
|
171
|
+
tf.ensure_shape(tf.io.parse_tensor(x[0], s.dtype), s.shape)
|
|
172
|
+
for (x, s) in zip(
|
|
173
|
+
tf.split(out, len(tf.nest.flatten(spec, expand_composites=True))),
|
|
174
|
+
tf.nest.flatten(spec, expand_composites=True)
|
|
175
|
+
)
|
|
176
|
+
]
|
|
174
177
|
out = tf.nest.pack_sequence_as(spec, tf.nest.flatten(out), expand_composites=True)
|
|
175
|
-
return out
|
|
178
|
+
return out
|
|
179
|
+
|
|
180
|
+
def _remove_files(path):
|
|
181
|
+
for filename in os.listdir(path):
|
|
182
|
+
if filename.endswith('tfrecord') or filename == 'spec.pb':
|
|
183
|
+
filepath = os.path.join(path, filename)
|
|
184
|
+
os.remove(filepath)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: molcraft
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.0a17
|
|
4
4
|
Summary: Graph Neural Networks for Molecular Machine Learning
|
|
5
5
|
Author-email: Alexander Kensert <alexander.kensert@gmail.com>
|
|
6
6
|
License: MIT License
|
|
@@ -35,6 +35,7 @@ Requires-Python: >=3.10
|
|
|
35
35
|
Description-Content-Type: text/markdown
|
|
36
36
|
License-File: LICENSE
|
|
37
37
|
Requires-Dist: tensorflow>=2.16
|
|
38
|
+
Requires-Dist: tensorflow-text>=2.16
|
|
38
39
|
Requires-Dist: rdkit>=2023.9.5
|
|
39
40
|
Requires-Dist: pandas>=1.0.3
|
|
40
41
|
Requires-Dist: ipython>=8.12.0
|
|
@@ -42,9 +43,9 @@ Provides-Extra: gpu
|
|
|
42
43
|
Requires-Dist: tensorflow[and-cuda]>=2.16; extra == "gpu"
|
|
43
44
|
Dynamic: license-file
|
|
44
45
|
|
|
45
|
-
<img src="https://github.com/akensert/molcraft/blob/main/docs/_static/molcraft-logo.png" alt="molcraft-logo">
|
|
46
|
+
<img src="https://github.com/akensert/molcraft/blob/main/docs/_static/molcraft-logo.png" alt="molcraft-logo", width="90%">
|
|
46
47
|
|
|
47
|
-
**Deep Learning on Molecules**: A Minimalistic GNN package for Molecular ML.
|
|
48
|
+
**Deep Learning on Molecules**: A Minimalistic GNN package for Molecular ML.
|
|
48
49
|
|
|
49
50
|
> [!NOTE]
|
|
50
51
|
> In progress.
|
|
@@ -82,11 +83,12 @@ featurizer = featurizers.MolGraphFeaturizer(
|
|
|
82
83
|
features.BondType(),
|
|
83
84
|
features.IsRotatable(),
|
|
84
85
|
],
|
|
85
|
-
|
|
86
|
+
super_node=True,
|
|
86
87
|
self_loops=True,
|
|
88
|
+
include_hydrogens=False,
|
|
87
89
|
)
|
|
88
90
|
|
|
89
|
-
graph = featurizer([('N[C@@H](C)C(=O)O', 2.
|
|
91
|
+
graph = featurizer([('N[C@@H](C)C(=O)O', 2.5), ('N[C@@H](CS)C(=O)O', 1.5)])
|
|
90
92
|
print(graph)
|
|
91
93
|
|
|
92
94
|
model = models.GraphModel.from_layers(
|
|
@@ -94,13 +96,13 @@ model = models.GraphModel.from_layers(
|
|
|
94
96
|
layers.Input(graph.spec),
|
|
95
97
|
layers.NodeEmbedding(dim=128),
|
|
96
98
|
layers.EdgeEmbedding(dim=128),
|
|
97
|
-
layers.
|
|
98
|
-
layers.
|
|
99
|
-
layers.
|
|
100
|
-
layers.
|
|
101
|
-
layers.Readout(
|
|
102
|
-
keras.layers.Dense(units=1024, activation='
|
|
103
|
-
keras.layers.Dense(units=1024, activation='
|
|
99
|
+
layers.GraphConv(units=128),
|
|
100
|
+
layers.GraphConv(units=128),
|
|
101
|
+
layers.GraphConv(units=128),
|
|
102
|
+
layers.GraphConv(units=128),
|
|
103
|
+
layers.Readout(),
|
|
104
|
+
keras.layers.Dense(units=1024, activation='elu'),
|
|
105
|
+
keras.layers.Dense(units=1024, activation='elu'),
|
|
104
106
|
keras.layers.Dense(1)
|
|
105
107
|
]
|
|
106
108
|
)
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
molcraft/__init__.py,sha256=vc-z1sgtzPY7Spwzkemu7I_b9ekEN9egnHrLEKbB9bk,431
|
|
2
|
+
molcraft/callbacks.py,sha256=x5HnkZhqcFRrW6xdApt_jZ4X08A-0fxcnFKfdmRKa0c,3571
|
|
3
|
+
molcraft/chem.py,sha256=e56qBDuqh8rq_4-UMyp6LCQNxxSx8hZ7gzuz-87DHgw,21652
|
|
4
|
+
molcraft/datasets.py,sha256=Nd2lw5USUZE52vvAiNr-q-n03Y3--NlZlK0NzqHgp-E,4145
|
|
5
|
+
molcraft/descriptors.py,sha256=Cl3KnBPsTST7XLgRLktkX5LwY9MV0P_lUlrt8iPV5no,3508
|
|
6
|
+
molcraft/features.py,sha256=s0WeV8eZcDEypPgC1m37f4s9QkvWIlVgn-L43Cdsa14,13525
|
|
7
|
+
molcraft/featurizers.py,sha256=bD3RFY9eg89-O-Nxgy6gote1zS4cyjOgzdSiSJZJdJE,17664
|
|
8
|
+
molcraft/layers.py,sha256=Y-TMb4oHh3R7tHgr7f3Y8sEPDnoSTbtwB6NkZIVnmcA,61734
|
|
9
|
+
molcraft/losses.py,sha256=qnS2yC5g-O3n_zVea9MR6TNiFraW2yqRgePOisoUP4A,1065
|
|
10
|
+
molcraft/models.py,sha256=2Pc1htT9fCukGd8ZxrvE0rzEHsPBm0pluHw4FZXaUE4,21963
|
|
11
|
+
molcraft/ops.py,sha256=bQbdFDt9waxVCzF5-dkTB6vlpj9eoSt8I4Qg7ZGXbsU,6178
|
|
12
|
+
molcraft/records.py,sha256=0j4EWP55sfnkoQIH5trdaAIevPfVbAtPLrygTRmLyFw,5686
|
|
13
|
+
molcraft/tensors.py,sha256=EOUKx496KUZsjA1zA2ABc7tU_TW3Jv7AXDsug_QsLbA,22407
|
|
14
|
+
molcraft/applications/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
15
|
+
molcraft/applications/chromatography.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
16
|
+
molcraft/applications/proteomics.py,sha256=Jb7OwHJHc_I7Wk3qnqr40j9P7um2EKtUnB4r-XhrnAc,7180
|
|
17
|
+
molcraft-0.1.0a17.dist-info/licenses/LICENSE,sha256=sbVeqlrtZ0V63uYhZGL5dCxUm8rBAOqe2avyA1zIQNk,1074
|
|
18
|
+
molcraft-0.1.0a17.dist-info/METADATA,sha256=XqNJDwFfY6pWNqQKYLyUOxwyvmfYUkOWTKou-ZQYXL4,3930
|
|
19
|
+
molcraft-0.1.0a17.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
20
|
+
molcraft-0.1.0a17.dist-info/top_level.txt,sha256=dENV6MfOceshM6MQCgJlcN1ojZkiCL9B4F7XyUge3QM,9
|
|
21
|
+
molcraft-0.1.0a17.dist-info/RECORD,,
|
molcraft/apps/peptides.py
DELETED
|
@@ -1,429 +0,0 @@
|
|
|
1
|
-
import re
|
|
2
|
-
import keras
|
|
3
|
-
import numpy as np
|
|
4
|
-
import tensorflow as tf
|
|
5
|
-
import tensorflow_text as tf_text
|
|
6
|
-
from rdkit import Chem
|
|
7
|
-
|
|
8
|
-
from molcraft import ops
|
|
9
|
-
from molcraft import chem
|
|
10
|
-
from molcraft import features
|
|
11
|
-
from molcraft import featurizers
|
|
12
|
-
from molcraft import tensors
|
|
13
|
-
from molcraft import descriptors
|
|
14
|
-
from molcraft import layers
|
|
15
|
-
from molcraft import models
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
@keras.saving.register_keras_serializable(package='molcraft')
|
|
20
|
-
class SequenceSplitter(keras.layers.Layer):
|
|
21
|
-
|
|
22
|
-
_pattern = "|".join([
|
|
23
|
-
r'(\[[A-Za-z0-9]+\]-[A-Z]\[[A-Za-z0-9]+\])', # N-term mod + mod
|
|
24
|
-
r'([A-Z]\[[A-Za-z0-9]+\]-\[[A-Za-z0-9]+\])', # C-term mod + mod
|
|
25
|
-
r'([A-Z]-\[[A-Za-z0-9]+\])', # C-term mod
|
|
26
|
-
r'(\[[A-Za-z0-9]+\]-[A-Z])', # N-term mod
|
|
27
|
-
r'([A-Z]\[[A-Za-z0-9]+\])', # Mod
|
|
28
|
-
r'([A-Z])', # No mod
|
|
29
|
-
])
|
|
30
|
-
|
|
31
|
-
def call(self, inputs):
|
|
32
|
-
inputs = tf_text.regex_split(inputs, self._pattern, self._pattern)
|
|
33
|
-
inputs = keras.ops.concatenate([
|
|
34
|
-
tf.strings.join([inputs[:, :-1], '-[X]']),
|
|
35
|
-
inputs[:, -1:]
|
|
36
|
-
], axis=1)
|
|
37
|
-
return inputs.to_tensor()
|
|
38
|
-
|
|
39
|
-
@keras.saving.register_keras_serializable(package='molcraft')
|
|
40
|
-
class Gather(keras.layers.Layer):
|
|
41
|
-
|
|
42
|
-
def __init__(
|
|
43
|
-
self,
|
|
44
|
-
padding: list[tuple[int]] | tuple[int] | int = 1,
|
|
45
|
-
mask_value: int = 0,
|
|
46
|
-
**kwargs
|
|
47
|
-
) -> None:
|
|
48
|
-
super().__init__(**kwargs)
|
|
49
|
-
self._splitter = SequenceSplitter()
|
|
50
|
-
self.padding = padding
|
|
51
|
-
self.mask_value = mask_value
|
|
52
|
-
self.supports_masking = True
|
|
53
|
-
|
|
54
|
-
self._tags = list(sorted(residues.keys()))
|
|
55
|
-
self._mapping = tf.lookup.StaticHashTable(
|
|
56
|
-
tf.lookup.KeyValueTensorInitializer(
|
|
57
|
-
keys=self._tags,
|
|
58
|
-
values=range(len(self._tags)),
|
|
59
|
-
),
|
|
60
|
-
default_value=-1,
|
|
61
|
-
)
|
|
62
|
-
|
|
63
|
-
def get_config(self):
|
|
64
|
-
config = super().get_config()
|
|
65
|
-
config['mask_value'] = self.mask_value
|
|
66
|
-
config['padding'] = self.padding
|
|
67
|
-
return config
|
|
68
|
-
|
|
69
|
-
def call(self, inputs) -> tf.Tensor:
|
|
70
|
-
embedding, sequence = inputs
|
|
71
|
-
sequence = self._splitter(sequence)
|
|
72
|
-
sequence = self._mapping.lookup(sequence)
|
|
73
|
-
readout = ops.gather(embedding, keras.ops.where(sequence == -1, 0, sequence))
|
|
74
|
-
readout = keras.ops.where(sequence[..., None] == -1, 0.0, readout)
|
|
75
|
-
return readout
|
|
76
|
-
|
|
77
|
-
def compute_mask(
|
|
78
|
-
self,
|
|
79
|
-
inputs: tensors.GraphTensor,
|
|
80
|
-
mask: bool | None = None
|
|
81
|
-
) -> tf.Tensor | None:
|
|
82
|
-
# if self.mask_value is None:
|
|
83
|
-
# return None
|
|
84
|
-
_, sequence = inputs
|
|
85
|
-
sequence = self._splitter(sequence)
|
|
86
|
-
return keras.ops.not_equal(sequence, '')
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
@keras.saving.register_keras_serializable(package='molcraft')
|
|
90
|
-
class Embedding(keras.layers.Layer):
|
|
91
|
-
|
|
92
|
-
def __init__(self, **kwargs):
|
|
93
|
-
super().__init__(**kwargs)
|
|
94
|
-
tags = list(sorted(residues.keys()))
|
|
95
|
-
self.mapping = tf.lookup.StaticHashTable(
|
|
96
|
-
tf.lookup.KeyValueTensorInitializer(
|
|
97
|
-
keys=tags,
|
|
98
|
-
values=range(len(tags)),
|
|
99
|
-
),
|
|
100
|
-
default_value=-1,
|
|
101
|
-
)
|
|
102
|
-
self.splitting = SequenceSplitter()
|
|
103
|
-
featurizer = featurizers.MolGraphFeaturizer(super_atom=True)
|
|
104
|
-
tensor_list = [featurizer(residues[tag]) for tag in tags]
|
|
105
|
-
graph = tf.stack(tensor_list, axis=0)
|
|
106
|
-
self._build_on_init(graph)
|
|
107
|
-
self.embedder = models.GraphModel.from_layers(
|
|
108
|
-
[
|
|
109
|
-
layers.Input(graph.spec),
|
|
110
|
-
layers.NodeEmbedding(128),
|
|
111
|
-
layers.EdgeEmbedding(128),
|
|
112
|
-
layers.GraphTransformer(128),
|
|
113
|
-
layers.Readout()
|
|
114
|
-
]
|
|
115
|
-
)
|
|
116
|
-
self.embedding = tf.Variable(
|
|
117
|
-
initial_value=tf.zeros((114, 128)), trainable=True
|
|
118
|
-
)
|
|
119
|
-
self.new_state = tf.Variable(True, dtype=tf.bool, trainable=False)
|
|
120
|
-
self.gather = Gather()
|
|
121
|
-
self.update_state()
|
|
122
|
-
|
|
123
|
-
# Keep AA as is (most simple?), add positional embedding to distingusih N-, C- and non-terminal
|
|
124
|
-
|
|
125
|
-
def update_state(self, inputs=None):
|
|
126
|
-
graph = self._graph_tensor
|
|
127
|
-
graph = tensors.to_dict(graph)
|
|
128
|
-
embedding = self.embedder(graph)
|
|
129
|
-
self.embedding.assign(embedding)
|
|
130
|
-
tf.print("STATE UPDATED")
|
|
131
|
-
return embedding
|
|
132
|
-
|
|
133
|
-
def call(self, inputs=None, training=None) -> tensors.GraphTensor:
|
|
134
|
-
if training:
|
|
135
|
-
embedding = self.update_state()
|
|
136
|
-
self.new_state.assign(True)
|
|
137
|
-
return self.gather([embedding, inputs])
|
|
138
|
-
else:
|
|
139
|
-
embedding = tf.cond(
|
|
140
|
-
pred=self.new_state,
|
|
141
|
-
true_fn=lambda: self.update_state(),
|
|
142
|
-
false_fn=lambda: self.embedding
|
|
143
|
-
)
|
|
144
|
-
self.new_state.assign(False)
|
|
145
|
-
return self.gather([embedding, inputs])
|
|
146
|
-
|
|
147
|
-
def build(self, input_shape):
|
|
148
|
-
super().build(input_shape)
|
|
149
|
-
|
|
150
|
-
def _build_on_init(self, x):
|
|
151
|
-
|
|
152
|
-
if isinstance(x, tensors.GraphTensor):
|
|
153
|
-
tensor = tensors.to_dict(x)
|
|
154
|
-
self._spec = tf.nest.map_structure(
|
|
155
|
-
tf.type_spec_from_value, tensor
|
|
156
|
-
)
|
|
157
|
-
else:
|
|
158
|
-
self._spec = x
|
|
159
|
-
|
|
160
|
-
self._graph = tf.nest.map_structure(
|
|
161
|
-
lambda s: self.add_weight(
|
|
162
|
-
shape=s.shape,
|
|
163
|
-
dtype=s.dtype,
|
|
164
|
-
trainable=False,
|
|
165
|
-
initializer='zeros'
|
|
166
|
-
),
|
|
167
|
-
self._spec
|
|
168
|
-
)
|
|
169
|
-
|
|
170
|
-
if isinstance(x, tensors.GraphTensor):
|
|
171
|
-
tf.nest.map_structure(
|
|
172
|
-
lambda v, x: v.assign(x),
|
|
173
|
-
self._graph, tensor
|
|
174
|
-
)
|
|
175
|
-
|
|
176
|
-
graph = tf.nest.map_structure(
|
|
177
|
-
keras.ops.convert_to_tensor, self._graph
|
|
178
|
-
)
|
|
179
|
-
self._graph_tensor = tensors.from_dict(graph)
|
|
180
|
-
|
|
181
|
-
# def get_config(self) -> dict:
|
|
182
|
-
# config = super().get_config()
|
|
183
|
-
# spec = keras.saving.serialize_keras_object(self._spec)
|
|
184
|
-
# config['spec'] = spec
|
|
185
|
-
# #config['layers'] = keras.saving.serialize_keras_object(self.embedding.layers)
|
|
186
|
-
# return config
|
|
187
|
-
|
|
188
|
-
# @classmethod
|
|
189
|
-
# def from_config(cls, config: dict) -> 'SequenceToGraph':
|
|
190
|
-
# spec = config.pop('spec')
|
|
191
|
-
# spec = keras.saving.deserialize_keras_object(spec)
|
|
192
|
-
# # config['layers'] = keras.saving.deserialize_keras_object(config['layers'])
|
|
193
|
-
# layer = cls(**config)
|
|
194
|
-
# layer._build_on_init(spec)
|
|
195
|
-
# return layer
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
@keras.saving.register_keras_serializable(package='molcraft')
|
|
199
|
-
class SequenceToGraph(keras.layers.Layer):
|
|
200
|
-
|
|
201
|
-
def __init__(
|
|
202
|
-
self,
|
|
203
|
-
atom_features: list[features.Feature] | str | None = 'auto',
|
|
204
|
-
bond_features: list[features.Feature] | str | None = 'auto',
|
|
205
|
-
molecule_features: list[descriptors.Descriptor] | str | None = 'auto',
|
|
206
|
-
super_atom: bool = True,
|
|
207
|
-
radius: int | float | None = None,
|
|
208
|
-
self_loops: bool = False,
|
|
209
|
-
include_hs: bool = False,
|
|
210
|
-
**kwargs,
|
|
211
|
-
):
|
|
212
|
-
super().__init__(**kwargs)
|
|
213
|
-
self._splitter = SequenceSplitter()
|
|
214
|
-
featurizer = featurizers.MolGraphFeaturizer(
|
|
215
|
-
atom_features=atom_features,
|
|
216
|
-
bond_features=bond_features,
|
|
217
|
-
molecule_features=molecule_features,
|
|
218
|
-
super_atom=super_atom,
|
|
219
|
-
radius=radius,
|
|
220
|
-
self_loops=self_loops,
|
|
221
|
-
include_hs=include_hs,
|
|
222
|
-
**kwargs,
|
|
223
|
-
)
|
|
224
|
-
tensor_list: list[tensors.GraphTensor] = [
|
|
225
|
-
featurizer(residues[tag]).update({'context': {'tag': tag}}) for tag in residues
|
|
226
|
-
]
|
|
227
|
-
graph = tf.stack(tensor_list, axis=0)
|
|
228
|
-
self._build_on_init(graph)
|
|
229
|
-
|
|
230
|
-
def call(self, sequence: tf.Tensor) -> tensors.GraphTensor:
|
|
231
|
-
sequence = self._splitter(sequence)
|
|
232
|
-
indices = self._tag_to_index.lookup(sequence)
|
|
233
|
-
indices = tf.sort(tf.unique(tf.reshape(indices, [-1]))[0])[1:]
|
|
234
|
-
graph = self._graph_tensor[indices]
|
|
235
|
-
return tensors.to_dict(graph)
|
|
236
|
-
|
|
237
|
-
def _build_on_init(self, x):
|
|
238
|
-
|
|
239
|
-
if isinstance(x, tensors.GraphTensor):
|
|
240
|
-
tensor = tensors.to_dict(x)
|
|
241
|
-
self._spec = tf.nest.map_structure(
|
|
242
|
-
tf.type_spec_from_value, tensor
|
|
243
|
-
)
|
|
244
|
-
else:
|
|
245
|
-
self._spec = x
|
|
246
|
-
|
|
247
|
-
self._graph = tf.nest.map_structure(
|
|
248
|
-
lambda s: self.add_weight(
|
|
249
|
-
shape=s.shape,
|
|
250
|
-
dtype=s.dtype,
|
|
251
|
-
trainable=False,
|
|
252
|
-
initializer='zeros'
|
|
253
|
-
),
|
|
254
|
-
self._spec
|
|
255
|
-
)
|
|
256
|
-
|
|
257
|
-
if isinstance(x, tensors.GraphTensor):
|
|
258
|
-
tf.nest.map_structure(
|
|
259
|
-
lambda v, x: v.assign(x),
|
|
260
|
-
self._graph, tensor
|
|
261
|
-
)
|
|
262
|
-
|
|
263
|
-
graph = tf.nest.map_structure(
|
|
264
|
-
keras.ops.convert_to_tensor, self._graph
|
|
265
|
-
)
|
|
266
|
-
self._graph_tensor = tensors.from_dict(graph)
|
|
267
|
-
|
|
268
|
-
tags = self._graph_tensor.context['tag']
|
|
269
|
-
|
|
270
|
-
self._tag_to_index = tf.lookup.StaticHashTable(
|
|
271
|
-
tf.lookup.KeyValueTensorInitializer(
|
|
272
|
-
keys=tags,
|
|
273
|
-
values=range(len(tags)),
|
|
274
|
-
),
|
|
275
|
-
default_value=-1,
|
|
276
|
-
)
|
|
277
|
-
|
|
278
|
-
def get_config(self) -> dict:
|
|
279
|
-
config = super().get_config()
|
|
280
|
-
spec = keras.saving.serialize_keras_object(self._spec)
|
|
281
|
-
config['spec'] = spec
|
|
282
|
-
return config
|
|
283
|
-
|
|
284
|
-
@classmethod
|
|
285
|
-
def from_config(cls, config: dict) -> 'SequenceToGraph':
|
|
286
|
-
spec = config.pop('spec')
|
|
287
|
-
spec = keras.saving.deserialize_keras_object(spec)
|
|
288
|
-
layer = cls(**config)
|
|
289
|
-
layer._build_on_init(spec)
|
|
290
|
-
return layer
|
|
291
|
-
|
|
292
|
-
# @property
|
|
293
|
-
# def graph(self) -> tensors.GraphTensor:
|
|
294
|
-
# return self._graph_tensor
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
@keras.saving.register_keras_serializable(package='molcraft')
|
|
298
|
-
class GraphToSequence(keras.layers.Layer):
|
|
299
|
-
|
|
300
|
-
def __init__(
|
|
301
|
-
self,
|
|
302
|
-
padding: list[tuple[int]] | tuple[int] | int = 1,
|
|
303
|
-
mask_value: int = 0,
|
|
304
|
-
**kwargs
|
|
305
|
-
) -> None:
|
|
306
|
-
super().__init__(**kwargs)
|
|
307
|
-
self._splitter = SequenceSplitter()
|
|
308
|
-
self.padding = padding
|
|
309
|
-
self.mask_value = mask_value
|
|
310
|
-
self._readout_layer = layers.Readout(mode='mean')
|
|
311
|
-
self.supports_masking = True
|
|
312
|
-
|
|
313
|
-
def get_config(self):
|
|
314
|
-
config = super().get_config()
|
|
315
|
-
config['mask_value'] = self.mask_value
|
|
316
|
-
config['padding'] = self.padding
|
|
317
|
-
return config
|
|
318
|
-
|
|
319
|
-
def call(self, inputs) -> tf.Tensor:
|
|
320
|
-
|
|
321
|
-
graph, sequence = inputs
|
|
322
|
-
sequence = self._splitter(sequence)
|
|
323
|
-
tag = graph['context']['tag']
|
|
324
|
-
data = self._readout_layer(graph)
|
|
325
|
-
|
|
326
|
-
table = tf.lookup.experimental.MutableHashTable(
|
|
327
|
-
key_dtype=tf.string,
|
|
328
|
-
value_dtype=tf.int32,
|
|
329
|
-
default_value=-1
|
|
330
|
-
)
|
|
331
|
-
|
|
332
|
-
table.insert(tag, tf.range(tf.shape(tag)[0]))
|
|
333
|
-
sequence = table.lookup(sequence)
|
|
334
|
-
|
|
335
|
-
readout = ops.gather(data, keras.ops.where(sequence == -1, 0, sequence))
|
|
336
|
-
readout = keras.ops.where(sequence[..., None] == -1, 0.0, readout)
|
|
337
|
-
return readout
|
|
338
|
-
|
|
339
|
-
def compute_mask(
|
|
340
|
-
self,
|
|
341
|
-
inputs: tensors.GraphTensor,
|
|
342
|
-
mask: bool | None = None
|
|
343
|
-
) -> tf.Tensor | None:
|
|
344
|
-
# if self.mask_value is None:
|
|
345
|
-
# return None
|
|
346
|
-
_, sequence = inputs
|
|
347
|
-
sequence = self._splitter(sequence)
|
|
348
|
-
return keras.ops.not_equal(sequence, '')
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
residues = {
|
|
352
|
-
"A": "N[C@@H](C)C(=O)O",
|
|
353
|
-
"C": "N[C@@H](CS)C(=O)O",
|
|
354
|
-
"C[Carbamidomethyl]": "N[C@@H](CSCC(=O)N)C(=O)O",
|
|
355
|
-
"D": "N[C@@H](CC(=O)O)C(=O)O",
|
|
356
|
-
"E": "N[C@@H](CCC(=O)O)C(=O)O",
|
|
357
|
-
"F": "N[C@@H](Cc1ccccc1)C(=O)O",
|
|
358
|
-
"G": "NCC(=O)O",
|
|
359
|
-
"H": "N[C@@H](CC1=CN=C-N1)C(=O)O",
|
|
360
|
-
"I": "N[C@@H](C(CC)C)C(=O)O",
|
|
361
|
-
"K": "N[C@@H](CCCCN)C(=O)O",
|
|
362
|
-
"K[Acetyl]": "N[C@@H](CCCCNC(=O)C)C(=O)O",
|
|
363
|
-
"K[Crotonyl]": "N[C@@H](CCCCNC(C=CC)=O)C(=O)O",
|
|
364
|
-
"K[Dimethyl]": "N[C@@H](CCCCN(C)C)C(=O)O",
|
|
365
|
-
"K[Formyl]": "N[C@@H](CCCCNC=O)C(=O)O",
|
|
366
|
-
"K[Malonyl]": "N[C@@H](CCCCNC(=O)CC(O)=O)C(=O)O",
|
|
367
|
-
"K[Methyl]": "N[C@@H](CCCCNC)C(=O)O",
|
|
368
|
-
"K[Propionyl]": "N[C@@H](CCCCNC(=O)CC)C(=O)O",
|
|
369
|
-
"K[Succinyl]": "N[C@@H](CCCCNC(CCC(O)=O)=O)C(=O)O",
|
|
370
|
-
"K[Trimethyl]": "N[C@@H](CCCC[N+](C)(C)C)C(=O)O",
|
|
371
|
-
"L": "N[C@@H](CC(C)C)C(=O)O",
|
|
372
|
-
"M": "N[C@@H](CCSC)C(=O)O",
|
|
373
|
-
"M[Oxidation]": "N[C@@H](CCS(=O)C)C(=O)O",
|
|
374
|
-
"N": "N[C@@H](CC(=O)N)C(=O)O",
|
|
375
|
-
"P": "N1[C@@H](CCC1)C(=O)O",
|
|
376
|
-
"P[Oxidation]": "N1CC(O)C[C@H]1C(=O)O",
|
|
377
|
-
"Q": "N[C@@H](CCC(=O)N)C(=O)O",
|
|
378
|
-
"R": "N[C@@H](CCCNC(=N)N)C(=O)O",
|
|
379
|
-
"R[Deamidated]": "N[C@@H](CCCNC(N)=O)C(=O)O",
|
|
380
|
-
"R[Dimethyl]": "N[C@@H](CCCNC(N(C)C)=N)C(=O)O",
|
|
381
|
-
"R[Methyl]": "N[C@@H](CCCNC(=N)NC)C(=O)O",
|
|
382
|
-
"S": "N[C@@H](CO)C(=O)O",
|
|
383
|
-
"T": "N[C@@H](C(O)C)C(=O)O",
|
|
384
|
-
"V": "N[C@@H](C(C)C)C(=O)O",
|
|
385
|
-
"W": "N[C@@H](CC(=CN2)C1=C2C=CC=C1)C(=O)O",
|
|
386
|
-
"Y": "N[C@@H](Cc1ccc(O)cc1)C(=O)O",
|
|
387
|
-
"Y[Nitro]": "N[C@@H](Cc1ccc(O)c(N(=O)=O)c1)C(=O)O",
|
|
388
|
-
"Y[Phospho]": "N[C@@H](Cc1ccc(OP(O)(=O)O)cc1)C(=O)O",
|
|
389
|
-
"[Acetyl]-A": "N(C(C)=O)[C@@H](C)C(=O)O",
|
|
390
|
-
"[Acetyl]-C": "N(C(C)=O)[C@@H](CS)C(=O)O",
|
|
391
|
-
"[Acetyl]-D": "N(C(=O)C)[C@H](C(=O)O)CC(=O)O",
|
|
392
|
-
"[Acetyl]-E": "N(C(=O)C)[C@@H](CCC(O)=O)C(=O)O",
|
|
393
|
-
"[Acetyl]-F": "N(C(C)=O)[C@@H](Cc1ccccc1)C(=O)O",
|
|
394
|
-
"[Acetyl]-G": "N(C(=O)C)CC(=O)O",
|
|
395
|
-
"[Acetyl]-H": "N(C(=O)C)[C@@H](Cc1[nH]cnc1)C(=O)O",
|
|
396
|
-
"[Acetyl]-I": "N(C(=O)C)[C@@H]([C@H](CC)C)C(=O)O",
|
|
397
|
-
"[Acetyl]-K": "N(C(C)=O)[C@@H](CCCCN)C(=O)O",
|
|
398
|
-
"[Acetyl]-L": "N(C(=O)C)[C@@H](CC(C)C)C(=O)O",
|
|
399
|
-
"[Acetyl]-M": "N(C(=O)C)[C@@H](CCSC)C(=O)O",
|
|
400
|
-
"[Acetyl]-N": "N(C(C)=O)[C@@H](CC(=O)N)C(=O)O",
|
|
401
|
-
"[Acetyl]-P": "N1(C(=O)C)CCC[C@H]1C(=O)O",
|
|
402
|
-
"[Acetyl]-Q": "N(C(=O)C)[C@@H](CCC(=O)N)C(=O)O",
|
|
403
|
-
"[Acetyl]-R": "N(C(C)=O)[C@@H](CCCN=C(N)N)C(=O)O",
|
|
404
|
-
"[Acetyl]-S": "N(C(C)=O)[C@@H](CO)C(=O)O",
|
|
405
|
-
"[Acetyl]-T": "N(C(=O)C)[C@@H]([C@H](O)C)C(=O)O",
|
|
406
|
-
"[Acetyl]-V": "N(C(=O)C)[C@@H](C(C)C)C(=O)O",
|
|
407
|
-
"[Acetyl]-W": "N(C(C)=O)[C@@H](Cc1c2ccccc2[nH]c1)C(=O)O",
|
|
408
|
-
"[Acetyl]-Y": "N(C(C)=O)[C@@H](Cc1ccc(O)cc1)C(=O)O"
|
|
409
|
-
}
|
|
410
|
-
|
|
411
|
-
residues_reverse = {}
|
|
412
|
-
def register_peptide_residues(residues_: dict[str, str], canonicalize=True):
|
|
413
|
-
for residue, smiles in residues_.items():
|
|
414
|
-
if canonicalize:
|
|
415
|
-
smiles = Chem.MolToSmiles(Chem.MolFromSmiles(smiles))
|
|
416
|
-
residues[residue] = smiles
|
|
417
|
-
residues_reverse[residues[residue]] = residue
|
|
418
|
-
|
|
419
|
-
register_peptide_residues(residues, canonicalize=False)
|
|
420
|
-
|
|
421
|
-
def _extract_residue_type(residue_tag: str) -> str:
|
|
422
|
-
pattern = r"(?<!\[)[A-Z](?![^\[]*\])"
|
|
423
|
-
return [match.group(0) for match in re.finditer(pattern, residue_tag)][0]
|
|
424
|
-
|
|
425
|
-
special_residues = {}
|
|
426
|
-
for key, value in residues.items():
|
|
427
|
-
special_residues[key + '-[X]'] = value.rstrip('O')
|
|
428
|
-
|
|
429
|
-
register_peptide_residues(special_residues, canonicalize=False)
|