molcraft 0.1.0a15__tar.gz → 0.1.0a17__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of molcraft might be problematic. Click here for more details.
- {molcraft-0.1.0a15 → molcraft-0.1.0a17}/PKG-INFO +14 -12
- {molcraft-0.1.0a15 → molcraft-0.1.0a17}/README.md +12 -11
- {molcraft-0.1.0a15 → molcraft-0.1.0a17}/molcraft/__init__.py +1 -2
- molcraft-0.1.0a17/molcraft/applications/chromatography.py +0 -0
- molcraft-0.1.0a17/molcraft/applications/proteomics.py +194 -0
- {molcraft-0.1.0a15 → molcraft-0.1.0a17}/molcraft/chem.py +17 -22
- {molcraft-0.1.0a15 → molcraft-0.1.0a17}/molcraft/datasets.py +6 -6
- {molcraft-0.1.0a15 → molcraft-0.1.0a17}/molcraft/descriptors.py +14 -0
- {molcraft-0.1.0a15 → molcraft-0.1.0a17}/molcraft/features.py +50 -58
- molcraft-0.1.0a17/molcraft/featurizers.py +523 -0
- {molcraft-0.1.0a15 → molcraft-0.1.0a17}/molcraft/layers.py +50 -0
- {molcraft-0.1.0a15 → molcraft-0.1.0a17}/molcraft/models.py +2 -0
- {molcraft-0.1.0a15 → molcraft-0.1.0a17}/molcraft/records.py +24 -15
- {molcraft-0.1.0a15 → molcraft-0.1.0a17}/molcraft.egg-info/PKG-INFO +14 -12
- {molcraft-0.1.0a15 → molcraft-0.1.0a17}/molcraft.egg-info/SOURCES.txt +3 -4
- {molcraft-0.1.0a15 → molcraft-0.1.0a17}/molcraft.egg-info/requires.txt +1 -0
- {molcraft-0.1.0a15 → molcraft-0.1.0a17}/pyproject.toml +1 -0
- {molcraft-0.1.0a15 → molcraft-0.1.0a17}/tests/test_featurizers.py +10 -17
- molcraft-0.1.0a15/molcraft/apps/peptides.py +0 -429
- molcraft-0.1.0a15/molcraft/apps/qsrr.py +0 -47
- molcraft-0.1.0a15/molcraft/conformers.py +0 -151
- molcraft-0.1.0a15/molcraft/featurizers.py +0 -753
- {molcraft-0.1.0a15 → molcraft-0.1.0a17}/LICENSE +0 -0
- {molcraft-0.1.0a15/molcraft/apps → molcraft-0.1.0a17/molcraft/applications}/__init__.py +0 -0
- {molcraft-0.1.0a15 → molcraft-0.1.0a17}/molcraft/callbacks.py +0 -0
- {molcraft-0.1.0a15 → molcraft-0.1.0a17}/molcraft/losses.py +0 -0
- {molcraft-0.1.0a15 → molcraft-0.1.0a17}/molcraft/ops.py +0 -0
- {molcraft-0.1.0a15 → molcraft-0.1.0a17}/molcraft/tensors.py +0 -0
- {molcraft-0.1.0a15 → molcraft-0.1.0a17}/molcraft.egg-info/dependency_links.txt +0 -0
- {molcraft-0.1.0a15 → molcraft-0.1.0a17}/molcraft.egg-info/top_level.txt +0 -0
- {molcraft-0.1.0a15 → molcraft-0.1.0a17}/setup.cfg +0 -0
- {molcraft-0.1.0a15 → molcraft-0.1.0a17}/tests/test_chem.py +0 -0
- {molcraft-0.1.0a15 → molcraft-0.1.0a17}/tests/test_layers.py +0 -0
- {molcraft-0.1.0a15 → molcraft-0.1.0a17}/tests/test_losses.py +0 -0
- {molcraft-0.1.0a15 → molcraft-0.1.0a17}/tests/test_models.py +0 -0
- {molcraft-0.1.0a15 → molcraft-0.1.0a17}/tests/test_tensors.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: molcraft
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.0a17
|
|
4
4
|
Summary: Graph Neural Networks for Molecular Machine Learning
|
|
5
5
|
Author-email: Alexander Kensert <alexander.kensert@gmail.com>
|
|
6
6
|
License: MIT License
|
|
@@ -35,6 +35,7 @@ Requires-Python: >=3.10
|
|
|
35
35
|
Description-Content-Type: text/markdown
|
|
36
36
|
License-File: LICENSE
|
|
37
37
|
Requires-Dist: tensorflow>=2.16
|
|
38
|
+
Requires-Dist: tensorflow-text>=2.16
|
|
38
39
|
Requires-Dist: rdkit>=2023.9.5
|
|
39
40
|
Requires-Dist: pandas>=1.0.3
|
|
40
41
|
Requires-Dist: ipython>=8.12.0
|
|
@@ -42,9 +43,9 @@ Provides-Extra: gpu
|
|
|
42
43
|
Requires-Dist: tensorflow[and-cuda]>=2.16; extra == "gpu"
|
|
43
44
|
Dynamic: license-file
|
|
44
45
|
|
|
45
|
-
<img src="https://github.com/akensert/molcraft/blob/main/docs/_static/molcraft-logo.png" alt="molcraft-logo">
|
|
46
|
+
<img src="https://github.com/akensert/molcraft/blob/main/docs/_static/molcraft-logo.png" alt="molcraft-logo", width="90%">
|
|
46
47
|
|
|
47
|
-
**Deep Learning on Molecules**: A Minimalistic GNN package for Molecular ML.
|
|
48
|
+
**Deep Learning on Molecules**: A Minimalistic GNN package for Molecular ML.
|
|
48
49
|
|
|
49
50
|
> [!NOTE]
|
|
50
51
|
> In progress.
|
|
@@ -82,11 +83,12 @@ featurizer = featurizers.MolGraphFeaturizer(
|
|
|
82
83
|
features.BondType(),
|
|
83
84
|
features.IsRotatable(),
|
|
84
85
|
],
|
|
85
|
-
|
|
86
|
+
super_node=True,
|
|
86
87
|
self_loops=True,
|
|
88
|
+
include_hydrogens=False,
|
|
87
89
|
)
|
|
88
90
|
|
|
89
|
-
graph = featurizer([('N[C@@H](C)C(=O)O', 2.
|
|
91
|
+
graph = featurizer([('N[C@@H](C)C(=O)O', 2.5), ('N[C@@H](CS)C(=O)O', 1.5)])
|
|
90
92
|
print(graph)
|
|
91
93
|
|
|
92
94
|
model = models.GraphModel.from_layers(
|
|
@@ -94,13 +96,13 @@ model = models.GraphModel.from_layers(
|
|
|
94
96
|
layers.Input(graph.spec),
|
|
95
97
|
layers.NodeEmbedding(dim=128),
|
|
96
98
|
layers.EdgeEmbedding(dim=128),
|
|
97
|
-
layers.
|
|
98
|
-
layers.
|
|
99
|
-
layers.
|
|
100
|
-
layers.
|
|
101
|
-
layers.Readout(
|
|
102
|
-
keras.layers.Dense(units=1024, activation='
|
|
103
|
-
keras.layers.Dense(units=1024, activation='
|
|
99
|
+
layers.GraphConv(units=128),
|
|
100
|
+
layers.GraphConv(units=128),
|
|
101
|
+
layers.GraphConv(units=128),
|
|
102
|
+
layers.GraphConv(units=128),
|
|
103
|
+
layers.Readout(),
|
|
104
|
+
keras.layers.Dense(units=1024, activation='elu'),
|
|
105
|
+
keras.layers.Dense(units=1024, activation='elu'),
|
|
104
106
|
keras.layers.Dense(1)
|
|
105
107
|
]
|
|
106
108
|
)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
<img src="https://github.com/akensert/molcraft/blob/main/docs/_static/molcraft-logo.png" alt="molcraft-logo">
|
|
1
|
+
<img src="https://github.com/akensert/molcraft/blob/main/docs/_static/molcraft-logo.png" alt="molcraft-logo", width="90%">
|
|
2
2
|
|
|
3
|
-
**Deep Learning on Molecules**: A Minimalistic GNN package for Molecular ML.
|
|
3
|
+
**Deep Learning on Molecules**: A Minimalistic GNN package for Molecular ML.
|
|
4
4
|
|
|
5
5
|
> [!NOTE]
|
|
6
6
|
> In progress.
|
|
@@ -38,11 +38,12 @@ featurizer = featurizers.MolGraphFeaturizer(
|
|
|
38
38
|
features.BondType(),
|
|
39
39
|
features.IsRotatable(),
|
|
40
40
|
],
|
|
41
|
-
|
|
41
|
+
super_node=True,
|
|
42
42
|
self_loops=True,
|
|
43
|
+
include_hydrogens=False,
|
|
43
44
|
)
|
|
44
45
|
|
|
45
|
-
graph = featurizer([('N[C@@H](C)C(=O)O', 2.
|
|
46
|
+
graph = featurizer([('N[C@@H](C)C(=O)O', 2.5), ('N[C@@H](CS)C(=O)O', 1.5)])
|
|
46
47
|
print(graph)
|
|
47
48
|
|
|
48
49
|
model = models.GraphModel.from_layers(
|
|
@@ -50,13 +51,13 @@ model = models.GraphModel.from_layers(
|
|
|
50
51
|
layers.Input(graph.spec),
|
|
51
52
|
layers.NodeEmbedding(dim=128),
|
|
52
53
|
layers.EdgeEmbedding(dim=128),
|
|
53
|
-
layers.
|
|
54
|
-
layers.
|
|
55
|
-
layers.
|
|
56
|
-
layers.
|
|
57
|
-
layers.Readout(
|
|
58
|
-
keras.layers.Dense(units=1024, activation='
|
|
59
|
-
keras.layers.Dense(units=1024, activation='
|
|
54
|
+
layers.GraphConv(units=128),
|
|
55
|
+
layers.GraphConv(units=128),
|
|
56
|
+
layers.GraphConv(units=128),
|
|
57
|
+
layers.GraphConv(units=128),
|
|
58
|
+
layers.Readout(),
|
|
59
|
+
keras.layers.Dense(units=1024, activation='elu'),
|
|
60
|
+
keras.layers.Dense(units=1024, activation='elu'),
|
|
60
61
|
keras.layers.Dense(1)
|
|
61
62
|
]
|
|
62
63
|
)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
__version__ = '0.1.
|
|
1
|
+
__version__ = '0.1.0a17'
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
4
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
|
@@ -6,7 +6,6 @@ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
|
|
6
6
|
from molcraft import chem
|
|
7
7
|
from molcraft import features
|
|
8
8
|
from molcraft import descriptors
|
|
9
|
-
from molcraft import conformers
|
|
10
9
|
from molcraft import featurizers
|
|
11
10
|
from molcraft import layers
|
|
12
11
|
from molcraft import models
|
|
File without changes
|
|
@@ -0,0 +1,194 @@
|
|
|
1
|
+
import re
|
|
2
|
+
import keras
|
|
3
|
+
import numpy as np
|
|
4
|
+
import tensorflow as tf
|
|
5
|
+
import tensorflow_text as tf_text
|
|
6
|
+
|
|
7
|
+
from molcraft import featurizers
|
|
8
|
+
from molcraft import tensors
|
|
9
|
+
from molcraft import layers
|
|
10
|
+
from molcraft import models
|
|
11
|
+
from molcraft import chem
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
# TODO: Add regex pattern for residue (C-term mod + N-term mod)?
|
|
15
|
+
# TODO: Add regex pattern for residue (C-term mod + N-term mod + mod)?
|
|
16
|
+
residue_pattern: str = "|".join([
|
|
17
|
+
r'(\[[A-Za-z0-9]+\]-[A-Z]\[[A-Za-z0-9]+\])', # residue (N-term mod + mod)
|
|
18
|
+
r'([A-Z]\[[A-Za-z0-9]+\]-\[[A-Za-z0-9]+\])', # residue (C-term mod + mod)
|
|
19
|
+
r'([A-Z]-\[[A-Za-z0-9]+\])', # residue (C-term mod)
|
|
20
|
+
r'(\[[A-Za-z0-9]+\]-[A-Z])', # residue (N-term mod)
|
|
21
|
+
r'([A-Z]\[[A-Za-z0-9]+\])', # residue (mod)
|
|
22
|
+
r'([A-Z])', # residue (no mod)
|
|
23
|
+
])
|
|
24
|
+
|
|
25
|
+
default_residues: dict[str, str] = {
|
|
26
|
+
"A": "N[C@@H](C)C(=O)O",
|
|
27
|
+
"C": "N[C@@H](CS)C(=O)O",
|
|
28
|
+
"D": "N[C@@H](CC(=O)O)C(=O)O",
|
|
29
|
+
"E": "N[C@@H](CCC(=O)O)C(=O)O",
|
|
30
|
+
"F": "N[C@@H](Cc1ccccc1)C(=O)O",
|
|
31
|
+
"G": "NCC(=O)O",
|
|
32
|
+
"H": "N[C@@H](CC1=CN=C-N1)C(=O)O",
|
|
33
|
+
"I": "N[C@@H](C(CC)C)C(=O)O",
|
|
34
|
+
"K": "N[C@@H](CCCCN)C(=O)O",
|
|
35
|
+
"L": "N[C@@H](CC(C)C)C(=O)O",
|
|
36
|
+
"M": "N[C@@H](CCSC)C(=O)O",
|
|
37
|
+
"N": "N[C@@H](CC(=O)N)C(=O)O",
|
|
38
|
+
"P": "N1[C@@H](CCC1)C(=O)O",
|
|
39
|
+
"Q": "N[C@@H](CCC(=O)N)C(=O)O",
|
|
40
|
+
"R": "N[C@@H](CCCNC(=N)N)C(=O)O",
|
|
41
|
+
"S": "N[C@@H](CO)C(=O)O",
|
|
42
|
+
"T": "N[C@@H](C(O)C)C(=O)O",
|
|
43
|
+
"V": "N[C@@H](C(C)C)C(=O)O",
|
|
44
|
+
"W": "N[C@@H](CC(=CN2)C1=C2C=CC=C1)C(=O)O",
|
|
45
|
+
"Y": "N[C@@H](Cc1ccc(O)cc1)C(=O)O",
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
def register_residues(residues: dict[str, str]) -> None:
|
|
49
|
+
# TODO: Implement functions that check if residue has N- or C-terminal mod
|
|
50
|
+
# if C-terminal mod, no need to enforce concatenatable perm.
|
|
51
|
+
# if N-terminal mod, enforce only 'C(=O)O'
|
|
52
|
+
# if normal mod, enforce concatenateable perm ('N[C@@H]' and 'C(=O)O)).
|
|
53
|
+
for residue, smiles in residues.items():
|
|
54
|
+
if residue.startswith('P'):
|
|
55
|
+
smiles.startswith('N'), f'Incorrect SMILES permutation for {residue}.'
|
|
56
|
+
elif not residue.startswith('['):
|
|
57
|
+
smiles.startswith('N[C@@H]'), f'Incorrect SMILES permutation for {residue}.'
|
|
58
|
+
if len(residue) > 1 and not residue[1] == "-":
|
|
59
|
+
assert smiles.endswith('C(=O)O'), f'Incorrect SMILES permutation for {residue}.'
|
|
60
|
+
registered_residues[residue] = smiles
|
|
61
|
+
registered_residues[residue + '*'] = smiles.strip('O')
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class Peptide(chem.Mol):
|
|
65
|
+
|
|
66
|
+
@classmethod
|
|
67
|
+
def from_sequence(cls, sequence: str, **kwargs) -> 'Peptide':
|
|
68
|
+
sequence = [
|
|
69
|
+
match.group(0) for match in re.finditer(residue_pattern, sequence)
|
|
70
|
+
]
|
|
71
|
+
peptide_smiles = []
|
|
72
|
+
for i, residue in enumerate(sequence):
|
|
73
|
+
if i < len(sequence) - 1:
|
|
74
|
+
residue_smiles = registered_residues[residue + '*']
|
|
75
|
+
else:
|
|
76
|
+
residue_smiles = registered_residues[residue]
|
|
77
|
+
peptide_smiles.append(residue_smiles)
|
|
78
|
+
peptide_smiles = ''.join(peptide_smiles)
|
|
79
|
+
return super().from_encoding(peptide_smiles, **kwargs)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
@keras.saving.register_keras_serializable(package='proteomics')
|
|
83
|
+
class ResidueEmbedding(keras.layers.Layer):
|
|
84
|
+
|
|
85
|
+
def __init__(
|
|
86
|
+
self,
|
|
87
|
+
featurizer: featurizers.MolGraphFeaturizer,
|
|
88
|
+
embedder: models.GraphModel,
|
|
89
|
+
residues: dict[str, str] | None = None,
|
|
90
|
+
**kwargs
|
|
91
|
+
) -> None:
|
|
92
|
+
super().__init__(**kwargs)
|
|
93
|
+
if residues is None:
|
|
94
|
+
residues = {}
|
|
95
|
+
self._residue_dict = {**default_residues, **residues}
|
|
96
|
+
self.embedder = embedder
|
|
97
|
+
self.featurizer = featurizer
|
|
98
|
+
self.embedding_dim = self.embedder.output.shape[-1]
|
|
99
|
+
self.ragged_split = SequenceSplitter(pad=False)
|
|
100
|
+
self.split = SequenceSplitter(pad=True)
|
|
101
|
+
self.use_cached_embeddings = tf.Variable(False)
|
|
102
|
+
self.supports_masking = True
|
|
103
|
+
|
|
104
|
+
@property
|
|
105
|
+
def residues(self) -> dict[str, str]:
|
|
106
|
+
return self._residue_dict
|
|
107
|
+
|
|
108
|
+
@residues.setter
|
|
109
|
+
def residues(self, residues: dict[str, str]) -> None:
|
|
110
|
+
self._residue_dict = residues
|
|
111
|
+
num_residues = len(residues)
|
|
112
|
+
residue_keys = sorted(residues.keys())
|
|
113
|
+
oov_value = np.where(np.array(residue_keys) == "G")[0][0]
|
|
114
|
+
self.mapping = tf.lookup.StaticHashTable(
|
|
115
|
+
tf.lookup.KeyValueTensorInitializer(
|
|
116
|
+
keys=residue_keys,
|
|
117
|
+
values=range(num_residues)
|
|
118
|
+
),
|
|
119
|
+
default_value=oov_value,
|
|
120
|
+
)
|
|
121
|
+
self.graph = tf.stack([
|
|
122
|
+
self.featurizer(residues[residue]) for residue in residue_keys
|
|
123
|
+
], axis=0)
|
|
124
|
+
self.cached_embeddings = tf.Variable(
|
|
125
|
+
initial_value=tf.zeros((num_residues, self.embedding_dim))
|
|
126
|
+
)
|
|
127
|
+
_ = self.cache_and_get_embeddings()
|
|
128
|
+
|
|
129
|
+
def build(self, input_shape) -> None:
|
|
130
|
+
self.residues = self._residue_dict
|
|
131
|
+
super().build(input_shape)
|
|
132
|
+
|
|
133
|
+
def call(self, sequences: tf.Tensor, training: bool = None) -> tf.Tensor:
|
|
134
|
+
if training is False:
|
|
135
|
+
self.use_cached_embeddings.assign(True)
|
|
136
|
+
else:
|
|
137
|
+
self.use_cached_embeddings.assign(False)
|
|
138
|
+
embeddings = tf.cond(
|
|
139
|
+
pred=self.use_cached_embeddings,
|
|
140
|
+
true_fn=lambda: self.cached_embeddings,
|
|
141
|
+
false_fn=lambda: self.cache_and_get_embeddings(),
|
|
142
|
+
)
|
|
143
|
+
sequences = self.ragged_split(sequences)
|
|
144
|
+
sequences = keras.ops.concatenate([
|
|
145
|
+
tf.strings.join([sequences[:, :-1], '*']), sequences[:, -1:]
|
|
146
|
+
], axis=1)
|
|
147
|
+
indices = self.mapping.lookup(sequences)
|
|
148
|
+
return tf.gather(embeddings, indices).to_tensor()
|
|
149
|
+
|
|
150
|
+
def cache_and_get_embeddings(self) -> tf.Tensor:
|
|
151
|
+
embeddings = self.embedder(self.graph)
|
|
152
|
+
self.cached_embeddings.assign(embeddings)
|
|
153
|
+
return embeddings
|
|
154
|
+
|
|
155
|
+
def compute_mask(
|
|
156
|
+
self,
|
|
157
|
+
inputs: tensors.GraphTensor,
|
|
158
|
+
mask: bool | None = None
|
|
159
|
+
) -> tf.Tensor | None:
|
|
160
|
+
sequences = self.split(inputs)
|
|
161
|
+
return keras.ops.not_equal(sequences, '')
|
|
162
|
+
|
|
163
|
+
def get_config(self) -> dict:
|
|
164
|
+
config = super().get_config()
|
|
165
|
+
config.update({
|
|
166
|
+
'featurizer': keras.saving.serialize_keras_object(self.featurizer),
|
|
167
|
+
'embedder': keras.saving.serialize_keras_object(self.embedder),
|
|
168
|
+
'residues': self._residue_dict,
|
|
169
|
+
})
|
|
170
|
+
return config
|
|
171
|
+
|
|
172
|
+
@classmethod
|
|
173
|
+
def from_config(cls, config: dict) -> 'ResidueEmbedding':
|
|
174
|
+
config['featurizer'] = keras.saving.deserialize_keras_object(config['featurizer'])
|
|
175
|
+
config['embedder'] = keras.saving.deserialize_keras_object(config['embedder'])
|
|
176
|
+
return super().from_config(config)
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
@keras.saving.register_keras_serializable(package='proteomics')
|
|
180
|
+
class SequenceSplitter(keras.layers.Layer):
|
|
181
|
+
|
|
182
|
+
def __init__(self, pad: bool, **kwargs):
|
|
183
|
+
super().__init__(**kwargs)
|
|
184
|
+
self.pad = pad
|
|
185
|
+
|
|
186
|
+
def call(self, inputs: tf.Tensor) -> tf.Tensor | tf.RaggedTensor:
|
|
187
|
+
inputs = tf_text.regex_split(inputs, residue_pattern, residue_pattern)
|
|
188
|
+
if self.pad:
|
|
189
|
+
inputs = inputs.to_tensor()
|
|
190
|
+
return inputs
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
registered_residues: dict[str, str] = {}
|
|
194
|
+
register_residues(default_residues)
|
|
@@ -19,8 +19,6 @@ class Mol(Chem.Mol):
|
|
|
19
19
|
@classmethod
|
|
20
20
|
def from_encoding(cls, encoding: str, explicit_hs: bool = False, **kwargs) -> 'Mol':
|
|
21
21
|
rdkit_mol = get_mol(encoding, **kwargs)
|
|
22
|
-
if not rdkit_mol:
|
|
23
|
-
return None
|
|
24
22
|
if explicit_hs:
|
|
25
23
|
rdkit_mol = Chem.AddHs(rdkit_mol)
|
|
26
24
|
rdkit_mol.__class__ = cls
|
|
@@ -102,21 +100,13 @@ class Mol(Chem.Mol):
|
|
|
102
100
|
|
|
103
101
|
def get_conformer(self, index: int = 0) -> 'Conformer':
|
|
104
102
|
if self.num_conformers == 0:
|
|
105
|
-
warnings.warn(
|
|
106
|
-
'Molecule has no conformer. To embed conformer(s), invoke the `embed` method, '
|
|
107
|
-
'and optionally followed by `minimize()` to perform force field minimization.',
|
|
108
|
-
stacklevel=2
|
|
109
|
-
)
|
|
103
|
+
warnings.warn('Molecule has no conformer.')
|
|
110
104
|
return None
|
|
111
105
|
return Conformer.cast(self.GetConformer(index))
|
|
112
106
|
|
|
113
107
|
def get_conformers(self) -> list['Conformer']:
|
|
114
108
|
if self.num_conformers == 0:
|
|
115
|
-
warnings.warn(
|
|
116
|
-
'Molecule has no conformers. To embed conformers, invoke the `embed` method, '
|
|
117
|
-
'and optionally followed by `minimize()` to perform force field minimization.',
|
|
118
|
-
stacklevel=2
|
|
119
|
-
)
|
|
109
|
+
warnings.warn('Molecule has no conformer.')
|
|
120
110
|
return []
|
|
121
111
|
return [Conformer.cast(x) for x in self.GetConformers()]
|
|
122
112
|
|
|
@@ -222,11 +212,10 @@ def get_mol(
|
|
|
222
212
|
else:
|
|
223
213
|
mol = Chem.MolFromSmiles(encoding, sanitize=False)
|
|
224
214
|
if mol is not None:
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
)
|
|
215
|
+
mol = sanitize_mol(mol, strict, assign_stereo_chemistry)
|
|
216
|
+
if mol is not None:
|
|
217
|
+
return mol
|
|
218
|
+
raise ValueError(f'Could not obtain `chem.Mol` from {encoding}.')
|
|
230
219
|
|
|
231
220
|
def get_adjacency_matrix(
|
|
232
221
|
mol: Chem.Mol,
|
|
@@ -402,8 +391,9 @@ def embed_conformers(
|
|
|
402
391
|
mol: Mol,
|
|
403
392
|
num_conformers: int,
|
|
404
393
|
method: str = 'ETKDGv3',
|
|
394
|
+
random_seed: int | None = None,
|
|
405
395
|
**kwargs
|
|
406
|
-
) ->
|
|
396
|
+
) -> Mol:
|
|
407
397
|
available_embedding_methods = {
|
|
408
398
|
'ETDG': rdDistGeom.ETDG(),
|
|
409
399
|
'ETKDG': rdDistGeom.ETKDG(),
|
|
@@ -423,6 +413,9 @@ def embed_conformers(
|
|
|
423
413
|
for key, value in kwargs.items():
|
|
424
414
|
setattr(embedding_method, key, value)
|
|
425
415
|
|
|
416
|
+
if random_seed is not None:
|
|
417
|
+
embedding_method.randomSeed = random_seed
|
|
418
|
+
|
|
426
419
|
success = rdDistGeom.EmbedMultipleConfs(
|
|
427
420
|
mol, numConfs=num_conformers, params=embedding_method
|
|
428
421
|
)
|
|
@@ -440,6 +433,8 @@ def embed_conformers(
|
|
|
440
433
|
fallback_embedding_method.useRandomCoords = True
|
|
441
434
|
fallback_embedding_method.maxAttempts = max_attempts
|
|
442
435
|
fallback_embedding_method.clearConfs = False
|
|
436
|
+
if random_seed is not None:
|
|
437
|
+
fallback_embedding_method.randomSeed = random_seed
|
|
443
438
|
success = rdDistGeom.EmbedMultipleConfs(
|
|
444
439
|
mol, numConfs=(num_conformers - num_successes), params=fallback_embedding_method
|
|
445
440
|
)
|
|
@@ -459,7 +454,7 @@ def optimize_conformers(
|
|
|
459
454
|
num_threads: bool = 1,
|
|
460
455
|
ignore_interfragment_interactions: bool = True,
|
|
461
456
|
vdw_threshold: float = 10.0,
|
|
462
|
-
):
|
|
457
|
+
) -> Mol:
|
|
463
458
|
available_force_field_methods = [
|
|
464
459
|
'MMFF', 'MMFF94', 'MMFF94s', 'UFF'
|
|
465
460
|
]
|
|
@@ -502,7 +497,7 @@ def prune_conformers(
|
|
|
502
497
|
keep: int = 1,
|
|
503
498
|
threshold: float = 0.0,
|
|
504
499
|
energy_force_field: str = 'UFF',
|
|
505
|
-
):
|
|
500
|
+
) -> Mol:
|
|
506
501
|
if mol.num_conformers == 0:
|
|
507
502
|
warnings.warn(
|
|
508
503
|
'Molecule has no conformers. To embed conformers, invoke the `embed` method, '
|
|
@@ -539,7 +534,7 @@ def _uff_optimize_conformers(
|
|
|
539
534
|
vdw_threshold: float = 10.0,
|
|
540
535
|
ignore_interfragment_interactions: bool = True,
|
|
541
536
|
**kwargs,
|
|
542
|
-
) ->
|
|
537
|
+
) -> tuple[list[float], list[bool]]:
|
|
543
538
|
"""Universal Force Field Minimization.
|
|
544
539
|
"""
|
|
545
540
|
results = rdForceFieldHelpers.UFFOptimizeMoleculeConfs(
|
|
@@ -560,7 +555,7 @@ def _mmff_optimize_conformers(
|
|
|
560
555
|
variant: str = 'MMFF94',
|
|
561
556
|
ignore_interfragment_interactions: bool = True,
|
|
562
557
|
**kwargs,
|
|
563
|
-
) ->
|
|
558
|
+
) -> tuple[list[float], list[bool]]:
|
|
564
559
|
"""Merck Molecular Force Field Minimization.
|
|
565
560
|
"""
|
|
566
561
|
if not rdForceFieldHelpers.MMFFHasAllMoleculeParams(mol):
|
|
@@ -11,7 +11,7 @@ def split(
|
|
|
11
11
|
test_size: float | None = None,
|
|
12
12
|
groups: str | np.ndarray = None,
|
|
13
13
|
shuffle: bool = False,
|
|
14
|
-
|
|
14
|
+
random_seed: int | None = None,
|
|
15
15
|
) -> tuple[np.ndarray | pd.DataFrame, ...]:
|
|
16
16
|
"""Splits the dataset into subsets.
|
|
17
17
|
|
|
@@ -28,7 +28,7 @@ def split(
|
|
|
28
28
|
The groups to perform the splitting on.
|
|
29
29
|
shuffle:
|
|
30
30
|
Whether the dataset should be shuffled prior to splitting.
|
|
31
|
-
|
|
31
|
+
random_seed:
|
|
32
32
|
The random state/seed. Only applicable if shuffling.
|
|
33
33
|
"""
|
|
34
34
|
if not isinstance(data, (pd.DataFrame, np.ndarray)):
|
|
@@ -69,7 +69,7 @@ def split(
|
|
|
69
69
|
train_size += remainder
|
|
70
70
|
|
|
71
71
|
if shuffle:
|
|
72
|
-
np.random.seed(
|
|
72
|
+
np.random.seed(random_seed)
|
|
73
73
|
np.random.shuffle(indices)
|
|
74
74
|
|
|
75
75
|
train_mask = np.isin(groups, indices[:train_size])
|
|
@@ -84,7 +84,7 @@ def cv_split(
|
|
|
84
84
|
num_splits: int = 10,
|
|
85
85
|
groups: str | np.ndarray = None,
|
|
86
86
|
shuffle: bool = False,
|
|
87
|
-
|
|
87
|
+
random_seed: int | None = None,
|
|
88
88
|
) -> typing.Iterator[
|
|
89
89
|
tuple[np.ndarray | pd.DataFrame, np.ndarray | pd.DataFrame]
|
|
90
90
|
]:
|
|
@@ -99,7 +99,7 @@ def cv_split(
|
|
|
99
99
|
The groups to perform the splitting on.
|
|
100
100
|
shuffle:
|
|
101
101
|
Whether the dataset should be shuffled prior to splitting.
|
|
102
|
-
|
|
102
|
+
random_seed:
|
|
103
103
|
The random state/seed. Only applicable if shuffling.
|
|
104
104
|
"""
|
|
105
105
|
if not isinstance(data, (pd.DataFrame, np.ndarray)):
|
|
@@ -119,7 +119,7 @@ def cv_split(
|
|
|
119
119
|
f'the data size or the number of groups ({size}).'
|
|
120
120
|
)
|
|
121
121
|
if shuffle:
|
|
122
|
-
np.random.seed(
|
|
122
|
+
np.random.seed(random_seed)
|
|
123
123
|
np.random.shuffle(indices)
|
|
124
124
|
|
|
125
125
|
indices_splits = np.array_split(indices, num_splits)
|
|
@@ -91,3 +91,17 @@ class NumRings(Descriptor):
|
|
|
91
91
|
def call(self, mol: chem.Mol) -> np.ndarray:
|
|
92
92
|
return rdMolDescriptors.CalcNumRings(mol)
|
|
93
93
|
|
|
94
|
+
|
|
95
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
96
|
+
class AtomCount(Descriptor):
|
|
97
|
+
|
|
98
|
+
def __init__(self, atom_type: str, **kwargs):
|
|
99
|
+
super().__init__(**kwargs)
|
|
100
|
+
self.atom_type = atom_type
|
|
101
|
+
|
|
102
|
+
def call(self, mol: chem.Mol) -> np.ndarray:
|
|
103
|
+
count = 0
|
|
104
|
+
for atom in mol.atoms:
|
|
105
|
+
if atom.GetSymbol() == self.atom_type:
|
|
106
|
+
count += 1
|
|
107
|
+
return count
|
|
@@ -41,11 +41,7 @@ class Feature(abc.ABC):
|
|
|
41
41
|
|
|
42
42
|
def __call__(self, mol: chem.Mol) -> np.ndarray:
|
|
43
43
|
if not isinstance(mol, chem.Mol):
|
|
44
|
-
raise
|
|
45
|
-
f'Input to {self.name} needs to be a `chem.Mol`, which '
|
|
46
|
-
'implements two properties that should be iterated over '
|
|
47
|
-
'to compute features: `atoms` and `bonds`.'
|
|
48
|
-
)
|
|
44
|
+
raise TypeError(f'Input to {self.name} must be a `chem.Mol` instance.')
|
|
49
45
|
features = self.call(mol)
|
|
50
46
|
if len(features) != mol.num_atoms and len(features) != mol.num_bonds:
|
|
51
47
|
raise ValueError(
|
|
@@ -119,59 +115,6 @@ class Feature(abc.ABC):
|
|
|
119
115
|
return np.asarray([value], dtype=self.dtype)
|
|
120
116
|
|
|
121
117
|
|
|
122
|
-
@keras.saving.register_keras_serializable(package='molcraft')
|
|
123
|
-
class EdgeFeature(Feature):
|
|
124
|
-
|
|
125
|
-
def __call__(self, mol: chem.Mol) -> np.ndarray:
|
|
126
|
-
if not isinstance(mol, chem.Mol):
|
|
127
|
-
raise ValueError(
|
|
128
|
-
f'Input to {self.name} needs to be a `chem.Mol`, which '
|
|
129
|
-
'implements two properties that should be iterated over '
|
|
130
|
-
'to compute features: `atoms` and `bonds`.'
|
|
131
|
-
)
|
|
132
|
-
features = self.call(mol)
|
|
133
|
-
if len(features) != int(mol.num_atoms**2):
|
|
134
|
-
raise ValueError(
|
|
135
|
-
f'The number of features computed by {self.name} does not '
|
|
136
|
-
'match the number of node pairs in the `chem.Mol` object. '
|
|
137
|
-
f'Make sure the list of items returned by {self.name}(input) '
|
|
138
|
-
'correspond to node/atom pairs: '
|
|
139
|
-
'[(0, 0), (0, 1), ..., (0, N), (1, 0), ... (N, N)], '
|
|
140
|
-
'where N denotes the number of nodes/atoms.'
|
|
141
|
-
)
|
|
142
|
-
func = (
|
|
143
|
-
self._featurize_categorical if self.vocab else
|
|
144
|
-
self._featurize_floating
|
|
145
|
-
)
|
|
146
|
-
return np.asarray([func(x) for x in features], dtype=self.dtype)
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
@keras.saving.register_keras_serializable(package='molcraft')
|
|
150
|
-
class Distance(EdgeFeature):
|
|
151
|
-
|
|
152
|
-
def __init__(
|
|
153
|
-
self,
|
|
154
|
-
max_distance: int = None,
|
|
155
|
-
allow_oov: int = True,
|
|
156
|
-
encode_oov: bool = True,
|
|
157
|
-
**kwargs,
|
|
158
|
-
) -> None:
|
|
159
|
-
vocab = kwargs.pop('vocab', None)
|
|
160
|
-
if not vocab:
|
|
161
|
-
if max_distance is None:
|
|
162
|
-
max_distance = 20
|
|
163
|
-
vocab = list(range(max_distance + 1))
|
|
164
|
-
super().__init__(
|
|
165
|
-
vocab=vocab,
|
|
166
|
-
allow_oov=allow_oov,
|
|
167
|
-
encode_oov=encode_oov,
|
|
168
|
-
**kwargs
|
|
169
|
-
)
|
|
170
|
-
|
|
171
|
-
def call(self, mol: chem.Mol) -> list[int]:
|
|
172
|
-
return [int(x) for x in chem.get_distances(mol).reshape(-1)]
|
|
173
|
-
|
|
174
|
-
|
|
175
118
|
@keras.saving.register_keras_serializable(package='molcraft')
|
|
176
119
|
class AtomType(Feature):
|
|
177
120
|
def call(self, mol: chem.Mol) -> list[int, float, str]:
|
|
@@ -340,6 +283,55 @@ class IsRotatable(Feature):
|
|
|
340
283
|
return chem.rotatable_bonds(mol)
|
|
341
284
|
|
|
342
285
|
|
|
286
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
287
|
+
class PairFeature(Feature):
|
|
288
|
+
|
|
289
|
+
def __call__(self, mol: chem.Mol) -> np.ndarray:
|
|
290
|
+
if not isinstance(mol, chem.Mol):
|
|
291
|
+
raise TypeError(f'Input to {self.name} must be a `chem.Mol` instance.')
|
|
292
|
+
features = self.call(mol)
|
|
293
|
+
if len(features) != int(mol.num_atoms**2):
|
|
294
|
+
raise ValueError(
|
|
295
|
+
f'The number of features computed by {self.name} does not '
|
|
296
|
+
'match the number of node/atom pairs in the `chem.Mol` object. '
|
|
297
|
+
f'Make sure the list of items returned by {self.name}(input) '
|
|
298
|
+
'correspond to node/atom pairs: '
|
|
299
|
+
'[(0, 0), (0, 1), ..., (0, N), (1, 0), ... (N, N)], '
|
|
300
|
+
'where N denotes the number of nodes/atoms.'
|
|
301
|
+
)
|
|
302
|
+
func = (
|
|
303
|
+
self._featurize_categorical if self.vocab else
|
|
304
|
+
self._featurize_floating
|
|
305
|
+
)
|
|
306
|
+
return np.asarray([func(x) for x in features], dtype=self.dtype)
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
310
|
+
class PairDistance(PairFeature):
|
|
311
|
+
|
|
312
|
+
def __init__(
|
|
313
|
+
self,
|
|
314
|
+
max_distance: int = None,
|
|
315
|
+
allow_oov: int = True,
|
|
316
|
+
encode_oov: bool = True,
|
|
317
|
+
**kwargs,
|
|
318
|
+
) -> None:
|
|
319
|
+
vocab = kwargs.pop('vocab', None)
|
|
320
|
+
if not vocab:
|
|
321
|
+
if max_distance is None:
|
|
322
|
+
max_distance = 10
|
|
323
|
+
vocab = list(range(max_distance + 1))
|
|
324
|
+
super().__init__(
|
|
325
|
+
vocab=vocab,
|
|
326
|
+
allow_oov=allow_oov,
|
|
327
|
+
encode_oov=encode_oov,
|
|
328
|
+
**kwargs
|
|
329
|
+
)
|
|
330
|
+
|
|
331
|
+
def call(self, mol: chem.Mol) -> list[int]:
|
|
332
|
+
return [int(x) for x in chem.get_distances(mol).reshape(-1)]
|
|
333
|
+
|
|
334
|
+
|
|
343
335
|
default_vocabulary = {
|
|
344
336
|
'AtomType': [
|
|
345
337
|
'*', 'H', 'He', 'Li', 'Be', 'B', 'C', 'N', 'O', 'F', 'Ne', 'Na',
|