molcraft 0.1.0a17__tar.gz → 0.1.0a18__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of molcraft might be problematic. Click here for more details.
- {molcraft-0.1.0a17 → molcraft-0.1.0a18}/PKG-INFO +1 -1
- {molcraft-0.1.0a17 → molcraft-0.1.0a18}/molcraft/__init__.py +4 -2
- {molcraft-0.1.0a17 → molcraft-0.1.0a18}/molcraft/applications/proteomics.py +121 -41
- {molcraft-0.1.0a17 → molcraft-0.1.0a18}/molcraft/layers.py +94 -39
- {molcraft-0.1.0a17 → molcraft-0.1.0a18}/molcraft.egg-info/PKG-INFO +1 -1
- {molcraft-0.1.0a17 → molcraft-0.1.0a18}/LICENSE +0 -0
- {molcraft-0.1.0a17 → molcraft-0.1.0a18}/README.md +0 -0
- {molcraft-0.1.0a17 → molcraft-0.1.0a18}/molcraft/applications/__init__.py +0 -0
- {molcraft-0.1.0a17 → molcraft-0.1.0a18}/molcraft/applications/chromatography.py +0 -0
- {molcraft-0.1.0a17 → molcraft-0.1.0a18}/molcraft/callbacks.py +0 -0
- {molcraft-0.1.0a17 → molcraft-0.1.0a18}/molcraft/chem.py +0 -0
- {molcraft-0.1.0a17 → molcraft-0.1.0a18}/molcraft/datasets.py +0 -0
- {molcraft-0.1.0a17 → molcraft-0.1.0a18}/molcraft/descriptors.py +0 -0
- {molcraft-0.1.0a17 → molcraft-0.1.0a18}/molcraft/features.py +0 -0
- {molcraft-0.1.0a17 → molcraft-0.1.0a18}/molcraft/featurizers.py +0 -0
- {molcraft-0.1.0a17 → molcraft-0.1.0a18}/molcraft/losses.py +0 -0
- {molcraft-0.1.0a17 → molcraft-0.1.0a18}/molcraft/models.py +0 -0
- {molcraft-0.1.0a17 → molcraft-0.1.0a18}/molcraft/ops.py +0 -0
- {molcraft-0.1.0a17 → molcraft-0.1.0a18}/molcraft/records.py +0 -0
- {molcraft-0.1.0a17 → molcraft-0.1.0a18}/molcraft/tensors.py +0 -0
- {molcraft-0.1.0a17 → molcraft-0.1.0a18}/molcraft.egg-info/SOURCES.txt +0 -0
- {molcraft-0.1.0a17 → molcraft-0.1.0a18}/molcraft.egg-info/dependency_links.txt +0 -0
- {molcraft-0.1.0a17 → molcraft-0.1.0a18}/molcraft.egg-info/requires.txt +0 -0
- {molcraft-0.1.0a17 → molcraft-0.1.0a18}/molcraft.egg-info/top_level.txt +0 -0
- {molcraft-0.1.0a17 → molcraft-0.1.0a18}/pyproject.toml +0 -0
- {molcraft-0.1.0a17 → molcraft-0.1.0a18}/setup.cfg +0 -0
- {molcraft-0.1.0a17 → molcraft-0.1.0a18}/tests/test_chem.py +0 -0
- {molcraft-0.1.0a17 → molcraft-0.1.0a18}/tests/test_featurizers.py +0 -0
- {molcraft-0.1.0a17 → molcraft-0.1.0a18}/tests/test_layers.py +0 -0
- {molcraft-0.1.0a17 → molcraft-0.1.0a18}/tests/test_losses.py +0 -0
- {molcraft-0.1.0a17 → molcraft-0.1.0a18}/tests/test_models.py +0 -0
- {molcraft-0.1.0a17 → molcraft-0.1.0a18}/tests/test_tensors.py +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
__version__ = '0.1.
|
|
1
|
+
__version__ = '0.1.0a18'
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
4
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
|
@@ -14,4 +14,6 @@ from molcraft import records
|
|
|
14
14
|
from molcraft import tensors
|
|
15
15
|
from molcraft import callbacks
|
|
16
16
|
from molcraft import datasets
|
|
17
|
-
from molcraft import losses
|
|
17
|
+
from molcraft import losses
|
|
18
|
+
|
|
19
|
+
from molcraft.applications import proteomics
|
|
@@ -3,6 +3,7 @@ import keras
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
import tensorflow as tf
|
|
5
5
|
import tensorflow_text as tf_text
|
|
6
|
+
from rdkit import Chem
|
|
6
7
|
|
|
7
8
|
from molcraft import featurizers
|
|
8
9
|
from molcraft import tensors
|
|
@@ -10,16 +11,47 @@ from molcraft import layers
|
|
|
10
11
|
from molcraft import models
|
|
11
12
|
from molcraft import chem
|
|
12
13
|
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
No need to correct smiles for modeling, only for interpretation.
|
|
22
|
+
|
|
23
|
+
Use added smiles data to rearrange list of saliency values.
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
"""
|
|
13
37
|
|
|
14
38
|
# TODO: Add regex pattern for residue (C-term mod + N-term mod)?
|
|
15
39
|
# TODO: Add regex pattern for residue (C-term mod + N-term mod + mod)?
|
|
40
|
+
|
|
41
|
+
no_mod_pattern = r'([A-Z])'
|
|
42
|
+
side_chain_mod_pattern = r'([A-Z]\[[A-Za-z0-9]+\])'
|
|
43
|
+
n_term_mod_pattern = r'(\[[A-Za-z0-9]+\]-[A-Z])'
|
|
44
|
+
c_term_mod_pattern = r'([A-Z]-\[[A-Za-z0-9]+\])'
|
|
45
|
+
side_chain_and_n_term_mod_pattern = r'(\[[A-Za-z0-9]+\]-[A-Z]\[[A-Za-z0-9]+\])'
|
|
46
|
+
side_chain_and_c_term_mod_pattern = r'([A-Z]\[[A-Za-z0-9]+\]-\[[A-Za-z0-9]+\])'
|
|
47
|
+
|
|
16
48
|
residue_pattern: str = "|".join([
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
49
|
+
side_chain_and_n_term_mod_pattern,
|
|
50
|
+
side_chain_and_c_term_mod_pattern,
|
|
51
|
+
n_term_mod_pattern,
|
|
52
|
+
c_term_mod_pattern,
|
|
53
|
+
side_chain_mod_pattern,
|
|
54
|
+
no_mod_pattern
|
|
23
55
|
])
|
|
24
56
|
|
|
25
57
|
default_residues: dict[str, str] = {
|
|
@@ -45,20 +77,26 @@ default_residues: dict[str, str] = {
|
|
|
45
77
|
"Y": "N[C@@H](Cc1ccc(O)cc1)C(=O)O",
|
|
46
78
|
}
|
|
47
79
|
|
|
48
|
-
def
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
80
|
+
def has_c_terminal_mod(residue: str):
|
|
81
|
+
if re.search(c_term_mod_pattern, residue):
|
|
82
|
+
return True
|
|
83
|
+
return False
|
|
84
|
+
|
|
85
|
+
def has_n_terminal_mod(residue: str):
|
|
86
|
+
if re.search(n_term_mod_pattern, residue):
|
|
87
|
+
return True
|
|
88
|
+
return False
|
|
89
|
+
|
|
90
|
+
# def register_residues(residues: dict[str, str]) -> None:
|
|
91
|
+
# for residue, smiles in residues.items():
|
|
92
|
+
# if residue.startswith('P'):
|
|
93
|
+
# smiles.startswith('N'), f'Incorrect SMILES permutation for {residue}.'
|
|
94
|
+
# elif not residue.startswith('['):
|
|
95
|
+
# smiles.startswith('N[C@@H]'), f'Incorrect SMILES permutation for {residue}.'
|
|
96
|
+
# if len(residue) > 1 and not residue[1] == "-":
|
|
97
|
+
# assert smiles.endswith('C(=O)O'), f'Incorrect SMILES permutation for {residue}.'
|
|
98
|
+
# registered_residues[residue] = smiles
|
|
99
|
+
# registered_residues[residue + '*'] = smiles.strip('O')
|
|
62
100
|
|
|
63
101
|
|
|
64
102
|
class Peptide(chem.Mol):
|
|
@@ -79,6 +117,22 @@ class Peptide(chem.Mol):
|
|
|
79
117
|
return super().from_encoding(peptide_smiles, **kwargs)
|
|
80
118
|
|
|
81
119
|
|
|
120
|
+
def permute_residue_smiles(smiles: str) -> str:
|
|
121
|
+
glycine = chem.Mol.from_encoding("NCC(=O)O")
|
|
122
|
+
mol = chem.Mol.from_encoding(smiles)
|
|
123
|
+
nitrogen_index = mol.GetSubstructMatch(glycine)[0]
|
|
124
|
+
permuted_smiles = Chem.MolToSmiles(
|
|
125
|
+
mol, rootedAtAtom=nitrogen_index
|
|
126
|
+
)
|
|
127
|
+
return permuted_smiles
|
|
128
|
+
|
|
129
|
+
def check_peptide_residue_smiles(smiles: list[str]) -> bool:
|
|
130
|
+
backbone = 'NCC(=O)' * (len(smiles) - 1) + 'NC'
|
|
131
|
+
backbone = chem.Mol.from_encoding(backbone)
|
|
132
|
+
mol = chem.Mol.from_encoding(''.join(smiles))
|
|
133
|
+
is_valid = mol.HasSubstructMatch(backbone)
|
|
134
|
+
return is_valid
|
|
135
|
+
|
|
82
136
|
@keras.saving.register_keras_serializable(package='proteomics')
|
|
83
137
|
class ResidueEmbedding(keras.layers.Layer):
|
|
84
138
|
|
|
@@ -92,42 +146,60 @@ class ResidueEmbedding(keras.layers.Layer):
|
|
|
92
146
|
super().__init__(**kwargs)
|
|
93
147
|
if residues is None:
|
|
94
148
|
residues = {}
|
|
95
|
-
self._residue_dict = {**default_residues, **residues}
|
|
96
149
|
self.embedder = embedder
|
|
97
150
|
self.featurizer = featurizer
|
|
98
|
-
self.embedding_dim = self.embedder.output.shape[-1]
|
|
151
|
+
self.embedding_dim = int(self.embedder.output.shape[-1])
|
|
99
152
|
self.ragged_split = SequenceSplitter(pad=False)
|
|
100
153
|
self.split = SequenceSplitter(pad=True)
|
|
101
154
|
self.use_cached_embeddings = tf.Variable(False)
|
|
155
|
+
self.residues = residues
|
|
102
156
|
self.supports_masking = True
|
|
103
157
|
|
|
104
158
|
@property
|
|
105
159
|
def residues(self) -> dict[str, str]:
|
|
106
|
-
return self.
|
|
160
|
+
return self._residues
|
|
107
161
|
|
|
108
162
|
@residues.setter
|
|
109
163
|
def residues(self, residues: dict[str, str]) -> None:
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
164
|
+
|
|
165
|
+
residues = {**default_residues, **residues}
|
|
166
|
+
self._residues = {}
|
|
167
|
+
for residue, smiles in residues.items():
|
|
168
|
+
permuted_smiles = permute_residue_smiles(smiles)
|
|
169
|
+
# Returned smiles should begin with the amino group.
|
|
170
|
+
# It seems that the returned smiles ends with carboxyl group,
|
|
171
|
+
# though we do another check just in case.
|
|
172
|
+
if not has_c_terminal_mod(residue):
|
|
173
|
+
carboxyl_group = 'C(=O)O'
|
|
174
|
+
if not permuted_smiles.endswith(carboxyl_group):
|
|
175
|
+
raise ValueError(
|
|
176
|
+
f'Unsupported permutation of {residue!r} smiles: {permuted_smiles!r}.'
|
|
177
|
+
)
|
|
178
|
+
self._residues[residue] = permuted_smiles
|
|
179
|
+
self._residues[residue + '*'] = permuted_smiles.rstrip('O')
|
|
180
|
+
|
|
181
|
+
residue_keys = sorted(self._residues.keys())
|
|
182
|
+
residue_values = range(len(residue_keys))
|
|
183
|
+
residue_oov_value = np.where(np.array(residue_keys) == "G")[0][0]
|
|
184
|
+
|
|
114
185
|
self.mapping = tf.lookup.StaticHashTable(
|
|
115
186
|
tf.lookup.KeyValueTensorInitializer(
|
|
116
187
|
keys=residue_keys,
|
|
117
|
-
values=
|
|
188
|
+
values=residue_values
|
|
118
189
|
),
|
|
119
|
-
default_value=
|
|
190
|
+
default_value=residue_oov_value,
|
|
120
191
|
)
|
|
192
|
+
|
|
121
193
|
self.graph = tf.stack([
|
|
122
|
-
self.featurizer(
|
|
194
|
+
self.featurizer(self._residues[r]) for r in residue_keys
|
|
123
195
|
], axis=0)
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
)
|
|
196
|
+
|
|
197
|
+
zeros = tf.zeros((residue_values[-1] + 1, self.embedding_dim))
|
|
198
|
+
self.cached_embeddings = tf.Variable(initial_value=zeros)
|
|
127
199
|
_ = self.cache_and_get_embeddings()
|
|
128
200
|
|
|
129
201
|
def build(self, input_shape) -> None:
|
|
130
|
-
self.residues = self.
|
|
202
|
+
self.residues = self._residues
|
|
131
203
|
super().build(input_shape)
|
|
132
204
|
|
|
133
205
|
def call(self, sequences: tf.Tensor, training: bool = None) -> tf.Tensor:
|
|
@@ -163,16 +235,24 @@ class ResidueEmbedding(keras.layers.Layer):
|
|
|
163
235
|
def get_config(self) -> dict:
|
|
164
236
|
config = super().get_config()
|
|
165
237
|
config.update({
|
|
166
|
-
'featurizer': keras.saving.serialize_keras_object(
|
|
167
|
-
|
|
168
|
-
|
|
238
|
+
'featurizer': keras.saving.serialize_keras_object(
|
|
239
|
+
self.featurizer
|
|
240
|
+
),
|
|
241
|
+
'embedder': keras.saving.serialize_keras_object(
|
|
242
|
+
self.embedder
|
|
243
|
+
),
|
|
244
|
+
'residues': self._residues,
|
|
169
245
|
})
|
|
170
246
|
return config
|
|
171
247
|
|
|
172
248
|
@classmethod
|
|
173
249
|
def from_config(cls, config: dict) -> 'ResidueEmbedding':
|
|
174
|
-
config['featurizer'] = keras.saving.deserialize_keras_object(
|
|
175
|
-
|
|
250
|
+
config['featurizer'] = keras.saving.deserialize_keras_object(
|
|
251
|
+
config['featurizer']
|
|
252
|
+
)
|
|
253
|
+
config['embedder'] = keras.saving.deserialize_keras_object(
|
|
254
|
+
config['embedder']
|
|
255
|
+
)
|
|
176
256
|
return super().from_config(config)
|
|
177
257
|
|
|
178
258
|
|
|
@@ -190,5 +270,5 @@ class SequenceSplitter(keras.layers.Layer):
|
|
|
190
270
|
return inputs
|
|
191
271
|
|
|
192
272
|
|
|
193
|
-
registered_residues: dict[str, str] = {}
|
|
194
|
-
register_residues(default_residues)
|
|
273
|
+
# registered_residues: dict[str, str] = {}
|
|
274
|
+
# register_residues(default_residues)
|
|
@@ -380,11 +380,14 @@ class GraphConv(GraphLayer):
|
|
|
380
380
|
self._update_final_dense = self.get_dense(self.units)
|
|
381
381
|
|
|
382
382
|
if not self._normalize:
|
|
383
|
-
self.
|
|
383
|
+
self._message_norm = keras.layers.Identity()
|
|
384
|
+
self._update_norm = keras.layers.Identity()
|
|
384
385
|
elif str(self._normalize).lower().startswith('layer'):
|
|
385
|
-
self.
|
|
386
|
+
self._message_norm = keras.layers.LayerNormalization()
|
|
387
|
+
self._update_norm = keras.layers.LayerNormalization()
|
|
386
388
|
else:
|
|
387
|
-
self.
|
|
389
|
+
self._message_norm = keras.layers.BatchNormalization()
|
|
390
|
+
self._update_norm = keras.layers.BatchNormalization()
|
|
388
391
|
|
|
389
392
|
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
390
393
|
"""Forward pass.
|
|
@@ -430,7 +433,7 @@ class GraphConv(GraphLayer):
|
|
|
430
433
|
elif add_aggregate:
|
|
431
434
|
update = update.update({'node': {'aggregate': None}})
|
|
432
435
|
|
|
433
|
-
if not self._skip_connect
|
|
436
|
+
if not self._skip_connect:
|
|
434
437
|
return update
|
|
435
438
|
|
|
436
439
|
feature = update.node['feature']
|
|
@@ -438,8 +441,6 @@ class GraphConv(GraphLayer):
|
|
|
438
441
|
if self._skip_connect:
|
|
439
442
|
feature += residual
|
|
440
443
|
|
|
441
|
-
feature = self._normalization(feature)
|
|
442
|
-
|
|
443
444
|
return update.update({'node': {'feature': feature}})
|
|
444
445
|
|
|
445
446
|
def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
@@ -480,6 +481,7 @@ class GraphConv(GraphLayer):
|
|
|
480
481
|
axis=-1
|
|
481
482
|
)
|
|
482
483
|
message = self._message_intermediate_dense(message)
|
|
484
|
+
message = self._message_norm(message)
|
|
483
485
|
message = self._message_intermediate_activation(message)
|
|
484
486
|
message = self._message_final_dense(message)
|
|
485
487
|
return tensor.update({'edge': {'message': message}})
|
|
@@ -519,6 +521,7 @@ class GraphConv(GraphLayer):
|
|
|
519
521
|
"""
|
|
520
522
|
aggregate = tensor.node['aggregate']
|
|
521
523
|
node_feature = self._update_intermediate_dense(aggregate)
|
|
524
|
+
node_feature = self._update_norm(node_feature)
|
|
522
525
|
node_feature = self._update_intermediate_activation(node_feature)
|
|
523
526
|
node_feature = self._update_final_dense(node_feature)
|
|
524
527
|
return tensor.update(
|
|
@@ -1312,13 +1315,19 @@ class NodeEmbedding(GraphLayer):
|
|
|
1312
1315
|
|
|
1313
1316
|
def __init__(
|
|
1314
1317
|
self,
|
|
1315
|
-
dim: int = None,
|
|
1318
|
+
dim: int | None = None,
|
|
1319
|
+
intermediate_dim: int | None = None,
|
|
1320
|
+
intermediate_activation: str | keras.layers.Activation | None = 'relu',
|
|
1316
1321
|
normalize: bool = False,
|
|
1317
1322
|
embed_context: bool = False,
|
|
1318
1323
|
**kwargs
|
|
1319
1324
|
) -> None:
|
|
1320
1325
|
super().__init__(**kwargs)
|
|
1321
1326
|
self.dim = dim
|
|
1327
|
+
self._intermediate_dim = intermediate_dim
|
|
1328
|
+
self._intermediate_activation = keras.activations.get(
|
|
1329
|
+
intermediate_activation
|
|
1330
|
+
)
|
|
1322
1331
|
self._normalize = normalize
|
|
1323
1332
|
self._embed_context = embed_context
|
|
1324
1333
|
|
|
@@ -1326,30 +1335,38 @@ class NodeEmbedding(GraphLayer):
|
|
|
1326
1335
|
feature_dim = spec.node['feature'].shape[-1]
|
|
1327
1336
|
if not self.dim:
|
|
1328
1337
|
self.dim = feature_dim
|
|
1329
|
-
|
|
1330
|
-
|
|
1338
|
+
if not self._intermediate_dim:
|
|
1339
|
+
self._intermediate_dim = self.dim * 2
|
|
1340
|
+
self._node_dense = self.get_dense(
|
|
1341
|
+
self._intermediate_dim, activation=self._intermediate_activation
|
|
1342
|
+
)
|
|
1331
1343
|
self._has_super = 'super' in spec.node
|
|
1332
1344
|
has_context_feature = 'feature' in spec.context
|
|
1333
1345
|
if not has_context_feature:
|
|
1334
1346
|
self._embed_context = False
|
|
1335
1347
|
if self._has_super and not self._embed_context:
|
|
1336
|
-
self._super_feature = self.get_weight(
|
|
1348
|
+
self._super_feature = self.get_weight(
|
|
1349
|
+
shape=[self._intermediate_dim], name='super_node_feature'
|
|
1350
|
+
)
|
|
1337
1351
|
if self._embed_context:
|
|
1338
|
-
self._context_dense = self.get_dense(
|
|
1339
|
-
|
|
1352
|
+
self._context_dense = self.get_dense(
|
|
1353
|
+
self._intermediate_dim, activation=self._intermediate_activation
|
|
1354
|
+
)
|
|
1340
1355
|
if not self._normalize:
|
|
1341
1356
|
self._norm = keras.layers.Identity()
|
|
1342
1357
|
elif str(self._normalize).lower().startswith('layer'):
|
|
1343
1358
|
self._norm = keras.layers.LayerNormalization()
|
|
1344
1359
|
else:
|
|
1345
1360
|
self._norm = keras.layers.BatchNormalization()
|
|
1361
|
+
self._dense = self.get_dense(self.dim)
|
|
1346
1362
|
|
|
1347
1363
|
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1348
1364
|
feature = self._node_dense(tensor.node['feature'])
|
|
1349
1365
|
|
|
1350
1366
|
if self._has_super and not self._embed_context:
|
|
1351
1367
|
super_mask = keras.ops.expand_dims(tensor.node['super'], 1)
|
|
1352
|
-
|
|
1368
|
+
super_feature = self._intermediate_activation(self._super_feature)
|
|
1369
|
+
feature = keras.ops.where(super_mask, super_feature, feature)
|
|
1353
1370
|
|
|
1354
1371
|
if self._embed_context:
|
|
1355
1372
|
context_feature = self._context_dense(tensor.context['feature'])
|
|
@@ -1357,6 +1374,7 @@ class NodeEmbedding(GraphLayer):
|
|
|
1357
1374
|
tensor = tensor.update({'context': {'feature': None}})
|
|
1358
1375
|
|
|
1359
1376
|
feature = self._norm(feature)
|
|
1377
|
+
feature = self._dense(feature)
|
|
1360
1378
|
|
|
1361
1379
|
return tensor.update({'node': {'feature': feature}})
|
|
1362
1380
|
|
|
@@ -1364,6 +1382,10 @@ class NodeEmbedding(GraphLayer):
|
|
|
1364
1382
|
config = super().get_config()
|
|
1365
1383
|
config.update({
|
|
1366
1384
|
'dim': self.dim,
|
|
1385
|
+
'intermediate_dim': self._intermediate_dim,
|
|
1386
|
+
'intermediate_activation': keras.activations.serialize(
|
|
1387
|
+
self._intermediate_activation
|
|
1388
|
+
),
|
|
1367
1389
|
'normalize': self._normalize,
|
|
1368
1390
|
'embed_context': self._embed_context,
|
|
1369
1391
|
})
|
|
@@ -1381,50 +1403,67 @@ class EdgeEmbedding(GraphLayer):
|
|
|
1381
1403
|
def __init__(
|
|
1382
1404
|
self,
|
|
1383
1405
|
dim: int = None,
|
|
1406
|
+
intermediate_dim: int | None = None,
|
|
1407
|
+
intermediate_activation: str | keras.layers.Activation | None = 'relu',
|
|
1384
1408
|
normalize: bool = False,
|
|
1385
1409
|
**kwargs
|
|
1386
1410
|
) -> None:
|
|
1387
1411
|
super().__init__(**kwargs)
|
|
1388
1412
|
self.dim = dim
|
|
1413
|
+
self._intermediate_dim = intermediate_dim
|
|
1414
|
+
self._intermediate_activation = keras.activations.get(
|
|
1415
|
+
intermediate_activation
|
|
1416
|
+
)
|
|
1389
1417
|
self._normalize = normalize
|
|
1390
1418
|
|
|
1391
1419
|
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
1392
1420
|
feature_dim = spec.edge['feature'].shape[-1]
|
|
1393
1421
|
if not self.dim:
|
|
1394
1422
|
self.dim = feature_dim
|
|
1395
|
-
|
|
1396
|
-
|
|
1397
|
-
self.
|
|
1398
|
-
|
|
1423
|
+
if not self._intermediate_dim:
|
|
1424
|
+
self._intermediate_dim = self.dim * 2
|
|
1425
|
+
self._edge_dense = self.get_dense(
|
|
1426
|
+
self._intermediate_dim, activation=self._intermediate_activation
|
|
1427
|
+
)
|
|
1428
|
+
self._self_loop_feature = self.get_weight(
|
|
1429
|
+
shape=[self._intermediate_dim], name='self_loop_edge_feature'
|
|
1430
|
+
)
|
|
1399
1431
|
self._has_super = 'super' in spec.edge
|
|
1400
1432
|
if self._has_super:
|
|
1401
|
-
self._super_feature = self.get_weight(
|
|
1402
|
-
|
|
1433
|
+
self._super_feature = self.get_weight(
|
|
1434
|
+
shape=[self._intermediate_dim], name='super_edge_feature'
|
|
1435
|
+
)
|
|
1403
1436
|
if not self._normalize:
|
|
1404
1437
|
self._norm = keras.layers.Identity()
|
|
1405
1438
|
elif str(self._normalize).lower().startswith('layer'):
|
|
1406
1439
|
self._norm = keras.layers.LayerNormalization()
|
|
1407
1440
|
else:
|
|
1408
1441
|
self._norm = keras.layers.BatchNormalization()
|
|
1442
|
+
self._dense = self.get_dense(self.dim)
|
|
1409
1443
|
|
|
1410
1444
|
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1411
1445
|
feature = self._edge_dense(tensor.edge['feature'])
|
|
1412
1446
|
|
|
1413
1447
|
if self._has_super:
|
|
1414
1448
|
super_mask = keras.ops.expand_dims(tensor.edge['super'], 1)
|
|
1415
|
-
|
|
1449
|
+
super_feature = self._intermediate_activation(self._super_feature)
|
|
1450
|
+
feature = keras.ops.where(super_mask, super_feature, feature)
|
|
1416
1451
|
|
|
1417
1452
|
self_loop_mask = keras.ops.expand_dims(tensor.edge['source'] == tensor.edge['target'], 1)
|
|
1418
|
-
|
|
1419
|
-
|
|
1453
|
+
self_loop_feature = self._intermediate_activation(self._self_loop_feature)
|
|
1454
|
+
feature = keras.ops.where(self_loop_mask, self_loop_feature, feature)
|
|
1420
1455
|
feature = self._norm(feature)
|
|
1421
|
-
|
|
1456
|
+
feature = self._dense(feature)
|
|
1422
1457
|
return tensor.update({'edge': {'feature': feature}})
|
|
1423
1458
|
|
|
1424
1459
|
def get_config(self) -> dict:
|
|
1425
1460
|
config = super().get_config()
|
|
1426
1461
|
config.update({
|
|
1427
1462
|
'dim': self.dim,
|
|
1463
|
+
'intermediate_dim': self._intermediate_dim,
|
|
1464
|
+
'intermediate_activation': keras.activations.serialize(
|
|
1465
|
+
self._intermediate_activation
|
|
1466
|
+
),
|
|
1428
1467
|
'normalize': self._normalize,
|
|
1429
1468
|
})
|
|
1430
1469
|
return config
|
|
@@ -1441,42 +1480,60 @@ class AddContext(GraphLayer):
|
|
|
1441
1480
|
def __init__(
|
|
1442
1481
|
self,
|
|
1443
1482
|
field: str = 'feature',
|
|
1483
|
+
intermediate_dim: int | None = None,
|
|
1484
|
+
intermediate_activation: str | keras.layers.Activation | None = 'relu',
|
|
1444
1485
|
drop: bool = False,
|
|
1445
1486
|
normalize: bool = False,
|
|
1446
1487
|
**kwargs
|
|
1447
1488
|
) -> None:
|
|
1448
1489
|
super().__init__(**kwargs)
|
|
1449
|
-
self.
|
|
1450
|
-
self.
|
|
1490
|
+
self._field = field
|
|
1491
|
+
self._drop = drop
|
|
1492
|
+
self._intermediate_dim = intermediate_dim
|
|
1493
|
+
self._intermediate_activation = keras.activations.get(
|
|
1494
|
+
intermediate_activation
|
|
1495
|
+
)
|
|
1451
1496
|
self._normalize = normalize
|
|
1452
1497
|
|
|
1453
1498
|
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
1454
1499
|
feature_dim = spec.node['feature'].shape[-1]
|
|
1455
|
-
self.
|
|
1500
|
+
if self._intermediate_dim is None:
|
|
1501
|
+
self._intermediate_dim = feature_dim * 2
|
|
1502
|
+
self._intermediate_dense = self.get_dense(
|
|
1503
|
+
self._intermediate_dim, activation=self._intermediate_activation
|
|
1504
|
+
)
|
|
1505
|
+
self._final_dense = self.get_dense(feature_dim)
|
|
1456
1506
|
if not self._normalize:
|
|
1457
|
-
self.
|
|
1507
|
+
self._intermediate_norm = keras.layers.Identity()
|
|
1458
1508
|
elif str(self._normalize).lower().startswith('layer'):
|
|
1459
|
-
self.
|
|
1509
|
+
self._intermediate_norm = keras.layers.LayerNormalization()
|
|
1460
1510
|
else:
|
|
1461
|
-
self.
|
|
1511
|
+
self._intermediate_norm = keras.layers.BatchNormalization()
|
|
1462
1512
|
|
|
1463
1513
|
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
1464
|
-
context = tensor.context[self.
|
|
1465
|
-
context = self.
|
|
1466
|
-
context = self.
|
|
1514
|
+
context = tensor.context[self._field]
|
|
1515
|
+
context = self._intermediate_dense(context)
|
|
1516
|
+
context = self._intermediate_norm(context)
|
|
1517
|
+
context = self._final_dense(context)
|
|
1467
1518
|
node_feature = ops.scatter_add(
|
|
1468
1519
|
tensor.node['feature'], tensor.node['super'], context
|
|
1469
1520
|
)
|
|
1470
1521
|
data = {'node': {'feature': node_feature}}
|
|
1471
|
-
if self.
|
|
1472
|
-
data['context'] = {self.
|
|
1522
|
+
if self._drop:
|
|
1523
|
+
data['context'] = {self._field: None}
|
|
1473
1524
|
return tensor.update(data)
|
|
1474
1525
|
|
|
1475
1526
|
def get_config(self) -> dict:
|
|
1476
1527
|
config = super().get_config()
|
|
1477
|
-
config
|
|
1478
|
-
|
|
1479
|
-
|
|
1528
|
+
config.update({
|
|
1529
|
+
'field': self._field,
|
|
1530
|
+
'intermediate_dim': self._intermediate_dim,
|
|
1531
|
+
'intermediate_activation': keras.activations.serialize(
|
|
1532
|
+
self._intermediate_activation
|
|
1533
|
+
),
|
|
1534
|
+
'drop': self._drop,
|
|
1535
|
+
'normalize': self._normalize,
|
|
1536
|
+
})
|
|
1480
1537
|
return config
|
|
1481
1538
|
|
|
1482
1539
|
|
|
@@ -1738,5 +1795,3 @@ def _spec_from_inputs(inputs):
|
|
|
1738
1795
|
return spec
|
|
1739
1796
|
return tensors.GraphTensor.Spec(**nested_specs)
|
|
1740
1797
|
|
|
1741
|
-
|
|
1742
|
-
GraphTransformer = GTConv
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|