molcraft 0.1.0a17__tar.gz → 0.1.0a19__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of molcraft might be problematic. Click here for more details.

Files changed (32) hide show
  1. {molcraft-0.1.0a17 → molcraft-0.1.0a19}/PKG-INFO +1 -1
  2. {molcraft-0.1.0a17 → molcraft-0.1.0a19}/molcraft/__init__.py +4 -2
  3. {molcraft-0.1.0a17 → molcraft-0.1.0a19}/molcraft/applications/proteomics.py +111 -41
  4. {molcraft-0.1.0a17 → molcraft-0.1.0a19}/molcraft/featurizers.py +1 -0
  5. {molcraft-0.1.0a17 → molcraft-0.1.0a19}/molcraft/layers.py +99 -44
  6. {molcraft-0.1.0a17 → molcraft-0.1.0a19}/molcraft/tensors.py +16 -10
  7. {molcraft-0.1.0a17 → molcraft-0.1.0a19}/molcraft.egg-info/PKG-INFO +1 -1
  8. {molcraft-0.1.0a17 → molcraft-0.1.0a19}/LICENSE +0 -0
  9. {molcraft-0.1.0a17 → molcraft-0.1.0a19}/README.md +0 -0
  10. {molcraft-0.1.0a17 → molcraft-0.1.0a19}/molcraft/applications/__init__.py +0 -0
  11. {molcraft-0.1.0a17 → molcraft-0.1.0a19}/molcraft/applications/chromatography.py +0 -0
  12. {molcraft-0.1.0a17 → molcraft-0.1.0a19}/molcraft/callbacks.py +0 -0
  13. {molcraft-0.1.0a17 → molcraft-0.1.0a19}/molcraft/chem.py +0 -0
  14. {molcraft-0.1.0a17 → molcraft-0.1.0a19}/molcraft/datasets.py +0 -0
  15. {molcraft-0.1.0a17 → molcraft-0.1.0a19}/molcraft/descriptors.py +0 -0
  16. {molcraft-0.1.0a17 → molcraft-0.1.0a19}/molcraft/features.py +0 -0
  17. {molcraft-0.1.0a17 → molcraft-0.1.0a19}/molcraft/losses.py +0 -0
  18. {molcraft-0.1.0a17 → molcraft-0.1.0a19}/molcraft/models.py +0 -0
  19. {molcraft-0.1.0a17 → molcraft-0.1.0a19}/molcraft/ops.py +0 -0
  20. {molcraft-0.1.0a17 → molcraft-0.1.0a19}/molcraft/records.py +0 -0
  21. {molcraft-0.1.0a17 → molcraft-0.1.0a19}/molcraft.egg-info/SOURCES.txt +0 -0
  22. {molcraft-0.1.0a17 → molcraft-0.1.0a19}/molcraft.egg-info/dependency_links.txt +0 -0
  23. {molcraft-0.1.0a17 → molcraft-0.1.0a19}/molcraft.egg-info/requires.txt +0 -0
  24. {molcraft-0.1.0a17 → molcraft-0.1.0a19}/molcraft.egg-info/top_level.txt +0 -0
  25. {molcraft-0.1.0a17 → molcraft-0.1.0a19}/pyproject.toml +0 -0
  26. {molcraft-0.1.0a17 → molcraft-0.1.0a19}/setup.cfg +0 -0
  27. {molcraft-0.1.0a17 → molcraft-0.1.0a19}/tests/test_chem.py +0 -0
  28. {molcraft-0.1.0a17 → molcraft-0.1.0a19}/tests/test_featurizers.py +0 -0
  29. {molcraft-0.1.0a17 → molcraft-0.1.0a19}/tests/test_layers.py +0 -0
  30. {molcraft-0.1.0a17 → molcraft-0.1.0a19}/tests/test_losses.py +0 -0
  31. {molcraft-0.1.0a17 → molcraft-0.1.0a19}/tests/test_models.py +0 -0
  32. {molcraft-0.1.0a17 → molcraft-0.1.0a19}/tests/test_tensors.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: molcraft
3
- Version: 0.1.0a17
3
+ Version: 0.1.0a19
4
4
  Summary: Graph Neural Networks for Molecular Machine Learning
5
5
  Author-email: Alexander Kensert <alexander.kensert@gmail.com>
6
6
  License: MIT License
@@ -1,4 +1,4 @@
1
- __version__ = '0.1.0a17'
1
+ __version__ = '0.1.0a19'
2
2
 
3
3
  import os
4
4
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
@@ -14,4 +14,6 @@ from molcraft import records
14
14
  from molcraft import tensors
15
15
  from molcraft import callbacks
16
16
  from molcraft import datasets
