molcraft 0.1.0a1__tar.gz → 0.1.0a3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of molcraft might be problematic. Click here for more details.
- {molcraft-0.1.0a1 → molcraft-0.1.0a3}/PKG-INFO +68 -1
- molcraft-0.1.0a3/README.md +81 -0
- {molcraft-0.1.0a1 → molcraft-0.1.0a3}/molcraft/__init__.py +2 -1
- molcraft-0.1.0a3/molcraft/datasets.py +123 -0
- {molcraft-0.1.0a1 → molcraft-0.1.0a3}/molcraft/experimental/peptides.py +28 -67
- {molcraft-0.1.0a1 → molcraft-0.1.0a3}/molcraft/features.py +5 -3
- {molcraft-0.1.0a1 → molcraft-0.1.0a3}/molcraft/featurizers.py +68 -27
- {molcraft-0.1.0a1 → molcraft-0.1.0a3}/molcraft/layers.py +1299 -647
- {molcraft-0.1.0a1 → molcraft-0.1.0a3}/molcraft/models.py +35 -5
- {molcraft-0.1.0a1 → molcraft-0.1.0a3}/molcraft/tensors.py +33 -12
- {molcraft-0.1.0a1 → molcraft-0.1.0a3}/molcraft.egg-info/PKG-INFO +68 -1
- {molcraft-0.1.0a1 → molcraft-0.1.0a3}/molcraft.egg-info/SOURCES.txt +1 -0
- {molcraft-0.1.0a1 → molcraft-0.1.0a3}/tests/test_featurizers.py +45 -0
- molcraft-0.1.0a3/tests/test_layers.py +268 -0
- {molcraft-0.1.0a1 → molcraft-0.1.0a3}/tests/test_models.py +4 -0
- molcraft-0.1.0a1/README.md +0 -14
- molcraft-0.1.0a1/tests/test_layers.py +0 -119
- {molcraft-0.1.0a1 → molcraft-0.1.0a3}/LICENSE +0 -0
- {molcraft-0.1.0a1 → molcraft-0.1.0a3}/molcraft/callbacks.py +0 -0
- {molcraft-0.1.0a1 → molcraft-0.1.0a3}/molcraft/chem.py +0 -0
- {molcraft-0.1.0a1 → molcraft-0.1.0a3}/molcraft/conformers.py +0 -0
- {molcraft-0.1.0a1 → molcraft-0.1.0a3}/molcraft/descriptors.py +0 -0
- {molcraft-0.1.0a1 → molcraft-0.1.0a3}/molcraft/experimental/__init__.py +0 -0
- {molcraft-0.1.0a1 → molcraft-0.1.0a3}/molcraft/ops.py +0 -0
- {molcraft-0.1.0a1 → molcraft-0.1.0a3}/molcraft/records.py +0 -0
- {molcraft-0.1.0a1 → molcraft-0.1.0a3}/molcraft.egg-info/dependency_links.txt +0 -0
- {molcraft-0.1.0a1 → molcraft-0.1.0a3}/molcraft.egg-info/requires.txt +0 -0
- {molcraft-0.1.0a1 → molcraft-0.1.0a3}/molcraft.egg-info/top_level.txt +0 -0
- {molcraft-0.1.0a1 → molcraft-0.1.0a3}/pyproject.toml +0 -0
- {molcraft-0.1.0a1 → molcraft-0.1.0a3}/setup.cfg +0 -0
- {molcraft-0.1.0a1 → molcraft-0.1.0a3}/tests/test_chem.py +0 -0
- {molcraft-0.1.0a1 → molcraft-0.1.0a3}/tests/test_tensors.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: molcraft
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.0a3
|
|
4
4
|
Summary: Graph Neural Networks for Molecular Machine Learning
|
|
5
5
|
Author-email: Alexander Kensert <alexander.kensert@gmail.com>
|
|
6
6
|
License: MIT License
|
|
@@ -56,3 +56,70 @@ Dynamic: license-file
|
|
|
56
56
|
- Modular graph **layers**
|
|
57
57
|
- Serializable graph **featurizers** and **models**
|
|
58
58
|
- Flexible **GraphTensor**
|
|
59
|
+
|
|
60
|
+
## Examples
|
|
61
|
+
|
|
62
|
+
```python
|
|
63
|
+
from molcraft import features
|
|
64
|
+
from molcraft import descriptors
|
|
65
|
+
from molcraft import featurizers
|
|
66
|
+
from molcraft import layers
|
|
67
|
+
from molcraft import models
|
|
68
|
+
import keras
|
|
69
|
+
|
|
70
|
+
featurizer = featurizers.MolGraphFeaturizer(
|
|
71
|
+
atom_features=[
|
|
72
|
+
features.AtomType(),
|
|
73
|
+
features.TotalNumHs(),
|
|
74
|
+
features.Degree(),
|
|
75
|
+
],
|
|
76
|
+
bond_features=[
|
|
77
|
+
features.BondType(),
|
|
78
|
+
features.IsRotatable(),
|
|
79
|
+
],
|
|
80
|
+
super_atom=True,
|
|
81
|
+
self_loops=False,
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
graph = featurizer([('N[C@@H](C)C(=O)O', 2.0), ('N[C@@H](CS)C(=O)O', 1.0)])
|
|
85
|
+
print(graph)
|
|
86
|
+
|
|
87
|
+
model = models.GraphModel.from_layers(
|
|
88
|
+
[
|
|
89
|
+
layers.Input(graph.spec),
|
|
90
|
+
layers.NodeEmbedding(dim=128),
|
|
91
|
+
layers.EdgeEmbedding(dim=128),
|
|
92
|
+
layers.GraphTransformer(units=128),
|
|
93
|
+
layers.GraphTransformer(units=128),
|
|
94
|
+
layers.GraphTransformer(units=128),
|
|
95
|
+
layers.GraphTransformer(units=128),
|
|
96
|
+
layers.Readout(mode='mean'),
|
|
97
|
+
keras.layers.Dense(units=1024, activation='relu'),
|
|
98
|
+
keras.layers.Dense(units=1024, activation='relu'),
|
|
99
|
+
keras.layers.Dense(1)
|
|
100
|
+
]
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
pred = model(graph)
|
|
104
|
+
print(pred)
|
|
105
|
+
|
|
106
|
+
# featurizers.save_featurizer(featurizer, '/tmp/featurizer.json')
|
|
107
|
+
# models.save_model(model, '/tmp/model.keras')
|
|
108
|
+
|
|
109
|
+
# featurizers.load_featurizer('/tmp/featurizer.json')
|
|
110
|
+
# models.load_model('/tmp/model.keras')
|
|
111
|
+
```
|
|
112
|
+
|
|
113
|
+
## Installation
|
|
114
|
+
|
|
115
|
+
Install the pre-release of molcraft via pip:
|
|
116
|
+
|
|
117
|
+
```bash
|
|
118
|
+
pip install molcraft --pre
|
|
119
|
+
```
|
|
120
|
+
|
|
121
|
+
with GPU support:
|
|
122
|
+
|
|
123
|
+
```bash
|
|
124
|
+
pip install molcraft[gpu] --pre
|
|
125
|
+
```
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
<img src="https://github.com/akensert/molcraft/blob/main/docs/_static/molcraft-logo.png" alt="molcraft-logo">
|
|
2
|
+
|
|
3
|
+
**Deep Learning on Molecules**: A Minimalistic GNN package for Molecular ML.
|
|
4
|
+
|
|
5
|
+
> [!NOTE]
|
|
6
|
+
> In progress/Unfinished.
|
|
7
|
+
|
|
8
|
+
## Highlights
|
|
9
|
+
- Compatible with **Keras 3**
|
|
10
|
+
- Simplified API
|
|
11
|
+
- Fast featurization
|
|
12
|
+
- Modular graph **layers**
|
|
13
|
+
- Serializable graph **featurizers** and **models**
|
|
14
|
+
- Flexible **GraphTensor**
|
|
15
|
+
|
|
16
|
+
## Examples
|
|
17
|
+
|
|
18
|
+
```python
|
|
19
|
+
from molcraft import features
|
|
20
|
+
from molcraft import descriptors
|
|
21
|
+
from molcraft import featurizers
|
|
22
|
+
from molcraft import layers
|
|
23
|
+
from molcraft import models
|
|
24
|
+
import keras
|
|
25
|
+
|
|
26
|
+
featurizer = featurizers.MolGraphFeaturizer(
|
|
27
|
+
atom_features=[
|
|
28
|
+
features.AtomType(),
|
|
29
|
+
features.TotalNumHs(),
|
|
30
|
+
features.Degree(),
|
|
31
|
+
],
|
|
32
|
+
bond_features=[
|
|
33
|
+
features.BondType(),
|
|
34
|
+
features.IsRotatable(),
|
|
35
|
+
],
|
|
36
|
+
super_atom=True,
|
|
37
|
+
self_loops=False,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
graph = featurizer([('N[C@@H](C)C(=O)O', 2.0), ('N[C@@H](CS)C(=O)O', 1.0)])
|
|
41
|
+
print(graph)
|
|
42
|
+
|
|
43
|
+
model = models.GraphModel.from_layers(
|
|
44
|
+
[
|
|
45
|
+
layers.Input(graph.spec),
|
|
46
|
+
layers.NodeEmbedding(dim=128),
|
|
47
|
+
layers.EdgeEmbedding(dim=128),
|
|
48
|
+
layers.GraphTransformer(units=128),
|
|
49
|
+
layers.GraphTransformer(units=128),
|
|
50
|
+
layers.GraphTransformer(units=128),
|
|
51
|
+
layers.GraphTransformer(units=128),
|
|
52
|
+
layers.Readout(mode='mean'),
|
|
53
|
+
keras.layers.Dense(units=1024, activation='relu'),
|
|
54
|
+
keras.layers.Dense(units=1024, activation='relu'),
|
|
55
|
+
keras.layers.Dense(1)
|
|
56
|
+
]
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
pred = model(graph)
|
|
60
|
+
print(pred)
|
|
61
|
+
|
|
62
|
+
# featurizers.save_featurizer(featurizer, '/tmp/featurizer.json')
|
|
63
|
+
# models.save_model(model, '/tmp/model.keras')
|
|
64
|
+
|
|
65
|
+
# featurizers.load_featurizer('/tmp/featurizer.json')
|
|
66
|
+
# models.load_model('/tmp/model.keras')
|
|
67
|
+
```
|
|
68
|
+
|
|
69
|
+
## Installation
|
|
70
|
+
|
|
71
|
+
Install the pre-release of molcraft via pip:
|
|
72
|
+
|
|
73
|
+
```bash
|
|
74
|
+
pip install molcraft --pre
|
|
75
|
+
```
|
|
76
|
+
|
|
77
|
+
with GPU support:
|
|
78
|
+
|
|
79
|
+
```bash
|
|
80
|
+
pip install molcraft[gpu] --pre
|
|
81
|
+
```
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
__version__ = '0.1.
|
|
1
|
+
__version__ = '0.1.0a3'
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
4
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
|
@@ -14,3 +14,4 @@ from molcraft import ops
|
|
|
14
14
|
from molcraft import records
|
|
15
15
|
from molcraft import tensors
|
|
16
16
|
from molcraft import callbacks
|
|
17
|
+
from molcraft import datasets
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pandas as pd
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def split(
|
|
6
|
+
data: pd.DataFrame | np.ndarray,
|
|
7
|
+
train_size: float | None = None,
|
|
8
|
+
validation_size: float | None = None,
|
|
9
|
+
test_size: float = 0.1,
|
|
10
|
+
shuffle: bool = False,
|
|
11
|
+
random_state: int | None = None,
|
|
12
|
+
) -> pd.DataFrame | np.ndarray:
|
|
13
|
+
"""Splits dataset into subsets.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
data:
|
|
17
|
+
A pd.DataFrame or np.ndarray object.
|
|
18
|
+
train_size:
|
|
19
|
+
Optional train size, as a fraction (`float`) or size (`int`).
|
|
20
|
+
validation_size:
|
|
21
|
+
Optional validation size, as a fraction (`float`) or size (`int`).
|
|
22
|
+
test_size:
|
|
23
|
+
Required test size, as a fraction (`float`) or size (`int`).
|
|
24
|
+
shuffle:
|
|
25
|
+
Whether the dataset should be shuffled prior to splitting.
|
|
26
|
+
random_state:
|
|
27
|
+
The random state (or seed). Only applicable if shuffling.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
if not isinstance(data, (pd.DataFrame, np.ndarray, list)):
|
|
31
|
+
raise ValueError(
|
|
32
|
+
'`data` needs to be a pd.DataFrame, np.ndarray or a list. '
|
|
33
|
+
f'Found {type(data)}.'
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
size = len(data)
|
|
37
|
+
|
|
38
|
+
if test_size is None:
|
|
39
|
+
raise ValueError('`test_size` is required.')
|
|
40
|
+
elif test_size <= 0:
|
|
41
|
+
raise ValueError(
|
|
42
|
+
f'Test size needs to be positive. Found: {test_size}. '
|
|
43
|
+
'Either specify a positive `float` (fraction) or '
|
|
44
|
+
'a positive `int` (size).'
|
|
45
|
+
)
|
|
46
|
+
if train_size is not None and train_size <= 0:
|
|
47
|
+
raise ValueError(
|
|
48
|
+
f'Train size needs to be None or positive. Found: {train_size}. '
|
|
49
|
+
'Either specify `None`, a positive `float` (fraction) or '
|
|
50
|
+
'a positive `int` (size).'
|
|
51
|
+
)
|
|
52
|
+
if validation_size is not None and validation_size <= 0:
|
|
53
|
+
raise ValueError(
|
|
54
|
+
f'Validation size needs to be None or positive. Found: {validation_size}. '
|
|
55
|
+
'Either specify `None`, a positive `float` (fraction) or '
|
|
56
|
+
'a positive `int` (size).'
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
if isinstance(test_size, float):
|
|
60
|
+
test_size = int(size * test_size)
|
|
61
|
+
if validation_size and isinstance(validation_size, float):
|
|
62
|
+
validation_size = int(size * validation_size)
|
|
63
|
+
elif not validation_size:
|
|
64
|
+
validation_size = 0
|
|
65
|
+
|
|
66
|
+
if train_size and isinstance(train_size, float):
|
|
67
|
+
train_size = int(size * train_size)
|
|
68
|
+
elif not train_size:
|
|
69
|
+
train_size = 0
|
|
70
|
+
|
|
71
|
+
if not train_size:
|
|
72
|
+
train_size = size - test_size
|
|
73
|
+
if not validation_size:
|
|
74
|
+
train_size -= validation_size
|
|
75
|
+
|
|
76
|
+
remainder = size - (train_size + validation_size + test_size)
|
|
77
|
+
|
|
78
|
+
if remainder < 0:
|
|
79
|
+
raise ValueError(
|
|
80
|
+
'Sizes of data subsets add up to more than the size of the original data set: '
|
|
81
|
+
f'{size} < ({train_size} + {validation_size} + {test_size})'
|
|
82
|
+
)
|
|
83
|
+
if test_size <= 0:
|
|
84
|
+
raise ValueError(
|
|
85
|
+
f'Test size needs to be greater than 0. Found: {test_size}.'
|
|
86
|
+
)
|
|
87
|
+
if train_size <= 0:
|
|
88
|
+
raise ValueError(
|
|
89
|
+
f'Train size needs to be greater than 0. Found: {train_size}.'
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
train_size += remainder
|
|
93
|
+
|
|
94
|
+
if isinstance(data, pd.DataFrame):
|
|
95
|
+
if shuffle:
|
|
96
|
+
data = data.sample(
|
|
97
|
+
frac=1.0, replace=False, random_state=random_state
|
|
98
|
+
)
|
|
99
|
+
train_data = data.iloc[:train_size]
|
|
100
|
+
test_data = data.iloc[-test_size:]
|
|
101
|
+
if not validation_size:
|
|
102
|
+
return train_data, test_data
|
|
103
|
+
validation_data = data.iloc[train_size:-test_size]
|
|
104
|
+
return train_data, validation_data, test_data
|
|
105
|
+
|
|
106
|
+
if not isinstance(data, np.ndarray):
|
|
107
|
+
data = np.asarray(data)
|
|
108
|
+
|
|
109
|
+
np.random.seed(random_state)
|
|
110
|
+
|
|
111
|
+
random_indices = np.arange(size)
|
|
112
|
+
np.random.shuffle(random_indices)
|
|
113
|
+
data = data[random_indices]
|
|
114
|
+
|
|
115
|
+
train_data = data[:train_size]
|
|
116
|
+
test_data = data[-test_size:]
|
|
117
|
+
if not validation_size:
|
|
118
|
+
return train_data, test_data
|
|
119
|
+
validation_data = data[train_size:-test_size]
|
|
120
|
+
return train_data, validation_data, test_data
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
|
|
@@ -9,75 +9,36 @@ from molcraft import chem
|
|
|
9
9
|
from molcraft import features
|
|
10
10
|
from molcraft import featurizers
|
|
11
11
|
from molcraft import tensors
|
|
12
|
+
from molcraft import descriptors
|
|
12
13
|
|
|
13
14
|
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
self_loops=self_loops,
|
|
35
|
-
include_hs=include_hs,
|
|
36
|
-
feature_dtype=feature_dtype,
|
|
37
|
-
index_dtype=index_dtype
|
|
38
|
-
)
|
|
15
|
+
def Graph(
|
|
16
|
+
inputs,
|
|
17
|
+
atom_features: list[features.Feature] | str | None = 'auto',
|
|
18
|
+
bond_features: list[features.Feature] | str | None = 'auto',
|
|
19
|
+
super_atom: bool = True,
|
|
20
|
+
radius: int | float | None = None,
|
|
21
|
+
self_loops: bool = False,
|
|
22
|
+
include_hs: bool = False,
|
|
23
|
+
**kwargs,
|
|
24
|
+
):
|
|
25
|
+
featurizer = featurizers.MolGraphFeaturizer(
|
|
26
|
+
atom_features=atom_features,
|
|
27
|
+
bond_features=bond_features,
|
|
28
|
+
molecule_features=[AminoAcidType()],
|
|
29
|
+
super_atom=super_atom,
|
|
30
|
+
radius=radius,
|
|
31
|
+
self_loops=self_loops,
|
|
32
|
+
include_hs=include_hs,
|
|
33
|
+
**kwargs,
|
|
34
|
+
)
|
|
39
35
|
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
inputs = [
|
|
46
|
-
features.residues[x] for x in ['G'] + inputs
|
|
47
|
-
]
|
|
48
|
-
mols = [
|
|
49
|
-
chem.Mol.from_encoding(x, explicit_hs=self.include_hs) for x in inputs
|
|
50
|
-
]
|
|
51
|
-
mols = [
|
|
52
|
-
mol for mol in mols if mol is not None
|
|
53
|
-
]
|
|
54
|
-
if not mols:
|
|
55
|
-
return None
|
|
56
|
-
tensor_list: list[tensors.GraphTensor] = [super().call(mol) for mol in mols]
|
|
57
|
-
tensor: tensors.GraphTensor = tf.stack(tensor_list, axis=0)
|
|
58
|
-
return tensor
|
|
59
|
-
|
|
60
|
-
def call(self, inputs: str | tuple) -> tensors.GraphTensor:
|
|
61
|
-
args = []
|
|
62
|
-
if isinstance(inputs, (list, tuple, np.ndarray)):
|
|
63
|
-
inputs, *args = inputs
|
|
64
|
-
inputs = [
|
|
65
|
-
features.residues[x] for x in chem.sequence_split(inputs)
|
|
66
|
-
]
|
|
67
|
-
tensor_list: list[tensors.GraphTensor] = [super().call(x) for x in inputs]
|
|
68
|
-
tensor: tensors.GraphTensor = tf.stack(tensor_list, axis=0)
|
|
69
|
-
tensor = tensor._merge()
|
|
70
|
-
context = {
|
|
71
|
-
k: v for (k, v) in zip(['label', 'weight'], args)
|
|
72
|
-
}
|
|
73
|
-
tensor = tensor.update(
|
|
74
|
-
{
|
|
75
|
-
'context': context
|
|
76
|
-
}
|
|
77
|
-
)
|
|
36
|
+
inputs = [
|
|
37
|
+
residues[x] for x in ['G'] + inputs
|
|
38
|
+
]
|
|
39
|
+
tensor_list = [featurizer(x) for x in inputs]
|
|
40
|
+
return tf.stack(tensor_list, axis=0)
|
|
78
41
|
|
|
79
|
-
return tensor
|
|
80
|
-
|
|
81
42
|
|
|
82
43
|
def GraphLookup(graph: tensors.GraphTensor) -> 'GraphLookupLayer':
|
|
83
44
|
lookup = GraphLookupLayer()
|
|
@@ -203,7 +164,7 @@ class Gather(keras.layers.Layer):
|
|
|
203
164
|
|
|
204
165
|
|
|
205
166
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
206
|
-
class AminoAcidType(
|
|
167
|
+
class AminoAcidType(descriptors.Descriptor):
|
|
207
168
|
|
|
208
169
|
def __init__(self, vocab=None, **kwargs):
|
|
209
170
|
vocab = [
|
|
@@ -217,7 +178,7 @@ class AminoAcidType(features.Feature):
|
|
|
217
178
|
if not residue:
|
|
218
179
|
raise KeyError(f'Could not find {mol.canonical_smiles} in `residues_reverse`.')
|
|
219
180
|
mol = chem.remove_hs(mol)
|
|
220
|
-
return
|
|
181
|
+
return _extract_residue_type(residues_reverse[mol.canonical_smiles])
|
|
221
182
|
|
|
222
183
|
def sequence_split(sequence: str):
|
|
223
184
|
patterns = [
|
|
@@ -155,9 +155,11 @@ class Distance(EdgeFeature):
|
|
|
155
155
|
encode_oov: bool = True,
|
|
156
156
|
**kwargs,
|
|
157
157
|
) -> None:
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
158
|
+
vocab = kwargs.pop('vocab', None)
|
|
159
|
+
if not vocab:
|
|
160
|
+
if max_distance is None:
|
|
161
|
+
max_distance = 20
|
|
162
|
+
vocab = list(range(max_distance + 1))
|
|
161
163
|
super().__init__(
|
|
162
164
|
vocab=vocab,
|
|
163
165
|
allow_oov=allow_oov,
|
|
@@ -200,12 +200,13 @@ class MolGraphFeaturizer(Featurizer):
|
|
|
200
200
|
self.feature_dtype = 'float32'
|
|
201
201
|
self.index_dtype = 'int32'
|
|
202
202
|
|
|
203
|
-
def call(self,
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
203
|
+
def call(self, inputs: str | tuple) -> tensors.GraphTensor:
|
|
204
|
+
if isinstance(inputs, (tuple, list, np.ndarray)):
|
|
205
|
+
x, *context = inputs
|
|
206
|
+
if len(context) and isinstance(context[0], dict):
|
|
207
|
+
context = copy.deepcopy(context[0])
|
|
207
208
|
else:
|
|
208
|
-
|
|
209
|
+
x, context = inputs, None
|
|
209
210
|
|
|
210
211
|
mol = chem.Mol.from_encoding(x, explicit_hs=self.include_hs)
|
|
211
212
|
|
|
@@ -220,14 +221,30 @@ class MolGraphFeaturizer(Featurizer):
|
|
|
220
221
|
bond_feature = self.bond_features(mol)
|
|
221
222
|
context_feature = self.context_feature(mol)
|
|
222
223
|
molecule_size = self.num_atoms(mol)
|
|
223
|
-
|
|
224
|
-
context,
|
|
225
|
-
|
|
226
|
-
|
|
224
|
+
|
|
225
|
+
if isinstance(context, dict):
|
|
226
|
+
if 'x' in context:
|
|
227
|
+
context['feature'] = context.pop('x')
|
|
228
|
+
if 'y' in context:
|
|
229
|
+
context['label'] = context.pop('y')
|
|
230
|
+
if 'sample_weight' in context:
|
|
231
|
+
context['weight'] = context.pop('sample_weight')
|
|
232
|
+
context = {
|
|
233
|
+
**{'size': molecule_size},
|
|
234
|
+
**context
|
|
235
|
+
}
|
|
236
|
+
elif isinstance(context, list):
|
|
237
|
+
context = {
|
|
238
|
+
**{'size': molecule_size},
|
|
239
|
+
**{key: value for (key, value) in zip(['label', 'weight'], context)}
|
|
240
|
+
}
|
|
241
|
+
else:
|
|
242
|
+
context = {'size': molecule_size}
|
|
227
243
|
|
|
228
244
|
if context_feature is not None:
|
|
229
245
|
context['feature'] = context_feature
|
|
230
246
|
|
|
247
|
+
node = {}
|
|
231
248
|
node['feature'] = atom_feature
|
|
232
249
|
|
|
233
250
|
if bond_feature is not None and (self.radius > 1 or self.self_loops):
|
|
@@ -239,6 +256,7 @@ class MolGraphFeaturizer(Featurizer):
|
|
|
239
256
|
[bond_feature, zero_bond_feature], axis=0
|
|
240
257
|
)
|
|
241
258
|
|
|
259
|
+
edge = {}
|
|
242
260
|
if self.radius == 1:
|
|
243
261
|
edge['source'], edge['target'] = mol.adjacency(
|
|
244
262
|
fill='full', sparse=True, self_loops=self.self_loops, dtype=self.index_dtype
|
|
@@ -384,6 +402,7 @@ class MolGraphFeaturizer(Featurizer):
|
|
|
384
402
|
return cls(**config)
|
|
385
403
|
|
|
386
404
|
|
|
405
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
387
406
|
class MolGraphFeaturizer3D(MolGraphFeaturizer):
|
|
388
407
|
|
|
389
408
|
"""Molecular 3d-graph featurizer.
|
|
@@ -494,19 +513,25 @@ class MolGraphFeaturizer3D(MolGraphFeaturizer):
|
|
|
494
513
|
self.embed_conformer = self.conformer_generator is not None
|
|
495
514
|
self.radius = float(radius) if radius else None
|
|
496
515
|
|
|
497
|
-
def call(self,
|
|
516
|
+
def call(self, inputs: str | tuple) -> tensors.GraphTensor:
|
|
498
517
|
|
|
499
|
-
if isinstance(
|
|
500
|
-
x, *
|
|
518
|
+
if isinstance(inputs, (tuple, list, np.ndarray)):
|
|
519
|
+
x, *context = inputs
|
|
520
|
+
if len(context) and isinstance(context[0], dict):
|
|
521
|
+
context = copy.deepcopy(context[0])
|
|
501
522
|
else:
|
|
502
|
-
|
|
523
|
+
x, context = inputs, None
|
|
503
524
|
|
|
504
525
|
explicit_hs = (self.include_hs or self.embed_conformer)
|
|
505
526
|
mol = chem.Mol.from_encoding(x, explicit_hs=explicit_hs)
|
|
506
|
-
|
|
527
|
+
|
|
507
528
|
if mol is None:
|
|
529
|
+
warn(
|
|
530
|
+
f'Could not obtain `chem.Mol` from {x}. '
|
|
531
|
+
'Proceeding without it.'
|
|
532
|
+
)
|
|
508
533
|
return None
|
|
509
|
-
|
|
534
|
+
|
|
510
535
|
if self.embed_conformer:
|
|
511
536
|
mol = self.conformer_generator(mol)
|
|
512
537
|
if not self.include_hs:
|
|
@@ -519,21 +544,38 @@ class MolGraphFeaturizer3D(MolGraphFeaturizer):
|
|
|
519
544
|
'of the `Featurizer` or input a 3D representation of the molecule. '
|
|
520
545
|
)
|
|
521
546
|
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
547
|
+
context_feature = self.context_feature(mol)
|
|
548
|
+
molecule_size = self.num_atoms(mol) + int(self.super_atom)
|
|
549
|
+
|
|
550
|
+
if isinstance(context, dict):
|
|
551
|
+
if 'x' in context:
|
|
552
|
+
context['feature'] = context.pop('x')
|
|
553
|
+
if 'y' in context:
|
|
554
|
+
context['label'] = context.pop('y')
|
|
555
|
+
if 'sample_weight' in context:
|
|
556
|
+
context['weight'] = context.pop('sample_weight')
|
|
557
|
+
context = {
|
|
558
|
+
**{'size': molecule_size},
|
|
559
|
+
**context
|
|
560
|
+
}
|
|
561
|
+
elif isinstance(context, list):
|
|
562
|
+
context = {
|
|
563
|
+
**{'size': molecule_size},
|
|
564
|
+
**{key: value for (key, value) in zip(['label', 'weight'], context)}
|
|
565
|
+
}
|
|
566
|
+
else:
|
|
567
|
+
context = {'size': molecule_size}
|
|
527
568
|
|
|
569
|
+
if context_feature is not None:
|
|
570
|
+
context['feature'] = context_feature
|
|
571
|
+
|
|
572
|
+
node = {}
|
|
528
573
|
node['feature'] = self.atom_features(mol)
|
|
529
574
|
|
|
530
575
|
if self._bond_features:
|
|
531
576
|
edge_feature = self.bond_features(mol)
|
|
532
577
|
|
|
533
|
-
|
|
534
|
-
if context_feature is not None:
|
|
535
|
-
context['feature'] = context_feature
|
|
536
|
-
|
|
578
|
+
edge = {}
|
|
537
579
|
mols = chem._split_mol_by_confs(mol)
|
|
538
580
|
tensor_list = []
|
|
539
581
|
for i, mol in enumerate(mols):
|
|
@@ -563,11 +605,10 @@ class MolGraphFeaturizer3D(MolGraphFeaturizer):
|
|
|
563
605
|
node_conformer['coordinate'] = np.concatenate(
|
|
564
606
|
[node_conformer['coordinate'], conformer.centroid[None]], axis=0
|
|
565
607
|
)
|
|
566
|
-
|
|
567
608
|
tensor_list.append(
|
|
568
609
|
tensors.GraphTensor(context, node_conformer, edge_conformer)
|
|
569
610
|
)
|
|
570
|
-
|
|
611
|
+
|
|
571
612
|
return tensor_list
|
|
572
613
|
|
|
573
614
|
def stack(self, outputs):
|
|
@@ -587,7 +628,7 @@ class MolGraphFeaturizer3D(MolGraphFeaturizer):
|
|
|
587
628
|
config['conformer_generator'] = keras.saving.deserialize_keras_object(
|
|
588
629
|
config['conformer_generator']
|
|
589
630
|
)
|
|
590
|
-
return super().from_config(
|
|
631
|
+
return super().from_config(config)
|
|
591
632
|
|
|
592
633
|
|
|
593
634
|
def save_featurizer(
|