molcraft 0.1.0a17__tar.gz → 0.1.0a19__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of molcraft might be problematic. Click here for more details.
- {molcraft-0.1.0a17 → molcraft-0.1.0a19}/PKG-INFO +1 -1
- {molcraft-0.1.0a17 → molcraft-0.1.0a19}/molcraft/__init__.py +4 -2
- {molcraft-0.1.0a17 → molcraft-0.1.0a19}/molcraft/applications/proteomics.py +111 -41
- {molcraft-0.1.0a17 → molcraft-0.1.0a19}/molcraft/featurizers.py +1 -0
- {molcraft-0.1.0a17 → molcraft-0.1.0a19}/molcraft/layers.py +99 -44
- {molcraft-0.1.0a17 → molcraft-0.1.0a19}/molcraft/tensors.py +16 -10
- {molcraft-0.1.0a17 → molcraft-0.1.0a19}/molcraft.egg-info/PKG-INFO +1 -1
- {molcraft-0.1.0a17 → molcraft-0.1.0a19}/LICENSE +0 -0
- {molcraft-0.1.0a17 → molcraft-0.1.0a19}/README.md +0 -0
- {molcraft-0.1.0a17 → molcraft-0.1.0a19}/molcraft/applications/__init__.py +0 -0
- {molcraft-0.1.0a17 → molcraft-0.1.0a19}/molcraft/applications/chromatography.py +0 -0
- {molcraft-0.1.0a17 → molcraft-0.1.0a19}/molcraft/callbacks.py +0 -0
- {molcraft-0.1.0a17 → molcraft-0.1.0a19}/molcraft/chem.py +0 -0
- {molcraft-0.1.0a17 → molcraft-0.1.0a19}/molcraft/datasets.py +0 -0
- {molcraft-0.1.0a17 → molcraft-0.1.0a19}/molcraft/descriptors.py +0 -0
- {molcraft-0.1.0a17 → molcraft-0.1.0a19}/molcraft/features.py +0 -0
- {molcraft-0.1.0a17 → molcraft-0.1.0a19}/molcraft/losses.py +0 -0
- {molcraft-0.1.0a17 → molcraft-0.1.0a19}/molcraft/models.py +0 -0
- {molcraft-0.1.0a17 → molcraft-0.1.0a19}/molcraft/ops.py +0 -0
- {molcraft-0.1.0a17 → molcraft-0.1.0a19}/molcraft/records.py +0 -0
- {molcraft-0.1.0a17 → molcraft-0.1.0a19}/molcraft.egg-info/SOURCES.txt +0 -0
- {molcraft-0.1.0a17 → molcraft-0.1.0a19}/molcraft.egg-info/dependency_links.txt +0 -0
- {molcraft-0.1.0a17 → molcraft-0.1.0a19}/molcraft.egg-info/requires.txt +0 -0
- {molcraft-0.1.0a17 → molcraft-0.1.0a19}/molcraft.egg-info/top_level.txt +0 -0
- {molcraft-0.1.0a17 → molcraft-0.1.0a19}/pyproject.toml +0 -0
- {molcraft-0.1.0a17 → molcraft-0.1.0a19}/setup.cfg +0 -0
- {molcraft-0.1.0a17 → molcraft-0.1.0a19}/tests/test_chem.py +0 -0
- {molcraft-0.1.0a17 → molcraft-0.1.0a19}/tests/test_featurizers.py +0 -0
- {molcraft-0.1.0a17 → molcraft-0.1.0a19}/tests/test_layers.py +0 -0
- {molcraft-0.1.0a17 → molcraft-0.1.0a19}/tests/test_losses.py +0 -0
- {molcraft-0.1.0a17 → molcraft-0.1.0a19}/tests/test_models.py +0 -0
- {molcraft-0.1.0a17 → molcraft-0.1.0a19}/tests/test_tensors.py +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
__version__ = '0.1.
|
|
1
|
+
__version__ = '0.1.0a19'
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
4
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
|
@@ -14,4 +14,6 @@ from molcraft import records
|
|
|
14
14
|
from molcraft import tensors
|
|
15
15
|
from molcraft import callbacks
|
|
16
16
|
from molcraft import datasets
|
|
17
|
-
from molcraft import losses
|
|
17
|
+
from molcraft import losses
|
|
18
|
+
|
|
19
|
+
from molcraft.applications import proteomics
|
|
@@ -3,6 +3,7 @@ import keras
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
import tensorflow as tf
|
|
5
5
|
import tensorflow_text as tf_text
|
|
6
|
+
from rdkit import Chem
|
|
6
7
|
|
|
7
8
|
from molcraft import featurizers
|
|
8
9
|
from molcraft import tensors
|
|
@@ -10,16 +11,47 @@ from molcraft import layers
|
|
|
10
11
|
from molcraft import models
|
|
11
12
|
from molcraft import chem
|
|
12
13
|
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
No need to correct smiles for modeling, only for interpretation.
|
|
22
|
+
|
|
23
|
+
Use added smiles data to rearrange list of saliency values.
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
"""
|
|
13
37
|
|
|
14
38
|
# TODO: Add regex pattern for residue (C-term mod + N-term mod)?
|
|
15
39
|
# TODO: Add regex pattern for residue (C-term mod + N-term mod + mod)?
|
|
40
|
+
|
|
41
|
+
no_mod_pattern = r'([A-Z])'
|
|
42
|
+
side_chain_mod_pattern = r'([A-Z]\[[A-Za-z0-9]+\])'
|
|
43
|
+
n_term_mod_pattern = r'(\[[A-Za-z0-9]+\]-[A-Z])'
|
|
44
|
+
c_term_mod_pattern = r'([A-Z]-\[[A-Za-z0-9]+\])'
|
|
45
|
+
side_chain_and_n_term_mod_pattern = r'(\[[A-Za-z0-9]+\]-[A-Z]\[[A-Za-z0-9]+\])'
|
|
46
|
+
side_chain_and_c_term_mod_pattern = r'([A-Z]\[[A-Za-z0-9]+\]-\[[A-Za-z0-9]+\])'
|
|
47
|
+
|
|
16
48
|
residue_pattern: str = "|".join([
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
49
|
+
side_chain_and_n_term_mod_pattern,
|
|
50
|
+
side_chain_and_c_term_mod_pattern,
|
|
51
|
+
n_term_mod_pattern,
|
|
52
|
+
c_term_mod_pattern,
|
|
53
|
+
side_chain_mod_pattern,
|
|
54
|
+
no_mod_pattern
|
|
23
55
|
])
|
|
24
56
|
|
|
25
57
|
default_residues: dict[str, str] = {
|
|
@@ -45,20 +77,26 @@ default_residues: dict[str, str] = {
|
|
|
45
77
|
"Y": "N[C@@H](Cc1ccc(O)cc1)C(=O)O",
|
|
46
78
|
}
|
|
47
79
|
|
|
48
|
-
def
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
80
|
+
def has_c_terminal_mod(residue: str):
|
|
81
|
+
if re.search(c_term_mod_pattern, residue):
|
|
82
|
+
return True
|
|
83
|
+
return False
|
|
84
|
+
|
|
85
|
+
def has_n_terminal_mod(residue: str):
|
|
86
|
+
if re.search(n_term_mod_pattern, residue):
|
|
87
|
+
return True
|
|
88
|
+
return False
|
|
89
|
+
|
|
90
|
+
# def register_residues(residues: dict[str, str]) -> None:
|
|
91
|
+
# for residue, smiles in residues.items():
|
|
92
|
+
# if residue.startswith('P'):
|
|
93
|
+
# smiles.startswith('N'), f'Incorrect SMILES permutation for {residue}.'
|
|
94
|
+
# elif not residue.startswith('['):
|
|
95
|
+
# smiles.startswith('N[C@@H]'), f'Incorrect SMILES permutation for {residue}.'
|
|
96
|
+
# if len(residue) > 1 and not residue[1] == "-":
|
|
97
|
+
# assert smiles.endswith('C(=O)O'), f'Incorrect SMILES permutation for {residue}.'
|
|
98
|
+
# registered_residues[residue] = smiles
|
|
99
|
+
# registered_residues[residue + '*'] = smiles.strip('O')
|
|
62
100
|
|
|
63
101
|
|
|
64
102
|
class Peptide(chem.Mol):
|
|
@@ -79,6 +117,22 @@ class Peptide(chem.Mol):
|
|
|
79
117
|
return super().from_encoding(peptide_smiles, **kwargs)
|
|
80
118
|
|
|
81
119
|
|
|
120
|
+
def permute_residue_smiles(smiles: str) -> str:
|
|
121
|
+
glycine = chem.Mol.from_encoding("NCC(=O)O")
|
|
122
|
+
mol = chem.Mol.from_encoding(smiles)
|
|
123
|
+
nitrogen_index = mol.GetSubstructMatch(glycine)[0]
|
|
124
|
+
permuted_smiles = Chem.MolToSmiles(
|
|
125
|
+
mol, rootedAtAtom=nitrogen_index
|
|
126
|
+
)
|
|
127
|
+
return permuted_smiles
|
|
128
|
+
|
|
129
|
+
def check_peptide_residue_smiles(smiles: list[str]) -> bool:
|
|
130
|
+
backbone = 'NCC(=O)' * (len(smiles) - 1) + 'NC'
|
|
131
|
+
backbone = chem.Mol.from_encoding(backbone)
|
|
132
|
+
mol = chem.Mol.from_encoding(''.join(smiles))
|
|
133
|
+
is_valid = mol.HasSubstructMatch(backbone)
|
|
134
|
+
return is_valid
|
|
135
|
+
|
|
82
136
|
@keras.saving.register_keras_serializable(package='proteomics')
|
|
83
137
|
class ResidueEmbedding(keras.layers.Layer):
|
|
84
138
|
|
|
@@ -92,42 +146,50 @@ class ResidueEmbedding(keras.layers.Layer):
|
|
|
92
146
|
super().__init__(**kwargs)
|
|
93
147
|
if residues is None:
|
|
94
148
|
residues = {}
|
|
95
|
-
self._residue_dict = {**default_residues, **residues}
|
|
96
149
|
self.embedder = embedder
|
|
97
150
|
self.featurizer = featurizer
|
|
98
|
-
self.embedding_dim = self.embedder.output.shape[-1]
|
|
151
|
+
self.embedding_dim = int(self.embedder.output.shape[-1])
|
|
99
152
|
self.ragged_split = SequenceSplitter(pad=False)
|
|
100
153
|
self.split = SequenceSplitter(pad=True)
|
|
101
154
|
self.use_cached_embeddings = tf.Variable(False)
|
|
155
|
+
self.residues = residues
|
|
102
156
|
self.supports_masking = True
|
|
103
157
|
|
|
104
158
|
@property
|
|
105
159
|
def residues(self) -> dict[str, str]:
|
|
106
|
-
return self.
|
|
160
|
+
return self._residues
|
|
107
161
|
|
|
108
162
|
@residues.setter
|
|
109
163
|
def residues(self, residues: dict[str, str]) -> None:
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
164
|
+
|
|
165
|
+
residues = {**default_residues, **residues}
|
|
166
|
+
self._residues = {}
|
|
167
|
+
for residue, smiles in residues.items():
|
|
168
|
+
self._residues[residue] = smiles
|
|
169
|
+
self._residues[residue + '*'] = smiles.rstrip('O')
|
|
170
|
+
|
|
171
|
+
residue_keys = sorted(self._residues.keys())
|
|
172
|
+
residue_values = range(len(residue_keys))
|
|
173
|
+
residue_oov_value = np.where(np.array(residue_keys) == "G")[0][0]
|
|
174
|
+
|
|
114
175
|
self.mapping = tf.lookup.StaticHashTable(
|
|
115
176
|
tf.lookup.KeyValueTensorInitializer(
|
|
116
177
|
keys=residue_keys,
|
|
117
|
-
values=
|
|
178
|
+
values=residue_values
|
|
118
179
|
),
|
|
119
|
-
default_value=
|
|
180
|
+
default_value=residue_oov_value,
|
|
120
181
|
)
|
|
182
|
+
|
|
121
183
|
self.graph = tf.stack([
|
|
122
|
-
self.featurizer(
|
|
184
|
+
self.featurizer(self._residues[r]) for r in residue_keys
|
|
123
185
|
], axis=0)
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
)
|
|
186
|
+
|
|
187
|
+
zeros = tf.zeros((residue_values[-1] + 1, self.embedding_dim))
|
|
188
|
+
self.cached_embeddings = tf.Variable(initial_value=zeros)
|
|
127
189
|
_ = self.cache_and_get_embeddings()
|
|
128
190
|
|
|
129
191
|
def build(self, input_shape) -> None:
|
|
130
|
-
self.residues = self.
|
|
192
|
+
self.residues = self._residues
|
|
131
193
|
super().build(input_shape)
|
|
132
194
|
|
|
133
195
|
def call(self, sequences: tf.Tensor, training: bool = None) -> tf.Tensor:
|
|
@@ -163,16 +225,24 @@ class ResidueEmbedding(keras.layers.Layer):
|
|
|
163
225
|
def get_config(self) -> dict:
|
|
164
226
|
config = super().get_config()
|
|
165
227
|
config.update({
|
|
166
|
-
'featurizer': keras.saving.serialize_keras_object(
|
|
167
|
-
|
|
168
|
-
|
|
228
|
+
'featurizer': keras.saving.serialize_keras_object(
|
|
229
|
+
self.featurizer
|
|
230
|
+
),
|
|
231
|
+
'embedder': keras.saving.serialize_keras_object(
|
|
232
|
+
self.embedder
|
|
233
|
+
),
|
|
234
|
+
'residues': self._residues,
|
|
169
235
|
})
|
|
170
236
|
return config
|
|
171
237
|
|
|
172
238
|
@classmethod
|
|
173
239
|
def from_config(cls, config: dict) -> 'ResidueEmbedding':
|
|
174
|
-
config['featurizer'] = keras.saving.deserialize_keras_object(
|
|
175
|
-
|
|
240
|
+
config['featurizer'] = keras.saving.deserialize_keras_object(
|
|
241
|
+
config['featurizer']
|
|
242
|
+
)
|
|
243
|
+
config['embedder'] = keras.saving.deserialize_keras_object(
|
|
244
|
+
config['embedder']
|
|
245
|
+
)
|
|
176
246
|
return super().from_config(config)
|
|
177
247
|
|
|
178
248
|
|
|
@@ -190,5 +260,5 @@ class SequenceSplitter(keras.layers.Layer):
|
|
|
190
260
|
return inputs
|
|
191
261
|
|
|
192
262
|
|
|
193
|
-
registered_residues: dict[str, str] = {}
|
|
194
|
-
register_residues(default_residues)
|
|
263
|
+
# registered_residues: dict[str, str] = {}
|
|
264
|
+
# register_residues(default_residues)
|
|
@@ -413,6 +413,7 @@ class MolGraphFeaturizer3D(MolGraphFeaturizer):
|
|
|
413
413
|
|
|
414
414
|
def get_config(self):
|
|
415
415
|
config = super().get_config()
|
|
416
|
+
config.pop('bond_features', None)
|
|
416
417
|
config['radius'] = self._radius
|
|
417
418
|
config['pair_features'] = keras.saving.serialize_keras_object(
|
|
418
419
|
self._pair_features
|
|
@@ -279,7 +279,7 @@ class GraphConv(GraphLayer):
|
|
|
279
279
|
use_bias (bool):
|
|
280
280
|
Whether bias should be used in the dense layers. Default to `True`.
|
|
281
281
|
normalize (bool, str):
|
|
282
|
-
Whether normalization should be
|
|
282
|
+
Whether a normalization layer should be obtain by `get_norm()`. Default to `False`.
|
|
283
283
|
skip_connect (bool):
|
|
284
284
|
Whether node feature input should be added to the node feature output. Default to `True`.
|
|
285
285
|
kernel_initializer (keras.initializers.Initializer, str):
|
|
@@ -366,6 +366,7 @@ class GraphConv(GraphLayer):
|
|
|
366
366
|
has_overridden_message = self.__class__.message != GraphConv.message
|
|
367
367
|
if not has_overridden_message:
|
|
368
368
|
self._message_intermediate_dense = self.get_dense(self.units)
|
|
369
|
+
self._message_norm = self.get_norm()
|
|
369
370
|
self._message_intermediate_activation = self.activation
|
|
370
371
|
self._message_final_dense = self.get_dense(self.units)
|
|
371
372
|
|
|
@@ -376,16 +377,10 @@ class GraphConv(GraphLayer):
|
|
|
376
377
|
has_overridden_update = self.__class__.update != GraphConv.update
|
|
377
378
|
if not has_overridden_update:
|
|
378
379
|
self._update_intermediate_dense = self.get_dense(self.units)
|
|
380
|
+
self._update_norm = self.get_norm()
|
|
379
381
|
self._update_intermediate_activation = self.activation
|
|
380
382
|
self._update_final_dense = self.get_dense(self.units)
|
|
381
383
|
|
|
382
|
-
if not self._normalize:
|
|
383
|
-
self._normalization = keras.layers.Identity()
|
|
384
|
-
elif str(self._normalize).lower().startswith('layer'):
|
|
385
|
-
self._normalization = keras.layers.LayerNormalization()
|
|
386
|
-
else:
|
|
387
|
-
self._normalization = keras.layers.BatchNormalization()
|
|
388
|
-
|
|
389
384
|
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
390
385
|
"""Forward pass.
|
|
391
386
|
|
|
@@ -430,7 +425,7 @@ class GraphConv(GraphLayer):
|
|
|
430
425
|
elif add_aggregate:
|
|
431
426
|
update = update.update({'node': {'aggregate': None}})
|
|
432
427
|
|
|
433
|
-
if not self._skip_connect
|
|
428
|
+
if not self._skip_connect:
|
|
434
429
|
return update
|
|
435
430
|
|
|
436
431
|
feature = update.node['feature']
|
|
@@ -438,8 +433,6 @@ class GraphConv(GraphLayer):
|
|
|
438
433
|
if self._skip_connect:
|
|
439
434
|
feature += residual
|
|
440
435
|
|
|
441
|
-
feature = self._normalization(feature)
|
|
442
|
-
|
|
443
436
|
return update.update({'node': {'feature': feature}})
|
|
444
437
|
|
|
445
438
|
def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
@@ -480,6 +473,7 @@ class GraphConv(GraphLayer):
|
|
|
480
473
|
axis=-1
|
|
481
474
|
)
|
|
482
475
|
message = self._message_intermediate_dense(message)
|
|
476
|
+
message = self._message_norm(message)
|
|
483
477
|
message = self._message_intermediate_activation(message)
|
|
484
478
|
message = self._message_final_dense(message)
|
|
485
479
|
return tensor.update({'edge': {'message': message}})
|
|
@@ -519,6 +513,7 @@ class GraphConv(GraphLayer):
|
|
|
519
513
|
"""
|
|
520
514
|
aggregate = tensor.node['aggregate']
|
|
521
515
|
node_feature = self._update_intermediate_dense(aggregate)
|
|
516
|
+
node_feature = self._update_norm(node_feature)
|
|
522
517
|
node_feature = self._update_intermediate_activation(node_feature)
|
|
523
518
|
node_feature = self._update_final_dense(node_feature)
|
|
524
519
|
return tensor.update(
|
|
@@ -530,6 +525,14 @@ class GraphConv(GraphLayer):
|
|
|
530
525
|
}
|
|
531
526
|
)
|
|
532
527
|
|
|
528
|
+
def get_norm(self, **kwargs):
|
|
529
|
+
if not self._normalize:
|
|
530
|
+
return keras.layers.Identity()
|
|
531
|
+
elif str(self._normalize).lower().startswith('layer'):
|
|
532
|
+
return keras.layers.LayerNormalization(**kwargs)
|
|
533
|
+
else:
|
|
534
|
+
return keras.layers.BatchNormalization(**kwargs)
|
|
535
|
+
|
|
533
536
|
def get_config(self) -> dict:
|
|
534
537
|
config = super().get_config()
|
|
535
538
|
config.update({
|
|
@@ -1312,13 +1315,19 @@ class NodeEmbedding(GraphLayer):
|
|
|
1312
1315
|
|
|
1313
1316
|
def __init__(
|
|
1314
1317
|
self,
|
|
1315
|
-
dim: int = None,
|
|
1318
|
+
dim: int | None = None,
|
|
1319
|
+
intermediate_dim: int | None = None,
|
|
1320
|
+
intermediate_activation: str | keras.layers.Activation | None = 'relu',
|
|
1316
1321
|
normalize: bool = False,
|
|
1317
1322
|
embed_context: bool = False,
|
|
1318
1323
|
**kwargs
|
|
1319
1324
|
) -> None:
|
|
1320
1325
|
super().__init__(**kwargs)
|
|
1321
1326
|
self.dim = dim
|
|
1327
|
+
self._intermediate_dim = intermediate_dim
|
|
1328
|
+
self._intermediate_activation = keras.activations.get(
|
|
1329
|
+
intermediate_activation
|
|
1330
|
+
)
|
|
1322
1331
|
self._normalize = normalize
|
|
1323
1332
|
self._embed_context = embed_context
|
|
1324
1333
|
|
|
@@ -1326,30 +1335,38 @@ class NodeEmbedding(GraphLayer):
|
|
|
1326
1335
|
feature_dim = spec.node['feature'].shape[-1]
|
|
1327
1336
|
if not self.dim:
|
|
1328
1337
|
self.dim = feature_dim
|
|
1329
|
-
|
|
1330
|
-
|
|
1338
|
+
if not self._intermediate_dim:
|
|
1339
|
+
self._intermediate_dim = self.dim * 2
|
|
1340
|
+
self._node_dense = self.get_dense(
|
|
1341
|
+
self._intermediate_dim, activation=self._intermediate_activation
|
|
1342
|
+
)
|
|
1331
1343
|
self._has_super = 'super' in spec.node
|
|
1332
1344
|
has_context_feature = 'feature' in spec.context
|
|
1333
1345
|
if not has_context_feature:
|
|
1334
1346
|
self._embed_context = False
|
|
1335
1347
|
if self._has_super and not self._embed_context:
|
|
1336
|
-
self._super_feature = self.get_weight(
|
|
1348
|
+
self._super_feature = self.get_weight(
|
|
1349
|
+
shape=[self._intermediate_dim], name='super_node_feature'
|
|
1350
|
+
)
|
|
1337
1351
|
if self._embed_context:
|
|
1338
|
-
self._context_dense = self.get_dense(
|
|
1339
|
-
|
|
1352
|
+
self._context_dense = self.get_dense(
|
|
1353
|
+
self._intermediate_dim, activation=self._intermediate_activation
|
|
1354
|
+
)
|
|
1340
1355
|
if not self._normalize:
|
|
1341
1356
|
self._norm = keras.layers.Identity()
|
|
1342
1357
|
elif str(self._normalize).lower().startswith('layer'):
|
|
1343
1358
|
self._norm = keras.layers.LayerNormalization()
|
|
1344
1359
|
else:
|
|
1345
1360
|
self._norm = keras.layers.BatchNormalization()
|
|
1361
|
+
self._dense = self.get_dense(self.dim)
|
|
1346
1362
|
|
|
1347
1363
|
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1348
1364
|
feature = self._node_dense(tensor.node['feature'])
|
|
1349
1365
|
|
|
1350
1366
|
if self._has_super and not self._embed_context:
|
|
1351
1367
|
super_mask = keras.ops.expand_dims(tensor.node['super'], 1)
|
|
1352
|
-
|
|
1368
|
+
super_feature = self._intermediate_activation(self._super_feature)
|
|
1369
|
+
feature = keras.ops.where(super_mask, super_feature, feature)
|
|
1353
1370
|
|
|
1354
1371
|
if self._embed_context:
|
|
1355
1372
|
context_feature = self._context_dense(tensor.context['feature'])
|
|
@@ -1357,6 +1374,7 @@ class NodeEmbedding(GraphLayer):
|
|
|
1357
1374
|
tensor = tensor.update({'context': {'feature': None}})
|
|
1358
1375
|
|
|
1359
1376
|
feature = self._norm(feature)
|
|
1377
|
+
feature = self._dense(feature)
|
|
1360
1378
|
|
|
1361
1379
|
return tensor.update({'node': {'feature': feature}})
|
|
1362
1380
|
|
|
@@ -1364,6 +1382,10 @@ class NodeEmbedding(GraphLayer):
|
|
|
1364
1382
|
config = super().get_config()
|
|
1365
1383
|
config.update({
|
|
1366
1384
|
'dim': self.dim,
|
|
1385
|
+
'intermediate_dim': self._intermediate_dim,
|
|
1386
|
+
'intermediate_activation': keras.activations.serialize(
|
|
1387
|
+
self._intermediate_activation
|
|
1388
|
+
),
|
|
1367
1389
|
'normalize': self._normalize,
|
|
1368
1390
|
'embed_context': self._embed_context,
|
|
1369
1391
|
})
|
|
@@ -1381,50 +1403,67 @@ class EdgeEmbedding(GraphLayer):
|
|
|
1381
1403
|
def __init__(
|
|
1382
1404
|
self,
|
|
1383
1405
|
dim: int = None,
|
|
1406
|
+
intermediate_dim: int | None = None,
|
|
1407
|
+
intermediate_activation: str | keras.layers.Activation | None = 'relu',
|
|
1384
1408
|
normalize: bool = False,
|
|
1385
1409
|
**kwargs
|
|
1386
1410
|
) -> None:
|
|
1387
1411
|
super().__init__(**kwargs)
|
|
1388
1412
|
self.dim = dim
|
|
1413
|
+
self._intermediate_dim = intermediate_dim
|
|
1414
|
+
self._intermediate_activation = keras.activations.get(
|
|
1415
|
+
intermediate_activation
|
|
1416
|
+
)
|
|
1389
1417
|
self._normalize = normalize
|
|
1390
1418
|
|
|
1391
1419
|
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
1392
1420
|
feature_dim = spec.edge['feature'].shape[-1]
|
|
1393
1421
|
if not self.dim:
|
|
1394
1422
|
self.dim = feature_dim
|
|
1395
|
-
|
|
1396
|
-
|
|
1397
|
-
self.
|
|
1398
|
-
|
|
1423
|
+
if not self._intermediate_dim:
|
|
1424
|
+
self._intermediate_dim = self.dim * 2
|
|
1425
|
+
self._edge_dense = self.get_dense(
|
|
1426
|
+
self._intermediate_dim, activation=self._intermediate_activation
|
|
1427
|
+
)
|
|
1428
|
+
self._self_loop_feature = self.get_weight(
|
|
1429
|
+
shape=[self._intermediate_dim], name='self_loop_edge_feature'
|
|
1430
|
+
)
|
|
1399
1431
|
self._has_super = 'super' in spec.edge
|
|
1400
1432
|
if self._has_super:
|
|
1401
|
-
self._super_feature = self.get_weight(
|
|
1402
|
-
|
|
1433
|
+
self._super_feature = self.get_weight(
|
|
1434
|
+
shape=[self._intermediate_dim], name='super_edge_feature'
|
|
1435
|
+
)
|
|
1403
1436
|
if not self._normalize:
|
|
1404
1437
|
self._norm = keras.layers.Identity()
|
|
1405
1438
|
elif str(self._normalize).lower().startswith('layer'):
|
|
1406
1439
|
self._norm = keras.layers.LayerNormalization()
|
|
1407
1440
|
else:
|
|
1408
1441
|
self._norm = keras.layers.BatchNormalization()
|
|
1442
|
+
self._dense = self.get_dense(self.dim)
|
|
1409
1443
|
|
|
1410
1444
|
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1411
1445
|
feature = self._edge_dense(tensor.edge['feature'])
|
|
1412
1446
|
|
|
1413
1447
|
if self._has_super:
|
|
1414
1448
|
super_mask = keras.ops.expand_dims(tensor.edge['super'], 1)
|
|
1415
|
-
|
|
1449
|
+
super_feature = self._intermediate_activation(self._super_feature)
|
|
1450
|
+
feature = keras.ops.where(super_mask, super_feature, feature)
|
|
1416
1451
|
|
|
1417
1452
|
self_loop_mask = keras.ops.expand_dims(tensor.edge['source'] == tensor.edge['target'], 1)
|
|
1418
|
-
|
|
1419
|
-
|
|
1453
|
+
self_loop_feature = self._intermediate_activation(self._self_loop_feature)
|
|
1454
|
+
feature = keras.ops.where(self_loop_mask, self_loop_feature, feature)
|
|
1420
1455
|
feature = self._norm(feature)
|
|
1421
|
-
|
|
1456
|
+
feature = self._dense(feature)
|
|
1422
1457
|
return tensor.update({'edge': {'feature': feature}})
|
|
1423
1458
|
|
|
1424
1459
|
def get_config(self) -> dict:
|
|
1425
1460
|
config = super().get_config()
|
|
1426
1461
|
config.update({
|
|
1427
1462
|
'dim': self.dim,
|
|
1463
|
+
'intermediate_dim': self._intermediate_dim,
|
|
1464
|
+
'intermediate_activation': keras.activations.serialize(
|
|
1465
|
+
self._intermediate_activation
|
|
1466
|
+
),
|
|
1428
1467
|
'normalize': self._normalize,
|
|
1429
1468
|
})
|
|
1430
1469
|
return config
|
|
@@ -1441,42 +1480,60 @@ class AddContext(GraphLayer):
|
|
|
1441
1480
|
def __init__(
|
|
1442
1481
|
self,
|
|
1443
1482
|
field: str = 'feature',
|
|
1483
|
+
intermediate_dim: int | None = None,
|
|
1484
|
+
intermediate_activation: str | keras.layers.Activation | None = 'relu',
|
|
1444
1485
|
drop: bool = False,
|
|
1445
1486
|
normalize: bool = False,
|
|
1446
1487
|
**kwargs
|
|
1447
1488
|
) -> None:
|
|
1448
1489
|
super().__init__(**kwargs)
|
|
1449
|
-
self.
|
|
1450
|
-
self.
|
|
1490
|
+
self._field = field
|
|
1491
|
+
self._drop = drop
|
|
1492
|
+
self._intermediate_dim = intermediate_dim
|
|
1493
|
+
self._intermediate_activation = keras.activations.get(
|
|
1494
|
+
intermediate_activation
|
|
1495
|
+
)
|
|
1451
1496
|
self._normalize = normalize
|
|
1452
1497
|
|
|
1453
1498
|
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
1454
1499
|
feature_dim = spec.node['feature'].shape[-1]
|
|
1455
|
-
self.
|
|
1500
|
+
if self._intermediate_dim is None:
|
|
1501
|
+
self._intermediate_dim = feature_dim * 2
|
|
1502
|
+
self._intermediate_dense = self.get_dense(
|
|
1503
|
+
self._intermediate_dim, activation=self._intermediate_activation
|
|
1504
|
+
)
|
|
1505
|
+
self._final_dense = self.get_dense(feature_dim)
|
|
1456
1506
|
if not self._normalize:
|
|
1457
|
-
self.
|
|
1507
|
+
self._intermediate_norm = keras.layers.Identity()
|
|
1458
1508
|
elif str(self._normalize).lower().startswith('layer'):
|
|
1459
|
-
self.
|
|
1509
|
+
self._intermediate_norm = keras.layers.LayerNormalization()
|
|
1460
1510
|
else:
|
|
1461
|
-
self.
|
|
1511
|
+
self._intermediate_norm = keras.layers.BatchNormalization()
|
|
1462
1512
|
|
|
1463
1513
|
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1464
|
-
context = tensor.context[self.
|
|
1465
|
-
context = self.
|
|
1466
|
-
context = self.
|
|
1514
|
+
context = tensor.context[self._field]
|
|
1515
|
+
context = self._intermediate_dense(context)
|
|
1516
|
+
context = self._intermediate_norm(context)
|
|
1517
|
+
context = self._final_dense(context)
|
|
1467
1518
|
node_feature = ops.scatter_add(
|
|
1468
1519
|
tensor.node['feature'], tensor.node['super'], context
|
|
1469
1520
|
)
|
|
1470
1521
|
data = {'node': {'feature': node_feature}}
|
|
1471
|
-
if self.
|
|
1472
|
-
data['context'] = {self.
|
|
1522
|
+
if self._drop:
|
|
1523
|
+
data['context'] = {self._field: None}
|
|
1473
1524
|
return tensor.update(data)
|
|
1474
1525
|
|
|
1475
1526
|
def get_config(self) -> dict:
|
|
1476
1527
|
config = super().get_config()
|
|
1477
|
-
config
|
|
1478
|
-
|
|
1479
|
-
|
|
1528
|
+
config.update({
|
|
1529
|
+
'field': self._field,
|
|
1530
|
+
'intermediate_dim': self._intermediate_dim,
|
|
1531
|
+
'intermediate_activation': keras.activations.serialize(
|
|
1532
|
+
self._intermediate_activation
|
|
1533
|
+
),
|
|
1534
|
+
'drop': self._drop,
|
|
1535
|
+
'normalize': self._normalize,
|
|
1536
|
+
})
|
|
1480
1537
|
return config
|
|
1481
1538
|
|
|
1482
1539
|
|
|
@@ -1738,5 +1795,3 @@ def _spec_from_inputs(inputs):
|
|
|
1738
1795
|
return spec
|
|
1739
1796
|
return tensors.GraphTensor.Spec(**nested_specs)
|
|
1740
1797
|
|
|
1741
|
-
|
|
1742
|
-
GraphTransformer = GTConv
|
|
@@ -16,13 +16,15 @@ class GraphTensorBatchEncoder(tf.experimental.ExtensionTypeBatchEncoder):
|
|
|
16
16
|
if isinstance(f, tf.TensorSpec):
|
|
17
17
|
return tf.TensorSpec(
|
|
18
18
|
shape=[None] + f.shape[1:],
|
|
19
|
-
dtype=f.dtype
|
|
19
|
+
dtype=f.dtype
|
|
20
|
+
)
|
|
20
21
|
elif isinstance(f, tf.RaggedTensorSpec):
|
|
21
22
|
return tf.RaggedTensorSpec(
|
|
22
23
|
shape=[batch_size, None] + f.shape[1:],
|
|
23
24
|
dtype=f.dtype,
|
|
24
25
|
ragged_rank=1,
|
|
25
|
-
row_splits_dtype=f.row_splits_dtype
|
|
26
|
+
row_splits_dtype=f.row_splits_dtype
|
|
27
|
+
)
|
|
26
28
|
elif isinstance(f, tf.TypeSpec):
|
|
27
29
|
return f.__batch_encoder__.batch(f, batch_size)
|
|
28
30
|
return f
|
|
@@ -33,7 +35,8 @@ class GraphTensorBatchEncoder(tf.experimental.ExtensionTypeBatchEncoder):
|
|
|
33
35
|
batched_spec = object.__new__(type(spec))
|
|
34
36
|
batched_context_fields = tf.nest.map_structure(
|
|
35
37
|
lambda spec: tf.TensorSpec([batch_size] + spec.shape, spec.dtype),
|
|
36
|
-
context_fields
|
|
38
|
+
context_fields
|
|
39
|
+
)
|
|
37
40
|
batched_spec.__dict__.update({'context': batched_context_fields})
|
|
38
41
|
batched_spec.__dict__.update(batched_fields)
|
|
39
42
|
return batched_spec
|
|
@@ -46,13 +49,15 @@ class GraphTensorBatchEncoder(tf.experimental.ExtensionTypeBatchEncoder):
|
|
|
46
49
|
if isinstance(f, tf.TensorSpec):
|
|
47
50
|
return tf.TensorSpec(
|
|
48
51
|
shape=[None] + f.shape[1:],
|
|
49
|
-
dtype=f.dtype
|
|
52
|
+
dtype=f.dtype
|
|
53
|
+
)
|
|
50
54
|
elif isinstance(f, tf.RaggedTensorSpec):
|
|
51
55
|
return tf.RaggedTensorSpec(
|
|
52
56
|
shape=[None] + f.shape[2:],
|
|
53
57
|
dtype=f.dtype,
|
|
54
58
|
ragged_rank=0,
|
|
55
|
-
row_splits_dtype=f.row_splits_dtype
|
|
59
|
+
row_splits_dtype=f.row_splits_dtype
|
|
60
|
+
)
|
|
56
61
|
elif isinstance(f, tf.TypeSpec):
|
|
57
62
|
return f.__batch_encoder__.unbatch(f)
|
|
58
63
|
return f
|
|
@@ -62,7 +67,8 @@ class GraphTensorBatchEncoder(tf.experimental.ExtensionTypeBatchEncoder):
|
|
|
62
67
|
unbatched_fields = tf.nest.map_structure(unbatch_field, fields)
|
|
63
68
|
unbatched_context_fields = tf.nest.map_structure(
|
|
64
69
|
lambda spec: tf.TensorSpec(spec.shape[1:], spec.dtype),
|
|
65
|
-
context_fields
|
|
70
|
+
context_fields
|
|
71
|
+
)
|
|
66
72
|
unbatched_spec = object.__new__(type(spec))
|
|
67
73
|
unbatched_spec.__dict__.update({'context': unbatched_context_fields})
|
|
68
74
|
unbatched_spec.__dict__.update(unbatched_fields)
|
|
@@ -91,7 +97,8 @@ class GraphTensorBatchEncoder(tf.experimental.ExtensionTypeBatchEncoder):
|
|
|
91
97
|
shape=([None] if scalar else [None, None]) + f.shape[1:],
|
|
92
98
|
dtype=f.dtype,
|
|
93
99
|
ragged_rank=(0 if scalar else 1),
|
|
94
|
-
row_splits_dtype=spec.context['size'].dtype
|
|
100
|
+
row_splits_dtype=spec.context['size'].dtype
|
|
101
|
+
)
|
|
95
102
|
return f
|
|
96
103
|
fields = dict(spec.__dict__)
|
|
97
104
|
context_fields = fields.pop('context')
|
|
@@ -99,7 +106,7 @@ class GraphTensorBatchEncoder(tf.experimental.ExtensionTypeBatchEncoder):
|
|
|
99
106
|
encoded_fields = {**{'context': context_fields}, **encoded_fields}
|
|
100
107
|
spec_components = tuple(encoded_fields.values())
|
|
101
108
|
spec_components = tuple(
|
|
102
|
-
x for x in tf.nest.flatten(spec_components)
|
|
109
|
+
x for x in tf.nest.flatten(spec_components)
|
|
103
110
|
if isinstance(x, tf.TypeSpec)
|
|
104
111
|
)
|
|
105
112
|
return spec_components
|
|
@@ -117,7 +124,6 @@ class GraphTensorBatchEncoder(tf.experimental.ExtensionTypeBatchEncoder):
|
|
|
117
124
|
fields = dict(zip(spec.__dict__.keys(), value_tuple))
|
|
118
125
|
value = object.__new__(spec.value_type)
|
|
119
126
|
value.__dict__.update(fields)
|
|
120
|
-
|
|
121
127
|
flatten = is_ragged(value) and not is_ragged(spec)
|
|
122
128
|
if flatten:
|
|
123
129
|
value = value.flatten()
|
|
@@ -125,7 +131,7 @@ class GraphTensorBatchEncoder(tf.experimental.ExtensionTypeBatchEncoder):
|
|
|
125
131
|
|
|
126
132
|
|
|
127
133
|
class GraphTensor(tf.experimental.BatchableExtensionType):
|
|
128
|
-
context: typing.Mapping[str,
|
|
134
|
+
context: typing.Mapping[str, tf.Tensor]
|
|
129
135
|
node: typing.Mapping[str, typing.Union[tf.Tensor, tf.RaggedTensor]]
|
|
130
136
|
edge: typing.Mapping[str, typing.Union[tf.Tensor, tf.RaggedTensor]]
|
|
131
137
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|