17
- from molcraft import losses
17
+ from molcraft import losses
18
+
19
+ from molcraft.applications import proteomics
@@ -3,6 +3,7 @@ import keras
3
3
  import numpy as np
4
4
  import tensorflow as tf
5
5
  import tensorflow_text as tf_text
6
+ from rdkit import Chem
6
7
 
7
8
  from molcraft import featurizers
8
9
  from molcraft import tensors
@@ -10,16 +11,47 @@ from molcraft import layers
10
11
  from molcraft import models
11
12
  from molcraft import chem
12
13
 
14
+ """
15
+
16
+
17
+
18
+
19
+
20
+
21
+ No need to correct smiles for modeling, only for interpretation.
22
+
23
+ Use added smiles data to rearrange list of saliency values.
24
+
25
+
26
+
27
+
28
+
29
+
30
+
31
+
32
+
33
+
34
+
35
+
36
+ """
13
37
 
14
38
  # TODO: Add regex pattern for residue (C-term mod + N-term mod)?
15
39
  # TODO: Add regex pattern for residue (C-term mod + N-term mod + mod)?
40
+
41
+ no_mod_pattern = r'([A-Z])'
42
+ side_chain_mod_pattern = r'([A-Z]\[[A-Za-z0-9]+\])'
43
+ n_term_mod_pattern = r'(\[[A-Za-z0-9]+\]-[A-Z])'
44
+ c_term_mod_pattern = r'([A-Z]-\[[A-Za-z0-9]+\])'
45
+ side_chain_and_n_term_mod_pattern = r'(\[[A-Za-z0-9]+\]-[A-Z]\[[A-Za-z0-9]+\])'
46
+ side_chain_and_c_term_mod_pattern = r'([A-Z]\[[A-Za-z0-9]+\]-\[[A-Za-z0-9]+\])'
47
+
16
48
  residue_pattern: str = "|".join([
17
- r'(\[[A-Za-z0-9]+\]-[A-Z]\[[A-Za-z0-9]+\])', # residue (N-term mod + mod)
18
- r'([A-Z]\[[A-Za-z0-9]+\]-\[[A-Za-z0-9]+\])', # residue (C-term mod + mod)
19
- r'([A-Z]-\[[A-Za-z0-9]+\])', # residue (C-term mod)
20
- r'(\[[A-Za-z0-9]+\]-[A-Z])', # residue (N-term mod)
21
- r'([A-Z]\[[A-Za-z0-9]+\])', # residue (mod)
22
- r'([A-Z])', # residue (no mod)
49
+ side_chain_and_n_term_mod_pattern,
50
+ side_chain_and_c_term_mod_pattern,
51
+ n_term_mod_pattern,
52
+ c_term_mod_pattern,
53
+ side_chain_mod_pattern,
54
+ no_mod_pattern
23
55
  ])
24
56
 
