molcraft 0.1.0a17__tar.gz → 0.1.0a18__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of molcraft might be problematic. Click here for more details.

Files changed (32) hide show
  1. {molcraft-0.1.0a17 → molcraft-0.1.0a18}/PKG-INFO +1 -1
  2. {molcraft-0.1.0a17 → molcraft-0.1.0a18}/molcraft/__init__.py +4 -2
  3. {molcraft-0.1.0a17 → molcraft-0.1.0a18}/molcraft/applications/proteomics.py +121 -41
  4. {molcraft-0.1.0a17 → molcraft-0.1.0a18}/molcraft/layers.py +94 -39
  5. {molcraft-0.1.0a17 → molcraft-0.1.0a18}/molcraft.egg-info/PKG-INFO +1 -1
  6. {molcraft-0.1.0a17 → molcraft-0.1.0a18}/LICENSE +0 -0
  7. {molcraft-0.1.0a17 → molcraft-0.1.0a18}/README.md +0 -0
  8. {molcraft-0.1.0a17 → molcraft-0.1.0a18}/molcraft/applications/__init__.py +0 -0
  9. {molcraft-0.1.0a17 → molcraft-0.1.0a18}/molcraft/applications/chromatography.py +0 -0
  10. {molcraft-0.1.0a17 → molcraft-0.1.0a18}/molcraft/callbacks.py +0 -0
  11. {molcraft-0.1.0a17 → molcraft-0.1.0a18}/molcraft/chem.py +0 -0
  12. {molcraft-0.1.0a17 → molcraft-0.1.0a18}/molcraft/datasets.py +0 -0
  13. {molcraft-0.1.0a17 → molcraft-0.1.0a18}/molcraft/descriptors.py +0 -0
  14. {molcraft-0.1.0a17 → molcraft-0.1.0a18}/molcraft/features.py +0 -0
  15. {molcraft-0.1.0a17 → molcraft-0.1.0a18}/molcraft/featurizers.py +0 -0
  16. {molcraft-0.1.0a17 → molcraft-0.1.0a18}/molcraft/losses.py +0 -0
  17. {molcraft-0.1.0a17 → molcraft-0.1.0a18}/molcraft/models.py +0 -0
  18. {molcraft-0.1.0a17 → molcraft-0.1.0a18}/molcraft/ops.py +0 -0
  19. {molcraft-0.1.0a17 → molcraft-0.1.0a18}/molcraft/records.py +0 -0
  20. {molcraft-0.1.0a17 → molcraft-0.1.0a18}/molcraft/tensors.py +0 -0
  21. {molcraft-0.1.0a17 → molcraft-0.1.0a18}/molcraft.egg-info/SOURCES.txt +0 -0
  22. {molcraft-0.1.0a17 → molcraft-0.1.0a18}/molcraft.egg-info/dependency_links.txt +0 -0
  23. {molcraft-0.1.0a17 → molcraft-0.1.0a18}/molcraft.egg-info/requires.txt +0 -0
  24. {molcraft-0.1.0a17 → molcraft-0.1.0a18}/molcraft.egg-info/top_level.txt +0 -0
  25. {molcraft-0.1.0a17 → molcraft-0.1.0a18}/pyproject.toml +0 -0
  26. {molcraft-0.1.0a17 → molcraft-0.1.0a18}/setup.cfg +0 -0
  27. {molcraft-0.1.0a17 → molcraft-0.1.0a18}/tests/test_chem.py +0 -0
  28. {molcraft-0.1.0a17 → molcraft-0.1.0a18}/tests/test_featurizers.py +0 -0
  29. {molcraft-0.1.0a17 → molcraft-0.1.0a18}/tests/test_layers.py +0 -0
  30. {molcraft-0.1.0a17 → molcraft-0.1.0a18}/tests/test_losses.py +0 -0
  31. {molcraft-0.1.0a17 → molcraft-0.1.0a18}/tests/test_models.py +0 -0
  32. {molcraft-0.1.0a17 → molcraft-0.1.0a18}/tests/test_tensors.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: molcraft
3
- Version: 0.1.0a17
3
+ Version: 0.1.0a18
4
4
  Summary: Graph Neural Networks for Molecular Machine Learning
5
5
  Author-email: Alexander Kensert <alexander.kensert@gmail.com>
6
6
  License: MIT License
