molcraft 0.1.0a15__py3-none-any.whl → 0.1.0a16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of molcraft might be problematic. Click here for more details.

molcraft/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = '0.1.0a15'
1
+ __version__ = '0.1.0a16'
2
2
 
3
3
  import os
4
4
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
@@ -0,0 +1,239 @@
1
+ import re
2
+ import keras
3
+ import numpy as np
4
+ import tensorflow as tf
5
+ import tensorflow_text as tf_text
6
+ import json
7
+
8
+ from molcraft import featurizers
9
+ from molcraft import tensors
10
+ from molcraft import layers
11
+ from molcraft import models
12
+ from molcraft import chem
13
+
14
+
15
+ # TODO: Add regex pattern for residue (C-term mod + N-term mod)?
16
+ # TODO: Add regex pattern for residue (C-term mod + N-term mod + mod)?
17
+ residue_pattern: str = "|".join([
18
+ r'(\[[A-Za-z0-9]+\]-[A-Z]\[[A-Za-z0-9]+\])', # residue (N-term mod + mod)
19
+ r'([A-Z]\[[A-Za-z0-9]+\]-\[[A-Za-z0-9]+\])', # residue (C-term mod + mod)
20
+ r'([A-Z]-\[[A-Za-z0-9]+\])', # residue (C-term mod)
21
+ r'(\[[A-Za-z0-9]+\]-[A-Z])', # residue (N-term mod)
22
+ r'([A-Z]\[[A-Za-z0-9]+\])', # residue (mod)
23
+ r'([A-Z])', # residue (no mod)
24
+ ])
25
+
26
+ default_residues: dict[str, str] = {
27
+ "A": "N[C@@H](C)C(=O)O",
28
+ "C": "N[C@@H](CS)C(=O)O",
29
+ "D": "N[C@@H](CC(=O)O)C(=O)O",
30
+ "E": "N[C@@H](CCC(=O)O)C(=O)O",
31
+ "F": "N[C@@H](Cc1ccccc1)C(=O)O",
32
+ "G": "NCC(=O)O",
33
+ "H": "N[C@@H](CC1=CN=C-N1)C(=O)O",
34
+ "I": "N[C@@H](C(CC)C)C(=O)O",
35
+ "K": "N[C@@H](CCCCN)C(=O)O",
36
+ "L": "N[C@@H](CC(C)C)C(=O)O",
37
+ "M": "N[C@@H](CCSC)C(=O)O",
38
+ "N": "N[C@@H](CC(=O)N)C(=O)O",
39
+ "P": "N1[C@@H](CCC1)C(=O)O",
40
+ "Q": "N[C@@H](CCC(=O)N)C(=O)O",
41
+ "R": "N[C@@H](CCCNC(=N)N)C(=O)O",
42
+ "S": "N[C@@H](CO)C(=O)O",
43
+ "T": "N[C@@H](C(O)C)C(=O)O",
44
+ "V": "N[C@@H](C(C)C)C(=O)O",
45
+ "W": "N[C@@H](CC(=CN2)C1=C2C=CC=C1)C(=O)O",
46
+ "Y": "N[C@@H](Cc1ccc(O)cc1)C(=O)O",
47
+ }
48
+
49
+
50
+ class Peptide(chem.Mol):
51
+
52
+ @classmethod
53
+ def from_sequence(cls, sequence: str, **kwargs) -> 'Peptide':
54
+ sequence = [
55
+ match.group(0) for match in re.finditer(residue_pattern, sequence)
56
+ ]
57
+ peptide_smiles = []
58
+ for i, residue in enumerate(sequence):
59
+ if i < len(sequence) - 1:
60
+ residue_smiles = registered_residues[residue + '*']
61
+ else:
62
+ residue_smiles = registered_residues[residue]
63
+ peptide_smiles.append(residue_smiles)
64
+ peptide_smiles = ''.join(peptide_smiles)
65
+ return super().from_encoding(peptide_smiles, **kwargs)
66
+
67
+
68
+ @keras.saving.register_keras_serializable(package='proteomics')
69
+ class ResidueEmbedding(keras.layers.Layer):
70
+
71
+ def __init__(
72
+ self,
73
+ featurizer: featurizers.MolGraphFeaturizer,
74
+ embedder: models.GraphModel,
75
+ **kwargs
76
+ ) -> None:
77
+ residues = kwargs.pop('_residues', None)
78
+ super().__init__(**kwargs)
79
+ if residues is None:
80
+ residues = registered_residues.copy()
81
+ self._residues = residues
82
+ self.embedder = embedder
83
+ self.featurizer = featurizer
84
+ self.ragged_split = SequenceSplitter(pad=False)
85
+ self.split = SequenceSplitter(pad=True)
86
+ self.supports_masking = True
87
+
88
+ def build(self, input_shape) -> None:
89
+ embedding_dim = self.embedder.output.shape[-1]
90
+ residues = sorted(self._residues.keys())
91
+ smiles = [self._residues[residue] for residue in residues]
92
+ num_residues = len(residues)
93
+ self.oov_index = np.where(np.array(residues) == "G")[0][0]
94
+ self.mapping = tf.lookup.StaticHashTable(
95
+ tf.lookup.KeyValueTensorInitializer(
96
+ keys=residues,
97
+ values=range(num_residues)
98
+ ),
99
+ default_value=-1,
100
+ )
101
+ self.graph = tf.stack([self.featurizer(s) for s in smiles], axis=0)
102
+ self.cached_embeddings = tf.Variable(
103
+ initial_value=tf.zeros((num_residues, embedding_dim))
104
+ )
105
+ self.use_cached_embeddings = tf.Variable(False)
106
+ super().build(input_shape)
107
+
108
+ def call(self, sequences, training=None) -> tensors.GraphTensor:
109
+ if training is False:
110
+ self.use_cached_embeddings.assign(True)
111
+ else:
112
+ self.use_cached_embeddings.assign(False)
113
+ embeddings = tf.cond(
114
+ pred=self.use_cached_embeddings,
115
+ true_fn=lambda: self.cached_embeddings,
116
+ false_fn=lambda: self.embeddings(),
117
+ )
118
+ sequences = self.ragged_split(sequences)
119
+ sequences = keras.ops.concatenate([
120
+ tf.strings.join([sequences[:, :-1], '*']), sequences[:, -1:]
121
+ ], axis=1)
122
+ indices = self.mapping.lookup(sequences)
123
+ indices = keras.ops.where(indices == -1, self.oov_index, indices)
124
+ return tf.gather(embeddings, indices).to_tensor()
125
+
126
+ def embeddings(self) -> tf.Tensor:
127
+ embeddings = self.embedder(self.graph)
128
+ self.cached_embeddings.assign(embeddings)
129
+ return embeddings
130
+
131
+ def compute_mask(
132
+ self,
133
+ inputs: tensors.GraphTensor,
134
+ mask: bool | None = None
135
+ ) -> tf.Tensor | None:
136
+ sequences = self.split(inputs)
137
+ return keras.ops.not_equal(sequences, '')
138
+
139
+ def get_config(self) -> dict:
140
+ config = super().get_config()
141
+ config.update({
142
+ '_residues': self._residues,
143
+ 'featurizer': keras.saving.serialize_keras_object(self.featurizer),
144
+ 'embedder': keras.saving.serialize_keras_object(self.embedder)
145
+ })
146
+ return config
147
+
148
+ @classmethod
149
+ def from_config(cls, config: dict) -> 'ResidueEmbedding':
150
+ config['featurizer'] = keras.saving.deserialize_keras_object(config['featurizer'])
151
+ config['embedder'] = keras.saving.deserialize_keras_object(config['embedder'])
152
+ return super().from_config(config)
153
+
154
+
155
+ @keras.saving.register_keras_serializable(package='proteomics')
156
+ class SequenceSplitter(keras.layers.Layer):
157
+
158
+ def __init__(self, pad: bool, **kwargs):
159
+ super().__init__(**kwargs)
160
+ self.pad = pad
161
+
162
+ def call(self, inputs):
163
+ inputs = tf_text.regex_split(inputs, residue_pattern, residue_pattern)
164
+ if self.pad:
165
+ inputs = inputs.to_tensor()
166
+ return inputs
167
+
168
+
169
+ def interpret(model: keras.models.Model, sequence: list[str]) -> tensors.GraphTensor:
170
+
171
+ if not tf.is_tensor(sequence):
172
+ sequence = keras.ops.convert_to_tensor(sequence)
173
+
174
+ # Find embedding layer
175
+ for layer in model.layers:
176
+ if isinstance(layer, ResidueEmbedding):
177
+ break
178
+
179
+ # Use embedding layer to convert the sequence to a graph
180
+ residues = layer.ragged_split(sequence)
181
+ residues = keras.ops.concatenate([
182
+ tf.strings.join([residues[:, :-1], '*']), residues[:, -1:]
183
+ ], axis=1)
184
+ indices = layer.mapping.lookup(residues)
185
+ graph = tf.concat([
186
+ layer.graph[residue_ids] for residue_ids in indices
187
+ ], axis=0)
188
+
189
+ # Define layer which reshapes data into sequences of residue embeddings
190
+ num_residues = indices.row_lengths()
191
+ to_sequence = (
192
+ lambda x: tf.RaggedTensor.from_row_lengths(x, num_residues).to_tensor()
193
+ )
194
+ reshape = keras.layers.Lambda(to_sequence)
195
+
196
+ # Obtain the embedder part of the original model
197
+ embedder = layer.embedder
198
+ # Obtain the remaining part of the original model
199
+ predictor = keras.models.Model(embedder.output, model.output)
200
+ # Obtain an 'interpretable model', based on the original model
201
+ inputs = layers.Input(graph.spec)
202
+ x = inputs
203
+ for layer in embedder.layers: # Loop over layers to expose them
204
+ x = layer(x)
205
+ x = reshape(x)
206
+ outputs = predictor(x)
207
+ interpretable_model = models.GraphModel(inputs, outputs)
208
+
209
+ # Interpret original model through the 'interpretable model'
210
+ graph = models.interpret(interpretable_model, graph)
211
+ del interpretable_model
212
+
213
+ # Update 'size' field with new sizes corresponding to peptides for convenience
214
+ # Allows the user to obtain n:th peptide graph using indexing: nth_peptide = graph[n]
215
+ peptide_indices = range(len(num_residues))
216
+ peptide_indicator = keras.ops.repeat(peptide_indices, num_residues)
217
+ residue_sizes = graph.context['size']
218
+ peptide_sizes = keras.ops.segment_sum(residue_sizes, peptide_indicator)
219
+ return graph.update({'context': {'size': peptide_sizes, 'sequence': sequence}})
220
+
221
+
222
+ def register_residues(residues: dict[str, str]) -> None:
223
+ # TODO: Implement functions that check if residue has N- or C-terminal mod
224
+ # if C-terminal mod, no need to enforce concatenatable perm.
225
+ # if N-terminal mod, enforce only 'C(=O)O'
226
+ # if normal mod, enforce concatenateable perm ('N[C@@H]' and 'C(=O)O)).
227
+ for residue, smiles in residues.items():
228
+ if residue.startswith('P'):
229
+ smiles.startswith('N'), f'Incorrect SMILES permutation for {residue}.'
230
+ elif not residue.startswith('['):
231
+ smiles.startswith('N[C@@H]'), f'Incorrect SMILES permutation for {residue}.'
232
+ if len(residue) > 1 and not residue[1] == "-":
233
+ assert smiles.endswith('C(=O)O'), f'Incorrect SMILES permutation for {residue}.'
234
+ registered_residues[residue] = smiles
235
+ registered_residues[residue + '*'] = smiles.strip('O')
236
+
237
+
238
+ registered_residues: dict[str, str] = {}
239
+ register_residues(default_residues)
molcraft/layers.py CHANGED
@@ -1430,6 +1430,56 @@ class EdgeEmbedding(GraphLayer):
1430
1430
  return config
