molcraft 0.1.0a1__py3-none-any.whl → 0.1.0a3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of molcraft might be problematic. Click here for more details.
- molcraft/__init__.py +2 -1
- molcraft/datasets.py +123 -0
- molcraft/experimental/peptides.py +28 -67
- molcraft/features.py +5 -3
- molcraft/featurizers.py +68 -27
- molcraft/layers.py +1299 -647
- molcraft/models.py +35 -5
- molcraft/tensors.py +33 -12
- {molcraft-0.1.0a1.dist-info → molcraft-0.1.0a3.dist-info}/METADATA +68 -1
- molcraft-0.1.0a3.dist-info/RECORD +20 -0
- molcraft-0.1.0a1.dist-info/RECORD +0 -19
- {molcraft-0.1.0a1.dist-info → molcraft-0.1.0a3.dist-info}/WHEEL +0 -0
- {molcraft-0.1.0a1.dist-info → molcraft-0.1.0a3.dist-info}/licenses/LICENSE +0 -0
- {molcraft-0.1.0a1.dist-info → molcraft-0.1.0a3.dist-info}/top_level.txt +0 -0
molcraft/models.py
CHANGED
|
@@ -194,8 +194,7 @@ class GraphModel(layers.GraphLayer, keras.models.Model):
|
|
|
194
194
|
Args:
|
|
195
195
|
x:
|
|
196
196
|
A `GraphTensor` instance or a `tf.data.Dataset` constructed from
|
|
197
|
-
a `GraphTensor` instance.
|
|
198
|
-
be encoded and will be ignored.
|
|
197
|
+
a `GraphTensor` instance.
|
|
199
198
|
batch_size:
|
|
200
199
|
Number of samples per batch of computation.
|
|
201
200
|
kwargs:
|
|
@@ -316,10 +315,16 @@ class FunctionalGraphModel(functional.Functional, GraphModel):
|
|
|
316
315
|
]
|
|
317
316
|
|
|
318
317
|
|
|
319
|
-
def save_model(model:
|
|
318
|
+
def save_model(model: GraphModel, filepath: str | Path, *args, **kwargs) -> None:
|
|
319
|
+
if not model.built:
|
|
320
|
+
raise ValueError(
|
|
321
|
+
'Model and its layers have not yet been (fully) built. '
|
|
322
|
+
'Build the model before saving it: `model.build(graph_spec)` '
|
|
323
|
+
'or `model(graph)`.'
|
|
324
|
+
)
|
|
320
325
|
keras.models.save_model(model, filepath, *args, **kwargs)
|
|
321
326
|
|
|
322
|
-
def load_model(filepath: str | Path, inputs=None, *args, **kwargs) ->
|
|
327
|
+
def load_model(filepath: str | Path, inputs=None, *args, **kwargs) -> GraphModel:
|
|
323
328
|
return keras.models.load_model(filepath, *args, **kwargs)
|
|
324
329
|
|
|
325
330
|
def create(
|
|
@@ -334,7 +339,7 @@ def create(
|
|
|
334
339
|
def interpret(
|
|
335
340
|
model: GraphModel,
|
|
336
341
|
graph_tensor: tensors.GraphTensor,
|
|
337
|
-
) ->
|
|
342
|
+
) -> tensors.GraphTensor:
|
|
338
343
|
x = graph_tensor
|
|
339
344
|
if tensors.is_ragged(x):
|
|
340
345
|
x = x.flatten()
|
|
@@ -374,6 +379,31 @@ def interpret(
|
|
|
374
379
|
}
|
|
375
380
|
)
|
|
376
381
|
|
|
382
|
+
def saliency(
|
|
383
|
+
model: GraphModel,
|
|
384
|
+
graph_tensor: tensors.GraphTensor
|
|
385
|
+
) -> tensors.GraphTensor:
|
|
386
|
+
x = graph_tensor
|
|
387
|
+
if tensors.is_ragged(x):
|
|
388
|
+
x = x.flatten()
|
|
389
|
+
y_true = x.context.get('label')
|
|
390
|
+
with tf.GradientTape(watch_accessed_variables=False) as tape:
|
|
391
|
+
tape.watch(x.node['feature'])
|
|
392
|
+
y_pred = model(x, training=False)
|
|
393
|
+
if y_true is not None and len(y_true.shape) > 1:
|
|
394
|
+
target = tf.gather_nd(y_pred, tf.where(y_true != 0))
|
|
395
|
+
else:
|
|
396
|
+
target = y_pred
|
|
397
|
+
gradients = tape.gradient(target, x.node['feature'])
|
|
398
|
+
gradients = keras.ops.absolute(gradients)
|
|
399
|
+
return graph_tensor.update(
|
|
400
|
+
{
|
|
401
|
+
'node': {
|
|
402
|
+
'feature_saliency': gradients
|
|
403
|
+
}
|
|
404
|
+
}
|
|
405
|
+
)
|
|
406
|
+
|
|
377
407
|
def predict(
|
|
378
408
|
model: GraphModel,
|
|
379
409
|
x: tensors.GraphTensor | tf.data.Dataset,
|
molcraft/tensors.py
CHANGED
|
@@ -9,6 +9,9 @@ from molcraft import ops
|
|
|
9
9
|
class GraphTensorBatchEncoder(tf.experimental.ExtensionTypeBatchEncoder):
|
|
10
10
|
|
|
11
11
|
def batch(self, spec: 'GraphTensor.Spec', batch_size: int | None):
|
|
12
|
+
"""Batches spec.
|
|
13
|
+
"""
|
|
14
|
+
|
|
12
15
|
def batch_field(f):
|
|
13
16
|
if isinstance(f, tf.TensorSpec):
|
|
14
17
|
return tf.TensorSpec(
|
|
@@ -36,6 +39,9 @@ class GraphTensorBatchEncoder(tf.experimental.ExtensionTypeBatchEncoder):
|
|
|
36
39
|
return batched_spec
|
|
37
40
|
|
|
38
41
|
def unbatch(self, spec: 'GraphTensor.Spec'):
|
|
42
|
+
"""Unbatches spec.
|
|
43
|
+
"""
|
|
44
|
+
|
|
39
45
|
def unbatch_field(f):
|
|
40
46
|
if isinstance(f, tf.TensorSpec):
|
|
41
47
|
return tf.TensorSpec(
|
|
@@ -63,6 +69,8 @@ class GraphTensorBatchEncoder(tf.experimental.ExtensionTypeBatchEncoder):
|
|
|
63
69
|
return unbatched_spec
|
|
64
70
|
|
|
65
71
|
def encode(self, spec: 'GraphTensor.Spec', value: 'GraphTensor', minimum_rank: int = 0):
|
|
72
|
+
"""Encodes value.
|
|
73
|
+
"""
|
|
66
74
|
unflatten = False if (is_ragged(spec) or is_scalar(spec)) else True
|
|
67
75
|
if unflatten:
|
|
68
76
|
value = value.unflatten()
|
|
@@ -74,6 +82,8 @@ class GraphTensorBatchEncoder(tf.experimental.ExtensionTypeBatchEncoder):
|
|
|
74
82
|
return value_components
|
|
75
83
|
|
|
76
84
|
def encoding_specs(self, spec: 'GraphTensor.Spec'):
|
|
85
|
+
"""Matches spec and encoded value of `encode(spec, value)`.
|
|
86
|
+
"""
|
|
77
87
|
def encode_fields(f):
|
|
78
88
|
if isinstance(f, tf.TensorSpec):
|
|
79
89
|
scalar = is_scalar(spec)
|
|
@@ -95,6 +105,8 @@ class GraphTensorBatchEncoder(tf.experimental.ExtensionTypeBatchEncoder):
|
|
|
95
105
|
return spec_components
|
|
96
106
|
|
|
97
107
|
def decode(self, spec, encoded_value):
|
|
108
|
+
"""Decodes encoded value.
|
|
109
|
+
"""
|
|
98
110
|
spec_tuple = tuple(spec.__dict__.values())
|
|
99
111
|
encoded_value = iter(encoded_value)
|
|
100
112
|
value_tuple = [
|
|
@@ -122,18 +134,27 @@ class GraphTensor(tf.experimental.BatchableExtensionType):
|
|
|
122
134
|
__name__ = 'GraphTensor'
|
|
123
135
|
|
|
124
136
|
def __validate__(self):
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
+
if tf.executing_eagerly():
|
|
138
|
+
assert 'size' in self.context, "graph.context['size'] is required."
|
|
139
|
+
assert self.context['size'].dtype == tf.int32, (
|
|
140
|
+
"dtype of graph.context['size'] needs to be int32."
|
|
141
|
+
)
|
|
142
|
+
assert 'feature' in self.node, "graph.node['feature'] is required."
|
|
143
|
+
assert 'source' in self.edge, "graph.edge['source'] is required."
|
|
144
|
+
assert 'target' in self.edge, "graph.edge['target'] is required."
|
|
145
|
+
assert self.edge['source'].dtype == tf.int32, (
|
|
146
|
+
"dtype of graph.edge['source'] needs to be int32."
|
|
147
|
+
)
|
|
148
|
+
assert self.edge['target'].dtype == tf.int32, (
|
|
149
|
+
"dtype of graph.edge['target'] needs to be int32."
|
|
150
|
+
)
|
|
151
|
+
if isinstance(self.node['feature'], tf.Tensor):
|
|
152
|
+
num_nodes = keras.ops.shape(self.node['feature'])[0]
|
|
153
|
+
else:
|
|
154
|
+
num_nodes = keras.ops.sum(self.node['feature'].row_lengths())
|
|
155
|
+
assert keras.ops.sum(self.context['size']) == num_nodes, (
|
|
156
|
+
"graph.node['feature'] tensor is incompatible with graph.context['size']"
|
|
157
|
+
)
|
|
137
158
|
|
|
138
159
|
@property
|
|
139
160
|
def spec(self):
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: molcraft
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.0a3
|
|
4
4
|
Summary: Graph Neural Networks for Molecular Machine Learning
|
|
5
5
|
Author-email: Alexander Kensert <alexander.kensert@gmail.com>
|
|
6
6
|
License: MIT License
|
|
@@ -56,3 +56,70 @@ Dynamic: license-file
|
|
|
56
56
|
- Modular graph **layers**
|
|
57
57
|
- Serializable graph **featurizers** and **models**
|
|
58
58
|
- Flexible **GraphTensor**
|
|
59
|
+
|
|
60
|
+
## Examples
|
|
61
|
+
|
|
62
|
+
```python
|
|
63
|
+
from molcraft import features
|
|
64
|
+
from molcraft import descriptors
|
|
65
|
+
from molcraft import featurizers
|
|
66
|
+
from molcraft import layers
|
|
67
|
+
from molcraft import models
|
|
68
|
+
import keras
|
|
69
|
+
|
|
70
|
+
featurizer = featurizers.MolGraphFeaturizer(
|
|
71
|
+
atom_features=[
|
|
72
|
+
features.AtomType(),
|
|
73
|
+
features.TotalNumHs(),
|
|
74
|
+
features.Degree(),
|
|
75
|
+
],
|
|
76
|
+
bond_features=[
|
|
77
|
+
features.BondType(),
|
|
78
|
+
features.IsRotatable(),
|
|
79
|
+
],
|
|
80
|
+
super_atom=True,
|
|
81
|
+
self_loops=False,
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
graph = featurizer([('N[C@@H](C)C(=O)O', 2.0), ('N[C@@H](CS)C(=O)O', 1.0)])
|
|
85
|
+
print(graph)
|
|
86
|
+
|
|
87
|
+
model = models.GraphModel.from_layers(
|
|
88
|
+
[
|
|
89
|
+
layers.Input(graph.spec),
|
|
90
|
+
layers.NodeEmbedding(dim=128),
|
|
91
|
+
layers.EdgeEmbedding(dim=128),
|
|
92
|
+
layers.GraphTransformer(units=128),
|
|
93
|
+
layers.GraphTransformer(units=128),
|
|
94
|
+
layers.GraphTransformer(units=128),
|
|
95
|
+
layers.GraphTransformer(units=128),
|
|
96
|
+
layers.Readout(mode='mean'),
|
|
97
|
+
keras.layers.Dense(units=1024, activation='relu'),
|
|
98
|
+
keras.layers.Dense(units=1024, activation='relu'),
|
|
99
|
+
keras.layers.Dense(1)
|
|
100
|
+
]
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
pred = model(graph)
|
|
104
|
+
print(pred)
|
|
105
|
+
|
|
106
|
+
# featurizers.save_featurizer(featurizer, '/tmp/featurizer.json')
|
|
107
|
+
# models.save_model(model, '/tmp/model.keras')
|
|
108
|
+
|
|
109
|
+
# featurizers.load_featurizer('/tmp/featurizer.json')
|
|
110
|
+
# models.load_model('/tmp/model.keras')
|
|
111
|
+
```
|
|
112
|
+
|
|
113
|
+
## Installation
|
|
114
|
+
|
|
115
|
+
Install the pre-release of molcraft via pip:
|
|
116
|
+
|
|
117
|
+
```bash
|
|
118
|
+
pip install molcraft --pre
|
|
119
|
+
```
|
|
120
|
+
|
|
121
|
+
with GPU support:
|
|
122
|
+
|
|
123
|
+
```bash
|
|
124
|
+
pip install molcraft[gpu] --pre
|
|
125
|
+
```
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
molcraft/__init__.py,sha256=2ZNfWBjGl8DscOwjdDiRkgIsuPnKit29Q3MhZyP336Q,435
|
|
2
|
+
molcraft/callbacks.py,sha256=6gwCwdsHGb-fVB4m1QGmtBwQwZ9mFq9QUkmPKSMn05U,849
|
|
3
|
+
molcraft/chem.py,sha256=_UO5O-I7KUtGf3vRrFEYoAUGlW5xi2x8ylu5f-Ybumo,18696
|
|
4
|
+
molcraft/conformers.py,sha256=p09gOQOdxLSj3yohZOMkxxLriHsZ1ZqOoiWLi73OpIg,4325
|
|
5
|
+
molcraft/datasets.py,sha256=rFgXTC1ZheLhfgQgcCspP_wEE54a33PIneH7OplbS-8,4047
|
|
6
|
+
molcraft/descriptors.py,sha256=x6RfZ-gK7D_WSvmK6sh6yHyEjQqovPnRU0xwC3dAKfg,2880
|
|
7
|
+
molcraft/features.py,sha256=69oV_GHNdBKPA4sp6Tpo6brvNmaauk_IVIzNjX7VDmg,13648
|
|
8
|
+
molcraft/featurizers.py,sha256=Yu8I6I_zkzB__WYSiqz-FDGjvKFOmyWFxojRBr39Aw8,26236
|
|
9
|
+
molcraft/layers.py,sha256=HjnAtqhuP0uZ5yP4L33k3xT4IUdLavWBrjd3wO9_Rmw,64915
|
|
10
|
+
molcraft/models.py,sha256=DXqWR_XnMVXQseVR91XnDLXvmHa1hv-6_Y_wvpQZBFI,17476
|
|
11
|
+
molcraft/ops.py,sha256=iiE6zgA2P7cmjKO1RHmL9GE_Tv7Tyuo_xDoxB_ELZQM,3824
|
|
12
|
+
molcraft/records.py,sha256=w4-bcWZEC0oVInrE1e0kQBroIaSCA0PN1JBPOtO6VUY,5251
|
|
13
|
+
molcraft/tensors.py,sha256=b7PO-YOvV72s9g057ILJACKS2n2fn10VkO35gHXpssI,22312
|
|
14
|
+
molcraft/experimental/__init__.py,sha256=x5h6LOO8bo3NPjkKKM9M1H-Kz6R3yxYhRSePoxHCdRE,42
|
|
15
|
+
molcraft/experimental/peptides.py,sha256=RCuOTSwoYHGSdeYi6TWHdPIv2WC3avCZjKLdhEZQeXw,8997
|
|
16
|
+
molcraft-0.1.0a3.dist-info/licenses/LICENSE,sha256=sbVeqlrtZ0V63uYhZGL5dCxUm8rBAOqe2avyA1zIQNk,1074
|
|
17
|
+
molcraft-0.1.0a3.dist-info/METADATA,sha256=f_5sBinpFcGSqKLaSqGkZJ83gGQtZw1Pb3fkgq9aCBM,4088
|
|
18
|
+
molcraft-0.1.0a3.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
|
19
|
+
molcraft-0.1.0a3.dist-info/top_level.txt,sha256=dENV6MfOceshM6MQCgJlcN1ojZkiCL9B4F7XyUge3QM,9
|
|
20
|
+
molcraft-0.1.0a3.dist-info/RECORD,,
|
|
@@ -1,19 +0,0 @@
|
|
|
1
|
-
molcraft/__init__.py,sha256=lx_frHVGfTmGUmzLhCt5LJVy47W_d_mTrKtRsw9RTQ8,406
|
|
2
|
-
molcraft/callbacks.py,sha256=6gwCwdsHGb-fVB4m1QGmtBwQwZ9mFq9QUkmPKSMn05U,849
|
|
3
|
-
molcraft/chem.py,sha256=_UO5O-I7KUtGf3vRrFEYoAUGlW5xi2x8ylu5f-Ybumo,18696
|
|
4
|
-
molcraft/conformers.py,sha256=p09gOQOdxLSj3yohZOMkxxLriHsZ1ZqOoiWLi73OpIg,4325
|
|
5
|
-
molcraft/descriptors.py,sha256=x6RfZ-gK7D_WSvmK6sh6yHyEjQqovPnRU0xwC3dAKfg,2880
|
|
6
|
-
molcraft/features.py,sha256=nZDfX9fsWWjhUbUbrWSUI0ny1QIDbxb4MO8umjcdQqw,13572
|
|
7
|
-
molcraft/featurizers.py,sha256=epv-K5ah9MVFEDZ7c2UzT7f9Vxglr28sbQoTYk1fev8,24583
|
|
8
|
-
molcraft/layers.py,sha256=XRxAUXnYOQ_fD8OegkuJTb1zMHHKm_4lAjxnIr2fdl4,43119
|
|
9
|
-
molcraft/models.py,sha256=z2V9_I_cnJr3VNqfY_CWkZKojlLoD_8r9MHkv3pKOh8,16605
|
|
10
|
-
molcraft/ops.py,sha256=iiE6zgA2P7cmjKO1RHmL9GE_Tv7Tyuo_xDoxB_ELZQM,3824
|
|
11
|
-
molcraft/records.py,sha256=w4-bcWZEC0oVInrE1e0kQBroIaSCA0PN1JBPOtO6VUY,5251
|
|
12
|
-
molcraft/tensors.py,sha256=_zlcTHH47icUe_S_wYzRp7J-4M2U_4lLkAZOIInea0w,21677
|
|
13
|
-
molcraft/experimental/__init__.py,sha256=x5h6LOO8bo3NPjkKKM9M1H-Kz6R3yxYhRSePoxHCdRE,42
|
|
14
|
-
molcraft/experimental/peptides.py,sha256=AcjuilPUPzn6l2ui5YPYE00VpXnGj8H1osTjz58biGw,10442
|
|
15
|
-
molcraft-0.1.0a1.dist-info/licenses/LICENSE,sha256=sbVeqlrtZ0V63uYhZGL5dCxUm8rBAOqe2avyA1zIQNk,1074
|
|
16
|
-
molcraft-0.1.0a1.dist-info/METADATA,sha256=0fczlpAEjq5hLHqGCQPdocQXeoCRfF7L4cyoz8WzYeM,2573
|
|
17
|
-
molcraft-0.1.0a1.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
|
18
|
-
molcraft-0.1.0a1.dist-info/top_level.txt,sha256=dENV6MfOceshM6MQCgJlcN1ojZkiCL9B4F7XyUge3QM,9
|
|
19
|
-
molcraft-0.1.0a1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|