25
57
  default_residues: dict[str, str] = {
@@ -45,20 +77,26 @@ default_residues: dict[str, str] = {
45
77
  "Y": "N[C@@H](Cc1ccc(O)cc1)C(=O)O",
46
78
  }
47
79
 
48
- def register_residues(residues: dict[str, str]) -> None:
49
- # TODO: Implement functions that check if residue has N- or C-terminal mod
50
- # if C-terminal mod, no need to enforce concatenatable perm.
51
- # if N-terminal mod, enforce only 'C(=O)O'
52
- # if normal mod, enforce concatenateable perm ('N[C@@H]' and 'C(=O)O)).
53
- for residue, smiles in residues.items():
54
- if residue.startswith('P'):
55
- smiles.startswith('N'), f'Incorrect SMILES permutation for {residue}.'
56
- elif not residue.startswith('['):
57
- smiles.startswith('N[C@@H]'), f'Incorrect SMILES permutation for {residue}.'
58
- if len(residue) > 1 and not residue[1] == "-":
59
- assert smiles.endswith('C(=O)O'), f'Incorrect SMILES permutation for {residue}.'
60
- registered_residues[residue] = smiles
61
- registered_residues[residue + '*'] = smiles.strip('O')
80
+ def has_c_terminal_mod(residue: str):
81
+ if re.search(c_term_mod_pattern, residue):
82
+ return True
83
+ return False
84
+
85
+ def has_n_terminal_mod(residue: str):
86
+ if re.search(n_term_mod_pattern, residue):
87
+ return True
88
+ return False
89
+
90
+ # def register_residues(residues: dict[str, str]) -> None:
91
+ # for residue, smiles in residues.items():
92
+ # if residue.startswith('P'):
93
+ # smiles.startswith('N'), f'Incorrect SMILES permutation for {residue}.'
94
+ # elif not residue.startswith('['):
95
+ # smiles.startswith('N[C@@H]'), f'Incorrect SMILES permutation for {residue}.'
96
+ # if len(residue) > 1 and not residue[1] == "-":
97
+ # assert smiles.endswith('C(=O)O'), f'Incorrect SMILES permutation for {residue}.'
98
+ # registered_residues[residue] = smiles
99
+ # registered_residues[residue + '*'] = smiles.strip('O')
62
100
 
63
101
 
64
102
  class Peptide(chem.Mol):
@@ -79,6 +117,22 @@ class Peptide(chem.Mol):
79
117
  return super().from_encoding(peptide_smiles, **kwargs)
80
118
 
81
119
 
120
+ def permute_residue_smiles(smiles: str) -> str:
121
+ glycine = chem.Mol.from_encoding("NCC(=O)O")
122
+ mol = chem.Mol.from_encoding(smiles)
123
+ nitrogen_index = mol.GetSubstructMatch(glycine)[0]
124
+ permuted_smiles = Chem.MolToSmiles(
125
+ mol, rootedAtAtom=nitrogen_index
126
+ )
127
+ return permuted_smiles
128
+
129
+ def check_peptide_residue_smiles(smiles: list[str]) -> bool:
130
+ backbone = 'NCC(=O)' * (len(smiles) - 1) + 'NC'
131
+ backbone = chem.Mol.from_encoding(backbone)
132
+ mol = chem.Mol.from_encoding(''.join(smiles))
133
+ is_valid = mol.HasSubstructMatch(backbone)
134
+ return is_valid
135
+
82
136
  @keras.saving.register_keras_serializable(package='proteomics')
83
137
  class ResidueEmbedding(keras.layers.Layer):
84
138
 
@@ -92,42 +146,50 @@ class ResidueEmbedding(keras.layers.Layer):
92
146
  super().__init__(**kwargs)
93
147
  if residues is None:
94
148
  residues = {}
95
- self._residue_dict = {**default_residues, **residues}
96
149
  self.embedder = embedder
97
150
  self.featurizer = featurizer
98
- self.embedding_dim = self.embedder.output.shape[-1]
151
+ self.embedding_dim = int(self.embedder.output.shape[-1])
99
152
  self.ragged_split = SequenceSplitter(pad=False)
100
153
  self.split = SequenceSplitter(pad=True)
101
154
  self.use_cached_embeddings = tf.Variable(False)
155
+ self.residues = residues
102
156
  self.supports_masking = True
103
157
 
104
158
  @property
105
159
  def residues(self) -> dict[str, str]:
106
- return self._residue_dict
160
+ return self._residues
107
161
 
108
162
  @residues.setter
109
163
  def residues(self, residues: dict[str, str]) -> None:
110
- self._residue_dict = residues
111
- num_residues = len(residues)
112
- residue_keys = sorted(residues.keys())
113
- oov_value = np.where(np.array(residue_keys) == "G")[0][0]
164
+
165
+ residues = {**default_residues, **residues}
166
+ self._residues = {}
167
+ for residue, smiles in residues.items():
168
+ self._residues[residue] = smiles
169
+ self._residues[residue + '*'] = smiles.rstrip('O')
170
+
171
+ residue_keys = sorted(self._residues.keys())
172
+ residue_values = range(len(residue_keys))
173
+ residue_oov_value = np.where(np.array(residue_keys) == "G")[0][0]
174
+
114
175
  self.mapping = tf.lookup.StaticHashTable(
115
176
  tf.lookup.KeyValueTensorInitializer(
116
177
  keys=residue_keys,
117
- values=range(num_residues)
178
+ values=residue_values
118
179
  ),
119
- default_value=oov_value,
180
+ default_value=residue_oov_value,
120
181
  )
182
+
121
183
  self.graph = tf.stack([
122
- self.featurizer(residues[residue]) for residue in residue_keys
184
+ self.featurizer(self._residues[r]) for r in residue_keys
123
185
  ], axis=0)
124
- self.cached_embeddings = tf.Variable(
125
- initial_value=tf.zeros((num_residues, self.embedding_dim))
126
- )
186
+
187
+ zeros = tf.zeros((residue_values[-1] + 1, self.embedding_dim))
188
+ self.cached_embeddings = tf.Variable(initial_value=zeros)
127
189
  _ = self.cache_and_get_embeddings()
128
190
 
129
191
  def build(self, input_shape) -> None:
130
- self.residues = self._residue_dict
192
+ self.residues = self._residues
131
193
  super().build(input_shape)
132
194
 
133
195
  def call(self, sequences: tf.Tensor, training: bool = None) -> tf.Tensor:
@@ -163,16 +225,24 @@ class ResidueEmbedding(keras.layers.Layer):
163
225
  def get_config(self) -> dict:
164
226
  config = super().get_config()
165
227
  config.update({
166
- 'featurizer': keras.saving.serialize_keras_object(self.featurizer),
167
- 'embedder': keras.saving.serialize_keras_object(self.embedder),
168
- 'residues': self._residue_dict,
228
+ 'featurizer': keras.saving.serialize_keras_object(
229
+ self.featurizer
230
+ ),
231
+ 'embedder': keras.saving.serialize_keras_object(
232
+ self.embedder
233
+ ),
234
+ 'residues': self._residues,
169
235
  })
170
236
  return config
171
237
 
172
238
  @classmethod
173
239
  def from_config(cls, config: dict) -> 'ResidueEmbedding':
174
- config['featurizer'] = keras.saving.deserialize_keras_object(config['featurizer'])
175
- config['embedder'] = keras.saving.deserialize_keras_object(config['embedder'])
240
+ config['featurizer'] = keras.saving.deserialize_keras_object(
241
+ config['featurizer']
242
+ )
243
+ config['embedder'] = keras.saving.deserialize_keras_object(
244
+ config['embedder']
245
+ )
176
246
  return super().from_config(config)
177
247
 
178
248
 
@@ -190,5 +260,5 @@ class SequenceSplitter(keras.layers.Layer):
190
260
  return inputs
191
261
 
192
262
 
193
- registered_residues: dict[str, str] = {}
194
- register_residues(default_residues)
263
+ # registered_residues: dict[str, str] = {}
264
+ # register_residues(default_residues)
@@ -413,6 +413,7 @@ class MolGraphFeaturizer3D(MolGraphFeaturizer):
413
413
 
414
414
  def get_config(self):
415
415
  config = super().get_config()
416
+ config.pop('bond_features', None)
416
417
  config['radius'] = self._radius
417
418
  config['pair_features'] = keras.saving.serialize_keras_object(
418
419
  self._pair_features
@@ -279,7 +279,7 @@ class GraphConv(GraphLayer):
279
279
  use_bias (bool):
280
280
  Whether bias should be used in the dense layers. Default to `True`.
281
281
  normalize (bool, str):
282
- Whether normalization should be applied to the final output. Default to `False`.
282
+ Whether a normalization layer should be obtain by `get_norm()`. Default to `False`.
283
283
  skip_connect (bool):
284
284
  Whether node feature input should be added to the node feature output. Default to `True`.
285
285
  kernel_initializer (keras.initializers.Initializer, str):
@@ -366,6 +366,7 @@ class GraphConv(GraphLayer):
366
366
  has_overridden_message = self.__class__.message != GraphConv.message
367
367
  if not has_overridden_message:
368
368
  self._message_intermediate_dense = self.get_dense(self.units)
369
+ self._message_norm = self.get_norm()
369
370
  self._message_intermediate_activation = self.activation
370
371
  self._message_final_dense = self.get_dense(self.units)
371
372
 
@@ -376,16 +377,10 @@ class GraphConv(GraphLayer):
376
377
  has_overridden_update = self.__class__.update != GraphConv.update
377
378
  if not has_overridden_update:
378
379
  self._update_intermediate_dense = self.get_dense(self.units)
380
+ self._update_norm = self.get_norm()
379
381
  self._update_intermediate_activation = self.activation
380
382
  self._update_final_dense = self.get_dense(self.units)
381
383
 
382
- if not self._normalize:
383
- self._normalization = keras.layers.Identity()
384
- elif str(self._normalize).lower().startswith('layer'):
385
- self._normalization = keras.layers.LayerNormalization()
386
- else:
387
- self._normalization = keras.layers.BatchNormalization()
388
-
389
384
  def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
390
385
  """Forward pass.
391
386
 
@@ -430,7 +425,7 @@ class GraphConv(GraphLayer):
430
425
  elif add_aggregate:
431
426
  update = update.update({'node': {'aggregate': None}})
432
427
 
433
- if not self._skip_connect and not self._normalize:
428
+ if not self._skip_connect:
434
429
  return update
435
430
 
436
431
  feature = update.node['feature']
@@ -438,8 +433,6 @@ class GraphConv(GraphLayer):
438
433
  if self._skip_connect:
439
434
  feature += residual
440
435
 
441
- feature = self._normalization(feature)
442
-
443
436
  return update.update({'node': {'feature': feature}})
444
437
 
445
438
  def message(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
@@ -480,6 +473,7 @@ class GraphConv(GraphLayer):
480
473
  axis=-1
481
474
  )
482
475
  message = self._message_intermediate_dense(message)
476
+ message = self._message_norm(message)
483
477
  message = self._message_intermediate_activation(message)
484
478
  message = self._message_final_dense(message)
485
479
  return tensor.update({'edge': {'message': message}})
@@ -519,6 +513,7 @@ class GraphConv(GraphLayer):
519
513
  """
520
514
  aggregate = tensor.node['aggregate']
521
515
  node_feature = self._update_intermediate_dense(aggregate)
516
+ node_feature = self._update_norm(node_feature)
522
517
  node_feature = self._update_intermediate_activation(node_feature)
523
518
  node_feature = self._update_final_dense(node_feature)
524
519
  return tensor.update(
@@ -530,6 +525,14 @@ class GraphConv(GraphLayer):
530
525
  }
531
526
  )
532
527
 
528
+ def get_norm(self, **kwargs):
529
+ if not self._normalize:
530
+ return keras.layers.Identity()
531
+ elif str(self._normalize).lower().startswith('layer'):
532
+ return keras.layers.LayerNormalization(**kwargs)
533
+ else:
534
+ return keras.layers.BatchNormalization(**kwargs)
535
+
533
536
  def get_config(self) -> dict:
534
537
  config = super().get_config()
535
538
  config.update({
@@ -1312,13 +1315,19 @@ class NodeEmbedding(GraphLayer):
1312
1315
 
1313
1316
  def __init__(
1314
1317
  self,
1315
- dim: int = None,
1318
+ dim: int | None = None,
1319
+ intermediate_dim: int | None = None,
1320
+ intermediate_activation: str | keras.layers.Activation | None = 'relu',
1316
1321
  normalize: bool = False,
1317
1322
  embed_context: bool = False,
1318
1323
  **kwargs
1319
1324
  ) -> None:
1320
1325
  super().__init__(**kwargs)
1321
1326
  self.dim = dim
1327
+ self._intermediate_dim = intermediate_dim
1328
+ self._intermediate_activation = keras.activations.get(
1329
+ intermediate_activation
1330
+ )
1322
1331
  self._normalize = normalize
1323
1332
  self._embed_context = embed_context
1324
1333
 
@@ -1326,30 +1335,38 @@ class NodeEmbedding(GraphLayer):
1326
1335
  feature_dim = spec.node['feature'].shape[-1]
1327
1336
  if not self.dim:
1328
1337
  self.dim = feature_dim
1329
- self._node_dense = self.get_dense(self.dim)
1330
-
1338
+ if not self._intermediate_dim:
1339
+ self._intermediate_dim = self.dim * 2
1340
+ self._node_dense = self.get_dense(
1341
+ self._intermediate_dim, activation=self._intermediate_activation
1342
+ )
1331
1343
  self._has_super = 'super' in spec.node
1332
1344
  has_context_feature = 'feature' in spec.context
1333
1345
  if not has_context_feature:
1334
1346
  self._embed_context = False
1335
1347
  if self._has_super and not self._embed_context:
1336
- self._super_feature = self.get_weight(shape=[self.dim], name='super_node_feature')
1348
+ self._super_feature = self.get_weight(
1349
+ shape=[self._intermediate_dim], name='super_node_feature'
1350
+ )
1337
1351
  if self._embed_context:
1338
- self._context_dense = self.get_dense(self.dim)
1339
-
1352
+ self._context_dense = self.get_dense(
1353
+ self._intermediate_dim, activation=self._intermediate_activation
1354
+ )
1340
1355
  if not self._normalize:
1341
1356
  self._norm = keras.layers.Identity()
1342
1357
  elif str(self._normalize).lower().startswith('layer'):
1343
1358
  self._norm = keras.layers.LayerNormalization()
1344
1359
  else:
1345
1360
  self._norm = keras.layers.BatchNormalization()
1361
+ self._dense = self.get_dense(self.dim)
1346
1362
 
1347
1363
  def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1348
1364
  feature = self._node_dense(tensor.node['feature'])
1349
1365
 
1350
1366
  if self._has_super and not self._embed_context:
1351
1367
  super_mask = keras.ops.expand_dims(tensor.node['super'], 1)
1352
- feature = keras.ops.where(super_mask, self._super_feature, feature)
1368
+ super_feature = self._intermediate_activation(self._super_feature)
1369
+ feature = keras.ops.where(super_mask, super_feature, feature)
1353
1370
 
1354
1371
  if self._embed_context:
1355
1372
  context_feature = self._context_dense(tensor.context['feature'])
@@ -1357,6 +1374,7 @@ class NodeEmbedding(GraphLayer):
1357
1374
  tensor = tensor.update({'context': {'feature': None}})
1358
1375
 
1359
1376
  feature = self._norm(feature)
1377
+ feature = self._dense(feature)
1360
1378
 
1361
1379
  return tensor.update({'node': {'feature': feature}})
1362
1380
 
@@ -1364,6 +1382,10 @@ class NodeEmbedding(GraphLayer):
1364
1382
  config = super().get_config()
1365
1383
  config.update({
1366
1384
  'dim': self.dim,
1385
+ 'intermediate_dim': self._intermediate_dim,
1386
+ 'intermediate_activation': keras.activations.serialize(
1387
+ self._intermediate_activation
1388
+ ),
1367
1389
  'normalize': self._normalize,
1368
1390
  'embed_context': self._embed_context,
1369
1391
  })
@@ -1381,50 +1403,67 @@ class EdgeEmbedding(GraphLayer):
1381
1403
  def __init__(
1382
1404
  self,
1383
1405
  dim: int = None,
1406
+ intermediate_dim: int | None = None,
1407
+ intermediate_activation: str | keras.layers.Activation | None = 'relu',
1384
1408
  normalize: bool = False,
1385
1409
  **kwargs
1386
1410
  ) -> None:
1387
1411
  super().__init__(**kwargs)
1388
1412
  self.dim = dim
1413
+ self._intermediate_dim = intermediate_dim
1414
+ self._intermediate_activation = keras.activations.get(
1415
+ intermediate_activation
1416
+ )
1389
1417
  self._normalize = normalize
1390
1418
 
1391
1419
  def build(self, spec: tensors.GraphTensor.Spec) -> None:
1392
1420
  feature_dim = spec.edge['feature'].shape[-1]
1393
1421
  if not self.dim:
1394
1422
  self.dim = feature_dim
1395
- self._edge_dense = self.get_dense(self.dim)
1396
-
1397
- self._self_loop_feature = self.get_weight(shape=[self.dim], name='self_loop_edge_feature')
1398
-
1423
+ if not self._intermediate_dim:
1424
+ self._intermediate_dim = self.dim * 2
1425
+ self._edge_dense = self.get_dense(
1426
+ self._intermediate_dim, activation=self._intermediate_activation
1427
+ )
1428
+ self._self_loop_feature = self.get_weight(
1429
+ shape=[self._intermediate_dim], name='self_loop_edge_feature'
1430
+ )
1399
1431
  self._has_super = 'super' in spec.edge
1400
1432
  if self._has_super:
1401
- self._super_feature = self.get_weight(shape=[self.dim], name='super_edge_feature')
1402
-
1433
+ self._super_feature = self.get_weight(
1434
+ shape=[self._intermediate_dim], name='super_edge_feature'
1435
+ )
1403
1436
  if not self._normalize:
1404
1437
  self._norm = keras.layers.Identity()
1405
1438
  elif str(self._normalize).lower().startswith('layer'):
1406
1439
  self._norm = keras.layers.LayerNormalization()
1407
1440
  else:
1408
1441
  self._norm = keras.layers.BatchNormalization()
1442
+ self._dense = self.get_dense(self.dim)
1409
1443
 
1410
1444
  def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1411
1445
  feature = self._edge_dense(tensor.edge['feature'])
1412
1446
 
1413
1447
  if self._has_super:
1414
1448
  super_mask = keras.ops.expand_dims(tensor.edge['super'], 1)
1415
- feature = keras.ops.where(super_mask, self._super_feature, feature)
1449
+ super_feature = self._intermediate_activation(self._super_feature)
1450
+ feature = keras.ops.where(super_mask, super_feature, feature)
1416
1451
 
1417
1452
  self_loop_mask = keras.ops.expand_dims(tensor.edge['source'] == tensor.edge['target'], 1)
1418
- feature = keras.ops.where(self_loop_mask, self._self_loop_feature, feature)
1419
-
1453
+ self_loop_feature = self._intermediate_activation(self._self_loop_feature)
1454
+ feature = keras.ops.where(self_loop_mask, self_loop_feature, feature)
1420
1455
  feature = self._norm(feature)
1421
-
1456
+ feature = self._dense(feature)
1422
1457
  return tensor.update({'edge': {'feature': feature}})
1423
1458
 
1424
1459
  def get_config(self) -> dict:
1425
1460
  config = super().get_config()
1426
1461
  config.update({
1427
1462
  'dim': self.dim,
1463
+ 'intermediate_dim': self._intermediate_dim,
1464
+ 'intermediate_activation': keras.activations.serialize(
1465
+ self._intermediate_activation
1466
+ ),
1428
1467
  'normalize': self._normalize,
1429
1468
  })
1430
1469
  return config
@@ -1441,42 +1480,60 @@ class AddContext(GraphLayer):
1441
1480
  def __init__(
1442
1481
  self,
1443
1482
  field: str = 'feature',
1483
+ intermediate_dim: int | None = None,
1484
+ intermediate_activation: str | keras.layers.Activation | None = 'relu',
1444
1485
  drop: bool = False,
1445
1486
  normalize: bool = False,
1446
1487
  **kwargs
1447
1488
  ) -> None:
1448
1489
  super().__init__(**kwargs)
1449
- self.field = field
1450
- self.drop = drop
1490
+ self._field = field
1491
+ self._drop = drop
1492
+ self._intermediate_dim = intermediate_dim
1493
+ self._intermediate_activation = keras.activations.get(
1494
+ intermediate_activation
1495
+ )
1451
1496
  self._normalize = normalize
1452
1497
 
1453
1498
  def build(self, spec: tensors.GraphTensor.Spec) -> None:
1454
1499
  feature_dim = spec.node['feature'].shape[-1]
1455
- self._context_dense = self.get_dense(feature_dim)
1500
+ if self._intermediate_dim is None:
1501
+ self._intermediate_dim = feature_dim * 2
1502
+ self._intermediate_dense = self.get_dense(
1503
+ self._intermediate_dim, activation=self._intermediate_activation
1504
+ )
1505
+ self._final_dense = self.get_dense(feature_dim)
1456
1506
  if not self._normalize:
1457
- self._norm = keras.layers.Identity()
1507
+ self._intermediate_norm = keras.layers.Identity()
1458
1508
  elif str(self._normalize).lower().startswith('layer'):
1459
- self._norm = keras.layers.LayerNormalization()
1509
+ self._intermediate_norm = keras.layers.LayerNormalization()
1460
1510
  else:
1461
- self._norm = keras.layers.BatchNormalization()
1511
+ self._intermediate_norm = keras.layers.BatchNormalization()
1462
1512
 
1463
1513
  def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
1464
- context = tensor.context[self.field]
1465
- context = self._context_dense(context)
1466
- context = self._norm(context)
1514
+ context = tensor.context[self._field]
1515
+ context = self._intermediate_dense(context)
1516
+ context = self._intermediate_norm(context)
1517
+ context = self._final_dense(context)
1467
1518
  node_feature = ops.scatter_add(
1468
1519
  tensor.node['feature'], tensor.node['super'], context
1469
1520
  )
1470
1521
  data = {'node': {'feature': node_feature}}
1471
- if self.drop:
1472
- data['context'] = {self.field: None}
1522
+ if self._drop:
1523
+ data['context'] = {self._field: None}
1473
1524
  return tensor.update(data)
1474
1525
 
1475
1526
  def get_config(self) -> dict:
1476
1527
  config = super().get_config()
1477
- config['field'] = self.field
1478
- config['drop'] = self.drop
1479
- config['normalize'] = self._normalize
1528
+ config.update({
1529
+ 'field': self._field,
1530
+ 'intermediate_dim': self._intermediate_dim,
1531
+ 'intermediate_activation': keras.activations.serialize(
1532
+ self._intermediate_activation
1533
+ ),
1534
+ 'drop': self._drop,
1535
+ 'normalize': self._normalize,
1536
+ })
1480
1537
  return config
1481
1538
 
1482
1539
 
@@ -1738,5 +1795,3 @@ def _spec_from_inputs(inputs):
1738
1795
  return spec
1739
1796
  return tensors.GraphTensor.Spec(**nested_specs)
1740
1797
 
1741
-
1742
- GraphTransformer = GTConv
@@ -16,13 +16,15 @@ class GraphTensorBatchEncoder(tf.experimental.ExtensionTypeBatchEncoder):
16
16
  if isinstance(f, tf.TensorSpec):
17
17
  return tf.TensorSpec(
18
18
  shape=[None] + f.shape[1:],
19
- dtype=f.dtype)
19
+ dtype=f.dtype
20
+ )
20
21
  elif isinstance(f, tf.RaggedTensorSpec):
21
22
  return tf.RaggedTensorSpec(
22
23
  shape=[batch_size, None] + f.shape[1:],
23
24
  dtype=f.dtype,
24
25
  ragged_rank=1,
25
- row_splits_dtype=f.row_splits_dtype)
26
+ row_splits_dtype=f.row_splits_dtype
27
+ )
26
28
  elif isinstance(f, tf.TypeSpec):
27
29
  return f.__batch_encoder__.batch(f, batch_size)
28
30
  return f
@@ -33,7 +35,8 @@ class GraphTensorBatchEncoder(tf.experimental.ExtensionTypeBatchEncoder):
33
35
  batched_spec = object.__new__(type(spec))
34
36
  batched_context_fields = tf.nest.map_structure(
35
37
  lambda spec: tf.TensorSpec([batch_size] + spec.shape, spec.dtype),
36
- context_fields)
38
+ context_fields
39
+ )
37
40
  batched_spec.__dict__.update({'context': batched_context_fields})
