molcraft 0.1.0a1__tar.gz → 0.1.0a2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of molcraft might be problematic. Click here for more details.
- {molcraft-0.1.0a1 → molcraft-0.1.0a2}/PKG-INFO +68 -1
- molcraft-0.1.0a2/README.md +81 -0
- {molcraft-0.1.0a1 → molcraft-0.1.0a2}/molcraft/__init__.py +1 -1
- molcraft-0.1.0a2/molcraft/datasets.py +123 -0
- {molcraft-0.1.0a1 → molcraft-0.1.0a2}/molcraft/experimental/peptides.py +28 -67
- {molcraft-0.1.0a1 → molcraft-0.1.0a2}/molcraft/featurizers.py +66 -26
- {molcraft-0.1.0a1 → molcraft-0.1.0a2}/molcraft/layers.py +792 -592
- {molcraft-0.1.0a1 → molcraft-0.1.0a2}/molcraft/models.py +1 -2
- {molcraft-0.1.0a1 → molcraft-0.1.0a2}/molcraft/tensors.py +33 -12
- {molcraft-0.1.0a1 → molcraft-0.1.0a2}/molcraft.egg-info/PKG-INFO +68 -1
- {molcraft-0.1.0a1 → molcraft-0.1.0a2}/molcraft.egg-info/SOURCES.txt +1 -0
- {molcraft-0.1.0a1 → molcraft-0.1.0a2}/tests/test_featurizers.py +45 -0
- {molcraft-0.1.0a1 → molcraft-0.1.0a2}/tests/test_layers.py +30 -6
- {molcraft-0.1.0a1 → molcraft-0.1.0a2}/tests/test_models.py +4 -0
- molcraft-0.1.0a1/README.md +0 -14
- {molcraft-0.1.0a1 → molcraft-0.1.0a2}/LICENSE +0 -0
- {molcraft-0.1.0a1 → molcraft-0.1.0a2}/molcraft/callbacks.py +0 -0
- {molcraft-0.1.0a1 → molcraft-0.1.0a2}/molcraft/chem.py +0 -0
- {molcraft-0.1.0a1 → molcraft-0.1.0a2}/molcraft/conformers.py +0 -0
- {molcraft-0.1.0a1 → molcraft-0.1.0a2}/molcraft/descriptors.py +0 -0
- {molcraft-0.1.0a1 → molcraft-0.1.0a2}/molcraft/experimental/__init__.py +0 -0
- {molcraft-0.1.0a1 → molcraft-0.1.0a2}/molcraft/features.py +0 -0
- {molcraft-0.1.0a1 → molcraft-0.1.0a2}/molcraft/ops.py +0 -0
- {molcraft-0.1.0a1 → molcraft-0.1.0a2}/molcraft/records.py +0 -0
- {molcraft-0.1.0a1 → molcraft-0.1.0a2}/molcraft.egg-info/dependency_links.txt +0 -0
- {molcraft-0.1.0a1 → molcraft-0.1.0a2}/molcraft.egg-info/requires.txt +0 -0
- {molcraft-0.1.0a1 → molcraft-0.1.0a2}/molcraft.egg-info/top_level.txt +0 -0
- {molcraft-0.1.0a1 → molcraft-0.1.0a2}/pyproject.toml +0 -0
- {molcraft-0.1.0a1 → molcraft-0.1.0a2}/setup.cfg +0 -0
- {molcraft-0.1.0a1 → molcraft-0.1.0a2}/tests/test_chem.py +0 -0
- {molcraft-0.1.0a1 → molcraft-0.1.0a2}/tests/test_tensors.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: molcraft
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.0a2
|
|
4
4
|
Summary: Graph Neural Networks for Molecular Machine Learning
|
|
5
5
|
Author-email: Alexander Kensert <alexander.kensert@gmail.com>
|
|
6
6
|
License: MIT License
|
|
@@ -56,3 +56,70 @@ Dynamic: license-file
|
|
|
56
56
|
- Modular graph **layers**
|
|
57
57
|
- Serializable graph **featurizers** and **models**
|
|
58
58
|
- Flexible **GraphTensor**
|
|
59
|
+
|
|
60
|
+
## Examples
|
|
61
|
+
|
|
62
|
+
```python
|
|
63
|
+
from molcraft import features
|
|
64
|
+
from molcraft import descriptors
|
|
65
|
+
from molcraft import featurizers
|
|
66
|
+
from molcraft import layers
|
|
67
|
+
from molcraft import models
|
|
68
|
+
import keras
|
|
69
|
+
|
|
70
|
+
featurizer = featurizers.MolGraphFeaturizer(
|
|
71
|
+
atom_features=[
|
|
72
|
+
features.AtomType(),
|
|
73
|
+
features.TotalNumHs(),
|
|
74
|
+
features.Degree(),
|
|
75
|
+
],
|
|
76
|
+
bond_features=[
|
|
77
|
+
features.BondType(),
|
|
78
|
+
features.IsRotatable(),
|
|
79
|
+
],
|
|
80
|
+
super_atom=True,
|
|
81
|
+
self_loops=False,
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
graph = featurizer([('N[C@@H](C)C(=O)O', 2.0), ('N[C@@H](CS)C(=O)O', 1.0)])
|
|
85
|
+
print(graph)
|
|
86
|
+
|
|
87
|
+
model = models.GraphModel.from_layers(
|
|
88
|
+
[
|
|
89
|
+
layers.Input(graph.spec),
|
|
90
|
+
layers.NodeEmbedding(dim=128),
|
|
91
|
+
layers.EdgeEmbedding(dim=128),
|
|
92
|
+
layers.GraphTransformer(units=128),
|
|
93
|
+
layers.GraphTransformer(units=128),
|
|
94
|
+
layers.GraphTransformer(units=128),
|
|
95
|
+
layers.GraphTransformer(units=128),
|
|
96
|
+
layers.Readout(mode='mean'),
|
|
97
|
+
keras.layers.Dense(units=1024, activation='relu'),
|
|
98
|
+
keras.layers.Dense(units=1024, activation='relu'),
|
|
99
|
+
keras.layers.Dense(1)
|
|
100
|
+
]
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
pred = model(graph)
|
|
104
|
+
print(pred)
|
|
105
|
+
|
|
106
|
+
# featurizers.save_featurizer(featurizer, '/tmp/featurizer.json')
|
|
107
|
+
# models.save_model(model, '/tmp/model.keras')
|
|
108
|
+
|
|
109
|
+
# featurizers.load_featurizer('/tmp/featurizer.json')
|
|
110
|
+
# models.load_model('/tmp/model.keras')
|
|
111
|
+
```
|
|
112
|
+
|
|
113
|
+
## Installation
|
|
114
|
+
|
|
115
|
+
Install the pre-release of molcraft via pip:
|
|
116
|
+
|
|
117
|
+
```bash
|
|
118
|
+
pip install molcraft --pre
|
|
119
|
+
```
|
|
120
|
+
|
|
121
|
+
with GPU support:
|
|
122
|
+
|
|
123
|
+
```bash
|
|
124
|
+
pip install molcraft[gpu] --pre
|
|
125
|
+
```
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
<img src="https://github.com/akensert/molcraft/blob/main/docs/_static/molcraft-logo.png" alt="molcraft-logo">
|
|
2
|
+
|
|
3
|
+
**Deep Learning on Molecules**: A Minimalistic GNN package for Molecular ML.
|
|
4
|
+
|
|
5
|
+
> [!NOTE]
|
|
6
|
+
> In progress/Unfinished.
|
|
7
|
+
|
|
8
|
+
## Highlights
|
|
9
|
+
- Compatible with **Keras 3**
|
|
10
|
+
- Simplified API
|
|
11
|
+
- Fast featurization
|
|
12
|
+
- Modular graph **layers**
|
|
13
|
+
- Serializable graph **featurizers** and **models**
|
|
14
|
+
- Flexible **GraphTensor**
|
|
15
|
+
|
|
16
|
+
## Examples
|
|
17
|
+
|
|
18
|
+
```python
|
|
19
|
+
from molcraft import features
|
|
20
|
+
from molcraft import descriptors
|
|
21
|
+
from molcraft import featurizers
|
|
22
|
+
from molcraft import layers
|
|
23
|
+
from molcraft import models
|
|
24
|
+
import keras
|
|
25
|
+
|
|
26
|
+
featurizer = featurizers.MolGraphFeaturizer(
|
|
27
|
+
atom_features=[
|
|
28
|
+
features.AtomType(),
|
|
29
|
+
features.TotalNumHs(),
|
|
30
|
+
features.Degree(),
|
|
31
|
+
],
|
|
32
|
+
bond_features=[
|
|
33
|
+
features.BondType(),
|
|
34
|
+
features.IsRotatable(),
|
|
35
|
+
],
|
|
36
|
+
super_atom=True,
|
|
37
|
+
self_loops=False,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
graph = featurizer([('N[C@@H](C)C(=O)O', 2.0), ('N[C@@H](CS)C(=O)O', 1.0)])
|
|
41
|
+
print(graph)
|
|
42
|
+
|
|
43
|
+
model = models.GraphModel.from_layers(
|
|
44
|
+
[
|
|
45
|
+
layers.Input(graph.spec),
|
|
46
|
+
layers.NodeEmbedding(dim=128),
|
|
47
|
+
layers.EdgeEmbedding(dim=128),
|
|
48
|
+
layers.GraphTransformer(units=128),
|
|
49
|
+
layers.GraphTransformer(units=128),
|
|
50
|
+
layers.GraphTransformer(units=128),
|
|
51
|
+
layers.GraphTransformer(units=128),
|
|
52
|
+
layers.Readout(mode='mean'),
|
|
53
|
+
keras.layers.Dense(units=1024, activation='relu'),
|
|
54
|
+
keras.layers.Dense(units=1024, activation='relu'),
|
|
55
|
+
keras.layers.Dense(1)
|
|
56
|
+
]
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
pred = model(graph)
|
|
60
|
+
print(pred)
|
|
61
|
+
|
|
62
|
+
# featurizers.save_featurizer(featurizer, '/tmp/featurizer.json')
|
|
63
|
+
# models.save_model(model, '/tmp/model.keras')
|
|
64
|
+
|
|
65
|
+
# featurizers.load_featurizer('/tmp/featurizer.json')
|
|
66
|
+
# models.load_model('/tmp/model.keras')
|
|
67
|
+
```
|
|
68
|
+
|
|
69
|
+
## Installation
|
|
70
|
+
|
|
71
|
+
Install the pre-release of molcraft via pip:
|
|
72
|
+
|
|
73
|
+
```bash
|
|
74
|
+
pip install molcraft --pre
|
|
75
|
+
```
|
|
76
|
+
|
|
77
|
+
with GPU support:
|
|
78
|
+
|
|
79
|
+
```bash
|
|
80
|
+
pip install molcraft[gpu] --pre
|
|
81
|
+
```
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pandas as pd
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def split(
|
|
6
|
+
data: pd.DataFrame | np.ndarray,
|
|
7
|
+
train_size: float | None = None,
|
|
8
|
+
validation_size: float | None = None,
|
|
9
|
+
test_size: float = 0.1,
|
|
10
|
+
shuffle: bool = False,
|
|
11
|
+
random_state: int | None = None,
|
|
12
|
+
) -> pd.DataFrame | np.ndarray:
|
|
13
|
+
"""Splits dataset into subsets.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
data:
|
|
17
|
+
A pd.DataFrame or np.ndarray object.
|
|
18
|
+
train_size:
|
|
19
|
+
Optional train size, as a fraction (`float`) or size (`int`).
|
|
20
|
+
validation_size:
|
|
21
|
+
Optional validation size, as a fraction (`float`) or size (`int`).
|
|
22
|
+
test_size:
|
|
23
|
+
Required test size, as a fraction (`float`) or size (`int`).
|
|
24
|
+
shuffle:
|
|
25
|
+
Whether the dataset should be shuffled prior to splitting.
|
|
26
|
+
random_state:
|
|
27
|
+
The random state (or seed). Only applicable if shuffling.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
if not isinstance(data, (pd.DataFrame, np.ndarray, list)):
|
|
31
|
+
raise ValueError(
|
|
32
|
+
'`data` needs to be a pd.DataFrame, np.ndarray or a list. '
|
|
33
|
+
f'Found {type(data)}.'
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
size = len(data)
|
|
37
|
+
|
|
38
|
+
if test_size is None:
|
|
39
|
+
raise ValueError('`test_size` is required.')
|
|
40
|
+
elif test_size <= 0:
|
|
41
|
+
raise ValueError(
|
|
42
|
+
f'Test size needs to be positive. Found: {test_size}. '
|
|
43
|
+
'Either specify a positive `float` (fraction) or '
|
|
44
|
+
'a positive `int` (size).'
|
|
45
|
+
)
|
|
46
|
+
if train_size is not None and train_size <= 0:
|
|
47
|
+
raise ValueError(
|
|
48
|
+
f'Train size needs to be None or positive. Found: {train_size}. '
|
|
49
|
+
'Either specify `None`, a positive `float` (fraction) or '
|
|
50
|
+
'a positive `int` (size).'
|
|
51
|
+
)
|
|
52
|
+
if validation_size is not None and validation_size <= 0:
|
|
53
|
+
raise ValueError(
|
|
54
|
+
f'Validation size needs to be None or positive. Found: {validation_size}. '
|
|
55
|
+
'Either specify `None`, a positive `float` (fraction) or '
|
|
56
|
+
'a positive `int` (size).'
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
if isinstance(test_size, float):
|
|
60
|
+
test_size = int(size * test_size)
|
|
61
|
+
if validation_size and isinstance(validation_size, float):
|
|
62
|
+
validation_size = int(size * validation_size)
|
|
63
|
+
elif not validation_size:
|
|
64
|
+
validation_size = 0
|
|
65
|
+
|
|
66
|
+
if train_size and isinstance(train_size, float):
|
|
67
|
+
train_size = int(size * train_size)
|
|
68
|
+
elif not train_size:
|
|
69
|
+
train_size = 0
|
|
70
|
+
|
|
71
|
+
if not train_size:
|
|
72
|
+
train_size = size - test_size
|
|
73
|
+
if not validation_size:
|
|
74
|
+
train_size -= validation_size
|
|
75
|
+
|
|
76
|
+
remainder = size - (train_size + validation_size + test_size)
|
|
77
|
+
|
|
78
|
+
if remainder < 0:
|
|
79
|
+
raise ValueError(
|
|
80
|
+
'Sizes of data subsets add up to more than the size of the original data set: '
|
|
81
|
+
f'{size} < ({train_size} + {validation_size} + {test_size})'
|
|
82
|
+
)
|
|
83
|
+
if test_size <= 0:
|
|
84
|
+
raise ValueError(
|
|
85
|
+
f'Test size needs to be greater than 0. Found: {test_size}.'
|
|
86
|
+
)
|
|
87
|
+
if train_size <= 0:
|
|
88
|
+
raise ValueError(
|
|
89
|
+
f'Train size needs to be greater than 0. Found: {train_size}.'
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
train_size += remainder
|
|
93
|
+
|
|
94
|
+
if isinstance(data, pd.DataFrame):
|
|
95
|
+
if shuffle:
|
|
96
|
+
data = data.sample(
|
|
97
|
+
frac=1.0, replace=False, random_state=random_state
|
|
98
|
+
)
|
|
99
|
+
train_data = data.iloc[:train_size]
|
|
100
|
+
test_data = data.iloc[-test_size:]
|
|
101
|
+
if not validation_size:
|
|
102
|
+
return train_data, test_data
|
|
103
|
+
validation_data = data.iloc[train_size:-test_size]
|
|
104
|
+
return train_data, validation_data, test_data
|
|
105
|
+
|
|
106
|
+
if not isinstance(data, np.ndarray):
|
|
107
|
+
data = np.asarray(data)
|
|
108
|
+
|
|
109
|
+
np.random.seed(random_state)
|
|
110
|
+
|
|
111
|
+
random_indices = np.arange(size)
|
|
112
|
+
np.random.shuffle(random_indices)
|
|
113
|
+
data = data[random_indices]
|
|
114
|
+
|
|
115
|
+
train_data = data[:train_size]
|
|
116
|
+
test_data = data[-test_size:]
|
|
117
|
+
if not validation_size:
|
|
118
|
+
return train_data, test_data
|
|
119
|
+
validation_data = data[train_size:-test_size]
|
|
120
|
+
return train_data, validation_data, test_data
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
|
|
@@ -9,75 +9,36 @@ from molcraft import chem
|
|
|
9
9
|
from molcraft import features
|
|
10
10
|
from molcraft import featurizers
|
|
11
11
|
from molcraft import tensors
|
|
12
|
+
from molcraft import descriptors
|
|
12
13
|
|
|
13
14
|
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
self_loops=self_loops,
|
|
35
|
-
include_hs=include_hs,
|
|
36
|
-
feature_dtype=feature_dtype,
|
|
37
|
-
index_dtype=index_dtype
|
|
38
|
-
)
|
|
15
|
+
def Graph(
|
|
16
|
+
inputs,
|
|
17
|
+
atom_features: list[features.Feature] | str | None = 'auto',
|
|
18
|
+
bond_features: list[features.Feature] | str | None = 'auto',
|
|
19
|
+
super_atom: bool = True,
|
|
20
|
+
radius: int | float | None = None,
|
|
21
|
+
self_loops: bool = False,
|
|
22
|
+
include_hs: bool = False,
|
|
23
|
+
**kwargs,
|
|
24
|
+
):
|
|
25
|
+
featurizer = featurizers.MolGraphFeaturizer(
|
|
26
|
+
atom_features=atom_features,
|
|
27
|
+
bond_features=bond_features,
|
|
28
|
+
molecule_features=[AminoAcidType()],
|
|
29
|
+
super_atom=super_atom,
|
|
30
|
+
radius=radius,
|
|
31
|
+
self_loops=self_loops,
|
|
32
|
+
include_hs=include_hs,
|
|
33
|
+
**kwargs,
|
|
34
|
+
)
|
|
39
35
|
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
inputs = [
|
|
46
|
-
features.residues[x] for x in ['G'] + inputs
|
|
47
|
-
]
|
|
48
|
-
mols = [
|
|
49
|
-
chem.Mol.from_encoding(x, explicit_hs=self.include_hs) for x in inputs
|
|
50
|
-
]
|
|
51
|
-
mols = [
|
|
52
|
-
mol for mol in mols if mol is not None
|
|
53
|
-
]
|
|
54
|
-
if not mols:
|
|
55
|
-
return None
|
|
56
|
-
tensor_list: list[tensors.GraphTensor] = [super().call(mol) for mol in mols]
|
|
57
|
-
tensor: tensors.GraphTensor = tf.stack(tensor_list, axis=0)
|
|
58
|
-
return tensor
|
|
59
|
-
|
|
60
|
-
def call(self, inputs: str | tuple) -> tensors.GraphTensor:
|
|
61
|
-
args = []
|
|
62
|
-
if isinstance(inputs, (list, tuple, np.ndarray)):
|
|
63
|
-
inputs, *args = inputs
|
|
64
|
-
inputs = [
|
|
65
|
-
features.residues[x] for x in chem.sequence_split(inputs)
|
|
66
|
-
]
|
|
67
|
-
tensor_list: list[tensors.GraphTensor] = [super().call(x) for x in inputs]
|
|
68
|
-
tensor: tensors.GraphTensor = tf.stack(tensor_list, axis=0)
|
|
69
|
-
tensor = tensor._merge()
|
|
70
|
-
context = {
|
|
71
|
-
k: v for (k, v) in zip(['label', 'weight'], args)
|
|
72
|
-
}
|
|
73
|
-
tensor = tensor.update(
|
|
74
|
-
{
|
|
75
|
-
'context': context
|
|
76
|
-
}
|
|
77
|
-
)
|
|
36
|
+
inputs = [
|
|
37
|
+
residues[x] for x in ['G'] + inputs
|
|
38
|
+
]
|
|
39
|
+
tensor_list = [featurizer(x) for x in inputs]
|
|
40
|
+
return tf.stack(tensor_list, axis=0)
|
|
78
41
|
|
|
79
|
-
return tensor
|
|
80
|
-
|
|
81
42
|
|
|
82
43
|
def GraphLookup(graph: tensors.GraphTensor) -> 'GraphLookupLayer':
|
|
83
44
|
lookup = GraphLookupLayer()
|
|
@@ -203,7 +164,7 @@ class Gather(keras.layers.Layer):
|
|
|
203
164
|
|
|
204
165
|
|
|
205
166
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
206
|
-
class AminoAcidType(
|
|
167
|
+
class AminoAcidType(descriptors.Descriptor):
|
|
207
168
|
|
|
208
169
|
def __init__(self, vocab=None, **kwargs):
|
|
209
170
|
vocab = [
|
|
@@ -217,7 +178,7 @@ class AminoAcidType(features.Feature):
|
|
|
217
178
|
if not residue:
|
|
218
179
|
raise KeyError(f'Could not find {mol.canonical_smiles} in `residues_reverse`.')
|
|
219
180
|
mol = chem.remove_hs(mol)
|
|
220
|
-
return
|
|
181
|
+
return _extract_residue_type(residues_reverse[mol.canonical_smiles])
|
|
221
182
|
|
|
222
183
|
def sequence_split(sequence: str):
|
|
223
184
|
patterns = [
|
|
@@ -200,12 +200,13 @@ class MolGraphFeaturizer(Featurizer):
|
|
|
200
200
|
self.feature_dtype = 'float32'
|
|
201
201
|
self.index_dtype = 'int32'
|
|
202
202
|
|
|
203
|
-
def call(self,
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
203
|
+
def call(self, inputs: str | tuple) -> tensors.GraphTensor:
|
|
204
|
+
if isinstance(inputs, (tuple, list, np.ndarray)):
|
|
205
|
+
x, *context = inputs
|
|
206
|
+
if len(context) and isinstance(context[0], dict):
|
|
207
|
+
context = copy.deepcopy(context[0])
|
|
207
208
|
else:
|
|
208
|
-
|
|
209
|
+
x, context = inputs, None
|
|
209
210
|
|
|
210
211
|
mol = chem.Mol.from_encoding(x, explicit_hs=self.include_hs)
|
|
211
212
|
|
|
@@ -220,14 +221,30 @@ class MolGraphFeaturizer(Featurizer):
|
|
|
220
221
|
bond_feature = self.bond_features(mol)
|
|
221
222
|
context_feature = self.context_feature(mol)
|
|
222
223
|
molecule_size = self.num_atoms(mol)
|
|
223
|
-
|
|
224
|
-
context,
|
|
225
|
-
|
|
226
|
-
|
|
224
|
+
|
|
225
|
+
if isinstance(context, dict):
|
|
226
|
+
if 'x' in context:
|
|
227
|
+
context['feature'] = context.pop('x')
|
|
228
|
+
if 'y' in context:
|
|
229
|
+
context['label'] = context.pop('y')
|
|
230
|
+
if 'sample_weight' in context:
|
|
231
|
+
context['weight'] = context.pop('sample_weight')
|
|
232
|
+
context = {
|
|
233
|
+
**{'size': molecule_size},
|
|
234
|
+
**context
|
|
235
|
+
}
|
|
236
|
+
elif isinstance(context, list):
|
|
237
|
+
context = {
|
|
238
|
+
**{'size': molecule_size},
|
|
239
|
+
**{key: value for (key, value) in zip(['label', 'weight'], context)}
|
|
240
|
+
}
|
|
241
|
+
else:
|
|
242
|
+
context = {'size': molecule_size}
|
|
227
243
|
|
|
228
244
|
if context_feature is not None:
|
|
229
245
|
context['feature'] = context_feature
|
|
230
246
|
|
|
247
|
+
node = {}
|
|
231
248
|
node['feature'] = atom_feature
|
|
232
249
|
|
|
233
250
|
if bond_feature is not None and (self.radius > 1 or self.self_loops):
|
|
@@ -239,6 +256,7 @@ class MolGraphFeaturizer(Featurizer):
|
|
|
239
256
|
[bond_feature, zero_bond_feature], axis=0
|
|
240
257
|
)
|
|
241
258
|
|
|
259
|
+
edge = {}
|
|
242
260
|
if self.radius == 1:
|
|
243
261
|
edge['source'], edge['target'] = mol.adjacency(
|
|
244
262
|
fill='full', sparse=True, self_loops=self.self_loops, dtype=self.index_dtype
|
|
@@ -494,19 +512,25 @@ class MolGraphFeaturizer3D(MolGraphFeaturizer):
|
|
|
494
512
|
self.embed_conformer = self.conformer_generator is not None
|
|
495
513
|
self.radius = float(radius) if radius else None
|
|
496
514
|
|
|
497
|
-
def call(self,
|
|
515
|
+
def call(self, inputs: str | tuple) -> tensors.GraphTensor:
|
|
498
516
|
|
|
499
|
-
if isinstance(
|
|
500
|
-
x, *
|
|
517
|
+
if isinstance(inputs, (tuple, list, np.ndarray)):
|
|
518
|
+
x, *context = inputs
|
|
519
|
+
if len(context) and isinstance(context[0], dict):
|
|
520
|
+
context = copy.deepcopy(context[0])
|
|
501
521
|
else:
|
|
502
|
-
|
|
522
|
+
x, context = inputs, None
|
|
503
523
|
|
|
504
524
|
explicit_hs = (self.include_hs or self.embed_conformer)
|
|
505
525
|
mol = chem.Mol.from_encoding(x, explicit_hs=explicit_hs)
|
|
506
|
-
|
|
526
|
+
|
|
507
527
|
if mol is None:
|
|
528
|
+
warn(
|
|
529
|
+
f'Could not obtain `chem.Mol` from {x}. '
|
|
530
|
+
'Proceeding without it.'
|
|
531
|
+
)
|
|
508
532
|
return None
|
|
509
|
-
|
|
533
|
+
|
|
510
534
|
if self.embed_conformer:
|
|
511
535
|
mol = self.conformer_generator(mol)
|
|
512
536
|
if not self.include_hs:
|
|
@@ -519,21 +543,38 @@ class MolGraphFeaturizer3D(MolGraphFeaturizer):
|
|
|
519
543
|
'of the `Featurizer` or input a 3D representation of the molecule. '
|
|
520
544
|
)
|
|
521
545
|
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
546
|
+
context_feature = self.context_feature(mol)
|
|
547
|
+
molecule_size = self.num_atoms(mol) + int(self.super_atom)
|
|
548
|
+
|
|
549
|
+
if isinstance(context, dict):
|
|
550
|
+
if 'x' in context:
|
|
551
|
+
context['feature'] = context.pop('x')
|
|
552
|
+
if 'y' in context:
|
|
553
|
+
context['label'] = context.pop('y')
|
|
554
|
+
if 'sample_weight' in context:
|
|
555
|
+
context['weight'] = context.pop('sample_weight')
|
|
556
|
+
context = {
|
|
557
|
+
**{'size': molecule_size},
|
|
558
|
+
**context
|
|
559
|
+
}
|
|
560
|
+
elif isinstance(context, list):
|
|
561
|
+
context = {
|
|
562
|
+
**{'size': molecule_size},
|
|
563
|
+
**{key: value for (key, value) in zip(['label', 'weight'], context)}
|
|
564
|
+
}
|
|
565
|
+
else:
|
|
566
|
+
context = {'size': molecule_size}
|
|
527
567
|
|
|
568
|
+
if context_feature is not None:
|
|
569
|
+
context['feature'] = context_feature
|
|
570
|
+
|
|
571
|
+
node = {}
|
|
528
572
|
node['feature'] = self.atom_features(mol)
|
|
529
573
|
|
|
530
574
|
if self._bond_features:
|
|
531
575
|
edge_feature = self.bond_features(mol)
|
|
532
576
|
|
|
533
|
-
|
|
534
|
-
if context_feature is not None:
|
|
535
|
-
context['feature'] = context_feature
|
|
536
|
-
|
|
577
|
+
edge = {}
|
|
537
578
|
mols = chem._split_mol_by_confs(mol)
|
|
538
579
|
tensor_list = []
|
|
539
580
|
for i, mol in enumerate(mols):
|
|
@@ -563,11 +604,10 @@ class MolGraphFeaturizer3D(MolGraphFeaturizer):
|
|
|
563
604
|
node_conformer['coordinate'] = np.concatenate(
|
|
564
605
|
[node_conformer['coordinate'], conformer.centroid[None]], axis=0
|
|
565
606
|
)
|
|
566
|
-
|
|
567
607
|
tensor_list.append(
|
|
568
608
|
tensors.GraphTensor(context, node_conformer, edge_conformer)
|
|
569
609
|
)
|
|
570
|
-
|
|
610
|
+
|
|
571
611
|
return tensor_list
|
|
572
612
|
|
|
573
613
|
def stack(self, outputs):
|