keras-rs-nightly 0.0.1.dev2025050103__py3-none-any.whl → 0.2.2.dev202506100336__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of keras-rs-nightly might be problematic. Click here for more details.

Files changed (33) hide show
  1. keras_rs/layers/__init__.py +12 -0
  2. keras_rs/src/layers/embedding/__init__.py +0 -0
  3. keras_rs/src/layers/embedding/base_distributed_embedding.py +1124 -0
  4. keras_rs/src/layers/embedding/distributed_embedding.py +33 -0
  5. keras_rs/src/layers/embedding/distributed_embedding_config.py +129 -0
  6. keras_rs/src/layers/embedding/embed_reduce.py +309 -0
  7. keras_rs/src/layers/embedding/jax/__init__.py +0 -0
  8. keras_rs/src/layers/embedding/jax/config_conversion.py +398 -0
  9. keras_rs/src/layers/embedding/jax/distributed_embedding.py +892 -0
  10. keras_rs/src/layers/embedding/jax/embedding_lookup.py +255 -0
  11. keras_rs/src/layers/embedding/jax/embedding_utils.py +596 -0
  12. keras_rs/src/layers/embedding/tensorflow/__init__.py +0 -0
  13. keras_rs/src/layers/embedding/tensorflow/config_conversion.py +323 -0
  14. keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py +424 -0
  15. keras_rs/src/layers/feature_interaction/dot_interaction.py +2 -2
  16. keras_rs/src/layers/feature_interaction/feature_cross.py +14 -16
  17. keras_rs/src/layers/retrieval/brute_force_retrieval.py +5 -5
  18. keras_rs/src/layers/retrieval/retrieval.py +4 -4
  19. keras_rs/src/losses/pairwise_loss.py +2 -2
  20. keras_rs/src/losses/pairwise_mean_squared_error.py +1 -3
  21. keras_rs/src/metrics/dcg.py +2 -2
  22. keras_rs/src/metrics/ndcg.py +2 -2
  23. keras_rs/src/metrics/ranking_metric.py +4 -4
  24. keras_rs/src/metrics/ranking_metrics_utils.py +8 -8
  25. keras_rs/src/metrics/utils.py +2 -4
  26. keras_rs/src/types.py +43 -14
  27. keras_rs/src/utils/keras_utils.py +26 -6
  28. keras_rs/src/version.py +1 -1
  29. {keras_rs_nightly-0.0.1.dev2025050103.dist-info → keras_rs_nightly-0.2.2.dev202506100336.dist-info}/METADATA +6 -3
  30. keras_rs_nightly-0.2.2.dev202506100336.dist-info/RECORD +55 -0
  31. {keras_rs_nightly-0.0.1.dev2025050103.dist-info → keras_rs_nightly-0.2.2.dev202506100336.dist-info}/WHEEL +1 -1
  32. keras_rs_nightly-0.0.1.dev2025050103.dist-info/RECORD +0 -42
  33. {keras_rs_nightly-0.0.1.dev2025050103.dist-info → keras_rs_nightly-0.2.2.dev202506100336.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1124 @@
1
+ import collections
2
+ import importlib.util
3
+ import typing
4
+ from typing import Any, Sequence
5
+
6
+ import keras
7
+ import numpy as np
8
+ from keras.src import backend
9
+
10
+ from keras_rs.src import types
11
+ from keras_rs.src.layers.embedding import distributed_embedding_config
12
+ from keras_rs.src.layers.embedding import embed_reduce
13
+ from keras_rs.src.utils import keras_utils
14
+
15
+ FeatureConfig = distributed_embedding_config.FeatureConfig
16
+ TableConfig = distributed_embedding_config.TableConfig
17
+ EmbedReduce = embed_reduce.EmbedReduce
18
+
19
+
20
+ SUPPORTED_PLACEMENTS = ("auto", "default_device", "sparsecore")
21
+
22
+
23
+ PlacementAndPath = collections.namedtuple(
24
+ "PlacementAndPath", ["placement", "path"]
25
+ )
26
+
27
+
28
+ def _ragged_to_dense_inputs(
29
+ inputs: Any, weights: Any | None = None, dense_row_length: int | None = None
30
+ ) -> Any:
31
+ """Converts a ragged set of inputs and weights to dense.
32
+
33
+ If inputs are ragged and weights are `None`, will create a dense set of
34
+ weights to mask out the new padded values.
35
+
36
+ If inputs are not ragged, returns the original `inputs` and `weights`
37
+ unmodified.
38
+
39
+ Args:
40
+ inputs: The inputs array.
41
+ weights: The optional weights array.
42
+ dense_row_length: The output dense row length. If None, uses the length
43
+ of the longest row of the input.
44
+
45
+ Returns:
46
+ Tuple of new (inputs, weights). If the input is a ragged array, returns
47
+ dense numpy arrays. Otherwise, returns the original input and weights.
48
+ """
49
+ x = inputs
50
+ w = weights
51
+ # tf.Ragged or other .numpy()-able types.
52
+ if hasattr(x, "numpy") and callable(getattr(x, "numpy")):
53
+ x = x.numpy()
54
+
55
+ # Ragged numpy array to dense numpy array.
56
+ if isinstance(x, np.ndarray) and len(x) > 0 and x.dtype == np.ndarray:
57
+ # Maybe convert weights to numpy.
58
+ if (
59
+ w is not None
60
+ and hasattr(w, "numpy")
61
+ and callable(getattr(w, "numpy"))
62
+ ):
63
+ w = w.numpy()
64
+
65
+ if dense_row_length is None:
66
+ # Use length of longest row.
67
+ dense_row_length = max([len(row) for row in x])
68
+
69
+ output = np.zeros((len(x), dense_row_length), dtype=x[0].dtype)
70
+ for i, row in enumerate(x):
71
+ output[i, : len(row)] = row
72
+
73
+ output_weights = np.zeros((len(x), dense_row_length), dtype=np.float32)
74
+ if w is None:
75
+ for i, row in enumerate(x):
76
+ output_weights[i, : len(row)] = 1.0
77
+ else:
78
+ for i, row in enumerate(w):
79
+ output_weights[i, : len(row)] = row
80
+
81
+ return output, output_weights
82
+
83
+ # Convert symbolic ragged/sparse keras tensors to dense tensors.
84
+ if isinstance(x, keras.KerasTensor) and (x.ragged or x.sparse):
85
+ inputs = keras.ops.convert_to_tensor(x, ragged=False)
86
+ weights = keras.ops.convert_to_tensor(x, dtype="float32", ragged=False)
87
+
88
+ # If not a ragged array, return the original, unmodified.
89
+ return inputs, weights
90
+
91
+
92
+ class DistributedEmbedding(keras.layers.Layer):
93
+ """DistributedEmbedding, a layer for accelerated large embedding lookups.
94
+
95
+ ---
96
+
97
+ ## Note: `DistributedEmbedding` is in Preview.
98
+
99
+ ---
100
+
101
+ `DistributedEmbedding` is a layer optimized for TPU chips with SparseCore
102
+ and can dramatically improve the speed of embedding lookups and embedding
103
+ training. It works by combining multiple lookups into one invocation, and by
104
+ sharding the embedding tables across the available chips. Note that one will
105
+ only see performance benefits for embedding tables that are large enough to
106
+ to require sharding because they don't fit on a single chip. More details
107
+ are provided in the "Placement" section below.
108
+
109
+ On other hardware, GPUs, CPUs and TPUs without SparseCore,
110
+ `DistributedEmbedding` provides the same API without any specific
111
+ acceleration. No particular distribution scheme is applied besides the one
112
+ set via `keras.distribution.set_distribution`.
113
+
114
+ `DistributedEmbedding` embeds sequences of inputs and reduces them to a
115
+ single embedding by applying a configurable combiner function.
116
+
117
+ ### Configuration
118
+
119
+ #### Features and tables
120
+
121
+ A `DistributedEmbedding` embedding layer is configured via a set of
122
+ `keras_rs.layers.FeatureConfig` objects, which themselves refer to
123
+ `keras_rs.layers.TableConfig` objects.
124
+
125
+ - `TableConfig` defines an embedding table with parameters such as its
126
+ vocabulary size, embedding dimension, as well as a combiner for reduction
127
+ and optimizer for training.
128
+ - `FeatureConfig` defines what input features the `DistributedEmbedding`
129
+ will handle and which embedding table to use. Note that multiple features
130
+ can use the same embedding table.
131
+
132
+ ```python
133
+ table1 = keras_rs.layers.TableConfig(
134
+ name="table1",
135
+ vocabulary_size=TABLE1_VOCABULARY_SIZE,
136
+ embedding_dim=TABLE1_EMBEDDING_SIZE,
137
+ placement="auto",
138
+ )
139
+ table2 = keras_rs.layers.TableConfig(
140
+ name="table2",
141
+ vocabulary_size=TABLE2_VOCABULARY_SIZE,
142
+ embedding_dim=TABLE2_EMBEDDING_SIZE,
143
+ placement="auto",
144
+ )
145
+
146
+ feature1 = keras_rs.layers.FeatureConfig(
147
+ name="feature1",
148
+ table=table1,
149
+ input_shape=(PER_REPLICA_BATCH_SIZE,),
150
+ output_shape=(PER_REPLICA_BATCH_SIZE, TABLE1_EMBEDDING_SIZE),
151
+ )
152
+ feature2 = keras_rs.layers.FeatureConfig(
153
+ name="feature2",
154
+ table=table2,
155
+ input_shape=(PER_REPLICA_BATCH_SIZE,),
156
+ output_shape=(PER_REPLICA_BATCH_SIZE, TABLE2_EMBEDDING_SIZE),
157
+ )
158
+
159
+ feature_configs = {
160
+ "feature1": feature1,
161
+ "feature2": feature2,
162
+ }
163
+
164
+ embedding = keras_rs.layers.DistributedEmbedding(feature_configs)
165
+ ```
166
+
167
+ #### Optimizers
168
+
169
+ Each embedding table within `DistributedEmbedding` uses its own optimizer
170
+ for training, which is independent from the optimizer set on the model via
171
+ `model.compile()`.
172
+
173
+ Note that not all optimizers are supported. Currently, the following are
174
+ supported on all backends and accelerators:
175
+
176
+ - `keras.optimizers.Adagrad`
177
+ - `keras.optimizers.SGD`
178
+
179
+ The following are additionally available when using the TensorFlow backend:
180
+
181
+ - `keras.optimizers.Adam`
182
+ - `keras.optimizers.Ftrl`
183
+
184
+ Also, not all parameters of the optimizers are supported (e.g. the
185
+ `nesterov` option of `SGD`). An error is raised when an unsupported
186
+ optimizer or an unsupported optimizer parameter is used.
187
+
188
+ #### Placement
189
+
190
+ Each embedding table within `DistributedEmbedding` can be either placed on
191
+ the SparseCore chip or the default device placement for the accelerator
192
+ (e.g. HBM of the Tensor Cores on TPU). This is controlled by the `placement`
193
+ attribute of `keras_rs.layers.TableConfig`.
194
+
195
+ - A placement of `"sparsecore"` indicates that the table should be placed on
196
+ the SparseCore chips. An error is raised if this option is selected and
197
+ there are no SparseCore chips.
198
+ - A placement of `"default_device"` indicates that the table should not be
199
+ placed on SparseCore, even if available. Instead the table is placed on
200
+ the device where the model normally goes, i.e. the HBM on TPUs and GPUs.
201
+ In this case, if applicable, the table is distributed using the scheme set
202
+ via `keras.distribution.set_distribution`. On GPUs, CPUs and TPUs without
203
+ SparseCore, this is the only placement available, and is the one selected
204
+ by `"auto"`.
205
+ - A placement of `"auto"` indicates to use `"sparsecore"` if available, and
206
+ `"default_device"` otherwise. This is the default when not specified.
207
+
208
+ To optimize performance on TPU:
209
+
210
+ - Tables that are so large that they need to be sharded should use the
211
+ `"sparsecore"` placement.
212
+ - Tables that are small enough should use `"default_device"` and should
213
+ typically be replicated across TPUs by using the
214
+ `keras.distribution.DataParallel` distribution option.
215
+
216
+ ### Usage with TensorFlow on TPU with SpareCore
217
+
218
+ #### Inputs
219
+
220
+ In addition to `tf.Tensor`, `DistributedEmbedding` accepts `tf.RaggedTensor`
221
+ and `tf.SparseTensor` as inputs for the embedding lookups. Ragged tensors
222
+ must be ragged in the dimension with index 1. Note that if weights are
223
+ passed, each weight tensor must be of the same class as the inputs for that
224
+ particular feature and use the exact same ragged row lenghts for ragged
225
+ tensors, and the same indices for sparse tensors. All the output of
226
+ `DistributedEmbedding` are dense tensors.
227
+
228
+ #### Setup
229
+
230
+ To use `DistributedEmbedding` on TPUs with TensorFlow, one must use a
231
+ `tf.distribute.TPUStrategy`. The `DistributedEmbedding` layer must be
232
+ created under the `TPUStrategy`.
233
+
234
+ ```python
235
+ resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="local")
236
+ topology = tf.tpu.experimental.initialize_tpu_system(resolver)
237
+ device_assignment = tf.tpu.experimental.DeviceAssignment.build(
238
+ topology, num_replicas=resolver.get_tpu_system_metadata().num_cores
239
+ )
240
+ strategy = tf.distribute.TPUStrategy(
241
+ resolver, experimental_device_assignment=device_assignment
242
+ )
243
+
244
+ with strategy.scope():
245
+ embedding = keras_rs.layers.DistributedEmbedding(feature_configs)
246
+ ```
247
+
248
+ #### Usage in a Keras model
249
+
250
+ To use Keras' `model.fit()`, one must compile the model under the
251
+ `TPUStrategy`. Then, `model.fit()`, `model.evaluate()` or `model.predict()`
252
+ can be called directly. The Keras model takes care of running the model
253
+ using the strategy and also automatically distributes the dataset.
254
+
255
+ ```python
256
+ with strategy.scope():
257
+ embedding = keras_rs.layers.DistributedEmbedding(feature_configs)
258
+ model = create_model(embedding)
259
+ model.compile(loss=keras.losses.MeanSquaredError(), optimizer="adam")
260
+
261
+ model.fit(dataset, epochs=10)
262
+ ```
263
+
264
+ #### Direct invocation
265
+
266
+ `DistributedEmbedding` must be invoked via a `strategy.run` call nested in a
267
+ `tf.function`.
268
+
269
+ ```python
270
+ @tf.function
271
+ def embedding_wrapper(tf_fn_inputs, tf_fn_weights=None):
272
+ def strategy_fn(st_fn_inputs, st_fn_weights):
273
+ return embedding(st_fn_inputs, st_fn_weights)
274
+
275
+ return strategy.run(strategy_fn, args=(tf_fn_inputs, tf_fn_weights)))
276
+
277
+ embedding_wrapper(my_inputs, my_weights)
278
+ ```
279
+
280
+ When using a dataset, the dataset must be distributed. The iterator can then
281
+ be passed to the `tf.function` that uses `strategy.run`.
282
+
283
+ ```python
284
+ dataset = strategy.experimental_distribute_dataset(dataset)
285
+
286
+ @tf.function
287
+ def run_loop(iterator):
288
+ def step(data):
289
+ (inputs, weights), labels = data
290
+ with tf.GradientTape() as tape:
291
+ result = embedding(inputs, weights)
292
+ loss = keras.losses.mean_squared_error(labels, result)
293
+ tape.gradient(loss, embedding.trainable_variables)
294
+ return result
295
+
296
+ for _ in tf.range(4):
297
+ result = strategy.run(step, args=(next(iterator),))
298
+
299
+ run_loop(iter(dataset))
300
+ ```
301
+
302
+ ### Usage with JAX on TPU with SpareCore
303
+
304
+ #### Setup
305
+
306
+ To use `DistributedEmbedding` on TPUs with JAX, one must create and set a
307
+ Keras `Distribution`.
308
+ ```python
309
+ distribution = keras.distribution.DataParallel(devices=jax.device("tpu"))
310
+ keras.distribution.set_distribution(distribution)
311
+ ```
312
+
313
+ #### Inputs
314
+
315
+ For JAX, inputs can either be dense tensors, or ragged (nested) NumPy
316
+ arrays. To enable `jit_compile = True`, one must explicitly call
317
+ `layer.preprocess(...)` on the inputs, and then feed the preprocessed
318
+ output to the model. See the next section on preprocessing for details.
319
+
320
+ Ragged input arrays must be ragged in the dimension with index 1. Note that
321
+ if weights are passed, each weight tensor must be of the same class as the
322
+ inputs for that particular feature and use the exact same ragged row lengths
323
+ for ragged tensors. All the output of `DistributedEmbedding` are dense
324
+ tensors.
325
+
326
+ #### Preprocessing
327
+
328
+ In JAX, SparseCore usage requires specially formatted data that depends
329
+ on properties of the available hardware. This data reformatting
330
+ currently does not support jit-compilation, so must be applied _prior_
331
+ to passing data into a model.
332
+
333
+ Preprocessing works on dense or ragged NumPy arrays, or on tensors that are
334
+ convertible to dense or ragged NumPy arrays like `tf.RaggedTensor`.
335
+
336
+ One simple way to add preprocessing is to append the function to an input
337
+ pipeline by using a python generator.
338
+ ```python
339
+ # Create the embedding layer.
340
+ embedding_layer = DistributedEmbedding(feature_configs)
341
+
342
+ # Add preprocessing to a data input pipeline.
343
+ def train_dataset_generator():
344
+ for (inputs, weights), labels in iter(train_dataset):
345
+ yield embedding_layer.preprocess(
346
+ inputs, weights, training=True
347
+ ), labels
348
+
349
+ preprocessed_train_dataset = train_dataset_generator()
350
+ ```
351
+ This explicit preprocessing stage combines the input and optional weights,
352
+ so the new data can be passed directly into the `inputs` argument of the
353
+ layer or model.
354
+
355
+ #### Usage in a Keras model
356
+
357
+ Once the global distribution is set and the input preprocessing pipeline
358
+ is defined, model training can proceed as normal. For example:
359
+ ```python
360
+ # Construct, compile, and fit the model using the preprocessed data.
361
+ model = keras.Sequential(
362
+ [
363
+ embedding_layer,
364
+ keras.layers.Dense(2),
365
+ keras.layers.Dense(3),
366
+ keras.layers.Dense(4),
367
+ ]
368
+ )
369
+ model.compile(optimizer="adam", loss="mse", jit_compile=True)
370
+ model.fit(preprocessed_train_dataset, epochs=10)
371
+ ```
372
+
373
+ #### Direct invocation
374
+
375
+ The `DistributedEmbedding` layer can also be invoked directly. Explicit
376
+ preprocessing is required when used with JIT compilation.
377
+ ```python
378
+ # Call the layer directly.
379
+ activations = embedding_layer(my_inputs, my_weights)
380
+
381
+ # Call the layer with JIT compilation and explicitly preprocessed inputs.
382
+ embedding_layer_jit = jax.jit(embedding_layer)
383
+ preprocessed_inputs = embedding_layer.preprocess(my_inputs, my_weights)
384
+ activations = embedding_layer_jit(preprocessed_inputs)
385
+ ```
386
+
387
+ Similarly, for custom training loops, preprocessing must be applied prior
388
+ to passing the data to the JIT-compiled training step.
389
+ ```python
390
+ # Create an optimizer and loss function.
391
+ optimizer = keras.optimizers.Adam(learning_rate=1e-3)
392
+
393
+ def loss_and_updates(trainable_variables, non_trainable_variables, x, y):
394
+ y_pred, non_trainable_variables = model.stateless_call(
395
+ trainable_variables, non_trainable_variables, x, training=True
396
+ )
397
+ loss = keras.losses.mean_squared_error(y, y_pred)
398
+ return loss, non_trainable_variables
399
+
400
+ grad_fn = jax.value_and_grad(loss_and_updates, has_aux=True)
401
+
402
+ # Create a JIT-compiled training step.
403
+ @jax.jit
404
+ def train_step(state, x, y):
405
+ (
406
+ trainable_variables,
407
+ non_trainable_variables,
408
+ optimizer_variables,
409
+ ) = state
410
+ (loss, non_trainable_variables), grads = grad_fn(
411
+ trainable_variables, non_trainable_variables, x, y
412
+ )
413
+ trainable_variables, optimizer_variables = optimizer.stateless_apply(
414
+ optimizer_variables, grads, trainable_variables
415
+ )
416
+ return loss, (
417
+ trainable_variables,
418
+ non_trainable_variables,
419
+ optimizer_variables,
420
+ )
421
+
422
+ # Build optimizer variables.
423
+ optimizer.build(model.trainable_variables)
424
+
425
+ # Assemble the training state.
426
+ trainable_variables = model.trainable_variables
427
+ non_trainable_variables = model.non_trainable_variables
428
+ optimizer_variables = optimizer.variables
429
+ state = trainable_variables, non_trainable_variables, optimizer_variables
430
+
431
+ # Training loop.
432
+ for (inputs, weights), labels in train_dataset:
433
+ # Explicitly preprocess the data.
434
+ preprocessed_inputs = embedding_layer.preprocess(inputs, weights)
435
+ loss, state = train_step(state, preprocessed_inputs, labels)
436
+ ```
437
+
438
+ Args:
439
+ feature_configs: A nested structure of `keras_rs.layers.FeatureConfig`.
440
+ table_stacking: The table stacking to use. `None` means no table
441
+ stacking. `"auto"` means to stack tables automatically. A list of
442
+ table names or list of lists of table names means to stack the
443
+ tables in the inner lists together. Note that table stacking is not
444
+ supported on older TPUs, in which case the default value of `"auto"`
445
+ will be interpreted as no table stacking.
446
+ **kwargs: Additional arguments to pass to the layer base class.
447
+ """
448
+
449
+ def __init__(
450
+ self,
451
+ feature_configs: types.Nested[FeatureConfig],
452
+ *,
453
+ table_stacking: (
454
+ str | Sequence[str] | Sequence[Sequence[str]]
455
+ ) = "auto",
456
+ **kwargs: Any,
457
+ ) -> None:
458
+ super().__init__(**kwargs)
459
+
460
+ self._init_feature_configs_structures(feature_configs)
461
+
462
+ # Initialize for features placed on "sparsecore".
463
+ if "sparsecore" in self._placement_to_path_to_feature_config:
464
+ self._sparsecore_init(
465
+ self._placement_to_path_to_feature_config["sparsecore"],
466
+ table_stacking,
467
+ )
468
+ # Initialize for features placed on "default_device".
469
+ if "default_device" in self._placement_to_path_to_feature_config:
470
+ self._default_device_init(
471
+ self._placement_to_path_to_feature_config["default_device"],
472
+ table_stacking,
473
+ )
474
+
475
+ @keras_utils.no_automatic_dependency_tracking
476
+ def _init_feature_configs_structures(
477
+ self,
478
+ feature_configs: types.Nested[FeatureConfig],
479
+ ) -> None:
480
+ """Initializations for efficiently transforming nested structures.
481
+
482
+ This layer handles arbitrarily nested structures for input features, and
483
+ therefore for outputs and feature configs. However, as an intermediary
484
+ format we use a two-level representation with nested dicts. the top
485
+ level dict is keyed by placement and the inner dict is keyed by path,
486
+ with the path representing the path in the original deeply nested
487
+ structure. Thanks to this intermediate representation, we can:
488
+ - dispatch the inputs by placement to overridden methods
489
+ - have backend specific implementations support only one level of
490
+ nesting.
491
+
492
+ This method is responsible for creating structures that allow this
493
+ conversion to happen in a few lines of code and efficiently. The
494
+ following attributes are created:
495
+ - self._feature_configs: the deeply nested `FeatureConfig` instances as
496
+ provided by user in `__init__`
497
+ - self._feature_deeply_nested_placement_and_paths: `PlacementAndPath`
498
+ instances in the same deeply nested structure as
499
+ `self._feature_configs`. Needed for `build` because flatten cannot be
500
+ used as it would expand the shape tuples.
501
+ - self._placement_to_path_to_feature_config: `FeatureConfig` instances
502
+ in the same two-level representation keyed by placement and then path.
503
+ Used to go from a flat representation to the intermediate
504
+ representation.
505
+
506
+ With these structures in place, the steps to:
507
+ - go from the deeply nested structure to the two-level structure are:
508
+ - `assert_same_struct` as `self._feature_configs`
509
+ - `flatten`
510
+ - `pack_sequence_as` `self._placement_to_path_to_feature_config`
511
+ - go from the two-level structure to the deeply nested structure:
512
+ - `assert_same_struct` as `self._placement_to_path_to_feature_config`
513
+ - `flatten`
514
+ - `pack_sequence_as` `self._feature_configs`
515
+
516
+ Args:
517
+ feature_configs: The deeply nested structure of `FeatureConfig` or
518
+ `tf.tpu.experimental.embedding.FeatureConfig` as provided by the
519
+ user.
520
+ """
521
+ # Needs to be assigned with `no_automatic_dependency_tracking` to not
522
+ # alter the data structure types.
523
+ self._feature_configs = feature_configs
524
+
525
+ placement_and_paths: list[PlacementAndPath] = []
526
+ paths_and_feature_configs = keras.tree.flatten_with_path(
527
+ self._feature_configs
528
+ )
529
+ self._placement_to_path_to_feature_config: dict[
530
+ str, dict[str, FeatureConfig]
531
+ ] = {}
532
+
533
+ # Lazily initialized.
534
+ has_sparsecore = None
535
+
536
+ for path, feature_config in paths_and_feature_configs:
537
+ if isinstance(feature_config, FeatureConfig):
538
+ placement = feature_config.table.placement
539
+ # Resolve "auto" to an actual placement.
540
+ if placement == "auto":
541
+ if has_sparsecore is None:
542
+ has_sparsecore = self._has_sparsecore()
543
+ placement = (
544
+ "sparsecore" if has_sparsecore else "default_device"
545
+ )
546
+ else:
547
+ # It's a `tf.tpu.experimental.embedding.FeatureConfig`.
548
+ placement = "sparsecore"
549
+
550
+ path = ".".join([str(e) for e in path])
551
+ if placement not in SUPPORTED_PLACEMENTS:
552
+ raise ValueError(
553
+ f"Feature '{path}' with name '{feature_config.name}' has "
554
+ f"unsupported placement '{placement}'."
555
+ )
556
+ placement_and_paths.append(PlacementAndPath(placement, path))
557
+ if placement not in self._placement_to_path_to_feature_config:
558
+ self._placement_to_path_to_feature_config[placement] = {}
559
+ self._placement_to_path_to_feature_config[placement][path] = (
560
+ feature_config
561
+ )
562
+
563
+ self._feature_deeply_nested_placement_and_paths = (
564
+ keras.tree.pack_sequence_as(
565
+ self._feature_configs, placement_and_paths
566
+ )
567
+ )
568
+
569
+ def build(self, input_shapes: types.Nested[types.Shape]) -> None:
570
+ if self.built:
571
+ return
572
+
573
+ self._verify_input_shapes(input_shapes)
574
+
575
+ # Go from deeply nested structure to placement -> path -> input shape.
576
+ placement_to_path_to_input_shape: collections.defaultdict[
577
+ str, dict[str, types.Shape]
578
+ ] = collections.defaultdict(dict)
579
+
580
+ def populate_placement_to_path_to_input_shape(
581
+ placement_and_path: PlacementAndPath, input_shape: types.Shape
582
+ ) -> None:
583
+ placement_to_path_to_input_shape[placement_and_path.placement][
584
+ placement_and_path.path
585
+ ] = input_shape
586
+
587
+ keras.tree.map_structure_up_to(
588
+ self._feature_configs,
589
+ populate_placement_to_path_to_input_shape,
590
+ self._feature_deeply_nested_placement_and_paths,
591
+ input_shapes,
592
+ )
593
+
594
+ # Build for features placed on "sparsecore".
595
+ if "sparsecore" in placement_to_path_to_input_shape:
596
+ self._sparsecore_build(
597
+ placement_to_path_to_input_shape["sparsecore"]
598
+ )
599
+
600
+ # Build for features placed on "default_device".
601
+ if "default_device" in placement_to_path_to_input_shape:
602
+ self._default_device_build(
603
+ placement_to_path_to_input_shape["default_device"]
604
+ )
605
+
606
+ super().build(input_shapes)
607
+
608
+ def preprocess(
609
+ self,
610
+ inputs: types.Nested[types.Tensor],
611
+ weights: types.Nested[types.Tensor] | None = None,
612
+ training: bool = False,
613
+ ) -> types.Nested[types.Tensor]:
614
+ """Preprocesses and reformats the data for consumption by the model.
615
+
616
+ For the JAX backend, converts the input data to a hardward-dependent
617
+ format required for use with SparseCores. Calling `preprocess`
618
+ explicitly is only necessary to enable `jit_compile = True`.
619
+
620
+ For non-JAX backends, preprocessing will bundle together the inputs and
621
+ weights, and separate the inputs by device placement. This step is
622
+ entirely optional.
623
+
624
+ Args:
625
+ inputs: Ragged or dense set of sample IDs.
626
+ weights: Optional ragged or dense set of sample weights.
627
+ training: If true, will update internal parameters, such as
628
+ required buffer sizes for the preprocessed data.
629
+
630
+ Returns:
631
+ Set of preprocessed inputs that can be fed directly into the
632
+ `inputs` argument of the layer.
633
+ """
634
+ # Verify input structure.
635
+ keras.tree.assert_same_structure(self._feature_configs, inputs)
636
+
637
+ if not self.built:
638
+ input_shapes = keras.tree.map_structure_up_to(
639
+ self._feature_configs,
640
+ lambda array: backend.standardize_shape(array.shape),
641
+ inputs,
642
+ )
643
+ self.build(input_shapes)
644
+
645
+ # Go from deeply nested structure of inputs to flat inputs.
646
+ flat_inputs = keras.tree.flatten(inputs)
647
+
648
+ # Go from flat to nested dict placement -> path -> input.
649
+ placement_to_path_to_inputs = keras.tree.pack_sequence_as(
650
+ self._placement_to_path_to_feature_config, flat_inputs
651
+ )
652
+
653
+ if weights is not None:
654
+ # Same for weights if present.
655
+ keras.tree.assert_same_structure(self._feature_configs, weights)
656
+ flat_weights = keras.tree.flatten(weights)
657
+ placement_to_path_to_weights = keras.tree.pack_sequence_as(
658
+ self._placement_to_path_to_feature_config, flat_weights
659
+ )
660
+ else:
661
+ # Populate keys for weights.
662
+ placement_to_path_to_weights = {
663
+ k: None for k in placement_to_path_to_inputs
664
+ }
665
+
666
+ placement_to_path_to_preprocessed: dict[
667
+ str, dict[str, dict[str, types.Nested[types.Tensor]]]
668
+ ] = {}
669
+
670
+ # Preprocess for features placed on "sparsecore".
671
+ if "sparsecore" in placement_to_path_to_inputs:
672
+ placement_to_path_to_preprocessed["sparsecore"] = (
673
+ self._sparsecore_preprocess(
674
+ placement_to_path_to_inputs["sparsecore"],
675
+ placement_to_path_to_weights["sparsecore"],
676
+ training,
677
+ )
678
+ )
679
+
680
+ # Preprocess for features placed on "default_device".
681
+ if "default_device" in placement_to_path_to_inputs:
682
+ placement_to_path_to_preprocessed["default_device"] = (
683
+ self._default_device_preprocess(
684
+ placement_to_path_to_inputs["default_device"],
685
+ placement_to_path_to_weights["default_device"],
686
+ training,
687
+ )
688
+ )
689
+
690
+ # Mark inputs as preprocessed using an extra level of nesting.
691
+ # This is necessary to detect whether inputs are already preprocessed
692
+ # in `call`.
693
+ output = {
694
+ "preprocessed_inputs_per_placement": (
695
+ placement_to_path_to_preprocessed
696
+ )
697
+ }
698
+ return output
699
+
700
+ def _is_preprocessed(
701
+ self, inputs: types.Nested[types.Tensor | types.Shape]
702
+ ) -> bool:
703
+ """Checks if the input is already preprocessed."""
704
+ return (
705
+ isinstance(inputs, dict)
706
+ and "preprocessed_inputs_per_placement" in inputs
707
+ )
708
+
709
+ def call(
710
+ self,
711
+ inputs: types.Nested[types.Tensor],
712
+ weights: types.Nested[types.Tensor] | None = None,
713
+ training: bool = False,
714
+ ) -> types.Nested[types.Tensor]:
715
+ """Lookup features in embedding tables and apply reduction.
716
+
717
+ Args:
718
+ inputs: A nested structure of 2D tensors to embed and reduce. The
719
+ structure must be the same as the `feature_configs` passed
720
+ during construction. Alternatively, may consist of already
721
+ preprocessed inputs (see `preprocess`).
722
+ weights: An optional nested structure of 2D tensors of weights to
723
+ apply before reduction. When present, the structure must be the
724
+ same as `inputs` and the shapes must match.
725
+ training: Whether we are training or evaluating the model.
726
+
727
+ Returns:
728
+ A nested structure of dense 2D tensors, which are the reduced
729
+ embeddings from the passed features. The structure is the same as
730
+ `inputs`.
731
+ """
732
+ preprocessed_inputs = inputs
733
+ # Preprocess if not already done.
734
+ if not self._is_preprocessed(inputs):
735
+ preprocessed_inputs = self.preprocess(inputs, weights, training)
736
+
737
+ preprocessed_inputs = typing.cast(
738
+ dict[str, dict[str, dict[str, types.Tensor]]], preprocessed_inputs
739
+ )
740
+ # Placement -> path -> preprocessed inputs.
741
+ preprocessed_inputs = preprocessed_inputs[
742
+ "preprocessed_inputs_per_placement"
743
+ ]
744
+
745
+ placement_to_path_to_outputs = {}
746
+
747
+ # Call for features placed on "sparsecore".
748
+ if "sparsecore" in preprocessed_inputs:
749
+ inputs_and_weights = preprocessed_inputs["sparsecore"]
750
+ placement_to_path_to_outputs["sparsecore"] = self._sparsecore_call(
751
+ **inputs_and_weights,
752
+ training=training,
753
+ )
754
+
755
+ # Call for features placed on "default_device".
756
+ if "default_device" in preprocessed_inputs:
757
+ inputs_and_weights = preprocessed_inputs["default_device"]
758
+ placement_to_path_to_outputs["default_device"] = (
759
+ self._default_device_call(
760
+ **inputs_and_weights,
761
+ training=training,
762
+ )
763
+ )
764
+
765
+ # Verify output structure.
766
+ keras.tree.assert_same_structure(
767
+ self._placement_to_path_to_feature_config,
768
+ placement_to_path_to_outputs,
769
+ )
770
+
771
+ # Go from placement -> path -> output to flat outputs.
772
+ flat_outputs = keras.tree.flatten(placement_to_path_to_outputs)
773
+
774
+ # Go from flat outputs to deeply nested structure.
775
+ return keras.tree.pack_sequence_as(self._feature_configs, flat_outputs)
776
+
777
+ def get_embedding_tables(self) -> dict[str, types.Tensor]:
778
+ """Return the content of the embedding tables by table name.
779
+
780
+ The tables are keyed by the name provided in each `TableConfig`. Note
781
+ that the returned tensors are not the actual embedding table variables
782
+ used internally by `DistributedEmbedding`.
783
+
784
+ Returns:
785
+ A dictionary of table name to tensor for the embedding tables.
786
+ """
787
+ tables = {}
788
+ if "sparsecore" in self._placement_to_path_to_feature_config:
789
+ tables.update(self._sparsecore_get_embedding_tables())
790
+ if "default_device" in self._placement_to_path_to_feature_config:
791
+ tables.update(self._default_device_get_embedding_tables())
792
+ return tables
793
+
794
+ def _default_device_init(
795
+ self,
796
+ feature_configs: dict[str, FeatureConfig],
797
+ table_stacking: str | Sequence[Sequence[str]],
798
+ ) -> None:
799
+ del table_stacking
800
+ table_to_embedding_layer: dict[TableConfig, EmbedReduce] = {}
801
+ self._default_device_embedding_layers: dict[str, EmbedReduce] = {}
802
+
803
+ for path, feature_config in feature_configs.items():
804
+ if feature_config.table in table_to_embedding_layer:
805
+ self._default_device_embedding_layers[path] = (
806
+ table_to_embedding_layer[feature_config.table]
807
+ )
808
+ else:
809
+ embedding_layer = EmbedReduce(
810
+ name=feature_config.table.name,
811
+ input_dim=feature_config.table.vocabulary_size,
812
+ output_dim=feature_config.table.embedding_dim,
813
+ embeddings_initializer=feature_config.table.initializer,
814
+ combiner=feature_config.table.combiner,
815
+ )
816
+ table_to_embedding_layer[feature_config.table] = embedding_layer
817
+ self._default_device_embedding_layers[path] = embedding_layer
818
+
819
+ def _default_device_build(
820
+ self, input_shapes: dict[str, types.Shape]
821
+ ) -> None:
822
+ for path, input_shape in input_shapes.items():
823
+ embedding_layer = self._default_device_embedding_layers[path]
824
+ if not embedding_layer.built:
825
+ embedding_layer.build(input_shape)
826
+
827
+ def _default_device_preprocess(
828
+ self,
829
+ inputs: dict[str, types.Tensor],
830
+ weights: dict[str, types.Tensor] | None,
831
+ training: bool = False,
832
+ ) -> dict[str, dict[str, types.Tensor]]:
833
+ del training
834
+
835
+ # NOTE: This JAX specialization is in the base layer so it is available
836
+ # on all platforms. The superclass jax.DistributedEmbedding layer
837
+ # is currently only imported in linux_x86_64.
838
+ if keras.backend.backend() == "jax":
839
+ feature_configs = self._placement_to_path_to_feature_config[
840
+ "default_device"
841
+ ]
842
+
843
+ # Potentially track new weights. For ragged inputs, if we
844
+ # densify, we will generate a dense weight tensor.
845
+ new_weights: dict[str, types.Tensor] = {}
846
+ use_weights = weights is not None
847
+
848
+ # Convert any ragged inputs to dense.
849
+ for path, config in feature_configs.items():
850
+ feature_inputs = inputs[path]
851
+ feature_weights = weights[path] if weights is not None else None
852
+
853
+ feature_valence = (
854
+ None
855
+ if len(config.input_shape) <= 1
856
+ else config.input_shape[1]
857
+ )
858
+ feature_inputs, feature_weights = _ragged_to_dense_inputs(
859
+ feature_inputs, feature_weights, feature_valence
860
+ )
861
+ # Converting to ragged may have introduced a weights array.
862
+ use_weights = use_weights or feature_weights is not None
863
+ inputs[path] = feature_inputs
864
+ new_weights[path] = feature_weights
865
+
866
+ if use_weights:
867
+ weights = new_weights
868
+
869
+ output: dict[str, types.Tensor] = {"inputs": inputs}
870
+ if weights is not None:
871
+ output["weights"] = weights
872
+
873
+ return output
874
+
875
+ def _default_device_call(
876
+ self,
877
+ inputs: dict[str, types.Tensor],
878
+ weights: dict[str, types.Tensor] | None = None,
879
+ training: bool = False,
880
+ ) -> dict[str, types.Tensor]:
881
+ del training # Unused by default.
882
+ if weights is None:
883
+ return {
884
+ path: self._default_device_embedding_layers[path](x)
885
+ for path, x in inputs.items()
886
+ }
887
+ else:
888
+ return {
889
+ path: self._default_device_embedding_layers[path](
890
+ x, weights[path]
891
+ )
892
+ for path, x in inputs.items()
893
+ }
894
+
895
+ def _default_device_get_embedding_tables(self) -> dict[str, types.Tensor]:
896
+ tables = {}
897
+ for path, feature_config in self._placement_to_path_to_feature_config[
898
+ "default_device"
899
+ ].items():
900
+ tables[feature_config.table.name] = (
901
+ self._default_device_embedding_layers[path].embeddings.value
902
+ )
903
+ return tables
904
+
905
+ def _has_sparsecore(self) -> bool:
906
+ # Explicitly check for SparseCore availability.
907
+ # We need this check here rather than in jax/distributed_embedding.py
908
+ # so that we can warn the user about missing dependencies.
909
+ if keras.backend.backend() == "jax":
910
+ # Check if SparseCores are available.
911
+ try:
912
+ import jax
913
+
914
+ tpu_devices = jax.devices("tpu")
915
+ except RuntimeError:
916
+ # No TPUs available.
917
+ return False
918
+
919
+ if len(tpu_devices) > 0:
920
+ device_kind = tpu_devices[0].device_kind
921
+ if device_kind in ["TPU v5", "TPU v6 lite"]:
922
+ return True
923
+
924
+ return False
925
+
926
+ def _sparsecore_init(
927
+ self,
928
+ feature_configs: dict[str, FeatureConfig],
929
+ table_stacking: str | Sequence[Sequence[str]],
930
+ ) -> None:
931
+ del feature_configs, table_stacking
932
+
933
+ if keras.backend.backend() == "jax":
934
+ jax_tpu_embedding_spec = importlib.util.find_spec(
935
+ "jax_tpu_embedding"
936
+ )
937
+ if jax_tpu_embedding_spec is None:
938
+ raise ImportError(
939
+ "Please install jax-tpu-embedding to use "
940
+ "DistributedEmbedding on sparsecore devices."
941
+ )
942
+
943
+ raise self._unsupported_placement_error("sparsecore")
944
+
945
+ def _sparsecore_build(self, input_shapes: dict[str, types.Shape]) -> None:
946
+ del input_shapes
947
+ raise self._unsupported_placement_error("sparsecore")
948
+
949
+ def _sparsecore_preprocess(
950
+ self,
951
+ inputs: dict[str, types.Tensor],
952
+ weights: dict[str, types.Tensor] | None,
953
+ training: bool = False,
954
+ ) -> dict[str, dict[str, types.Tensor]]:
955
+ del training
956
+ output: dict[str, types.Tensor] = {"inputs": inputs}
957
+ if weights is not None:
958
+ output["weights"] = weights
959
+
960
+ return output
961
+
962
+ def _sparsecore_call(
963
+ self,
964
+ inputs: dict[str, types.Tensor],
965
+ weights: dict[str, types.Tensor] | None = None,
966
+ training: bool = False,
967
+ ) -> dict[str, types.Tensor]:
968
+ del inputs, weights, training
969
+ raise self._unsupported_placement_error("sparsecore")
970
+
971
+ def _sparsecore_get_embedding_tables(self) -> dict[str, types.Tensor]:
972
+ raise self._unsupported_placement_error("sparsecore")
973
+
974
+ def compute_output_shape(
975
+ self, input_shapes: types.Nested[types.Shape]
976
+ ) -> types.Nested[types.Shape]:
977
+ self._verify_input_shapes(input_shapes)
978
+ output_shape: types.Nested[types.Shape] = keras.tree.map_structure(
979
+ lambda fc: fc.output_shape, self._feature_configs
980
+ )
981
+ return output_shape
982
+
983
+ def get_config(self) -> dict[str, Any]:
984
+ # Because the Keras serialization creates a tree of serialized objects,
985
+ # it does not directly support sharing tables between feature configs.
986
+ # We therefore serialize the tables config as a flat list and then refer
987
+ # to them by index in each feature config.
988
+
989
+ # The serialized `TableConfig` objects.
990
+ table_config_dicts: list[dict[str, Any]] = []
991
+ # Mapping from `TableConfig` to index in `table_config_dicts`.
992
+ table_config_indices: dict[TableConfig, int] = {}
993
+
994
+ def serialize_feature_config(
995
+ feature_config: FeatureConfig,
996
+ ) -> dict[str, Any]:
997
+ # Note that for consistency with the contract of `get_config`, the
998
+ # returned dict contains the serialized `TableConfig` in the "table"
999
+ # key.
1000
+ feature_config_dict = feature_config.get_config()
1001
+
1002
+ if feature_config.table not in table_config_indices:
1003
+ # Save the serialized `TableConfig` the first time we see it and
1004
+ # remember its index.
1005
+ table_config_indices[feature_config.table] = len(
1006
+ table_config_dicts
1007
+ )
1008
+ table_config_dicts.append(feature_config_dict["table"])
1009
+
1010
+ # Replace the serialized `TableConfig` with its index.
1011
+ feature_config_dict["table"] = table_config_indices[
1012
+ feature_config.table
1013
+ ]
1014
+ return feature_config_dict
1015
+
1016
+ config: dict[str, Any] = super().get_config()
1017
+ config["feature_configs"] = keras.tree.map_structure(
1018
+ serialize_feature_config, self._feature_configs
1019
+ )
1020
+ config["tables"] = table_config_dicts
1021
+ if hasattr(self, "_table_stacking"):
1022
+ config["table_stacking"] = self._table_stacking
1023
+ return config
1024
+
1025
+ @classmethod
1026
+ def from_config(cls, config: dict[str, Any]) -> "DistributedEmbedding":
1027
+ config = config.copy()
1028
+ # We need to reconnect the `TableConfig`s to the `FeatureConfig`s.
1029
+
1030
+ # The serialized `TableConfig` objects.
1031
+ table_config_dicts: list[dict[str, Any]] = config.pop("tables")
1032
+ # The deserialized `TableConfig` objects at the same indices.
1033
+ table_configs: list[TableConfig | None] = [None] * len(
1034
+ table_config_dicts
1035
+ )
1036
+
1037
+ def deserialize_feature_config(
1038
+ feature_config_dict: dict[str, Any],
1039
+ ) -> FeatureConfig | None:
1040
+ # Look for a "name" attribute which is a string to detect a
1041
+ # `FeatureConfig` leaf node. If not, keep recursing.
1042
+ if "name" not in feature_config_dict or not isinstance(
1043
+ feature_config_dict["name"], str
1044
+ ):
1045
+ # Tell `traverse` to recurse.
1046
+ return None
1047
+
1048
+ table_index = feature_config_dict["table"]
1049
+ # Note that for consistency with the contract of `from_config`, the
1050
+ # passed dict must contain the serialized `TableConfig` in the
1051
+ # "table" key.
1052
+ feature_config_dict["table"] = table_config_dicts[table_index]
1053
+ feature_config = FeatureConfig.from_config(feature_config_dict)
1054
+ # But then dedupe `TableConfig`s.
1055
+ if table_configs[table_index] is None:
1056
+ # Remember each new `TableConfig` we see.
1057
+ table_configs[table_index] = feature_config.table
1058
+ else:
1059
+ # And swap duplicates for the original.
1060
+ feature_config.table = table_configs[table_index]
1061
+ return feature_config
1062
+
1063
+ # Because each `FeatureConfig` is serialized as a dict, we cannot use
1064
+ # `map_structure` as it would recurse in the config itself. We use
1065
+ # `traverse` instead with a function that detects leaf nodes.
1066
+ config["feature_configs"] = keras.tree.traverse(
1067
+ deserialize_feature_config, config["feature_configs"]
1068
+ )
1069
+ return cls(**config)
1070
+
1071
+ def _verify_input_shapes(
1072
+ self, input_shapes: types.Nested[types.Shape]
1073
+ ) -> None:
1074
+ """Verifies that the input shapes match the ones in the feature configs.
1075
+
1076
+ Args:
1077
+ input_shapes: The structure of input shapes to verify.
1078
+ """
1079
+ # Support preprocessing.
1080
+ if self._is_preprocessed(input_shapes):
1081
+ # Structure should be :
1082
+ # {
1083
+ # placement: {
1084
+ # inputs: {path: Any},
1085
+ # weights: {path: Any}
1086
+ # }
1087
+ # }
1088
+ #
1089
+ # But the `Any` values could be nested tensors with varying
1090
+ # structure, depending on hardware constraints. This complicates
1091
+ # checking shapes via keras.tree methods. So, assume the
1092
+ # input is a result of explicitly calling the `preprocess(...)`
1093
+ # function, in which case the structure has already been verified.
1094
+ return
1095
+
1096
+ def _verify_input_shape(
1097
+ feature_config: FeatureConfig,
1098
+ input_shape: types.Shape,
1099
+ ) -> None:
1100
+ if not isinstance(input_shape, (tuple, list)) or not all(
1101
+ isinstance(d, (int, type(None))) for d in input_shape
1102
+ ):
1103
+ raise ValueError(f"Received invalid input shape {input_shape}.")
1104
+ if len(input_shape) < 1:
1105
+ raise ValueError(
1106
+ f"Received input shape {input_shape}. Rank must be 1 or "
1107
+ "above."
1108
+ )
1109
+ keras_utils.check_shapes_compatible(
1110
+ feature_config.input_shape, input_shape
1111
+ )
1112
+
1113
+ keras.tree.map_structure_up_to(
1114
+ self._feature_configs,
1115
+ _verify_input_shape,
1116
+ self._feature_configs,
1117
+ input_shapes,
1118
+ )
1119
+
1120
+ def _unsupported_placement_error(self, placement: str) -> Exception:
1121
+ return NotImplementedError(
1122
+ f"Backend '{keras.backend.backend()}' does not support the "
1123
+ f"'{placement}' placement."
1124
+ )