@@ -1,4 +1,4 @@
1
- __version__ = '0.1.0a17'
1
+ __version__ = '0.1.0a18'
2
2
 
3
3
  import os
4
4
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
@@ -14,4 +14,6 @@ from molcraft import records
14
14
  from molcraft import tensors
15
15
  from molcraft import callbacks
16
16
  from molcraft import datasets
17
- from molcraft import losses
17
+ from molcraft import losses
18
+
19
+ from molcraft.applications import proteomics
@@ -3,6 +3,7 @@ import keras
3
3
  import numpy as np
4
4
  import tensorflow as tf
5
5
  import tensorflow_text as tf_text
6
+ from rdkit import Chem
6
7
 
7
8
  from molcraft import featurizers
8
9
  from molcraft import tensors
@@ -10,16 +11,47 @@ from molcraft import layers
10
11
  from molcraft import models
11
12
  from molcraft import chem
12
13
 
14
+ """
15
+
16
+
17
+
18
+
19
+
20
+
21
+ No need to correct smiles for modeling, only for interpretation.
22
+
23
+ Use added smiles data to rearrange list of saliency values.
24
+
25
+
26
+
27
+
28
+
29
+
30
+
31
+
32
+
33
+
34
+
35
+
36
+ """
13
37
 
14
38
  # TODO: Add regex pattern for residue (C-term mod + N-term mod)?
15
39
  # TODO: Add regex pattern for residue (C-term mod + N-term mod + mod)?
40
+
41
+ no_mod_pattern = r'([A-Z])'
42
+ side_chain_mod_pattern = r'([A-Z]\[[A-Za-z0-9]+\])'
43
+ n_term_mod_pattern = r'(\[[A-Za-z0-9]+\]-[A-Z])'
44
+ c_term_mod_pattern = r'([A-Z]-\[[A-Za-z0-9]+\])'
45
+ side_chain_and_n_term_mod_pattern = r'(\[[A-Za-z0-9]+\]-[A-Z]\[[A-Za-z0-9]+\])'
46
+ side_chain_and_c_term_mod_pattern = r'([A-Z]\[[A-Za-z0-9]+\]-\[[A-Za-z0-9]+\])'
47
+
16
48
  residue_pattern: str = "|".join([
17
- r'(\[[A-Za-z0-9]+\]-[A-Z]\[[A-Za-z0-9]+\])', # residue (N-term mod + mod)
18
- r'([A-Z]\[[A-Za-z0-9]+\]-\[[A-Za-z0-9]+\])', # residue (C-term mod + mod)
19
- r'([A-Z]-\[[A-Za-z0-9]+\])', # residue (C-term mod)
20
- r'(\[[A-Za-z0-9]+\]-[A-Z])', # residue (N-term mod)
21
- r'([A-Z]\[[A-Za-z0-9]+\])', # residue (mod)
22
- r'([A-Z])', # residue (no mod)
49
+ side_chain_and_n_term_mod_pattern,
50
+ side_chain_and_c_term_mod_pattern,
51
+ n_term_mod_pattern,
52
+ c_term_mod_pattern,
53
+ side_chain_mod_pattern,
54
+ no_mod_pattern
23
55
  ])
24
56
 
