keras-rs-nightly 0.0.1.dev2025021903__py3-none-any.whl → 0.3.1.dev202512130338__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. keras_rs/__init__.py +9 -28
  2. keras_rs/layers/__init__.py +37 -0
  3. keras_rs/losses/__init__.py +19 -0
  4. keras_rs/metrics/__init__.py +16 -0
  5. keras_rs/src/layers/embedding/base_distributed_embedding.py +1151 -0
  6. keras_rs/src/layers/embedding/distributed_embedding.py +33 -0
  7. keras_rs/src/layers/embedding/distributed_embedding_config.py +132 -0
  8. keras_rs/src/layers/embedding/embed_reduce.py +309 -0
  9. keras_rs/src/layers/embedding/jax/__init__.py +0 -0
  10. keras_rs/src/layers/embedding/jax/checkpoint_utils.py +104 -0
  11. keras_rs/src/layers/embedding/jax/config_conversion.py +468 -0
  12. keras_rs/src/layers/embedding/jax/distributed_embedding.py +829 -0
  13. keras_rs/src/layers/embedding/jax/embedding_lookup.py +276 -0
  14. keras_rs/src/layers/embedding/jax/embedding_utils.py +217 -0
  15. keras_rs/src/layers/embedding/tensorflow/__init__.py +0 -0
  16. keras_rs/src/layers/embedding/tensorflow/config_conversion.py +363 -0
  17. keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py +436 -0
  18. keras_rs/src/layers/feature_interaction/__init__.py +0 -0
  19. keras_rs/src/layers/{modeling → feature_interaction}/dot_interaction.py +116 -25
  20. keras_rs/src/layers/{modeling → feature_interaction}/feature_cross.py +40 -22
  21. keras_rs/src/layers/retrieval/brute_force_retrieval.py +16 -65
  22. keras_rs/src/layers/retrieval/hard_negative_mining.py +94 -0
  23. keras_rs/src/layers/retrieval/remove_accidental_hits.py +97 -0
  24. keras_rs/src/layers/retrieval/retrieval.py +127 -0
  25. keras_rs/src/layers/retrieval/sampling_probability_correction.py +63 -0
  26. keras_rs/src/losses/__init__.py +0 -0
  27. keras_rs/src/losses/list_mle_loss.py +212 -0
  28. keras_rs/src/losses/pairwise_hinge_loss.py +90 -0
  29. keras_rs/src/losses/pairwise_logistic_loss.py +99 -0
  30. keras_rs/src/losses/pairwise_loss.py +165 -0
  31. keras_rs/src/losses/pairwise_loss_utils.py +39 -0
  32. keras_rs/src/losses/pairwise_mean_squared_error.py +133 -0
  33. keras_rs/src/losses/pairwise_soft_zero_one_loss.py +98 -0
  34. keras_rs/src/metrics/__init__.py +0 -0
  35. keras_rs/src/metrics/dcg.py +161 -0
  36. keras_rs/src/metrics/mean_average_precision.py +130 -0
  37. keras_rs/src/metrics/mean_reciprocal_rank.py +121 -0
  38. keras_rs/src/metrics/ndcg.py +197 -0
  39. keras_rs/src/metrics/precision_at_k.py +117 -0
  40. keras_rs/src/metrics/ranking_metric.py +260 -0
  41. keras_rs/src/metrics/ranking_metrics_utils.py +257 -0
  42. keras_rs/src/metrics/recall_at_k.py +108 -0
  43. keras_rs/src/metrics/utils.py +70 -0
  44. keras_rs/src/types.py +43 -14
  45. keras_rs/src/utils/doc_string_utils.py +53 -0
  46. keras_rs/src/utils/keras_utils.py +52 -3
  47. keras_rs/src/utils/tpu_test_utils.py +120 -0
  48. keras_rs/src/version.py +1 -1
  49. {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/METADATA +88 -8
  50. keras_rs_nightly-0.3.1.dev202512130338.dist-info/RECORD +58 -0
  51. {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/WHEEL +1 -1
  52. keras_rs/api/__init__.py +0 -9
  53. keras_rs/api/layers/__init__.py +0 -11
  54. keras_rs_nightly-0.0.1.dev2025021903.dist-info/RECORD +0 -19
  55. /keras_rs/src/layers/{modeling → embedding}/__init__.py +0 -0
  56. {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1151 @@
1
+ import collections
2
+ import dataclasses
3
+ import importlib.util
4
+ import typing
5
+ from typing import Any, Sequence
6
+
7
+ import keras
8
+ import numpy as np
9
+ from keras.src import backend
10
+
11
+ from keras_rs.src import types
12
+ from keras_rs.src.layers.embedding import distributed_embedding_config
13
+ from keras_rs.src.layers.embedding import embed_reduce
14
+ from keras_rs.src.utils import keras_utils
15
+
16
+ FeatureConfig = distributed_embedding_config.FeatureConfig
17
+ TableConfig = distributed_embedding_config.TableConfig
18
+ EmbedReduce = embed_reduce.EmbedReduce
19
+
20
+
21
+ SUPPORTED_PLACEMENTS = ("auto", "default_device", "sparsecore")
22
+
23
+
24
+ @dataclasses.dataclass(eq=True, unsafe_hash=True, order=True)
25
+ class PlacementAndPath:
26
+ placement: str
27
+ path: str
28
+
29
+
30
+ def _ragged_to_dense_inputs(
31
+ inputs: Any, weights: Any | None = None, dense_row_length: int | None = None
32
+ ) -> Any:
33
+ """Converts a ragged set of inputs and weights to dense.
34
+
35
+ If inputs are ragged and weights are `None`, will create a dense set of
36
+ weights to mask out the new padded values.
37
+
38
+ If inputs are not ragged, returns the original `inputs` and `weights`
39
+ unmodified.
40
+
41
+ Args:
42
+ inputs: The inputs array.
43
+ weights: The optional weights array.
44
+ dense_row_length: The output dense row length. If None, uses the length
45
+ of the longest row of the input.
46
+
47
+ Returns:
48
+ Tuple of new (inputs, weights). If the input is a ragged array, returns
49
+ dense numpy arrays. Otherwise, returns the original input and weights.
50
+ """
51
+ x = inputs
52
+ w = weights
53
+ # tf.Ragged or other .numpy()-able types.
54
+ if hasattr(x, "numpy") and callable(getattr(x, "numpy")):
55
+ x = x.numpy()
56
+
57
+ # Ragged numpy array to dense numpy array.
58
+ if isinstance(x, np.ndarray) and len(x) > 0 and x.dtype == np.ndarray:
59
+ # Maybe convert weights to numpy.
60
+ if (
61
+ w is not None
62
+ and hasattr(w, "numpy")
63
+ and callable(getattr(w, "numpy"))
64
+ ):
65
+ w = w.numpy()
66
+
67
+ if dense_row_length is None:
68
+ # Use length of longest row.
69
+ dense_row_length = max([len(row) for row in x])
70
+
71
+ output = np.zeros((len(x), dense_row_length), dtype=x[0].dtype)
72
+ for i, row in enumerate(x):
73
+ output[i, : len(row)] = row
74
+
75
+ output_weights = np.zeros((len(x), dense_row_length), dtype=np.float32)
76
+ if w is None:
77
+ for i, row in enumerate(x):
78
+ output_weights[i, : len(row)] = 1.0
79
+ else:
80
+ for i, row in enumerate(w):
81
+ output_weights[i, : len(row)] = row
82
+
83
+ return output, output_weights
84
+
85
+ # Convert symbolic ragged/sparse keras tensors to dense tensors.
86
+ if isinstance(x, keras.KerasTensor) and (x.ragged or x.sparse):
87
+ inputs = keras.ops.convert_to_tensor(x, ragged=False)
88
+ weights = keras.ops.convert_to_tensor(x, dtype="float32", ragged=False)
89
+
90
+ # If not a ragged array, return the original, unmodified.
91
+ return inputs, weights
92
+
93
+
94
+ class DistributedEmbedding(keras.layers.Layer):
95
+ """DistributedEmbedding, a layer for accelerated large embedding lookups.
96
+
97
+ ---
98
+
99
+ ## Note: `DistributedEmbedding` is in Preview.
100
+
101
+ ---
102
+
103
+ `DistributedEmbedding` is a layer optimized for TPU chips with SparseCore
104
+ and can dramatically improve the speed of embedding lookups and embedding
105
+ training. It works by combining multiple lookups into one invocation, and by
106
+ sharding the embedding tables across the available chips. Note that one will
107
+ only see performance benefits for embedding tables that are large enough to
108
+ to require sharding because they don't fit on a single chip. More details
109
+ are provided in the "Placement" section below.
110
+
111
+ On other hardware, GPUs, CPUs and TPUs without SparseCore,
112
+ `DistributedEmbedding` provides the same API without any specific
113
+ acceleration. No particular distribution scheme is applied besides the one
114
+ set via `keras.distribution.set_distribution`.
115
+
116
+ `DistributedEmbedding` embeds sequences of inputs and reduces them to a
117
+ single embedding by applying a configurable combiner function.
118
+
119
+ ### Configuration
120
+
121
+ #### Features and tables
122
+
123
+ A `DistributedEmbedding` embedding layer is configured via a set of
124
+ `keras_rs.layers.FeatureConfig` objects, which themselves refer to
125
+ `keras_rs.layers.TableConfig` objects.
126
+
127
+ - `TableConfig` defines an embedding table with parameters such as its
128
+ vocabulary size, embedding dimension, as well as a combiner for reduction
129
+ and optimizer for training.
130
+ - `FeatureConfig` defines what input features the `DistributedEmbedding`
131
+ will handle and which embedding table to use. Note that multiple features
132
+ can use the same embedding table.
133
+
134
+ ```python
135
+ table1 = keras_rs.layers.TableConfig(
136
+ name="table1",
137
+ vocabulary_size=TABLE1_VOCABULARY_SIZE,
138
+ embedding_dim=TABLE1_EMBEDDING_SIZE,
139
+ placement="auto",
140
+ )
141
+ table2 = keras_rs.layers.TableConfig(
142
+ name="table2",
143
+ vocabulary_size=TABLE2_VOCABULARY_SIZE,
144
+ embedding_dim=TABLE2_EMBEDDING_SIZE,
145
+ placement="auto",
146
+ )
147
+
148
+ feature1 = keras_rs.layers.FeatureConfig(
149
+ name="feature1",
150
+ table=table1,
151
+ input_shape=(GLOBAL_BATCH_SIZE,),
152
+ output_shape=(GLOBAL_BATCH_SIZE, TABLE1_EMBEDDING_SIZE),
153
+ )
154
+ feature2 = keras_rs.layers.FeatureConfig(
155
+ name="feature2",
156
+ table=table2,
157
+ input_shape=(GLOBAL_BATCH_SIZE,),
158
+ output_shape=(GLOBAL_BATCH_SIZE, TABLE2_EMBEDDING_SIZE),
159
+ )
160
+
161
+ feature_configs = {
162
+ "feature1": feature1,
163
+ "feature2": feature2,
164
+ }
165
+
166
+ embedding = keras_rs.layers.DistributedEmbedding(feature_configs)
167
+ ```
168
+
169
+ #### Optimizers
170
+
171
+ Each embedding table within `DistributedEmbedding` uses its own optimizer
172
+ for training, which is independent from the optimizer set on the model via
173
+ `model.compile()`.
174
+
175
+ Note that not all optimizers are supported. Currently, the following are
176
+ supported on all backends and accelerators:
177
+
178
+ - `keras.optimizers.Adagrad`
179
+ - `keras.optimizers.Adam`
180
+ - `keras.optimizers.Ftrl`
181
+ - `keras.optimizers.SGD`
182
+
183
+ Also, not all parameters of the optimizers are supported (e.g. the
184
+ `nesterov` option of `SGD`). An error is raised when an unsupported
185
+ optimizer or an unsupported optimizer parameter is used.
186
+
187
+ #### Placement
188
+
189
+ Each embedding table within `DistributedEmbedding` can be either placed on
190
+ the SparseCore chip or the default device placement for the accelerator
191
+ (e.g. HBM of the Tensor Cores on TPU). This is controlled by the `placement`
192
+ attribute of `keras_rs.layers.TableConfig`.
193
+
194
+ - A placement of `"sparsecore"` indicates that the table should be placed on
195
+ the SparseCore chips. An error is raised if this option is selected and
196
+ there are no SparseCore chips.
197
+ - A placement of `"default_device"` indicates that the table should not be
198
+ placed on SparseCore, even if available. Instead the table is placed on
199
+ the device where the model normally goes, i.e. the HBM on TPUs and GPUs.
200
+ In this case, if applicable, the table is distributed using the scheme set
201
+ via `keras.distribution.set_distribution`. On GPUs, CPUs and TPUs without
202
+ SparseCore, this is the only placement available, and is the one selected
203
+ by `"auto"`.
204
+ - A placement of `"auto"` indicates to use `"sparsecore"` if available, and
205
+ `"default_device"` otherwise. This is the default when not specified.
206
+
207
+ To optimize performance on TPU:
208
+
209
+ - Tables that are so large that they need to be sharded should use the
210
+ `"sparsecore"` placement.
211
+ - Tables that are small enough should use `"default_device"` and should
212
+ typically be replicated across TPUs by using the
213
+ `keras.distribution.DataParallel` distribution option.
214
+
215
+ ### Usage with TensorFlow on TPU with SpareCore
216
+
217
+ #### Inputs
218
+
219
+ In addition to `tf.Tensor`, `DistributedEmbedding` accepts `tf.RaggedTensor`
220
+ and `tf.SparseTensor` as inputs for the embedding lookups. Ragged tensors
221
+ must be ragged in the dimension with index 1. Note that if weights are
222
+ passed, each weight tensor must be of the same class as the inputs for that
223
+ particular feature and use the exact same ragged row lenghts for ragged
224
+ tensors, and the same indices for sparse tensors. All the output of
225
+ `DistributedEmbedding` are dense tensors.
226
+
227
+ #### Setup
228
+
229
+ To use `DistributedEmbedding` on TPUs with TensorFlow, one must use a
230
+ `tf.distribute.TPUStrategy`. The `DistributedEmbedding` layer must be
231
+ created under the `TPUStrategy`.
232
+
233
+ ```python
234
+ resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="local")
235
+ topology = tf.tpu.experimental.initialize_tpu_system(resolver)
236
+ device_assignment = tf.tpu.experimental.DeviceAssignment.build(
237
+ topology, num_replicas=resolver.get_tpu_system_metadata().num_cores
238
+ )
239
+ strategy = tf.distribute.TPUStrategy(
240
+ resolver, experimental_device_assignment=device_assignment
241
+ )
242
+
243
+ with strategy.scope():
244
+ embedding = keras_rs.layers.DistributedEmbedding(feature_configs)
245
+ ```
246
+
247
+ #### Usage in a Keras model
248
+
249
+ To use Keras' `model.fit()`, one must compile the model under the
250
+ `TPUStrategy`. Then, `model.fit()`, `model.evaluate()` or `model.predict()`
251
+ can be called directly. The Keras model takes care of running the model
252
+ using the strategy and also automatically distributes the dataset.
253
+
254
+ ```python
255
+ with strategy.scope():
256
+ embedding = keras_rs.layers.DistributedEmbedding(feature_configs)
257
+ model = create_model(embedding)
258
+ model.compile(loss=keras.losses.MeanSquaredError(), optimizer="adam")
259
+
260
+ model.fit(dataset, epochs=10)
261
+ ```
262
+
263
+ #### Direct invocation
264
+
265
+ `DistributedEmbedding` must be invoked via a `strategy.run` call nested in a
266
+ `tf.function`.
267
+
268
+ ```python
269
+ @tf.function
270
+ def embedding_wrapper(tf_fn_inputs, tf_fn_weights=None):
271
+ def strategy_fn(st_fn_inputs, st_fn_weights):
272
+ return embedding(st_fn_inputs, st_fn_weights)
273
+
274
+ return strategy.run(strategy_fn, args=(tf_fn_inputs, tf_fn_weights)))
275
+
276
+ embedding_wrapper(my_inputs, my_weights)
277
+ ```
278
+
279
+ When using a dataset, the dataset must be distributed. The iterator can then
280
+ be passed to the `tf.function` that uses `strategy.run`.
281
+
282
+ ```python
283
+ dataset = strategy.experimental_distribute_dataset(dataset)
284
+
285
+ @tf.function
286
+ def run_loop(iterator):
287
+ def step(data):
288
+ (inputs, weights), labels = data
289
+ with tf.GradientTape() as tape:
290
+ result = embedding(inputs, weights)
291
+ loss = keras.losses.mean_squared_error(labels, result)
292
+ tape.gradient(loss, embedding.trainable_variables)
293
+ return result
294
+
295
+ for _ in tf.range(4):
296
+ result = strategy.run(step, args=(next(iterator),))
297
+
298
+ run_loop(iter(dataset))
299
+ ```
300
+
301
+ ### Usage with JAX on TPU with SpareCore
302
+
303
+ #### Setup
304
+
305
+ To use `DistributedEmbedding` on TPUs with JAX, one must create and set a
306
+ Keras `Distribution`.
307
+ ```python
308
+ distribution = keras.distribution.DataParallel(devices=jax.device("tpu"))
309
+ keras.distribution.set_distribution(distribution)
310
+ ```
311
+
312
+ #### Inputs
313
+
314
+ For JAX, inputs can either be dense tensors, or ragged (nested) NumPy
315
+ arrays. To enable `jit_compile = True`, one must explicitly call
316
+ `layer.preprocess(...)` on the inputs, and then feed the preprocessed
317
+ output to the model. See the next section on preprocessing for details.
318
+
319
+ Ragged input arrays must be ragged in the dimension with index 1. Note that
320
+ if weights are passed, each weight tensor must be of the same class as the
321
+ inputs for that particular feature and use the exact same ragged row lengths
322
+ for ragged tensors. All the output of `DistributedEmbedding` are dense
323
+ tensors.
324
+
325
+ #### Preprocessing
326
+
327
+ In JAX, SparseCore usage requires specially formatted data that depends
328
+ on properties of the available hardware. This data reformatting
329
+ currently does not support jit-compilation, so must be applied _prior_
330
+ to passing data into a model.
331
+
332
+ Preprocessing works on dense or ragged NumPy arrays, or on tensors that are
333
+ convertible to dense or ragged NumPy arrays like `tf.RaggedTensor`.
334
+
335
+ One simple way to add preprocessing is to append the function to an input
336
+ pipeline by using a python generator.
337
+ ```python
338
+ # Create the embedding layer.
339
+ embedding_layer = DistributedEmbedding(feature_configs)
340
+
341
+ # Add preprocessing to a data input pipeline.
342
+ def preprocessed_dataset_generator(dataset):
343
+ for (inputs, weights), labels in iter(dataset):
344
+ yield embedding_layer.preprocess(
345
+ inputs, weights, training=True
346
+ ), labels
347
+
348
+ preprocessed_train_dataset = preprocessed_dataset_generator(train_dataset)
349
+ ```
350
+ This explicit preprocessing stage combines the input and optional weights,
351
+ so the new data can be passed directly into the `inputs` argument of the
352
+ layer or model.
353
+
354
+ **NOTE**: When working in a multi-host setting with data parallelism, the
355
+ data needs to be sharded properly across hosts. If the original dataset is
356
+ of type `tf.data.Dataset`, it will need to be manually sharded _prior_ to
357
+ applying the preprocess generator:
358
+ ```python
359
+ # Manually shard the dataset across hosts.
360
+ train_dataset = distribution.distribute_dataset(train_dataset)
361
+ distribution.auto_shard_dataset = False # Dataset is already sharded.
362
+
363
+ # Add a preprocessing stage to the distributed data input pipeline.
364
+ train_dataset = preprocessed_dataset_generator(train_dataset)
365
+ ```
366
+ If the original dataset is _not_ a `tf.data.Dataset`, it must already be
367
+ pre-sharded across hosts.
368
+
369
+ #### Usage in a Keras model
370
+
371
+ Once the global distribution is set and the input preprocessing pipeline
372
+ is defined, model training can proceed as normal. For example:
373
+ ```python
374
+ # Construct, compile, and fit the model using the preprocessed data.
375
+ model = keras.Sequential(
376
+ [
377
+ embedding_layer,
378
+ keras.layers.Dense(2),
379
+ keras.layers.Dense(3),
380
+ keras.layers.Dense(4),
381
+ ]
382
+ )
383
+ model.compile(optimizer="adam", loss="mse", jit_compile=True)
384
+ model.fit(preprocessed_train_dataset, epochs=10)
385
+ ```
386
+
387
+ #### Direct invocation
388
+
389
+ The `DistributedEmbedding` layer can also be invoked directly. Explicit
390
+ preprocessing is required when used with JIT compilation.
391
+ ```python
392
+ # Call the layer directly.
393
+ activations = embedding_layer(my_inputs, my_weights)
394
+
395
+ # Call the layer with JIT compilation and explicitly preprocessed inputs.
396
+ embedding_layer_jit = jax.jit(embedding_layer)
397
+ preprocessed_inputs = embedding_layer.preprocess(my_inputs, my_weights)
398
+ activations = embedding_layer_jit(preprocessed_inputs)
399
+ ```
400
+
401
+ Similarly, for custom training loops, preprocessing must be applied prior
402
+ to passing the data to the JIT-compiled training step.
403
+ ```python
404
+ # Create an optimizer and loss function.
405
+ optimizer = keras.optimizers.Adam(learning_rate=1e-3)
406
+
407
+ def loss_and_updates(trainable_variables, non_trainable_variables, x, y):
408
+ y_pred, non_trainable_variables = model.stateless_call(
409
+ trainable_variables, non_trainable_variables, x, training=True
410
+ )
411
+ loss = keras.losses.mean_squared_error(y, y_pred)
412
+ return loss, non_trainable_variables
413
+
414
+ grad_fn = jax.value_and_grad(loss_and_updates, has_aux=True)
415
+
416
+ # Create a JIT-compiled training step.
417
+ @jax.jit
418
+ def train_step(state, x, y):
419
+ (
420
+ trainable_variables,
421
+ non_trainable_variables,
422
+ optimizer_variables,
423
+ ) = state
424
+ (loss, non_trainable_variables), grads = grad_fn(
425
+ trainable_variables, non_trainable_variables, x, y
426
+ )
427
+ trainable_variables, optimizer_variables = optimizer.stateless_apply(
428
+ optimizer_variables, grads, trainable_variables
429
+ )
430
+ return loss, (
431
+ trainable_variables,
432
+ non_trainable_variables,
433
+ optimizer_variables,
434
+ )
435
+
436
+ # Build optimizer variables.
437
+ optimizer.build(model.trainable_variables)
438
+
439
+ # Assemble the training state.
440
+ trainable_variables = model.trainable_variables
441
+ non_trainable_variables = model.non_trainable_variables
442
+ optimizer_variables = optimizer.variables
443
+ state = trainable_variables, non_trainable_variables, optimizer_variables
444
+
445
+ # Training loop.
446
+ for (inputs, weights), labels in train_dataset:
447
+ # Explicitly preprocess the data.
448
+ preprocessed_inputs = embedding_layer.preprocess(inputs, weights)
449
+ loss, state = train_step(state, preprocessed_inputs, labels)
450
+ ```
451
+
452
+ Args:
453
+ feature_configs: A nested structure of `keras_rs.layers.FeatureConfig`.
454
+ table_stacking: The table stacking to use. `None` means no table
455
+ stacking. `"auto"` means to stack tables automatically. A list of
456
+ table names or list of lists of table names means to stack the
457
+ tables in the inner lists together. Note that table stacking is not
458
+ supported on older TPUs, in which case the default value of `"auto"`
459
+ will be interpreted as no table stacking.
460
+ **kwargs: Additional arguments to pass to the layer base class.
461
+ """
462
+
463
+ def __init__(
464
+ self,
465
+ feature_configs: types.Nested[FeatureConfig],
466
+ *,
467
+ table_stacking: (
468
+ str | Sequence[str] | Sequence[Sequence[str]]
469
+ ) = "auto",
470
+ **kwargs: Any,
471
+ ) -> None:
472
+ super().__init__(**kwargs)
473
+
474
+ self._init_feature_configs_structures(feature_configs)
475
+
476
+ # Initialize for features placed on "sparsecore".
477
+ if "sparsecore" in self._placement_to_path_to_feature_config:
478
+ self._sparsecore_init(
479
+ self._placement_to_path_to_feature_config["sparsecore"],
480
+ table_stacking,
481
+ )
482
+ # Initialize for features placed on "default_device".
483
+ if "default_device" in self._placement_to_path_to_feature_config:
484
+ self._default_device_init(
485
+ self._placement_to_path_to_feature_config["default_device"],
486
+ table_stacking,
487
+ )
488
+
489
+ @keras_utils.no_automatic_dependency_tracking
490
+ def _init_feature_configs_structures(
491
+ self,
492
+ feature_configs: types.Nested[FeatureConfig],
493
+ ) -> None:
494
+ """Initializations for efficiently transforming nested structures.
495
+
496
+ This layer handles arbitrarily nested structures for input features, and
497
+ therefore for outputs and feature configs. However, as an intermediary
498
+ format we use a two-level representation with nested dicts. the top
499
+ level dict is keyed by placement and the inner dict is keyed by path,
500
+ with the path representing the path in the original deeply nested
501
+ structure. Thanks to this intermediate representation, we can:
502
+ - dispatch the inputs by placement to overridden methods
503
+ - have backend specific implementations support only one level of
504
+ nesting.
505
+
506
+ This method is responsible for creating structures that allow this
507
+ conversion to happen in a few lines of code and efficiently. The
508
+ following attributes are created:
509
+ - self._feature_configs: the deeply nested `FeatureConfig` instances as
510
+ provided by user in `__init__`
511
+ - self._feature_deeply_nested_placement_and_paths: `PlacementAndPath`
512
+ instances in the same deeply nested structure as
513
+ `self._feature_configs`. Needed for `build` because flatten cannot be
514
+ used as it would expand the shape tuples.
515
+ - self._placement_to_path_to_feature_config: `FeatureConfig` instances
516
+ in the same two-level representation keyed by placement and then path.
517
+ Used to go from a flat representation to the intermediate
518
+ representation.
519
+
520
+ With these structures in place, the steps to:
521
+ - go from the deeply nested structure to the two-level structure are:
522
+ - `assert_same_struct` as `self._feature_configs`
523
+ - use `self._feature_deeply_nested_placement_and_paths` to map from
524
+ deeply nested to two-level
525
+ - go from the two-level structure to the deeply nested structure:
526
+ - `assert_same_struct` as `self._placement_to_path_to_feature_config`
527
+ - use `self._feature_deeply_nested_placement_and_paths` to locate each
528
+ output in the two-level dicts
529
+
530
+ Args:
531
+ feature_configs: The deeply nested structure of `FeatureConfig` or
532
+ `tf.tpu.experimental.embedding.FeatureConfig` as provided by the
533
+ user.
534
+ """
535
+ # Needs to be assigned with `no_automatic_dependency_tracking` to not
536
+ # alter the data structure types.
537
+ self._feature_configs = feature_configs
538
+
539
+ placement_and_paths: list[PlacementAndPath] = []
540
+ paths_and_feature_configs = keras.tree.flatten_with_path(
541
+ self._feature_configs
542
+ )
543
+ self._placement_to_path_to_feature_config: dict[
544
+ str, dict[str, FeatureConfig]
545
+ ] = {}
546
+
547
+ # Lazily initialized.
548
+ has_sparsecore = None
549
+
550
+ for path, feature_config in paths_and_feature_configs:
551
+ if isinstance(feature_config, FeatureConfig):
552
+ placement = feature_config.table.placement
553
+ # Resolve "auto" to an actual placement.
554
+ if placement == "auto":
555
+ if has_sparsecore is None:
556
+ has_sparsecore = self._has_sparsecore()
557
+ placement = (
558
+ "sparsecore" if has_sparsecore else "default_device"
559
+ )
560
+ else:
561
+ # It's a `tf.tpu.experimental.embedding.FeatureConfig`.
562
+ placement = "sparsecore"
563
+
564
+ path = ".".join([str(e) for e in path])
565
+ if placement not in SUPPORTED_PLACEMENTS:
566
+ raise ValueError(
567
+ f"Feature '{path}' with name '{feature_config.name}' has "
568
+ f"unsupported placement '{placement}'."
569
+ )
570
+ placement_and_paths.append(PlacementAndPath(placement, path))
571
+ if placement not in self._placement_to_path_to_feature_config:
572
+ self._placement_to_path_to_feature_config[placement] = {}
573
+ self._placement_to_path_to_feature_config[placement][path] = (
574
+ feature_config
575
+ )
576
+
577
+ self._feature_deeply_nested_placement_and_paths = (
578
+ keras.tree.pack_sequence_as(
579
+ self._feature_configs, placement_and_paths
580
+ )
581
+ )
582
+
583
+ def build(self, input_shapes: types.Nested[types.Shape]) -> None:
584
+ if self.built:
585
+ return
586
+
587
+ self._verify_input_shapes(input_shapes)
588
+
589
+ # Go from deeply nested structure to placement -> path -> input shape.
590
+ placement_to_path_to_input_shape: collections.defaultdict[
591
+ str, dict[str, types.Shape]
592
+ ] = collections.defaultdict(dict)
593
+
594
+ def populate_placement_to_path_to_input_shape(
595
+ pp: PlacementAndPath, input_shape: types.Shape
596
+ ) -> None:
597
+ placement_to_path_to_input_shape[pp.placement][pp.path] = (
598
+ input_shape
599
+ )
600
+
601
+ keras.tree.map_structure_up_to(
602
+ self._feature_deeply_nested_placement_and_paths,
603
+ populate_placement_to_path_to_input_shape,
604
+ self._feature_deeply_nested_placement_and_paths,
605
+ input_shapes,
606
+ )
607
+
608
+ # Build for features placed on "sparsecore".
609
+ if "sparsecore" in placement_to_path_to_input_shape:
610
+ self._sparsecore_build(
611
+ placement_to_path_to_input_shape["sparsecore"]
612
+ )
613
+
614
+ # Build for features placed on "default_device".
615
+ if "default_device" in placement_to_path_to_input_shape:
616
+ self._default_device_build(
617
+ placement_to_path_to_input_shape["default_device"]
618
+ )
619
+
620
+ super().build(input_shapes)
621
+
622
+ def preprocess(
623
+ self,
624
+ inputs: types.Nested[types.Tensor],
625
+ weights: types.Nested[types.Tensor] | None = None,
626
+ training: bool = False,
627
+ ) -> types.Nested[types.Tensor]:
628
+ """Preprocesses and reformats the data for consumption by the model.
629
+
630
+ For the JAX backend, converts the input data to a hardward-dependent
631
+ format required for use with SparseCores. Calling `preprocess`
632
+ explicitly is only necessary to enable `jit_compile = True`.
633
+
634
+ For non-JAX backends, preprocessing will bundle together the inputs and
635
+ weights, and separate the inputs by device placement. This step is
636
+ entirely optional.
637
+
638
+ Args:
639
+ inputs: Ragged or dense set of sample IDs.
640
+ weights: Optional ragged or dense set of sample weights.
641
+ training: If true, will update internal parameters, such as
642
+ required buffer sizes for the preprocessed data.
643
+
644
+ Returns:
645
+ Set of preprocessed inputs that can be fed directly into the
646
+ `inputs` argument of the layer.
647
+ """
648
+ # Verify input structure.
649
+ keras.tree.assert_same_structure(self._feature_configs, inputs)
650
+ if weights is not None:
651
+ keras.tree.assert_same_structure(self._feature_configs, weights)
652
+
653
+ if not self.built:
654
+ input_shapes = keras.tree.map_structure(
655
+ lambda array: backend.standardize_shape(array.shape),
656
+ inputs,
657
+ )
658
+ self.build(input_shapes)
659
+
660
+ # Go from deeply nested to nested dict placement -> path -> input.
661
+ def to_placement_to_path(
662
+ tensors: types.Nested[types.Tensor],
663
+ ) -> dict[str, dict[str, types.Tensor]]:
664
+ result: dict[str, dict[str, types.Tensor]] = {
665
+ p: dict() for p in self._placement_to_path_to_feature_config
666
+ }
667
+
668
+ def populate(pp: PlacementAndPath, x: types.Tensor) -> None:
669
+ result[pp.placement][pp.path] = x
670
+
671
+ keras.tree.map_structure(
672
+ populate,
673
+ self._feature_deeply_nested_placement_and_paths,
674
+ tensors,
675
+ )
676
+ return result
677
+
678
+ placement_to_path_to_inputs = to_placement_to_path(inputs)
679
+
680
+ # Same for weights if present.
681
+ placement_to_path_to_weights = (
682
+ to_placement_to_path(weights) if weights is not None else None
683
+ )
684
+
685
+ placement_to_path_to_preprocessed: dict[
686
+ str, dict[str, dict[str, types.Nested[types.Tensor]]]
687
+ ] = {}
688
+
689
+ # Preprocess for features placed on "sparsecore".
690
+ if "sparsecore" in placement_to_path_to_inputs:
691
+ placement_to_path_to_preprocessed["sparsecore"] = (
692
+ self._sparsecore_preprocess(
693
+ placement_to_path_to_inputs["sparsecore"],
694
+ placement_to_path_to_weights["sparsecore"]
695
+ if placement_to_path_to_weights is not None
696
+ else None,
697
+ training,
698
+ )
699
+ )
700
+
701
+ # Preprocess for features placed on "default_device".
702
+ if "default_device" in placement_to_path_to_inputs:
703
+ placement_to_path_to_preprocessed["default_device"] = (
704
+ self._default_device_preprocess(
705
+ placement_to_path_to_inputs["default_device"],
706
+ placement_to_path_to_weights["default_device"]
707
+ if placement_to_path_to_weights is not None
708
+ else None,
709
+ training,
710
+ )
711
+ )
712
+
713
+ # Mark inputs as preprocessed using an extra level of nesting.
714
+ # This is necessary to detect whether inputs are already preprocessed
715
+ # in `call`.
716
+ output = {
717
+ "preprocessed_inputs_per_placement": (
718
+ placement_to_path_to_preprocessed
719
+ )
720
+ }
721
+ return output
722
+
723
+ def _is_preprocessed(
724
+ self, inputs: types.Nested[types.Tensor | types.Shape]
725
+ ) -> bool:
726
+ """Checks if the input is already preprocessed."""
727
+ return (
728
+ isinstance(inputs, dict)
729
+ and "preprocessed_inputs_per_placement" in inputs
730
+ )
731
+
732
+ def call(
733
+ self,
734
+ inputs: types.Nested[types.Tensor],
735
+ weights: types.Nested[types.Tensor] | None = None,
736
+ training: bool = False,
737
+ ) -> types.Nested[types.Tensor]:
738
+ """Lookup features in embedding tables and apply reduction.
739
+
740
+ Args:
741
+ inputs: A nested structure of 2D tensors to embed and reduce. The
742
+ structure must be the same as the `feature_configs` passed
743
+ during construction. Alternatively, may consist of already
744
+ preprocessed inputs (see `preprocess`).
745
+ weights: An optional nested structure of 2D tensors of weights to
746
+ apply before reduction. When present, the structure must be the
747
+ same as `inputs` and the shapes must match.
748
+ training: Whether we are training or evaluating the model.
749
+
750
+ Returns:
751
+ A nested structure of dense 2D tensors, which are the reduced
752
+ embeddings from the passed features. The structure is the same as
753
+ `inputs`.
754
+ """
755
+ preprocessed_inputs = inputs
756
+ # Preprocess if not already done.
757
+ if not self._is_preprocessed(inputs):
758
+ preprocessed_inputs = self.preprocess(inputs, weights, training)
759
+
760
+ preprocessed_inputs = typing.cast(
761
+ dict[str, dict[str, dict[str, types.Tensor]]], preprocessed_inputs
762
+ )
763
+ # Placement -> path -> preprocessed inputs.
764
+ preprocessed_inputs = preprocessed_inputs[
765
+ "preprocessed_inputs_per_placement"
766
+ ]
767
+
768
+ placement_to_path_to_outputs = {}
769
+
770
+ # Call for features placed on "sparsecore".
771
+ if "sparsecore" in preprocessed_inputs:
772
+ inputs_and_weights = preprocessed_inputs["sparsecore"]
773
+ placement_to_path_to_outputs["sparsecore"] = self._sparsecore_call(
774
+ **inputs_and_weights,
775
+ training=training,
776
+ )
777
+
778
+ # Call for features placed on "default_device".
779
+ if "default_device" in preprocessed_inputs:
780
+ inputs_and_weights = preprocessed_inputs["default_device"]
781
+ placement_to_path_to_outputs["default_device"] = (
782
+ self._default_device_call(
783
+ **inputs_and_weights,
784
+ training=training,
785
+ )
786
+ )
787
+
788
+ # Verify output structure.
789
+ keras.tree.assert_same_structure(
790
+ self._placement_to_path_to_feature_config,
791
+ placement_to_path_to_outputs,
792
+ )
793
+
794
+ # Go from placement -> path -> output to deeply nested structure.
795
+ def populate_output(pp: PlacementAndPath) -> types.Tensor:
796
+ return placement_to_path_to_outputs[pp.placement][pp.path]
797
+
798
+ return keras.tree.map_structure(
799
+ populate_output, self._feature_deeply_nested_placement_and_paths
800
+ )
801
+
802
+ def get_embedding_tables(self) -> dict[str, types.Tensor]:
803
+ """Return the content of the embedding tables by table name.
804
+
805
+ The tables are keyed by the name provided in each `TableConfig`. Note
806
+ that the returned tensors are not the actual embedding table variables
807
+ used internally by `DistributedEmbedding`.
808
+
809
+ Returns:
810
+ A dictionary of table name to tensor for the embedding tables.
811
+ """
812
+ tables = {}
813
+ if "sparsecore" in self._placement_to_path_to_feature_config:
814
+ tables.update(self._sparsecore_get_embedding_tables())
815
+ if "default_device" in self._placement_to_path_to_feature_config:
816
+ tables.update(self._default_device_get_embedding_tables())
817
+ return tables
818
+
819
+ def _default_device_init(
820
+ self,
821
+ feature_configs: dict[str, FeatureConfig],
822
+ table_stacking: str | Sequence[Sequence[str]],
823
+ ) -> None:
824
+ del table_stacking
825
+ table_config_id_to_embedding_layer: dict[int, EmbedReduce] = {}
826
+ self._default_device_embedding_layers: dict[str, EmbedReduce] = {}
827
+
828
+ for path, feature_config in feature_configs.items():
829
+ if id(feature_config.table) in table_config_id_to_embedding_layer:
830
+ self._default_device_embedding_layers[path] = (
831
+ table_config_id_to_embedding_layer[id(feature_config.table)]
832
+ )
833
+ else:
834
+ embedding_layer = EmbedReduce(
835
+ name=feature_config.table.name,
836
+ input_dim=feature_config.table.vocabulary_size,
837
+ output_dim=feature_config.table.embedding_dim,
838
+ embeddings_initializer=feature_config.table.initializer,
839
+ combiner=feature_config.table.combiner,
840
+ )
841
+ table_config_id_to_embedding_layer[id(feature_config.table)] = (
842
+ embedding_layer
843
+ )
844
+ self._default_device_embedding_layers[path] = embedding_layer
845
+
846
+ def _default_device_build(
847
+ self, input_shapes: dict[str, types.Shape]
848
+ ) -> None:
849
+ for path, input_shape in input_shapes.items():
850
+ embedding_layer = self._default_device_embedding_layers[path]
851
+ if not embedding_layer.built:
852
+ embedding_layer.build(input_shape)
853
+
854
+ def _default_device_preprocess(
855
+ self,
856
+ inputs: dict[str, types.Tensor],
857
+ weights: dict[str, types.Tensor] | None,
858
+ training: bool = False,
859
+ ) -> dict[str, dict[str, types.Tensor]]:
860
+ del training
861
+
862
+ # NOTE: This JAX specialization is in the base layer so it is available
863
+ # on all platforms. The superclass jax.DistributedEmbedding layer
864
+ # is currently only imported in linux_x86_64.
865
+ if keras.backend.backend() == "jax":
866
+ feature_configs = self._placement_to_path_to_feature_config[
867
+ "default_device"
868
+ ]
869
+
870
+ # Potentially track new weights. For ragged inputs, if we
871
+ # densify, we will generate a dense weight tensor.
872
+ new_weights: dict[str, types.Tensor] = {}
873
+ use_weights = weights is not None
874
+
875
+ # Convert any ragged inputs to dense.
876
+ for path, config in feature_configs.items():
877
+ feature_inputs = inputs[path]
878
+ feature_weights = weights[path] if weights is not None else None
879
+
880
+ feature_valence = (
881
+ None
882
+ if len(config.input_shape) <= 1
883
+ else config.input_shape[1]
884
+ )
885
+ feature_inputs, feature_weights = _ragged_to_dense_inputs(
886
+ feature_inputs, feature_weights, feature_valence
887
+ )
888
+ # Converting to ragged may have introduced a weights array.
889
+ use_weights = use_weights or feature_weights is not None
890
+ inputs[path] = feature_inputs
891
+ new_weights[path] = feature_weights
892
+
893
+ if use_weights:
894
+ weights = new_weights
895
+
896
+ output: dict[str, types.Tensor] = {"inputs": inputs}
897
+ if weights is not None:
898
+ output["weights"] = weights
899
+
900
+ return output
901
+
902
+ def _default_device_call(
903
+ self,
904
+ inputs: dict[str, types.Tensor],
905
+ weights: dict[str, types.Tensor] | None = None,
906
+ training: bool = False,
907
+ ) -> dict[str, types.Tensor]:
908
+ del training # Unused by default.
909
+ if weights is None:
910
+ return {
911
+ path: self._default_device_embedding_layers[path](x)
912
+ for path, x in inputs.items()
913
+ }
914
+ else:
915
+ return {
916
+ path: self._default_device_embedding_layers[path](
917
+ x, weights[path]
918
+ )
919
+ for path, x in inputs.items()
920
+ }
921
+
922
+ def _default_device_get_embedding_tables(self) -> dict[str, types.Tensor]:
923
+ tables = {}
924
+ for path, feature_config in self._placement_to_path_to_feature_config[
925
+ "default_device"
926
+ ].items():
927
+ tables[feature_config.table.name] = (
928
+ self._default_device_embedding_layers[path].embeddings.value
929
+ )
930
+ return tables
931
+
932
+ def _has_sparsecore(self) -> bool:
933
+ # Explicitly check for SparseCore availability.
934
+ # We need this check here rather than in jax/distributed_embedding.py
935
+ # so that we can warn the user about missing dependencies.
936
+ if keras.backend.backend() == "jax":
937
+ # Check if SparseCores are available.
938
+ try:
939
+ import jax
940
+
941
+ tpu_devices = jax.devices("tpu")
942
+ except RuntimeError:
943
+ # No TPUs available.
944
+ return False
945
+
946
+ if len(tpu_devices) > 0:
947
+ device_kind = tpu_devices[0].device_kind
948
+ if device_kind in ["TPU v5", "TPU v6 lite"]:
949
+ return True
950
+
951
+ return False
952
+
953
+ def _sparsecore_init(
954
+ self,
955
+ feature_configs: dict[str, FeatureConfig],
956
+ table_stacking: str | Sequence[Sequence[str]],
957
+ ) -> None:
958
+ del feature_configs, table_stacking
959
+
960
+ if keras.backend.backend() == "jax":
961
+ jax_tpu_embedding_spec = importlib.util.find_spec(
962
+ "jax_tpu_embedding"
963
+ )
964
+ if jax_tpu_embedding_spec is None:
965
+ raise ImportError(
966
+ "Please install jax-tpu-embedding to use "
967
+ "DistributedEmbedding on sparsecore devices."
968
+ )
969
+
970
+ raise self._unsupported_placement_error("sparsecore")
971
+
972
+ def _sparsecore_build(self, input_shapes: dict[str, types.Shape]) -> None:
973
+ del input_shapes
974
+ raise self._unsupported_placement_error("sparsecore")
975
+
976
+ def _sparsecore_preprocess(
977
+ self,
978
+ inputs: dict[str, types.Tensor],
979
+ weights: dict[str, types.Tensor] | None,
980
+ training: bool = False,
981
+ ) -> dict[str, dict[str, types.Tensor]]:
982
+ del training
983
+ output: dict[str, types.Tensor] = {"inputs": inputs}
984
+ if weights is not None:
985
+ output["weights"] = weights
986
+
987
+ return output
988
+
989
+ def _sparsecore_call(
990
+ self,
991
+ inputs: dict[str, types.Tensor],
992
+ weights: dict[str, types.Tensor] | None = None,
993
+ training: bool = False,
994
+ ) -> dict[str, types.Tensor]:
995
+ del inputs, weights, training
996
+ raise self._unsupported_placement_error("sparsecore")
997
+
998
+ def _sparsecore_get_embedding_tables(self) -> dict[str, types.Tensor]:
999
+ raise self._unsupported_placement_error("sparsecore")
1000
+
1001
+ def compute_output_shape(
1002
+ self, input_shapes: types.Nested[types.Shape]
1003
+ ) -> types.Nested[types.Shape]:
1004
+ self._verify_input_shapes(input_shapes)
1005
+ output_shape: types.Nested[types.Shape] = keras.tree.map_structure(
1006
+ lambda fc: fc.output_shape, self._feature_configs
1007
+ )
1008
+ return output_shape
1009
+
1010
+ def get_config(self) -> dict[str, Any]:
1011
+ # Because the Keras serialization creates a tree of serialized objects,
1012
+ # it does not directly support sharing tables between feature configs.
1013
+ # We therefore serialize the tables config as a flat list and then refer
1014
+ # to them by index in each feature config.
1015
+
1016
+ # The serialized `TableConfig` objects.
1017
+ table_config_dicts: list[dict[str, Any]] = []
1018
+ # Mapping from `TableConfig` id to index in `table_config_dicts`.
1019
+ table_config_id_to_index: dict[int, int] = {}
1020
+
1021
+ def serialize_feature_config(
1022
+ feature_config: FeatureConfig,
1023
+ ) -> dict[str, Any]:
1024
+ # Note that for consistency with the contract of `get_config`, the
1025
+ # returned dict contains the serialized `TableConfig` in the "table"
1026
+ # key.
1027
+ feature_config_dict = feature_config.get_config()
1028
+
1029
+ if id(feature_config.table) not in table_config_id_to_index:
1030
+ # Save the serialized `TableConfig` the first time we see it and
1031
+ # remember its index.
1032
+ table_config_id_to_index[id(feature_config.table)] = len(
1033
+ table_config_dicts
1034
+ )
1035
+ table_config_dicts.append(feature_config_dict["table"])
1036
+
1037
+ # Replace the serialized `TableConfig` with its index.
1038
+ feature_config_dict["table"] = table_config_id_to_index[
1039
+ id(feature_config.table)
1040
+ ]
1041
+ return feature_config_dict
1042
+
1043
+ config: dict[str, Any] = super().get_config()
1044
+ config["feature_configs"] = keras.tree.map_structure(
1045
+ serialize_feature_config, self._feature_configs
1046
+ )
1047
+ config["tables"] = table_config_dicts
1048
+ if hasattr(self, "_table_stacking"):
1049
+ config["table_stacking"] = self._table_stacking
1050
+ return config
1051
+
1052
+ @classmethod
1053
+ def from_config(cls, config: dict[str, Any]) -> "DistributedEmbedding":
1054
+ config = config.copy()
1055
+ # We need to reconnect the `TableConfig`s to the `FeatureConfig`s.
1056
+
1057
+ # The serialized `TableConfig` objects.
1058
+ table_config_dicts: list[dict[str, Any]] = config.pop("tables")
1059
+ # The deserialized `TableConfig` objects at the same indices.
1060
+ table_configs: list[TableConfig | None] = [None] * len(
1061
+ table_config_dicts
1062
+ )
1063
+
1064
+ def deserialize_feature_config(
1065
+ feature_config_dict: dict[str, Any],
1066
+ ) -> FeatureConfig | None:
1067
+ # Look for a "name" attribute which is a string to detect a
1068
+ # `FeatureConfig` leaf node. If not, keep recursing.
1069
+ if "name" not in feature_config_dict or not isinstance(
1070
+ feature_config_dict["name"], str
1071
+ ):
1072
+ # Tell `traverse` to recurse.
1073
+ return None
1074
+
1075
+ table_index = feature_config_dict["table"]
1076
+ # Note that for consistency with the contract of `from_config`, the
1077
+ # passed dict must contain the serialized `TableConfig` in the
1078
+ # "table" key.
1079
+ feature_config_dict["table"] = table_config_dicts[table_index]
1080
+ feature_config = FeatureConfig.from_config(feature_config_dict)
1081
+ # But then dedupe `TableConfig`s.
1082
+ if table_configs[table_index] is None:
1083
+ # Remember each new `TableConfig` we see.
1084
+ table_configs[table_index] = feature_config.table
1085
+ else:
1086
+ # And swap duplicates for the original.
1087
+ feature_config.table = table_configs[table_index]
1088
+ return feature_config
1089
+
1090
+ # Because each `FeatureConfig` is serialized as a dict, we cannot use
1091
+ # `map_structure` as it would recurse in the config itself. We use
1092
+ # `traverse` instead with a function that detects leaf nodes.
1093
+ config["feature_configs"] = keras.tree.traverse(
1094
+ deserialize_feature_config, config["feature_configs"]
1095
+ )
1096
+ return cls(**config)
1097
+
1098
+ def _verify_input_shapes(
1099
+ self, input_shapes: types.Nested[types.Shape]
1100
+ ) -> None:
1101
+ """Verifies that the input shapes match the ones in the feature configs.
1102
+
1103
+ Args:
1104
+ input_shapes: The structure of input shapes to verify.
1105
+ """
1106
+ # Support preprocessing.
1107
+ if self._is_preprocessed(input_shapes):
1108
+ # Structure should be :
1109
+ # {
1110
+ # placement: {
1111
+ # inputs: {path: Any},
1112
+ # weights: {path: Any}
1113
+ # }
1114
+ # }
1115
+ #
1116
+ # But the `Any` values could be nested tensors with varying
1117
+ # structure, depending on hardware constraints. This complicates
1118
+ # checking shapes via keras.tree methods. So, assume the
1119
+ # input is a result of explicitly calling the `preprocess(...)`
1120
+ # function, in which case the structure has already been verified.
1121
+ return
1122
+
1123
+ def _verify_input_shape(
1124
+ feature_config: FeatureConfig,
1125
+ input_shape: types.Shape,
1126
+ ) -> None:
1127
+ if not isinstance(input_shape, (tuple, list)) or not all(
1128
+ isinstance(d, (int, type(None))) for d in input_shape
1129
+ ):
1130
+ raise ValueError(f"Received invalid input shape {input_shape}.")
1131
+ if len(input_shape) < 1:
1132
+ raise ValueError(
1133
+ f"Received input shape {input_shape}. Rank must be 1 or "
1134
+ "above."
1135
+ )
1136
+ keras_utils.check_shapes_compatible(
1137
+ feature_config.input_shape, input_shape
1138
+ )
1139
+
1140
+ keras.tree.map_structure_up_to(
1141
+ self._feature_configs,
1142
+ _verify_input_shape,
1143
+ self._feature_configs,
1144
+ input_shapes,
1145
+ )
1146
+
1147
+ def _unsupported_placement_error(self, placement: str) -> Exception:
1148
+ return NotImplementedError(
1149
+ f"Backend '{keras.backend.backend()}' does not support the "
1150
+ f"'{placement}' placement."
1151
+ )