molcraft 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of molcraft might be problematic. Click here for more details.

molcraft/records.py ADDED
@@ -0,0 +1,169 @@
1
+ import os
2
+ import math
3
+ import glob
4
+ import time
5
+ import typing
6
+ import tensorflow as tf
7
+ import numpy as np
8
+ import pandas as pd
9
+ import multiprocessing as mp
10
+
11
+ from molcraft import tensors
12
+ from molcraft import featurizers
13
+
14
+
15
+ def write(
16
+ inputs: list[str | tuple],
17
+ featurizer: featurizers.Featurizer,
18
+ path: str,
19
+ overwrite: bool = True,
20
+ num_files: typing.Optional[int] = None,
21
+ num_processes: typing.Optional[int] = None,
22
+ multiprocessing: bool = False,
23
+ device: str = '/cpu:0'
24
+ ) -> None:
25
+
26
+ if os.path.isdir(path) and not overwrite:
27
+ return
28
+
29
+ os.makedirs(path, exist_ok=True)
30
+
31
+ with tf.device(device):
32
+
33
+ if isinstance(inputs, (pd.DataFrame, pd.Series)):
34
+ inputs = inputs.values
35
+
36
+ if not isinstance(inputs, list):
37
+ inputs = list(inputs)
38
+
39
+ example = _featurize_input(inputs[0], featurizer)
40
+ if not isinstance(example, tensors.GraphTensor):
41
+ example = example[0]
42
+
43
+ save_spec(os.path.join(path, 'spec.pb'), example.spec)
44
+
45
+ if num_processes is None:
46
+ num_processes = mp.cpu_count()
47
+
48
+ if num_files is None:
49
+ num_files = min(len(inputs), num_processes)
50
+
51
+ chunk_size = math.ceil(len(inputs) / num_files)
52
+ num_files = math.ceil(len(inputs) / chunk_size)
53
+
54
+ paths = [
55
+ os.path.join(path, f'tfrecord-{i:04d}.tfrecord')
56
+ for i in range(num_files)
57
+ ]
58
+
59
+ input_chunks = [
60
+ inputs[i * chunk_size: (i + 1) * chunk_size]
61
+ for i in range(num_files)
62
+ ]
63
+
64
+ if not multiprocessing:
65
+ for path, input_chunk in zip(paths, input_chunks):
66
+ _write_tfrecord(input_chunk, path, featurizer)
67
+ return
68
+
69
+ processes = []
70
+
71
+ for path, input_chunk in zip(paths, input_chunks):
72
+
73
+ while len(processes) >= num_processes:
74
+ for process in processes:
75
+ if not process.is_alive():
76
+ processes.remove(process)
77
+ else:
78
+ time.sleep(0.1)
79
+ continue
80
+
81
+ process = mp.Process(
82
+ target=_write_tfrecord,
83
+ args=(input_chunk, path, featurizer)
84
+ )
85
+ processes.append(process)
86
+ process.start()
87
+
88
+ for process in processes:
89
+ process.join()
90
+
91
+ def load(
92
+ path: str,
93
+ shuffle_files: bool = False
94
+ ) -> tf.data.Dataset:
95
+ spec = load_spec(os.path.join(path, 'spec.pb'))
96
+ filenames = sorted(glob.glob(os.path.join(path, '*.tfrecord')))
97
+ num_files = len(filenames)
98
+ ds = tf.data.Dataset.from_tensor_slices(filenames)
99
+ if shuffle_files:
100
+ ds = ds.shuffle(num_files)
101
+ ds = ds.interleave(
102
+ tf.data.TFRecordDataset, num_parallel_calls=1)
103
+ ds = ds.map(
104
+ lambda x: _parse_example(x, spec),
105
+ num_parallel_calls=tf.data.AUTOTUNE)
106
+ if not tensors.is_scalar(spec):
107
+ ds = ds.unbatch()
108
+ return ds
109
+
110
+ def _write_tfrecord(
111
+ inputs,
112
+ path: str,
113
+ featurizer: featurizers.Featurizer,
114
+ ) -> None:
115
+
116
+ def write_example(tensor):
117
+ flat_values = tf.nest.flatten(tensor, expand_composites=True)
118
+ flat_values = [tf.io.serialize_tensor(value).numpy() for value in flat_values]
119
+ feature = tf.train.Feature(bytes_list=tf.train.BytesList(value=flat_values))
120
+ serialized_feature = _serialize_example({'feature': feature})
121
+ writer.write(serialized_feature)
122
+
123
+ with tf.io.TFRecordWriter(path) as writer:
124
+ for x in inputs:
125
+ tensor = _featurize_input(x, featurizer)
126
+ if isinstance(tensor, tensors.GraphTensor):
127
+ write_example(tensor)
128
+ else:
129
+ for t in tensor:
130
+ write_example(t)
131
+
132
+ def _featurize_input(x, featurizer):
133
+ if isinstance(x, (list, np.ndarray)):
134
+ x = tuple(x)
135
+ return featurizer(x)
136
+
137
+ def _serialize_example(
138
+ feature: dict[str, tf.train.Feature]
139
+ ) -> bytes:
140
+ example_proto = tf.train.Example(
141
+ features=tf.train.Features(feature=feature))
142
+ return example_proto.SerializeToString()
143
+
144
+ def _parse_example(
145
+ x: tf.Tensor,
146
+ spec: tensors.GraphTensor.Spec
147
+ ) -> tf.Tensor:
148
+ out = tf.io.parse_single_example(
149
+ x, features={'feature': tf.io.RaggedFeature(tf.string)})['feature']
150
+ out = [tf.ensure_shape(tf.io.parse_tensor(x[0], s.dtype), s.shape) for (x, s) in zip(
151
+ tf.split(out, len(tf.nest.flatten(spec, expand_composites=True))),
152
+ tf.nest.flatten(spec, expand_composites=True))]
153
+ out = tf.nest.pack_sequence_as(spec, tf.nest.flatten(out), expand_composites=True)
154
+ return out
155
+
156
+ def save_spec(path: str, spec: tensors.GraphTensor.Spec) -> None:
157
+ proto = spec.experimental_as_proto()
158
+ with open(path, 'wb') as fh:
159
+ fh.write(proto.SerializeToString())
160
+
161
+ def load_spec(path: str) -> tensors.GraphTensor.Spec:
162
+ with open(path, 'rb') as fh:
163
+ serialized_proto = fh.read()
164
+ spec = tensors.GraphTensor.Spec.experimental_from_proto(
165
+ tensors.GraphTensor.Spec
166
+ .experimental_type_proto()
167
+ .FromString(serialized_proto))
168
+ return spec
169
+
molcraft/tensors.py ADDED
@@ -0,0 +1,527 @@
1
+ import tensorflow as tf
2
+ import keras
3
+ import typing
4
+ from tensorflow.python.framework import composite_tensor
5
+
6
+ from molcraft import ops
7
+
8
+
9
+ class GraphTensorBatchEncoder(tf.experimental.ExtensionTypeBatchEncoder):
10
+
11
+ def batch(self, spec: 'GraphTensor.Spec', batch_size: int | None):
12
+ def batch_field(f):
13
+ if isinstance(f, tf.TensorSpec):
14
+ return tf.TensorSpec(
15
+ shape=[None] + f.shape[1:],
16
+ dtype=f.dtype)
17
+ elif isinstance(f, tf.RaggedTensorSpec):
18
+ return tf.RaggedTensorSpec(
19
+ shape=[batch_size, None] + f.shape[1:],
20
+ dtype=f.dtype,
21
+ ragged_rank=1,
22
+ row_splits_dtype=f.row_splits_dtype)
23
+ elif isinstance(f, tf.TypeSpec):
24
+ return f.__batch_encoder__.batch(f, batch_size)
25
+ return f
26
+ fields = dict(spec.__dict__)
27
+ # Pop context fields as they will be batched differently
28
+ context_fields = fields.pop('context')
29
+ batched_fields = tf.nest.map_structure(batch_field, fields)
30
+ batched_spec = object.__new__(type(spec))
31
+ batched_context_fields = tf.nest.map_structure(
32
+ lambda spec: tf.TensorSpec([batch_size] + spec.shape, spec.dtype),
33
+ context_fields)
34
+ batched_spec.__dict__.update({'context': batched_context_fields})
35
+ batched_spec.__dict__.update(batched_fields)
36
+ return batched_spec
37
+
38
+ def unbatch(self, spec: 'GraphTensor.Spec'):
39
+ def unbatch_field(f):
40
+ if isinstance(f, tf.TensorSpec):
41
+ return tf.TensorSpec(
42
+ shape=[None] + f.shape[1:],
43
+ dtype=f.dtype)
44
+ elif isinstance(f, tf.RaggedTensorSpec):
45
+ return tf.RaggedTensorSpec(
46
+ shape=[None] + f.shape[2:],
47
+ dtype=f.dtype,
48
+ ragged_rank=0,
49
+ row_splits_dtype=f.row_splits_dtype)
50
+ elif isinstance(f, tf.TypeSpec):
51
+ return f.__batch_encoder__.unbatch(f)
52
+ return f
53
+ fields = dict(spec.__dict__)
54
+ # Pop context fields as they will be unbatched differently
55
+ context_fields = fields.pop('context')
56
+ unbatched_fields = tf.nest.map_structure(unbatch_field, fields)
57
+ unbatched_context_fields = tf.nest.map_structure(
58
+ lambda spec: tf.TensorSpec(spec.shape[1:], spec.dtype),
59
+ context_fields)
60
+ unbatched_spec = object.__new__(type(spec))
61
+ unbatched_spec.__dict__.update({'context': unbatched_context_fields})
62
+ unbatched_spec.__dict__.update(unbatched_fields)
63
+ return unbatched_spec
64
+
65
+ def encode(self, spec: 'GraphTensor.Spec', value: 'GraphTensor', minimum_rank: int = 0):
66
+ unflatten = False if (is_ragged(spec) or is_scalar(spec)) else True
67
+ if unflatten:
68
+ value = value.unflatten()
69
+ value_components = tuple(value.__dict__[key] for key in spec.__dict__)
70
+ value_components = tuple(
71
+ x for x in tf.nest.flatten(value_components)
72
+ if isinstance(x, (tf.Tensor, composite_tensor.CompositeTensor))
73
+ )
74
+ return value_components
75
+
76
+ def encoding_specs(self, spec: 'GraphTensor.Spec'):
77
+ def encode_fields(f):
78
+ if isinstance(f, tf.TensorSpec):
79
+ scalar = is_scalar(spec)
80
+ return tf.RaggedTensorSpec(
81
+ shape=([None] if scalar else [None, None]) + f.shape[1:],
82
+ dtype=f.dtype,
83
+ ragged_rank=(0 if scalar else 1),
84
+ row_splits_dtype=spec.context['size'].dtype)
85
+ return f
86
+ fields = dict(spec.__dict__)
87
+ context_fields = fields.pop('context')
88
+ encoded_fields = tf.nest.map_structure(encode_fields, fields)
89
+ encoded_fields = {**{'context': context_fields}, **encoded_fields}
90
+ spec_components = tuple(encoded_fields.values())
91
+ spec_components = tuple(
92
+ x for x in tf.nest.flatten(spec_components)
93
+ if isinstance(x, tf.TypeSpec)
94
+ )
95
+ return spec_components
96
+
97
+ def decode(self, spec, encoded_value):
98
+ spec_tuple = tuple(spec.__dict__.values())
99
+ encoded_value = iter(encoded_value)
100
+ value_tuple = [
101
+ next(encoded_value) if isinstance(x, tf.TypeSpec) else x
102
+ for x in tf.nest.flatten(spec_tuple)
103
+ ]
104
+ value_tuple = tf.nest.pack_sequence_as(spec_tuple, value_tuple)
105
+ fields = dict(zip(spec.__dict__.keys(), value_tuple))
106
+ value = object.__new__(spec.value_type)
107
+ value.__dict__.update(fields)
108
+
109
+ flatten = is_ragged(value) and not is_ragged(spec)
110
+ if flatten:
111
+ value = value.flatten()
112
+ return value
113
+
114
+
115
+ class GraphTensor(tf.experimental.BatchableExtensionType):
116
+ context: typing.Mapping[str, typing.Union[tf.Tensor, tf.RaggedTensor]]
117
+ node: typing.Mapping[str, typing.Union[tf.Tensor, tf.RaggedTensor]]
118
+ edge: typing.Mapping[str, typing.Union[tf.Tensor, tf.RaggedTensor]]
119
+
120
+ __batch_encoder__ = GraphTensorBatchEncoder()
121
+
122
+ __name__ = 'GraphTensor'
123
+
124
+ def __validate__(self):
125
+ assert 'size' in self.context, "graph.context['size'] is required."
126
+ assert self.context['size'].dtype == tf.int32, (
127
+ "dtype of graph.context['size'] needs to be int32.")
128
+ assert 'feature' in self.node, "graph.node['feature'] is required."
129
+ assert 'source' in self.edge, "graph.edge['source'] is required."
130
+ assert 'target' in self.edge, "graph.edge['target'] is required."
131
+ assert self.edge['source'].dtype == tf.int32, (
132
+ "dtype of graph.edge['source'] needs to be int32.")
133
+ assert self.edge['target'].dtype == tf.int32, (
134
+ "dtype of graph.edge['target'] needs to be int32.")
135
+ # TODO: Assert node sizes (based on context['size'])
136
+ # Assert edge sizes (based on edge['source'])
137
+
138
+ @property
139
+ def spec(self):
140
+ def unspecify_spec(s):
141
+ if isinstance(s, tf.TensorSpec):
142
+ return tf.TensorSpec([None] + s.shape[1:], s.dtype)
143
+ return s
144
+ orig_spec = tf.type_spec_from_value(self)
145
+ fields = dict(orig_spec.__dict__)
146
+ context_fields = fields.pop('context')
147
+ new_spec_components = tf.nest.map_structure(unspecify_spec, fields)
148
+ new_spec_components['context'] = context_fields
149
+ return orig_spec.__class__(**new_spec_components)
150
+
151
+ @property
152
+ def shape(self):
153
+ if is_ragged(self):
154
+ return self.node['feature'].shape
155
+ return self.context['size'].shape + [None] + self.node['feature'].shape[1:]
156
+
157
+ @property
158
+ def dtype(self):
159
+ return self.node['feature'].dtype
160
+
161
+ @property
162
+ def graph_indicator(self):
163
+ dtype = self.context['size'].dtype
164
+ if is_scalar(self):
165
+ return tf.zeros(tf.shape(self.node['feature'])[:1], dtype=dtype)
166
+ num_graphs = keras.ops.shape(self.context['size'])[0]
167
+ if is_ragged(self):
168
+ return keras.ops.arange(num_graphs, dtype=dtype)
169
+ return keras.ops.repeat(keras.ops.arange(num_graphs, dtype=dtype), self.context['size'])
170
+
171
+ @property
172
+ def num_subgraphs(self) -> tf.Tensor:
173
+ dtype = self.context['size'].dtype
174
+ if is_scalar(self):
175
+ num_subgraphs = tf.constant(1, dtype=dtype)
176
+ else:
177
+ num_subgraphs = tf.shape(self.context['size'], out_type=dtype)[0]
178
+ if tf.executing_eagerly():
179
+ return int(num_subgraphs)
180
+ return num_subgraphs
181
+
182
+ @property
183
+ def num_nodes(self):
184
+ num_nodes = keras.ops.shape(self.node['feature'])[0]
185
+ if tf.executing_eagerly():
186
+ return int(num_nodes)
187
+ return num_nodes
188
+
189
+ @property
190
+ def num_edges(self):
191
+ num_edges = keras.ops.shape(self.edge['source'])[0]
192
+ if tf.executing_eagerly():
193
+ return int(num_edges)
194
+ return num_edges
195
+
196
+ def gather(self, node_attr: str, edge_type: str) -> tf.Tensor:
197
+ if edge_type != 'source' and edge_type != 'target':
198
+ raise ValueError
199
+ return ops.gather(self.node[node_attr], self.edge[edge_type])
200
+
201
+ def aggregate(self, edge_attr: str, edge_type: str = 'target') -> tf.Tensor:
202
+ if edge_type != 'source' and edge_type != 'target':
203
+ raise ValueError
204
+ edge_attr = self.edge[edge_attr]
205
+ if 'weight' in self.edge:
206
+ edge_attr = edge_attr * self.edge['weight']
207
+ return ops.aggregate(edge_attr, self.edge[edge_type], self.num_nodes)
208
+
209
+ def propagate(self, add_edge_feature: bool = False):
210
+ updated_feature = ops.propagate(
211
+ node_feature=self.node['feature'],
212
+ edge_source=self.edge['source'],
213
+ edge_target=self.edge['target'],
214
+ edge_feature=self.edge.get('feature') if add_edge_feature else None,
215
+ edge_weight=self.edge.get('weight'),
216
+ )
217
+ return self.update({'node': {'feature': updated_feature}})
218
+
219
+ def flatten(self):
220
+ if not is_ragged(self):
221
+ raise ValueError(
222
+ f"{self.__class__.__qualname__} instance is already flat.")
223
+ def flatten_fn(x):
224
+ if isinstance(x, tf.RaggedTensor):
225
+ return x.flat_values
226
+ return x
227
+ edge_increment = ops.gather(
228
+ self.node['feature'].row_starts(), self.edge['source'].value_rowids())
229
+ edge_increment = tf.cast(
230
+ edge_increment, dtype=self.edge['source'].dtype)
231
+ data = to_dict(self)
232
+ flat_values = tf.nest.map_structure(flatten_fn, data)
233
+ flat_values['edge']['source'] += edge_increment
234
+ flat_values['edge']['target'] += edge_increment
235
+ return from_dict(flat_values)
236
+
237
+ def unflatten(self, *, force: bool = False):
238
+ if is_scalar(self):
239
+ raise ValueError(
240
+ f"{self.__class__.__qualname__} instance is a scalar "
241
+ "and cannot be unflattened.")
242
+ if is_ragged(self):
243
+ raise ValueError(
244
+ f"{self.__class__.__qualname__} instance is already unflat.")
245
+ def unflatten_fn(x, value_rowids, nrows) -> tf.RaggedTensor:
246
+ if isinstance(x, tf.Tensor):
247
+ return tf.RaggedTensor.from_value_rowids(x, value_rowids, nrows)
248
+ return x
249
+ graph_indicator_node = self.graph_indicator
250
+ graph_indicator_edge = ops.gather(graph_indicator_node, self.edge['source'])
251
+ if force:
252
+ sorted_indices = keras.ops.argsort(graph_indicator_edge)
253
+ num_subgraphs = self.num_subgraphs
254
+ unflat_values = {}
255
+ data = to_dict(self)
256
+ for key in tf.type_spec_from_value(self).__dict__:
257
+ value = data[key]
258
+ if key == 'context':
259
+ unflat_values[key] = value
260
+ elif key == 'node':
261
+ unflat_values[key] = tf.nest.map_structure(
262
+ lambda x: unflatten_fn(x, graph_indicator_node, num_subgraphs),
263
+ value)
264
+ row_starts = unflat_values[key]['feature'].row_starts()
265
+ edge_decrement = ops.gather(row_starts, graph_indicator_edge)
266
+ if force:
267
+ edge_decrement = ops.gather(edge_decrement, sorted_indices)
268
+ graph_indicator_edge = ops.gather(graph_indicator_edge, sorted_indices)
269
+ edge_decrement = tf.cast(edge_decrement, dtype=self.edge['source'].dtype)
270
+ elif key == 'edge':
271
+ if force:
272
+ value = tf.nest.map_structure(lambda x: ops.gather(x, sorted_indices), value)
273
+ value['source'] -= edge_decrement
274
+ value['target'] -= edge_decrement
275
+ unflat_values[key] = tf.nest.map_structure(
276
+ lambda x: unflatten_fn(x, graph_indicator_edge, num_subgraphs),
277
+ value)
278
+ return from_dict(unflat_values)
279
+
280
+ def update(self, values):
281
+ data = to_dict(self)
282
+ for outer_field, mapping in values.items():
283
+ if outer_field == 'context':
284
+ reference_value = data[outer_field]['size']
285
+ elif outer_field == 'edge':
286
+ reference_value = data[outer_field]['source']
287
+ else:
288
+ reference_value = data[outer_field]['feature']
289
+ for inner_field, value in mapping.items():
290
+ if value is None:
291
+ data[outer_field].pop(inner_field, None)
292
+ continue
293
+ data[outer_field][inner_field] = _maybe_convert_new_value(
294
+ value, reference_value
295
+ )
296
+ return self.__class__(**data)
297
+
298
+ def replace(self, values):
299
+ data = to_dict(self)
300
+ for outer_field, mapping in values.items():
301
+ if outer_field == 'context':
302
+ reference_value = data[outer_field]['size']
303
+ elif outer_field == 'edge':
304
+ reference_value = data[outer_field]['source']
305
+ else:
306
+ reference_value = data[outer_field]['feature']
307
+ for inner_field, value in mapping.items():
308
+ values[outer_field][inner_field] = _maybe_convert_new_value(
309
+ value, reference_value
310
+ )
311
+ return self.__class__(**values)
312
+
313
+ def __getitem__(self, index):
314
+ if index is None and is_scalar(self):
315
+ return self.__class__(
316
+ context={key: value[None] for (key, value) in self.context.items()},
317
+ node=self.node,
318
+ edge=self.edge,
319
+ )
320
+ if not is_ragged(self):
321
+ is_flat = True
322
+ tensor = self.unflatten()
323
+ else:
324
+ tensor = self
325
+ is_flat = False
326
+ data = to_dict(tensor)
327
+ if isinstance(index, (slice, int)):
328
+ data = tf.nest.map_structure(lambda x: x[index], data)
329
+ else:
330
+ data = tf.nest.map_structure(lambda x: ops.gather(x, index), data)
331
+ tensor = from_dict(data)
332
+ if is_flat and not is_scalar(tensor):
333
+ return tensor.flatten()
334
+ return tensor
335
+
336
+ def __repr__(self):
337
+ return _repr(self)
338
+
339
+ def numpy(self):
340
+ """For now added to work with `keras.Model.predict`"""
341
+ return self
342
+
343
+ class Spec:
344
+
345
+ def __init__(
346
+ self,
347
+ context: typing.Mapping[str, tf.TensorSpec | tf.RaggedTensorSpec],
348
+ node: typing.Mapping[str, tf.TensorSpec | tf.RaggedTensorSpec],
349
+ edge: typing.Mapping[str, tf.TensorSpec | tf.RaggedTensorSpec],
350
+ ) -> None:
351
+ self.context = context
352
+ self.node = node
353
+ self.edge = edge
354
+
355
+ @property
356
+ def shape(self):
357
+ if is_ragged(self):
358
+ return self.node['feature'].shape
359
+ return self.context['size'].shape + [None] + self.node['feature'].shape[1:]
360
+
361
+ @classmethod
362
+ def from_input_shape_dict(cls, input_shape: dict[str, tf.TensorShape]) -> 'GraphTensor.Spec':
363
+ for key, value in input_shape.items():
364
+ input_shape[key] = {k: tf.TensorShape(v) for k, v in value.items()}
365
+ return cls(**tf.nest.map_structure(lambda s: tf.TensorSpec(s, dtype=tf.variant), input_shape))
366
+
367
+ def __repr__(self):
368
+ return _repr(self)
369
+
370
+
371
+ @tf.experimental.dispatch_for_api(tf.concat, {'values': typing.List[GraphTensor]})
372
+ def graph_tensor_concat(
373
+ values: typing.List[GraphTensor],
374
+ axis: int = 0,
375
+ name: str = 'concat'
376
+ ) -> GraphTensor:
377
+ ragged = [is_ragged(v) for v in values]
378
+ if 0 < sum(ragged) < len(ragged):
379
+ raise ValueError(
380
+ 'Nested data of the GraphTensor instances do not have consistent '
381
+ 'types: found both tf.RaggedTensor values and tf.Tensor values.')
382
+ else:
383
+ ragged = ragged[0]
384
+
385
+ if ragged:
386
+ values = [v.flatten() for v in values]
387
+
388
+ flat_values = [tf.nest.flatten(v, expand_composites=True) for v in values]
389
+ flat_values = [tf.concat(f, axis=0) for f in list(zip(*flat_values))]
390
+ num_edges = [keras.ops.shape(v.edge['source'])[0] for v in values]
391
+ num_nodes = [keras.ops.shape(v.node['feature'])[0] for v in values]
392
+ incr = tf.concat([[0], tf.cumsum(num_nodes)[:-1]], axis=0)
393
+ incr = tf.repeat(incr, num_edges)
394
+ value = tf.nest.pack_sequence_as(values[0], flat_values, expand_composites=True)
395
+
396
+ edge_update = {
397
+ 'source': value.edge['source'] + incr,
398
+ 'target': value.edge['target'] + incr,
399
+ }
400
+ value = value.update(
401
+ {
402
+ 'edge': edge_update
403
+ },
404
+ )
405
+ if not ragged:
406
+ return value
407
+ return value.unflatten()
408
+
409
+ # TODO: Clean this up.
410
+ @tf.experimental.dispatch_for_api(tf.stack, {'values': typing.List[GraphTensor]})
411
+ def graph_tensor_stack(
412
+ values: typing.List[GraphTensor],
413
+ axis: int = 0,
414
+ name: str = 'stack'
415
+ ) -> GraphTensor:
416
+ ragged = [is_ragged(v) for v in values]
417
+ if not is_scalar(values[0]):
418
+ raise ValueError(
419
+ 'tf.stack on a list of `GraphTensor`s is currently '
420
+ 'only supported for scalar `GraphTensor`s. '
421
+ )
422
+ if any(ragged):
423
+ raise ValueError(
424
+ 'tf.stack on a list of `GraphTensor`s is currently '
425
+ 'only supported for flattened `GraphTensor`s. '
426
+ )
427
+
428
+ def concat_or_stack(k, v):
429
+ if k.startswith('context'):
430
+ return tf.stack(v, axis=0)
431
+ return tf.concat(v, axis=0)
432
+
433
+ fields = tuple(tf.type_spec_from_value(values[0]).__dict__)
434
+ num_inner_fields = tuple(len(values[0].__dict__[field]) for field in fields)
435
+ outer_keys = []
436
+ for (f, num_fields) in zip(fields, num_inner_fields):
437
+ outer_keys.extend([f] * num_fields)
438
+
439
+ flat_values = [tf.nest.flatten(v, expand_composites=True) for v in values]
440
+ flat_values = [concat_or_stack(k, f) for k, f in zip(outer_keys, list(zip(*flat_values)))]
441
+ value = tf.nest.pack_sequence_as(values[0], flat_values, expand_composites=True)
442
+
443
+ num_edges = [keras.ops.shape(v.edge['source'])[0] for v in values]
444
+ num_nodes = [keras.ops.shape(v.node['feature'])[0] for v in values]
445
+ incr = tf.concat([[0], tf.cumsum(num_nodes)[:-1]], axis=0)
446
+ incr = tf.repeat(incr, num_edges)
447
+ edge_update = {
448
+ 'source': value.edge['source'] + incr,
449
+ 'target': value.edge['target'] + incr,
450
+ }
451
+ value = value.update(
452
+ {
453
+ 'edge': edge_update
454
+ },
455
+ )
456
+ return value
457
+
458
+ def is_scalar(value_or_spec: GraphTensor | GraphTensor.Spec) -> bool:
459
+ return value_or_spec.context['size'].shape.rank == 0
460
+
461
+ def is_ragged(value_or_spec: GraphTensor | GraphTensor.Spec) -> bool:
462
+ is_ragged = isinstance(
463
+ value_or_spec.node['feature'], (tf.RaggedTensor, tf.RaggedTensorSpec))
464
+ if isinstance(value_or_spec, tf.RaggedTensorSpec):
465
+ is_ragged = (
466
+ is_ragged and value_or_spec.node['feature'].ragged_rank == 1)
467
+ return is_ragged
468
+
469
+ def to_dict(tensor: GraphTensor) -> dict:
470
+ spec = tf.type_spec_from_value(tensor)
471
+ return {key: dict(tensor.__dict__[key]) for key in spec.__dict__}
472
+
473
+ def from_dict(data: dict) -> GraphTensor:
474
+ data['context']['size'] = tf.cast(data['context']['size'], tf.int32)
475
+ data['edge']['source'] = tf.cast(data['edge']['source'], tf.int32)
476
+ data['edge']['target'] = tf.cast(data['edge']['target'], tf.int32)
477
+ return GraphTensor(**data)
478
+
479
+ def is_graph(data):
480
+ if isinstance(data, GraphTensor):
481
+ return True
482
+ elif isinstance(data, dict) and 'size' in data.get('context', {}):
483
+ return True
484
+ return False
485
+
486
+ def _maybe_convert_new_value(
487
+ new_value: tf.Tensor | tf.RaggedTensor,
488
+ old_value: tf.Tensor | tf.RaggedTensor | None,
489
+ ) -> tf.Tensor | tf.RaggedTensor:
490
+ if old_value is None:
491
+ return new_value
492
+ is_old_ragged = isinstance(old_value, tf.RaggedTensor)
493
+ is_new_ragged = isinstance(new_value, tf.RaggedTensor)
494
+ if is_old_ragged and not is_new_ragged:
495
+ new_value = old_value.with_flat_values(new_value)
496
+ elif not is_old_ragged and is_new_ragged:
497
+ new_value = new_value.flat_values
498
+ return new_value
499
+
500
+ def _repr(x: GraphTensor | GraphTensor.Spec):
501
+ if isinstance(x, GraphTensor):
502
+ def _trepr(v: tf.Tensor | tf.RaggedTensor):
503
+ if isinstance(v, tf.Tensor):
504
+ return f'<tf.Tensor: shape={v.shape.as_list()}, dtype={v.dtype.name}>'
505
+ return (
506
+ f'<tf.RaggedTensor: shape={v.shape.as_list()}, '
507
+ f'dtype={v.dtype.name}, ragged_rank={v.ragged_rank}>'
508
+ )
509
+ else:
510
+ def _trepr(v: tf.TensorSpec | tf.RaggedTensorSpec):
511
+ if isinstance(v, tf.TensorSpec):
512
+ return f'<tf.TensorSpec: shape={v.shape.as_list()}, dtype={v.dtype.name}>'
513
+ return (
514
+ f'<tf.RaggedTensorSpec: shape={v.shape.as_list()}, '
515
+ f'dtype={v.dtype.name}, ragged_rank={v.ragged_rank}>'
516
+ )
517
+
518
+ context_fields = f',\n '.join([f'{k!r}: {_trepr(v)}' for k, v in x.context.items()])
519
+ node_fields = f',\n '.join([f'{k!r}: {_trepr(v)}'for k, v in x.node.items()])
520
+ edge_fields = f',\n '.join([f'{k!r}: {_trepr(v)}' for k, v in x.edge.items()])
521
+
522
+ context_field = 'context={\n ' + context_fields + '\n }'
523
+ node_field = 'node={\n ' + node_fields + '\n }'
524
+ edge_field = 'edge={\n ' + edge_fields + '\n }'
525
+
526
+ fields = ',\n '.join([context_field, node_field, edge_field])
527
+ return x.__class__.__name__ + '(\n ' + fields + '\n)'