38
41
  batched_spec.__dict__.update(batched_fields)
39
42
  return batched_spec
@@ -46,13 +49,15 @@ class GraphTensorBatchEncoder(tf.experimental.ExtensionTypeBatchEncoder):
46
49
  if isinstance(f, tf.TensorSpec):
47
50
  return tf.TensorSpec(
48
51
  shape=[None] + f.shape[1:],
49
- dtype=f.dtype)
52
+ dtype=f.dtype
53
+ )
50
54
  elif isinstance(f, tf.RaggedTensorSpec):
51
55
  return tf.RaggedTensorSpec(
52
56
  shape=[None] + f.shape[2:],
53
57
  dtype=f.dtype,
54
58
  ragged_rank=0,
55
- row_splits_dtype=f.row_splits_dtype)
59
+ row_splits_dtype=f.row_splits_dtype
60
+ )
56
61
  elif isinstance(f, tf.TypeSpec):
57
62
  return f.__batch_encoder__.unbatch(f)
58
63
  return f
@@ -62,7 +67,8 @@ class GraphTensorBatchEncoder(tf.experimental.ExtensionTypeBatchEncoder):
62
67
  unbatched_fields = tf.nest.map_structure(unbatch_field, fields)
63
68
  unbatched_context_fields = tf.nest.map_structure(
64
69
  lambda spec: tf.TensorSpec(spec.shape[1:], spec.dtype),
65
- context_fields)
70
+ context_fields
71
+ )
66
72
  unbatched_spec = object.__new__(type(spec))
