molcraft 0.1.0a3__py3-none-any.whl → 0.1.0a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of molcraft might be problematic. Click here for more details.

molcraft/models.py CHANGED
@@ -270,6 +270,19 @@ class GraphModel(layers.GraphLayer, keras.models.Model):
270
270
  """
271
271
  super().load_weights(filepath, *args, **kwargs)
272
272
 
273
+ def embedding(self) -> 'FunctionalGraphModel':
274
+ model = self
275
+ if not isinstance(model, FunctionalGraphModel):
276
+ raise ValueError(
277
+ 'Currently, to extract the embedding part of the model, '
278
+ 'it needs to be a `FunctionalGraphModel`. '
279
+ )
280
+ inputs = model.input
281
+ for layer in model.layers:
282
+ if isinstance(layer, layers.Readout):
283
+ outputs = layer.output
284
+ return self.__class__(inputs, outputs, name=f'{self.name}_embedding')
285
+
273
286
  def train_step(self, tensor: tensors.GraphTensor) -> dict[str, float]:
274
287
  y = tensor.context.get('label')
275
288
  sample_weight = tensor.context.get('weight')
@@ -381,7 +394,7 @@ def interpret(
381
394
 
382
395
  def saliency(
383
396
  model: GraphModel,
384
- graph_tensor: tensors.GraphTensor
397
+ graph_tensor: tensors.GraphTensor,
385
398
  ) -> tensors.GraphTensor:
386
399
  x = graph_tensor
387
400
  if tensors.is_ragged(x):
molcraft/ops.py CHANGED
@@ -19,9 +19,16 @@ def gather(
19
19
  def aggregate(
20
20
  node_feature: tf.Tensor,
21
21
  edge: tf.Tensor,
22
- num_nodes: tf.Tensor
22
+ num_nodes: tf.Tensor,
23
+ mode: str = 'sum',
23
24
  ) -> tf.Tensor:
24
- return keras.ops.segment_sum(node_feature, edge, num_nodes)
25
+ if mode == 'mean':
26
+ return segment_mean(
27
+ node_feature, edge, num_nodes, sorted=False
28
+ )
29
+ return keras.ops.segment_sum(
30
+ node_feature, edge, num_nodes, sorted=False
31
+ )
25
32
 
26
33
  def propagate(
27
34
  node_feature: tf.Tensor,
@@ -82,7 +89,11 @@ def segment_mean(
82
89
  sorted: bool = False,
83
90
  ) -> tf.Tensor:
84
91
  if num_segments is None:
85
- num_segments = keras.ops.max(segment_ids) + 1
92
+ num_segments = keras.ops.cond(
93
+ keras.ops.shape(segment_ids)[0] > 0,
94
+ lambda: keras.ops.max(segment_ids) + 1,
95
+ lambda: 0
96
+ )
86
97
  if backend.backend() == 'tensorflow':
87
98
  return tf.math.unsorted_segment_mean(
88
99
  data=data,
molcraft/tensors.py CHANGED
@@ -219,13 +219,13 @@ class GraphTensor(tf.experimental.BatchableExtensionType):
219
219
  raise ValueError
220
220
  return ops.gather(self.node[node_attr], self.edge[edge_type])
221
221
 
222
- def aggregate(self, edge_attr: str, edge_type: str = 'target') -> tf.Tensor:
222
+ def aggregate(self, edge_attr: str, edge_type: str = 'target', mode: str = 'sum') -> tf.Tensor:
223
223
  if edge_type != 'source' and edge_type != 'target':
224
- raise ValueError
224
+ raise ValueError('`edge_attr` needs to be `source` or `target`.')
225
225
  edge_attr = self.edge[edge_attr]
226
226
  if 'weight' in self.edge:
227
227
  edge_attr = edge_attr * self.edge['weight']
228
- return ops.aggregate(edge_attr, self.edge[edge_type], self.num_nodes)
228
+ return ops.aggregate(edge_attr, self.edge[edge_type], self.num_nodes, mode=mode)
229
229
 
230
230
  def propagate(self, add_edge_feature: bool = False):
231
231
  updated_feature = ops.propagate(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: molcraft
3
- Version: 0.1.0a3
3
+ Version: 0.1.0a4
4
4
  Summary: Graph Neural Networks for Molecular Machine Learning
5
5
  Author-email: Alexander Kensert <alexander.kensert@gmail.com>
6
6
  License: MIT License
@@ -51,11 +51,11 @@ Dynamic: license-file
51
51
 
52
52
  ## Highlights
53
53
  - Compatible with **Keras 3**
54
- - Simplified API
55
- - Fast featurization
56
- - Modular graph **layers**
57
- - Serializable graph **featurizers** and **models**
58
- - Flexible **GraphTensor**
54
+ - Customizable and serializable **featurizers**
55
+ - Customizable and serializable **layers** and **models**
56
+ - Customizable **GraphTensor**
57
+ - Fast and efficient featurization of molecular graphs
58
+ - Efficient and easy-to-use input pipelines using TF **records**
59
59
 
60
60
  ## Examples
61
61
 
@@ -0,0 +1,20 @@
1
+ molcraft/__init__.py,sha256=FQyasgy1kEz2v9sKdr3am6ap7Cm1oHEuCKhHwH-CQpM,435
2
+ molcraft/callbacks.py,sha256=mkz4ALjJFPy8nHd2nCAuMbKceKnq4tIpZhUuUOvie2Y,1209
3
+ molcraft/chem.py,sha256=_UO5O-I7KUtGf3vRrFEYoAUGlW5xi2x8ylu5f-Ybumo,18696
4
+ molcraft/conformers.py,sha256=p09gOQOdxLSj3yohZOMkxxLriHsZ1ZqOoiWLi73OpIg,4325
5
+ molcraft/datasets.py,sha256=rFgXTC1ZheLhfgQgcCspP_wEE54a33PIneH7OplbS-8,4047
6
+ molcraft/descriptors.py,sha256=gKqlJ3BqJLTeR2ft8isftSEaJDC8cv64eTq5IYhy4XM,3032
7
+ molcraft/features.py,sha256=69oV_GHNdBKPA4sp6Tpo6brvNmaauk_IVIzNjX7VDmg,13648
8
+ molcraft/featurizers.py,sha256=kV5RN_Z2pELjDcwE65KYy_JagbDUueXoClpsIOFsI9I,27073
9
+ molcraft/layers.py,sha256=y-sBLXWttr-fkGZ-acL1srMB8QqeXnHotYK9KCcyJNU,70581
10
+ molcraft/models.py,sha256=0MN4PAlsacni7RfIcYm_imxuzBVL2K8w3MnaUM24DeI,18021
11
+ molcraft/ops.py,sha256=uSnBYQwxYJ1ATdDpr290bxiyQZkrSCVxlB7btlh_n2I,4112
12
+ molcraft/records.py,sha256=w4-bcWZEC0oVInrE1e0kQBroIaSCA0PN1JBPOtO6VUY,5251
13
+ molcraft/tensors.py,sha256=8hwlad000wQ5pNLSdzd3rCXVbaUHBxUq2MbBx27dKzU,22391
14
+ molcraft/experimental/__init__.py,sha256=x5h6LOO8bo3NPjkKKM9M1H-Kz6R3yxYhRSePoxHCdRE,42
15
+ molcraft/experimental/peptides.py,sha256=82Bzw9FEnlymOUgTIIKha-ELNbqEFkv9T4hspDGRetw,9266
16
+ molcraft-0.1.0a4.dist-info/licenses/LICENSE,sha256=sbVeqlrtZ0V63uYhZGL5dCxUm8rBAOqe2avyA1zIQNk,1074
17
+ molcraft-0.1.0a4.dist-info/METADATA,sha256=bhsytRfa6BIbfmph0Cm2NfubmZJPumsMQt4lbch33kQ,4201
18
+ molcraft-0.1.0a4.dist-info/WHEEL,sha256=ck4Vq1_RXyvS4Jt6SI0Vz6fyVs4GWg7AINwpsaGEgPE,91
19
+ molcraft-0.1.0a4.dist-info/top_level.txt,sha256=dENV6MfOceshM6MQCgJlcN1ojZkiCL9B4F7XyUge3QM,9
20
+ molcraft-0.1.0a4.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (78.1.0)
2
+ Generator: setuptools (80.0.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,20 +0,0 @@
1
- molcraft/__init__.py,sha256=2ZNfWBjGl8DscOwjdDiRkgIsuPnKit29Q3MhZyP336Q,435
2
- molcraft/callbacks.py,sha256=6gwCwdsHGb-fVB4m1QGmtBwQwZ9mFq9QUkmPKSMn05U,849
3
- molcraft/chem.py,sha256=_UO5O-I7KUtGf3vRrFEYoAUGlW5xi2x8ylu5f-Ybumo,18696
4
- molcraft/conformers.py,sha256=p09gOQOdxLSj3yohZOMkxxLriHsZ1ZqOoiWLi73OpIg,4325
5
- molcraft/datasets.py,sha256=rFgXTC1ZheLhfgQgcCspP_wEE54a33PIneH7OplbS-8,4047
6
- molcraft/descriptors.py,sha256=x6RfZ-gK7D_WSvmK6sh6yHyEjQqovPnRU0xwC3dAKfg,2880
7
- molcraft/features.py,sha256=69oV_GHNdBKPA4sp6Tpo6brvNmaauk_IVIzNjX7VDmg,13648
8
- molcraft/featurizers.py,sha256=Yu8I6I_zkzB__WYSiqz-FDGjvKFOmyWFxojRBr39Aw8,26236
9
- molcraft/layers.py,sha256=HjnAtqhuP0uZ5yP4L33k3xT4IUdLavWBrjd3wO9_Rmw,64915
10
- molcraft/models.py,sha256=DXqWR_XnMVXQseVR91XnDLXvmHa1hv-6_Y_wvpQZBFI,17476
11
- molcraft/ops.py,sha256=iiE6zgA2P7cmjKO1RHmL9GE_Tv7Tyuo_xDoxB_ELZQM,3824
12
- molcraft/records.py,sha256=w4-bcWZEC0oVInrE1e0kQBroIaSCA0PN1JBPOtO6VUY,5251
13
- molcraft/tensors.py,sha256=b7PO-YOvV72s9g057ILJACKS2n2fn10VkO35gHXpssI,22312
14
- molcraft/experimental/__init__.py,sha256=x5h6LOO8bo3NPjkKKM9M1H-Kz6R3yxYhRSePoxHCdRE,42
15
- molcraft/experimental/peptides.py,sha256=RCuOTSwoYHGSdeYi6TWHdPIv2WC3avCZjKLdhEZQeXw,8997
16
- molcraft-0.1.0a3.dist-info/licenses/LICENSE,sha256=sbVeqlrtZ0V63uYhZGL5dCxUm8rBAOqe2avyA1zIQNk,1074
17
- molcraft-0.1.0a3.dist-info/METADATA,sha256=f_5sBinpFcGSqKLaSqGkZJ83gGQtZw1Pb3fkgq9aCBM,4088
18
- molcraft-0.1.0a3.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
19
- molcraft-0.1.0a3.dist-info/top_level.txt,sha256=dENV6MfOceshM6MQCgJlcN1ojZkiCL9B4F7XyUge3QM,9
20
- molcraft-0.1.0a3.dist-info/RECORD,,