1431
1431
 
1432
1432
 
1433
+ @keras.saving.register_keras_serializable(package='molcraft')
1434
+ class AddContext(GraphLayer):
1435
+
1436
+ """Context adding layer.
1437
+
1438
+ Adds context to super nodes.
1439
+ """
1440
+
1441
+ def __init__(
1442
+ self,
1443
+ field: str = 'feature',
1444
+ drop: bool = True,
1445
+ normalize: bool = False,
1446
+ **kwargs
1447
+ ) -> None:
1448
+ super().__init__(**kwargs)
1449
+ self.field = field
1450
+ self.drop = drop
1451
+ self._normalize = normalize
1452
+
1453
+ def build(self, spec: tensors.GraphTensor.Spec) -> None:
1454
+ feature_dim = spec.node['feature'].shape[-1]
1455
+ self._context_dense = self.get_dense(feature_dim)
1456
+ if not self._normalize:
1457
+ self._norm = keras.layers.Identity()
1458
+ elif str(self._normalize).lower().startswith('layer'):
1459
+ self._norm = keras.layers.LayerNormalization()
1460
+ else:
1461
+ self._norm = keras.layers.BatchNormalization()
1462
+
1463
+ def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1464
+ context = tensor.context[self.field]
1465
+ context = self._context_dense(context)
1466
+ context = self._norm(context)
1467
+ node_feature = ops.scatter_add(
1468
+ tensor.node['feature'], tensor.node['super'], context
1469
+ )
1470
+ data = {'node': {'feature': node_feature}}
1471
+ if self.drop:
1472
+ data['context'] = {self.field: None}
1473
+ return tensor.update(data)
1474
+
1475
+ def get_config(self) -> dict:
1476
+ config = super().get_config()
1477
+ config['field'] = self.field
1478
+ config['drop'] = self.drop
1479
+ config['normalize'] = self._normalize
1480
+ return config
1481
+
1482
+
1433
1483
  @keras.saving.register_keras_serializable(package='molcraft')
