molcraft 0.1.0a15__tar.gz → 0.1.0a16__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of molcraft might be problematic. Click here for more details.
- {molcraft-0.1.0a15 → molcraft-0.1.0a16}/PKG-INFO +2 -1
- {molcraft-0.1.0a15 → molcraft-0.1.0a16}/molcraft/__init__.py +1 -1
- molcraft-0.1.0a16/molcraft/applications/proteomics.py +239 -0
- {molcraft-0.1.0a15 → molcraft-0.1.0a16}/molcraft/layers.py +50 -0
- {molcraft-0.1.0a15 → molcraft-0.1.0a16}/molcraft.egg-info/PKG-INFO +2 -1
- {molcraft-0.1.0a15 → molcraft-0.1.0a16}/molcraft.egg-info/SOURCES.txt +2 -3
- {molcraft-0.1.0a15 → molcraft-0.1.0a16}/molcraft.egg-info/requires.txt +1 -0
- {molcraft-0.1.0a15 → molcraft-0.1.0a16}/pyproject.toml +1 -0
- molcraft-0.1.0a15/molcraft/apps/peptides.py +0 -429
- molcraft-0.1.0a15/molcraft/apps/qsrr.py +0 -47
- {molcraft-0.1.0a15 → molcraft-0.1.0a16}/LICENSE +0 -0
- {molcraft-0.1.0a15 → molcraft-0.1.0a16}/README.md +0 -0
- {molcraft-0.1.0a15/molcraft/apps → molcraft-0.1.0a16/molcraft/applications}/__init__.py +0 -0
- {molcraft-0.1.0a15 → molcraft-0.1.0a16}/molcraft/callbacks.py +0 -0
- {molcraft-0.1.0a15 → molcraft-0.1.0a16}/molcraft/chem.py +0 -0
- {molcraft-0.1.0a15 → molcraft-0.1.0a16}/molcraft/conformers.py +0 -0
- {molcraft-0.1.0a15 → molcraft-0.1.0a16}/molcraft/datasets.py +0 -0
- {molcraft-0.1.0a15 → molcraft-0.1.0a16}/molcraft/descriptors.py +0 -0
- {molcraft-0.1.0a15 → molcraft-0.1.0a16}/molcraft/features.py +0 -0
- {molcraft-0.1.0a15 → molcraft-0.1.0a16}/molcraft/featurizers.py +0 -0
- {molcraft-0.1.0a15 → molcraft-0.1.0a16}/molcraft/losses.py +0 -0
- {molcraft-0.1.0a15 → molcraft-0.1.0a16}/molcraft/models.py +0 -0
- {molcraft-0.1.0a15 → molcraft-0.1.0a16}/molcraft/ops.py +0 -0
- {molcraft-0.1.0a15 → molcraft-0.1.0a16}/molcraft/records.py +0 -0
- {molcraft-0.1.0a15 → molcraft-0.1.0a16}/molcraft/tensors.py +0 -0
- {molcraft-0.1.0a15 → molcraft-0.1.0a16}/molcraft.egg-info/dependency_links.txt +0 -0
- {molcraft-0.1.0a15 → molcraft-0.1.0a16}/molcraft.egg-info/top_level.txt +0 -0
- {molcraft-0.1.0a15 → molcraft-0.1.0a16}/setup.cfg +0 -0
- {molcraft-0.1.0a15 → molcraft-0.1.0a16}/tests/test_chem.py +0 -0
- {molcraft-0.1.0a15 → molcraft-0.1.0a16}/tests/test_featurizers.py +0 -0
- {molcraft-0.1.0a15 → molcraft-0.1.0a16}/tests/test_layers.py +0 -0
- {molcraft-0.1.0a15 → molcraft-0.1.0a16}/tests/test_losses.py +0 -0
- {molcraft-0.1.0a15 → molcraft-0.1.0a16}/tests/test_models.py +0 -0
- {molcraft-0.1.0a15 → molcraft-0.1.0a16}/tests/test_tensors.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: molcraft
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.0a16
|
|
4
4
|
Summary: Graph Neural Networks for Molecular Machine Learning
|
|
5
5
|
Author-email: Alexander Kensert <alexander.kensert@gmail.com>
|
|
6
6
|
License: MIT License
|
|
@@ -35,6 +35,7 @@ Requires-Python: >=3.10
|
|
|
35
35
|
Description-Content-Type: text/markdown
|
|
36
36
|
License-File: LICENSE
|
|
37
37
|
Requires-Dist: tensorflow>=2.16
|
|
38
|
+
Requires-Dist: tensorflow-text>=2.16
|
|
38
39
|
Requires-Dist: rdkit>=2023.9.5
|
|
39
40
|
Requires-Dist: pandas>=1.0.3
|
|
40
41
|
Requires-Dist: ipython>=8.12.0
|
|
@@ -0,0 +1,239 @@
|
|
|
1
|
+
import re
|
|
2
|
+
import keras
|
|
3
|
+
import numpy as np
|
|
4
|
+
import tensorflow as tf
|
|
5
|
+
import tensorflow_text as tf_text
|
|
6
|
+
import json
|
|
7
|
+
|
|
8
|
+
from molcraft import featurizers
|
|
9
|
+
from molcraft import tensors
|
|
10
|
+
from molcraft import layers
|
|
11
|
+
from molcraft import models
|
|
12
|
+
from molcraft import chem
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
# TODO: Add regex pattern for residue (C-term mod + N-term mod)?
|
|
16
|
+
# TODO: Add regex pattern for residue (C-term mod + N-term mod + mod)?
|
|
17
|
+
residue_pattern: str = "|".join([
|
|
18
|
+
r'(\[[A-Za-z0-9]+\]-[A-Z]\[[A-Za-z0-9]+\])', # residue (N-term mod + mod)
|
|
19
|
+
r'([A-Z]\[[A-Za-z0-9]+\]-\[[A-Za-z0-9]+\])', # residue (C-term mod + mod)
|
|
20
|
+
r'([A-Z]-\[[A-Za-z0-9]+\])', # residue (C-term mod)
|
|
21
|
+
r'(\[[A-Za-z0-9]+\]-[A-Z])', # residue (N-term mod)
|
|
22
|
+
r'([A-Z]\[[A-Za-z0-9]+\])', # residue (mod)
|
|
23
|
+
r'([A-Z])', # residue (no mod)
|
|
24
|
+
])
|
|
25
|
+
|
|
26
|
+
default_residues: dict[str, str] = {
|
|
27
|
+
"A": "N[C@@H](C)C(=O)O",
|
|
28
|
+
"C": "N[C@@H](CS)C(=O)O",
|
|
29
|
+
"D": "N[C@@H](CC(=O)O)C(=O)O",
|
|
30
|
+
"E": "N[C@@H](CCC(=O)O)C(=O)O",
|
|
31
|
+
"F": "N[C@@H](Cc1ccccc1)C(=O)O",
|
|
32
|
+
"G": "NCC(=O)O",
|
|
33
|
+
"H": "N[C@@H](CC1=CN=C-N1)C(=O)O",
|
|
34
|
+
"I": "N[C@@H](C(CC)C)C(=O)O",
|
|
35
|
+
"K": "N[C@@H](CCCCN)C(=O)O",
|
|
36
|
+
"L": "N[C@@H](CC(C)C)C(=O)O",
|
|
37
|
+
"M": "N[C@@H](CCSC)C(=O)O",
|
|
38
|
+
"N": "N[C@@H](CC(=O)N)C(=O)O",
|
|
39
|
+
"P": "N1[C@@H](CCC1)C(=O)O",
|
|
40
|
+
"Q": "N[C@@H](CCC(=O)N)C(=O)O",
|
|
41
|
+
"R": "N[C@@H](CCCNC(=N)N)C(=O)O",
|
|
42
|
+
"S": "N[C@@H](CO)C(=O)O",
|
|
43
|
+
"T": "N[C@@H](C(O)C)C(=O)O",
|
|
44
|
+
"V": "N[C@@H](C(C)C)C(=O)O",
|
|
45
|
+
"W": "N[C@@H](CC(=CN2)C1=C2C=CC=C1)C(=O)O",
|
|
46
|
+
"Y": "N[C@@H](Cc1ccc(O)cc1)C(=O)O",
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class Peptide(chem.Mol):
|
|
51
|
+
|
|
52
|
+
@classmethod
|
|
53
|
+
def from_sequence(cls, sequence: str, **kwargs) -> 'Peptide':
|
|
54
|
+
sequence = [
|
|
55
|
+
match.group(0) for match in re.finditer(residue_pattern, sequence)
|
|
56
|
+
]
|
|
57
|
+
peptide_smiles = []
|
|
58
|
+
for i, residue in enumerate(sequence):
|
|
59
|
+
if i < len(sequence) - 1:
|
|
60
|
+
residue_smiles = registered_residues[residue + '*']
|
|
61
|
+
else:
|
|
62
|
+
residue_smiles = registered_residues[residue]
|
|
63
|
+
peptide_smiles.append(residue_smiles)
|
|
64
|
+
peptide_smiles = ''.join(peptide_smiles)
|
|
65
|
+
return super().from_encoding(peptide_smiles, **kwargs)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@keras.saving.register_keras_serializable(package='proteomics')
|
|
69
|
+
class ResidueEmbedding(keras.layers.Layer):
|
|
70
|
+
|
|
71
|
+
def __init__(
|
|
72
|
+
self,
|
|
73
|
+
featurizer: featurizers.MolGraphFeaturizer,
|
|
74
|
+
embedder: models.GraphModel,
|
|
75
|
+
**kwargs
|
|
76
|
+
) -> None:
|
|
77
|
+
residues = kwargs.pop('_residues', None)
|
|
78
|
+
super().__init__(**kwargs)
|
|
79
|
+
if residues is None:
|
|
80
|
+
residues = registered_residues.copy()
|
|
81
|
+
self._residues = residues
|
|
82
|
+
self.embedder = embedder
|
|
83
|
+
self.featurizer = featurizer
|
|
84
|
+
self.ragged_split = SequenceSplitter(pad=False)
|
|
85
|
+
self.split = SequenceSplitter(pad=True)
|
|
86
|
+
self.supports_masking = True
|
|
87
|
+
|
|
88
|
+
def build(self, input_shape) -> None:
|
|
89
|
+
embedding_dim = self.embedder.output.shape[-1]
|
|
90
|
+
residues = sorted(self._residues.keys())
|
|
91
|
+
smiles = [self._residues[residue] for residue in residues]
|
|
92
|
+
num_residues = len(residues)
|
|
93
|
+
self.oov_index = np.where(np.array(residues) == "G")[0][0]
|
|
94
|
+
self.mapping = tf.lookup.StaticHashTable(
|
|
95
|
+
tf.lookup.KeyValueTensorInitializer(
|
|
96
|
+
keys=residues,
|
|
97
|
+
values=range(num_residues)
|
|
98
|
+
),
|
|
99
|
+
default_value=-1,
|
|
100
|
+
)
|
|
101
|
+
self.graph = tf.stack([self.featurizer(s) for s in smiles], axis=0)
|
|
102
|
+
self.cached_embeddings = tf.Variable(
|
|
103
|
+
initial_value=tf.zeros((num_residues, embedding_dim))
|
|
104
|
+
)
|
|
105
|
+
self.use_cached_embeddings = tf.Variable(False)
|
|
106
|
+
super().build(input_shape)
|
|
107
|
+
|
|
108
|
+
def call(self, sequences, training=None) -> tensors.GraphTensor:
|
|
109
|
+
if training is False:
|
|
110
|
+
self.use_cached_embeddings.assign(True)
|
|
111
|
+
else:
|
|
112
|
+
self.use_cached_embeddings.assign(False)
|
|
113
|
+
embeddings = tf.cond(
|
|
114
|
+
pred=self.use_cached_embeddings,
|
|
115
|
+
true_fn=lambda: self.cached_embeddings,
|
|
116
|
+
false_fn=lambda: self.embeddings(),
|
|
117
|
+
)
|
|
118
|
+
sequences = self.ragged_split(sequences)
|
|
119
|
+
sequences = keras.ops.concatenate([
|
|
120
|
+
tf.strings.join([sequences[:, :-1], '*']), sequences[:, -1:]
|
|
121
|
+
], axis=1)
|
|
122
|
+
indices = self.mapping.lookup(sequences)
|
|
123
|
+
indices = keras.ops.where(indices == -1, self.oov_index, indices)
|
|
124
|
+
return tf.gather(embeddings, indices).to_tensor()
|
|
125
|
+
|
|
126
|
+
def embeddings(self) -> tf.Tensor:
|
|
127
|
+
embeddings = self.embedder(self.graph)
|
|
128
|
+
self.cached_embeddings.assign(embeddings)
|
|
129
|
+
return embeddings
|
|
130
|
+
|
|
131
|
+
def compute_mask(
|
|
132
|
+
self,
|
|
133
|
+
inputs: tensors.GraphTensor,
|
|
134
|
+
mask: bool | None = None
|
|
135
|
+
) -> tf.Tensor | None:
|
|
136
|
+
sequences = self.split(inputs)
|
|
137
|
+
return keras.ops.not_equal(sequences, '')
|
|
138
|
+
|
|
139
|
+
def get_config(self) -> dict:
|
|
140
|
+
config = super().get_config()
|
|
141
|
+
config.update({
|
|
142
|
+
'_residues': self._residues,
|
|
143
|
+
'featurizer': keras.saving.serialize_keras_object(self.featurizer),
|
|
144
|
+
'embedder': keras.saving.serialize_keras_object(self.embedder)
|
|
145
|
+
})
|
|
146
|
+
return config
|
|
147
|
+
|
|
148
|
+
@classmethod
|
|
149
|
+
def from_config(cls, config: dict) -> 'ResidueEmbedding':
|
|
150
|
+
config['featurizer'] = keras.saving.deserialize_keras_object(config['featurizer'])
|
|
151
|
+
config['embedder'] = keras.saving.deserialize_keras_object(config['embedder'])
|
|
152
|
+
return super().from_config(config)
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
@keras.saving.register_keras_serializable(package='proteomics')
|
|
156
|
+
class SequenceSplitter(keras.layers.Layer):
|
|
157
|
+
|
|
158
|
+
def __init__(self, pad: bool, **kwargs):
|
|
159
|
+
super().__init__(**kwargs)
|
|
160
|
+
self.pad = pad
|
|
161
|
+
|
|
162
|
+
def call(self, inputs):
|
|
163
|
+
inputs = tf_text.regex_split(inputs, residue_pattern, residue_pattern)
|
|
164
|
+
if self.pad:
|
|
165
|
+
inputs = inputs.to_tensor()
|
|
166
|
+
return inputs
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def interpret(model: keras.models.Model, sequence: list[str]) -> tensors.GraphTensor:
|
|
170
|
+
|
|
171
|
+
if not tf.is_tensor(sequence):
|
|
172
|
+
sequence = keras.ops.convert_to_tensor(sequence)
|
|
173
|
+
|
|
174
|
+
# Find embedding layer
|
|
175
|
+
for layer in model.layers:
|
|
176
|
+
if isinstance(layer, ResidueEmbedding):
|
|
177
|
+
break
|
|
178
|
+
|
|
179
|
+
# Use embedding layer to convert the sequence to a graph
|
|
180
|
+
residues = layer.ragged_split(sequence)
|
|
181
|
+
residues = keras.ops.concatenate([
|
|
182
|
+
tf.strings.join([residues[:, :-1], '*']), residues[:, -1:]
|
|
183
|
+
], axis=1)
|
|
184
|
+
indices = layer.mapping.lookup(residues)
|
|
185
|
+
graph = tf.concat([
|
|
186
|
+
layer.graph[residue_ids] for residue_ids in indices
|
|
187
|
+
], axis=0)
|
|
188
|
+
|
|
189
|
+
# Define layer which reshapes data into sequences of residue embeddings
|
|
190
|
+
num_residues = indices.row_lengths()
|
|
191
|
+
to_sequence = (
|
|
192
|
+
lambda x: tf.RaggedTensor.from_row_lengths(x, num_residues).to_tensor()
|
|
193
|
+
)
|
|
194
|
+
reshape = keras.layers.Lambda(to_sequence)
|
|
195
|
+
|
|
196
|
+
# Obtain the embedder part of the original model
|
|
197
|
+
embedder = layer.embedder
|
|
198
|
+
# Obtain the remaining part of the original model
|
|
199
|
+
predictor = keras.models.Model(embedder.output, model.output)
|
|
200
|
+
# Obtain an 'interpretable model', based on the original model
|
|
201
|
+
inputs = layers.Input(graph.spec)
|
|
202
|
+
x = inputs
|
|
203
|
+
for layer in embedder.layers: # Loop over layers to expose them
|
|
204
|
+
x = layer(x)
|
|
205
|
+
x = reshape(x)
|
|
206
|
+
outputs = predictor(x)
|
|
207
|
+
interpretable_model = models.GraphModel(inputs, outputs)
|
|
208
|
+
|
|
209
|
+
# Interpret original model through the 'interpretable model'
|
|
210
|
+
graph = models.interpret(interpretable_model, graph)
|
|
211
|
+
del interpretable_model
|
|
212
|
+
|
|
213
|
+
# Update 'size' field with new sizes corresponding to peptides for convenience
|
|
214
|
+
# Allows the user to obtain n:th peptide graph using indexing: nth_peptide = graph[n]
|
|
215
|
+
peptide_indices = range(len(num_residues))
|
|
216
|
+
peptide_indicator = keras.ops.repeat(peptide_indices, num_residues)
|
|
217
|
+
residue_sizes = graph.context['size']
|
|
218
|
+
peptide_sizes = keras.ops.segment_sum(residue_sizes, peptide_indicator)
|
|
219
|
+
return graph.update({'context': {'size': peptide_sizes, 'sequence': sequence}})
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
def register_residues(residues: dict[str, str]) -> None:
|
|
223
|
+
# TODO: Implement functions that check if residue has N- or C-terminal mod
|
|
224
|
+
# if C-terminal mod, no need to enforce concatenatable perm.
|
|
225
|
+
# if N-terminal mod, enforce only 'C(=O)O'
|
|
226
|
+
# if normal mod, enforce concatenateable perm ('N[C@@H]' and 'C(=O)O)).
|
|
227
|
+
for residue, smiles in residues.items():
|
|
228
|
+
if residue.startswith('P'):
|
|
229
|
+
smiles.startswith('N'), f'Incorrect SMILES permutation for {residue}.'
|
|
230
|
+
elif not residue.startswith('['):
|
|
231
|
+
smiles.startswith('N[C@@H]'), f'Incorrect SMILES permutation for {residue}.'
|
|
232
|
+
if len(residue) > 1 and not residue[1] == "-":
|
|
233
|
+
assert smiles.endswith('C(=O)O'), f'Incorrect SMILES permutation for {residue}.'
|
|
234
|
+
registered_residues[residue] = smiles
|
|
235
|
+
registered_residues[residue + '*'] = smiles.strip('O')
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
registered_residues: dict[str, str] = {}
|
|
239
|
+
register_residues(default_residues)
|
|
@@ -1430,6 +1430,56 @@ class EdgeEmbedding(GraphLayer):
|
|
|
1430
1430
|
return config
|
|
1431
1431
|
|
|
1432
1432
|
|
|
1433
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1434
|
+
class AddContext(GraphLayer):
|
|
1435
|
+
|
|
1436
|
+
"""Context adding layer.
|
|
1437
|
+
|
|
1438
|
+
Adds context to super nodes.
|
|
1439
|
+
"""
|
|
1440
|
+
|
|
1441
|
+
def __init__(
|
|
1442
|
+
self,
|
|
1443
|
+
field: str = 'feature',
|
|
1444
|
+
drop: bool = True,
|
|
1445
|
+
normalize: bool = False,
|
|
1446
|
+
**kwargs
|
|
1447
|
+
) -> None:
|
|
1448
|
+
super().__init__(**kwargs)
|
|
1449
|
+
self.field = field
|
|
1450
|
+
self.drop = drop
|
|
1451
|
+
self._normalize = normalize
|
|
1452
|
+
|
|
1453
|
+
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
1454
|
+
feature_dim = spec.node['feature'].shape[-1]
|
|
1455
|
+
self._context_dense = self.get_dense(feature_dim)
|
|
1456
|
+
if not self._normalize:
|
|
1457
|
+
self._norm = keras.layers.Identity()
|
|
1458
|
+
elif str(self._normalize).lower().startswith('layer'):
|
|
1459
|
+
self._norm = keras.layers.LayerNormalization()
|
|
1460
|
+
else:
|
|
1461
|
+
self._norm = keras.layers.BatchNormalization()
|
|
1462
|
+
|
|
1463
|
+
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1464
|
+
context = tensor.context[self.field]
|
|
1465
|
+
context = self._context_dense(context)
|
|
1466
|
+
context = self._norm(context)
|
|
1467
|
+
node_feature = ops.scatter_add(
|
|
1468
|
+
tensor.node['feature'], tensor.node['super'], context
|
|
1469
|
+
)
|
|
1470
|
+
data = {'node': {'feature': node_feature}}
|
|
1471
|
+
if self.drop:
|
|
1472
|
+
data['context'] = {self.field: None}
|
|
1473
|
+
return tensor.update(data)
|
|
1474
|
+
|
|
1475
|
+
def get_config(self) -> dict:
|
|
1476
|
+
config = super().get_config()
|
|
1477
|
+
config['field'] = self.field
|
|
1478
|
+
config['drop'] = self.drop
|
|
1479
|
+
config['normalize'] = self._normalize
|
|
1480
|
+
return config
|
|
1481
|
+
|
|
1482
|
+
|
|
1433
1483
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
1434
1484
|
class GraphNetwork(GraphLayer):
|
|
1435
1485
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: molcraft
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.0a16
|
|
4
4
|
Summary: Graph Neural Networks for Molecular Machine Learning
|
|
5
5
|
Author-email: Alexander Kensert <alexander.kensert@gmail.com>
|
|
6
6
|
License: MIT License
|
|
@@ -35,6 +35,7 @@ Requires-Python: >=3.10
|
|
|
35
35
|
Description-Content-Type: text/markdown
|
|
36
36
|
License-File: LICENSE
|
|
37
37
|
Requires-Dist: tensorflow>=2.16
|
|
38
|
+
Requires-Dist: tensorflow-text>=2.16
|
|
38
39
|
Requires-Dist: rdkit>=2023.9.5
|
|
39
40
|
Requires-Dist: pandas>=1.0.3
|
|
40
41
|
Requires-Dist: ipython>=8.12.0
|
|
@@ -20,9 +20,8 @@ molcraft.egg-info/SOURCES.txt
|
|
|
20
20
|
molcraft.egg-info/dependency_links.txt
|
|
21
21
|
molcraft.egg-info/requires.txt
|
|
22
22
|
molcraft.egg-info/top_level.txt
|
|
23
|
-
molcraft/
|
|
24
|
-
molcraft/
|
|
25
|
-
molcraft/apps/qsrr.py
|
|
23
|
+
molcraft/applications/__init__.py
|
|
24
|
+
molcraft/applications/proteomics.py
|
|
26
25
|
tests/test_chem.py
|
|
27
26
|
tests/test_featurizers.py
|
|
28
27
|
tests/test_layers.py
|
|
@@ -1,429 +0,0 @@
|
|
|
1
|
-
import re
|
|
2
|
-
import keras
|
|
3
|
-
import numpy as np
|
|
4
|
-
import tensorflow as tf
|
|
5
|
-
import tensorflow_text as tf_text
|
|
6
|
-
from rdkit import Chem
|
|
7
|
-
|
|
8
|
-
from molcraft import ops
|
|
9
|
-
from molcraft import chem
|
|
10
|
-
from molcraft import features
|
|
11
|
-
from molcraft import featurizers
|
|
12
|
-
from molcraft import tensors
|
|
13
|
-
from molcraft import descriptors
|
|
14
|
-
from molcraft import layers
|
|
15
|
-
from molcraft import models
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
@keras.saving.register_keras_serializable(package='molcraft')
|
|
20
|
-
class SequenceSplitter(keras.layers.Layer):
|
|
21
|
-
|
|
22
|
-
_pattern = "|".join([
|
|
23
|
-
r'(\[[A-Za-z0-9]+\]-[A-Z]\[[A-Za-z0-9]+\])', # N-term mod + mod
|
|
24
|
-
r'([A-Z]\[[A-Za-z0-9]+\]-\[[A-Za-z0-9]+\])', # C-term mod + mod
|
|
25
|
-
r'([A-Z]-\[[A-Za-z0-9]+\])', # C-term mod
|
|
26
|
-
r'(\[[A-Za-z0-9]+\]-[A-Z])', # N-term mod
|
|
27
|
-
r'([A-Z]\[[A-Za-z0-9]+\])', # Mod
|
|
28
|
-
r'([A-Z])', # No mod
|
|
29
|
-
])
|
|
30
|
-
|
|
31
|
-
def call(self, inputs):
|
|
32
|
-
inputs = tf_text.regex_split(inputs, self._pattern, self._pattern)
|
|
33
|
-
inputs = keras.ops.concatenate([
|
|
34
|
-
tf.strings.join([inputs[:, :-1], '-[X]']),
|
|
35
|
-
inputs[:, -1:]
|
|
36
|
-
], axis=1)
|
|
37
|
-
return inputs.to_tensor()
|
|
38
|
-
|
|
39
|
-
@keras.saving.register_keras_serializable(package='molcraft')
|
|
40
|
-
class Gather(keras.layers.Layer):
|
|
41
|
-
|
|
42
|
-
def __init__(
|
|
43
|
-
self,
|
|
44
|
-
padding: list[tuple[int]] | tuple[int] | int = 1,
|
|
45
|
-
mask_value: int = 0,
|
|
46
|
-
**kwargs
|
|
47
|
-
) -> None:
|
|
48
|
-
super().__init__(**kwargs)
|
|
49
|
-
self._splitter = SequenceSplitter()
|
|
50
|
-
self.padding = padding
|
|
51
|
-
self.mask_value = mask_value
|
|
52
|
-
self.supports_masking = True
|
|
53
|
-
|
|
54
|
-
self._tags = list(sorted(residues.keys()))
|
|
55
|
-
self._mapping = tf.lookup.StaticHashTable(
|
|
56
|
-
tf.lookup.KeyValueTensorInitializer(
|
|
57
|
-
keys=self._tags,
|
|
58
|
-
values=range(len(self._tags)),
|
|
59
|
-
),
|
|
60
|
-
default_value=-1,
|
|
61
|
-
)
|
|
62
|
-
|
|
63
|
-
def get_config(self):
|
|
64
|
-
config = super().get_config()
|
|
65
|
-
config['mask_value'] = self.mask_value
|
|
66
|
-
config['padding'] = self.padding
|
|
67
|
-
return config
|
|
68
|
-
|
|
69
|
-
def call(self, inputs) -> tf.Tensor:
|
|
70
|
-
embedding, sequence = inputs
|
|
71
|
-
sequence = self._splitter(sequence)
|
|
72
|
-
sequence = self._mapping.lookup(sequence)
|
|
73
|
-
readout = ops.gather(embedding, keras.ops.where(sequence == -1, 0, sequence))
|
|
74
|
-
readout = keras.ops.where(sequence[..., None] == -1, 0.0, readout)
|
|
75
|
-
return readout
|
|
76
|
-
|
|
77
|
-
def compute_mask(
|
|
78
|
-
self,
|
|
79
|
-
inputs: tensors.GraphTensor,
|
|
80
|
-
mask: bool | None = None
|
|
81
|
-
) -> tf.Tensor | None:
|
|
82
|
-
# if self.mask_value is None:
|
|
83
|
-
# return None
|
|
84
|
-
_, sequence = inputs
|
|
85
|
-
sequence = self._splitter(sequence)
|
|
86
|
-
return keras.ops.not_equal(sequence, '')
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
@keras.saving.register_keras_serializable(package='molcraft')
|
|
90
|
-
class Embedding(keras.layers.Layer):
|
|
91
|
-
|
|
92
|
-
def __init__(self, **kwargs):
|
|
93
|
-
super().__init__(**kwargs)
|
|
94
|
-
tags = list(sorted(residues.keys()))
|
|
95
|
-
self.mapping = tf.lookup.StaticHashTable(
|
|
96
|
-
tf.lookup.KeyValueTensorInitializer(
|
|
97
|
-
keys=tags,
|
|
98
|
-
values=range(len(tags)),
|
|
99
|
-
),
|
|
100
|
-
default_value=-1,
|
|
101
|
-
)
|
|
102
|
-
self.splitting = SequenceSplitter()
|
|
103
|
-
featurizer = featurizers.MolGraphFeaturizer(super_atom=True)
|
|
104
|
-
tensor_list = [featurizer(residues[tag]) for tag in tags]
|
|
105
|
-
graph = tf.stack(tensor_list, axis=0)
|
|
106
|
-
self._build_on_init(graph)
|
|
107
|
-
self.embedder = models.GraphModel.from_layers(
|
|
108
|
-
[
|
|
109
|
-
layers.Input(graph.spec),
|
|
110
|
-
layers.NodeEmbedding(128),
|
|
111
|
-
layers.EdgeEmbedding(128),
|
|
112
|
-
layers.GraphTransformer(128),
|
|
113
|
-
layers.Readout()
|
|
114
|
-
]
|
|
115
|
-
)
|
|
116
|
-
self.embedding = tf.Variable(
|
|
117
|
-
initial_value=tf.zeros((114, 128)), trainable=True
|
|
118
|
-
)
|
|
119
|
-
self.new_state = tf.Variable(True, dtype=tf.bool, trainable=False)
|
|
120
|
-
self.gather = Gather()
|
|
121
|
-
self.update_state()
|
|
122
|
-
|
|
123
|
-
# Keep AA as is (most simple?), add positional embedding to distingusih N-, C- and non-terminal
|
|
124
|
-
|
|
125
|
-
def update_state(self, inputs=None):
|
|
126
|
-
graph = self._graph_tensor
|
|
127
|
-
graph = tensors.to_dict(graph)
|
|
128
|
-
embedding = self.embedder(graph)
|
|
129
|
-
self.embedding.assign(embedding)
|
|
130
|
-
tf.print("STATE UPDATED")
|
|
131
|
-
return embedding
|
|
132
|
-
|
|
133
|
-
def call(self, inputs=None, training=None) -> tensors.GraphTensor:
|
|
134
|
-
if training:
|
|
135
|
-
embedding = self.update_state()
|
|
136
|
-
self.new_state.assign(True)
|
|
137
|
-
return self.gather([embedding, inputs])
|
|
138
|
-
else:
|
|
139
|
-
embedding = tf.cond(
|
|
140
|
-
pred=self.new_state,
|
|
141
|
-
true_fn=lambda: self.update_state(),
|
|
142
|
-
false_fn=lambda: self.embedding
|
|
143
|
-
)
|
|
144
|
-
self.new_state.assign(False)
|
|
145
|
-
return self.gather([embedding, inputs])
|
|
146
|
-
|
|
147
|
-
def build(self, input_shape):
|
|
148
|
-
super().build(input_shape)
|
|
149
|
-
|
|
150
|
-
def _build_on_init(self, x):
|
|
151
|
-
|
|
152
|
-
if isinstance(x, tensors.GraphTensor):
|
|
153
|
-
tensor = tensors.to_dict(x)
|
|
154
|
-
self._spec = tf.nest.map_structure(
|
|
155
|
-
tf.type_spec_from_value, tensor
|
|
156
|
-
)
|
|
157
|
-
else:
|
|
158
|
-
self._spec = x
|
|
159
|
-
|
|
160
|
-
self._graph = tf.nest.map_structure(
|
|
161
|
-
lambda s: self.add_weight(
|
|
162
|
-
shape=s.shape,
|
|
163
|
-
dtype=s.dtype,
|
|
164
|
-
trainable=False,
|
|
165
|
-
initializer='zeros'
|
|
166
|
-
),
|
|
167
|
-
self._spec
|
|
168
|
-
)
|
|
169
|
-
|
|
170
|
-
if isinstance(x, tensors.GraphTensor):
|
|
171
|
-
tf.nest.map_structure(
|
|
172
|
-
lambda v, x: v.assign(x),
|
|
173
|
-
self._graph, tensor
|
|
174
|
-
)
|
|
175
|
-
|
|
176
|
-
graph = tf.nest.map_structure(
|
|
177
|
-
keras.ops.convert_to_tensor, self._graph
|
|
178
|
-
)
|
|
179
|
-
self._graph_tensor = tensors.from_dict(graph)
|
|
180
|
-
|
|
181
|
-
# def get_config(self) -> dict:
|
|
182
|
-
# config = super().get_config()
|
|
183
|
-
# spec = keras.saving.serialize_keras_object(self._spec)
|
|
184
|
-
# config['spec'] = spec
|
|
185
|
-
# #config['layers'] = keras.saving.serialize_keras_object(self.embedding.layers)
|
|
186
|
-
# return config
|
|
187
|
-
|
|
188
|
-
# @classmethod
|
|
189
|
-
# def from_config(cls, config: dict) -> 'SequenceToGraph':
|
|
190
|
-
# spec = config.pop('spec')
|
|
191
|
-
# spec = keras.saving.deserialize_keras_object(spec)
|
|
192
|
-
# # config['layers'] = keras.saving.deserialize_keras_object(config['layers'])
|
|
193
|
-
# layer = cls(**config)
|
|
194
|
-
# layer._build_on_init(spec)
|
|
195
|
-
# return layer
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
@keras.saving.register_keras_serializable(package='molcraft')
|
|
199
|
-
class SequenceToGraph(keras.layers.Layer):
|
|
200
|
-
|
|
201
|
-
def __init__(
|
|
202
|
-
self,
|
|
203
|
-
atom_features: list[features.Feature] | str | None = 'auto',
|
|
204
|
-
bond_features: list[features.Feature] | str | None = 'auto',
|
|
205
|
-
molecule_features: list[descriptors.Descriptor] | str | None = 'auto',
|
|
206
|
-
super_atom: bool = True,
|
|
207
|
-
radius: int | float | None = None,
|
|
208
|
-
self_loops: bool = False,
|
|
209
|
-
include_hs: bool = False,
|
|
210
|
-
**kwargs,
|
|
211
|
-
):
|
|
212
|
-
super().__init__(**kwargs)
|
|
213
|
-
self._splitter = SequenceSplitter()
|
|
214
|
-
featurizer = featurizers.MolGraphFeaturizer(
|
|
215
|
-
atom_features=atom_features,
|
|
216
|
-
bond_features=bond_features,
|
|
217
|
-
molecule_features=molecule_features,
|
|
218
|
-
super_atom=super_atom,
|
|
219
|
-
radius=radius,
|
|
220
|
-
self_loops=self_loops,
|
|
221
|
-
include_hs=include_hs,
|
|
222
|
-
**kwargs,
|
|
223
|
-
)
|
|
224
|
-
tensor_list: list[tensors.GraphTensor] = [
|
|
225
|
-
featurizer(residues[tag]).update({'context': {'tag': tag}}) for tag in residues
|
|
226
|
-
]
|
|
227
|
-
graph = tf.stack(tensor_list, axis=0)
|
|
228
|
-
self._build_on_init(graph)
|
|
229
|
-
|
|
230
|
-
def call(self, sequence: tf.Tensor) -> tensors.GraphTensor:
|
|
231
|
-
sequence = self._splitter(sequence)
|
|
232
|
-
indices = self._tag_to_index.lookup(sequence)
|
|
233
|
-
indices = tf.sort(tf.unique(tf.reshape(indices, [-1]))[0])[1:]
|
|
234
|
-
graph = self._graph_tensor[indices]
|
|
235
|
-
return tensors.to_dict(graph)
|
|
236
|
-
|
|
237
|
-
def _build_on_init(self, x):
|
|
238
|
-
|
|
239
|
-
if isinstance(x, tensors.GraphTensor):
|
|
240
|
-
tensor = tensors.to_dict(x)
|
|
241
|
-
self._spec = tf.nest.map_structure(
|
|
242
|
-
tf.type_spec_from_value, tensor
|
|
243
|
-
)
|
|
244
|
-
else:
|
|
245
|
-
self._spec = x
|
|
246
|
-
|
|
247
|
-
self._graph = tf.nest.map_structure(
|
|
248
|
-
lambda s: self.add_weight(
|
|
249
|
-
shape=s.shape,
|
|
250
|
-
dtype=s.dtype,
|
|
251
|
-
trainable=False,
|
|
252
|
-
initializer='zeros'
|
|
253
|
-
),
|
|
254
|
-
self._spec
|
|
255
|
-
)
|
|
256
|
-
|
|
257
|
-
if isinstance(x, tensors.GraphTensor):
|
|
258
|
-
tf.nest.map_structure(
|
|
259
|
-
lambda v, x: v.assign(x),
|
|
260
|
-
self._graph, tensor
|
|
261
|
-
)
|
|
262
|
-
|
|
263
|
-
graph = tf.nest.map_structure(
|
|
264
|
-
keras.ops.convert_to_tensor, self._graph
|
|
265
|
-
)
|
|
266
|
-
self._graph_tensor = tensors.from_dict(graph)
|
|
267
|
-
|
|
268
|
-
tags = self._graph_tensor.context['tag']
|
|
269
|
-
|
|
270
|
-
self._tag_to_index = tf.lookup.StaticHashTable(
|
|
271
|
-
tf.lookup.KeyValueTensorInitializer(
|
|
272
|
-
keys=tags,
|
|
273
|
-
values=range(len(tags)),
|
|
274
|
-
),
|
|
275
|
-
default_value=-1,
|
|
276
|
-
)
|
|
277
|
-
|
|
278
|
-
def get_config(self) -> dict:
|
|
279
|
-
config = super().get_config()
|
|
280
|
-
spec = keras.saving.serialize_keras_object(self._spec)
|
|
281
|
-
config['spec'] = spec
|
|
282
|
-
return config
|
|
283
|
-
|
|
284
|
-
@classmethod
|
|
285
|
-
def from_config(cls, config: dict) -> 'SequenceToGraph':
|
|
286
|
-
spec = config.pop('spec')
|
|
287
|
-
spec = keras.saving.deserialize_keras_object(spec)
|
|
288
|
-
layer = cls(**config)
|
|
289
|
-
layer._build_on_init(spec)
|
|
290
|
-
return layer
|
|
291
|
-
|
|
292
|
-
# @property
|
|
293
|
-
# def graph(self) -> tensors.GraphTensor:
|
|
294
|
-
# return self._graph_tensor
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
@keras.saving.register_keras_serializable(package='molcraft')
|
|
298
|
-
class GraphToSequence(keras.layers.Layer):
|
|
299
|
-
|
|
300
|
-
def __init__(
|
|
301
|
-
self,
|
|
302
|
-
padding: list[tuple[int]] | tuple[int] | int = 1,
|
|
303
|
-
mask_value: int = 0,
|
|
304
|
-
**kwargs
|
|
305
|
-
) -> None:
|
|
306
|
-
super().__init__(**kwargs)
|
|
307
|
-
self._splitter = SequenceSplitter()
|
|
308
|
-
self.padding = padding
|
|
309
|
-
self.mask_value = mask_value
|
|
310
|
-
self._readout_layer = layers.Readout(mode='mean')
|
|
311
|
-
self.supports_masking = True
|
|
312
|
-
|
|
313
|
-
def get_config(self):
|
|
314
|
-
config = super().get_config()
|
|
315
|
-
config['mask_value'] = self.mask_value
|
|
316
|
-
config['padding'] = self.padding
|
|
317
|
-
return config
|
|
318
|
-
|
|
319
|
-
def call(self, inputs) -> tf.Tensor:
|
|
320
|
-
|
|
321
|
-
graph, sequence = inputs
|
|
322
|
-
sequence = self._splitter(sequence)
|
|
323
|
-
tag = graph['context']['tag']
|
|
324
|
-
data = self._readout_layer(graph)
|
|
325
|
-
|
|
326
|
-
table = tf.lookup.experimental.MutableHashTable(
|
|
327
|
-
key_dtype=tf.string,
|
|
328
|
-
value_dtype=tf.int32,
|
|
329
|
-
default_value=-1
|
|
330
|
-
)
|
|
331
|
-
|
|
332
|
-
table.insert(tag, tf.range(tf.shape(tag)[0]))
|
|
333
|
-
sequence = table.lookup(sequence)
|
|
334
|
-
|
|
335
|
-
readout = ops.gather(data, keras.ops.where(sequence == -1, 0, sequence))
|
|
336
|
-
readout = keras.ops.where(sequence[..., None] == -1, 0.0, readout)
|
|
337
|
-
return readout
|
|
338
|
-
|
|
339
|
-
def compute_mask(
|
|
340
|
-
self,
|
|
341
|
-
inputs: tensors.GraphTensor,
|
|
342
|
-
mask: bool | None = None
|
|
343
|
-
) -> tf.Tensor | None:
|
|
344
|
-
# if self.mask_value is None:
|
|
345
|
-
# return None
|
|
346
|
-
_, sequence = inputs
|
|
347
|
-
sequence = self._splitter(sequence)
|
|
348
|
-
return keras.ops.not_equal(sequence, '')
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
residues = {
|
|
352
|
-
"A": "N[C@@H](C)C(=O)O",
|
|
353
|
-
"C": "N[C@@H](CS)C(=O)O",
|
|
354
|
-
"C[Carbamidomethyl]": "N[C@@H](CSCC(=O)N)C(=O)O",
|
|
355
|
-
"D": "N[C@@H](CC(=O)O)C(=O)O",
|
|
356
|
-
"E": "N[C@@H](CCC(=O)O)C(=O)O",
|
|
357
|
-
"F": "N[C@@H](Cc1ccccc1)C(=O)O",
|
|
358
|
-
"G": "NCC(=O)O",
|
|
359
|
-
"H": "N[C@@H](CC1=CN=C-N1)C(=O)O",
|
|
360
|
-
"I": "N[C@@H](C(CC)C)C(=O)O",
|
|
361
|
-
"K": "N[C@@H](CCCCN)C(=O)O",
|
|
362
|
-
"K[Acetyl]": "N[C@@H](CCCCNC(=O)C)C(=O)O",
|
|
363
|
-
"K[Crotonyl]": "N[C@@H](CCCCNC(C=CC)=O)C(=O)O",
|
|
364
|
-
"K[Dimethyl]": "N[C@@H](CCCCN(C)C)C(=O)O",
|
|
365
|
-
"K[Formyl]": "N[C@@H](CCCCNC=O)C(=O)O",
|
|
366
|
-
"K[Malonyl]": "N[C@@H](CCCCNC(=O)CC(O)=O)C(=O)O",
|
|
367
|
-
"K[Methyl]": "N[C@@H](CCCCNC)C(=O)O",
|
|
368
|
-
"K[Propionyl]": "N[C@@H](CCCCNC(=O)CC)C(=O)O",
|
|
369
|
-
"K[Succinyl]": "N[C@@H](CCCCNC(CCC(O)=O)=O)C(=O)O",
|
|
370
|
-
"K[Trimethyl]": "N[C@@H](CCCC[N+](C)(C)C)C(=O)O",
|
|
371
|
-
"L": "N[C@@H](CC(C)C)C(=O)O",
|
|
372
|
-
"M": "N[C@@H](CCSC)C(=O)O",
|
|
373
|
-
"M[Oxidation]": "N[C@@H](CCS(=O)C)C(=O)O",
|
|
374
|
-
"N": "N[C@@H](CC(=O)N)C(=O)O",
|
|
375
|
-
"P": "N1[C@@H](CCC1)C(=O)O",
|
|
376
|
-
"P[Oxidation]": "N1CC(O)C[C@H]1C(=O)O",
|
|
377
|
-
"Q": "N[C@@H](CCC(=O)N)C(=O)O",
|
|
378
|
-
"R": "N[C@@H](CCCNC(=N)N)C(=O)O",
|
|
379
|
-
"R[Deamidated]": "N[C@@H](CCCNC(N)=O)C(=O)O",
|
|
380
|
-
"R[Dimethyl]": "N[C@@H](CCCNC(N(C)C)=N)C(=O)O",
|
|
381
|
-
"R[Methyl]": "N[C@@H](CCCNC(=N)NC)C(=O)O",
|
|
382
|
-
"S": "N[C@@H](CO)C(=O)O",
|
|
383
|
-
"T": "N[C@@H](C(O)C)C(=O)O",
|
|
384
|
-
"V": "N[C@@H](C(C)C)C(=O)O",
|
|
385
|
-
"W": "N[C@@H](CC(=CN2)C1=C2C=CC=C1)C(=O)O",
|
|
386
|
-
"Y": "N[C@@H](Cc1ccc(O)cc1)C(=O)O",
|
|
387
|
-
"Y[Nitro]": "N[C@@H](Cc1ccc(O)c(N(=O)=O)c1)C(=O)O",
|
|
388
|
-
"Y[Phospho]": "N[C@@H](Cc1ccc(OP(O)(=O)O)cc1)C(=O)O",
|
|
389
|
-
"[Acetyl]-A": "N(C(C)=O)[C@@H](C)C(=O)O",
|
|
390
|
-
"[Acetyl]-C": "N(C(C)=O)[C@@H](CS)C(=O)O",
|
|
391
|
-
"[Acetyl]-D": "N(C(=O)C)[C@H](C(=O)O)CC(=O)O",
|
|
392
|
-
"[Acetyl]-E": "N(C(=O)C)[C@@H](CCC(O)=O)C(=O)O",
|
|
393
|
-
"[Acetyl]-F": "N(C(C)=O)[C@@H](Cc1ccccc1)C(=O)O",
|
|
394
|
-
"[Acetyl]-G": "N(C(=O)C)CC(=O)O",
|
|
395
|
-
"[Acetyl]-H": "N(C(=O)C)[C@@H](Cc1[nH]cnc1)C(=O)O",
|
|
396
|
-
"[Acetyl]-I": "N(C(=O)C)[C@@H]([C@H](CC)C)C(=O)O",
|
|
397
|
-
"[Acetyl]-K": "N(C(C)=O)[C@@H](CCCCN)C(=O)O",
|
|
398
|
-
"[Acetyl]-L": "N(C(=O)C)[C@@H](CC(C)C)C(=O)O",
|
|
399
|
-
"[Acetyl]-M": "N(C(=O)C)[C@@H](CCSC)C(=O)O",
|
|
400
|
-
"[Acetyl]-N": "N(C(C)=O)[C@@H](CC(=O)N)C(=O)O",
|
|
401
|
-
"[Acetyl]-P": "N1(C(=O)C)CCC[C@H]1C(=O)O",
|
|
402
|
-
"[Acetyl]-Q": "N(C(=O)C)[C@@H](CCC(=O)N)C(=O)O",
|
|
403
|
-
"[Acetyl]-R": "N(C(C)=O)[C@@H](CCCN=C(N)N)C(=O)O",
|
|
404
|
-
"[Acetyl]-S": "N(C(C)=O)[C@@H](CO)C(=O)O",
|
|
405
|
-
"[Acetyl]-T": "N(C(=O)C)[C@@H]([C@H](O)C)C(=O)O",
|
|
406
|
-
"[Acetyl]-V": "N(C(=O)C)[C@@H](C(C)C)C(=O)O",
|
|
407
|
-
"[Acetyl]-W": "N(C(C)=O)[C@@H](Cc1c2ccccc2[nH]c1)C(=O)O",
|
|
408
|
-
"[Acetyl]-Y": "N(C(C)=O)[C@@H](Cc1ccc(O)cc1)C(=O)O"
|
|
409
|
-
}
|
|
410
|
-
|
|
411
|
-
residues_reverse = {}
|
|
412
|
-
def register_peptide_residues(residues_: dict[str, str], canonicalize=True):
|
|
413
|
-
for residue, smiles in residues_.items():
|
|
414
|
-
if canonicalize:
|
|
415
|
-
smiles = Chem.MolToSmiles(Chem.MolFromSmiles(smiles))
|
|
416
|
-
residues[residue] = smiles
|
|
417
|
-
residues_reverse[residues[residue]] = residue
|
|
418
|
-
|
|
419
|
-
register_peptide_residues(residues, canonicalize=False)
|
|
420
|
-
|
|
421
|
-
def _extract_residue_type(residue_tag: str) -> str:
|
|
422
|
-
pattern = r"(?<!\[)[A-Z](?![^\[]*\])"
|
|
423
|
-
return [match.group(0) for match in re.finditer(pattern, residue_tag)][0]
|
|
424
|
-
|
|
425
|
-
special_residues = {}
|
|
426
|
-
for key, value in residues.items():
|
|
427
|
-
special_residues[key + '-[X]'] = value.rstrip('O')
|
|
428
|
-
|
|
429
|
-
register_peptide_residues(special_residues, canonicalize=False)
|
|
@@ -1,47 +0,0 @@
|
|
|
1
|
-
import molcraft
|
|
2
|
-
import keras
|
|
3
|
-
|
|
4
|
-
@keras.saving.register_keras_serializable(package='molcraft')
|
|
5
|
-
class AuxiliaryFeatureInjection(molcraft.layers.GraphLayer):
|
|
6
|
-
|
|
7
|
-
def __init__(
|
|
8
|
-
self,
|
|
9
|
-
field: str = 'auxiliary_feature',
|
|
10
|
-
depth: int = 2,
|
|
11
|
-
drop: bool = True,
|
|
12
|
-
activation: str | None = None,
|
|
13
|
-
**kwargs,
|
|
14
|
-
) -> None:
|
|
15
|
-
super().__init__(**kwargs)
|
|
16
|
-
self.field = field
|
|
17
|
-
self.depth = depth
|
|
18
|
-
self.drop = drop
|
|
19
|
-
self.activation = keras.activations.get(activation)
|
|
20
|
-
|
|
21
|
-
def build(self, spec: molcraft.tensors.GraphTensor.Spec) -> None:
|
|
22
|
-
units = spec.node['feature'].shape[1]
|
|
23
|
-
for i in range(self.depth):
|
|
24
|
-
setattr(
|
|
25
|
-
self, f'dense_{i}', self.get_dense(units, activation=self.activation)
|
|
26
|
-
)
|
|
27
|
-
|
|
28
|
-
def propagate(self, tensor: molcraft.tensors.GraphTensor) -> None:
|
|
29
|
-
x = tensor.context[self.field]
|
|
30
|
-
if self.drop:
|
|
31
|
-
tensor = tensor.update({'context': {self.field: None}})
|
|
32
|
-
for i in range(self.depth):
|
|
33
|
-
x = getattr(self, f'dense_{i}')(x)
|
|
34
|
-
node_feature = molcraft.ops.scatter_add(
|
|
35
|
-
tensor.node['feature'], tensor.node['super'], x
|
|
36
|
-
)
|
|
37
|
-
return tensor.update({'node': {'feature': node_feature}})
|
|
38
|
-
|
|
39
|
-
def get_config(self) -> dict:
|
|
40
|
-
config = super().get_config()
|
|
41
|
-
config.update({
|
|
42
|
-
'field': self.field,
|
|
43
|
-
'depth': self.depth,
|
|
44
|
-
'drop': self.drop,
|
|
45
|
-
'activation': keras.activations.serialize(self.activation)
|
|
46
|
-
})
|
|
47
|
-
return config
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|