molcraft 0.1.0a18__tar.gz → 0.1.0a20__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of molcraft might be problematic. Click here for more details.
- {molcraft-0.1.0a18 → molcraft-0.1.0a20}/PKG-INFO +1 -1
- {molcraft-0.1.0a18 → molcraft-0.1.0a20}/molcraft/__init__.py +1 -1
- {molcraft-0.1.0a18 → molcraft-0.1.0a20}/molcraft/applications/proteomics.py +2 -12
- {molcraft-0.1.0a18 → molcraft-0.1.0a20}/molcraft/featurizers.py +2 -0
- {molcraft-0.1.0a18 → molcraft-0.1.0a20}/molcraft/layers.py +11 -11
- {molcraft-0.1.0a18 → molcraft-0.1.0a20}/molcraft/tensors.py +16 -10
- {molcraft-0.1.0a18 → molcraft-0.1.0a20}/molcraft.egg-info/PKG-INFO +1 -1
- {molcraft-0.1.0a18 → molcraft-0.1.0a20}/tests/test_featurizers.py +14 -0
- {molcraft-0.1.0a18 → molcraft-0.1.0a20}/LICENSE +0 -0
- {molcraft-0.1.0a18 → molcraft-0.1.0a20}/README.md +0 -0
- {molcraft-0.1.0a18 → molcraft-0.1.0a20}/molcraft/applications/__init__.py +0 -0
- {molcraft-0.1.0a18 → molcraft-0.1.0a20}/molcraft/applications/chromatography.py +0 -0
- {molcraft-0.1.0a18 → molcraft-0.1.0a20}/molcraft/callbacks.py +0 -0
- {molcraft-0.1.0a18 → molcraft-0.1.0a20}/molcraft/chem.py +0 -0
- {molcraft-0.1.0a18 → molcraft-0.1.0a20}/molcraft/datasets.py +0 -0
- {molcraft-0.1.0a18 → molcraft-0.1.0a20}/molcraft/descriptors.py +0 -0
- {molcraft-0.1.0a18 → molcraft-0.1.0a20}/molcraft/features.py +0 -0
- {molcraft-0.1.0a18 → molcraft-0.1.0a20}/molcraft/losses.py +0 -0
- {molcraft-0.1.0a18 → molcraft-0.1.0a20}/molcraft/models.py +0 -0
- {molcraft-0.1.0a18 → molcraft-0.1.0a20}/molcraft/ops.py +0 -0
- {molcraft-0.1.0a18 → molcraft-0.1.0a20}/molcraft/records.py +0 -0
- {molcraft-0.1.0a18 → molcraft-0.1.0a20}/molcraft.egg-info/SOURCES.txt +0 -0
- {molcraft-0.1.0a18 → molcraft-0.1.0a20}/molcraft.egg-info/dependency_links.txt +0 -0
- {molcraft-0.1.0a18 → molcraft-0.1.0a20}/molcraft.egg-info/requires.txt +0 -0
- {molcraft-0.1.0a18 → molcraft-0.1.0a20}/molcraft.egg-info/top_level.txt +0 -0
- {molcraft-0.1.0a18 → molcraft-0.1.0a20}/pyproject.toml +0 -0
- {molcraft-0.1.0a18 → molcraft-0.1.0a20}/setup.cfg +0 -0
- {molcraft-0.1.0a18 → molcraft-0.1.0a20}/tests/test_chem.py +0 -0
- {molcraft-0.1.0a18 → molcraft-0.1.0a20}/tests/test_layers.py +0 -0
- {molcraft-0.1.0a18 → molcraft-0.1.0a20}/tests/test_losses.py +0 -0
- {molcraft-0.1.0a18 → molcraft-0.1.0a20}/tests/test_models.py +0 -0
- {molcraft-0.1.0a18 → molcraft-0.1.0a20}/tests/test_tensors.py +0 -0
|
@@ -165,18 +165,8 @@ class ResidueEmbedding(keras.layers.Layer):
|
|
|
165
165
|
residues = {**default_residues, **residues}
|
|
166
166
|
self._residues = {}
|
|
167
167
|
for residue, smiles in residues.items():
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
# It seems that the returned smiles ends with carboxyl group,
|
|
171
|
-
# though we do another check just in case.
|
|
172
|
-
if not has_c_terminal_mod(residue):
|
|
173
|
-
carboxyl_group = 'C(=O)O'
|
|
174
|
-
if not permuted_smiles.endswith(carboxyl_group):
|
|
175
|
-
raise ValueError(
|
|
176
|
-
f'Unsupported permutation of {residue!r} smiles: {permuted_smiles!r}.'
|
|
177
|
-
)
|
|
178
|
-
self._residues[residue] = permuted_smiles
|
|
179
|
-
self._residues[residue + '*'] = permuted_smiles.rstrip('O')
|
|
168
|
+
self._residues[residue] = smiles
|
|
169
|
+
self._residues[residue + '*'] = smiles.rstrip('O')
|
|
180
170
|
|
|
181
171
|
residue_keys = sorted(self._residues.keys())
|
|
182
172
|
residue_values = range(len(residue_keys))
|
|
@@ -323,7 +323,9 @@ class MolGraphFeaturizer3D(MolGraphFeaturizer):
|
|
|
323
323
|
include_hydrogens: bool = False,
|
|
324
324
|
radius: int | float | None = 6.0,
|
|
325
325
|
random_seed: int | None = None,
|
|
326
|
+
**kwargs,
|
|
326
327
|
) -> None:
|
|
328
|
+
kwargs.pop('bond_features', None)
|
|
327
329
|
super().__init__(
|
|
328
330
|
atom_features=atom_features,
|
|
329
331
|
bond_features=None,
|
|
@@ -279,7 +279,7 @@ class GraphConv(GraphLayer):
|
|
|
279
279
|
use_bias (bool):
|
|
280
280
|
Whether bias should be used in the dense layers. Default to `True`.
|
|
281
281
|
normalize (bool, str):
|
|
282
|
-
Whether normalization should be
|
|
282
|
+
Whether a normalization layer should be obtain by `get_norm()`. Default to `False`.
|
|
283
283
|
skip_connect (bool):
|
|
284
284
|
Whether node feature input should be added to the node feature output. Default to `True`.
|
|
285
285
|
kernel_initializer (keras.initializers.Initializer, str):
|
|
@@ -366,6 +366,7 @@ class GraphConv(GraphLayer):
|
|
|
366
366
|
has_overridden_message = self.__class__.message != GraphConv.message
|
|
367
367
|
if not has_overridden_message:
|
|
368
368
|
self._message_intermediate_dense = self.get_dense(self.units)
|
|
369
|
+
self._message_norm = self.get_norm()
|
|
369
370
|
self._message_intermediate_activation = self.activation
|
|
370
371
|
self._message_final_dense = self.get_dense(self.units)
|
|
371
372
|
|
|
@@ -376,19 +377,10 @@ class GraphConv(GraphLayer):
|
|
|
376
377
|
has_overridden_update = self.__class__.update != GraphConv.update
|
|
377
378
|
if not has_overridden_update:
|
|
378
379
|
self._update_intermediate_dense = self.get_dense(self.units)
|
|
380
|
+
self._update_norm = self.get_norm()
|
|
379
381
|
self._update_intermediate_activation = self.activation
|
|
380
382
|
self._update_final_dense = self.get_dense(self.units)
|
|
381
383
|
|
|
382
|
-
if not self._normalize:
|
|
383
|
-
self._message_norm = keras.layers.Identity()
|
|
384
|
-
self._update_norm = keras.layers.Identity()
|
|
385
|
-
elif str(self._normalize).lower().startswith('layer'):
|
|
386
|
-
self._message_norm = keras.layers.LayerNormalization()
|
|
387
|
-
self._update_norm = keras.layers.LayerNormalization()
|
|
388
|
-
else:
|
|
389
|
-
self._message_norm = keras.layers.BatchNormalization()
|
|
390
|
-
self._update_norm = keras.layers.BatchNormalization()
|
|
391
|
-
|
|
392
384
|
def propagate(self, tensor: tensors.GraphTensor) -> tensors.GraphTensor:
|
|
393
385
|
"""Forward pass.
|
|
394
386
|
|
|
@@ -533,6 +525,14 @@ class GraphConv(GraphLayer):
|
|
|
533
525
|
}
|
|
534
526
|
)
|
|
535
527
|
|
|
528
|
+
def get_norm(self, **kwargs):
|
|
529
|
+
if not self._normalize:
|
|
530
|
+
return keras.layers.Identity()
|
|
531
|
+
elif str(self._normalize).lower().startswith('layer'):
|
|
532
|
+
return keras.layers.LayerNormalization(**kwargs)
|
|
533
|
+
else:
|
|
534
|
+
return keras.layers.BatchNormalization(**kwargs)
|
|
535
|
+
|
|
536
536
|
def get_config(self) -> dict:
|
|
537
537
|
config = super().get_config()
|
|
538
538
|
config.update({
|
|
@@ -16,13 +16,15 @@ class GraphTensorBatchEncoder(tf.experimental.ExtensionTypeBatchEncoder):
|
|
|
16
16
|
if isinstance(f, tf.TensorSpec):
|
|
17
17
|
return tf.TensorSpec(
|
|
18
18
|
shape=[None] + f.shape[1:],
|
|
19
|
-
dtype=f.dtype
|
|
19
|
+
dtype=f.dtype
|
|
20
|
+
)
|
|
20
21
|
elif isinstance(f, tf.RaggedTensorSpec):
|
|
21
22
|
return tf.RaggedTensorSpec(
|
|
22
23
|
shape=[batch_size, None] + f.shape[1:],
|
|
23
24
|
dtype=f.dtype,
|
|
24
25
|
ragged_rank=1,
|
|
25
|
-
row_splits_dtype=f.row_splits_dtype
|
|
26
|
+
row_splits_dtype=f.row_splits_dtype
|
|
27
|
+
)
|
|
26
28
|
elif isinstance(f, tf.TypeSpec):
|
|
27
29
|
return f.__batch_encoder__.batch(f, batch_size)
|
|
28
30
|
return f
|
|
@@ -33,7 +35,8 @@ class GraphTensorBatchEncoder(tf.experimental.ExtensionTypeBatchEncoder):
|
|
|
33
35
|
batched_spec = object.__new__(type(spec))
|
|
34
36
|
batched_context_fields = tf.nest.map_structure(
|
|
35
37
|
lambda spec: tf.TensorSpec([batch_size] + spec.shape, spec.dtype),
|
|
36
|
-
context_fields
|
|
38
|
+
context_fields
|
|
39
|
+
)
|
|
37
40
|
batched_spec.__dict__.update({'context': batched_context_fields})
|
|
38
41
|
batched_spec.__dict__.update(batched_fields)
|
|
39
42
|
return batched_spec
|
|
@@ -46,13 +49,15 @@ class GraphTensorBatchEncoder(tf.experimental.ExtensionTypeBatchEncoder):
|
|
|
46
49
|
if isinstance(f, tf.TensorSpec):
|
|
47
50
|
return tf.TensorSpec(
|
|
48
51
|
shape=[None] + f.shape[1:],
|
|
49
|
-
dtype=f.dtype
|
|
52
|
+
dtype=f.dtype
|
|
53
|
+
)
|
|
50
54
|
elif isinstance(f, tf.RaggedTensorSpec):
|
|
51
55
|
return tf.RaggedTensorSpec(
|
|
52
56
|
shape=[None] + f.shape[2:],
|
|
53
57
|
dtype=f.dtype,
|
|
54
58
|
ragged_rank=0,
|
|
55
|
-
row_splits_dtype=f.row_splits_dtype
|
|
59
|
+
row_splits_dtype=f.row_splits_dtype
|
|
60
|
+
)
|
|
56
61
|
elif isinstance(f, tf.TypeSpec):
|
|
57
62
|
return f.__batch_encoder__.unbatch(f)
|
|
58
63
|
return f
|
|
@@ -62,7 +67,8 @@ class GraphTensorBatchEncoder(tf.experimental.ExtensionTypeBatchEncoder):
|
|
|
62
67
|
unbatched_fields = tf.nest.map_structure(unbatch_field, fields)
|
|
63
68
|
unbatched_context_fields = tf.nest.map_structure(
|
|
64
69
|
lambda spec: tf.TensorSpec(spec.shape[1:], spec.dtype),
|
|
65
|
-
context_fields
|
|
70
|
+
context_fields
|
|
71
|
+
)
|
|
66
72
|
unbatched_spec = object.__new__(type(spec))
|
|
67
73
|
unbatched_spec.__dict__.update({'context': unbatched_context_fields})
|
|
68
74
|
unbatched_spec.__dict__.update(unbatched_fields)
|
|
@@ -91,7 +97,8 @@ class GraphTensorBatchEncoder(tf.experimental.ExtensionTypeBatchEncoder):
|
|
|
91
97
|
shape=([None] if scalar else [None, None]) + f.shape[1:],
|
|
92
98
|
dtype=f.dtype,
|
|
93
99
|
ragged_rank=(0 if scalar else 1),
|
|
94
|
-
row_splits_dtype=spec.context['size'].dtype
|
|
100
|
+
row_splits_dtype=spec.context['size'].dtype
|
|
101
|
+
)
|
|
95
102
|
return f
|
|
96
103
|
fields = dict(spec.__dict__)
|
|
97
104
|
context_fields = fields.pop('context')
|
|
@@ -99,7 +106,7 @@ class GraphTensorBatchEncoder(tf.experimental.ExtensionTypeBatchEncoder):
|
|
|
99
106
|
encoded_fields = {**{'context': context_fields}, **encoded_fields}
|
|
100
107
|
spec_components = tuple(encoded_fields.values())
|
|
101
108
|
spec_components = tuple(
|
|
102
|
-
x for x in tf.nest.flatten(spec_components)
|
|
109
|
+
x for x in tf.nest.flatten(spec_components)
|
|
103
110
|
if isinstance(x, tf.TypeSpec)
|
|
104
111
|
)
|
|
105
112
|
return spec_components
|
|
@@ -117,7 +124,6 @@ class GraphTensorBatchEncoder(tf.experimental.ExtensionTypeBatchEncoder):
|
|
|
117
124
|
fields = dict(zip(spec.__dict__.keys(), value_tuple))
|
|
118
125
|
value = object.__new__(spec.value_type)
|
|
119
126
|
value.__dict__.update(fields)
|
|
120
|
-
|
|
121
127
|
flatten = is_ragged(value) and not is_ragged(spec)
|
|
122
128
|
if flatten:
|
|
123
129
|
value = value.flatten()
|
|
@@ -125,7 +131,7 @@ class GraphTensorBatchEncoder(tf.experimental.ExtensionTypeBatchEncoder):
|
|
|
125
131
|
|
|
126
132
|
|
|
127
133
|
class GraphTensor(tf.experimental.BatchableExtensionType):
|
|
128
|
-
context: typing.Mapping[str,
|
|
134
|
+
context: typing.Mapping[str, tf.Tensor]
|
|
129
135
|
node: typing.Mapping[str, typing.Union[tf.Tensor, tf.RaggedTensor]]
|
|
130
136
|
edge: typing.Mapping[str, typing.Union[tf.Tensor, tf.RaggedTensor]]
|
|
131
137
|
|
|
@@ -1,4 +1,6 @@
|
|
|
1
1
|
import unittest
|
|
2
|
+
import tempfile
|
|
3
|
+
import shutil
|
|
2
4
|
|
|
3
5
|
from molcraft import features
|
|
4
6
|
from molcraft import featurizers
|
|
@@ -42,6 +44,12 @@ class TestFeaturizer(unittest.TestCase):
|
|
|
42
44
|
include_hydrogens=False,
|
|
43
45
|
)
|
|
44
46
|
|
|
47
|
+
tmp_dir = tempfile.mkdtemp()
|
|
48
|
+
tmp_file = tmp_dir + '/featurizer.json'
|
|
49
|
+
featurizers.save_featurizer(featurizer, tmp_file)
|
|
50
|
+
_ = featurizers.load_featurizer(tmp_file)
|
|
51
|
+
shutil.rmtree(tmp_dir)
|
|
52
|
+
|
|
45
53
|
node_dim = 9
|
|
46
54
|
edge_dim = 4
|
|
47
55
|
|
|
@@ -127,6 +135,12 @@ class TestFeaturizer(unittest.TestCase):
|
|
|
127
135
|
radius=5.0,
|
|
128
136
|
)
|
|
129
137
|
|
|
138
|
+
tmp_dir = tempfile.mkdtemp()
|
|
139
|
+
tmp_file = tmp_dir + '/featurizer.json'
|
|
140
|
+
featurizers.save_featurizer(featurizer, tmp_file)
|
|
141
|
+
_ = featurizers.load_featurizer(tmp_file)
|
|
142
|
+
shutil.rmtree(tmp_dir)
|
|
143
|
+
|
|
130
144
|
node_dim = 10
|
|
131
145
|
edge_dim = 22
|
|
132
146
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|