1434
1484
  class GraphNetwork(GraphLayer):
1435
1485
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: molcraft
3
- Version: 0.1.0a15
3
+ Version: 0.1.0a16
4
4
  Summary: Graph Neural Networks for Molecular Machine Learning
5
5
  Author-email: Alexander Kensert <alexander.kensert@gmail.com>
6
6
  License: MIT License
@@ -35,6 +35,7 @@ Requires-Python: >=3.10
35
35
  Description-Content-Type: text/markdown
36
36
  License-File: LICENSE
37
37
  Requires-Dist: tensorflow>=2.16
38
+ Requires-Dist: tensorflow-text>=2.16
38
39
  Requires-Dist: rdkit>=2023.9.5
39
40
  Requires-Dist: pandas>=1.0.3
40
41
  Requires-Dist: ipython>=8.12.0
@@ -1,4 +1,4 @@
1
- molcraft/__init__.py,sha256=4yc0HLuOki-T3c1zX4_5III8vUIhGS4T8AfmIVvb0bw,464
1
+ molcraft/__init__.py,sha256=uo2ze7WMv3VhP0JcJedXDS9UkeEFydlXgSw4l7Xi8E0,464
2
2
  molcraft/callbacks.py,sha256=x5HnkZhqcFRrW6xdApt_jZ4X08A-0fxcnFKfdmRKa0c,3571
