molcraft 0.1.0a21__tar.gz → 0.1.0a23__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of molcraft might be problematic. Click here for more details.

Files changed (33) hide show
  1. {molcraft-0.1.0a21 → molcraft-0.1.0a23}/PKG-INFO +2 -3
  2. {molcraft-0.1.0a21 → molcraft-0.1.0a23}/README.md +1 -1
  3. {molcraft-0.1.0a21 → molcraft-0.1.0a23}/molcraft/__init__.py +1 -3
  4. {molcraft-0.1.0a21 → molcraft-0.1.0a23}/molcraft/chem.py +51 -30
  5. {molcraft-0.1.0a21 → molcraft-0.1.0a23}/molcraft/datasets.py +1 -0
  6. {molcraft-0.1.0a21 → molcraft-0.1.0a23}/molcraft/descriptors.py +37 -5
  7. molcraft-0.1.0a23/molcraft/diffusion.py +241 -0
  8. {molcraft-0.1.0a21 → molcraft-0.1.0a23}/molcraft/features.py +4 -5
  9. {molcraft-0.1.0a21 → molcraft-0.1.0a23}/molcraft/featurizers.py +13 -2
  10. {molcraft-0.1.0a21 → molcraft-0.1.0a23}/molcraft/layers.py +20 -21
  11. {molcraft-0.1.0a21 → molcraft-0.1.0a23}/molcraft/losses.py +1 -0
  12. {molcraft-0.1.0a21 → molcraft-0.1.0a23}/molcraft/models.py +4 -1
  13. {molcraft-0.1.0a21 → molcraft-0.1.0a23}/molcraft/ops.py +1 -0
  14. {molcraft-0.1.0a21 → molcraft-0.1.0a23}/molcraft/records.py +26 -12
  15. {molcraft-0.1.0a21 → molcraft-0.1.0a23}/molcraft/tensors.py +1 -0
  16. {molcraft-0.1.0a21 → molcraft-0.1.0a23}/molcraft.egg-info/PKG-INFO +2 -3
  17. {molcraft-0.1.0a21 → molcraft-0.1.0a23}/molcraft.egg-info/SOURCES.txt +1 -0
  18. {molcraft-0.1.0a21 → molcraft-0.1.0a23}/molcraft.egg-info/requires.txt +0 -1
  19. {molcraft-0.1.0a21 → molcraft-0.1.0a23}/pyproject.toml +0 -1
  20. {molcraft-0.1.0a21 → molcraft-0.1.0a23}/tests/test_chem.py +3 -0
  21. {molcraft-0.1.0a21 → molcraft-0.1.0a23}/tests/test_featurizers.py +4 -1
  22. {molcraft-0.1.0a21 → molcraft-0.1.0a23}/tests/test_losses.py +5 -1
  23. {molcraft-0.1.0a21 → molcraft-0.1.0a23}/tests/test_models.py +5 -1
  24. {molcraft-0.1.0a21 → molcraft-0.1.0a23}/LICENSE +0 -0
  25. {molcraft-0.1.0a21 → molcraft-0.1.0a23}/molcraft/applications/__init__.py +0 -0
  26. {molcraft-0.1.0a21 → molcraft-0.1.0a23}/molcraft/applications/chromatography.py +0 -0
  27. {molcraft-0.1.0a21 → molcraft-0.1.0a23}/molcraft/applications/proteomics.py +0 -0
  28. {molcraft-0.1.0a21 → molcraft-0.1.0a23}/molcraft/callbacks.py +1 -1
  29. {molcraft-0.1.0a21 → molcraft-0.1.0a23}/molcraft.egg-info/dependency_links.txt +0 -0
  30. {molcraft-0.1.0a21 → molcraft-0.1.0a23}/molcraft.egg-info/top_level.txt +0 -0
  31. {molcraft-0.1.0a21 → molcraft-0.1.0a23}/setup.cfg +0 -0
  32. {molcraft-0.1.0a21 → molcraft-0.1.0a23}/tests/test_layers.py +0 -0
  33. {molcraft-0.1.0a21 → molcraft-0.1.0a23}/tests/test_tensors.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: molcraft
3
- Version: 0.1.0a21
3
+ Version: 0.1.0a23
4
4
  Summary: Graph Neural Networks for Molecular Machine Learning
5
5
  Author-email: Alexander Kensert <alexander.kensert@gmail.com>
6
6
  License: MIT License
@@ -35,7 +35,6 @@ Requires-Python: >=3.10
35
35
  Description-Content-Type: text/markdown
36
36
  License-File: LICENSE
37
37
  Requires-Dist: tensorflow>=2.16
38
- Requires-Dist: tensorflow-text>=2.16
39
38
  Requires-Dist: rdkit>=2023.9.5
40
39
  Requires-Dist: pandas>=1.0.3
41
40
  Requires-Dist: ipython>=8.12.0
@@ -43,7 +42,7 @@ Provides-Extra: gpu
43
42
  Requires-Dist: tensorflow[and-cuda]>=2.16; extra == "gpu"
44
43
  Dynamic: license-file
45
44
 
46
- <img src="https://github.com/akensert/molcraft/blob/main/docs/_static/molcraft-logo.png" alt="molcraft-logo", width="90%">
45
+ <img src="https://github.com/akensert/molcraft/blob/main/docs/_static/molcraft-logo.png" alt="molcraft-logo" width="90%">
47
46
 
48
47
  **Deep Learning on Molecules**: A Minimalistic GNN package for Molecular ML.
49
48
 
@@ -1,4 +1,4 @@
1
- <img src="https://github.com/akensert/molcraft/blob/main/docs/_static/molcraft-logo.png" alt="molcraft-logo", width="90%">
1
+ <img src="https://github.com/akensert/molcraft/blob/main/docs/_static/molcraft-logo.png" alt="molcraft-logo" width="90%">
2
2
 
3
3
  **Deep Learning on Molecules**: A Minimalistic GNN package for Molecular ML.
4
4
 