25
57
  default_residues: dict[str, str] = {
@@ -45,20 +77,26 @@ default_residues: dict[str, str] = {
45
77
  "Y": "N[C@@H](Cc1ccc(O)cc1)C(=O)O",
46
78
  }
47
79
 
48
- def register_residues(residues: dict[str, str]) -> None:
49
- # TODO: Implement functions that check if residue has N- or C-terminal mod
50
- # if C-terminal mod, no need to enforce concatenatable perm.
51
- # if N-terminal mod, enforce only 'C(=O)O'
52
- # if normal mod, enforce concatenateable perm ('N[C@@H]' and 'C(=O)O)).
53
- for residue, smiles in residues.items():
54
- if residue.startswith('P'):
55
- smiles.startswith('N'), f'Incorrect SMILES permutation for {residue}.'
56
- elif not residue.startswith('['):
57
- smiles.startswith('N[C@@H]'), f'Incorrect SMILES permutation for {residue}.'
58
- if len(residue) > 1 and not residue[1] == "-":
59
- assert smiles.endswith('C(=O)O'), f'Incorrect SMILES permutation for {residue}.'
60
- registered_residues[residue] = smiles
61
- registered_residues[residue + '*'] = smiles.strip('O')
80
+ def has_c_terminal_mod(residue: str):
81
+ if re.search(c_term_mod_pattern, residue):
82
+ return True
83
+ return False
84
+
85
+ def has_n_terminal_mod(residue: str):
86
+ if re.search(n_term_mod_pattern, residue):
87
+ return True
88
+ return False
89
+
90
+ # def register_residues(residues: dict[str, str]) -> None:
91
+ # for residue, smiles in residues.items():
92
+ # if residue.startswith('P'):
93
+ # smiles.startswith('N'), f'Incorrect SMILES permutation for {residue}.'
94
+ # elif not residue.startswith('['):
95
+ # smiles.startswith('N[C@@H]'), f'Incorrect SMILES permutation for {residue}.'
96
+ # if len(residue) > 1 and not residue[1] == "-":
97
+ # assert smiles.endswith('C(=O)O'), f'Incorrect SMILES permutation for {residue}.'
98
+ # registered_residues[residue] = smiles
99
+ # registered_residues[residue + '*'] = smiles.strip('O')
62
100
 
63
101
 
64
102
  class Peptide(chem.Mol):
@@ -79,6 +117,22 @@ class Peptide(chem.Mol):
79
117
  return super().from_encoding(peptide_smiles, **kwargs)
80
118
 
81
119
 
120
+ def permute_residue_smiles(smiles: str) -> str:
121
+ glycine = chem.Mol.from_encoding("NCC(=O)O")
122
+ mol = chem.Mol.from_encoding(smiles)
123
+ nitrogen_index = mol.GetSubstructMatch(glycine)[0]
124
+ permuted_smiles = Chem.MolToSmiles(
125
+ mol, rootedAtAtom=nitrogen_index
126
+ )
127
+ return permuted_smiles
128
+
129
+ def check_peptide_residue_smiles(smiles: list[str]) -> bool:
130
+ backbone = 'NCC(=O)' * (len(smiles) - 1) + 'NC'
131
+ backbone = chem.Mol.from_encoding(backbone)
132
+ mol = chem.Mol.from_encoding(''.join(smiles))
133
+ is_valid = mol.HasSubstructMatch(backbone)
134
+ return is_valid
135
+
82
136
  @keras.saving.register_keras_serializable(package='proteomics')
83
137
  class ResidueEmbedding(keras.layers.Layer):
84
138
 
@@ -92,42 +146,60 @@ class ResidueEmbedding(keras.layers.Layer):
92
146
  super().__init__(**kwargs)
93
147
  if residues is None:
94
148
  residues = {}
95
- self._residue_dict = {**default_residues, **residues}
96
149
  self.embedder = embedder
97
150
  self.featurizer = featurizer
98
- self.embedding_dim = self.embedder.output.shape[-1]
151
+ self.embedding_dim = int(self.embedder.output.shape[-1])
99
152
  self.ragged_split = SequenceSplitter(pad=False)
100
153
  self.split = SequenceSplitter(pad=True)
101
154
  self.use_cached_embeddings = tf.Variable(False)
155
+ self.residues = residues
102
156
  self.supports_masking = True
103
157
 
104
158
  @property
105
159
  def residues(self) -> dict[str, str]:
106
- return self._residue_dict
160
+ return self._residues
107
161
 
108
162
  @residues.setter
109
163
  def residues(self, residues: dict[str, str]) -> None:
110
- self._residue_dict = residues
111
- num_residues = len(residues)
112
- residue_keys = sorted(residues.keys())
113
- oov_value = np.where(np.array(residue_keys) == "G")[0][0]
164
+
165
+ residues = {**default_residues, **residues}
166
+ self._residues = {}
167
+ for residue, smiles in residues.items():
168
+ permuted_smiles = permute_residue_smiles(smiles)
169
+ # Returned smiles should begin with the amino group.
170
+ # It seems that the returned smiles ends with carboxyl group,
171
+ # though we do another check just in case.
172
+ if not has_c_terminal_mod(residue):
173
+ carboxyl_group = 'C(=O)O'
174
+ if not permuted_smiles.endswith(carboxyl_group):
175
+ raise ValueError(
176
+ f'Unsupported permutation of {residue!r} smiles: {permuted_smiles!r}.'
177
+ )
178
+ self._residues[residue] = permuted_smiles
179
+ self._residues[residue + '*'] = permuted_smiles.rstrip('O')
180
+
181
+ residue_keys = sorted(self._residues.keys())
182
+ residue_values = range(len(residue_keys))
183
+ residue_oov_value = np.where(np.array(residue_keys) == "G")[0][0]
184
+
114
185
  self.mapping = tf.lookup.StaticHashTable(
115
186
  tf.lookup.KeyValueTensorInitializer(
116
187
  keys=residue_keys,
117
- values=range(num_residues)
188
+ values=residue_values
118
189
  ),
119
- default_value=oov_value,
190
+ default_value=residue_oov_value,
120
191
  )
192
+
121
193
  self.graph = tf.stack([
122
- self.featurizer(residues[residue]) for residue in residue_keys
194
+ self.featurizer(self._residues[r]) for r in residue_keys
123
195
  ], axis=0)
124
- self.cached_embeddings = tf.Variable(
125
- initial_value=tf.zeros((num_residues, self.embedding_dim))
126
- )
196
+
197
+ zeros = tf.zeros((residue_values[-1] + 1, self.embedding_dim))
198
+ self.cached_embeddings = tf.Variable(initial_value=zeros)
127
199
  _ = self.cache_and_get_embeddings()
128
200
 
129
201
  def build(self, input_shape) -> None:
130
- self.residues = self._residue_dict
202
+ self.residues = self._residues
131
203
  super().build(input_shape)
132
204
 
133
205
  def call(self, sequences: tf.Tensor, training: bool = None) -> tf.Tensor:
@@ -163,16 +235,24 @@ class ResidueEmbedding(keras.layers.Layer):
163
235
  def get_config(self) -> dict:
164
236
  config = super().get_config()
165
237
  config.update({
166
- 'featurizer': keras.saving.serialize_keras_object(self.featurizer),
167
- 'embedder': keras.saving.serialize_keras_object(self.embedder),
168
- 'residues': self._residue_dict,
238
+ 'featurizer': keras.saving.serialize_keras_object(
239
+ self.featurizer
240
+ ),
241
+ 'embedder': keras.saving.serialize_keras_object(
242
+ self.embedder
243
+ ),
244
+ 'residues': self._residues,
169
245
  })
170
246
  return config
171
247
 
172
248
  @classmethod
173
249
  def from_config(cls, config: dict) -> 'ResidueEmbedding':
174
- config['featurizer'] = keras.saving.deserialize_keras_object(config['featurizer'])
175
- config['embedder'] = keras.saving.deserialize_keras_object(config['embedder'])
250
+ config['featurizer'] = keras.saving.deserialize_keras_object(
251
+ config['featurizer']
252
+ )
253
+ config['embedder'] = keras.saving.deserialize_keras_object(
254
+ config['embedder']
255
+ )
176
256
  return super().from_config(config)
177
257
 
178
258
 
@@ -190,5 +270,5 @@ class SequenceSplitter(keras.layers.Layer):
190
270
  return inputs
191
271
 
192
272
 
193
- registered_residues: dict[str, str] = {}
194
- register_residues(default_residues)
273
+ # registered_residues: dict[str, str] = {}
274
+ # register_residues(default_residues)
@@ -380,11 +380,14 @@ class GraphConv(GraphLayer):
380
380
  self._update_final_dense = self.get_dense(self.units)
381
381
 
382
382
  if not self._normalize:
383
- self._normalization = keras.layers.Identity()
383
+ self._message_norm = keras.layers.Identity()
384
+ self._update_norm = keras.layers.Identity()
384
385
  elif str(self._normalize).lower().startswith('layer'):
385
- self._normalization = keras.layers.LayerNormalization()
386
+ self._message_norm = keras.layers.LayerNormalization()
387
+ self._update_norm = keras.layers.LayerNormalization()
386
388
  else:
387
- self._normalization = keras.layers.BatchNormalization()
389
+ self._message_norm = keras.layers.BatchNormalization()
390
+ self._update_norm = keras.layers.BatchNormalization()
388
391
 
389
392
  def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
390
393
  """Forward pass.
@@ -430,7 +433,7 @@ class GraphConv(GraphLayer):
430
433
  elif add_aggregate:
431
434
  update = update.update({'node': {'aggregate': None}})
432
435
 
433
- if not self._skip_connect and not self._normalize:
436
+ if not self._skip_connect:
434
437
  return update
435
438
 
436
439
  feature = update.node['feature']
@@ -438,8 +441,6 @@ class GraphConv(GraphLayer):
438
441
  if self._skip_connect:
439
442
  feature += residual
440
443
 
441
- feature = self._normalization(feature)
442
-
443
444
  return update.update({'node': {'feature': feature}})
444
445
 
445
446
  def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
@@ -480,6 +481,7 @@ class GraphConv(GraphLayer):
480
481
  axis=-1
481
482
  )
482
483
  message = self._message_intermediate_dense(message)
484
+ message = self._message_norm(message)
483
485
  message = self._message_intermediate_activation(message)
484
486
  message = self._message_final_dense(message)
485
487
  return tensor.update({'edge': {'message': message}})
@@ -519,6 +521,7 @@ class GraphConv(GraphLayer):
519
521
  """
520
522
  aggregate = tensor.node['aggregate']
521
523
  node_feature = self._update_intermediate_dense(aggregate)
524
+ node_feature = self._update_norm(node_feature)
522
525
  node_feature = self._update_intermediate_activation(node_feature)
523
526
  node_feature = self._update_final_dense(node_feature)
524
527
  return tensor.update(
@@ -1312,13 +1315,19 @@ class NodeEmbedding(GraphLayer):
1312
1315
 
1313
1316
  def __init__(
1314
1317
  self,
1315
- dim: int = None,
1318
+ dim: int | None = None,
1319
+ intermediate_dim: int | None = None,
1320
+ intermediate_activation: str | keras.layers.Activation | None = 'relu',
1316
1321
  normalize: bool = False,
1317
1322
  embed_context: bool = False,
1318
1323
  **kwargs
1319
1324
  ) -> None:
1320
1325
  super().__init__(**kwargs)
1321
1326
  self.dim = dim
1327
+ self._intermediate_dim = intermediate_dim
1328
+ self._intermediate_activation = keras.activations.get(
1329
+ intermediate_activation
1330
+ )
1322
1331
  self._normalize = normalize
1323
1332
  self._embed_context = embed_context
1324
1333
 
@@ -1326,30 +1335,38 @@ class NodeEmbedding(GraphLayer):
1326
1335
  feature_dim = spec.node['feature'].shape[-1]
1327
1336
  if not self.dim:
1328
1337
  self.dim = feature_dim
1329
- self._node_dense = self.get_dense(self.dim)
1330
-
1338
+ if not self._intermediate_dim:
1339
+ self._intermediate_dim = self.dim * 2
1340
+ self._node_dense = self.get_dense(
1341
+ self._intermediate_dim, activation=self._intermediate_activation
1342
+ )
1331
1343
  self._has_super = 'super' in spec.node
1332
1344
  has_context_feature = 'feature' in spec.context
1333
1345
  if not has_context_feature:
1334
1346
  self._embed_context = False
1335
1347
  if self._has_super and not self._embed_context:
1336
- self._super_feature = self.get_weight(shape=[self.dim], name='super_node_feature')
1348
+ self._super_feature = self.get_weight(
1349
+ shape=[self._intermediate_dim], name='super_node_feature'
1350
+ )
1337
1351
  if self._embed_context:
1338
- self._context_dense = self.get_dense(self.dim)
1339
-
1352
+ self._context_dense = self.get_dense(
1353
+ self._intermediate_dim, activation=self._intermediate_activation
1354
+ )
1340
1355
  if not self._normalize:
1341
1356
  self._norm = keras.layers.Identity()
1342
1357
  elif str(self._normalize).lower().startswith('layer'):
1343
1358
  self._norm = keras.layers.LayerNormalization()
1344
1359
  else:
1345
1360
  self._norm = keras.layers.BatchNormalization()
1361
+ self._dense = self.get_dense(self.dim)
1346
1362
 
1347
1363
  def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1348
1364
  feature = self._node_dense(tensor.node['feature'])
1349
1365
 
1350
1366
  if self._has_super and not self._embed_context:
1351
1367
  super_mask = keras.ops.expand_dims(tensor.node['super'], 1)
1352
- feature = keras.ops.where(super_mask, self._super_feature, feature)
1368
+ super_feature = self._intermediate_activation(self._super_feature)
1369
+ feature = keras.ops.where(super_mask, super_feature, feature)
1353
1370
 
1354
1371
  if self._embed_context:
1355
1372
  context_feature = self._context_dense(tensor.context['feature'])
@@ -1357,6 +1374,7 @@ class NodeEmbedding(GraphLayer):
1357
1374
  tensor = tensor.update({'context': {'feature': None}})
1358
1375
 
1359
1376
  feature = self._norm(feature)
1377
+ feature = self._dense(feature)
1360
1378
 
1361
1379
  return tensor.update({'node': {'feature': feature}})
1362
1380
 
@@ -1364,6 +1382,10 @@ class NodeEmbedding(GraphLayer):
1364
1382
  config = super().get_config()
1365
1383
  config.update({
1366
1384
  'dim': self.dim,
1385
+ 'intermediate_dim': self._intermediate_dim,
1386
+ 'intermediate_activation': keras.activations.serialize(
1387
+ self._intermediate_activation
1388
+ ),
1367
1389
  'normalize': self._normalize,
1368
1390
  'embed_context': self._embed_context,
1369
1391
  })
@@ -1381,50 +1403,67 @@ class EdgeEmbedding(GraphLayer):
1381
1403
  def __init__(
1382
1404
  self,
1383
1405
  dim: int = None,
1406
+ intermediate_dim: int | None = None,
1407
+ intermediate_activation: str | keras.layers.Activation | None = 'relu',
1384
1408
  normalize: bool = False,
1385
1409
  **kwargs
1386
1410
  ) -> None:
1387
1411
  super().__init__(**kwargs)
1388
1412
  self.dim = dim
1413
+ self._intermediate_dim = intermediate_dim
1414
+ self._intermediate_activation = keras.activations.get(
1415
+ intermediate_activation
1416
+ )
1389
1417
  self._normalize = normalize
1390
1418
 
1391
1419
  def build(self, spec: tensors.GraphTensor.Spec) -> None:
1392
1420
  feature_dim = spec.edge['feature'].shape[-1]
1393
1421
  if not self.dim:
1394
1422
  self.dim = feature_dim
1395
- self._edge_dense = self.get_dense(self.dim)
1396
-
1397
- self._self_loop_feature = self.get_weight(shape=[self.dim], name='self_loop_edge_feature')
1398
-
1423
+ if not self._intermediate_dim:
1424
+ self._intermediate_dim = self.dim * 2
1425
+ self._edge_dense = self.get_dense(
1426
+ self._intermediate_dim, activation=self._intermediate_activation
1427
+ )
1428
+ self._self_loop_feature = self.get_weight(
1429
+ shape=[self._intermediate_dim], name='self_loop_edge_feature'
1430
+ )
1399
1431
  self._has_super = 'super' in spec.edge
1400
1432
  if self._has_super:
1401
- self._super_feature = self.get_weight(shape=[self.dim], name='super_edge_feature')
1402
-
1433
+ self._super_feature = self.get_weight(
1434
+ shape=[self._intermediate_dim], name='super_edge_feature'
1435
+ )
1403
1436
  if not self._normalize:
1404
1437
  self._norm = keras.layers.Identity()
1405
1438
  elif str(self._normalize).lower().startswith('layer'):
1406
1439
  self._norm = keras.layers.LayerNormalization()
1407
1440
  else:
1408
1441
  self._norm = keras.layers.BatchNormalization()
1442
+ self._dense = self.get_dense(self.dim)
1409
1443
 
1410
1444
  def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1411
1445
  feature = self._edge_dense(tensor.edge['feature'])
1412
1446
 
1413
1447
  if self._has_super:
1414
1448
  super_mask = keras.ops.expand_dims(tensor.edge['super'], 1)
1415
- feature = keras.ops.where(super_mask, self._super_feature, feature)
1449
+ super_feature = self._intermediate_activation(self._super_feature)
1450
+ feature = keras.ops.where(super_mask, super_feature, feature)
1416
1451
 
1417
1452
  self_loop_mask = keras.ops.expand_dims(tensor.edge['source'] == tensor.edge['target'], 1)
1418
- feature = keras.ops.where(self_loop_mask, self._self_loop_feature, feature)
1419
-
1453
+ self_loop_feature = self._intermediate_activation(self._self_loop_feature)
1454
+ feature = keras.ops.where(self_loop_mask, self_loop_feature, feature)
1420
1455
  feature = self._norm(feature)
1421
-
1456
+ feature = self._dense(feature)
1422
1457
  return tensor.update({'edge': {'feature': feature}})
1423
1458
 
1424
1459
  def get_config(self) -> dict:
1425
1460
  config = super().get_config()
1426
1461
  config.update({
1427
1462
  'dim': self.dim,
1463
+ 'intermediate_dim': self._intermediate_dim,
1464
+ 'intermediate_activation': keras.activations.serialize(
1465
+ self._intermediate_activation
1466
+ ),
1428
1467
  'normalize': self._normalize,
1429
1468
  })
1430
1469
  return config
@@ -1441,42 +1480,60 @@ class AddContext(GraphLayer):
1441
1480
  def __init__(
1442
1481
  self,
1443
1482
  field: str = 'feature',
1483
+ intermediate_dim: int | None = None,
1484
+ intermediate_activation: str | keras.layers.Activation | None = 'relu',
1444
1485
  drop: bool = False,
1445
1486
  normalize: bool = False,
1446
1487
  **kwargs
1447
1488
  ) -> None:
1448
1489
  super().__init__(**kwargs)
1449
- self.field = field
1450
- self.drop = drop
1490
+ self._field = field
1491
+ self._drop = drop
1492
+ self._intermediate_dim = intermediate_dim
1493
+ self._intermediate_activation = keras.activations.get(
1494
+ intermediate_activation
1495
+ )
1451
1496
  self._normalize = normalize
1452
1497
 
1453
1498
  def build(self, spec: tensors.GraphTensor.Spec) -> None:
1454
1499
  feature_dim = spec.node['feature'].shape[-1]
1455
- self._context_dense = self.get_dense(feature_dim)
1500
+ if self._intermediate_dim is None:
1501
+ self._intermediate_dim = feature_dim * 2
1502
+ self._intermediate_dense = self.get_dense(
1503
+ self._intermediate_dim, activation=self._intermediate_activation
1504
+ )
1505
+ self._final_dense = self.get_dense(feature_dim)
1456
1506
  if not self._normalize:
1457
- self._norm = keras.layers.Identity()
1507
+ self._intermediate_norm = keras.layers.Identity()
1458
1508
  elif str(self._normalize).lower().startswith('layer'):
1459
- self._norm = keras.layers.LayerNormalization()
1509
+ self._intermediate_norm = keras.layers.LayerNormalization()
1460
1510
  else:
1461
- self._norm = keras.layers.BatchNormalization()
1511
+ self._intermediate_norm = keras.layers.BatchNormalization()
1462
1512
 
1463
1513
  def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1464
- context = tensor.context[self.field]
1465
- context = self._context_dense(context)
1466
- context = self._norm(context)
1514
+ context = tensor.context[self._field]
1515
+ context = self._intermediate_dense(context)
1516
+ context = self._intermediate_norm(context)
1517
+ context = self._final_dense(context)
1467
1518
  node_feature = ops.scatter_add(
1468
1519
  tensor.node['feature'], tensor.node['super'], context
1469
1520
  )
1470
1521
  data = {'node': {'feature': node_feature}}
1471
- if self.drop:
1472
- data['context'] = {self.field: None}
1522
+ if self._drop:
1523
+ data['context'] = {self._field: None}
1473
1524
  return tensor.update(data)
1474
1525
 
1475
1526
  def get_config(self) -> dict:
1476
1527
  config = super().get_config()
1477
- config['field'] = self.field
1478
- config['drop'] = self.drop
1479
- config['normalize'] = self._normalize
1528
+ config.update({
1529
+ 'field': self._field,
1530
+ 'intermediate_dim': self._intermediate_dim,
1531
+ 'intermediate_activation': keras.activations.serialize(
1532
+ self._intermediate_activation
1533
+ ),
1534
+ 'drop': self._drop,
1535
+ 'normalize': self._normalize,
1536
+ })
1480
1537
  return config
1481
1538
 
1482
1539
 
@@ -1738,5 +1795,3 @@ def _spec_from_inputs(inputs):
1738
1795
  return spec
1739
1796
  return tensors.GraphTensor.Spec(**nested_specs)
1740
1797
 
1741
-
1742
- GraphTransformer = GTConv
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: molcraft
3
- Version: 0.1.0a17
3
+ Version: 0.1.0a18
4
4
  Summary: Graph Neural Networks for Molecular Machine Learning
5
5
  Author-email: Alexander Kensert <alexander.kensert@gmail.com>
6
6
  License: MIT License
File without changes
File without changes
File without changes
File without changes
File without changes