67
73
  unbatched_spec.__dict__.update({'context': unbatched_context_fields})
68
74
  unbatched_spec.__dict__.update(unbatched_fields)
@@ -91,7 +97,8 @@ class GraphTensorBatchEncoder(tf.experimental.ExtensionTypeBatchEncoder):
91
97
  shape=([None] if scalar else [None, None]) + f.shape[1:],
92
98
  dtype=f.dtype,
93
99
  ragged_rank=(0 if scalar else 1),
94
- row_splits_dtype=spec.context['size'].dtype)
100
+ row_splits_dtype=spec.context['size'].dtype
101
+ )
95
102
  return f
96
103
  fields = dict(spec.__dict__)
97
104
  context_fields = fields.pop('context')
@@ -99,7 +106,7 @@ class GraphTensorBatchEncoder(tf.experimental.ExtensionTypeBatchEncoder):
99
106
  encoded_fields = {**{'context': context_fields}, **encoded_fields}
100
107
  spec_components = tuple(encoded_fields.values())
101
108
  spec_components = tuple(
102
- x for x in tf.nest.flatten(spec_components)
109
+ x for x in tf.nest.flatten(spec_components)
103
110
  if isinstance(x, tf.TypeSpec)
104
111
  )
105
112
  return spec_components
@@ -117,7 +124,6 @@ class GraphTensorBatchEncoder(tf.experimental.ExtensionTypeBatchEncoder):
117
124
  fields = dict(zip(spec.__dict__.keys(), value_tuple))
118
125
  value = object.__new__(spec.value_type)
119
126
  value.__dict__.update(fields)
120
-
121
127
  flatten = is_ragged(value) and not is_ragged(spec)
122
128
  if flatten:
123
129
  value = value.flatten()
@@ -125,7 +131,7 @@ class GraphTensorBatchEncoder(tf.experimental.ExtensionTypeBatchEncoder):
125
131
 
126
132
 
127
133
  class GraphTensor(tf.experimental.BatchableExtensionType):
128
- context: typing.Mapping[str, typing.Union[tf.Tensor, tf.RaggedTensor]]
134
+ context: typing.Mapping[str, tf.Tensor]
129
135
  node: typing.Mapping[str, typing.Union[tf.Tensor, tf.RaggedTensor]]
130
136
  edge: typing.Mapping[str, typing.Union[tf.Tensor, tf.RaggedTensor]]
131
137
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: molcraft
3
- Version: 0.1.0a17
3
+ Version: 0.1.0a19
4
4
  Summary: Graph Neural Networks for Molecular Machine Learning
5
5
  Author-email: Alexander Kensert <alexander.kensert@gmail.com>
6
6
  License: MIT License
File without changes
File without changes
File without changes
File without changes
File without changes