molcraft 0.1.0rc10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- molcraft/__init__.py +18 -0
- molcraft/callbacks.py +100 -0
- molcraft/chem.py +714 -0
- molcraft/datasets.py +132 -0
- molcraft/descriptors.py +149 -0
- molcraft/features.py +379 -0
- molcraft/featurizers.py +727 -0
- molcraft/layers.py +2034 -0
- molcraft/losses.py +37 -0
- molcraft/models.py +627 -0
- molcraft/ops.py +195 -0
- molcraft/records.py +187 -0
- molcraft/tensors.py +561 -0
- molcraft/trainers.py +212 -0
- molcraft-0.1.0rc10.dist-info/METADATA +118 -0
- molcraft-0.1.0rc10.dist-info/RECORD +19 -0
- molcraft-0.1.0rc10.dist-info/WHEEL +5 -0
- molcraft-0.1.0rc10.dist-info/licenses/LICENSE +21 -0
- molcraft-0.1.0rc10.dist-info/top_level.txt +1 -0
molcraft/ops.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
import keras
|
|
3
|
+
import numpy as np
|
|
4
|
+
import tensorflow as tf
|
|
5
|
+
from keras import backend
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
9
|
+
def gather(
|
|
10
|
+
node_feature: tf.Tensor,
|
|
11
|
+
edge: tf.Tensor
|
|
12
|
+
) -> tf.Tensor:
|
|
13
|
+
if backend.backend() == 'tensorflow':
|
|
14
|
+
return tf.gather(node_feature, edge)
|
|
15
|
+
expected_rank = len(keras.ops.shape(node_feature))
|
|
16
|
+
current_rank = len(keras.ops.shape(edge))
|
|
17
|
+
for _ in range(expected_rank - current_rank):
|
|
18
|
+
edge = keras.ops.expand_dims(edge, axis=-1)
|
|
19
|
+
return keras.ops.take_along_axis(node_feature, edge, axis=0)
|
|
20
|
+
|
|
21
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
22
|
+
def aggregate(
|
|
23
|
+
node_feature: tf.Tensor,
|
|
24
|
+
edge: tf.Tensor,
|
|
25
|
+
num_nodes: tf.Tensor,
|
|
26
|
+
mode: str = 'sum',
|
|
27
|
+
) -> tf.Tensor:
|
|
28
|
+
if mode == 'mean':
|
|
29
|
+
return segment_mean(
|
|
30
|
+
node_feature, edge, num_nodes, sorted=False
|
|
31
|
+
)
|
|
32
|
+
return keras.ops.segment_sum(
|
|
33
|
+
node_feature, edge, num_nodes, sorted=False
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
37
|
+
def propagate(
|
|
38
|
+
node_feature: tf.Tensor,
|
|
39
|
+
edge_source: tf.Tensor,
|
|
40
|
+
edge_target: tf.Tensor,
|
|
41
|
+
edge_feature: tf.Tensor | None = None,
|
|
42
|
+
edge_weight: tf.Tensor | None = None,
|
|
43
|
+
) -> tf.Tensor:
|
|
44
|
+
num_nodes = keras.ops.shape(node_feature)[0]
|
|
45
|
+
|
|
46
|
+
node_feature_source = gather(node_feature, edge_source)
|
|
47
|
+
|
|
48
|
+
if edge_weight is not None:
|
|
49
|
+
node_feature_source *= edge_weight
|
|
50
|
+
|
|
51
|
+
if edge_feature is not None:
|
|
52
|
+
node_feature_source += edge_feature
|
|
53
|
+
|
|
54
|
+
return aggregate(node_feature, edge_target, num_nodes)
|
|
55
|
+
|
|
56
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
57
|
+
def scatter_update(
|
|
58
|
+
inputs: tf.Tensor,
|
|
59
|
+
indices: tf.Tensor,
|
|
60
|
+
updates: tf.Tensor,
|
|
61
|
+
) -> tf.Tensor:
|
|
62
|
+
if indices.dtype == tf.bool:
|
|
63
|
+
indices = keras.ops.stack(keras.ops.where(indices), axis=-1)
|
|
64
|
+
expected_rank = len(keras.ops.shape(inputs))
|
|
65
|
+
current_rank = len(keras.ops.shape(indices))
|
|
66
|
+
for _ in range(expected_rank - current_rank):
|
|
67
|
+
indices = keras.ops.expand_dims(indices, axis=-1)
|
|
68
|
+
return keras.ops.scatter_update(inputs, indices, updates)
|
|
69
|
+
|
|
70
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
71
|
+
def scatter_add(
|
|
72
|
+
inputs: tf.Tensor,
|
|
73
|
+
indices: tf.Tensor,
|
|
74
|
+
updates: tf.Tensor,
|
|
75
|
+
) -> tf.Tensor:
|
|
76
|
+
if indices.dtype == tf.bool:
|
|
77
|
+
indices = keras.ops.stack(keras.ops.where(indices), axis=-1)
|
|
78
|
+
expected_rank = len(keras.ops.shape(inputs))
|
|
79
|
+
current_rank = len(keras.ops.shape(indices))
|
|
80
|
+
for _ in range(expected_rank - current_rank):
|
|
81
|
+
indices = keras.ops.expand_dims(indices, axis=-1)
|
|
82
|
+
if backend.backend() == 'tensorflow':
|
|
83
|
+
return tf.tensor_scatter_nd_add(inputs, indices, updates)
|
|
84
|
+
updates = scatter_update(keras.ops.zeros_like(inputs), indices, updates)
|
|
85
|
+
return inputs + updates
|
|
86
|
+
|
|
87
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
88
|
+
def edge_softmax(
|
|
89
|
+
score: tf.Tensor,
|
|
90
|
+
edge_target: tf.Tensor
|
|
91
|
+
) -> tf.Tensor:
|
|
92
|
+
num_segments = keras.ops.cond(
|
|
93
|
+
keras.ops.greater(keras.ops.shape(edge_target)[0], 0),
|
|
94
|
+
lambda: keras.ops.maximum(keras.ops.max(edge_target) + 1, 1),
|
|
95
|
+
lambda: 0
|
|
96
|
+
)
|
|
97
|
+
score_max = keras.ops.segment_max(
|
|
98
|
+
score, edge_target, num_segments, sorted=False
|
|
99
|
+
)
|
|
100
|
+
score_max = gather(score_max, edge_target)
|
|
101
|
+
numerator = keras.ops.exp(score - score_max)
|
|
102
|
+
denominator = keras.ops.segment_sum(
|
|
103
|
+
numerator, edge_target, num_segments, sorted=False
|
|
104
|
+
)
|
|
105
|
+
denominator = gather(denominator, edge_target)
|
|
106
|
+
return numerator / denominator
|
|
107
|
+
|
|
108
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
109
|
+
def edge_weight(
|
|
110
|
+
edge: tf.Tensor,
|
|
111
|
+
edge_weight: tf.Tensor,
|
|
112
|
+
) -> tf.Tensor:
|
|
113
|
+
expected_rank = len(keras.ops.shape(edge))
|
|
114
|
+
current_rank = len(keras.ops.shape(edge_weight))
|
|
115
|
+
for _ in range(expected_rank - current_rank):
|
|
116
|
+
edge_weight = keras.ops.expand_dims(edge_weight, axis=-1)
|
|
117
|
+
return edge * edge_weight
|
|
118
|
+
|
|
119
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
120
|
+
def segment_mean(
|
|
121
|
+
data: tf.Tensor,
|
|
122
|
+
segment_ids: tf.Tensor,
|
|
123
|
+
num_segments: int | None = None,
|
|
124
|
+
sorted: bool = False,
|
|
125
|
+
) -> tf.Tensor:
|
|
126
|
+
if num_segments is None:
|
|
127
|
+
num_segments = keras.ops.cond(
|
|
128
|
+
keras.ops.greater(keras.ops.shape(segment_ids)[0], 0),
|
|
129
|
+
lambda: keras.ops.max(segment_ids) + 1,
|
|
130
|
+
lambda: 0
|
|
131
|
+
)
|
|
132
|
+
if backend.backend() == 'tensorflow':
|
|
133
|
+
segment_mean_fn = (
|
|
134
|
+
tf.math.unsorted_segment_mean if not sorted else
|
|
135
|
+
tf.math.segment_mean
|
|
136
|
+
)
|
|
137
|
+
return segment_mean_fn(
|
|
138
|
+
data=data,
|
|
139
|
+
segment_ids=segment_ids,
|
|
140
|
+
num_segments=num_segments
|
|
141
|
+
)
|
|
142
|
+
x = keras.ops.segment_sum(
|
|
143
|
+
data=data,
|
|
144
|
+
segment_ids=segment_ids,
|
|
145
|
+
num_segments=num_segments,
|
|
146
|
+
sorted=sorted
|
|
147
|
+
)
|
|
148
|
+
sizes = keras.ops.cast(
|
|
149
|
+
keras.ops.bincount(segment_ids, minlength=num_segments),
|
|
150
|
+
dtype=x.dtype
|
|
151
|
+
)
|
|
152
|
+
return x / sizes[:, None]
|
|
153
|
+
|
|
154
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
155
|
+
def gaussian(
|
|
156
|
+
x: tf.Tensor,
|
|
157
|
+
mean: tf.Tensor,
|
|
158
|
+
std: tf.Tensor
|
|
159
|
+
) -> tf.Tensor:
|
|
160
|
+
expected_rank = len(keras.ops.shape(x))
|
|
161
|
+
current_rank = len(keras.ops.shape(mean))
|
|
162
|
+
for _ in range(expected_rank - current_rank):
|
|
163
|
+
mean = keras.ops.expand_dims(mean, axis=0)
|
|
164
|
+
std = keras.ops.expand_dims(std, axis=0)
|
|
165
|
+
a = (2 * np.pi) ** 0.5
|
|
166
|
+
return keras.ops.exp(-0.5 * (((x - mean) / std) ** 2)) / (a * std)
|
|
167
|
+
|
|
168
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
169
|
+
def euclidean_distance(
|
|
170
|
+
x1: tf.Tensor,
|
|
171
|
+
x2: tf.Tensor,
|
|
172
|
+
axis: int = -1,
|
|
173
|
+
keepdims: bool = True,
|
|
174
|
+
) -> tf.Tensor:
|
|
175
|
+
relative_distance = keras.ops.subtract(x1, x2)
|
|
176
|
+
return keras.ops.sqrt(
|
|
177
|
+
keras.ops.sum(
|
|
178
|
+
keras.ops.square(relative_distance),
|
|
179
|
+
axis=axis,
|
|
180
|
+
keepdims=keepdims
|
|
181
|
+
)
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
@keras.saving.register_keras_serializable(package='molcraft')
|
|
185
|
+
def displacement(
|
|
186
|
+
x1: tf.Tensor,
|
|
187
|
+
x2: tf.Tensor,
|
|
188
|
+
normalize: bool = False,
|
|
189
|
+
axis: int = -1,
|
|
190
|
+
keepdims: bool = True,
|
|
191
|
+
) -> tf.Tensor:
|
|
192
|
+
displacement = keras.ops.subtract(x1, x2)
|
|
193
|
+
if not normalize:
|
|
194
|
+
return displacement
|
|
195
|
+
return displacement / euclidean_distance(x1, x2, axis=axis, keepdims=keepdims)
|
molcraft/records.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
import os
|
|
3
|
+
import math
|
|
4
|
+
import glob
|
|
5
|
+
import time
|
|
6
|
+
import typing
|
|
7
|
+
import tensorflow as tf
|
|
8
|
+
import numpy as np
|
|
9
|
+
import pandas as pd
|
|
10
|
+
import multiprocessing as mp
|
|
11
|
+
|
|
12
|
+
from molcraft import tensors
|
|
13
|
+
|
|
14
|
+
if typing.TYPE_CHECKING:
|
|
15
|
+
from molcraft import featurizers
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def write(
|
|
19
|
+
inputs: list[str | tuple],
|
|
20
|
+
featurizer: 'featurizers.GraphFeaturizer',
|
|
21
|
+
path: str,
|
|
22
|
+
exist_ok: bool = False,
|
|
23
|
+
overwrite: bool = False,
|
|
24
|
+
num_files: typing.Optional[int] = None,
|
|
25
|
+
num_processes: typing.Optional[int] = None,
|
|
26
|
+
multiprocessing: bool = False,
|
|
27
|
+
device: str = '/cpu:0'
|
|
28
|
+
) -> None:
|
|
29
|
+
|
|
30
|
+
if os.path.isdir(path):
|
|
31
|
+
if not exist_ok:
|
|
32
|
+
raise FileExistsError(f'Records already exist: {path}')
|
|
33
|
+
if not overwrite:
|
|
34
|
+
return
|
|
35
|
+
else:
|
|
36
|
+
_remove_files(path)
|
|
37
|
+
else:
|
|
38
|
+
os.makedirs(path)
|
|
39
|
+
|
|
40
|
+
with tf.device(device):
|
|
41
|
+
|
|
42
|
+
if isinstance(inputs, (pd.DataFrame, pd.Series)):
|
|
43
|
+
inputs = list(inputs.iterrows())
|
|
44
|
+
|
|
45
|
+
example = featurizer._call(inputs[0])
|
|
46
|
+
save_spec(os.path.join(path, 'spec.pb'), example.spec)
|
|
47
|
+
|
|
48
|
+
if num_processes is None:
|
|
49
|
+
num_processes = mp.cpu_count()
|
|
50
|
+
|
|
51
|
+
if num_files is None:
|
|
52
|
+
num_files = min(len(inputs), max(1, math.ceil(len(inputs) / 1_000)))
|
|
53
|
+
|
|
54
|
+
num_examples = len(inputs)
|
|
55
|
+
chunk_sizes = [0] * num_files
|
|
56
|
+
for i in range(num_examples):
|
|
57
|
+
chunk_sizes[i % num_files] += 1
|
|
58
|
+
|
|
59
|
+
input_chunks = []
|
|
60
|
+
start_indices = []
|
|
61
|
+
current_index = 0
|
|
62
|
+
for size in chunk_sizes:
|
|
63
|
+
input_chunks.append(inputs[current_index: current_index + size])
|
|
64
|
+
start_indices.append(current_index)
|
|
65
|
+
current_index += size
|
|
66
|
+
|
|
67
|
+
assert current_index == num_examples
|
|
68
|
+
|
|
69
|
+
paths = [
|
|
70
|
+
os.path.join(path, f'tfrecord-{i:06d}.tfrecord')
|
|
71
|
+
for i in range(num_files)
|
|
72
|
+
]
|
|
73
|
+
|
|
74
|
+
if not multiprocessing:
|
|
75
|
+
for path, input_chunk, start_index in zip(paths, input_chunks, start_indices):
|
|
76
|
+
_write_tfrecord(input_chunk, path, featurizer, start_index)
|
|
77
|
+
return
|
|
78
|
+
|
|
79
|
+
processes = []
|
|
80
|
+
|
|
81
|
+
for path, input_chunk, start_index in zip(paths, input_chunks, start_indices):
|
|
82
|
+
|
|
83
|
+
while len(processes) >= num_processes:
|
|
84
|
+
for process in processes:
|
|
85
|
+
if not process.is_alive():
|
|
86
|
+
processes.remove(process)
|
|
87
|
+
else:
|
|
88
|
+
time.sleep(0.1)
|
|
89
|
+
continue
|
|
90
|
+
|
|
91
|
+
process = mp.Process(
|
|
92
|
+
target=_write_tfrecord,
|
|
93
|
+
args=(input_chunk, path, featurizer, start_index)
|
|
94
|
+
)
|
|
95
|
+
processes.append(process)
|
|
96
|
+
process.start()
|
|
97
|
+
|
|
98
|
+
for process in processes:
|
|
99
|
+
process.join()
|
|
100
|
+
|
|
101
|
+
def read(
|
|
102
|
+
path: str,
|
|
103
|
+
shuffle_files: bool = False
|
|
104
|
+
) -> tf.data.Dataset:
|
|
105
|
+
spec = load_spec(os.path.join(path, 'spec.pb'))
|
|
106
|
+
filenames = sorted(glob.glob(os.path.join(path, '*.tfrecord')))
|
|
107
|
+
num_files = len(filenames)
|
|
108
|
+
ds = tf.data.Dataset.from_tensor_slices(filenames)
|
|
109
|
+
if shuffle_files:
|
|
110
|
+
ds = ds.shuffle(num_files)
|
|
111
|
+
ds = ds.interleave(
|
|
112
|
+
tf.data.TFRecordDataset, num_parallel_calls=1)
|
|
113
|
+
ds = ds.map(
|
|
114
|
+
lambda x: _parse_example(x, spec),
|
|
115
|
+
num_parallel_calls=tf.data.AUTOTUNE)
|
|
116
|
+
if not tensors.is_scalar(spec):
|
|
117
|
+
ds = ds.unbatch()
|
|
118
|
+
return ds
|
|
119
|
+
|
|
120
|
+
def save_spec(path: str, spec: tensors.GraphTensor.Spec) -> None:
|
|
121
|
+
proto = spec.experimental_as_proto()
|
|
122
|
+
with open(path, 'wb') as fh:
|
|
123
|
+
fh.write(proto.SerializeToString())
|
|
124
|
+
|
|
125
|
+
def load_spec(path: str) -> tensors.GraphTensor.Spec:
|
|
126
|
+
with open(path, 'rb') as fh:
|
|
127
|
+
serialized_proto = fh.read()
|
|
128
|
+
spec = tensors.GraphTensor.Spec.experimental_from_proto(
|
|
129
|
+
tensors.GraphTensor.Spec
|
|
130
|
+
.experimental_type_proto()
|
|
131
|
+
.FromString(serialized_proto)
|
|
132
|
+
)
|
|
133
|
+
return spec
|
|
134
|
+
|
|
135
|
+
def _write_tfrecord(
|
|
136
|
+
inputs: list,
|
|
137
|
+
path: str,
|
|
138
|
+
featurizer: 'featurizers.GraphFeaturizer',
|
|
139
|
+
start_index: int,
|
|
140
|
+
) -> None:
|
|
141
|
+
with tf.io.TFRecordWriter(path) as writer:
|
|
142
|
+
for i, x in enumerate(inputs):
|
|
143
|
+
try:
|
|
144
|
+
tensor = featurizer._call(x)
|
|
145
|
+
serialized = _serialize_example(tensor)
|
|
146
|
+
writer.write(serialized)
|
|
147
|
+
except Exception as e:
|
|
148
|
+
index = getattr(x, 'Index', (i + start_index))
|
|
149
|
+
warnings.warn(
|
|
150
|
+
f'Could not write record for index {index}, '
|
|
151
|
+
f'proceeding without it. Exception raised:\n{e}'
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
def _serialize_example(tensor):
|
|
155
|
+
flat_values = tf.nest.flatten(tensor, expand_composites=True)
|
|
156
|
+
flat_values = [
|
|
157
|
+
tf.io.serialize_tensor(value).numpy() for value in flat_values
|
|
158
|
+
]
|
|
159
|
+
feature = tf.train.Feature(
|
|
160
|
+
bytes_list=tf.train.BytesList(value=flat_values)
|
|
161
|
+
)
|
|
162
|
+
example_proto = tf.train.Example(
|
|
163
|
+
features=tf.train.Features(feature={'feature': feature})
|
|
164
|
+
)
|
|
165
|
+
return example_proto.SerializeToString()
|
|
166
|
+
|
|
167
|
+
def _parse_example(
|
|
168
|
+
x: tf.Tensor,
|
|
169
|
+
spec: tensors.GraphTensor.Spec
|
|
170
|
+
) -> tf.Tensor:
|
|
171
|
+
out = tf.io.parse_single_example(
|
|
172
|
+
x, features={'feature': tf.io.RaggedFeature(tf.string)})['feature']
|
|
173
|
+
out = [
|
|
174
|
+
tf.ensure_shape(tf.io.parse_tensor(x[0], s.dtype), s.shape)
|
|
175
|
+
for (x, s) in zip(
|
|
176
|
+
tf.split(out, len(tf.nest.flatten(spec, expand_composites=True))),
|
|
177
|
+
tf.nest.flatten(spec, expand_composites=True)
|
|
178
|
+
)
|
|
179
|
+
]
|
|
180
|
+
out = tf.nest.pack_sequence_as(spec, tf.nest.flatten(out), expand_composites=True)
|
|
181
|
+
return out
|
|
182
|
+
|
|
183
|
+
def _remove_files(path):
|
|
184
|
+
for filename in os.listdir(path):
|
|
185
|
+
if filename.endswith('tfrecord') or filename == 'spec.pb':
|
|
186
|
+
filepath = os.path.join(path, filename)
|
|
187
|
+
os.remove(filepath)
|