molcraft 0.1.0a6__py3-none-any.whl → 0.1.0a8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of molcraft might be problematic. Click here for more details.

molcraft/models.py CHANGED
@@ -415,6 +415,10 @@ class GraphModel(layers.GraphLayer, keras.models.Model):
415
415
  def predict_step(self, tensor: tensors.GraphTensor) -> np.ndarray:
416
416
  return self(tensor, training=False)
417
417
 
418
+ def compute_loss(self, x, y, y_pred, sample_weight=None):
419
+ y, y_pred, sample_weight = _maybe_reshape(y, y_pred, sample_weight)
420
+ return super().compute_loss(x, y, y_pred, sample_weight)
421
+
418
422
  def compute_metrics(self, x, y, y_pred, sample_weight=None) -> dict[str, float]:
419
423
  loss = self.compute_loss(x, y, y_pred, sample_weight)
420
424
  metric_results = {}
@@ -423,7 +427,7 @@ class GraphModel(layers.GraphLayer, keras.models.Model):
423
427
  metric.update_state(loss)
424
428
  metric_results[metric.name] = metric.result()
425
429
  else:
426
- metric.update_state(y, y_pred)
430
+ metric.update_state(y, y_pred, sample_weight=sample_weight)
427
431
  metric_results.update(metric.result())
428
432
  return metric_results
429
433
 
@@ -593,3 +597,14 @@ def _make_dataset(x: tensors.GraphTensor, batch_size: int):
593
597
  .batch(batch_size)
594
598
  .prefetch(-1)
595
599
  )
600
+
601
+ def _maybe_reshape(y, y_pred, sample_weight):
602
+ if (
603
+ sample_weight is not None and
604
+ len(keras.ops.shape(sample_weight)) == 2 and
605
+ sample_weight.shape == y_pred.shape
606
+ ):
607
+ y = keras.ops.reshape(y, [-1])
608
+ y_pred = keras.ops.reshape(y_pred, [-1])
609
+ sample_weight = keras.ops.reshape(sample_weight, [-1])
610
+ return y, y_pred, sample_weight
molcraft/ops.py CHANGED
@@ -67,7 +67,7 @@ def edge_softmax(
67
67
  edge_target: tf.Tensor
68
68
  ) -> tf.Tensor:
69
69
  num_segments = keras.ops.cond(
70
- keras.ops.shape(edge_target)[0] > 0,
70
+ keras.ops.greater(keras.ops.shape(edge_target)[0], 0),
71
71
  lambda: keras.ops.maximum(keras.ops.max(edge_target) + 1, 1),
72
72
  lambda: 0
73
73
  )
@@ -100,7 +100,7 @@ def segment_mean(
100
100
  ) -> tf.Tensor:
101
101
  if num_segments is None:
102
102
  num_segments = keras.ops.cond(
103
- keras.ops.shape(segment_ids)[0] > 0,
103
+ keras.ops.greater(keras.ops.shape(segment_ids)[0], 0),
104
104
  lambda: keras.ops.max(segment_ids) + 1,
105
105
  lambda: 0
106
106
  )
@@ -147,4 +147,15 @@ def euclidean_distance(
147
147
  axis=axis,
148
148
  keepdims=True
149
149
  )
150
- )
150
+ )
151
+
152
+ def displacement(
153
+ x1: tf.Tensor,
154
+ x2: tf.Tensor,
155
+ normalize: bool = False,
156
+ axis=-1,
157
+ ) -> tf.Tensor:
158
+ displacement = keras.ops.subtract(x1, x2)
159
+ if not normalize:
160
+ return displacement
161
+ return displacement / euclidean_distance(x1, x2, axis=axis)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: molcraft
3
- Version: 0.1.0a6
3
+ Version: 0.1.0a8
4
4
  Summary: Graph Neural Networks for Molecular Machine Learning
5
5
  Author-email: Alexander Kensert <alexander.kensert@gmail.com>
6
6
  License: MIT License
@@ -78,7 +78,7 @@ featurizer = featurizers.MolGraphFeaturizer(
78
78
  features.IsRotatable(),
79
79
  ],
80
80
  super_atom=True,
81
- self_loops=False,
81
+ self_loops=True,
82
82
  )
83
83
 
84
84
  graph = featurizer([('N[C@@H](C)C(=O)O', 2.0), ('N[C@@H](CS)C(=O)O', 1.0)])
