molcraft 0.1.0a2__py3-none-any.whl → 0.1.0a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of molcraft might be problematic. Click here for more details.

molcraft/models.py CHANGED
@@ -270,6 +270,19 @@ class GraphModel(layers.GraphLayer, keras.models.Model):
270
270
  """
271
271
  super().load_weights(filepath, *args, **kwargs)
272
272
 
273
+ def embedding(self) -> 'FunctionalGraphModel':
274
+ model = self
275
+ if not isinstance(model, FunctionalGraphModel):
276
+ raise ValueError(
277
+ 'Currently, to extract the embedding part of the model, '
278
+ 'it needs to be a `FunctionalGraphModel`. '
279
+ )
280
+ inputs = model.input
281
+ for layer in model.layers:
282
+ if isinstance(layer, layers.Readout):
283
+ outputs = layer.output
284
+ return self.__class__(inputs, outputs, name=f'{self.name}_embedding')
285
+
273
286
  def train_step(self, tensor: tensors.GraphTensor) -> dict[str, float]:
274
287
  y = tensor.context.get('label')
275
288
  sample_weight = tensor.context.get('weight')
@@ -315,10 +328,16 @@ class FunctionalGraphModel(functional.Functional, GraphModel):
315
328
  ]
316
329
 
317
330
 
318
- def save_model(model: keras.Model, filepath: str | Path, *args, **kwargs) -> None:
331
+ def save_model(model: GraphModel, filepath: str | Path, *args, **kwargs) -> None:
332
+ if not model.built:
333
+ raise ValueError(
334
+ 'Model and its layers have not yet been (fully) built. '
335
+ 'Build the model before saving it: `model.build(graph_spec)` '
336
+ 'or `model(graph)`.'
337
+ )
319
338
  keras.models.save_model(model, filepath, *args, **kwargs)
320
339
 
321
- def load_model(filepath: str | Path, inputs=None, *args, **kwargs) -> None:
340
+ def load_model(filepath: str | Path, inputs=None, *args, **kwargs) -> GraphModel:
322
341
  return keras.models.load_model(filepath, *args, **kwargs)
323
342
 
324
343
  def create(
@@ -333,7 +352,7 @@ def create(
333
352
  def interpret(
334
353
  model: GraphModel,
335
354
  graph_tensor: tensors.GraphTensor,
336
- ) -> tuple[tf.Tensor | tf.RaggedTensor | np.ndarray, tf.Tensor | np.ndarray]:
355
+ ) -> tensors.GraphTensor:
337
356
  x = graph_tensor
338
357
  if tensors.is_ragged(x):
339
358
  x = x.flatten()
@@ -373,6 +392,31 @@ def interpret(
373
392
  }
374
393
  )
375
394
 
395
+ def saliency(
396
+ model: GraphModel,
397
+ graph_tensor: tensors.GraphTensor,
398
+ ) -> tensors.GraphTensor:
399
+ x = graph_tensor
400
+ if tensors.is_ragged(x):
401
+ x = x.flatten()
402
+ y_true = x.context.get('label')
403
+ with tf.GradientTape(watch_accessed_variables=False) as tape:
404
+ tape.watch(x.node['feature'])
405
+ y_pred = model(x, training=False)
406
+ if y_true is not None and len(y_true.shape) > 1:
407
+ target = tf.gather_nd(y_pred, tf.where(y_true != 0))
408
+ else:
409
+ target = y_pred
410
+ gradients = tape.gradient(target, x.node['feature'])
411
+ gradients = keras.ops.absolute(gradients)
412
+ return graph_tensor.update(
413
+ {
414
+ 'node': {
415
+ 'feature_saliency': gradients
416
+ }
417
+ }
418
+ )
419
+
376
420
  def predict(
377
421
  model: GraphModel,
378
422
  x: tensors.GraphTensor | tf.data.Dataset,
molcraft/ops.py CHANGED
@@ -19,9 +19,16 @@ def gather(
19
19
  def aggregate(
20
20
  node_feature: tf.Tensor,
21
21
  edge: tf.Tensor,
22
- num_nodes: tf.Tensor
22
+ num_nodes: tf.Tensor,
23
+ mode: str = 'sum',
23
24
  ) -> tf.Tensor:
24
- return keras.ops.segment_sum(node_feature, edge, num_nodes)
25
+ if mode == 'mean':
26
+ return segment_mean(
27
+ node_feature, edge, num_nodes, sorted=False
28
+ )
29
+ return keras.ops.segment_sum(
30
+ node_feature, edge, num_nodes, sorted=False
31
+ )
25
32
 
26
33
  def propagate(
27
34
  node_feature: tf.Tensor,
@@ -82,7 +89,11 @@ def segment_mean(
82
89
  sorted: bool = False,
83
90
  ) -> tf.Tensor:
84
91
  if num_segments is None:
85
- num_segments = keras.ops.max(segment_ids) + 1
92
+ num_segments = keras.ops.cond(
93
+ keras.ops.shape(segment_ids)[0] > 0,
94
+ lambda: keras.ops.max(segment_ids) + 1,
95
+ lambda: 0
96
+ )
86
97
  if backend.backend() == 'tensorflow':
87
98
  return tf.math.unsorted_segment_mean(
88
99
  data=data,
molcraft/tensors.py CHANGED
@@ -219,13 +219,13 @@ class GraphTensor(tf.experimental.BatchableExtensionType):
219
219
  raise ValueError
220
220
  return ops.gather(self.node[node_attr], self.edge[edge_type])
221
221
 
222
- def aggregate(self, edge_attr: str, edge_type: str = 'target') -> tf.Tensor:
222
+ def aggregate(self, edge_attr: str, edge_type: str = 'target', mode: str = 'sum') -> tf.Tensor:
223
223
  if edge_type != 'source' and edge_type != 'target':
224
- raise ValueError
224
+ raise ValueError('`edge_attr` needs to be `source` or `target`.')
225
225
  edge_attr = self.edge[edge_attr]
226
226
  if 'weight' in self.edge:
227
227
  edge_attr = edge_attr * self.edge['weight']
228
- return ops.aggregate(edge_attr, self.edge[edge_type], self.num_nodes)
228
+ return ops.aggregate(edge_attr, self.edge[edge_type], self.num_nodes, mode=mode)
229
229
 
230
230
  def propagate(self, add_edge_feature: bool = False):
231
231
  updated_feature = ops.propagate(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: molcraft
3
- Version: 0.1.0a2
3
+ Version: 0.1.0a4
4
4
  Summary: Graph Neural Networks for Molecular Machine Learning
5
5
  Author-email: Alexander Kensert <alexander.kensert@gmail.com>
6
6
  License: MIT License
@@ -51,11 +51,11 @@ Dynamic: license-file
51
51
 
52
52
  ## Highlights
53
53
  - Compatible with **Keras 3**
54
- - Simplified API
55
- - Fast featurization
56
- - Modular graph **layers**
57
- - Serializable graph **featurizers** and **models**
58
- - Flexible **GraphTensor**
54
+ - Customizable and serializable **featurizers**
55
+ - Customizable and serializable **layers** and **models**
56
+ - Customizable **GraphTensor**
57
+ - Fast and efficient featurization of molecular graphs
58
+ - Efficient and easy-to-use input pipelines using TF **records**
59
59
 
60
60
  ## Examples
61
61
 
@@ -0,0 +1,20 @@
1
+ molcraft/__init__.py,sha256=FQyasgy1kEz2v9sKdr3am6ap7Cm1oHEuCKhHwH-CQpM,435
2
+ molcraft/callbacks.py,sha256=mkz4ALjJFPy8nHd2nCAuMbKceKnq4tIpZhUuUOvie2Y,1209
3
+ molcraft/chem.py,sha256=_UO5O-I7KUtGf3vRrFEYoAUGlW5xi2x8ylu5f-Ybumo,18696
4
+ molcraft/conformers.py,sha256=p09gOQOdxLSj3yohZOMkxxLriHsZ1ZqOoiWLi73OpIg,4325
5
+ molcraft/datasets.py,sha256=rFgXTC1ZheLhfgQgcCspP_wEE54a33PIneH7OplbS-8,4047
6
+ molcraft/descriptors.py,sha256=gKqlJ3BqJLTeR2ft8isftSEaJDC8cv64eTq5IYhy4XM,3032
7
+ molcraft/features.py,sha256=69oV_GHNdBKPA4sp6Tpo6brvNmaauk_IVIzNjX7VDmg,13648
8
+ molcraft/featurizers.py,sha256=kV5RN_Z2pELjDcwE65KYy_JagbDUueXoClpsIOFsI9I,27073
9
+ molcraft/layers.py,sha256=y-sBLXWttr-fkGZ-acL1srMB8QqeXnHotYK9KCcyJNU,70581
10
+ molcraft/models.py,sha256=0MN4PAlsacni7RfIcYm_imxuzBVL2K8w3MnaUM24DeI,18021
11
+ molcraft/ops.py,sha256=uSnBYQwxYJ1ATdDpr290bxiyQZkrSCVxlB7btlh_n2I,4112
12
+ molcraft/records.py,sha256=w4-bcWZEC0oVInrE1e0kQBroIaSCA0PN1JBPOtO6VUY,5251
13
+ molcraft/tensors.py,sha256=8hwlad000wQ5pNLSdzd3rCXVbaUHBxUq2MbBx27dKzU,22391
14
+ molcraft/experimental/__init__.py,sha256=x5h6LOO8bo3NPjkKKM9M1H-Kz6R3yxYhRSePoxHCdRE,42
15
+ molcraft/experimental/peptides.py,sha256=82Bzw9FEnlymOUgTIIKha-ELNbqEFkv9T4hspDGRetw,9266
16
+ molcraft-0.1.0a4.dist-info/licenses/LICENSE,sha256=sbVeqlrtZ0V63uYhZGL5dCxUm8rBAOqe2avyA1zIQNk,1074
17
+ molcraft-0.1.0a4.dist-info/METADATA,sha256=bhsytRfa6BIbfmph0Cm2NfubmZJPumsMQt4lbch33kQ,4201
18
+ molcraft-0.1.0a4.dist-info/WHEEL,sha256=ck4Vq1_RXyvS4Jt6SI0Vz6fyVs4GWg7AINwpsaGEgPE,91
19
+ molcraft-0.1.0a4.dist-info/top_level.txt,sha256=dENV6MfOceshM6MQCgJlcN1ojZkiCL9B4F7XyUge3QM,9
20
+ molcraft-0.1.0a4.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (78.1.0)
2
+ Generator: setuptools (80.0.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,20 +0,0 @@
1
- molcraft/__init__.py,sha256=lE7_mCo7lLcP1AopGZtGyWqzAN1qgjZnH5juymdjrJc,406
2
- molcraft/callbacks.py,sha256=6gwCwdsHGb-fVB4m1QGmtBwQwZ9mFq9QUkmPKSMn05U,849
3
- molcraft/chem.py,sha256=_UO5O-I7KUtGf3vRrFEYoAUGlW5xi2x8ylu5f-Ybumo,18696
4
- molcraft/conformers.py,sha256=p09gOQOdxLSj3yohZOMkxxLriHsZ1ZqOoiWLi73OpIg,4325
5
- molcraft/datasets.py,sha256=rFgXTC1ZheLhfgQgcCspP_wEE54a33PIneH7OplbS-8,4047
6
- molcraft/descriptors.py,sha256=x6RfZ-gK7D_WSvmK6sh6yHyEjQqovPnRU0xwC3dAKfg,2880
7
- molcraft/features.py,sha256=nZDfX9fsWWjhUbUbrWSUI0ny1QIDbxb4MO8umjcdQqw,13572
8
- molcraft/featurizers.py,sha256=gAUe7Ui8gF32aotuiDAUoRUuw8bTbkMgB2C2BO1VWDM,26176
9
- molcraft/layers.py,sha256=zs6Ae6p7ASeAy3eF113f35d55yQmyk2Z7vUUfkfJUmY,49677
10
- molcraft/models.py,sha256=Nvm5LKCtH-xj395f1OvIEmYVTTrnutoSthL2DxGicnY,16519
11
- molcraft/ops.py,sha256=iiE6zgA2P7cmjKO1RHmL9GE_Tv7Tyuo_xDoxB_ELZQM,3824
12
- molcraft/records.py,sha256=w4-bcWZEC0oVInrE1e0kQBroIaSCA0PN1JBPOtO6VUY,5251
13
- molcraft/tensors.py,sha256=b7PO-YOvV72s9g057ILJACKS2n2fn10VkO35gHXpssI,22312
14
- molcraft/experimental/__init__.py,sha256=x5h6LOO8bo3NPjkKKM9M1H-Kz6R3yxYhRSePoxHCdRE,42
15
- molcraft/experimental/peptides.py,sha256=RCuOTSwoYHGSdeYi6TWHdPIv2WC3avCZjKLdhEZQeXw,8997
16
- molcraft-0.1.0a2.dist-info/licenses/LICENSE,sha256=sbVeqlrtZ0V63uYhZGL5dCxUm8rBAOqe2avyA1zIQNk,1074
17
- molcraft-0.1.0a2.dist-info/METADATA,sha256=TYf32YHTSrK9OaaGKCCk89uPlf_REWsK-LKf93c6V4M,4088
18
- molcraft-0.1.0a2.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
19
- molcraft-0.1.0a2.dist-info/top_level.txt,sha256=dENV6MfOceshM6MQCgJlcN1ojZkiCL9B4F7XyUge3QM,9
20
- molcraft-0.1.0a2.dist-info/RECORD,,