molcraft 0.1.0a21__tar.gz → 0.1.0a23__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of molcraft might be problematic. Click here for more details.
- {molcraft-0.1.0a21 → molcraft-0.1.0a23}/PKG-INFO +2 -3
- {molcraft-0.1.0a21 → molcraft-0.1.0a23}/README.md +1 -1
- {molcraft-0.1.0a21 → molcraft-0.1.0a23}/molcraft/__init__.py +1 -3
- {molcraft-0.1.0a21 → molcraft-0.1.0a23}/molcraft/chem.py +51 -30
- {molcraft-0.1.0a21 → molcraft-0.1.0a23}/molcraft/datasets.py +1 -0
- {molcraft-0.1.0a21 → molcraft-0.1.0a23}/molcraft/descriptors.py +37 -5
- molcraft-0.1.0a23/molcraft/diffusion.py +241 -0
- {molcraft-0.1.0a21 → molcraft-0.1.0a23}/molcraft/features.py +4 -5
- {molcraft-0.1.0a21 → molcraft-0.1.0a23}/molcraft/featurizers.py +13 -2
- {molcraft-0.1.0a21 → molcraft-0.1.0a23}/molcraft/layers.py +20 -21
- {molcraft-0.1.0a21 → molcraft-0.1.0a23}/molcraft/losses.py +1 -0
- {molcraft-0.1.0a21 → molcraft-0.1.0a23}/molcraft/models.py +4 -1
- {molcraft-0.1.0a21 → molcraft-0.1.0a23}/molcraft/ops.py +1 -0
- {molcraft-0.1.0a21 → molcraft-0.1.0a23}/molcraft/records.py +26 -12
- {molcraft-0.1.0a21 → molcraft-0.1.0a23}/molcraft/tensors.py +1 -0
- {molcraft-0.1.0a21 → molcraft-0.1.0a23}/molcraft.egg-info/PKG-INFO +2 -3
- {molcraft-0.1.0a21 → molcraft-0.1.0a23}/molcraft.egg-info/SOURCES.txt +1 -0
- {molcraft-0.1.0a21 → molcraft-0.1.0a23}/molcraft.egg-info/requires.txt +0 -1
- {molcraft-0.1.0a21 → molcraft-0.1.0a23}/pyproject.toml +0 -1
- {molcraft-0.1.0a21 → molcraft-0.1.0a23}/tests/test_chem.py +3 -0
- {molcraft-0.1.0a21 → molcraft-0.1.0a23}/tests/test_featurizers.py +4 -1
- {molcraft-0.1.0a21 → molcraft-0.1.0a23}/tests/test_losses.py +5 -1
- {molcraft-0.1.0a21 → molcraft-0.1.0a23}/tests/test_models.py +5 -1
- {molcraft-0.1.0a21 → molcraft-0.1.0a23}/LICENSE +0 -0
- {molcraft-0.1.0a21 → molcraft-0.1.0a23}/molcraft/applications/__init__.py +0 -0
- {molcraft-0.1.0a21 → molcraft-0.1.0a23}/molcraft/applications/chromatography.py +0 -0
- {molcraft-0.1.0a21 → molcraft-0.1.0a23}/molcraft/applications/proteomics.py +0 -0
- {molcraft-0.1.0a21 → molcraft-0.1.0a23}/molcraft/callbacks.py +1 -1
- {molcraft-0.1.0a21 → molcraft-0.1.0a23}/molcraft.egg-info/dependency_links.txt +0 -0
- {molcraft-0.1.0a21 → molcraft-0.1.0a23}/molcraft.egg-info/top_level.txt +0 -0
- {molcraft-0.1.0a21 → molcraft-0.1.0a23}/setup.cfg +0 -0
- {molcraft-0.1.0a21 → molcraft-0.1.0a23}/tests/test_layers.py +0 -0
- {molcraft-0.1.0a21 → molcraft-0.1.0a23}/tests/test_tensors.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: molcraft
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.0a23
|
|
4
4
|
Summary: Graph Neural Networks for Molecular Machine Learning
|
|
5
5
|
Author-email: Alexander Kensert <alexander.kensert@gmail.com>
|
|
6
6
|
License: MIT License
|
|
@@ -35,7 +35,6 @@ Requires-Python: >=3.10
|
|
|
35
35
|
Description-Content-Type: text/markdown
|
|
36
36
|
License-File: LICENSE
|
|
37
37
|
Requires-Dist: tensorflow>=2.16
|
|
38
|
-
Requires-Dist: tensorflow-text>=2.16
|
|
39
38
|
Requires-Dist: rdkit>=2023.9.5
|
|
40
39
|
Requires-Dist: pandas>=1.0.3
|
|
41
40
|
Requires-Dist: ipython>=8.12.0
|
|
@@ -43,7 +42,7 @@ Provides-Extra: gpu
|
|
|
43
42
|
Requires-Dist: tensorflow[and-cuda]>=2.16; extra == "gpu"
|
|
44
43
|
Dynamic: license-file
|
|
45
44
|
|
|
46
|
-
<img src="https://github.com/akensert/molcraft/blob/main/docs/_static/molcraft-logo.png" alt="molcraft-logo"
|
|
45
|
+
<img src="https://github.com/akensert/molcraft/blob/main/docs/_static/molcraft-logo.png" alt="molcraft-logo" width="90%">
|
|
47
46
|
|
|
48
47
|
**Deep Learning on Molecules**: A Minimalistic GNN package for Molecular ML.
|
|
49
48
|
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
<img src="https://github.com/akensert/molcraft/blob/main/docs/_static/molcraft-logo.png" alt="molcraft-logo"
|
|
1
|
+
<img src="https://github.com/akensert/molcraft/blob/main/docs/_static/molcraft-logo.png" alt="molcraft-logo" width="90%">
|
|
2
2
|
|
|
3
3
|
**Deep Learning on Molecules**: A Minimalistic GNN package for Molecular ML.
|
|
4
4
|
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
__version__ = '0.1.
|
|
1
|
+
__version__ = '0.1.0a23'
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
4
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
|
@@ -15,5 +15,3 @@ from molcraft import tensors
|
|
|
15
15
|
from molcraft import callbacks
|
|
16
16
|
from molcraft import datasets
|
|
17
17
|
from molcraft import losses
|
|
18
|
-
|
|
19
|
-
from molcraft.applications import proteomics
|
|
@@ -3,6 +3,7 @@ import collections
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
|
|
5
5
|
from rdkit import Chem
|
|
6
|
+
from rdkit.Chem import AllChem
|
|
6
7
|
from rdkit.Chem import Lipinski
|
|
7
8
|
from rdkit.Chem import rdDistGeom
|
|
8
9
|
from rdkit.Chem import rdDepictor
|
|
@@ -22,12 +23,17 @@ class Mol(Chem.Mol):
|
|
|
22
23
|
if explicit_hs:
|
|
23
24
|
rdkit_mol = Chem.AddHs(rdkit_mol)
|
|
24
25
|
rdkit_mol.__class__ = cls
|
|
26
|
+
setattr(rdkit_mol, '_encoding', encoding)
|
|
25
27
|
return rdkit_mol
|
|
26
28
|
|
|
27
29
|
@property
|
|
28
30
|
def canonical_smiles(self) -> str:
|
|
29
31
|
return Chem.MolToSmiles(self, canonical=True)
|
|
30
32
|
|
|
33
|
+
@property
|
|
34
|
+
def encoding(self):
|
|
35
|
+
return getattr(self, '_encoding', None)
|
|
36
|
+
|
|
31
37
|
@property
|
|
32
38
|
def bonds(self) -> list['Bond']:
|
|
33
39
|
if not hasattr(self, '_bonds'):
|
|
@@ -60,7 +66,7 @@ class Mol(Chem.Mol):
|
|
|
60
66
|
atom = atom.GetIdx()
|
|
61
67
|
return Atom.cast(self.GetAtomWithIdx(int(atom)))
|
|
62
68
|
|
|
63
|
-
def
|
|
69
|
+
def get_shortest_path_between_atoms(
|
|
64
70
|
self,
|
|
65
71
|
atom_i: int | Chem.Atom,
|
|
66
72
|
atom_j: int | Chem.Atom
|
|
@@ -100,13 +106,13 @@ class Mol(Chem.Mol):
|
|
|
100
106
|
|
|
101
107
|
def get_conformer(self, index: int = 0) -> 'Conformer':
|
|
102
108
|
if self.num_conformers == 0:
|
|
103
|
-
warnings.warn('
|
|
109
|
+
warnings.warn(f'{self} has no conformer. Returning None.')
|
|
104
110
|
return None
|
|
105
111
|
return Conformer.cast(self.GetConformer(index))
|
|
106
112
|
|
|
107
113
|
def get_conformers(self) -> list['Conformer']:
|
|
108
114
|
if self.num_conformers == 0:
|
|
109
|
-
warnings.warn('
|
|
115
|
+
warnings.warn(f'{self} has no conformers. Returning an empty list.')
|
|
110
116
|
return []
|
|
111
117
|
return [Conformer.cast(x) for x in self.GetConformers()]
|
|
112
118
|
|
|
@@ -117,7 +123,8 @@ class Mol(Chem.Mol):
|
|
|
117
123
|
return None
|
|
118
124
|
|
|
119
125
|
def __repr__(self) -> str:
|
|
120
|
-
|
|
126
|
+
encoding = self.encoding or self.canonical_smiles
|
|
127
|
+
return f'<{self.__class__.__name__} {encoding} at {hex(id(self))}>'
|
|
121
128
|
|
|
122
129
|
|
|
123
130
|
class Conformer(Chem.Conformer):
|
|
@@ -244,7 +251,10 @@ def sanitize_mol(
|
|
|
244
251
|
flag = Chem.SanitizeMol(mol, catchErrors=True)
|
|
245
252
|
if flag != Chem.SanitizeFlags.SANITIZE_NONE:
|
|
246
253
|
if strict:
|
|
247
|
-
|
|
254
|
+
raise ValueError(f'Could not sanitize {mol}.')
|
|
255
|
+
warnings.warn(
|
|
256
|
+
f'Could not sanitize {mol}. Proceeding with partial sanitization.'
|
|
257
|
+
)
|
|
248
258
|
# Sanitize mol, excluding the steps causing the error previously
|
|
249
259
|
Chem.SanitizeMol(mol, sanitizeOps=Chem.SanitizeFlags.SANITIZE_ALL^flag)
|
|
250
260
|
if assign_stereo_chemistry:
|
|
@@ -391,6 +401,7 @@ def embed_conformers(
|
|
|
391
401
|
mol: Mol,
|
|
392
402
|
num_conformers: int,
|
|
393
403
|
method: str = 'ETKDGv3',
|
|
404
|
+
timeout: int | None = None,
|
|
394
405
|
random_seed: int | None = None,
|
|
395
406
|
**kwargs
|
|
396
407
|
) -> Mol:
|
|
@@ -405,16 +416,22 @@ def embed_conformers(
|
|
|
405
416
|
mol = Mol(mol)
|
|
406
417
|
embedding_method = available_embedding_methods.get(method)
|
|
407
418
|
if embedding_method is None:
|
|
408
|
-
|
|
409
|
-
f'
|
|
410
|
-
'`ETDG`, `ETKDG`, `ETKDGv2`, `ETKDGv3`, `srETKDGv3` or `KDG`.'
|
|
419
|
+
warnings.warn(
|
|
420
|
+
f'{method} is not available. Proceeding with ETKDGv3.'
|
|
411
421
|
)
|
|
422
|
+
embedding_method = available_embedding_methods['ETKDGv3']
|
|
412
423
|
|
|
413
424
|
for key, value in kwargs.items():
|
|
414
425
|
setattr(embedding_method, key, value)
|
|
415
426
|
|
|
416
|
-
if
|
|
417
|
-
|
|
427
|
+
if not timeout:
|
|
428
|
+
timeout = 0 # No timeout
|
|
429
|
+
|
|
430
|
+
if not random_seed:
|
|
431
|
+
random_seed = -1 # No random seed
|
|
432
|
+
|
|
433
|
+
embedding_method.randomSeed = random_seed
|
|
434
|
+
embedding_method.timeout = timeout
|
|
418
435
|
|
|
419
436
|
success = rdDistGeom.EmbedMultipleConfs(
|
|
420
437
|
mol, numConfs=num_conformers, params=embedding_method
|
|
@@ -422,17 +439,18 @@ def embed_conformers(
|
|
|
422
439
|
num_successes = len(success)
|
|
423
440
|
if num_successes < num_conformers:
|
|
424
441
|
warnings.warn(
|
|
425
|
-
f'Could only embed {num_successes} out of {num_conformers} conformer(s) '
|
|
426
|
-
f'
|
|
427
|
-
f'{num_conformers
|
|
428
|
-
stacklevel=2
|
|
442
|
+
f'Could only embed {num_successes} out of {num_conformers} conformer(s) for '
|
|
443
|
+
f'{mol} using the specified method ({method}) and parameters. Attempting to '
|
|
444
|
+
f'embed the remaining {num_conformers-num_successes} using fallback methods.',
|
|
429
445
|
)
|
|
446
|
+
max_iters = 20 * mol.num_atoms # Doubling the number of iterations
|
|
430
447
|
for fallback_method in [method, 'ETDG', 'KDG']:
|
|
431
448
|
fallback_embedding_method = available_embedding_methods[fallback_method]
|
|
432
449
|
fallback_embedding_method.useRandomCoords = True
|
|
450
|
+
fallback_embedding_method.maxIterations = int(max_iters)
|
|
433
451
|
fallback_embedding_method.clearConfs = False
|
|
434
|
-
|
|
435
|
-
|
|
452
|
+
fallback_embedding_method.timeout = int(timeout)
|
|
453
|
+
fallback_embedding_method.randomSeed = int(random_seed)
|
|
436
454
|
success = rdDistGeom.EmbedMultipleConfs(
|
|
437
455
|
mol, numConfs=(num_conformers - num_successes), params=fallback_embedding_method
|
|
438
456
|
)
|
|
@@ -441,10 +459,13 @@ def embed_conformers(
|
|
|
441
459
|
break
|
|
442
460
|
else:
|
|
443
461
|
raise RuntimeError(
|
|
444
|
-
f'Could not embed {num_conformers} conformer(s) for {mol
|
|
462
|
+
f'Could not embed {num_conformers} conformer(s) for {mol}. '
|
|
445
463
|
)
|
|
446
464
|
return mol
|
|
447
465
|
|
|
466
|
+
|
|
467
|
+
import warnings
|
|
468
|
+
|
|
448
469
|
def optimize_conformers(
|
|
449
470
|
mol: Mol,
|
|
450
471
|
method: str = 'UFF',
|
|
@@ -453,14 +474,17 @@ def optimize_conformers(
|
|
|
453
474
|
ignore_interfragment_interactions: bool = True,
|
|
454
475
|
vdw_threshold: float = 10.0,
|
|
455
476
|
) -> Mol:
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
477
|
+
if mol.num_conformers == 0:
|
|
478
|
+
warnings.warn(
|
|
479
|
+
f'{mol} has no conformers to optimize. Proceeding without it.'
|
|
480
|
+
)
|
|
481
|
+
return Mol(mol)
|
|
482
|
+
available_force_field_methods = ['MMFF', 'MMFF94', 'MMFF94s', 'UFF']
|
|
459
483
|
if method not in available_force_field_methods:
|
|
460
|
-
|
|
461
|
-
f'
|
|
462
|
-
'`UFF`, `MMFF`, `MMFF94` or `MMFF94s`.'
|
|
484
|
+
warnings.warn(
|
|
485
|
+
f'{method} is not available. Proceeding with universal force field (UFF).'
|
|
463
486
|
)
|
|
487
|
+
method = 'UFF'
|
|
464
488
|
mol_optimized = Mol(mol)
|
|
465
489
|
try:
|
|
466
490
|
if method.startswith('MMFF'):
|
|
@@ -484,10 +508,9 @@ def optimize_conformers(
|
|
|
484
508
|
)
|
|
485
509
|
except RuntimeError as e:
|
|
486
510
|
warnings.warn(
|
|
487
|
-
f'{method} force field minimization
|
|
488
|
-
stacklevel=2
|
|
511
|
+
f'Unsuccessful {method} force field minimization for {mol}. Proceeding without it.',
|
|
489
512
|
)
|
|
490
|
-
return mol
|
|
513
|
+
return Mol(mol)
|
|
491
514
|
return mol_optimized
|
|
492
515
|
|
|
493
516
|
def prune_conformers(
|
|
@@ -498,11 +521,9 @@ def prune_conformers(
|
|
|
498
521
|
) -> Mol:
|
|
499
522
|
if mol.num_conformers == 0:
|
|
500
523
|
warnings.warn(
|
|
501
|
-
'
|
|
502
|
-
'and optionally followed by `minimize()` to perform force field minimization.',
|
|
503
|
-
stacklevel=2
|
|
524
|
+
f'{mol} has no conformers to prune. Proceeding without it.'
|
|
504
525
|
)
|
|
505
|
-
return mol
|
|
526
|
+
return Chem.Mol(mol)
|
|
506
527
|
|
|
507
528
|
threshold = threshold or 0.0
|
|
508
529
|
deviations = conformer_deviations(mol)
|
|
@@ -1,5 +1,7 @@
|
|
|
1
|
+
import warnings
|
|
1
2
|
import keras
|
|
2
3
|
import numpy as np
|
|
4
|
+
|
|
3
5
|
from rdkit.Chem import rdMolDescriptors
|
|
4
6
|
|
|
5
7
|
from molcraft import chem
|
|
@@ -12,9 +14,7 @@ class Descriptor(features.Feature):
|
|
|
12
14
|
def __call__(self, mol: chem.Mol) -> np.ndarray:
|
|
13
15
|
if not isinstance(mol, chem.Mol):
|
|
14
16
|
raise ValueError(
|
|
15
|
-
f'Input to {self.name}
|
|
16
|
-
'implements two properties that should be iterated over '
|
|
17
|
-
'to compute features: `atoms` and `bonds`.'
|
|
17
|
+
f'Input to {self.name} must be a `chem.Mol` object.'
|
|
18
18
|
)
|
|
19
19
|
descriptor = self.call(mol)
|
|
20
20
|
func = (
|
|
@@ -30,6 +30,23 @@ class Descriptor(features.Feature):
|
|
|
30
30
|
return np.concatenate(descriptors)
|
|
31
31
|
|
|
32
32
|
|
|
33
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
34
|
+
class Descriptor3D(Descriptor):
|
|
35
|
+
|
|
36
|
+
def __call__(self, mol: chem.Mol) -> np.ndarray:
|
|
37
|
+
if not isinstance(mol, chem.Mol):
|
|
38
|
+
raise ValueError(
|
|
39
|
+
f'Input to {self.name} must be a `chem.Mol` object.'
|
|
40
|
+
)
|
|
41
|
+
if mol.num_conformers == 0:
|
|
42
|
+
raise ValueError(
|
|
43
|
+
f'The inputted `chem.Mol` to {self.name} must embed a conformer. '
|
|
44
|
+
f'It is recommended that {self.name} is used as a molecule feature '
|
|
45
|
+
'for `MolGraphFeaturizer3D`, which by default embeds a conformer.'
|
|
46
|
+
)
|
|
47
|
+
return super().__call__(mol)
|
|
48
|
+
|
|
49
|
+
|
|
33
50
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
34
51
|
class MolWeight(Descriptor):
|
|
35
52
|
def call(self, mol: chem.Mol) -> np.ndarray:
|
|
@@ -77,7 +94,7 @@ class NumHydrogenDonors(Descriptor):
|
|
|
77
94
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
78
95
|
class NumHydrogenAcceptors(Descriptor):
|
|
79
96
|
def call(self, mol: chem.Mol) -> np.ndarray:
|
|
80
|
-
return rdMolDescriptors.CalcNumHBA(mol)
|
|
97
|
+
return rdMolDescriptors.CalcNumHBA(mol)
|
|
81
98
|
|
|
82
99
|
|
|
83
100
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
@@ -89,7 +106,7 @@ class NumRotatableBonds(Descriptor):
|
|
|
89
106
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
90
107
|
class NumRings(Descriptor):
|
|
91
108
|
def call(self, mol: chem.Mol) -> np.ndarray:
|
|
92
|
-
return rdMolDescriptors.CalcNumRings(mol)
|
|
109
|
+
return rdMolDescriptors.CalcNumRings(mol)
|
|
93
110
|
|
|
94
111
|
|
|
95
112
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
@@ -105,3 +122,18 @@ class AtomCount(Descriptor):
|
|
|
105
122
|
if atom.GetSymbol() == self.atom_type:
|
|
106
123
|
count += 1
|
|
107
124
|
return count
|
|
125
|
+
|
|
126
|
+
def get_config(self) -> dict:
|
|
127
|
+
config = super().get_config()
|
|
128
|
+
config['atom_type'] = self.atom_type
|
|
129
|
+
return config
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
133
|
+
class ForceFieldEnergy(Descriptor3D):
|
|
134
|
+
"""Universal Force Field (UFF) Energy."""
|
|
135
|
+
def call(self, mol: chem.Mol) -> np.ndarray:
|
|
136
|
+
mol_copy = chem.Mol(mol)
|
|
137
|
+
mol_copy = chem.add_hs(mol_copy)
|
|
138
|
+
return chem.conformer_energies(mol_copy, method="UFF")
|
|
139
|
+
|
|
@@ -0,0 +1,241 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
import keras
|
|
3
|
+
import tensorflow as tf
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
from molcraft import ops
|
|
7
|
+
from molcraft import tensors
|
|
8
|
+
from molcraft import layers
|
|
9
|
+
from molcraft import models
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
# smiles = pd.read_csv('../../data/rt/RIKEN.csv')['smiles'].values
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
# graph = featurizers.MolGraphFeaturizer3D(super_node=False)(smiles)
|
|
16
|
+
# graph.node['coordinate']
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
# encoder = molcraft.models.GraphModel.from_layers(
|
|
20
|
+
# [
|
|
21
|
+
# diffusion.CoordinateNoise(),
|
|
22
|
+
# molcraft.layers.NodeEmbedding(128),
|
|
23
|
+
# molcraft.layers.EdgeEmbedding(128),
|
|
24
|
+
# molcraft.layers.AddContext('position'),
|
|
25
|
+
# molcraft.layers.MPConv(128),
|
|
26
|
+
# molcraft.layers.AddContext('position'),
|
|
27
|
+
# molcraft.layers.MPConv(128),
|
|
28
|
+
# molcraft.layers.AddContext('position'),
|
|
29
|
+
# ]
|
|
30
|
+
# )
|
|
31
|
+
|
|
32
|
+
# decoder = keras.Sequential([
|
|
33
|
+
# keras.layers.Dense(128, activation='relu'),
|
|
34
|
+
# keras.layers.Dense(3),
|
|
35
|
+
# ])
|
|
36
|
+
|
|
37
|
+
# model = diffusion.CoordinateNoisePredictor(encoder, decoder)
|
|
38
|
+
|
|
39
|
+
# model(graph)
|
|
40
|
+
|
|
41
|
+
# model.save('/tmp/model.keras')
|
|
42
|
+
# model = molcraft.models.load_model('/tmp/model.keras')
|
|
43
|
+
|
|
44
|
+
# model.compile(keras.optimizers.Adam(1e-3), 'mse')
|
|
45
|
+
# model.fit(graph, epochs=100)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
# from rdkit.Geometry import Point3D
|
|
49
|
+
|
|
50
|
+
# def energy(smiles, coordinate):
|
|
51
|
+
# m = chem.Mol.from_encoding(smiles)
|
|
52
|
+
# m = chem.embed_conformers(m, 1)
|
|
53
|
+
# conf = m.GetConformer()
|
|
54
|
+
# for i in range(m.GetNumAtoms()):
|
|
55
|
+
# x, y, z = coordinate[i]
|
|
56
|
+
# conf.SetAtomPosition(i, Point3D(float(x), float(y), float(z)))
|
|
57
|
+
# return m, chem.conformer_energies(m)[0]
|
|
58
|
+
|
|
59
|
+
# def denoise(
|
|
60
|
+
# graph: tensors.GraphTensor,
|
|
61
|
+
# model,
|
|
62
|
+
# ):
|
|
63
|
+
|
|
64
|
+
# print("----")
|
|
65
|
+
# print(energy(smiles[0], graph[0].node['coordinate'])[-1])
|
|
66
|
+
# print("----")
|
|
67
|
+
|
|
68
|
+
# beta = keras.ops.linspace(1e-4, 1e-2, 100)
|
|
69
|
+
# alpha = 1 - beta
|
|
70
|
+
# alpha_bar = keras.ops.cumprod(alpha)
|
|
71
|
+
# sigma = keras.ops.sqrt(beta[1:] * (1.0 - alpha_bar[:-1]) / (1.0 - alpha_bar[1:]))
|
|
72
|
+
|
|
73
|
+
# graph = graph.update(
|
|
74
|
+
# {
|
|
75
|
+
# 'context': {
|
|
76
|
+
# 'position': keras.ops.ones_like(graph.context['size']) * 99
|
|
77
|
+
# },
|
|
78
|
+
# 'node': {
|
|
79
|
+
# 'coordinate': keras.random.normal(graph.node['coordinate'].shape)
|
|
80
|
+
# }
|
|
81
|
+
# }
|
|
82
|
+
# )
|
|
83
|
+
|
|
84
|
+
# for t in reversed(range(100)):
|
|
85
|
+
# alpha_t = alpha[t]
|
|
86
|
+
# alpha_bar_t = alpha_bar[t]
|
|
87
|
+
|
|
88
|
+
# a = 1 / keras.ops.sqrt(alpha_t)
|
|
89
|
+
|
|
90
|
+
# b = (1 - alpha_t) / keras.ops.sqrt(1 - alpha_bar_t)
|
|
91
|
+
|
|
92
|
+
# if t > 0:
|
|
93
|
+
# z = keras.random.normal(()) * sigma[t-1]
|
|
94
|
+
# else:
|
|
95
|
+
# z = 0.0
|
|
96
|
+
|
|
97
|
+
# graph = graph.update({
|
|
98
|
+
# 'node': {
|
|
99
|
+
# 'coordinate': (
|
|
100
|
+
# a * (graph.node['coordinate'] - b * model(graph)) + z
|
|
101
|
+
# )
|
|
102
|
+
# }
|
|
103
|
+
# })
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
# print(energy(smiles[0], graph[0].node['coordinate'])[-1])
|
|
107
|
+
|
|
108
|
+
# return graph
|
|
109
|
+
|
|
110
|
+
# graph_updated = denoise(graph[:1], model)x
|
|
111
|
+
# mol, e = energy(smiles[0], graph_updated[0].node['coordinate'])
|
|
112
|
+
# print(e)
|
|
113
|
+
# Chem.Mol(mol)
|
|
114
|
+
|
|
115
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
116
|
+
class CoordinateNoisePredictor(models.GraphModel):
|
|
117
|
+
|
|
118
|
+
def __init__(self, encoder, decoder, *args, **kwargs):
|
|
119
|
+
super().__init__(*args, **kwargs)
|
|
120
|
+
self.encoder = encoder
|
|
121
|
+
self.decoder = decoder
|
|
122
|
+
|
|
123
|
+
def propagate(self, tensor):
|
|
124
|
+
return self.decoder(self.encoder(tensor).node['feature'])
|
|
125
|
+
|
|
126
|
+
def train_step(self, tensor: tensors.GraphTensor) -> dict[str, float]:
|
|
127
|
+
with tf.GradientTape() as tape:
|
|
128
|
+
tensor = self.encoder(tensor)
|
|
129
|
+
feature = tensor.node['feature']
|
|
130
|
+
noise_true = tensor.node['label']
|
|
131
|
+
noise_pred = self.decoder(feature)
|
|
132
|
+
loss = self.compute_loss(tensor, noise_true, noise_pred)
|
|
133
|
+
loss = self.optimizer.scale_loss(loss)
|
|
134
|
+
trainable_weights = self.trainable_weights
|
|
135
|
+
gradients = tape.gradient(loss, trainable_weights)
|
|
136
|
+
self.optimizer.apply_gradients(zip(gradients, trainable_weights))
|
|
137
|
+
return self.compute_metrics(tensor, noise_true, noise_pred)
|
|
138
|
+
|
|
139
|
+
def test_step(self, tensor: tensors.GraphTensor) -> dict[str, float]:
|
|
140
|
+
tensor = self.encoder(tensor)
|
|
141
|
+
feature = tensor.node['feature']
|
|
142
|
+
noise_true = tensor.node['label']
|
|
143
|
+
noise_pred = self.decoder(feature)
|
|
144
|
+
return self.compute_metrics(tensor, noise_true, noise_pred)
|
|
145
|
+
|
|
146
|
+
def get_config(self) -> dict:
|
|
147
|
+
config = super().get_config()
|
|
148
|
+
config['encoder'] = keras.saving.serialize_keras_object(self.encoder)
|
|
149
|
+
config['decoder'] = keras.saving.serialize_keras_object(self.decoder)
|
|
150
|
+
return config
|
|
151
|
+
|
|
152
|
+
@classmethod
|
|
153
|
+
def from_config(cls, config: dict):
|
|
154
|
+
config['encoder'] = keras.saving.deserialize_keras_object(config['encoder'])
|
|
155
|
+
config['decoder'] = keras.saving.deserialize_keras_object(config['decoder'])
|
|
156
|
+
return super().from_config(config)
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
160
|
+
class CoordinateNoise(layers.GraphLayer):
|
|
161
|
+
|
|
162
|
+
def __init__(
|
|
163
|
+
self,
|
|
164
|
+
beta: tuple[float, float] = (1e-4, 1e-2),
|
|
165
|
+
position_dim: int = 128,
|
|
166
|
+
max_timesteps: int = 100,
|
|
167
|
+
**kwargs
|
|
168
|
+
) -> None:
|
|
169
|
+
super().__init__(**kwargs)
|
|
170
|
+
self._beta = beta
|
|
171
|
+
self._max_timesteps = max_timesteps
|
|
172
|
+
beta = keras.ops.linspace(*self._beta, self._max_timesteps)
|
|
173
|
+
alpha = 1 - beta
|
|
174
|
+
alpha_cumprod = keras.ops.cumprod(alpha)
|
|
175
|
+
alpha_cumprod = keras.ops.expand_dims(alpha_cumprod, -1)
|
|
176
|
+
self._alpha_cumprod = alpha_cumprod
|
|
177
|
+
self._timestep_embedding = TimestepEmbedding(dim=position_dim)
|
|
178
|
+
|
|
179
|
+
def propagate(self, graph: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
180
|
+
if 'position' in graph.context:
|
|
181
|
+
return graph.update({'context': {'position': self._timestep_embedding(graph.context['position'])}})
|
|
182
|
+
|
|
183
|
+
timestep = keras.random.randint(
|
|
184
|
+
shape=(graph.num_subgraphs,), minval=0, maxval=self._max_timesteps
|
|
185
|
+
)
|
|
186
|
+
alpha_cumprod = ops.gather(
|
|
187
|
+
ops.gather(self._alpha_cumprod, timestep), graph.graph_indicator
|
|
188
|
+
)
|
|
189
|
+
epsilon = keras.random.normal(
|
|
190
|
+
shape=keras.ops.shape(graph.node['coordinate']), mean=0, stddev=1
|
|
191
|
+
)
|
|
192
|
+
noisy_coordinate = (
|
|
193
|
+
keras.ops.sqrt(alpha_cumprod) * graph.node['coordinate'] +
|
|
194
|
+
keras.ops.sqrt(1 - alpha_cumprod) * epsilon
|
|
195
|
+
)
|
|
196
|
+
timestep = self._timestep_embedding(timestep)
|
|
197
|
+
return graph.update(
|
|
198
|
+
{
|
|
199
|
+
'context': {
|
|
200
|
+
'position': timestep,
|
|
201
|
+
},
|
|
202
|
+
'node': {
|
|
203
|
+
'coordinate': noisy_coordinate,
|
|
204
|
+
'label': epsilon
|
|
205
|
+
},
|
|
206
|
+
}
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
def get_config(self) -> dict:
|
|
210
|
+
config = super().get_config()
|
|
211
|
+
config['beta'] = self._beta
|
|
212
|
+
config['max_timesteps'] = self._max_timesteps
|
|
213
|
+
return config
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
class TimestepEmbedding(keras.layers.Layer):
|
|
217
|
+
|
|
218
|
+
def __init__(self, dim: int, max_wavelength: int = 10000, **kwargs) -> None:
|
|
219
|
+
super().__init__(**kwargs)
|
|
220
|
+
self._dim = dim
|
|
221
|
+
self._max_wavelength = max_wavelength
|
|
222
|
+
|
|
223
|
+
def call(self, inputs: tf.Tensor) -> tf.Tensor:
|
|
224
|
+
timestep = keras.ops.cast(inputs, 'float32')
|
|
225
|
+
embedding = keras.ops.log(self._max_wavelength) / (self._dim // 2 - 1)
|
|
226
|
+
embedding = keras.ops.exp(
|
|
227
|
+
-embedding * keras.ops.arange(self._dim // 2, dtype='float32')
|
|
228
|
+
)
|
|
229
|
+
embedding = timestep[:, None] * embedding[None, :]
|
|
230
|
+
embedding = keras.ops.concatenate(
|
|
231
|
+
[keras.ops.sin(embedding), keras.ops.cos(embedding)], axis=-1
|
|
232
|
+
)
|
|
233
|
+
return embedding
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
def get_config(self) -> dict:
|
|
237
|
+
config = super().get_config()
|
|
238
|
+
config['dim'] = self._dim
|
|
239
|
+
config['max_wavelength'] = self._max_wavelength
|
|
240
|
+
return config
|
|
241
|
+
|
|
@@ -1,7 +1,7 @@
|
|
|
1
|
+
import warnings
|
|
1
2
|
import abc
|
|
2
3
|
import math
|
|
3
4
|
import keras
|
|
4
|
-
import warnings
|
|
5
5
|
import numpy as np
|
|
6
6
|
|
|
7
7
|
from molcraft import chem
|
|
@@ -41,14 +41,14 @@ class Feature(abc.ABC):
|
|
|
41
41
|
|
|
42
42
|
def __call__(self, mol: chem.Mol) -> np.ndarray:
|
|
43
43
|
if not isinstance(mol, chem.Mol):
|
|
44
|
-
raise TypeError(f'Input to {self.name} must be a `chem.Mol`
|
|
44
|
+
raise TypeError(f'Input to {self.name} must be a `chem.Mol` object.')
|
|
45
45
|
features = self.call(mol)
|
|
46
46
|
if len(features) != mol.num_atoms and len(features) != mol.num_bonds:
|
|
47
47
|
raise ValueError(
|
|
48
48
|
f'The number of features computed by {self.name} does not '
|
|
49
49
|
'match the number of atoms or bonds of the `chem.Mol` object. '
|
|
50
|
-
'Make sure to iterate over `atoms` or `bonds` of `chem.Mol` '
|
|
51
|
-
'when computing features.'
|
|
50
|
+
'Make sure to iterate over `atoms` or `bonds` of the `chem.Mol` '
|
|
51
|
+
'object when computing features.'
|
|
52
52
|
)
|
|
53
53
|
if len(features) == 0:
|
|
54
54
|
# Edge case: no atoms or bonds in the molecule.
|
|
@@ -109,7 +109,6 @@ class Feature(abc.ABC):
|
|
|
109
109
|
warnings.warn(
|
|
110
110
|
f'Found value of {self.name} to be non-finite. '
|
|
111
111
|
f'Value received: {value}. Converting it to a value of 0.',
|
|
112
|
-
stacklevel=2
|
|
113
112
|
)
|
|
114
113
|
value = 0.0
|
|
115
114
|
return np.asarray([value], dtype=self.dtype)
|
|
@@ -1,9 +1,8 @@
|
|
|
1
|
+
import warnings
|
|
1
2
|
import keras
|
|
2
3
|
import json
|
|
3
4
|
import abc
|
|
4
5
|
import typing
|
|
5
|
-
import copy
|
|
6
|
-
import warnings
|
|
7
6
|
import numpy as np
|
|
8
7
|
import pandas as pd
|
|
9
8
|
import tensorflow as tf
|
|
@@ -13,6 +12,7 @@ from pathlib import Path
|
|
|
13
12
|
|
|
14
13
|
from molcraft import tensors
|
|
15
14
|
from molcraft import features
|
|
15
|
+
from molcraft import records
|
|
16
16
|
from molcraft import chem
|
|
17
17
|
from molcraft import descriptors
|
|
18
18
|
|
|
@@ -41,6 +41,17 @@ class GraphFeaturizer(abc.ABC):
|
|
|
41
41
|
def load(filepath: str | Path, *args, **kwargs) -> 'GraphFeaturizer':
|
|
42
42
|
return load_featurizer(filepath, *args, **kwargs)
|
|
43
43
|
|
|
44
|
+
def write_records(self, inputs: str | chem.Mol | tuple, path: str | Path, **kwargs) -> None:
|
|
45
|
+
records.write(
|
|
46
|
+
inputs, featurizer=self, path=path, **kwargs
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
@staticmethod
|
|
50
|
+
def read_records(path: str | Path, **kwargs) -> tf.data.Dataset:
|
|
51
|
+
return records.read(
|
|
52
|
+
path=path, **kwargs
|
|
53
|
+
)
|
|
54
|
+
|
|
44
55
|
def __call__(
|
|
45
56
|
self,
|
|
46
57
|
inputs: str | chem.Mol | tuple | typing.Iterable,
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
+
import warnings
|
|
1
2
|
import keras
|
|
2
3
|
import tensorflow as tf
|
|
3
|
-
import warnings
|
|
4
4
|
import functools
|
|
5
5
|
from keras.src.models import functional
|
|
6
6
|
|
|
@@ -350,11 +350,8 @@ class GraphConv(GraphLayer):
|
|
|
350
350
|
)
|
|
351
351
|
if self._project_residual:
|
|
352
352
|
warnings.warn(
|
|
353
|
-
'
|
|
354
|
-
'
|
|
355
|
-
'Automatically applying a projection layer to residual to '
|
|
356
|
-
'match input and output. ',
|
|
357
|
-
stacklevel=2,
|
|
353
|
+
'Found incompatible dim between input and output. Applying '
|
|
354
|
+
'a projection layer to residual to match input and output dim.',
|
|
358
355
|
)
|
|
359
356
|
self._residual_dense = self.get_dense(
|
|
360
357
|
self.units, name='residual_dense'
|
|
@@ -613,10 +610,8 @@ class GIConv(GraphConv):
|
|
|
613
610
|
if not self._update_edge_feature:
|
|
614
611
|
if (edge_feature_dim != node_feature_dim):
|
|
615
612
|
warnings.warn(
|
|
616
|
-
'Found edge feature dim to be incompatible
|
|
617
|
-
'
|
|
618
|
-
'the dim of node features.',
|
|
619
|
-
stacklevel=2,
|
|
613
|
+
'Found edge and node feature dim to be incompatible. Applying a '
|
|
614
|
+
'projection layer to edge features to match the dim of the node features.',
|
|
620
615
|
)
|
|
621
616
|
self._update_edge_feature = True
|
|
622
617
|
|
|
@@ -870,10 +865,10 @@ class MPConv(GraphConv):
|
|
|
870
865
|
self._project_previous_node_feature = node_feature_dim != self.units
|
|
871
866
|
if self._project_previous_node_feature:
|
|
872
867
|
warnings.warn(
|
|
873
|
-
'
|
|
874
|
-
'
|
|
875
|
-
'
|
|
876
|
-
|
|
868
|
+
'Inputted node feature dim does not match updated node feature dim, '
|
|
869
|
+
'which is required for the GRU update. Applying a projection layer to '
|
|
870
|
+
'the inputted node features prior to the GRU update, to match dim '
|
|
871
|
+
'of the updated node feature dim.'
|
|
877
872
|
)
|
|
878
873
|
self._previous_node_dense = self.get_dense(self.units)
|
|
879
874
|
|
|
@@ -1497,6 +1492,7 @@ class AddContext(GraphLayer):
|
|
|
1497
1492
|
|
|
1498
1493
|
def build(self, spec: tensors.GraphTensor.Spec) -> None:
|
|
1499
1494
|
feature_dim = spec.node['feature'].shape[-1]
|
|
1495
|
+
self._has_super_node = 'super' in spec.node
|
|
1500
1496
|
if self._intermediate_dim is None:
|
|
1501
1497
|
self._intermediate_dim = feature_dim * 2
|
|
1502
1498
|
self._intermediate_dense = self.get_dense(
|
|
@@ -1515,9 +1511,14 @@ class AddContext(GraphLayer):
|
|
|
1515
1511
|
context = self._intermediate_dense(context)
|
|
1516
1512
|
context = self._intermediate_norm(context)
|
|
1517
1513
|
context = self._final_dense(context)
|
|
1518
|
-
|
|
1519
|
-
|
|
1520
|
-
|
|
1514
|
+
if self._has_super_node:
|
|
1515
|
+
node_feature = ops.scatter_add(
|
|
1516
|
+
tensor.node['feature'], tensor.node['super'], context
|
|
1517
|
+
)
|
|
1518
|
+
else:
|
|
1519
|
+
node_feature = (
|
|
1520
|
+
tensor.node['feature'] + ops.gather(context, tensor.graph_indicator)
|
|
1521
|
+
)
|
|
1521
1522
|
data = {'node': {'feature': node_feature}}
|
|
1522
1523
|
if self._drop:
|
|
1523
1524
|
data['context'] = {self._field: None}
|
|
@@ -1561,8 +1562,7 @@ class GraphNetwork(GraphLayer):
|
|
|
1561
1562
|
if self._update_node_feature:
|
|
1562
1563
|
warnings.warn(
|
|
1563
1564
|
'Node feature dim does not match `units` of the first layer. '
|
|
1564
|
-
'
|
|
1565
|
-
stacklevel=2
|
|
1565
|
+
'Applying a projection layer to node features to match `units`.',
|
|
1566
1566
|
)
|
|
1567
1567
|
self._node_dense = self.get_dense(units)
|
|
1568
1568
|
self._has_edge_feature = 'feature' in spec.edge
|
|
@@ -1572,8 +1572,7 @@ class GraphNetwork(GraphLayer):
|
|
|
1572
1572
|
if self._update_edge_feature:
|
|
1573
1573
|
warnings.warn(
|
|
1574
1574
|
'Edge feature dim does not match `units` of the first layer. '
|
|
1575
|
-
'
|
|
1576
|
-
stacklevel=2
|
|
1575
|
+
'Applying projection layer to edge features to match `units`.'
|
|
1577
1576
|
)
|
|
1578
1577
|
self._edge_dense = self.get_dense(units)
|
|
1579
1578
|
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import warnings
|
|
1
2
|
import typing
|
|
2
3
|
import keras
|
|
3
4
|
import numpy as np
|
|
@@ -111,7 +112,7 @@ class GraphModel(layers.GraphLayer, keras.models.Model):
|
|
|
111
112
|
def __new__(cls, *args, **kwargs):
|
|
112
113
|
if _functional_init_arguments(args, kwargs) and cls == GraphModel:
|
|
113
114
|
return FunctionalGraphModel(*args, **kwargs)
|
|
114
|
-
return
|
|
115
|
+
return super().__new__(cls)
|
|
115
116
|
|
|
116
117
|
def __init__(self, *args, **kwargs):
|
|
117
118
|
self._model_layers = kwargs.pop('model_layers', None)
|
|
@@ -137,6 +138,8 @@ class GraphModel(layers.GraphLayer, keras.models.Model):
|
|
|
137
138
|
"""
|
|
138
139
|
if not tensors.is_graph(graph_layers[0]):
|
|
139
140
|
return cls(model_layers=graph_layers)
|
|
141
|
+
elif cls != GraphModel:
|
|
142
|
+
return cls(model_layers=graph_layers[1:])
|
|
140
143
|
inputs: dict = graph_layers.pop(0)
|
|
141
144
|
x = inputs
|
|
142
145
|
for layer in graph_layers:
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import warnings
|
|
1
2
|
import os
|
|
2
3
|
import math
|
|
3
4
|
import glob
|
|
@@ -9,14 +10,17 @@ import pandas as pd
|
|
|
9
10
|
import multiprocessing as mp
|
|
10
11
|
|
|
11
12
|
from molcraft import tensors
|
|
12
|
-
|
|
13
|
+
|
|
14
|
+
if typing.TYPE_CHECKING:
|
|
15
|
+
from molcraft import featurizers
|
|
13
16
|
|
|
14
17
|
|
|
15
18
|
def write(
|
|
16
19
|
inputs: list[str | tuple],
|
|
17
|
-
featurizer: featurizers.GraphFeaturizer,
|
|
20
|
+
featurizer: 'featurizers.GraphFeaturizer',
|
|
18
21
|
path: str,
|
|
19
|
-
|
|
22
|
+
exist_ok: bool = False,
|
|
23
|
+
overwrite: bool = False,
|
|
20
24
|
num_files: typing.Optional[int] = None,
|
|
21
25
|
num_processes: typing.Optional[int] = None,
|
|
22
26
|
multiprocessing: bool = False,
|
|
@@ -24,6 +28,8 @@ def write(
|
|
|
24
28
|
) -> None:
|
|
25
29
|
|
|
26
30
|
if os.path.isdir(path):
|
|
31
|
+
if not exist_ok:
|
|
32
|
+
raise FileExistsError(f'Records already exist: {path}')
|
|
27
33
|
if not overwrite:
|
|
28
34
|
return
|
|
29
35
|
else:
|
|
@@ -60,9 +66,11 @@ def write(
|
|
|
60
66
|
chunk_sizes[i % num_files] += 1
|
|
61
67
|
|
|
62
68
|
input_chunks = []
|
|
69
|
+
start_indices = []
|
|
63
70
|
current_index = 0
|
|
64
71
|
for size in chunk_sizes:
|
|
65
72
|
input_chunks.append(inputs[current_index: current_index + size])
|
|
73
|
+
start_indices.append(current_index)
|
|
66
74
|
current_index += size
|
|
67
75
|
|
|
68
76
|
assert current_index == num_examples
|
|
@@ -73,13 +81,13 @@ def write(
|
|
|
73
81
|
]
|
|
74
82
|
|
|
75
83
|
if not multiprocessing:
|
|
76
|
-
for path, input_chunk in zip(paths, input_chunks):
|
|
77
|
-
_write_tfrecord(input_chunk, path, featurizer)
|
|
84
|
+
for path, input_chunk, start_index in zip(paths, input_chunks, start_indices):
|
|
85
|
+
_write_tfrecord(input_chunk, path, featurizer, start_index)
|
|
78
86
|
return
|
|
79
87
|
|
|
80
88
|
processes = []
|
|
81
89
|
|
|
82
|
-
for path, input_chunk in zip(paths, input_chunks):
|
|
90
|
+
for path, input_chunk, start_index in zip(paths, input_chunks, start_indices):
|
|
83
91
|
|
|
84
92
|
while len(processes) >= num_processes:
|
|
85
93
|
for process in processes:
|
|
@@ -91,7 +99,7 @@ def write(
|
|
|
91
99
|
|
|
92
100
|
process = mp.Process(
|
|
93
101
|
target=_write_tfrecord,
|
|
94
|
-
args=(input_chunk, path, featurizer)
|
|
102
|
+
args=(input_chunk, path, featurizer, start_index)
|
|
95
103
|
)
|
|
96
104
|
processes.append(process)
|
|
97
105
|
process.start()
|
|
@@ -134,9 +142,10 @@ def load_spec(path: str) -> tensors.GraphTensor.Spec:
|
|
|
134
142
|
return spec
|
|
135
143
|
|
|
136
144
|
def _write_tfrecord(
|
|
137
|
-
inputs,
|
|
145
|
+
inputs: list[str, tuple],
|
|
138
146
|
path: str,
|
|
139
|
-
featurizer: featurizers.GraphFeaturizer,
|
|
147
|
+
featurizer: 'featurizers.GraphFeaturizer',
|
|
148
|
+
start_index: int,
|
|
140
149
|
) -> None:
|
|
141
150
|
|
|
142
151
|
def _write_example(tensor):
|
|
@@ -147,12 +156,17 @@ def _write_tfrecord(
|
|
|
147
156
|
writer.write(serialized_feature)
|
|
148
157
|
|
|
149
158
|
with tf.io.TFRecordWriter(path) as writer:
|
|
150
|
-
for x in inputs:
|
|
159
|
+
for i, x in enumerate(inputs):
|
|
151
160
|
if isinstance(x, (list, np.ndarray)):
|
|
152
161
|
x = tuple(x)
|
|
153
|
-
|
|
154
|
-
|
|
162
|
+
try:
|
|
163
|
+
tensor = featurizer(x)
|
|
155
164
|
_write_example(tensor)
|
|
165
|
+
except Exception as e:
|
|
166
|
+
warnings.warn(
|
|
167
|
+
f'Could not write record for index {i + start_index}, proceeding without it.'
|
|
168
|
+
f'Exception raised:\n{e}'
|
|
169
|
+
)
|
|
156
170
|
|
|
157
171
|
def _serialize_example(
|
|
158
172
|
feature: dict[str, tf.train.Feature]
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: molcraft
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.0a23
|
|
4
4
|
Summary: Graph Neural Networks for Molecular Machine Learning
|
|
5
5
|
Author-email: Alexander Kensert <alexander.kensert@gmail.com>
|
|
6
6
|
License: MIT License
|
|
@@ -35,7 +35,6 @@ Requires-Python: >=3.10
|
|
|
35
35
|
Description-Content-Type: text/markdown
|
|
36
36
|
License-File: LICENSE
|
|
37
37
|
Requires-Dist: tensorflow>=2.16
|
|
38
|
-
Requires-Dist: tensorflow-text>=2.16
|
|
39
38
|
Requires-Dist: rdkit>=2023.9.5
|
|
40
39
|
Requires-Dist: pandas>=1.0.3
|
|
41
40
|
Requires-Dist: ipython>=8.12.0
|
|
@@ -43,7 +42,7 @@ Provides-Extra: gpu
|
|
|
43
42
|
Requires-Dist: tensorflow[and-cuda]>=2.16; extra == "gpu"
|
|
44
43
|
Dynamic: license-file
|
|
45
44
|
|
|
46
|
-
<img src="https://github.com/akensert/molcraft/blob/main/docs/_static/molcraft-logo.png" alt="molcraft-logo"
|
|
45
|
+
<img src="https://github.com/akensert/molcraft/blob/main/docs/_static/molcraft-logo.png" alt="molcraft-logo" width="90%">
|
|
47
46
|
|
|
48
47
|
**Deep Learning on Molecules**: A Minimalistic GNN package for Molecular ML.
|
|
49
48
|
|
|
@@ -3,6 +3,7 @@ import tempfile
|
|
|
3
3
|
import shutil
|
|
4
4
|
|
|
5
5
|
from molcraft import features
|
|
6
|
+
from molcraft import descriptors
|
|
6
7
|
from molcraft import featurizers
|
|
7
8
|
|
|
8
9
|
|
|
@@ -129,6 +130,9 @@ class TestFeaturizer(unittest.TestCase):
|
|
|
129
130
|
pair_features=[
|
|
130
131
|
features.PairDistance(max_distance=20)
|
|
131
132
|
],
|
|
133
|
+
molecule_features=[
|
|
134
|
+
descriptors.ForceFieldEnergy(),
|
|
135
|
+
],
|
|
132
136
|
super_node=True,
|
|
133
137
|
self_loops=False,
|
|
134
138
|
include_hydrogens=False,
|
|
@@ -199,6 +203,5 @@ class TestFeaturizer(unittest.TestCase):
|
|
|
199
203
|
self.assertEqual(graph.edge['target'].dtype.name, 'int32')
|
|
200
204
|
|
|
201
205
|
|
|
202
|
-
|
|
203
206
|
if __name__ == '__main__':
|
|
204
207
|
unittest.main()
|
|
@@ -21,4 +21,8 @@ class TestLoss(unittest.TestCase):
|
|
|
21
21
|
keras.ops.array([[2., 0.1], [4., 0.2], [5., 0.3]])
|
|
22
22
|
)
|
|
23
23
|
self.assertGreater(value, 0)
|
|
24
|
-
self.assertEqual(len(keras.ops.shape(value)), 0)
|
|
24
|
+
self.assertEqual(len(keras.ops.shape(value)), 0)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
if __name__ == '__main__':
|
|
28
|
+
unittest.main()
|
|
@@ -266,4 +266,8 @@ class TestModel(unittest.TestCase):
|
|
|
266
266
|
model = get_model(tensor)
|
|
267
267
|
out = model.embedding()(tensor)
|
|
268
268
|
self.assertTrue(out.shape[0] == tensor.context['size'].shape[0])
|
|
269
|
-
self.assertTrue(out.shape[1] == units)
|
|
269
|
+
self.assertTrue(out.shape[1] == units)
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
if __name__ == '__main__':
|
|
273
|
+
unittest.main()
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|