3
3
  molcraft/chem.py,sha256=--4AdZV0TCj_cf5i-TRidNJGSFyab1ksUEMjmDi7zaM,21837
4
4
  molcraft/conformers.py,sha256=K6ZtiSUNDN_fwqGP9JrPcwALLFFvlMlF_XejEJH3Sr4,4205
@@ -6,17 +6,16 @@ molcraft/datasets.py,sha256=QKHi9SUBKvJvdkRFmRQNowhrnu35pQqtujuLatOK8bE,4151
6
6
  molcraft/descriptors.py,sha256=jJpT0XWu3Tx_bxnwk1rENySRkaM8cMDMaDIjG8KKvtg,3097
7
7
  molcraft/features.py,sha256=GwOecLCNUIuGfbIVzsAJH4LikkzWMKj5IT7zSgGTttU,13846
8
8
  molcraft/featurizers.py,sha256=8Jmd2yguYmVRyh5wkn6sRzzEENkJ0TqHSlR8qgC4zNY,27131
9
- molcraft/layers.py,sha256=cUpo9dqqNEnc7rNf-Dze8adFhOkTV5F9IhHOKs13OUI,60134
9
+ molcraft/layers.py,sha256=200Y4QLOXDyHw1bnjoSQ6hZ-zD0vpforv-KQGESAZi8,61733
10
10
  molcraft/losses.py,sha256=qnS2yC5g-O3n_zVea9MR6TNiFraW2yqRgePOisoUP4A,1065
11
11
  molcraft/models.py,sha256=hKYSV8z65ohRKfPyjjzxZeVjipm064BWeUBGZE0tpyU,21882
12
12
  molcraft/ops.py,sha256=bQbdFDt9waxVCzF5-dkTB6vlpj9eoSt8I4Qg7ZGXbsU,6178
13
13
  molcraft/records.py,sha256=MbvYkcCunbAmpy_MWXmQ9WBGi2WvwxFUlwQSPKPvSSk,5534
14
14
  molcraft/tensors.py,sha256=EOUKx496KUZsjA1zA2ABc7tU_TW3Jv7AXDsug_QsLbA,22407