@@ -1,4 +1,4 @@
1
- __version__ = '0.1.0a21'
1
+ __version__ = '0.1.0a23'
2
2
 
3
3
  import os
4
4
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
@@ -15,5 +15,3 @@ from molcraft import tensors
15
15
  from molcraft import callbacks
16
16
  from molcraft import datasets
17
17
  from molcraft import losses
18
-
19
- from molcraft.applications import proteomics
@@ -3,6 +3,7 @@ import collections
3
3
  import numpy as np
4
4
 
5
5
  from rdkit import Chem
6
+ from rdkit.Chem import AllChem
6
7
  from rdkit.Chem import Lipinski
7
8
  from rdkit.Chem import rdDistGeom
8
9
  from rdkit.Chem import rdDepictor
@@ -22,12 +23,17 @@ class Mol(Chem.Mol):
22
23
  if explicit_hs:
23
24
  rdkit_mol = Chem.AddHs(rdkit_mol)
24
25
  rdkit_mol.__class__ = cls
26
+ setattr(rdkit_mol, '_encoding', encoding)
25
27
  return rdkit_mol
26
28
 
27
29
  @property
28
30
  def canonical_smiles(self) -> str:
29
31
  return Chem.MolToSmiles(self, canonical=True)
30
32
 
33
+ @property
34
+ def encoding(self):
35
+ return getattr(self, '_encoding', None)
36
+
31
37
  @property
32
38
  def bonds(self) -> list['Bond']:
33
39
  if not hasattr(self, '_bonds'):
@@ -60,7 +66,7 @@ class Mol(Chem.Mol):
60
66
  atom = atom.GetIdx()
61
67
  return Atom.cast(self.GetAtomWithIdx(int(atom)))
62
68
 
