molcraft 0.1.0rc10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
molcraft/tensors.py ADDED
@@ -0,0 +1,561 @@
1
+ import warnings
2
+ import tensorflow as tf
3
+ import keras
4
+ import typing
5
+ from tensorflow.python.framework import composite_tensor
6
+
7
+ from molcraft import ops
8
+
9
+
10
+ class GraphTensorBatchEncoder(tf.experimental.ExtensionTypeBatchEncoder):
11
+
12
+ def batch(self, spec: 'GraphTensor.Spec', batch_size: int | None):
13
+ """Batches spec.
14
+ """
15
+
16
+ def batch_field(f):
17
+ if isinstance(f, tf.TensorSpec):
18
+ return tf.TensorSpec(
19
+ shape=[None] + f.shape[1:],
20
+ dtype=f.dtype
21
+ )
22
+ elif isinstance(f, tf.RaggedTensorSpec):
23
+ return tf.RaggedTensorSpec(
24
+ shape=[batch_size, None] + f.shape[1:],
25
+ dtype=f.dtype,
26
+ ragged_rank=1,
27
+ row_splits_dtype=f.row_splits_dtype
28
+ )
29
+ elif isinstance(f, tf.TypeSpec):
30
+ return f.__batch_encoder__.batch(f, batch_size)
31
+ return f
32
+ fields = dict(spec.__dict__)
33
+ # Pop context fields as they will be batched differently
34
+ context_fields = fields.pop('context')
35
+ batched_fields = tf.nest.map_structure(batch_field, fields)
36
+ batched_spec = object.__new__(type(spec))
37
+ batched_context_fields = tf.nest.map_structure(
38
+ lambda spec: tf.TensorSpec([batch_size] + spec.shape, spec.dtype),
39
+ context_fields
40
+ )
41
+ batched_spec.__dict__.update({'context': batched_context_fields})
42
+ batched_spec.__dict__.update(batched_fields)
43
+ return batched_spec
44
+
45
+ def unbatch(self, spec: 'GraphTensor.Spec'):
46
+ """Unbatches spec.
47
+ """
48
+
49
+ def unbatch_field(f):
50
+ if isinstance(f, tf.TensorSpec):
51
+ return tf.TensorSpec(
52
+ shape=[None] + f.shape[1:],
53
+ dtype=f.dtype
54
+ )
55
+ elif isinstance(f, tf.RaggedTensorSpec):
56
+ return tf.RaggedTensorSpec(
57
+ shape=[None] + f.shape[2:],
58
+ dtype=f.dtype,
59
+ ragged_rank=0,
60
+ row_splits_dtype=f.row_splits_dtype
61
+ )
62
+ elif isinstance(f, tf.TypeSpec):
63
+ return f.__batch_encoder__.unbatch(f)
64
+ return f
65
+ fields = dict(spec.__dict__)
66
+ # Pop context fields as they will be unbatched differently
67
+ context_fields = fields.pop('context')
68
+ unbatched_fields = tf.nest.map_structure(unbatch_field, fields)
69
+ unbatched_context_fields = tf.nest.map_structure(
70
+ lambda spec: tf.TensorSpec(spec.shape[1:], spec.dtype),
71
+ context_fields
72
+ )
73
+ unbatched_spec = object.__new__(type(spec))
74
+ unbatched_spec.__dict__.update({'context': unbatched_context_fields})
75
+ unbatched_spec.__dict__.update(unbatched_fields)
76
+ return unbatched_spec
77
+
78
+ def encode(self, spec: 'GraphTensor.Spec', value: 'GraphTensor', minimum_rank: int = 0):
79
+ """Encodes value.
80
+ """
81
+ unflatten = False if (is_ragged(spec) or is_scalar(spec)) else True
82
+ if unflatten:
83
+ value = value.unflatten()
84
+ value_components = tuple(value.__dict__[key] for key in spec.__dict__)
85
+ value_components = tuple(
86
+ x for x in tf.nest.flatten(value_components)
87
+ if isinstance(x, (tf.Tensor, composite_tensor.CompositeTensor))
88
+ )
89
+ return value_components
90
+
91
+ def encoding_specs(self, spec: 'GraphTensor.Spec'):
92
+ """Matches spec and encoded value of `encode(spec, value)`.
93
+ """
94
+ def encode_fields(f):
95
+ if isinstance(f, tf.TensorSpec):
96
+ scalar = is_scalar(spec)
97
+ return tf.RaggedTensorSpec(
98
+ shape=([None] if scalar else [None, None]) + f.shape[1:],
99
+ dtype=f.dtype,
100
+ ragged_rank=(0 if scalar else 1),
101
+ row_splits_dtype=spec.context['size'].dtype
102
+ )
103
+ return f
104
+ fields = dict(spec.__dict__)
105
+ context_fields = fields.pop('context')
106
+ encoded_fields = tf.nest.map_structure(encode_fields, fields)
107
+ encoded_fields = {**{'context': context_fields}, **encoded_fields}
108
+ spec_components = tuple(encoded_fields.values())
109
+ spec_components = tuple(
110
+ x for x in tf.nest.flatten(spec_components)
111
+ if isinstance(x, tf.TypeSpec)
112
+ )
113
+ return spec_components
114
+
115
+ def decode(self, spec, encoded_value):
116
+ """Decodes encoded value.
117
+ """
118
+ spec_tuple = tuple(spec.__dict__.values())
119
+ encoded_value = iter(encoded_value)
120
+ value_tuple = [
121
+ next(encoded_value) if isinstance(x, tf.TypeSpec) else x
122
+ for x in tf.nest.flatten(spec_tuple)
123
+ ]
124
+ value_tuple = tf.nest.pack_sequence_as(spec_tuple, value_tuple)
125
+ fields = dict(zip(spec.__dict__.keys(), value_tuple))
126
+ value = object.__new__(spec.value_type)
127
+ value.__dict__.update(fields)
128
+ flatten = is_ragged(value) and not is_ragged(spec)
129
+ if flatten:
130
+ value = value.flatten()
131
+ return value
132
+
133
+
134
+ class GraphTensor(tf.experimental.BatchableExtensionType):
135
+ context: typing.Mapping[str, tf.Tensor]
136
+ node: typing.Mapping[str, typing.Union[tf.Tensor, tf.RaggedTensor]]
137
+ edge: typing.Mapping[str, typing.Union[tf.Tensor, tf.RaggedTensor]]
138
+
139
+ __batch_encoder__ = GraphTensorBatchEncoder()
140
+
141
+ __name__ = 'GraphTensor'
142
+
143
+ def __validate__(self):
144
+ if tf.executing_eagerly():
145
+ assert 'size' in self.context, "graph.context['size'] is required."
146
+ assert self.context['size'].dtype == tf.int32, (
147
+ "dtype of graph.context['size'] needs to be int32."
148
+ )
149
+ assert 'feature' in self.node, "graph.node['feature'] is required."
150
+ assert 'source' in self.edge, "graph.edge['source'] is required."
151
+ assert 'target' in self.edge, "graph.edge['target'] is required."
152
+ assert self.edge['source'].dtype == tf.int32, (
153
+ "dtype of graph.edge['source'] needs to be int32."
154
+ )
155
+ assert self.edge['target'].dtype == tf.int32, (
156
+ "dtype of graph.edge['target'] needs to be int32."
157
+ )
158
+ if isinstance(self.node['feature'], tf.Tensor):
159
+ num_nodes = keras.ops.shape(self.node['feature'])[0]
160
+ else:
161
+ num_nodes = keras.ops.sum(self.node['feature'].row_lengths())
162
+ assert keras.ops.sum(self.context['size']) == num_nodes, (
163
+ "graph.node['feature'] tensor is incompatible with graph.context['size']"
164
+ )
165
+
166
+ @property
167
+ def spec(self):
168
+ def unspecify_spec(s):
169
+ if isinstance(s, tf.TensorSpec):
170
+ return tf.TensorSpec([None] + s.shape[1:], s.dtype)
171
+ return s
172
+ orig_spec = tf.type_spec_from_value(self)
173
+ fields = dict(orig_spec.__dict__)
174
+ context_fields = fields.pop('context')
175
+ new_spec_components = tf.nest.map_structure(unspecify_spec, fields)
176
+ new_spec_components['context'] = context_fields
177
+ return orig_spec.__class__(**new_spec_components)
178
+
179
+ @property
180
+ def shape(self):
181
+ if is_ragged(self):
182
+ return self.node['feature'].shape
183
+ return self.context['size'].shape + [None] + self.node['feature'].shape[1:]
184
+
185
+ @property
186
+ def dtype(self):
187
+ return self.node['feature'].dtype
188
+
189
+ @property
190
+ def graph_indicator(self):
191
+ dtype = self.context['size'].dtype
192
+ if is_scalar(self):
193
+ return tf.zeros(tf.shape(self.node['feature'])[:1], dtype=dtype)
194
+ num_graphs = keras.ops.shape(self.context['size'])[0]
195
+ if is_ragged(self):
196
+ return keras.ops.arange(num_graphs, dtype=dtype)
197
+ return keras.ops.repeat(keras.ops.arange(num_graphs, dtype=dtype), self.context['size'])
198
+
199
+ @property
200
+ def num_subgraphs(self) -> tf.Tensor:
201
+ dtype = self.context['size'].dtype
202
+ if is_scalar(self):
203
+ num_subgraphs = tf.constant(1, dtype=dtype)
204
+ else:
205
+ num_subgraphs = tf.shape(self.context['size'], out_type=dtype)[0]
206
+ if tf.executing_eagerly():
207
+ return int(num_subgraphs)
208
+ return num_subgraphs
209
+
210
+ @property
211
+ def num_nodes(self):
212
+ num_nodes = keras.ops.shape(self.node['feature'])[0]
213
+ if tf.executing_eagerly():
214
+ return int(num_nodes)
215
+ return num_nodes
216
+
217
+ @property
218
+ def num_edges(self):
219
+ num_edges = keras.ops.shape(self.edge['source'])[0]
220
+ if tf.executing_eagerly():
221
+ return int(num_edges)
222
+ return num_edges
223
+
224
+ def gather(self, node_attr: str, edge_type: str) -> tf.Tensor:
225
+ if edge_type != 'source' and edge_type != 'target':
226
+ raise ValueError
227
+ return ops.gather(self.node[node_attr], self.edge[edge_type])
228
+
229
+ def aggregate(self, edge_attr: str, edge_type: str = 'target', mode: str = 'sum') -> tf.Tensor:
230
+ if edge_type != 'source' and edge_type != 'target':
231
+ raise ValueError('`edge_attr` needs to be `source` or `target`.')
232
+ edge_attr = self.edge[edge_attr]
233
+ if 'weight' in self.edge:
234
+ edge_attr = ops.edge_weight(edge_attr, self.edge['weight'])
235
+ return ops.aggregate(edge_attr, self.edge[edge_type], self.num_nodes, mode=mode)
236
+
237
+ def propagate(self, add_edge_feature: bool = False):
238
+ updated_feature = ops.propagate(
239
+ node_feature=self.node['feature'],
240
+ edge_source=self.edge['source'],
241
+ edge_target=self.edge['target'],
242
+ edge_feature=self.edge.get('feature') if add_edge_feature else None,
243
+ edge_weight=self.edge.get('weight'),
244
+ )
245
+ return self.update({'node': {'feature': updated_feature}})
246
+
247
+ def flatten(self):
248
+ if not is_ragged(self):
249
+ raise ValueError(
250
+ f"{self.__class__.__qualname__} instance is already flat.")
251
+ def flatten_fn(x):
252
+ if isinstance(x, tf.RaggedTensor):
253
+ return x.flat_values
254
+ return x
255
+ edge_increment = ops.gather(
256
+ self.node['feature'].row_starts(), self.edge['source'].value_rowids())
257
+ edge_increment = keras.ops.cast(
258
+ edge_increment, dtype=self.edge['source'].dtype)
259
+ data = to_dict(self)
260
+ flat_values = tf.nest.map_structure(flatten_fn, data)
261
+ flat_values['edge']['source'] += edge_increment
262
+ flat_values['edge']['target'] += edge_increment
263
+ return from_dict(flat_values)
264
+
265
+ def unflatten(self, *, force: bool = False):
266
+ if is_scalar(self):
267
+ raise ValueError(
268
+ f"{self.__class__.__qualname__} instance is a scalar "
269
+ "and cannot be unflattened.")
270
+ if is_ragged(self):
271
+ raise ValueError(
272
+ f"{self.__class__.__qualname__} instance is already unflat.")
273
+ def unflatten_fn(x, value_rowids, nrows) -> tf.RaggedTensor:
274
+ if isinstance(x, tf.Tensor):
275
+ return tf.RaggedTensor.from_value_rowids(x, value_rowids, nrows)
276
+ return x
277
+ graph_indicator_node = self.graph_indicator
278
+ graph_indicator_edge = ops.gather(graph_indicator_node, self.edge['source'])
279
+ if force:
280
+ sorted_indices = keras.ops.argsort(graph_indicator_edge)
281
+ num_subgraphs = self.num_subgraphs
282
+ unflat_values = {}
283
+ data = to_dict(self)
284
+ for key in tf.type_spec_from_value(self).__dict__:
285
+ value = data[key]
286
+ if key == 'context':
287
+ unflat_values[key] = value
288
+ elif key == 'node':
289
+ unflat_values[key] = tf.nest.map_structure(
290
+ lambda x: unflatten_fn(x, graph_indicator_node, num_subgraphs),
291
+ value)
292
+ row_starts = unflat_values[key]['feature'].row_starts()
293
+ edge_decrement = ops.gather(row_starts, graph_indicator_edge)
294
+ if force:
295
+ edge_decrement = ops.gather(edge_decrement, sorted_indices)
296
+ graph_indicator_edge = ops.gather(graph_indicator_edge, sorted_indices)
297
+ edge_decrement = keras.ops.cast(edge_decrement, dtype=self.edge['source'].dtype)
298
+ elif key == 'edge':
299
+ if force:
300
+ value = tf.nest.map_structure(lambda x: ops.gather(x, sorted_indices), value)
301
+ value['source'] -= edge_decrement
302
+ value['target'] -= edge_decrement
303
+ unflat_values[key] = tf.nest.map_structure(
304
+ lambda x: unflatten_fn(x, graph_indicator_edge, num_subgraphs),
305
+ value)
306
+ return from_dict(unflat_values)
307
+
308
+ def update(self, values):
309
+ data = to_dict(self)
310
+ for outer_field, mapping in values.items():
311
+ if outer_field == 'context':
312
+ reference_value = data[outer_field]['size']
313
+ elif outer_field == 'edge':
314
+ reference_value = data[outer_field]['source']
315
+ else:
316
+ reference_value = data[outer_field]['feature']
317
+ for inner_field, value in mapping.items():
318
+ if value is None:
319
+ data[outer_field].pop(inner_field, None)
320
+ continue
321
+ data[outer_field][inner_field] = _maybe_convert_new_value(
322
+ value, reference_value
323
+ )
324
+ return self.__class__(**data)
325
+
326
+ def replace(self, values):
327
+ data = to_dict(self)
328
+ for outer_field, mapping in values.items():
329
+ if outer_field == 'context':
330
+ reference_value = data[outer_field]['size']
331
+ elif outer_field == 'edge':
332
+ reference_value = data[outer_field]['source']
333
+ else:
334
+ reference_value = data[outer_field]['feature']
335
+ for inner_field, value in mapping.items():
336
+ values[outer_field][inner_field] = _maybe_convert_new_value(
337
+ value, reference_value
338
+ )
339
+ return self.__class__(**values)
340
+
341
+ def with_context(self, *args, **kwargs):
342
+ context = (
343
+ args[0] if len(args) and isinstance(args[0], dict) else kwargs
344
+ )
345
+ return self.update({'context': context})
346
+
347
+ def __getitem__(self, index):
348
+ if index is None and is_scalar(self):
349
+ return self.__class__(
350
+ context={key: value[None] for (key, value) in self.context.items()},
351
+ node=self.node,
352
+ edge=self.edge,
353
+ )
354
+ if not is_ragged(self):
355
+ is_flat = True
356
+ tensor = self.unflatten()
357
+ else:
358
+ tensor = self
359
+ is_flat = False
360
+ data = to_dict(tensor)
361
+ if isinstance(index, (slice, int)):
362
+ data = tf.nest.map_structure(lambda x: x[index], data)
363
+ else:
364
+ data = tf.nest.map_structure(lambda x: ops.gather(x, index), data)
365
+ tensor = from_dict(data)
366
+ if is_flat and not is_scalar(tensor):
367
+ return tensor.flatten()
368
+ return tensor
369
+
370
+ def __repr__(self):
371
+ return _repr(self)
372
+
373
+ def numpy(self):
374
+ """For now added to work with `keras.Model.predict`"""
375
+ return self
376
+
377
+ class Spec:
378
+
379
+ def __init__(
380
+ self,
381
+ context: typing.Mapping[str, tf.TensorSpec | tf.RaggedTensorSpec],
382
+ node: typing.Mapping[str, tf.TensorSpec | tf.RaggedTensorSpec],
383
+ edge: typing.Mapping[str, tf.TensorSpec | tf.RaggedTensorSpec],
384
+ ) -> None:
385
+ self.context = context
386
+ self.node = node
387
+ self.edge = edge
388
+
389
+ @property
390
+ def shape(self):
391
+ if is_ragged(self):
392
+ return self.node['feature'].shape
393
+ return self.context['size'].shape + [None] + self.node['feature'].shape[1:]
394
+
395
+ @classmethod
396
+ def from_input_shape_dict(cls, input_shape: dict[str, tf.TensorShape]) -> 'GraphTensor.Spec':
397
+ for key, value in input_shape.items():
398
+ input_shape[key] = {k: tf.TensorShape(v) for k, v in value.items()}
399
+ return cls(**tf.nest.map_structure(lambda s: tf.TensorSpec(s, dtype=tf.variant), input_shape))
400
+
401
+ def __repr__(self):
402
+ return _repr(self)
403
+
404
+
405
+ @tf.experimental.dispatch_for_api(tf.concat, {'values': typing.List[GraphTensor]})
406
+ def graph_tensor_concat(
407
+ values: typing.List[GraphTensor],
408
+ axis: int = 0,
409
+ name: str = 'concat'
410
+ ) -> GraphTensor:
411
+ ragged = [is_ragged(v) for v in values]
412
+ if 0 < sum(ragged) < len(ragged):
413
+ raise ValueError(
414
+ 'Nested data of the GraphTensor instances do not have consistent '
415
+ 'types: found both tf.RaggedTensor values and tf.Tensor values.')
416
+ else:
417
+ ragged = ragged[0]
418
+
419
+ if ragged:
420
+ values = [v.flatten() for v in values]
421
+
422
+ flat_values = [tf.nest.flatten(v, expand_composites=True) for v in values]
423
+ flat_values = [tf.concat(f, axis=0) for f in list(zip(*flat_values))]
424
+ num_edges = [keras.ops.shape(v.edge['source'])[0] for v in values]
425
+ num_nodes = [keras.ops.shape(v.node['feature'])[0] for v in values]
426
+ incr = tf.concat([[0], tf.cumsum(num_nodes)[:-1]], axis=0)
427
+ incr = tf.repeat(incr, num_edges)
428
+ value = tf.nest.pack_sequence_as(values[0], flat_values, expand_composites=True)
429
+
430
+ edge_update = {
431
+ 'source': value.edge['source'] + incr,
432
+ 'target': value.edge['target'] + incr,
433
+ }
434
+ value = value.update(
435
+ {
436
+ 'edge': edge_update
437
+ },
438
+ )
439
+ if not ragged:
440
+ return value
441
+ return value.unflatten()
442
+
443
+ # TODO: Clean this up.
444
+ @tf.experimental.dispatch_for_api(tf.stack, {'values': typing.List[GraphTensor]})
445
+ def graph_tensor_stack(
446
+ values: typing.List[GraphTensor],
447
+ axis: int = 0,
448
+ name: str = 'stack'
449
+ ) -> GraphTensor:
450
+ ragged = [is_ragged(v) for v in values]
451
+ if not is_scalar(values[0]):
452
+ raise ValueError(
453
+ 'tf.stack on a list of `GraphTensor`s is currently '
454
+ 'only supported for scalar `GraphTensor`s. '
455
+ )
456
+ if any(ragged):
457
+ raise ValueError(
458
+ 'tf.stack on a list of `GraphTensor`s is currently '
459
+ 'only supported for flattened `GraphTensor`s. '
460
+ )
461
+
462
+ def concat_or_stack(k, v):
463
+ if k.startswith('context'):
464
+ return tf.stack(v, axis=0)
465
+ return tf.concat(v, axis=0)
466
+
467
+ fields = tuple(tf.type_spec_from_value(values[0]).__dict__)
468
+ num_inner_fields = tuple(len(values[0].__dict__[field]) for field in fields)
469
+ outer_keys = []
470
+ for (f, num_fields) in zip(fields, num_inner_fields):
471
+ outer_keys.extend([f] * num_fields)
472
+
473
+ flat_values = [tf.nest.flatten(v, expand_composites=True) for v in values]
474
+ flat_values = [concat_or_stack(k, f) for k, f in zip(outer_keys, list(zip(*flat_values)))]
475
+ value = tf.nest.pack_sequence_as(values[0], flat_values, expand_composites=True)
476
+
477
+ num_edges = [keras.ops.shape(v.edge['source'])[0] for v in values]
478
+ num_nodes = [keras.ops.shape(v.node['feature'])[0] for v in values]
479
+ incr = tf.concat([[0], tf.cumsum(num_nodes)[:-1]], axis=0)
480
+ incr = tf.repeat(incr, num_edges)
481
+ edge_update = {
482
+ 'source': value.edge['source'] + incr,
483
+ 'target': value.edge['target'] + incr,
484
+ }
485
+ value = value.update(
486
+ {
487
+ 'edge': edge_update
488
+ },
489
+ )
490
+ return value
491
+
492
+ def is_scalar(value_or_spec: GraphTensor | GraphTensor.Spec) -> bool:
493
+ return value_or_spec.context['size'].shape.rank == 0
494
+
495
+ def is_ragged(value_or_spec: GraphTensor | GraphTensor.Spec) -> bool:
496
+ is_ragged = isinstance(
497
+ value_or_spec.node['feature'], (tf.RaggedTensor, tf.RaggedTensorSpec))
498
+ if isinstance(value_or_spec, tf.RaggedTensorSpec):
499
+ is_ragged = (
500
+ is_ragged and value_or_spec.node['feature'].ragged_rank == 1)
501
+ return is_ragged
502
+
503
+ def to_dict(tensor: GraphTensor) -> dict:
504
+ spec = tf.type_spec_from_value(tensor)
505
+ return {key: dict(tensor.__dict__[key]) for key in spec.__dict__}
506
+
507
+ def from_dict(data: dict) -> GraphTensor:
508
+ data['context']['size'] = keras.ops.cast(data['context']['size'], tf.int32)
509
+ data['edge']['source'] = keras.ops.cast(data['edge']['source'], tf.int32)
510
+ data['edge']['target'] = keras.ops.cast(data['edge']['target'], tf.int32)
511
+ return GraphTensor(**data)
512
+
513
+ def is_graph(data):
514
+ if isinstance(data, GraphTensor):
515
+ return True
516
+ elif isinstance(data, dict) and 'size' in data.get('context', {}):
517
+ return True
518
+ return False
519
+
520
+ def _maybe_convert_new_value(
521
+ new_value: tf.Tensor | tf.RaggedTensor,
522
+ old_value: tf.Tensor | tf.RaggedTensor | None,
523
+ ) -> tf.Tensor | tf.RaggedTensor:
524
+ if old_value is None:
525
+ return new_value
526
+ is_old_ragged = isinstance(old_value, tf.RaggedTensor)
527
+ is_new_ragged = isinstance(new_value, tf.RaggedTensor)
528
+ if is_old_ragged and not is_new_ragged:
529
+ new_value = old_value.with_flat_values(new_value)
530
+ elif not is_old_ragged and is_new_ragged:
531
+ new_value = new_value.flat_values
532
+ return new_value
533
+
534
+ def _repr(x: GraphTensor | GraphTensor.Spec):
535
+ if isinstance(x, GraphTensor):
536
+ def _trepr(v: tf.Tensor | tf.RaggedTensor):
537
+ if isinstance(v, tf.Tensor):
538
+ return f'<tf.Tensor: shape={v.shape.as_list()}, dtype={v.dtype.name}>'
539
+ return (
540
+ f'<tf.RaggedTensor: shape={v.shape.as_list()}, '
541
+ f'dtype={v.dtype.name}, ragged_rank={v.ragged_rank}>'
542
+ )
543
+ else:
544
+ def _trepr(v: tf.TensorSpec | tf.RaggedTensorSpec):
545
+ if isinstance(v, tf.TensorSpec):
546
+ return f'<tf.TensorSpec: shape={v.shape.as_list()}, dtype={v.dtype.name}>'
547
+ return (
548
+ f'<tf.RaggedTensorSpec: shape={v.shape.as_list()}, '
549
+ f'dtype={v.dtype.name}, ragged_rank={v.ragged_rank}>'
550
+ )
551
+
552
+ context_fields = f',\n '.join([f'{k!r}: {_trepr(v)}' for k, v in x.context.items()])
553
+ node_fields = f',\n '.join([f'{k!r}: {_trepr(v)}'for k, v in x.node.items()])
554
+ edge_fields = f',\n '.join([f'{k!r}: {_trepr(v)}' for k, v in x.edge.items()])
555
+
556
+ context_field = 'context={\n ' + context_fields + '\n }'
557
+ node_field = 'node={\n ' + node_fields + '\n }'
558
+ edge_field = 'edge={\n ' + edge_fields + '\n }'
559
+
560
+ fields = ',\n '.join([context_field, node_field, edge_field])
561
+ return x.__class__.__name__ + '(\n ' + fields + '\n)'