molcraft 0.1.0rc10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
molcraft/ops.py ADDED
@@ -0,0 +1,195 @@
1
+ import warnings
2
+ import keras
3
+ import numpy as np
4
+ import tensorflow as tf
5
+ from keras import backend
6
+
7
+
8
+ @keras.saving.register_keras_serializable(package='molcraft')
9
+ def gather(
10
+ node_feature: tf.Tensor,
11
+ edge: tf.Tensor
12
+ ) -> tf.Tensor:
13
+ if backend.backend() == 'tensorflow':
14
+ return tf.gather(node_feature, edge)
15
+ expected_rank = len(keras.ops.shape(node_feature))
16
+ current_rank = len(keras.ops.shape(edge))
17
+ for _ in range(expected_rank - current_rank):
18
+ edge = keras.ops.expand_dims(edge, axis=-1)
19
+ return keras.ops.take_along_axis(node_feature, edge, axis=0)
20
+
21
+ @keras.saving.register_keras_serializable(package='molcraft')
22
+ def aggregate(
23
+ node_feature: tf.Tensor,
24
+ edge: tf.Tensor,
25
+ num_nodes: tf.Tensor,
26
+ mode: str = 'sum',
27
+ ) -> tf.Tensor:
28
+ if mode == 'mean':
29
+ return segment_mean(
30
+ node_feature, edge, num_nodes, sorted=False
31
+ )
32
+ return keras.ops.segment_sum(
33
+ node_feature, edge, num_nodes, sorted=False
34
+ )
35
+
36
+ @keras.saving.register_keras_serializable(package='molcraft')
37
+ def propagate(
38
+ node_feature: tf.Tensor,
39
+ edge_source: tf.Tensor,
40
+ edge_target: tf.Tensor,
41
+ edge_feature: tf.Tensor | None = None,
42
+ edge_weight: tf.Tensor | None = None,
43
+ ) -> tf.Tensor:
44
+ num_nodes = keras.ops.shape(node_feature)[0]
45
+
46
+ node_feature_source = gather(node_feature, edge_source)
47
+
48
+ if edge_weight is not None:
49
+ node_feature_source *= edge_weight
50
+
51
+ if edge_feature is not None:
52
+ node_feature_source += edge_feature
53
+
54
+ return aggregate(node_feature, edge_target, num_nodes)
55
+
56
+ @keras.saving.register_keras_serializable(package='molcraft')
57
+ def scatter_update(
58
+ inputs: tf.Tensor,
59
+ indices: tf.Tensor,
60
+ updates: tf.Tensor,
61
+ ) -> tf.Tensor:
62
+ if indices.dtype == tf.bool:
63
+ indices = keras.ops.stack(keras.ops.where(indices), axis=-1)
64
+ expected_rank = len(keras.ops.shape(inputs))
65
+ current_rank = len(keras.ops.shape(indices))
66
+ for _ in range(expected_rank - current_rank):
67
+ indices = keras.ops.expand_dims(indices, axis=-1)
68
+ return keras.ops.scatter_update(inputs, indices, updates)
69
+
70
+ @keras.saving.register_keras_serializable(package='molcraft')
71
+ def scatter_add(
72
+ inputs: tf.Tensor,
73
+ indices: tf.Tensor,
74
+ updates: tf.Tensor,
75
+ ) -> tf.Tensor:
76
+ if indices.dtype == tf.bool:
77
+ indices = keras.ops.stack(keras.ops.where(indices), axis=-1)
78
+ expected_rank = len(keras.ops.shape(inputs))
79
+ current_rank = len(keras.ops.shape(indices))
80
+ for _ in range(expected_rank - current_rank):
81
+ indices = keras.ops.expand_dims(indices, axis=-1)
82
+ if backend.backend() == 'tensorflow':
83
+ return tf.tensor_scatter_nd_add(inputs, indices, updates)
84
+ updates = scatter_update(keras.ops.zeros_like(inputs), indices, updates)
85
+ return inputs + updates
86
+
87
+ @keras.saving.register_keras_serializable(package='molcraft')
88
+ def edge_softmax(
89
+ score: tf.Tensor,
90
+ edge_target: tf.Tensor
91
+ ) -> tf.Tensor:
92
+ num_segments = keras.ops.cond(
93
+ keras.ops.greater(keras.ops.shape(edge_target)[0], 0),
94
+ lambda: keras.ops.maximum(keras.ops.max(edge_target) + 1, 1),
95
+ lambda: 0
96
+ )
97
+ score_max = keras.ops.segment_max(
98
+ score, edge_target, num_segments, sorted=False
99
+ )
100
+ score_max = gather(score_max, edge_target)
101
+ numerator = keras.ops.exp(score - score_max)
102
+ denominator = keras.ops.segment_sum(
103
+ numerator, edge_target, num_segments, sorted=False
104
+ )
105
+ denominator = gather(denominator, edge_target)
106
+ return numerator / denominator
107
+
108
+ @keras.saving.register_keras_serializable(package='molcraft')
109
+ def edge_weight(
110
+ edge: tf.Tensor,
111
+ edge_weight: tf.Tensor,
112
+ ) -> tf.Tensor:
113
+ expected_rank = len(keras.ops.shape(edge))
114
+ current_rank = len(keras.ops.shape(edge_weight))
115
+ for _ in range(expected_rank - current_rank):
116
+ edge_weight = keras.ops.expand_dims(edge_weight, axis=-1)
117
+ return edge * edge_weight
118
+
119
+ @keras.saving.register_keras_serializable(package='molcraft')
120
+ def segment_mean(
121
+ data: tf.Tensor,
122
+ segment_ids: tf.Tensor,
123
+ num_segments: int | None = None,
124
+ sorted: bool = False,
125
+ ) -> tf.Tensor:
126
+ if num_segments is None:
127
+ num_segments = keras.ops.cond(
128
+ keras.ops.greater(keras.ops.shape(segment_ids)[0], 0),
129
+ lambda: keras.ops.max(segment_ids) + 1,
130
+ lambda: 0
131
+ )
132
+ if backend.backend() == 'tensorflow':
133
+ segment_mean_fn = (
134
+ tf.math.unsorted_segment_mean if not sorted else
135
+ tf.math.segment_mean
136
+ )
137
+ return segment_mean_fn(
138
+ data=data,
139
+ segment_ids=segment_ids,
140
+ num_segments=num_segments
141
+ )
142
+ x = keras.ops.segment_sum(
143
+ data=data,
144
+ segment_ids=segment_ids,
145
+ num_segments=num_segments,
146
+ sorted=sorted
147
+ )
148
+ sizes = keras.ops.cast(
149
+ keras.ops.bincount(segment_ids, minlength=num_segments),
150
+ dtype=x.dtype
151
+ )
152
+ return x / sizes[:, None]
153
+
154
+ @keras.saving.register_keras_serializable(package='molcraft')
155
+ def gaussian(
156
+ x: tf.Tensor,
157
+ mean: tf.Tensor,
158
+ std: tf.Tensor
159
+ ) -> tf.Tensor:
160
+ expected_rank = len(keras.ops.shape(x))
161
+ current_rank = len(keras.ops.shape(mean))
162
+ for _ in range(expected_rank - current_rank):
163
+ mean = keras.ops.expand_dims(mean, axis=0)
164
+ std = keras.ops.expand_dims(std, axis=0)
165
+ a = (2 * np.pi) ** 0.5
166
+ return keras.ops.exp(-0.5 * (((x - mean) / std) ** 2)) / (a * std)
167
+
168
+ @keras.saving.register_keras_serializable(package='molcraft')
169
+ def euclidean_distance(
170
+ x1: tf.Tensor,
171
+ x2: tf.Tensor,
172
+ axis: int = -1,
173
+ keepdims: bool = True,
174
+ ) -> tf.Tensor:
175
+ relative_distance = keras.ops.subtract(x1, x2)
176
+ return keras.ops.sqrt(
177
+ keras.ops.sum(
178
+ keras.ops.square(relative_distance),
179
+ axis=axis,
180
+ keepdims=keepdims
181
+ )
182
+ )
183
+
184
+ @keras.saving.register_keras_serializable(package='molcraft')
185
+ def displacement(
186
+ x1: tf.Tensor,
187
+ x2: tf.Tensor,
188
+ normalize: bool = False,
189
+ axis: int = -1,
190
+ keepdims: bool = True,
191
+ ) -> tf.Tensor:
192
+ displacement = keras.ops.subtract(x1, x2)
193
+ if not normalize:
194
+ return displacement
195
+ return displacement / euclidean_distance(x1, x2, axis=axis, keepdims=keepdims)
molcraft/records.py ADDED
@@ -0,0 +1,187 @@
1
+ import warnings
2
+ import os
3
+ import math
4
+ import glob
5
+ import time
6
+ import typing
7
+ import tensorflow as tf
8
+ import numpy as np
9
+ import pandas as pd
10
+ import multiprocessing as mp
11
+
12
+ from molcraft import tensors
13
+
14
+ if typing.TYPE_CHECKING:
15
+ from molcraft import featurizers
16
+
17
+
18
+ def write(
19
+ inputs: list[str | tuple],
20
+ featurizer: 'featurizers.GraphFeaturizer',
21
+ path: str,
22
+ exist_ok: bool = False,
23
+ overwrite: bool = False,
24
+ num_files: typing.Optional[int] = None,
25
+ num_processes: typing.Optional[int] = None,
26
+ multiprocessing: bool = False,
27
+ device: str = '/cpu:0'
28
+ ) -> None:
29
+
30
+ if os.path.isdir(path):
31
+ if not exist_ok:
32
+ raise FileExistsError(f'Records already exist: {path}')
33
+ if not overwrite:
34
+ return
35
+ else:
36
+ _remove_files(path)
37
+ else:
38
+ os.makedirs(path)
39
+
40
+ with tf.device(device):
41
+
42
+ if isinstance(inputs, (pd.DataFrame, pd.Series)):
43
+ inputs = list(inputs.iterrows())
44
+
45
+ example = featurizer._call(inputs[0])
46
+ save_spec(os.path.join(path, 'spec.pb'), example.spec)
47
+
48
+ if num_processes is None:
49
+ num_processes = mp.cpu_count()
50
+
51
+ if num_files is None:
52
+ num_files = min(len(inputs), max(1, math.ceil(len(inputs) / 1_000)))
53
+
54
+ num_examples = len(inputs)
55
+ chunk_sizes = [0] * num_files
56
+ for i in range(num_examples):
57
+ chunk_sizes[i % num_files] += 1
58
+
59
+ input_chunks = []
60
+ start_indices = []
61
+ current_index = 0
62
+ for size in chunk_sizes:
63
+ input_chunks.append(inputs[current_index: current_index + size])
64
+ start_indices.append(current_index)
65
+ current_index += size
66
+
67
+ assert current_index == num_examples
68
+
69
+ paths = [
70
+ os.path.join(path, f'tfrecord-{i:06d}.tfrecord')
71
+ for i in range(num_files)
72
+ ]
73
+
74
+ if not multiprocessing:
75
+ for path, input_chunk, start_index in zip(paths, input_chunks, start_indices):
76
+ _write_tfrecord(input_chunk, path, featurizer, start_index)
77
+ return
78
+
79
+ processes = []
80
+
81
+ for path, input_chunk, start_index in zip(paths, input_chunks, start_indices):
82
+
83
+ while len(processes) >= num_processes:
84
+ for process in processes:
85
+ if not process.is_alive():
86
+ processes.remove(process)
87
+ else:
88
+ time.sleep(0.1)
89
+ continue
90
+
91
+ process = mp.Process(
92
+ target=_write_tfrecord,
93
+ args=(input_chunk, path, featurizer, start_index)
94
+ )
95
+ processes.append(process)
96
+ process.start()
97
+
98
+ for process in processes:
99
+ process.join()
100
+
101
+ def read(
102
+ path: str,
103
+ shuffle_files: bool = False
104
+ ) -> tf.data.Dataset:
105
+ spec = load_spec(os.path.join(path, 'spec.pb'))
106
+ filenames = sorted(glob.glob(os.path.join(path, '*.tfrecord')))
107
+ num_files = len(filenames)
108
+ ds = tf.data.Dataset.from_tensor_slices(filenames)
109
+ if shuffle_files:
110
+ ds = ds.shuffle(num_files)
111
+ ds = ds.interleave(
112
+ tf.data.TFRecordDataset, num_parallel_calls=1)
113
+ ds = ds.map(
114
+ lambda x: _parse_example(x, spec),
115
+ num_parallel_calls=tf.data.AUTOTUNE)
116
+ if not tensors.is_scalar(spec):
117
+ ds = ds.unbatch()
118
+ return ds
119
+
120
+ def save_spec(path: str, spec: tensors.GraphTensor.Spec) -> None:
121
+ proto = spec.experimental_as_proto()
122
+ with open(path, 'wb') as fh:
123
+ fh.write(proto.SerializeToString())
124
+
125
+ def load_spec(path: str) -> tensors.GraphTensor.Spec:
126
+ with open(path, 'rb') as fh:
127
+ serialized_proto = fh.read()
128
+ spec = tensors.GraphTensor.Spec.experimental_from_proto(
129
+ tensors.GraphTensor.Spec
130
+ .experimental_type_proto()
131
+ .FromString(serialized_proto)
132
+ )
133
+ return spec
134
+
135
+ def _write_tfrecord(
136
+ inputs: list,
137
+ path: str,
138
+ featurizer: 'featurizers.GraphFeaturizer',
139
+ start_index: int,
140
+ ) -> None:
141
+ with tf.io.TFRecordWriter(path) as writer:
142
+ for i, x in enumerate(inputs):
143
+ try:
144
+ tensor = featurizer._call(x)
145
+ serialized = _serialize_example(tensor)
146
+ writer.write(serialized)
147
+ except Exception as e:
148
+ index = getattr(x, 'Index', (i + start_index))
149
+ warnings.warn(
150
+ f'Could not write record for index {index}, '
151
+ f'proceeding without it. Exception raised:\n{e}'
152
+ )
153
+
154
+ def _serialize_example(tensor):
155
+ flat_values = tf.nest.flatten(tensor, expand_composites=True)
156
+ flat_values = [
157
+ tf.io.serialize_tensor(value).numpy() for value in flat_values
158
+ ]
159
+ feature = tf.train.Feature(
160
+ bytes_list=tf.train.BytesList(value=flat_values)
161
+ )
162
+ example_proto = tf.train.Example(
163
+ features=tf.train.Features(feature={'feature': feature})
164
+ )
165
+ return example_proto.SerializeToString()
166
+
167
+ def _parse_example(
168
+ x: tf.Tensor,
169
+ spec: tensors.GraphTensor.Spec
170
+ ) -> tf.Tensor:
171
+ out = tf.io.parse_single_example(
172
+ x, features={'feature': tf.io.RaggedFeature(tf.string)})['feature']
173
+ out = [
174
+ tf.ensure_shape(tf.io.parse_tensor(x[0], s.dtype), s.shape)
175
+ for (x, s) in zip(
176
+ tf.split(out, len(tf.nest.flatten(spec, expand_composites=True))),
177
+ tf.nest.flatten(spec, expand_composites=True)
178
+ )
179
+ ]
180
+ out = tf.nest.pack_sequence_as(spec, tf.nest.flatten(out), expand_composites=True)
181
+ return out
182
+
183
+ def _remove_files(path):
184
+ for filename in os.listdir(path):
185
+ if filename.endswith('tfrecord') or filename == 'spec.pb':
186
+ filepath = os.path.join(path, filename)
187
+ os.remove(filepath)