63
- def get_path_between_atoms(
69
+ def get_shortest_path_between_atoms(
64
70
  self,
65
71
  atom_i: int | Chem.Atom,
66
72
  atom_j: int | Chem.Atom
@@ -100,13 +106,13 @@ class Mol(Chem.Mol):
100
106
 
101
107
  def get_conformer(self, index: int = 0) -> 'Conformer':
102
108
  if self.num_conformers == 0:
103
- warnings.warn('Molecule has no conformer.')
109
+ warnings.warn(f'{self} has no conformer. Returning None.')
104
110
  return None
105
111
  return Conformer.cast(self.GetConformer(index))
106
112
 
107
113
  def get_conformers(self) -> list['Conformer']:
108
114
  if self.num_conformers == 0:
109
- warnings.warn('Molecule has no conformer.')
115
+ warnings.warn(f'{self} has no conformers. Returning an empty list.')
110
116
  return []
111
117
  return [Conformer.cast(x) for x in self.GetConformers()]
112
118
 
@@ -117,7 +123,8 @@ class Mol(Chem.Mol):
117
123
  return None
118
124
 
119
125
  def __repr__(self) -> str:
120
- return f'<{self.__class__.__name__} {self.canonical_smiles} at {hex(id(self))}>'
126
+ encoding = self.encoding or self.canonical_smiles
127
+ return f'<{self.__class__.__name__} {encoding} at {hex(id(self))}>'
121
128
 
122
129
 
123
130
  class Conformer(Chem.Conformer):
@@ -244,7 +251,10 @@ def sanitize_mol(
244
251
  flag = Chem.SanitizeMol(mol, catchErrors=True)
245
252
  if flag != Chem.SanitizeFlags.SANITIZE_NONE:
246
253
  if strict:
247
- return None
254
+ raise ValueError(f'Could not sanitize {mol}.')
255
+ warnings.warn(
256
+ f'Could not sanitize {mol}. Proceeding with partial sanitization.'
257
+ )
248
258
  # Sanitize mol, excluding the steps causing the error previously
249
259
  Chem.SanitizeMol(mol, sanitizeOps=Chem.SanitizeFlags.SANITIZE_ALL^flag)
250
260
  if assign_stereo_chemistry:
@@ -391,6 +401,7 @@ def embed_conformers(
391
401
  mol: Mol,
392
402
  num_conformers: int,
393
403
  method: str = 'ETKDGv3',
404
+ timeout: int | None = None,
394
405
  random_seed: int | None = None,
395
406
  **kwargs
396
407
  ) -> Mol:
@@ -405,16 +416,22 @@ def embed_conformers(
405
416
  mol = Mol(mol)
406
417
  embedding_method = available_embedding_methods.get(method)
407
418
  if embedding_method is None:
408
- raise ValueError(
409
- f'Could not find `method` {method!r}. Specify either of: '
410
- '`ETDG`, `ETKDG`, `ETKDGv2`, `ETKDGv3`, `srETKDGv3` or `KDG`.'
419
+ warnings.warn(
420
+ f'{method} is not available. Proceeding with ETKDGv3.'
411
421
  )
422
+ embedding_method = available_embedding_methods['ETKDGv3']
412
423
 
413
424
  for key, value in kwargs.items():
414
425
  setattr(embedding_method, key, value)
415
426
 
416
- if random_seed is not None:
417
- embedding_method.randomSeed = random_seed
427
+ if not timeout:
428
+ timeout = 0 # No timeout
429
+
430
+ if not random_seed:
431
+ random_seed = -1 # No random seed
432
+
433
+ embedding_method.randomSeed = random_seed
434
+ embedding_method.timeout = timeout
418
435
 
419
436
  success = rdDistGeom.EmbedMultipleConfs(
420
437
  mol, numConfs=num_conformers, params=embedding_method
@@ -422,17 +439,18 @@ def embed_conformers(
422
439
  num_successes = len(success)
423
440
  if num_successes < num_conformers:
424
441
  warnings.warn(
425
- f'Could only embed {num_successes} out of {num_conformers} conformer(s) '
426
- f'for {mol.canonical_smiles!r} using {method}. Embedding the remaining '
427
- f'{num_conformers - num_successes} conformer(s) using different embedding methods.',
428
- stacklevel=2
442
+ f'Could only embed {num_successes} out of {num_conformers} conformer(s) for '
443
+ f'{mol} using the specified method ({method}) and parameters. Attempting to '
444
+ f'embed the remaining {num_conformers-num_successes} using fallback methods.',
429
445
  )
446
+ max_iters = 20 * mol.num_atoms # Doubling the number of iterations
430
447
  for fallback_method in [method, 'ETDG', 'KDG']:
431
448
  fallback_embedding_method = available_embedding_methods[fallback_method]
432
449
  fallback_embedding_method.useRandomCoords = True
450
+ fallback_embedding_method.maxIterations = int(max_iters)
433
451
  fallback_embedding_method.clearConfs = False
434
- if random_seed is not None:
435
- fallback_embedding_method.randomSeed = random_seed
452
+ fallback_embedding_method.timeout = int(timeout)
453
+ fallback_embedding_method.randomSeed = int(random_seed)
436
454
  success = rdDistGeom.EmbedMultipleConfs(
437
455
  mol, numConfs=(num_conformers - num_successes), params=fallback_embedding_method
438
456
  )
@@ -441,10 +459,13 @@ def embed_conformers(
441
459
  break
442
460
  else:
443
461
  raise RuntimeError(
444
- f'Could not embed {num_conformers} conformer(s) for {mol.canonical_smiles!r}. '
462
+ f'Could not embed {num_conformers} conformer(s) for {mol}. '
445
463
  )
446
464
  return mol
447
465
 
466
+
467
+ import warnings
468
+
448
469
  def optimize_conformers(
449
470
  mol: Mol,
450
471
  method: str = 'UFF',
@@ -453,14 +474,17 @@ def optimize_conformers(
453
474
  ignore_interfragment_interactions: bool = True,
454
475
  vdw_threshold: float = 10.0,
455
476
  ) -> Mol:
456
- available_force_field_methods = [
457
- 'MMFF', 'MMFF94', 'MMFF94s', 'UFF'
458
- ]
477
+ if mol.num_conformers == 0:
478
+ warnings.warn(
479
+ f'{mol} has no conformers to optimize. Proceeding without it.'
480
+ )
481
+ return Mol(mol)
482
+ available_force_field_methods = ['MMFF', 'MMFF94', 'MMFF94s', 'UFF']
459
483
  if method not in available_force_field_methods:
460
- raise ValueError(
461
- f'Could not find `method` {method!r}. Specify either of: '
462
- '`UFF`, `MMFF`, `MMFF94` or `MMFF94s`.'
484
+ warnings.warn(
485
+ f'{method} is not available. Proceeding with universal force field (UFF).'
463
486
  )
487
+ method = 'UFF'
464
488
  mol_optimized = Mol(mol)
465
489
  try:
466
490
  if method.startswith('MMFF'):
@@ -484,10 +508,9 @@ def optimize_conformers(
484
508
  )
485
509
  except RuntimeError as e:
486
510
  warnings.warn(
487
- f'{method} force field minimization did not succeed. Proceeding without it.',
488
- stacklevel=2
511
+ f'Unsuccessful {method} force field minimization for {mol}. Proceeding without it.',
489
512
  )
490
- return mol
513
+ return Mol(mol)
491
514
  return mol_optimized
492
515
 
493
516
  def prune_conformers(
@@ -498,11 +521,9 @@ def prune_conformers(
498
521
  ) -> Mol:
499
522
  if mol.num_conformers == 0:
500
523
  warnings.warn(
501
- 'Molecule has no conformers. To embed conformers, invoke the `embed` method, '
502
- 'and optionally followed by `minimize()` to perform force field minimization.',
503
- stacklevel=2
524
+ f'{mol} has no conformers to prune. Proceeding without it.'
504
525
  )
505
- return mol
526
+ return Chem.Mol(mol)
506
527
 
507
528
  threshold = threshold or 0.0
508
529
  deviations = conformer_deviations(mol)
@@ -1,3 +1,4 @@
1
+ import warnings
1
2
  import numpy as np
2
3
  import pandas as pd
3
4
  import typing
@@ -1,5 +1,7 @@
1
+ import warnings
1
2
  import keras
2
3
  import numpy as np
4
+
3
5
  from rdkit.Chem import rdMolDescriptors
4
6
 
5
7
  from molcraft import chem
@@ -12,9 +14,7 @@ class Descriptor(features.Feature):
12
14
  def __call__(self, mol: chem.Mol) -> np.ndarray:
13
15
  if not isinstance(mol, chem.Mol):
14
16
  raise ValueError(
15
- f'Input to {self.name} needs to be a `chem.Mol`, which '
16
- 'implements two properties that should be iterated over '
17
- 'to compute features: `atoms` and `bonds`.'
17
+ f'Input to {self.name} must be a `chem.Mol` object.'
18
18
  )
19
19
  descriptor = self.call(mol)
20
20
  func = (
@@ -30,6 +30,23 @@ class Descriptor(features.Feature):
30
30
  return np.concatenate(descriptors)
31
31
 
32
32
 
33
+ @keras.saving.register_keras_serializable(package='molcraft')
34
+ class Descriptor3D(Descriptor):
35
+
36
+ def __call__(self, mol: chem.Mol) -> np.ndarray:
37
+ if not isinstance(mol, chem.Mol):
38
+ raise ValueError(
39
+ f'Input to {self.name} must be a `chem.Mol` object.'
40
+ )
41
+ if mol.num_conformers == 0:
42
+ raise ValueError(
43
+ f'The inputted `chem.Mol` to {self.name} must embed a conformer. '
44
+ f'It is recommended that {self.name} is used as a molecule feature '
45
+ 'for `MolGraphFeaturizer3D`, which by default embeds a conformer.'
46
+ )
47
+ return super().__call__(mol)
48
+
49
+
33
50
  @keras.saving.register_keras_serializable(package='molcraft')
34
51
  class MolWeight(Descriptor):
35
52
  def call(self, mol: chem.Mol) -> np.ndarray:
@@ -77,7 +94,7 @@ class NumHydrogenDonors(Descriptor):
77
94
  @keras.saving.register_keras_serializable(package='molcraft')
78
95
  class NumHydrogenAcceptors(Descriptor):
79
96
  def call(self, mol: chem.Mol) -> np.ndarray:
80
- return rdMolDescriptors.CalcNumHBA(mol)
97
+ return rdMolDescriptors.CalcNumHBA(mol)
81
98
 
82
99
 
83
100
  @keras.saving.register_keras_serializable(package='molcraft')
@@ -89,7 +106,7 @@ class NumRotatableBonds(Descriptor):
89
106
  @keras.saving.register_keras_serializable(package='molcraft')
90
107
  class NumRings(Descriptor):
91
108
  def call(self, mol: chem.Mol) -> np.ndarray:
92
- return rdMolDescriptors.CalcNumRings(mol)
109
+ return rdMolDescriptors.CalcNumRings(mol)
93
110
 
94
111
 
95
112
  @keras.saving.register_keras_serializable(package='molcraft')
@@ -105,3 +122,18 @@ class AtomCount(Descriptor):
105
122
  if atom.GetSymbol() == self.atom_type:
106
123
  count += 1
107
124
  return count
125
+
126
+ def get_config(self) -> dict:
127
+ config = super().get_config()
128
+ config['atom_type'] = self.atom_type
129
+ return config
130
+
131
+
132
+ @keras.saving.register_keras_serializable(package='molcraft')
133
+ class ForceFieldEnergy(Descriptor3D):
134
+ """Universal Force Field (UFF) Energy."""
135
+ def call(self, mol: chem.Mol) -> np.ndarray:
136
+ mol_copy = chem.Mol(mol)
137
+ mol_copy = chem.add_hs(mol_copy)
138
+ return chem.conformer_energies(mol_copy, method="UFF")
139
+
@@ -0,0 +1,241 @@
1
+ import warnings
2
+ import keras
3
+ import tensorflow as tf
4
+ import numpy as np
5
+
6
+ from molcraft import ops
7
+ from molcraft import tensors
8
+ from molcraft import layers
9
+ from molcraft import models
10
+
11
+
12
+ # smiles = pd.read_csv('../../data/rt/RIKEN.csv')['smiles'].values
13
+
14
+
15
+ # graph = featurizers.MolGraphFeaturizer3D(super_node=False)(smiles)
16
+ # graph.node['coordinate']
17
+
18
+
19
+ # encoder = molcraft.models.GraphModel.from_layers(
20
+ # [
21
+ # diffusion.CoordinateNoise(),
22
+ # molcraft.layers.NodeEmbedding(128),
23
+ # molcraft.layers.EdgeEmbedding(128),
24
+ # molcraft.layers.AddContext('position'),
25
+ # molcraft.layers.MPConv(128),
26
+ # molcraft.layers.AddContext('position'),
27
+ # molcraft.layers.MPConv(128),
28
+ # molcraft.layers.AddContext('position'),
29
+ # ]
30
+ # )
31
+
32
+ # decoder = keras.Sequential([
33
+ # keras.layers.Dense(128, activation='relu'),
34
+ # keras.layers.Dense(3),
35
+ # ])
36
+
37
+ # model = diffusion.CoordinateNoisePredictor(encoder, decoder)
38
+
39
+ # model(graph)
40
+
41
+ # model.save('/tmp/model.keras')
42
+ # model = molcraft.models.load_model('/tmp/model.keras')
43
+
44
+ # model.compile(keras.optimizers.Adam(1e-3), 'mse')
45
+ # model.fit(graph, epochs=100)
46
+
47
+
48
+ # from rdkit.Geometry import Point3D
49
+
50
+ # def energy(smiles, coordinate):
51
+ # m = chem.Mol.from_encoding(smiles)
52
+ # m = chem.embed_conformers(m, 1)
53
+ # conf = m.GetConformer()
54
+ # for i in range(m.GetNumAtoms()):
55
+ # x, y, z = coordinate[i]
56
+ # conf.SetAtomPosition(i, Point3D(float(x), float(y), float(z)))
57
+ # return m, chem.conformer_energies(m)[0]
58
+
59
+ # def denoise(
60
+ # graph: tensors.GraphTensor,
61
+ # model,
62
+ # ):
63
+
64
+ # print("----")
65
+ # print(energy(smiles[0], graph[0].node['coordinate'])[-1])
66
+ # print("----")
67
+
68
+ # beta = keras.ops.linspace(1e-4, 1e-2, 100)
69
+ # alpha = 1 - beta
70
+ # alpha_bar = keras.ops.cumprod(alpha)
71
+ # sigma = keras.ops.sqrt(beta[1:] * (1.0 - alpha_bar[:-1]) / (1.0 - alpha_bar[1:]))
72
+
73
+ # graph = graph.update(
74
+ # {
75
+ # 'context': {
76
+ # 'position': keras.ops.ones_like(graph.context['size']) * 99
77
+ # },
78
+ # 'node': {
79
+ # 'coordinate': keras.random.normal(graph.node['coordinate'].shape)
80
+ # }
81
+ # }
82
+ # )
83
+
84
+ # for t in reversed(range(100)):
85
+ # alpha_t = alpha[t]
86
+ # alpha_bar_t = alpha_bar[t]
87
+
88
+ # a = 1 / keras.ops.sqrt(alpha_t)
89
+
90
+ # b = (1 - alpha_t) / keras.ops.sqrt(1 - alpha_bar_t)
91
+
92
+ # if t > 0:
93
+ # z = keras.random.normal(()) * sigma[t-1]
94
+ # else:
95
+ # z = 0.0
96
+
97
+ # graph = graph.update({
98
+ # 'node': {
99
+ # 'coordinate': (
100
+ # a * (graph.node['coordinate'] - b * model(graph)) + z
101
+ # )
102
+ # }
103
+ # })
104
+
105
+
106
+ # print(energy(smiles[0], graph[0].node['coordinate'])[-1])
107
+
108
+ # return graph
109
+
110
+ # graph_updated = denoise(graph[:1], model)x
111
+ # mol, e = energy(smiles[0], graph_updated[0].node['coordinate'])
112
+ # print(e)
113
+ # Chem.Mol(mol)
114
+
115
+ @keras.saving.register_keras_serializable(package='molcraft')
116
+ class CoordinateNoisePredictor(models.GraphModel):
117
+
118
+ def __init__(self, encoder, decoder, *args, **kwargs):
119
+ super().__init__(*args, **kwargs)
120
+ self.encoder = encoder
121
+ self.decoder = decoder
122
+
123
+ def propagate(self, tensor):
124
+ return self.decoder(self.encoder(tensor).node['feature'])
125
+
126
+ def train_step(self, tensor: tensors.GraphTensor) -> dict[str, float]:
127
+ with tf.GradientTape() as tape:
128
+ tensor = self.encoder(tensor)
129
+ feature = tensor.node['feature']
130
+ noise_true = tensor.node['label']
131
+ noise_pred = self.decoder(feature)
132
+ loss = self.compute_loss(tensor, noise_true, noise_pred)
133
+ loss = self.optimizer.scale_loss(loss)
134
+ trainable_weights = self.trainable_weights
135
+ gradients = tape.gradient(loss, trainable_weights)
136
+ self.optimizer.apply_gradients(zip(gradients, trainable_weights))
137
+ return self.compute_metrics(tensor, noise_true, noise_pred)
138
+
139
+ def test_step(self, tensor: tensors.GraphTensor) -> dict[str, float]:
140
+ tensor = self.encoder(tensor)
141
+ feature = tensor.node['feature']
142
+ noise_true = tensor.node['label']
143
+ noise_pred = self.decoder(feature)
144
+ return self.compute_metrics(tensor, noise_true, noise_pred)
145
+
146
+ def get_config(self) -> dict:
147
+ config = super().get_config()
148
+ config['encoder'] = keras.saving.serialize_keras_object(self.encoder)
149
+ config['decoder'] = keras.saving.serialize_keras_object(self.decoder)
150
+ return config
151
+
152
+ @classmethod
153
+ def from_config(cls, config: dict):
154
+ config['encoder'] = keras.saving.deserialize_keras_object(config['encoder'])
155
+ config['decoder'] = keras.saving.deserialize_keras_object(config['decoder'])
156
+ return super().from_config(config)
157
+
158
+
159
+ @keras.saving.register_keras_serializable(package='molcraft')
160
+ class CoordinateNoise(layers.GraphLayer):
161
+
162
+ def __init__(
163
+ self,
164
+ beta: tuple[float, float] = (1e-4, 1e-2),
165
+ position_dim: int = 128,
166
+ max_timesteps: int = 100,
167
+ **kwargs
168
+ ) -> None:
169
+ super().__init__(**kwargs)
170
+ self._beta = beta
171
+ self._max_timesteps = max_timesteps
172
+ beta = keras.ops.linspace(*self._beta, self._max_timesteps)
173
+ alpha = 1 - beta
174
+ alpha_cumprod = keras.ops.cumprod(alpha)
175
+ alpha_cumprod = keras.ops.expand_dims(alpha_cumprod, -1)
176
+ self._alpha_cumprod = alpha_cumprod
177
+ self._timestep_embedding = TimestepEmbedding(dim=position_dim)
178
+
179
+ def propagate(self, graph: tensors.GraphTensor) -> tensors.GraphTensor:
180
+ if 'position' in graph.context:
181
+ return graph.update({'context': {'position': self._timestep_embedding(graph.context['position'])}})
182
+
183
+ timestep = keras.random.randint(
184
+ shape=(graph.num_subgraphs,), minval=0, maxval=self._max_timesteps
185
+ )
186
+ alpha_cumprod = ops.gather(
187
+ ops.gather(self._alpha_cumprod, timestep), graph.graph_indicator
188
+ )
189
+ epsilon = keras.random.normal(
190
+ shape=keras.ops.shape(graph.node['coordinate']), mean=0, stddev=1
191
+ )
192
+ noisy_coordinate = (
193
+ keras.ops.sqrt(alpha_cumprod) * graph.node['coordinate'] +
194
+ keras.ops.sqrt(1 - alpha_cumprod) * epsilon
195
+ )
196
+ timestep = self._timestep_embedding(timestep)
197
+ return graph.update(
198
+ {
199
+ 'context': {
200
+ 'position': timestep,
201
+ },
202
+ 'node': {
203
+ 'coordinate': noisy_coordinate,
204
+ 'label': epsilon
205
+ },
206
+ }
207
+ )
208
+
209
+ def get_config(self) -> dict:
210
+ config = super().get_config()
211
+ config['beta'] = self._beta
212
+ config['max_timesteps'] = self._max_timesteps
213
+ return config
214
+
215
+
216
+ class TimestepEmbedding(keras.layers.Layer):
217
+
218
+ def __init__(self, dim: int, max_wavelength: int = 10000, **kwargs) -> None:
219
+ super().__init__(**kwargs)
220
+ self._dim = dim
221
+ self._max_wavelength = max_wavelength
222
+
223
+ def call(self, inputs: tf.Tensor) -> tf.Tensor:
224
+ timestep = keras.ops.cast(inputs, 'float32')
225
+ embedding = keras.ops.log(self._max_wavelength) / (self._dim // 2 - 1)
226
+ embedding = keras.ops.exp(
227
+ -embedding * keras.ops.arange(self._dim // 2, dtype='float32')
228
+ )
229
+ embedding = timestep[:, None] * embedding[None, :]
230
+ embedding = keras.ops.concatenate(
231
+ [keras.ops.sin(embedding), keras.ops.cos(embedding)], axis=-1
232
+ )
233
+ return embedding
234
+
235
+
236
+ def get_config(self) -> dict:
237
+ config = super().get_config()
238
+ config['dim'] = self._dim
239
+ config['max_wavelength'] = self._max_wavelength
240
+ return config
241
+
@@ -1,7 +1,7 @@
1
+ import warnings
1
2
  import abc
2
3
  import math
3
4
  import keras
4
- import warnings
5
5
  import numpy as np
6
6
 
7
7
  from molcraft import chem
@@ -41,14 +41,14 @@ class Feature(abc.ABC):
41
41
 
42
42
  def __call__(self, mol: chem.Mol) -> np.ndarray:
43
43
  if not isinstance(mol, chem.Mol):
44
- raise TypeError(f'Input to {self.name} must be a `chem.Mol` instance.')
44
+ raise TypeError(f'Input to {self.name} must be a `chem.Mol` object.')
45
45
  features = self.call(mol)
46
46
  if len(features) != mol.num_atoms and len(features) != mol.num_bonds:
47
47
  raise ValueError(
48
48
  f'The number of features computed by {self.name} does not '
49
49
  'match the number of atoms or bonds of the `chem.Mol` object. '
50
- 'Make sure to iterate over `atoms` or `bonds` of `chem.Mol` '
51
- 'when computing features.'
50
+ 'Make sure to iterate over `atoms` or `bonds` of the `chem.Mol` '
51
+ 'object when computing features.'
52
52
  )
53
53
  if len(features) == 0:
54
54
  # Edge case: no atoms or bonds in the molecule.
@@ -109,7 +109,6 @@ class Feature(abc.ABC):
109
109
  warnings.warn(
110
110
  f'Found value of {self.name} to be non-finite. '
111
111
  f'Value received: {value}. Converting it to a value of 0.',
112
- stacklevel=2
113
112
  )
114
113
  value = 0.0
115
114
  return np.asarray([value], dtype=self.dtype)
@@ -1,9 +1,8 @@
1
+ import warnings
1
2
  import keras
2
3
  import json
3
4
  import abc
4
5
  import typing
5
- import copy
6
- import warnings
7
6
  import numpy as np
8
7
  import pandas as pd
9
8
  import tensorflow as tf
@@ -13,6 +12,7 @@ from pathlib import Path
13
12
 
14
13
  from molcraft import tensors
15
14
  from molcraft import features
15
+ from molcraft import records
16
16
  from molcraft import chem
17
17
  from molcraft import descriptors
18
18
 
@@ -41,6 +41,17 @@ class GraphFeaturizer(abc.ABC):
41
41
  def load(filepath: str | Path, *args, **kwargs) -> 'GraphFeaturizer':
42
42
  return load_featurizer(filepath, *args, **kwargs)
43
43
 
44
+ def write_records(self, inputs: str | chem.Mol | tuple, path: str | Path, **kwargs) -> None:
45
+ records.write(
46
+ inputs, featurizer=self, path=path, **kwargs
47
+ )
48
+
49
+ @staticmethod
50
+ def read_records(path: str | Path, **kwargs) -> tf.data.Dataset:
51
+ return records.read(
52
+ path=path, **kwargs
53
+ )
54
+
44
55
  def __call__(
45
56
  self,
46
57
  inputs: str | chem.Mol | tuple | typing.Iterable,
@@ -1,6 +1,6 @@
1
+ import warnings
1
2
  import keras
2
3
  import tensorflow as tf
3
- import warnings
4
4
  import functools
5
5
  from keras.src.models import functional
6
6
 
@@ -350,11 +350,8 @@ class GraphConv(GraphLayer):
350
350
  )
351
351
  if self._project_residual:
352
352
  warnings.warn(
353
- '`skip_connect` is set to `True`, but found incompatible dim '
354
- 'between input (node feature dim) and output (`self.units`). '
355
- 'Automatically applying a projection layer to residual to '
356
- 'match input and output. ',
357
- stacklevel=2,
353
+ 'Found incompatible dim between input and output. Applying '
354
+ 'a projection layer to residual to match input and output dim.',
358
355
  )
359
356
  self._residual_dense = self.get_dense(
360
357
  self.units, name='residual_dense'
@@ -613,10 +610,8 @@ class GIConv(GraphConv):
613
610
  if not self._update_edge_feature:
614
611
  if (edge_feature_dim != node_feature_dim):
615
612
  warnings.warn(
616
- 'Found edge feature dim to be incompatible with node feature dim. '
617
- 'Automatically adding a edge feature projection layer to match '
618
- 'the dim of node features.',
619
- stacklevel=2,
613
+ 'Found edge and node feature dim to be incompatible. Applying a '
614
+ 'projection layer to edge features to match the dim of the node features.',
620
615
  )
621
616
  self._update_edge_feature = True
622
617
 
@@ -870,10 +865,10 @@ class MPConv(GraphConv):
870
865
  self._project_previous_node_feature = node_feature_dim != self.units
871
866
  if self._project_previous_node_feature:
872
867
  warnings.warn(
873
- 'Input node feature dim does not match updated node feature dim. '
874
- 'To make sure input node feature can be passed as `states` to the '
875
- 'GRU cell, it will automatically be projected prior to it.',
876
- stacklevel=2
868
+ 'Inputted node feature dim does not match updated node feature dim, '
869
+ 'which is required for the GRU update. Applying a projection layer to '
870
+ 'the inputted node features prior to the GRU update, to match dim '
871
+ 'of the updated node feature dim.'
877
872
  )
878
873
  self._previous_node_dense = self.get_dense(self.units)
879
874
 
@@ -1497,6 +1492,7 @@ class AddContext(GraphLayer):
1497
1492
 
1498
1493
  def build(self, spec: tensors.GraphTensor.Spec) -> None:
1499
1494
  feature_dim = spec.node['feature'].shape[-1]
1495
+ self._has_super_node = 'super' in spec.node
1500
1496
  if self._intermediate_dim is None:
1501
1497
  self._intermediate_dim = feature_dim * 2
1502
1498
  self._intermediate_dense = self.get_dense(
@@ -1515,9 +1511,14 @@ class AddContext(GraphLayer):
1515
1511
  context = self._intermediate_dense(context)
1516
1512
  context = self._intermediate_norm(context)
1517
1513
  context = self._final_dense(context)
1518
- node_feature = ops.scatter_add(
1519
- tensor.node['feature'], tensor.node['super'], context
1520
- )
1514
+ if self._has_super_node:
1515
+ node_feature = ops.scatter_add(
1516
+ tensor.node['feature'], tensor.node['super'], context
1517
+ )
1518
+ else:
1519
+ node_feature = (
1520
+ tensor.node['feature'] + ops.gather(context, tensor.graph_indicator)
1521
+ )
1521
1522
  data = {'node': {'feature': node_feature}}
1522
1523
  if self._drop:
1523
1524
  data['context'] = {self._field: None}
@@ -1561,8 +1562,7 @@ class GraphNetwork(GraphLayer):
1561
1562
  if self._update_node_feature:
1562
1563
  warnings.warn(
1563
1564
  'Node feature dim does not match `units` of the first layer. '
1564
- 'Automatically adding a node projection layer to match `units`.',
1565
- stacklevel=2
1565
+ 'Applying a projection layer to node features to match `units`.',
1566
1566
  )
1567
1567
  self._node_dense = self.get_dense(units)
1568
1568
  self._has_edge_feature = 'feature' in spec.edge
@@ -1572,8 +1572,7 @@ class GraphNetwork(GraphLayer):
1572
1572
  if self._update_edge_feature:
1573
1573
  warnings.warn(
1574
1574
  'Edge feature dim does not match `units` of the first layer. '
1575
- 'Automatically adding a edge projection layer to match `units`.',
1576
- stacklevel=2
1575
+ 'Applying projection layer to edge features to match `units`.'
1577
1576
  )
1578
1577
  self._edge_dense = self.get_dense(units)
1579
1578
 
@@ -1,3 +1,4 @@
1
+ import warnings
1
2
  import keras
2
3
  import numpy as np
3
4
 
@@ -1,3 +1,4 @@
1
+ import warnings
1
2
  import typing
2
3
  import keras
3
4
  import numpy as np
@@ -111,7 +112,7 @@ class GraphModel(layers.GraphLayer, keras.models.Model):
111
112
  def __new__(cls, *args, **kwargs):
112
113
  if _functional_init_arguments(args, kwargs) and cls == GraphModel:
113
114
  return FunctionalGraphModel(*args, **kwargs)
114
- return typing.cast(GraphModel, super().__new__(cls))
115
+ return super().__new__(cls)
115
116
 
116
117
  def __init__(self, *args, **kwargs):
117
118
  self._model_layers = kwargs.pop('model_layers', None)
@@ -137,6 +138,8 @@ class GraphModel(layers.GraphLayer, keras.models.Model):
137
138
  """
138
139
  if not tensors.is_graph(graph_layers[0]):
139
140
  return cls(model_layers=graph_layers)
141
+ elif cls != GraphModel:
142
+ return cls(model_layers=graph_layers[1:])
140
143
  inputs: dict = graph_layers.pop(0)
141
144
  x = inputs
142
145
  for layer in graph_layers:
@@ -1,3 +1,4 @@
1
+ import warnings
1
2
  import keras
2
3
  import numpy as np
3
4
  import tensorflow as tf
@@ -1,3 +1,4 @@
1
+ import warnings
1
2
  import os
2
3
  import math
3
4
  import glob
@@ -9,14 +10,17 @@ import pandas as pd
9
10
  import multiprocessing as mp
10
11
 
11
12
  from molcraft import tensors
12
- from molcraft import featurizers
13
+
14
+ if typing.TYPE_CHECKING:
15
+ from molcraft import featurizers
13
16
 
14
17
 
15
18
  def write(
16
19
  inputs: list[str | tuple],
17
- featurizer: featurizers.GraphFeaturizer,
20
+ featurizer: 'featurizers.GraphFeaturizer',
18
21
  path: str,
19
- overwrite: bool = True,
22
+ exist_ok: bool = False,
23
+ overwrite: bool = False,
20
24
  num_files: typing.Optional[int] = None,
21
25
  num_processes: typing.Optional[int] = None,
22
26
  multiprocessing: bool = False,
@@ -24,6 +28,8 @@ def write(
24
28
  ) -> None:
25
29
 
26
30
  if os.path.isdir(path):
31
+ if not exist_ok:
32
+ raise FileExistsError(f'Records already exist: {path}')
27
33
  if not overwrite:
28
34
  return
29
35
  else:
@@ -60,9 +66,11 @@ def write(
60
66
  chunk_sizes[i % num_files] += 1
61
67
 
62
68
  input_chunks = []
69
+ start_indices = []
63
70
  current_index = 0
64
71
  for size in chunk_sizes:
65
72
  input_chunks.append(inputs[current_index: current_index + size])
73
+ start_indices.append(current_index)
66
74
  current_index += size
67
75
 
68
76
  assert current_index == num_examples
@@ -73,13 +81,13 @@ def write(
73
81
  ]
74
82
 
75
83
  if not multiprocessing:
76
- for path, input_chunk in zip(paths, input_chunks):
77
- _write_tfrecord(input_chunk, path, featurizer)
84
+ for path, input_chunk, start_index in zip(paths, input_chunks, start_indices):
85
+ _write_tfrecord(input_chunk, path, featurizer, start_index)
78
86
  return
79
87
 
80
88
  processes = []
81
89
 
82
- for path, input_chunk in zip(paths, input_chunks):
90
+ for path, input_chunk, start_index in zip(paths, input_chunks, start_indices):
83
91
 
84
92
  while len(processes) >= num_processes:
85
93
  for process in processes:
@@ -91,7 +99,7 @@ def write(
91
99
 
92
100
  process = mp.Process(
93
101
  target=_write_tfrecord,
94
- args=(input_chunk, path, featurizer)
102
+ args=(input_chunk, path, featurizer, start_index)
95
103
  )
96
104
  processes.append(process)
97
105
  process.start()
@@ -134,9 +142,10 @@ def load_spec(path: str) -> tensors.GraphTensor.Spec:
134
142
  return spec
135
143
 
136
144
  def _write_tfrecord(
137
- inputs,
145
+ inputs: list[str, tuple],
138
146
  path: str,
139
- featurizer: featurizers.GraphFeaturizer,
147
+ featurizer: 'featurizers.GraphFeaturizer',
148
+ start_index: int,
140
149
  ) -> None:
141
150
 
142
151
  def _write_example(tensor):
@@ -147,12 +156,17 @@ def _write_tfrecord(
147
156
  writer.write(serialized_feature)
148
157
 
149
158
  with tf.io.TFRecordWriter(path) as writer:
150
- for x in inputs:
159
+ for i, x in enumerate(inputs):
151
160
  if isinstance(x, (list, np.ndarray)):
152
161
  x = tuple(x)
153
- tensor = featurizer(x)
154
- if tensor is not None:
162
+ try:
163
+ tensor = featurizer(x)
155
164
  _write_example(tensor)
165
+ except Exception as e:
166
+ warnings.warn(
167
+ f'Could not write record for index {i + start_index}, proceeding without it.'
168
+ f'Exception raised:\n{e}'
169
+ )
156
170
 
157
171
  def _serialize_example(
158
172
  feature: dict[str, tf.train.Feature]
@@ -1,3 +1,4 @@
1
+ import warnings
1
2
  import tensorflow as tf
2
3
  import keras
3
4
  import typing
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: molcraft
3
- Version: 0.1.0a21
3
+ Version: 0.1.0a23
4
4
  Summary: Graph Neural Networks for Molecular Machine Learning
5
5
  Author-email: Alexander Kensert <alexander.kensert@gmail.com>
6
6
  License: MIT License
@@ -35,7 +35,6 @@ Requires-Python: >=3.10
35
35
  Description-Content-Type: text/markdown
36
36
  License-File: LICENSE
37
37
  Requires-Dist: tensorflow>=2.16
38
- Requires-Dist: tensorflow-text>=2.16
39
38
  Requires-Dist: rdkit>=2023.9.5
40
39
  Requires-Dist: pandas>=1.0.3
41
40
  Requires-Dist: ipython>=8.12.0
@@ -43,7 +42,7 @@ Provides-Extra: gpu
43
42
  Requires-Dist: tensorflow[and-cuda]>=2.16; extra == "gpu"
44
43
  Dynamic: license-file
45
44
 
46
- <img src="https://github.com/akensert/molcraft/blob/main/docs/_static/molcraft-logo.png" alt="molcraft-logo", width="90%">
45
+ <img src="https://github.com/akensert/molcraft/blob/main/docs/_static/molcraft-logo.png" alt="molcraft-logo" width="90%">
47
46
 
48
47
  **Deep Learning on Molecules**: A Minimalistic GNN package for Molecular ML.
49
48
 
@@ -6,6 +6,7 @@ molcraft/callbacks.py
6
6
  molcraft/chem.py
7
7
  molcraft/datasets.py
8
8
  molcraft/descriptors.py
9
+ molcraft/diffusion.py
9
10
  molcraft/features.py
10
11
  molcraft/featurizers.py
11
12
  molcraft/layers.py
@@ -1,5 +1,4 @@
1
1
  tensorflow>=2.16
2
- tensorflow-text>=2.16
3
2
  rdkit>=2023.9.5
4
3
  pandas>=1.0.3
5
4
  ipython>=8.12.0
@@ -26,7 +26,6 @@ classifiers = [
26
26
  requires-python = ">=3.10"
27
27
  dependencies = [
28
28
  "tensorflow>=2.16",
29
- "tensorflow-text>=2.16",
30
29
  "rdkit>=2023.9.5",
31
30
  "pandas>=1.0.3",
32
31
  "ipython>=8.12.0"
@@ -12,3 +12,6 @@ class TestChem(unittest.TestCase):
12
12
  "N1[C@@H](CCC1)C(=O)O",
13
13
  ]
14
14
 
15
+
16
+ if __name__ == '__main__':
17
+ unittest.main()
@@ -3,6 +3,7 @@ import tempfile
3
3
  import shutil
4
4
 
5
5
  from molcraft import features
6
+ from molcraft import descriptors
6
7
  from molcraft import featurizers
7
8
 
8
9
 
@@ -129,6 +130,9 @@ class TestFeaturizer(unittest.TestCase):
129
130
  pair_features=[
130
131
  features.PairDistance(max_distance=20)
131
132
  ],
133
+ molecule_features=[
134
+ descriptors.ForceFieldEnergy(),
135
+ ],
132
136
  super_node=True,
133
137
  self_loops=False,
134
138
  include_hydrogens=False,
@@ -199,6 +203,5 @@ class TestFeaturizer(unittest.TestCase):
199
203
  self.assertEqual(graph.edge['target'].dtype.name, 'int32')
200
204
 
201
205
 
202
-
203
206
  if __name__ == '__main__':
204
207
  unittest.main()
@@ -21,4 +21,8 @@ class TestLoss(unittest.TestCase):
21
21
  keras.ops.array([[2., 0.1], [4., 0.2], [5., 0.3]])
22
22
  )
23
23
  self.assertGreater(value, 0)
24
- self.assertEqual(len(keras.ops.shape(value)), 0)
24
+ self.assertEqual(len(keras.ops.shape(value)), 0)
25
+
26
+
27
+ if __name__ == '__main__':
28
+ unittest.main()
@@ -266,4 +266,8 @@ class TestModel(unittest.TestCase):
266
266
  model = get_model(tensor)
267
267
  out = model.embedding()(tensor)
268
268
  self.assertTrue(out.shape[0] == tensor.context['size'].shape[0])
269
- self.assertTrue(out.shape[1] == units)
269
+ self.assertTrue(out.shape[1] == units)
270
+
271
+
272
+ if __name__ == '__main__':
273
+ unittest.main()
File without changes
@@ -1,5 +1,5 @@
1
- import keras
2
1
  import warnings
2
+ import keras
3
3
  import numpy as np
4
4
 
5
5
 
File without changes