15
- molcraft/apps/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
16
- molcraft/apps/peptides.py,sha256=N5wJDGDIDRbmOmxin_dTY-odLqb0avAX9FU22U6x6c0,14576
17
- molcraft/apps/qsrr.py,sha256=HhsJzTUuSSvHcl5fmPrI7VtzAUP711yesQ_pAc9hNhU,1572
18
- molcraft-0.1.0a15.dist-info/licenses/LICENSE,sha256=sbVeqlrtZ0V63uYhZGL5dCxUm8rBAOqe2avyA1zIQNk,1074
19
- molcraft-0.1.0a15.dist-info/METADATA,sha256=hOAPpbo8vhG-9Jr0WzA8NgugAm0d46Um0R3IhtZpeZU,3893
20
- molcraft-0.1.0a15.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
21
- molcraft-0.1.0a15.dist-info/top_level.txt,sha256=dENV6MfOceshM6MQCgJlcN1ojZkiCL9B4F7XyUge3QM,9
22
- molcraft-0.1.0a15.dist-info/RECORD,,
15
+ molcraft/applications/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
16
+ molcraft/applications/proteomics.py,sha256=usZkoYtmTi1BtoP8SigyBNPjxR-nLH1yEsuAdpjvF2M,9009
17
+ molcraft-0.1.0a16.dist-info/licenses/LICENSE,sha256=sbVeqlrtZ0V63uYhZGL5dCxUm8rBAOqe2avyA1zIQNk,1074
18
+ molcraft-0.1.0a16.dist-info/METADATA,sha256=JL32RxGZY92s39EnhonY4mZbvgGvXjcybtV9nXukn4s,3930
19
+ molcraft-0.1.0a16.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
20
+ molcraft-0.1.0a16.dist-info/top_level.txt,sha256=dENV6MfOceshM6MQCgJlcN1ojZkiCL9B4F7XyUge3QM,9
21
+ molcraft-0.1.0a16.dist-info/RECORD,,
molcraft/apps/peptides.py DELETED
@@ -1,429 +0,0 @@
1
- import re
2
- import keras
3
- import numpy as np
4
- import tensorflow as tf
5
- import tensorflow_text as tf_text
6
- from rdkit import Chem
7
-
8
- from molcraft import ops
9
- from molcraft import chem
10
- from molcraft import features
11
- from molcraft import featurizers
12
- from molcraft import tensors
13
- from molcraft import descriptors
14
- from molcraft import layers
15
- from molcraft import models
16
-
17
-
18
-
19
- @keras.saving.register_keras_serializable(package='molcraft')
20
- class SequenceSplitter(keras.layers.Layer):
21
-
22
- _pattern = "|".join([
23
- r'(\[[A-Za-z0-9]+\]-[A-Z]\[[A-Za-z0-9]+\])', # N-term mod + mod
24
- r'([A-Z]\[[A-Za-z0-9]+\]-\[[A-Za-z0-9]+\])', # C-term mod + mod
25
- r'([A-Z]-\[[A-Za-z0-9]+\])', # C-term mod
26
- r'(\[[A-Za-z0-9]+\]-[A-Z])', # N-term mod
27
- r'([A-Z]\[[A-Za-z0-9]+\])', # Mod
28
- r'([A-Z])', # No mod
29
- ])
30
-
31
- def call(self, inputs):
32
- inputs = tf_text.regex_split(inputs, self._pattern, self._pattern)
33
- inputs = keras.ops.concatenate([
34
- tf.strings.join([inputs[:, :-1], '-[X]']),
35
- inputs[:, -1:]
36
- ], axis=1)
37
- return inputs.to_tensor()
38
-
39
- @keras.saving.register_keras_serializable(package='molcraft')
40
- class Gather(keras.layers.Layer):
41
-
42
- def __init__(
43
- self,
44
- padding: list[tuple[int]] | tuple[int] | int = 1,
45
- mask_value: int = 0,
46
- **kwargs
47
- ) -> None:
48
- super().__init__(**kwargs)
49
- self._splitter = SequenceSplitter()
50
- self.padding = padding
51
- self.mask_value = mask_value
52
- self.supports_masking = True
53
-
54
- self._tags = list(sorted(residues.keys()))
55
- self._mapping = tf.lookup.StaticHashTable(
56
- tf.lookup.KeyValueTensorInitializer(
57
- keys=self._tags,
58
- values=range(len(self._tags)),
59
- ),
60
- default_value=-1,
61
- )
62
-
63
- def get_config(self):
64
- config = super().get_config()
65
- config['mask_value'] = self.mask_value
66
- config['padding'] = self.padding
67
- return config
68
-
69
- def call(self, inputs) -> tf.Tensor:
70
- embedding, sequence = inputs
71
- sequence = self._splitter(sequence)
72
- sequence = self._mapping.lookup(sequence)
73
- readout = ops.gather(embedding, keras.ops.where(sequence == -1, 0, sequence))
74
- readout = keras.ops.where(sequence[..., None] == -1, 0.0, readout)
75
- return readout
76
-
77
- def compute_mask(
78
- self,
79
- inputs: tensors.GraphTensor,
80
- mask: bool | None = None
81
- ) -> tf.Tensor | None:
82
- # if self.mask_value is None:
83
- # return None
84
- _, sequence = inputs
85
- sequence = self._splitter(sequence)
86
- return keras.ops.not_equal(sequence, '')
87
-
88
-
89
- @keras.saving.register_keras_serializable(package='molcraft')
90
- class Embedding(keras.layers.Layer):
91
-
92
- def __init__(self, **kwargs):
93
- super().__init__(**kwargs)
94
- tags = list(sorted(residues.keys()))
95
- self.mapping = tf.lookup.StaticHashTable(
96
- tf.lookup.KeyValueTensorInitializer(
97
- keys=tags,
98
- values=range(len(tags)),
99
- ),
100
- default_value=-1,
101
- )
102
- self.splitting = SequenceSplitter()
103
- featurizer = featurizers.MolGraphFeaturizer(super_atom=True)
104
- tensor_list = [featurizer(residues[tag]) for tag in tags]
105
- graph = tf.stack(tensor_list, axis=0)
106
- self._build_on_init(graph)
107
- self.embedder = models.GraphModel.from_layers(
108
- [
109
- layers.Input(graph.spec),
110
- layers.NodeEmbedding(128),
111
- layers.EdgeEmbedding(128),
112
- layers.GraphTransformer(128),
113
- layers.Readout()
114
- ]
115
- )
116
- self.embedding = tf.Variable(
117
- initial_value=tf.zeros((114, 128)), trainable=True
118
- )
119
- self.new_state = tf.Variable(True, dtype=tf.bool, trainable=False)
120
- self.gather = Gather()
121
- self.update_state()
122
-
123
- # Keep AA as is (most simple?), add positional embedding to distingusih N-, C- and non-terminal
124
-
125
- def update_state(self, inputs=None):
126
- graph = self._graph_tensor
127
- graph = tensors.to_dict(graph)
128
- embedding = self.embedder(graph)
129
- self.embedding.assign(embedding)
130
- tf.print("STATE UPDATED")
131
- return embedding
132
-
133
- def call(self, inputs=None, training=None) -> tensors.GraphTensor:
134
- if training:
135
- embedding = self.update_state()
136
- self.new_state.assign(True)
137
- return self.gather([embedding, inputs])
138
- else:
139
- embedding = tf.cond(
140
- pred=self.new_state,
141
- true_fn=lambda: self.update_state(),
142
- false_fn=lambda: self.embedding
143
- )
144
- self.new_state.assign(False)
145
- return self.gather([embedding, inputs])
146
-
147
- def build(self, input_shape):
148
- super().build(input_shape)
149
-
150
- def _build_on_init(self, x):
151
-
152
- if isinstance(x, tensors.GraphTensor):
153
- tensor = tensors.to_dict(x)
154
- self._spec = tf.nest.map_structure(
155
- tf.type_spec_from_value, tensor
156
- )
157
- else:
158
- self._spec = x
159
-
160
- self._graph = tf.nest.map_structure(
161
- lambda s: self.add_weight(
162
- shape=s.shape,
163
- dtype=s.dtype,
164
- trainable=False,
165
- initializer='zeros'
166
- ),
167
- self._spec
168
- )
169
-
170
- if isinstance(x, tensors.GraphTensor):
171
- tf.nest.map_structure(
172
- lambda v, x: v.assign(x),
173
- self._graph, tensor
174
- )
175
-
176
- graph = tf.nest.map_structure(
177
- keras.ops.convert_to_tensor, self._graph
178
- )
179
- self._graph_tensor = tensors.from_dict(graph)
180
-
181
- # def get_config(self) -> dict:
182
- # config = super().get_config()
183
- # spec = keras.saving.serialize_keras_object(self._spec)
184
- # config['spec'] = spec
185
- # #config['layers'] = keras.saving.serialize_keras_object(self.embedding.layers)
186
- # return config
187
-
188
- # @classmethod
189
- # def from_config(cls, config: dict) -> 'SequenceToGraph':
190
- # spec = config.pop('spec')
191
- # spec = keras.saving.deserialize_keras_object(spec)
192
- # # config['layers'] = keras.saving.deserialize_keras_object(config['layers'])
193
- # layer = cls(**config)
194
- # layer._build_on_init(spec)
195
- # return layer
196
-
197
-
198
- @keras.saving.register_keras_serializable(package='molcraft')
199
- class SequenceToGraph(keras.layers.Layer):
200
-
201
- def __init__(
202
- self,
203
- atom_features: list[features.Feature] | str | None = 'auto',
204
- bond_features: list[features.Feature] | str | None = 'auto',
205
- molecule_features: list[descriptors.Descriptor] | str | None = 'auto',
206
- super_atom: bool = True,
207
- radius: int | float | None = None,
208
- self_loops: bool = False,
209
- include_hs: bool = False,
210
- **kwargs,
211
- ):
212
- super().__init__(**kwargs)
213
- self._splitter = SequenceSplitter()
214
- featurizer = featurizers.MolGraphFeaturizer(
215
- atom_features=atom_features,
216
- bond_features=bond_features,
217
- molecule_features=molecule_features,
218
- super_atom=super_atom,
219
- radius=radius,
220
- self_loops=self_loops,
221
- include_hs=include_hs,
222
- **kwargs,
223
- )
224
- tensor_list: list[tensors.GraphTensor] = [
225
- featurizer(residues[tag]).update({'context': {'tag': tag}}) for tag in residues
226
- ]
227
- graph = tf.stack(tensor_list, axis=0)
228
- self._build_on_init(graph)
229
-
230
- def call(self, sequence: tf.Tensor) -> tensors.GraphTensor:
231
- sequence = self._splitter(sequence)
232
- indices = self._tag_to_index.lookup(sequence)
233
- indices = tf.sort(tf.unique(tf.reshape(indices, [-1]))[0])[1:]
234
- graph = self._graph_tensor[indices]
235
- return tensors.to_dict(graph)
236
-
237
- def _build_on_init(self, x):
238
-
239
- if isinstance(x, tensors.GraphTensor):
240
- tensor = tensors.to_dict(x)
241
- self._spec = tf.nest.map_structure(
242
- tf.type_spec_from_value, tensor
243
- )
244
- else:
245
- self._spec = x
246
-
247
- self._graph = tf.nest.map_structure(
248
- lambda s: self.add_weight(
249
- shape=s.shape,
250
- dtype=s.dtype,
251
- trainable=False,
252
- initializer='zeros'
253
- ),
254
- self._spec
255
- )
256
-
257
- if isinstance(x, tensors.GraphTensor):
258
- tf.nest.map_structure(
259
- lambda v, x: v.assign(x),
260
- self._graph, tensor
261
- )
262
-
263
- graph = tf.nest.map_structure(
264
- keras.ops.convert_to_tensor, self._graph
265
- )
266
- self._graph_tensor = tensors.from_dict(graph)
267
-
268
- tags = self._graph_tensor.context['tag']
269
-
270
- self._tag_to_index = tf.lookup.StaticHashTable(
271
- tf.lookup.KeyValueTensorInitializer(
272
- keys=tags,
273
- values=range(len(tags)),
274
- ),
275
- default_value=-1,
276
- )
277
-
278
- def get_config(self) -> dict:
279
- config = super().get_config()
280
- spec = keras.saving.serialize_keras_object(self._spec)
281
- config['spec'] = spec
282
- return config
283
-
284
- @classmethod
285
- def from_config(cls, config: dict) -> 'SequenceToGraph':
286
- spec = config.pop('spec')
287
- spec = keras.saving.deserialize_keras_object(spec)
288
- layer = cls(**config)
289
- layer._build_on_init(spec)
290
- return layer
291
-
292
- # @property
293
- # def graph(self) -> tensors.GraphTensor:
294
- # return self._graph_tensor
295
-
296
-
297
- @keras.saving.register_keras_serializable(package='molcraft')
298
- class GraphToSequence(keras.layers.Layer):
299
-
300
- def __init__(
301
- self,
302
- padding: list[tuple[int]] | tuple[int] | int = 1,
303
- mask_value: int = 0,
304
- **kwargs
305
- ) -> None:
306
- super().__init__(**kwargs)
307
- self._splitter = SequenceSplitter()
308
- self.padding = padding
309
- self.mask_value = mask_value
310
- self._readout_layer = layers.Readout(mode='mean')
311
- self.supports_masking = True
312
-
313
- def get_config(self):
314
- config = super().get_config()
315
- config['mask_value'] = self.mask_value
316
- config['padding'] = self.padding
317
- return config
318
-
319
- def call(self, inputs) -> tf.Tensor:
320
-
321
- graph, sequence = inputs
322
- sequence = self._splitter(sequence)
323
- tag = graph['context']['tag']
324
- data = self._readout_layer(graph)
325
-
326
- table = tf.lookup.experimental.MutableHashTable(
327
- key_dtype=tf.string,
328
- value_dtype=tf.int32,
329
- default_value=-1
330
- )
331
-
332
- table.insert(tag, tf.range(tf.shape(tag)[0]))
333
- sequence = table.lookup(sequence)
334
-
335
- readout = ops.gather(data, keras.ops.where(sequence == -1, 0, sequence))
336
- readout = keras.ops.where(sequence[..., None] == -1, 0.0, readout)
337
- return readout
338
-
339
- def compute_mask(
340
- self,
341
- inputs: tensors.GraphTensor,
342
- mask: bool | None = None
343
- ) -> tf.Tensor | None:
344
- # if self.mask_value is None:
345
- # return None
346
- _, sequence = inputs
347
- sequence = self._splitter(sequence)
348
- return keras.ops.not_equal(sequence, '')
349
-
350
-
351
- residues = {
352
- "A": "N[C@@H](C)C(=O)O",
353
- "C": "N[C@@H](CS)C(=O)O",
354
- "C[Carbamidomethyl]": "N[C@@H](CSCC(=O)N)C(=O)O",
355
- "D": "N[C@@H](CC(=O)O)C(=O)O",
356
- "E": "N[C@@H](CCC(=O)O)C(=O)O",
357
- "F": "N[C@@H](Cc1ccccc1)C(=O)O",
358
- "G": "NCC(=O)O",
359
- "H": "N[C@@H](CC1=CN=C-N1)C(=O)O",
360
- "I": "N[C@@H](C(CC)C)C(=O)O",
361
- "K": "N[C@@H](CCCCN)C(=O)O",
362
- "K[Acetyl]": "N[C@@H](CCCCNC(=O)C)C(=O)O",
363
- "K[Crotonyl]": "N[C@@H](CCCCNC(C=CC)=O)C(=O)O",
364
- "K[Dimethyl]": "N[C@@H](CCCCN(C)C)C(=O)O",
365
- "K[Formyl]": "N[C@@H](CCCCNC=O)C(=O)O",
366
- "K[Malonyl]": "N[C@@H](CCCCNC(=O)CC(O)=O)C(=O)O",
367
- "K[Methyl]": "N[C@@H](CCCCNC)C(=O)O",
368
- "K[Propionyl]": "N[C@@H](CCCCNC(=O)CC)C(=O)O",
369
- "K[Succinyl]": "N[C@@H](CCCCNC(CCC(O)=O)=O)C(=O)O",
370
- "K[Trimethyl]": "N[C@@H](CCCC[N+](C)(C)C)C(=O)O",
371
- "L": "N[C@@H](CC(C)C)C(=O)O",
372
- "M": "N[C@@H](CCSC)C(=O)O",
373
- "M[Oxidation]": "N[C@@H](CCS(=O)C)C(=O)O",
374
- "N": "N[C@@H](CC(=O)N)C(=O)O",
375
- "P": "N1[C@@H](CCC1)C(=O)O",
376
- "P[Oxidation]": "N1CC(O)C[C@H]1C(=O)O",
377
- "Q": "N[C@@H](CCC(=O)N)C(=O)O",
378
- "R": "N[C@@H](CCCNC(=N)N)C(=O)O",
379
- "R[Deamidated]": "N[C@@H](CCCNC(N)=O)C(=O)O",
380
- "R[Dimethyl]": "N[C@@H](CCCNC(N(C)C)=N)C(=O)O",
381
- "R[Methyl]": "N[C@@H](CCCNC(=N)NC)C(=O)O",
382
- "S": "N[C@@H](CO)C(=O)O",
383
- "T": "N[C@@H](C(O)C)C(=O)O",
384
- "V": "N[C@@H](C(C)C)C(=O)O",
385
- "W": "N[C@@H](CC(=CN2)C1=C2C=CC=C1)C(=O)O",
386
- "Y": "N[C@@H](Cc1ccc(O)cc1)C(=O)O",
387
- "Y[Nitro]": "N[C@@H](Cc1ccc(O)c(N(=O)=O)c1)C(=O)O",
388
- "Y[Phospho]": "N[C@@H](Cc1ccc(OP(O)(=O)O)cc1)C(=O)O",
389
- "[Acetyl]-A": "N(C(C)=O)[C@@H](C)C(=O)O",
390
- "[Acetyl]-C": "N(C(C)=O)[C@@H](CS)C(=O)O",
391
- "[Acetyl]-D": "N(C(=O)C)[C@H](C(=O)O)CC(=O)O",
392
- "[Acetyl]-E": "N(C(=O)C)[C@@H](CCC(O)=O)C(=O)O",
393
- "[Acetyl]-F": "N(C(C)=O)[C@@H](Cc1ccccc1)C(=O)O",
394
- "[Acetyl]-G": "N(C(=O)C)CC(=O)O",
395
- "[Acetyl]-H": "N(C(=O)C)[C@@H](Cc1[nH]cnc1)C(=O)O",
396
- "[Acetyl]-I": "N(C(=O)C)[C@@H]([C@H](CC)C)C(=O)O",
397
- "[Acetyl]-K": "N(C(C)=O)[C@@H](CCCCN)C(=O)O",
398
- "[Acetyl]-L": "N(C(=O)C)[C@@H](CC(C)C)C(=O)O",
399
- "[Acetyl]-M": "N(C(=O)C)[C@@H](CCSC)C(=O)O",
400
- "[Acetyl]-N": "N(C(C)=O)[C@@H](CC(=O)N)C(=O)O",
401
- "[Acetyl]-P": "N1(C(=O)C)CCC[C@H]1C(=O)O",
402
- "[Acetyl]-Q": "N(C(=O)C)[C@@H](CCC(=O)N)C(=O)O",
403
- "[Acetyl]-R": "N(C(C)=O)[C@@H](CCCN=C(N)N)C(=O)O",
404
- "[Acetyl]-S": "N(C(C)=O)[C@@H](CO)C(=O)O",
405
- "[Acetyl]-T": "N(C(=O)C)[C@@H]([C@H](O)C)C(=O)O",
406
- "[Acetyl]-V": "N(C(=O)C)[C@@H](C(C)C)C(=O)O",
407
- "[Acetyl]-W": "N(C(C)=O)[C@@H](Cc1c2ccccc2[nH]c1)C(=O)O",
408
- "[Acetyl]-Y": "N(C(C)=O)[C@@H](Cc1ccc(O)cc1)C(=O)O"
409
- }
410
-
411
- residues_reverse = {}
412
- def register_peptide_residues(residues_: dict[str, str], canonicalize=True):
413
- for residue, smiles in residues_.items():
414
- if canonicalize:
415
- smiles = Chem.MolToSmiles(Chem.MolFromSmiles(smiles))
416
- residues[residue] = smiles
417
- residues_reverse[residues[residue]] = residue
418
-
419
- register_peptide_residues(residues, canonicalize=False)
420
-
421
- def _extract_residue_type(residue_tag: str) -> str:
422
- pattern = r"(?<!\[)[A-Z](?![^\[]*\])"
423
- return [match.group(0) for match in re.finditer(pattern, residue_tag)][0]
424
-
425
- special_residues = {}
426
- for key, value in residues.items():
427
- special_residues[key + '-[X]'] = value.rstrip('O')
428
-
429
- register_peptide_residues(special_residues, canonicalize=False)
molcraft/apps/qsrr.py DELETED
@@ -1,47 +0,0 @@
1
- import molcraft
2
- import keras
3
-
4
- @keras.saving.register_keras_serializable(package='molcraft')
5
- class AuxiliaryFeatureInjection(molcraft.layers.GraphLayer):
6
-
7
- def __init__(
8
- self,
9
- field: str = 'auxiliary_feature',
10
- depth: int = 2,
11
- drop: bool = True,
12
- activation: str | None = None,
13
- **kwargs,
14
- ) -> None:
15
- super().__init__(**kwargs)
16
- self.field = field
17
- self.depth = depth
18
- self.drop = drop
19
- self.activation = keras.activations.get(activation)
20
-
21
- def build(self, spec: molcraft.tensors.GraphTensor.Spec) -> None:
22
- units = spec.node['feature'].shape[1]
23
- for i in range(self.depth):
24
- setattr(
25
- self, f'dense_{i}', self.get_dense(units, activation=self.activation)
26
- )
27
-
28
- def propagate(self, tensor: molcraft.tensors.GraphTensor) -> None:
29
- x = tensor.context[self.field]
30
- if self.drop:
31
- tensor = tensor.update({'context': {self.field: None}})
32
- for i in range(self.depth):
33
- x = getattr(self, f'dense_{i}')(x)
34
- node_feature = molcraft.ops.scatter_add(
35
- tensor.node['feature'], tensor.node['super'], x
36
- )
37
- return tensor.update({'node': {'feature': node_feature}})
38
-
39
- def get_config(self) -> dict:
40
- config = super().get_config()
41
- config.update({
42
- 'field': self.field,
43
- 'depth': self.depth,
44
- 'drop': self.drop,
45
- 'activation': keras.activations.serialize(self.activation)
46
- })
47
- return config
File without changes