molcraft 0.1.0a6__py3-none-any.whl → 0.1.0a8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of molcraft might be problematic. Click here for more details.
- molcraft/__init__.py +1 -1
- molcraft/callbacks.py +67 -0
- molcraft/chem.py +45 -30
- molcraft/conformers.py +0 -4
- molcraft/features.py +3 -9
- molcraft/featurizers.py +18 -26
- molcraft/layers.py +466 -801
- molcraft/models.py +16 -1
- molcraft/ops.py +14 -3
- {molcraft-0.1.0a6.dist-info → molcraft-0.1.0a8.dist-info}/METADATA +2 -2
- molcraft-0.1.0a8.dist-info/RECORD +19 -0
- {molcraft-0.1.0a6.dist-info → molcraft-0.1.0a8.dist-info}/WHEEL +1 -1
- molcraft-0.1.0a6.dist-info/RECORD +0 -19
- {molcraft-0.1.0a6.dist-info → molcraft-0.1.0a8.dist-info}/licenses/LICENSE +0 -0
- {molcraft-0.1.0a6.dist-info → molcraft-0.1.0a8.dist-info}/top_level.txt +0 -0
molcraft/models.py
CHANGED
|
@@ -415,6 +415,10 @@ class GraphModel(layers.GraphLayer, keras.models.Model):
|
|
|
415
415
|
def predict_step(self, tensor: tensors.GraphTensor) -> np.ndarray:
|
|
416
416
|
return self(tensor, training=False)
|
|
417
417
|
|
|
418
|
+
def compute_loss(self, x, y, y_pred, sample_weight=None):
|
|
419
|
+
y, y_pred, sample_weight = _maybe_reshape(y, y_pred, sample_weight)
|
|
420
|
+
return super().compute_loss(x, y, y_pred, sample_weight)
|
|
421
|
+
|
|
418
422
|
def compute_metrics(self, x, y, y_pred, sample_weight=None) -> dict[str, float]:
|
|
419
423
|
loss = self.compute_loss(x, y, y_pred, sample_weight)
|
|
420
424
|
metric_results = {}
|
|
@@ -423,7 +427,7 @@ class GraphModel(layers.GraphLayer, keras.models.Model):
|
|
|
423
427
|
metric.update_state(loss)
|
|
424
428
|
metric_results[metric.name] = metric.result()
|
|
425
429
|
else:
|
|
426
|
-
metric.update_state(y, y_pred)
|
|
430
|
+
metric.update_state(y, y_pred, sample_weight=sample_weight)
|
|
427
431
|
metric_results.update(metric.result())
|
|
428
432
|
return metric_results
|
|
429
433
|
|
|
@@ -593,3 +597,14 @@ def _make_dataset(x: tensors.GraphTensor, batch_size: int):
|
|
|
593
597
|
.batch(batch_size)
|
|
594
598
|
.prefetch(-1)
|
|
595
599
|
)
|
|
600
|
+
|
|
601
|
+
def _maybe_reshape(y, y_pred, sample_weight):
|
|
602
|
+
if (
|
|
603
|
+
sample_weight is not None and
|
|
604
|
+
len(keras.ops.shape(sample_weight)) == 2 and
|
|
605
|
+
sample_weight.shape == y_pred.shape
|
|
606
|
+
):
|
|
607
|
+
y = keras.ops.reshape(y, [-1])
|
|
608
|
+
y_pred = keras.ops.reshape(y_pred, [-1])
|
|
609
|
+
sample_weight = keras.ops.reshape(sample_weight, [-1])
|
|
610
|
+
return y, y_pred, sample_weight
|
molcraft/ops.py
CHANGED
|
@@ -67,7 +67,7 @@ def edge_softmax(
|
|
|
67
67
|
edge_target: tf.Tensor
|
|
68
68
|
) -> tf.Tensor:
|
|
69
69
|
num_segments = keras.ops.cond(
|
|
70
|
-
keras.ops.shape(edge_target)[0]
|
|
70
|
+
keras.ops.greater(keras.ops.shape(edge_target)[0], 0),
|
|
71
71
|
lambda: keras.ops.maximum(keras.ops.max(edge_target) + 1, 1),
|
|
72
72
|
lambda: 0
|
|
73
73
|
)
|
|
@@ -100,7 +100,7 @@ def segment_mean(
|
|
|
100
100
|
) -> tf.Tensor:
|
|
101
101
|
if num_segments is None:
|
|
102
102
|
num_segments = keras.ops.cond(
|
|
103
|
-
keras.ops.shape(segment_ids)[0]
|
|
103
|
+
keras.ops.greater(keras.ops.shape(segment_ids)[0], 0),
|
|
104
104
|
lambda: keras.ops.max(segment_ids) + 1,
|
|
105
105
|
lambda: 0
|
|
106
106
|
)
|
|
@@ -147,4 +147,15 @@ def euclidean_distance(
|
|
|
147
147
|
axis=axis,
|
|
148
148
|
keepdims=True
|
|
149
149
|
)
|
|
150
|
-
)
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
def displacement(
|
|
153
|
+
x1: tf.Tensor,
|
|
154
|
+
x2: tf.Tensor,
|
|
155
|
+
normalize: bool = False,
|
|
156
|
+
axis=-1,
|
|
157
|
+
) -> tf.Tensor:
|
|
158
|
+
displacement = keras.ops.subtract(x1, x2)
|
|
159
|
+
if not normalize:
|
|
160
|
+
return displacement
|
|
161
|
+
return displacement / euclidean_distance(x1, x2, axis=axis)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: molcraft
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.0a8
|
|
4
4
|
Summary: Graph Neural Networks for Molecular Machine Learning
|
|
5
5
|
Author-email: Alexander Kensert <alexander.kensert@gmail.com>
|
|
6
6
|
License: MIT License
|
|
@@ -78,7 +78,7 @@ featurizer = featurizers.MolGraphFeaturizer(
|
|
|
78
78
|
features.IsRotatable(),
|
|
79
79
|
],
|
|
80
80
|
super_atom=True,
|
|
81
|
-
self_loops=
|
|
81
|
+
self_loops=True,
|
|
82
82
|
)
|
|
83
83
|
|
|
84
84
|
graph = featurizer([('N[C@@H](C)C(=O)O', 2.0), ('N[C@@H](CS)C(=O)O', 1.0)])
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
molcraft/__init__.py,sha256=s8dUh6Fjq34j2aNgF13Y2NUkDwBWmsOAuIJVgY3gwCE,463
|
|
2
|
+
molcraft/callbacks.py,sha256=x5HnkZhqcFRrW6xdApt_jZ4X08A-0fxcnFKfdmRKa0c,3571
|
|
3
|
+
molcraft/chem.py,sha256=zHH7iX0ZJ7QmP-YqR_IXCpylTwCXHXptWf1DsblnZR4,21496
|
|
4
|
+
molcraft/conformers.py,sha256=K6ZtiSUNDN_fwqGP9JrPcwALLFFvlMlF_XejEJH3Sr4,4205
|
|
5
|
+
molcraft/datasets.py,sha256=rFgXTC1ZheLhfgQgcCspP_wEE54a33PIneH7OplbS-8,4047
|
|
6
|
+
molcraft/descriptors.py,sha256=gKqlJ3BqJLTeR2ft8isftSEaJDC8cv64eTq5IYhy4XM,3032
|
|
7
|
+
molcraft/features.py,sha256=aBYxDfQqQsVuyjKaPUlwEgvCjbNZ-FJhuKo2Cg5ajrA,13554
|
|
8
|
+
molcraft/featurizers.py,sha256=qNmXSOAeplICN3j-nzvWACVuKoJ_ZBzhYP9LterKVH8,27042
|
|
9
|
+
molcraft/layers.py,sha256=KKaH58zuov5aARj72BS_xK3ZQEwSFJrIPkoXQAAcqz8,62285
|
|
10
|
+
molcraft/losses.py,sha256=JEKZEX2f8vDgky_fUocsF8vZjy9VMzRjZUBa20Uf9Qw,1065
|
|
11
|
+
molcraft/models.py,sha256=FLXpO3OUmRxLmxG3MjBK4ZwcVFlea1gqEgs1ibKly2w,23263
|
|
12
|
+
molcraft/ops.py,sha256=dLIUq-KG8nOzEcphJqNbF_f82VZRDNrB1UKrcPt5JNM,4752
|
|
13
|
+
molcraft/records.py,sha256=0sjOdcr266ZER4F-aTBQ3AVPNAwflKWNiNJVsSc1-PQ,5370
|
|
14
|
+
molcraft/tensors.py,sha256=EOUKx496KUZsjA1zA2ABc7tU_TW3Jv7AXDsug_QsLbA,22407
|
|
15
|
+
molcraft-0.1.0a8.dist-info/licenses/LICENSE,sha256=sbVeqlrtZ0V63uYhZGL5dCxUm8rBAOqe2avyA1zIQNk,1074
|
|
16
|
+
molcraft-0.1.0a8.dist-info/METADATA,sha256=CtHK0DVlQECWUdlhg0KzvvpPyUD150BSyfzkdNF3fT8,4062
|
|
17
|
+
molcraft-0.1.0a8.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
|
|
18
|
+
molcraft-0.1.0a8.dist-info/top_level.txt,sha256=dENV6MfOceshM6MQCgJlcN1ojZkiCL9B4F7XyUge3QM,9
|
|
19
|
+
molcraft-0.1.0a8.dist-info/RECORD,,
|
|
@@ -1,19 +0,0 @@
|
|
|
1
|
-
molcraft/__init__.py,sha256=eKMk4e5Wki4Ay7_BUuY7B-j3Po2l5FDDefPjkFQw3OM,463
|
|
2
|
-
molcraft/callbacks.py,sha256=mkz4ALjJFPy8nHd2nCAuMbKceKnq4tIpZhUuUOvie2Y,1209
|
|
3
|
-
molcraft/chem.py,sha256=apaECcQSuAMs3Tm12yc6ne4x0BGx5JzfoRhTC1WMhlI,20695
|
|
4
|
-
molcraft/conformers.py,sha256=rojo8OaZrKAesx0JA5kf-JVNEpmsQyLSpcxbWhV9cd4,4324
|
|
5
|
-
molcraft/datasets.py,sha256=rFgXTC1ZheLhfgQgcCspP_wEE54a33PIneH7OplbS-8,4047
|
|
6
|
-
molcraft/descriptors.py,sha256=gKqlJ3BqJLTeR2ft8isftSEaJDC8cv64eTq5IYhy4XM,3032
|
|
7
|
-
molcraft/features.py,sha256=69oV_GHNdBKPA4sp6Tpo6brvNmaauk_IVIzNjX7VDmg,13648
|
|
8
|
-
molcraft/featurizers.py,sha256=aJJibnHCxvSu3bNbE2xQk34QvFb47Mnm__0MxlRLA0w,27323
|
|
9
|
-
molcraft/layers.py,sha256=RyKmdHmHlYJJL15LvHH32daTKsChJ_pHmHUnpUcwS1U,73437
|
|
10
|
-
molcraft/losses.py,sha256=JEKZEX2f8vDgky_fUocsF8vZjy9VMzRjZUBa20Uf9Qw,1065
|
|
11
|
-
molcraft/models.py,sha256=Rl9CkQlOVkj20TLjGlwI8vaQwX07EqqWz22bFYtJlpk,22636
|
|
12
|
-
molcraft/ops.py,sha256=eAi79aawJwxuIVVamjA1kPRHGlUm0PsvN-7d2CYu15I,4441
|
|
13
|
-
molcraft/records.py,sha256=0sjOdcr266ZER4F-aTBQ3AVPNAwflKWNiNJVsSc1-PQ,5370
|
|
14
|
-
molcraft/tensors.py,sha256=EOUKx496KUZsjA1zA2ABc7tU_TW3Jv7AXDsug_QsLbA,22407
|
|
15
|
-
molcraft-0.1.0a6.dist-info/licenses/LICENSE,sha256=sbVeqlrtZ0V63uYhZGL5dCxUm8rBAOqe2avyA1zIQNk,1074
|
|
16
|
-
molcraft-0.1.0a6.dist-info/METADATA,sha256=Zzl1K3WleDp056zbLChy5B1AQ3U26t22oMkIKRUpbMY,4063
|
|
17
|
-
molcraft-0.1.0a6.dist-info/WHEEL,sha256=0CuiUZ_p9E4cD6NyLD6UG80LBXYyiSYZOKDm5lp32xk,91
|
|
18
|
-
molcraft-0.1.0a6.dist-info/top_level.txt,sha256=dENV6MfOceshM6MQCgJlcN1ojZkiCL9B4F7XyUge3QM,9
|
|
19
|
-
molcraft-0.1.0a6.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|