@@ -0,0 +1,19 @@
1
+ molcraft/__init__.py,sha256=s8dUh6Fjq34j2aNgF13Y2NUkDwBWmsOAuIJVgY3gwCE,463
2
+ molcraft/callbacks.py,sha256=x5HnkZhqcFRrW6xdApt_jZ4X08A-0fxcnFKfdmRKa0c,3571
3
+ molcraft/chem.py,sha256=zHH7iX0ZJ7QmP-YqR_IXCpylTwCXHXptWf1DsblnZR4,21496
4
+ molcraft/conformers.py,sha256=K6ZtiSUNDN_fwqGP9JrPcwALLFFvlMlF_XejEJH3Sr4,4205
5
+ molcraft/datasets.py,sha256=rFgXTC1ZheLhfgQgcCspP_wEE54a33PIneH7OplbS-8,4047
6
+ molcraft/descriptors.py,sha256=gKqlJ3BqJLTeR2ft8isftSEaJDC8cv64eTq5IYhy4XM,3032
7
+ molcraft/features.py,sha256=aBYxDfQqQsVuyjKaPUlwEgvCjbNZ-FJhuKo2Cg5ajrA,13554
8
+ molcraft/featurizers.py,sha256=qNmXSOAeplICN3j-nzvWACVuKoJ_ZBzhYP9LterKVH8,27042
9
+ molcraft/layers.py,sha256=KKaH58zuov5aARj72BS_xK3ZQEwSFJrIPkoXQAAcqz8,62285
10
+ molcraft/losses.py,sha256=JEKZEX2f8vDgky_fUocsF8vZjy9VMzRjZUBa20Uf9Qw,1065
11
+ molcraft/models.py,sha256=FLXpO3OUmRxLmxG3MjBK4ZwcVFlea1gqEgs1ibKly2w,23263
12
+ molcraft/ops.py,sha256=dLIUq-KG8nOzEcphJqNbF_f82VZRDNrB1UKrcPt5JNM,4752
13
+ molcraft/records.py,sha256=0sjOdcr266ZER4F-aTBQ3AVPNAwflKWNiNJVsSc1-PQ,5370
14
+ molcraft/tensors.py,sha256=EOUKx496KUZsjA1zA2ABc7tU_TW3Jv7AXDsug_QsLbA,22407
15
+ molcraft-0.1.0a8.dist-info/licenses/LICENSE,sha256=sbVeqlrtZ0V63uYhZGL5dCxUm8rBAOqe2avyA1zIQNk,1074
16
+ molcraft-0.1.0a8.dist-info/METADATA,sha256=CtHK0DVlQECWUdlhg0KzvvpPyUD150BSyfzkdNF3fT8,4062
17
+ molcraft-0.1.0a8.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
18
+ molcraft-0.1.0a8.dist-info/top_level.txt,sha256=dENV6MfOceshM6MQCgJlcN1ojZkiCL9B4F7XyUge3QM,9
19
+ molcraft-0.1.0a8.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.3.1)
2
+ Generator: setuptools (80.7.1)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,19 +0,0 @@
1
- molcraft/__init__.py,sha256=eKMk4e5Wki4Ay7_BUuY7B-j3Po2l5FDDefPjkFQw3OM,463
2
- molcraft/callbacks.py,sha256=mkz4ALjJFPy8nHd2nCAuMbKceKnq4tIpZhUuUOvie2Y,1209
3
- molcraft/chem.py,sha256=apaECcQSuAMs3Tm12yc6ne4x0BGx5JzfoRhTC1WMhlI,20695
4
- molcraft/conformers.py,sha256=rojo8OaZrKAesx0JA5kf-JVNEpmsQyLSpcxbWhV9cd4,4324
5
- molcraft/datasets.py,sha256=rFgXTC1ZheLhfgQgcCspP_wEE54a33PIneH7OplbS-8,4047
6
- molcraft/descriptors.py,sha256=gKqlJ3BqJLTeR2ft8isftSEaJDC8cv64eTq5IYhy4XM,3032
7
- molcraft/features.py,sha256=69oV_GHNdBKPA4sp6Tpo6brvNmaauk_IVIzNjX7VDmg,13648
8
- molcraft/featurizers.py,sha256=aJJibnHCxvSu3bNbE2xQk34QvFb47Mnm__0MxlRLA0w,27323
9
- molcraft/layers.py,sha256=RyKmdHmHlYJJL15LvHH32daTKsChJ_pHmHUnpUcwS1U,73437
10
- molcraft/losses.py,sha256=JEKZEX2f8vDgky_fUocsF8vZjy9VMzRjZUBa20Uf9Qw,1065
11
- molcraft/models.py,sha256=Rl9CkQlOVkj20TLjGlwI8vaQwX07EqqWz22bFYtJlpk,22636
12
- molcraft/ops.py,sha256=eAi79aawJwxuIVVamjA1kPRHGlUm0PsvN-7d2CYu15I,4441
13
- molcraft/records.py,sha256=0sjOdcr266ZER4F-aTBQ3AVPNAwflKWNiNJVsSc1-PQ,5370
14
- molcraft/tensors.py,sha256=EOUKx496KUZsjA1zA2ABc7tU_TW3Jv7AXDsug_QsLbA,22407
15
- molcraft-0.1.0a6.dist-info/licenses/LICENSE,sha256=sbVeqlrtZ0V63uYhZGL5dCxUm8rBAOqe2avyA1zIQNk,1074
16
- molcraft-0.1.0a6.dist-info/METADATA,sha256=Zzl1K3WleDp056zbLChy5B1AQ3U26t22oMkIKRUpbMY,4063
17
- molcraft-0.1.0a6.dist-info/WHEEL,sha256=0CuiUZ_p9E4cD6NyLD6UG80LBXYyiSYZOKDm5lp32xk,91
18
- molcraft-0.1.0a6.dist-info/top_level.txt,sha256=dENV6MfOceshM6MQCgJlcN1ojZkiCL9B4F7XyUge3QM,9
19
- molcraft-0.1.0a6.dist-info/RECORD,,