molcraft 0.1.0a2__py3-none-any.whl → 0.1.0a4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of molcraft might be problematic. Click here for more details.
- molcraft/__init__.py +2 -1
- molcraft/callbacks.py +12 -0
- molcraft/descriptors.py +24 -23
- molcraft/experimental/peptides.py +96 -79
- molcraft/features.py +5 -3
- molcraft/featurizers.py +61 -38
- molcraft/layers.py +1004 -425
- molcraft/models.py +47 -3
- molcraft/ops.py +14 -3
- molcraft/tensors.py +3 -3
- {molcraft-0.1.0a2.dist-info → molcraft-0.1.0a4.dist-info}/METADATA +6 -6
- molcraft-0.1.0a4.dist-info/RECORD +20 -0
- {molcraft-0.1.0a2.dist-info → molcraft-0.1.0a4.dist-info}/WHEEL +1 -1
- molcraft-0.1.0a2.dist-info/RECORD +0 -20
- {molcraft-0.1.0a2.dist-info → molcraft-0.1.0a4.dist-info}/licenses/LICENSE +0 -0
- {molcraft-0.1.0a2.dist-info → molcraft-0.1.0a4.dist-info}/top_level.txt +0 -0
molcraft/models.py
CHANGED
|
@@ -270,6 +270,19 @@ class GraphModel(layers.GraphLayer, keras.models.Model):
|
|
|
270
270
|
"""
|
|
271
271
|
super().load_weights(filepath, *args, **kwargs)
|
|
272
272
|
|
|
273
|
+
def embedding(self) -> 'FunctionalGraphModel':
|
|
274
|
+
model = self
|
|
275
|
+
if not isinstance(model, FunctionalGraphModel):
|
|
276
|
+
raise ValueError(
|
|
277
|
+
'Currently, to extract the embedding part of the model, '
|
|
278
|
+
'it needs to be a `FunctionalGraphModel`. '
|
|
279
|
+
)
|
|
280
|
+
inputs = model.input
|
|
281
|
+
for layer in model.layers:
|
|
282
|
+
if isinstance(layer, layers.Readout):
|
|
283
|
+
outputs = layer.output
|
|
284
|
+
return self.__class__(inputs, outputs, name=f'{self.name}_embedding')
|
|
285
|
+
|
|
273
286
|
def train_step(self, tensor: tensors.GraphTensor) -> dict[str, float]:
|
|
274
287
|
y = tensor.context.get('label')
|
|
275
288
|
sample_weight = tensor.context.get('weight')
|
|
@@ -315,10 +328,16 @@ class FunctionalGraphModel(functional.Functional, GraphModel):
|
|
|
315
328
|
]
|
|
316
329
|
|
|
317
330
|
|
|
318
|
-
def save_model(model:
|
|
331
|
+
def save_model(model: GraphModel, filepath: str | Path, *args, **kwargs) -> None:
|
|
332
|
+
if not model.built:
|
|
333
|
+
raise ValueError(
|
|
334
|
+
'Model and its layers have not yet been (fully) built. '
|
|
335
|
+
'Build the model before saving it: `model.build(graph_spec)` '
|
|
336
|
+
'or `model(graph)`.'
|
|
337
|
+
)
|
|
319
338
|
keras.models.save_model(model, filepath, *args, **kwargs)
|
|
320
339
|
|
|
321
|
-
def load_model(filepath: str | Path, inputs=None, *args, **kwargs) ->
|
|
340
|
+
def load_model(filepath: str | Path, inputs=None, *args, **kwargs) -> GraphModel:
|
|
322
341
|
return keras.models.load_model(filepath, *args, **kwargs)
|
|
323
342
|
|
|
324
343
|
def create(
|
|
@@ -333,7 +352,7 @@ def create(
|
|
|
333
352
|
def interpret(
|
|
334
353
|
model: GraphModel,
|
|
335
354
|
graph_tensor: tensors.GraphTensor,
|
|
336
|
-
) ->
|
|
355
|
+
) -> tensors.GraphTensor:
|
|
337
356
|
x = graph_tensor
|
|
338
357
|
if tensors.is_ragged(x):
|
|
339
358
|
x = x.flatten()
|
|
@@ -373,6 +392,31 @@ def interpret(
|
|
|
373
392
|
}
|
|
374
393
|
)
|
|
375
394
|
|
|
395
|
+
def saliency(
|
|
396
|
+
model: GraphModel,
|
|
397
|
+
graph_tensor: tensors.GraphTensor,
|
|
398
|
+
) -> tensors.GraphTensor:
|
|
399
|
+
x = graph_tensor
|
|
400
|
+
if tensors.is_ragged(x):
|
|
401
|
+
x = x.flatten()
|
|
402
|
+
y_true = x.context.get('label')
|
|
403
|
+
with tf.GradientTape(watch_accessed_variables=False) as tape:
|
|
404
|
+
tape.watch(x.node['feature'])
|
|
405
|
+
y_pred = model(x, training=False)
|
|
406
|
+
if y_true is not None and len(y_true.shape) > 1:
|
|
407
|
+
target = tf.gather_nd(y_pred, tf.where(y_true != 0))
|
|
408
|
+
else:
|
|
409
|
+
target = y_pred
|
|
410
|
+
gradients = tape.gradient(target, x.node['feature'])
|
|
411
|
+
gradients = keras.ops.absolute(gradients)
|
|
412
|
+
return graph_tensor.update(
|
|
413
|
+
{
|
|
414
|
+
'node': {
|
|
415
|
+
'feature_saliency': gradients
|
|
416
|
+
}
|
|
417
|
+
}
|
|
418
|
+
)
|
|
419
|
+
|
|
376
420
|
def predict(
|
|
377
421
|
model: GraphModel,
|
|
378
422
|
x: tensors.GraphTensor | tf.data.Dataset,
|
molcraft/ops.py
CHANGED
|
@@ -19,9 +19,16 @@ def gather(
|
|
|
19
19
|
def aggregate(
|
|
20
20
|
node_feature: tf.Tensor,
|
|
21
21
|
edge: tf.Tensor,
|
|
22
|
-
num_nodes: tf.Tensor
|
|
22
|
+
num_nodes: tf.Tensor,
|
|
23
|
+
mode: str = 'sum',
|
|
23
24
|
) -> tf.Tensor:
|
|
24
|
-
|
|
25
|
+
if mode == 'mean':
|
|
26
|
+
return segment_mean(
|
|
27
|
+
node_feature, edge, num_nodes, sorted=False
|
|
28
|
+
)
|
|
29
|
+
return keras.ops.segment_sum(
|
|
30
|
+
node_feature, edge, num_nodes, sorted=False
|
|
31
|
+
)
|
|
25
32
|
|
|
26
33
|
def propagate(
|
|
27
34
|
node_feature: tf.Tensor,
|
|
@@ -82,7 +89,11 @@ def segment_mean(
|
|
|
82
89
|
sorted: bool = False,
|
|
83
90
|
) -> tf.Tensor:
|
|
84
91
|
if num_segments is None:
|
|
85
|
-
num_segments = keras.ops.
|
|
92
|
+
num_segments = keras.ops.cond(
|
|
93
|
+
keras.ops.shape(segment_ids)[0] > 0,
|
|
94
|
+
lambda: keras.ops.max(segment_ids) + 1,
|
|
95
|
+
lambda: 0
|
|
96
|
+
)
|
|
86
97
|
if backend.backend() == 'tensorflow':
|
|
87
98
|
return tf.math.unsorted_segment_mean(
|
|
88
99
|
data=data,
|
molcraft/tensors.py
CHANGED
|
@@ -219,13 +219,13 @@ class GraphTensor(tf.experimental.BatchableExtensionType):
|
|
|
219
219
|
raise ValueError
|
|
220
220
|
return ops.gather(self.node[node_attr], self.edge[edge_type])
|
|
221
221
|
|
|
222
|
-
def aggregate(self, edge_attr: str, edge_type: str = 'target') -> tf.Tensor:
|
|
222
|
+
def aggregate(self, edge_attr: str, edge_type: str = 'target', mode: str = 'sum') -> tf.Tensor:
|
|
223
223
|
if edge_type != 'source' and edge_type != 'target':
|
|
224
|
-
raise ValueError
|
|
224
|
+
raise ValueError('`edge_attr` needs to be `source` or `target`.')
|
|
225
225
|
edge_attr = self.edge[edge_attr]
|
|
226
226
|
if 'weight' in self.edge:
|
|
227
227
|
edge_attr = edge_attr * self.edge['weight']
|
|
228
|
-
return ops.aggregate(edge_attr, self.edge[edge_type], self.num_nodes)
|
|
228
|
+
return ops.aggregate(edge_attr, self.edge[edge_type], self.num_nodes, mode=mode)
|
|
229
229
|
|
|
230
230
|
def propagate(self, add_edge_feature: bool = False):
|
|
231
231
|
updated_feature = ops.propagate(
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: molcraft
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.0a4
|
|
4
4
|
Summary: Graph Neural Networks for Molecular Machine Learning
|
|
5
5
|
Author-email: Alexander Kensert <alexander.kensert@gmail.com>
|
|
6
6
|
License: MIT License
|
|
@@ -51,11 +51,11 @@ Dynamic: license-file
|
|
|
51
51
|
|
|
52
52
|
## Highlights
|
|
53
53
|
- Compatible with **Keras 3**
|
|
54
|
-
-
|
|
55
|
-
-
|
|
56
|
-
-
|
|
57
|
-
-
|
|
58
|
-
-
|
|
54
|
+
- Customizable and serializable **featurizers**
|
|
55
|
+
- Customizable and serializable **layers** and **models**
|
|
56
|
+
- Customizable **GraphTensor**
|
|
57
|
+
- Fast and efficient featurization of molecular graphs
|
|
58
|
+
- Efficient and easy-to-use input pipelines using TF **records**
|
|
59
59
|
|
|
60
60
|
## Examples
|
|
61
61
|
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
molcraft/__init__.py,sha256=FQyasgy1kEz2v9sKdr3am6ap7Cm1oHEuCKhHwH-CQpM,435
|
|
2
|
+
molcraft/callbacks.py,sha256=mkz4ALjJFPy8nHd2nCAuMbKceKnq4tIpZhUuUOvie2Y,1209
|
|
3
|
+
molcraft/chem.py,sha256=_UO5O-I7KUtGf3vRrFEYoAUGlW5xi2x8ylu5f-Ybumo,18696
|
|
4
|
+
molcraft/conformers.py,sha256=p09gOQOdxLSj3yohZOMkxxLriHsZ1ZqOoiWLi73OpIg,4325
|
|
5
|
+
molcraft/datasets.py,sha256=rFgXTC1ZheLhfgQgcCspP_wEE54a33PIneH7OplbS-8,4047
|
|
6
|
+
molcraft/descriptors.py,sha256=gKqlJ3BqJLTeR2ft8isftSEaJDC8cv64eTq5IYhy4XM,3032
|
|
7
|
+
molcraft/features.py,sha256=69oV_GHNdBKPA4sp6Tpo6brvNmaauk_IVIzNjX7VDmg,13648
|
|
8
|
+
molcraft/featurizers.py,sha256=kV5RN_Z2pELjDcwE65KYy_JagbDUueXoClpsIOFsI9I,27073
|
|
9
|
+
molcraft/layers.py,sha256=y-sBLXWttr-fkGZ-acL1srMB8QqeXnHotYK9KCcyJNU,70581
|
|
10
|
+
molcraft/models.py,sha256=0MN4PAlsacni7RfIcYm_imxuzBVL2K8w3MnaUM24DeI,18021
|
|
11
|
+
molcraft/ops.py,sha256=uSnBYQwxYJ1ATdDpr290bxiyQZkrSCVxlB7btlh_n2I,4112
|
|
12
|
+
molcraft/records.py,sha256=w4-bcWZEC0oVInrE1e0kQBroIaSCA0PN1JBPOtO6VUY,5251
|
|
13
|
+
molcraft/tensors.py,sha256=8hwlad000wQ5pNLSdzd3rCXVbaUHBxUq2MbBx27dKzU,22391
|
|
14
|
+
molcraft/experimental/__init__.py,sha256=x5h6LOO8bo3NPjkKKM9M1H-Kz6R3yxYhRSePoxHCdRE,42
|
|
15
|
+
molcraft/experimental/peptides.py,sha256=82Bzw9FEnlymOUgTIIKha-ELNbqEFkv9T4hspDGRetw,9266
|
|
16
|
+
molcraft-0.1.0a4.dist-info/licenses/LICENSE,sha256=sbVeqlrtZ0V63uYhZGL5dCxUm8rBAOqe2avyA1zIQNk,1074
|
|
17
|
+
molcraft-0.1.0a4.dist-info/METADATA,sha256=bhsytRfa6BIbfmph0Cm2NfubmZJPumsMQt4lbch33kQ,4201
|
|
18
|
+
molcraft-0.1.0a4.dist-info/WHEEL,sha256=ck4Vq1_RXyvS4Jt6SI0Vz6fyVs4GWg7AINwpsaGEgPE,91
|
|
19
|
+
molcraft-0.1.0a4.dist-info/top_level.txt,sha256=dENV6MfOceshM6MQCgJlcN1ojZkiCL9B4F7XyUge3QM,9
|
|
20
|
+
molcraft-0.1.0a4.dist-info/RECORD,,
|
|
@@ -1,20 +0,0 @@
|
|
|
1
|
-
molcraft/__init__.py,sha256=lE7_mCo7lLcP1AopGZtGyWqzAN1qgjZnH5juymdjrJc,406
|
|
2
|
-
molcraft/callbacks.py,sha256=6gwCwdsHGb-fVB4m1QGmtBwQwZ9mFq9QUkmPKSMn05U,849
|
|
3
|
-
molcraft/chem.py,sha256=_UO5O-I7KUtGf3vRrFEYoAUGlW5xi2x8ylu5f-Ybumo,18696
|
|
4
|
-
molcraft/conformers.py,sha256=p09gOQOdxLSj3yohZOMkxxLriHsZ1ZqOoiWLi73OpIg,4325
|
|
5
|
-
molcraft/datasets.py,sha256=rFgXTC1ZheLhfgQgcCspP_wEE54a33PIneH7OplbS-8,4047
|
|
6
|
-
molcraft/descriptors.py,sha256=x6RfZ-gK7D_WSvmK6sh6yHyEjQqovPnRU0xwC3dAKfg,2880
|
|
7
|
-
molcraft/features.py,sha256=nZDfX9fsWWjhUbUbrWSUI0ny1QIDbxb4MO8umjcdQqw,13572
|
|
8
|
-
molcraft/featurizers.py,sha256=gAUe7Ui8gF32aotuiDAUoRUuw8bTbkMgB2C2BO1VWDM,26176
|
|
9
|
-
molcraft/layers.py,sha256=zs6Ae6p7ASeAy3eF113f35d55yQmyk2Z7vUUfkfJUmY,49677
|
|
10
|
-
molcraft/models.py,sha256=Nvm5LKCtH-xj395f1OvIEmYVTTrnutoSthL2DxGicnY,16519
|
|
11
|
-
molcraft/ops.py,sha256=iiE6zgA2P7cmjKO1RHmL9GE_Tv7Tyuo_xDoxB_ELZQM,3824
|
|
12
|
-
molcraft/records.py,sha256=w4-bcWZEC0oVInrE1e0kQBroIaSCA0PN1JBPOtO6VUY,5251
|
|
13
|
-
molcraft/tensors.py,sha256=b7PO-YOvV72s9g057ILJACKS2n2fn10VkO35gHXpssI,22312
|
|
14
|
-
molcraft/experimental/__init__.py,sha256=x5h6LOO8bo3NPjkKKM9M1H-Kz6R3yxYhRSePoxHCdRE,42
|
|
15
|
-
molcraft/experimental/peptides.py,sha256=RCuOTSwoYHGSdeYi6TWHdPIv2WC3avCZjKLdhEZQeXw,8997
|
|
16
|
-
molcraft-0.1.0a2.dist-info/licenses/LICENSE,sha256=sbVeqlrtZ0V63uYhZGL5dCxUm8rBAOqe2avyA1zIQNk,1074
|
|
17
|
-
molcraft-0.1.0a2.dist-info/METADATA,sha256=TYf32YHTSrK9OaaGKCCk89uPlf_REWsK-LKf93c6V4M,4088
|
|
18
|
-
molcraft-0.1.0a2.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
|
19
|
-
molcraft-0.1.0a2.dist-info/top_level.txt,sha256=dENV6MfOceshM6MQCgJlcN1ojZkiCL9B4F7XyUge3QM,9
|
|
20
|
-
molcraft-0.1.0a2.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|