molcraft 0.1.0rc9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of molcraft might be problematic. Click here for more details.
- molcraft/__init__.py +18 -0
- molcraft/callbacks.py +100 -0
- molcraft/chem.py +714 -0
- molcraft/datasets.py +132 -0
- molcraft/descriptors.py +149 -0
- molcraft/features.py +379 -0
- molcraft/featurizers.py +624 -0
- molcraft/layers.py +1910 -0
- molcraft/losses.py +37 -0
- molcraft/models.py +623 -0
- molcraft/ops.py +195 -0
- molcraft/records.py +187 -0
- molcraft/tensors.py +561 -0
- molcraft/trainers.py +212 -0
- molcraft-0.1.0rc9.dist-info/METADATA +118 -0
- molcraft-0.1.0rc9.dist-info/RECORD +19 -0
- molcraft-0.1.0rc9.dist-info/WHEEL +5 -0
- molcraft-0.1.0rc9.dist-info/licenses/LICENSE +21 -0
- molcraft-0.1.0rc9.dist-info/top_level.txt +1 -0
molcraft/tensors.py
ADDED
|
@@ -0,0 +1,561 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
import tensorflow as tf
|
|
3
|
+
import keras
|
|
4
|
+
import typing
|
|
5
|
+
from tensorflow.python.framework import composite_tensor
|
|
6
|
+
|
|
7
|
+
from molcraft import ops
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class GraphTensorBatchEncoder(tf.experimental.ExtensionTypeBatchEncoder):
|
|
11
|
+
|
|
12
|
+
def batch(self, spec: 'GraphTensor.Spec', batch_size: int | None):
|
|
13
|
+
"""Batches spec.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def batch_field(f):
|
|
17
|
+
if isinstance(f, tf.TensorSpec):
|
|
18
|
+
return tf.TensorSpec(
|
|
19
|
+
shape=[None] + f.shape[1:],
|
|
20
|
+
dtype=f.dtype
|
|
21
|
+
)
|
|
22
|
+
elif isinstance(f, tf.RaggedTensorSpec):
|
|
23
|
+
return tf.RaggedTensorSpec(
|
|
24
|
+
shape=[batch_size, None] + f.shape[1:],
|
|
25
|
+
dtype=f.dtype,
|
|
26
|
+
ragged_rank=1,
|
|
27
|
+
row_splits_dtype=f.row_splits_dtype
|
|
28
|
+
)
|
|
29
|
+
elif isinstance(f, tf.TypeSpec):
|
|
30
|
+
return f.__batch_encoder__.batch(f, batch_size)
|
|
31
|
+
return f
|
|
32
|
+
fields = dict(spec.__dict__)
|
|
33
|
+
# Pop context fields as they will be batched differently
|
|
34
|
+
context_fields = fields.pop('context')
|
|
35
|
+
batched_fields = tf.nest.map_structure(batch_field, fields)
|
|
36
|
+
batched_spec = object.__new__(type(spec))
|
|
37
|
+
batched_context_fields = tf.nest.map_structure(
|
|
38
|
+
lambda spec: tf.TensorSpec([batch_size] + spec.shape, spec.dtype),
|
|
39
|
+
context_fields
|
|
40
|
+
)
|
|
41
|
+
batched_spec.__dict__.update({'context': batched_context_fields})
|
|
42
|
+
batched_spec.__dict__.update(batched_fields)
|
|
43
|
+
return batched_spec
|
|
44
|
+
|
|
45
|
+
def unbatch(self, spec: 'GraphTensor.Spec'):
|
|
46
|
+
"""Unbatches spec.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
def unbatch_field(f):
|
|
50
|
+
if isinstance(f, tf.TensorSpec):
|
|
51
|
+
return tf.TensorSpec(
|
|
52
|
+
shape=[None] + f.shape[1:],
|
|
53
|
+
dtype=f.dtype
|
|
54
|
+
)
|
|
55
|
+
elif isinstance(f, tf.RaggedTensorSpec):
|
|
56
|
+
return tf.RaggedTensorSpec(
|
|
57
|
+
shape=[None] + f.shape[2:],
|
|
58
|
+
dtype=f.dtype,
|
|
59
|
+
ragged_rank=0,
|
|
60
|
+
row_splits_dtype=f.row_splits_dtype
|
|
61
|
+
)
|
|
62
|
+
elif isinstance(f, tf.TypeSpec):
|
|
63
|
+
return f.__batch_encoder__.unbatch(f)
|
|
64
|
+
return f
|
|
65
|
+
fields = dict(spec.__dict__)
|
|
66
|
+
# Pop context fields as they will be unbatched differently
|
|
67
|
+
context_fields = fields.pop('context')
|
|
68
|
+
unbatched_fields = tf.nest.map_structure(unbatch_field, fields)
|
|
69
|
+
unbatched_context_fields = tf.nest.map_structure(
|
|
70
|
+
lambda spec: tf.TensorSpec(spec.shape[1:], spec.dtype),
|
|
71
|
+
context_fields
|
|
72
|
+
)
|
|
73
|
+
unbatched_spec = object.__new__(type(spec))
|
|
74
|
+
unbatched_spec.__dict__.update({'context': unbatched_context_fields})
|
|
75
|
+
unbatched_spec.__dict__.update(unbatched_fields)
|
|
76
|
+
return unbatched_spec
|
|
77
|
+
|
|
78
|
+
def encode(self, spec: 'GraphTensor.Spec', value: 'GraphTensor', minimum_rank: int = 0):
|
|
79
|
+
"""Encodes value.
|
|
80
|
+
"""
|
|
81
|
+
unflatten = False if (is_ragged(spec) or is_scalar(spec)) else True
|
|
82
|
+
if unflatten:
|
|
83
|
+
value = value.unflatten()
|
|
84
|
+
value_components = tuple(value.__dict__[key] for key in spec.__dict__)
|
|
85
|
+
value_components = tuple(
|
|
86
|
+
x for x in tf.nest.flatten(value_components)
|
|
87
|
+
if isinstance(x, (tf.Tensor, composite_tensor.CompositeTensor))
|
|
88
|
+
)
|
|
89
|
+
return value_components
|
|
90
|
+
|
|
91
|
+
def encoding_specs(self, spec: 'GraphTensor.Spec'):
|
|
92
|
+
"""Matches spec and encoded value of `encode(spec, value)`.
|
|
93
|
+
"""
|
|
94
|
+
def encode_fields(f):
|
|
95
|
+
if isinstance(f, tf.TensorSpec):
|
|
96
|
+
scalar = is_scalar(spec)
|
|
97
|
+
return tf.RaggedTensorSpec(
|
|
98
|
+
shape=([None] if scalar else [None, None]) + f.shape[1:],
|
|
99
|
+
dtype=f.dtype,
|
|
100
|
+
ragged_rank=(0 if scalar else 1),
|
|
101
|
+
row_splits_dtype=spec.context['size'].dtype
|
|
102
|
+
)
|
|
103
|
+
return f
|
|
104
|
+
fields = dict(spec.__dict__)
|
|
105
|
+
context_fields = fields.pop('context')
|
|
106
|
+
encoded_fields = tf.nest.map_structure(encode_fields, fields)
|
|
107
|
+
encoded_fields = {**{'context': context_fields}, **encoded_fields}
|
|
108
|
+
spec_components = tuple(encoded_fields.values())
|
|
109
|
+
spec_components = tuple(
|
|
110
|
+
x for x in tf.nest.flatten(spec_components)
|
|
111
|
+
if isinstance(x, tf.TypeSpec)
|
|
112
|
+
)
|
|
113
|
+
return spec_components
|
|
114
|
+
|
|
115
|
+
def decode(self, spec, encoded_value):
|
|
116
|
+
"""Decodes encoded value.
|
|
117
|
+
"""
|
|
118
|
+
spec_tuple = tuple(spec.__dict__.values())
|
|
119
|
+
encoded_value = iter(encoded_value)
|
|
120
|
+
value_tuple = [
|
|
121
|
+
next(encoded_value) if isinstance(x, tf.TypeSpec) else x
|
|
122
|
+
for x in tf.nest.flatten(spec_tuple)
|
|
123
|
+
]
|
|
124
|
+
value_tuple = tf.nest.pack_sequence_as(spec_tuple, value_tuple)
|
|
125
|
+
fields = dict(zip(spec.__dict__.keys(), value_tuple))
|
|
126
|
+
value = object.__new__(spec.value_type)
|
|
127
|
+
value.__dict__.update(fields)
|
|
128
|
+
flatten = is_ragged(value) and not is_ragged(spec)
|
|
129
|
+
if flatten:
|
|
130
|
+
value = value.flatten()
|
|
131
|
+
return value
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
class GraphTensor(tf.experimental.BatchableExtensionType):
|
|
135
|
+
context: typing.Mapping[str, tf.Tensor]
|
|
136
|
+
node: typing.Mapping[str, typing.Union[tf.Tensor, tf.RaggedTensor]]
|
|
137
|
+
edge: typing.Mapping[str, typing.Union[tf.Tensor, tf.RaggedTensor]]
|
|
138
|
+
|
|
139
|
+
__batch_encoder__ = GraphTensorBatchEncoder()
|
|
140
|
+
|
|
141
|
+
__name__ = 'GraphTensor'
|
|
142
|
+
|
|
143
|
+
def __validate__(self):
|
|
144
|
+
if tf.executing_eagerly():
|
|
145
|
+
assert 'size' in self.context, "graph.context['size'] is required."
|
|
146
|
+
assert self.context['size'].dtype == tf.int32, (
|
|
147
|
+
"dtype of graph.context['size'] needs to be int32."
|
|
148
|
+
)
|
|
149
|
+
assert 'feature' in self.node, "graph.node['feature'] is required."
|
|
150
|
+
assert 'source' in self.edge, "graph.edge['source'] is required."
|
|
151
|
+
assert 'target' in self.edge, "graph.edge['target'] is required."
|
|
152
|
+
assert self.edge['source'].dtype == tf.int32, (
|
|
153
|
+
"dtype of graph.edge['source'] needs to be int32."
|
|
154
|
+
)
|
|
155
|
+
assert self.edge['target'].dtype == tf.int32, (
|
|
156
|
+
"dtype of graph.edge['target'] needs to be int32."
|
|
157
|
+
)
|
|
158
|
+
if isinstance(self.node['feature'], tf.Tensor):
|
|
159
|
+
num_nodes = keras.ops.shape(self.node['feature'])[0]
|
|
160
|
+
else:
|
|
161
|
+
num_nodes = keras.ops.sum(self.node['feature'].row_lengths())
|
|
162
|
+
assert keras.ops.sum(self.context['size']) == num_nodes, (
|
|
163
|
+
"graph.node['feature'] tensor is incompatible with graph.context['size']"
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
@property
|
|
167
|
+
def spec(self):
|
|
168
|
+
def unspecify_spec(s):
|
|
169
|
+
if isinstance(s, tf.TensorSpec):
|
|
170
|
+
return tf.TensorSpec([None] + s.shape[1:], s.dtype)
|
|
171
|
+
return s
|
|
172
|
+
orig_spec = tf.type_spec_from_value(self)
|
|
173
|
+
fields = dict(orig_spec.__dict__)
|
|
174
|
+
context_fields = fields.pop('context')
|
|
175
|
+
new_spec_components = tf.nest.map_structure(unspecify_spec, fields)
|
|
176
|
+
new_spec_components['context'] = context_fields
|
|
177
|
+
return orig_spec.__class__(**new_spec_components)
|
|
178
|
+
|
|
179
|
+
@property
|
|
180
|
+
def shape(self):
|
|
181
|
+
if is_ragged(self):
|
|
182
|
+
return self.node['feature'].shape
|
|
183
|
+
return self.context['size'].shape + [None] + self.node['feature'].shape[1:]
|
|
184
|
+
|
|
185
|
+
@property
|
|
186
|
+
def dtype(self):
|
|
187
|
+
return self.node['feature'].dtype
|
|
188
|
+
|
|
189
|
+
@property
|
|
190
|
+
def graph_indicator(self):
|
|
191
|
+
dtype = self.context['size'].dtype
|
|
192
|
+
if is_scalar(self):
|
|
193
|
+
return tf.zeros(tf.shape(self.node['feature'])[:1], dtype=dtype)
|
|
194
|
+
num_graphs = keras.ops.shape(self.context['size'])[0]
|
|
195
|
+
if is_ragged(self):
|
|
196
|
+
return keras.ops.arange(num_graphs, dtype=dtype)
|
|
197
|
+
return keras.ops.repeat(keras.ops.arange(num_graphs, dtype=dtype), self.context['size'])
|
|
198
|
+
|
|
199
|
+
@property
|
|
200
|
+
def num_subgraphs(self) -> tf.Tensor:
|
|
201
|
+
dtype = self.context['size'].dtype
|
|
202
|
+
if is_scalar(self):
|
|
203
|
+
num_subgraphs = tf.constant(1, dtype=dtype)
|
|
204
|
+
else:
|
|
205
|
+
num_subgraphs = tf.shape(self.context['size'], out_type=dtype)[0]
|
|
206
|
+
if tf.executing_eagerly():
|
|
207
|
+
return int(num_subgraphs)
|
|
208
|
+
return num_subgraphs
|
|
209
|
+
|
|
210
|
+
@property
|
|
211
|
+
def num_nodes(self):
|
|
212
|
+
num_nodes = keras.ops.shape(self.node['feature'])[0]
|
|
213
|
+
if tf.executing_eagerly():
|
|
214
|
+
return int(num_nodes)
|
|
215
|
+
return num_nodes
|
|
216
|
+
|
|
217
|
+
@property
|
|
218
|
+
def num_edges(self):
|
|
219
|
+
num_edges = keras.ops.shape(self.edge['source'])[0]
|
|
220
|
+
if tf.executing_eagerly():
|
|
221
|
+
return int(num_edges)
|
|
222
|
+
return num_edges
|
|
223
|
+
|
|
224
|
+
def gather(self, node_attr: str, edge_type: str) -> tf.Tensor:
|
|
225
|
+
if edge_type != 'source' and edge_type != 'target':
|
|
226
|
+
raise ValueError
|
|
227
|
+
return ops.gather(self.node[node_attr], self.edge[edge_type])
|
|
228
|
+
|
|
229
|
+
def aggregate(self, edge_attr: str, edge_type: str = 'target', mode: str = 'sum') -> tf.Tensor:
|
|
230
|
+
if edge_type != 'source' and edge_type != 'target':
|
|
231
|
+
raise ValueError('`edge_attr` needs to be `source` or `target`.')
|
|
232
|
+
edge_attr = self.edge[edge_attr]
|
|
233
|
+
if 'weight' in self.edge:
|
|
234
|
+
edge_attr = ops.edge_weight(edge_attr, self.edge['weight'])
|
|
235
|
+
return ops.aggregate(edge_attr, self.edge[edge_type], self.num_nodes, mode=mode)
|
|
236
|
+
|
|
237
|
+
def propagate(self, add_edge_feature: bool = False):
|
|
238
|
+
updated_feature = ops.propagate(
|
|
239
|
+
node_feature=self.node['feature'],
|
|
240
|
+
edge_source=self.edge['source'],
|
|
241
|
+
edge_target=self.edge['target'],
|
|
242
|
+
edge_feature=self.edge.get('feature') if add_edge_feature else None,
|
|
243
|
+
edge_weight=self.edge.get('weight'),
|
|
244
|
+
)
|
|
245
|
+
return self.update({'node': {'feature': updated_feature}})
|
|
246
|
+
|
|
247
|
+
def flatten(self):
|
|
248
|
+
if not is_ragged(self):
|
|
249
|
+
raise ValueError(
|
|
250
|
+
f"{self.__class__.__qualname__} instance is already flat.")
|
|
251
|
+
def flatten_fn(x):
|
|
252
|
+
if isinstance(x, tf.RaggedTensor):
|
|
253
|
+
return x.flat_values
|
|
254
|
+
return x
|
|
255
|
+
edge_increment = ops.gather(
|
|
256
|
+
self.node['feature'].row_starts(), self.edge['source'].value_rowids())
|
|
257
|
+
edge_increment = tf.cast(
|
|
258
|
+
edge_increment, dtype=self.edge['source'].dtype)
|
|
259
|
+
data = to_dict(self)
|
|
260
|
+
flat_values = tf.nest.map_structure(flatten_fn, data)
|
|
261
|
+
flat_values['edge']['source'] += edge_increment
|
|
262
|
+
flat_values['edge']['target'] += edge_increment
|
|
263
|
+
return from_dict(flat_values)
|
|
264
|
+
|
|
265
|
+
def unflatten(self, *, force: bool = False):
|
|
266
|
+
if is_scalar(self):
|
|
267
|
+
raise ValueError(
|
|
268
|
+
f"{self.__class__.__qualname__} instance is a scalar "
|
|
269
|
+
"and cannot be unflattened.")
|
|
270
|
+
if is_ragged(self):
|
|
271
|
+
raise ValueError(
|
|
272
|
+
f"{self.__class__.__qualname__} instance is already unflat.")
|
|
273
|
+
def unflatten_fn(x, value_rowids, nrows) -> tf.RaggedTensor:
|
|
274
|
+
if isinstance(x, tf.Tensor):
|
|
275
|
+
return tf.RaggedTensor.from_value_rowids(x, value_rowids, nrows)
|
|
276
|
+
return x
|
|
277
|
+
graph_indicator_node = self.graph_indicator
|
|
278
|
+
graph_indicator_edge = ops.gather(graph_indicator_node, self.edge['source'])
|
|
279
|
+
if force:
|
|
280
|
+
sorted_indices = keras.ops.argsort(graph_indicator_edge)
|
|
281
|
+
num_subgraphs = self.num_subgraphs
|
|
282
|
+
unflat_values = {}
|
|
283
|
+
data = to_dict(self)
|
|
284
|
+
for key in tf.type_spec_from_value(self).__dict__:
|
|
285
|
+
value = data[key]
|
|
286
|
+
if key == 'context':
|
|
287
|
+
unflat_values[key] = value
|
|
288
|
+
elif key == 'node':
|
|
289
|
+
unflat_values[key] = tf.nest.map_structure(
|
|
290
|
+
lambda x: unflatten_fn(x, graph_indicator_node, num_subgraphs),
|
|
291
|
+
value)
|
|
292
|
+
row_starts = unflat_values[key]['feature'].row_starts()
|
|
293
|
+
edge_decrement = ops.gather(row_starts, graph_indicator_edge)
|
|
294
|
+
if force:
|
|
295
|
+
edge_decrement = ops.gather(edge_decrement, sorted_indices)
|
|
296
|
+
graph_indicator_edge = ops.gather(graph_indicator_edge, sorted_indices)
|
|
297
|
+
edge_decrement = tf.cast(edge_decrement, dtype=self.edge['source'].dtype)
|
|
298
|
+
elif key == 'edge':
|
|
299
|
+
if force:
|
|
300
|
+
value = tf.nest.map_structure(lambda x: ops.gather(x, sorted_indices), value)
|
|
301
|
+
value['source'] -= edge_decrement
|
|
302
|
+
value['target'] -= edge_decrement
|
|
303
|
+
unflat_values[key] = tf.nest.map_structure(
|
|
304
|
+
lambda x: unflatten_fn(x, graph_indicator_edge, num_subgraphs),
|
|
305
|
+
value)
|
|
306
|
+
return from_dict(unflat_values)
|
|
307
|
+
|
|
308
|
+
def update(self, values):
|
|
309
|
+
data = to_dict(self)
|
|
310
|
+
for outer_field, mapping in values.items():
|
|
311
|
+
if outer_field == 'context':
|
|
312
|
+
reference_value = data[outer_field]['size']
|
|
313
|
+
elif outer_field == 'edge':
|
|
314
|
+
reference_value = data[outer_field]['source']
|
|
315
|
+
else:
|
|
316
|
+
reference_value = data[outer_field]['feature']
|
|
317
|
+
for inner_field, value in mapping.items():
|
|
318
|
+
if value is None:
|
|
319
|
+
data[outer_field].pop(inner_field, None)
|
|
320
|
+
continue
|
|
321
|
+
data[outer_field][inner_field] = _maybe_convert_new_value(
|
|
322
|
+
value, reference_value
|
|
323
|
+
)
|
|
324
|
+
return self.__class__(**data)
|
|
325
|
+
|
|
326
|
+
def replace(self, values):
|
|
327
|
+
data = to_dict(self)
|
|
328
|
+
for outer_field, mapping in values.items():
|
|
329
|
+
if outer_field == 'context':
|
|
330
|
+
reference_value = data[outer_field]['size']
|
|
331
|
+
elif outer_field == 'edge':
|
|
332
|
+
reference_value = data[outer_field]['source']
|
|
333
|
+
else:
|
|
334
|
+
reference_value = data[outer_field]['feature']
|
|
335
|
+
for inner_field, value in mapping.items():
|
|
336
|
+
values[outer_field][inner_field] = _maybe_convert_new_value(
|
|
337
|
+
value, reference_value
|
|
338
|
+
)
|
|
339
|
+
return self.__class__(**values)
|
|
340
|
+
|
|
341
|
+
def with_context(self, *args, **kwargs):
|
|
342
|
+
context = (
|
|
343
|
+
args[0] if len(args) and isinstance(args[0], dict) else kwargs
|
|
344
|
+
)
|
|
345
|
+
return self.update({'context': context})
|
|
346
|
+
|
|
347
|
+
def __getitem__(self, index):
|
|
348
|
+
if index is None and is_scalar(self):
|
|
349
|
+
return self.__class__(
|
|
350
|
+
context={key: value[None] for (key, value) in self.context.items()},
|
|
351
|
+
node=self.node,
|
|
352
|
+
edge=self.edge,
|
|
353
|
+
)
|
|
354
|
+
if not is_ragged(self):
|
|
355
|
+
is_flat = True
|
|
356
|
+
tensor = self.unflatten()
|
|
357
|
+
else:
|
|
358
|
+
tensor = self
|
|
359
|
+
is_flat = False
|
|
360
|
+
data = to_dict(tensor)
|
|
361
|
+
if isinstance(index, (slice, int)):
|
|
362
|
+
data = tf.nest.map_structure(lambda x: x[index], data)
|
|
363
|
+
else:
|
|
364
|
+
data = tf.nest.map_structure(lambda x: ops.gather(x, index), data)
|
|
365
|
+
tensor = from_dict(data)
|
|
366
|
+
if is_flat and not is_scalar(tensor):
|
|
367
|
+
return tensor.flatten()
|
|
368
|
+
return tensor
|
|
369
|
+
|
|
370
|
+
def __repr__(self):
|
|
371
|
+
return _repr(self)
|
|
372
|
+
|
|
373
|
+
def numpy(self):
|
|
374
|
+
"""For now added to work with `keras.Model.predict`"""
|
|
375
|
+
return self
|
|
376
|
+
|
|
377
|
+
class Spec:
|
|
378
|
+
|
|
379
|
+
def __init__(
|
|
380
|
+
self,
|
|
381
|
+
context: typing.Mapping[str, tf.TensorSpec | tf.RaggedTensorSpec],
|
|
382
|
+
node: typing.Mapping[str, tf.TensorSpec | tf.RaggedTensorSpec],
|
|
383
|
+
edge: typing.Mapping[str, tf.TensorSpec | tf.RaggedTensorSpec],
|
|
384
|
+
) -> None:
|
|
385
|
+
self.context = context
|
|
386
|
+
self.node = node
|
|
387
|
+
self.edge = edge
|
|
388
|
+
|
|
389
|
+
@property
|
|
390
|
+
def shape(self):
|
|
391
|
+
if is_ragged(self):
|
|
392
|
+
return self.node['feature'].shape
|
|
393
|
+
return self.context['size'].shape + [None] + self.node['feature'].shape[1:]
|
|
394
|
+
|
|
395
|
+
@classmethod
|
|
396
|
+
def from_input_shape_dict(cls, input_shape: dict[str, tf.TensorShape]) -> 'GraphTensor.Spec':
|
|
397
|
+
for key, value in input_shape.items():
|
|
398
|
+
input_shape[key] = {k: tf.TensorShape(v) for k, v in value.items()}
|
|
399
|
+
return cls(**tf.nest.map_structure(lambda s: tf.TensorSpec(s, dtype=tf.variant), input_shape))
|
|
400
|
+
|
|
401
|
+
def __repr__(self):
|
|
402
|
+
return _repr(self)
|
|
403
|
+
|
|
404
|
+
|
|
405
|
+
@tf.experimental.dispatch_for_api(tf.concat, {'values': typing.List[GraphTensor]})
|
|
406
|
+
def graph_tensor_concat(
|
|
407
|
+
values: typing.List[GraphTensor],
|
|
408
|
+
axis: int = 0,
|
|
409
|
+
name: str = 'concat'
|
|
410
|
+
) -> GraphTensor:
|
|
411
|
+
ragged = [is_ragged(v) for v in values]
|
|
412
|
+
if 0 < sum(ragged) < len(ragged):
|
|
413
|
+
raise ValueError(
|
|
414
|
+
'Nested data of the GraphTensor instances do not have consistent '
|
|
415
|
+
'types: found both tf.RaggedTensor values and tf.Tensor values.')
|
|
416
|
+
else:
|
|
417
|
+
ragged = ragged[0]
|
|
418
|
+
|
|
419
|
+
if ragged:
|
|
420
|
+
values = [v.flatten() for v in values]
|
|
421
|
+
|
|
422
|
+
flat_values = [tf.nest.flatten(v, expand_composites=True) for v in values]
|
|
423
|
+
flat_values = [tf.concat(f, axis=0) for f in list(zip(*flat_values))]
|
|
424
|
+
num_edges = [keras.ops.shape(v.edge['source'])[0] for v in values]
|
|
425
|
+
num_nodes = [keras.ops.shape(v.node['feature'])[0] for v in values]
|
|
426
|
+
incr = tf.concat([[0], tf.cumsum(num_nodes)[:-1]], axis=0)
|
|
427
|
+
incr = tf.repeat(incr, num_edges)
|
|
428
|
+
value = tf.nest.pack_sequence_as(values[0], flat_values, expand_composites=True)
|
|
429
|
+
|
|
430
|
+
edge_update = {
|
|
431
|
+
'source': value.edge['source'] + incr,
|
|
432
|
+
'target': value.edge['target'] + incr,
|
|
433
|
+
}
|
|
434
|
+
value = value.update(
|
|
435
|
+
{
|
|
436
|
+
'edge': edge_update
|
|
437
|
+
},
|
|
438
|
+
)
|
|
439
|
+
if not ragged:
|
|
440
|
+
return value
|
|
441
|
+
return value.unflatten()
|
|
442
|
+
|
|
443
|
+
# TODO: Clean this up.
|
|
444
|
+
@tf.experimental.dispatch_for_api(tf.stack, {'values': typing.List[GraphTensor]})
|
|
445
|
+
def graph_tensor_stack(
|
|
446
|
+
values: typing.List[GraphTensor],
|
|
447
|
+
axis: int = 0,
|
|
448
|
+
name: str = 'stack'
|
|
449
|
+
) -> GraphTensor:
|
|
450
|
+
ragged = [is_ragged(v) for v in values]
|
|
451
|
+
if not is_scalar(values[0]):
|
|
452
|
+
raise ValueError(
|
|
453
|
+
'tf.stack on a list of `GraphTensor`s is currently '
|
|
454
|
+
'only supported for scalar `GraphTensor`s. '
|
|
455
|
+
)
|
|
456
|
+
if any(ragged):
|
|
457
|
+
raise ValueError(
|
|
458
|
+
'tf.stack on a list of `GraphTensor`s is currently '
|
|
459
|
+
'only supported for flattened `GraphTensor`s. '
|
|
460
|
+
)
|
|
461
|
+
|
|
462
|
+
def concat_or_stack(k, v):
|
|
463
|
+
if k.startswith('context'):
|
|
464
|
+
return tf.stack(v, axis=0)
|
|
465
|
+
return tf.concat(v, axis=0)
|
|
466
|
+
|
|
467
|
+
fields = tuple(tf.type_spec_from_value(values[0]).__dict__)
|
|
468
|
+
num_inner_fields = tuple(len(values[0].__dict__[field]) for field in fields)
|
|
469
|
+
outer_keys = []
|
|
470
|
+
for (f, num_fields) in zip(fields, num_inner_fields):
|
|
471
|
+
outer_keys.extend([f] * num_fields)
|
|
472
|
+
|
|
473
|
+
flat_values = [tf.nest.flatten(v, expand_composites=True) for v in values]
|
|
474
|
+
flat_values = [concat_or_stack(k, f) for k, f in zip(outer_keys, list(zip(*flat_values)))]
|
|
475
|
+
value = tf.nest.pack_sequence_as(values[0], flat_values, expand_composites=True)
|
|
476
|
+
|
|
477
|
+
num_edges = [keras.ops.shape(v.edge['source'])[0] for v in values]
|
|
478
|
+
num_nodes = [keras.ops.shape(v.node['feature'])[0] for v in values]
|
|
479
|
+
incr = tf.concat([[0], tf.cumsum(num_nodes)[:-1]], axis=0)
|
|
480
|
+
incr = tf.repeat(incr, num_edges)
|
|
481
|
+
edge_update = {
|
|
482
|
+
'source': value.edge['source'] + incr,
|
|
483
|
+
'target': value.edge['target'] + incr,
|
|
484
|
+
}
|
|
485
|
+
value = value.update(
|
|
486
|
+
{
|
|
487
|
+
'edge': edge_update
|
|
488
|
+
},
|
|
489
|
+
)
|
|
490
|
+
return value
|
|
491
|
+
|
|
492
|
+
def is_scalar(value_or_spec: GraphTensor | GraphTensor.Spec) -> bool:
|
|
493
|
+
return value_or_spec.context['size'].shape.rank == 0
|
|
494
|
+
|
|
495
|
+
def is_ragged(value_or_spec: GraphTensor | GraphTensor.Spec) -> bool:
|
|
496
|
+
is_ragged = isinstance(
|
|
497
|
+
value_or_spec.node['feature'], (tf.RaggedTensor, tf.RaggedTensorSpec))
|
|
498
|
+
if isinstance(value_or_spec, tf.RaggedTensorSpec):
|
|
499
|
+
is_ragged = (
|
|
500
|
+
is_ragged and value_or_spec.node['feature'].ragged_rank == 1)
|
|
501
|
+
return is_ragged
|
|
502
|
+
|
|
503
|
+
def to_dict(tensor: GraphTensor) -> dict:
|
|
504
|
+
spec = tf.type_spec_from_value(tensor)
|
|
505
|
+
return {key: dict(tensor.__dict__[key]) for key in spec.__dict__}
|
|
506
|
+
|
|
507
|
+
def from_dict(data: dict) -> GraphTensor:
|
|
508
|
+
data['context']['size'] = tf.cast(data['context']['size'], tf.int32)
|
|
509
|
+
data['edge']['source'] = tf.cast(data['edge']['source'], tf.int32)
|
|
510
|
+
data['edge']['target'] = tf.cast(data['edge']['target'], tf.int32)
|
|
511
|
+
return GraphTensor(**data)
|
|
512
|
+
|
|
513
|
+
def is_graph(data):
|
|
514
|
+
if isinstance(data, GraphTensor):
|
|
515
|
+
return True
|
|
516
|
+
elif isinstance(data, dict) and 'size' in data.get('context', {}):
|
|
517
|
+
return True
|
|
518
|
+
return False
|
|
519
|
+
|
|
520
|
+
def _maybe_convert_new_value(
|
|
521
|
+
new_value: tf.Tensor | tf.RaggedTensor,
|
|
522
|
+
old_value: tf.Tensor | tf.RaggedTensor | None,
|
|
523
|
+
) -> tf.Tensor | tf.RaggedTensor:
|
|
524
|
+
if old_value is None:
|
|
525
|
+
return new_value
|
|
526
|
+
is_old_ragged = isinstance(old_value, tf.RaggedTensor)
|
|
527
|
+
is_new_ragged = isinstance(new_value, tf.RaggedTensor)
|
|
528
|
+
if is_old_ragged and not is_new_ragged:
|
|
529
|
+
new_value = old_value.with_flat_values(new_value)
|
|
530
|
+
elif not is_old_ragged and is_new_ragged:
|
|
531
|
+
new_value = new_value.flat_values
|
|
532
|
+
return new_value
|
|
533
|
+
|
|
534
|
+
def _repr(x: GraphTensor | GraphTensor.Spec):
|
|
535
|
+
if isinstance(x, GraphTensor):
|
|
536
|
+
def _trepr(v: tf.Tensor | tf.RaggedTensor):
|
|
537
|
+
if isinstance(v, tf.Tensor):
|
|
538
|
+
return f'<tf.Tensor: shape={v.shape.as_list()}, dtype={v.dtype.name}>'
|
|
539
|
+
return (
|
|
540
|
+
f'<tf.RaggedTensor: shape={v.shape.as_list()}, '
|
|
541
|
+
f'dtype={v.dtype.name}, ragged_rank={v.ragged_rank}>'
|
|
542
|
+
)
|
|
543
|
+
else:
|
|
544
|
+
def _trepr(v: tf.TensorSpec | tf.RaggedTensorSpec):
|
|
545
|
+
if isinstance(v, tf.TensorSpec):
|
|
546
|
+
return f'<tf.TensorSpec: shape={v.shape.as_list()}, dtype={v.dtype.name}>'
|
|
547
|
+
return (
|
|
548
|
+
f'<tf.RaggedTensorSpec: shape={v.shape.as_list()}, '
|
|
549
|
+
f'dtype={v.dtype.name}, ragged_rank={v.ragged_rank}>'
|
|
550
|
+
)
|
|
551
|
+
|
|
552
|
+
context_fields = f',\n '.join([f'{k!r}: {_trepr(v)}' for k, v in x.context.items()])
|
|
553
|
+
node_fields = f',\n '.join([f'{k!r}: {_trepr(v)}'for k, v in x.node.items()])
|
|
554
|
+
edge_fields = f',\n '.join([f'{k!r}: {_trepr(v)}' for k, v in x.edge.items()])
|
|
555
|
+
|
|
556
|
+
context_field = 'context={\n ' + context_fields + '\n }'
|
|
557
|
+
node_field = 'node={\n ' + node_fields + '\n }'
|
|
558
|
+
edge_field = 'edge={\n ' + edge_fields + '\n }'
|
|
559
|
+
|
|
560
|
+
fields = ',\n '.join([context_field, node_field, edge_field])
|
|
561
|
+
return x.__class__.__name__ + '(\n ' + fields + '\n)'
|