keras-rs-nightly 0.0.1.dev2025050103__py3-none-any.whl → 0.2.2.dev202506100336__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of keras-rs-nightly might be problematic. Click here for more details.

Files changed (33) hide show
  1. keras_rs/layers/__init__.py +12 -0
  2. keras_rs/src/layers/embedding/__init__.py +0 -0
  3. keras_rs/src/layers/embedding/base_distributed_embedding.py +1124 -0
  4. keras_rs/src/layers/embedding/distributed_embedding.py +33 -0
  5. keras_rs/src/layers/embedding/distributed_embedding_config.py +129 -0
  6. keras_rs/src/layers/embedding/embed_reduce.py +309 -0
  7. keras_rs/src/layers/embedding/jax/__init__.py +0 -0
  8. keras_rs/src/layers/embedding/jax/config_conversion.py +398 -0
  9. keras_rs/src/layers/embedding/jax/distributed_embedding.py +892 -0
  10. keras_rs/src/layers/embedding/jax/embedding_lookup.py +255 -0
  11. keras_rs/src/layers/embedding/jax/embedding_utils.py +596 -0
  12. keras_rs/src/layers/embedding/tensorflow/__init__.py +0 -0
  13. keras_rs/src/layers/embedding/tensorflow/config_conversion.py +323 -0
  14. keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py +424 -0
  15. keras_rs/src/layers/feature_interaction/dot_interaction.py +2 -2
  16. keras_rs/src/layers/feature_interaction/feature_cross.py +14 -16
  17. keras_rs/src/layers/retrieval/brute_force_retrieval.py +5 -5
  18. keras_rs/src/layers/retrieval/retrieval.py +4 -4
  19. keras_rs/src/losses/pairwise_loss.py +2 -2
  20. keras_rs/src/losses/pairwise_mean_squared_error.py +1 -3
  21. keras_rs/src/metrics/dcg.py +2 -2
  22. keras_rs/src/metrics/ndcg.py +2 -2
  23. keras_rs/src/metrics/ranking_metric.py +4 -4
  24. keras_rs/src/metrics/ranking_metrics_utils.py +8 -8
  25. keras_rs/src/metrics/utils.py +2 -4
  26. keras_rs/src/types.py +43 -14
  27. keras_rs/src/utils/keras_utils.py +26 -6
  28. keras_rs/src/version.py +1 -1
  29. {keras_rs_nightly-0.0.1.dev2025050103.dist-info → keras_rs_nightly-0.2.2.dev202506100336.dist-info}/METADATA +6 -3
  30. keras_rs_nightly-0.2.2.dev202506100336.dist-info/RECORD +55 -0
  31. {keras_rs_nightly-0.0.1.dev2025050103.dist-info → keras_rs_nightly-0.2.2.dev202506100336.dist-info}/WHEEL +1 -1
  32. keras_rs_nightly-0.0.1.dev2025050103.dist-info/RECORD +0 -42
  33. {keras_rs_nightly-0.0.1.dev2025050103.dist-info → keras_rs_nightly-0.2.2.dev202506100336.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,892 @@
1
+ """JAX implementation of the TPU embedding layer."""
2
+
3
+ import math
4
+ import typing
5
+ from typing import Any, Mapping, Sequence, Union
6
+
7
+ import jax
8
+ import keras
9
+ import numpy as np
10
+ from jax import numpy as jnp
11
+ from jax.experimental import layout as jax_layout
12
+ from jax_tpu_embedding.sparsecore.lib.nn import embedding
13
+ from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec
14
+ from jax_tpu_embedding.sparsecore.lib.nn import (
15
+ table_stacking as jte_table_stacking,
16
+ )
17
+ from jax_tpu_embedding.sparsecore.utils import utils as jte_utils
18
+ from keras.src import backend
19
+
20
+ from keras_rs.src import types
21
+ from keras_rs.src.layers.embedding import base_distributed_embedding
22
+ from keras_rs.src.layers.embedding import distributed_embedding_config as config
23
+ from keras_rs.src.layers.embedding.jax import config_conversion
24
+ from keras_rs.src.layers.embedding.jax import (
25
+ embedding_lookup as jte_embedding_lookup,
26
+ )
27
+ from keras_rs.src.layers.embedding.jax import embedding_utils
28
+ from keras_rs.src.types import Nested
29
+ from keras_rs.src.utils import keras_utils
30
+
31
+ ArrayLike = Union[np.ndarray[Any, Any], jax.Array]
32
+ FeatureConfig = config.FeatureConfig
33
+ shard_map = jax.experimental.shard_map.shard_map # type: ignore[attr-defined]
34
+
35
+
36
+ def _get_partition_spec(
37
+ layout: (
38
+ keras.distribution.TensorLayout
39
+ | jax_layout.Layout
40
+ | jax.sharding.NamedSharding
41
+ | jax.sharding.PartitionSpec
42
+ ),
43
+ ) -> Any:
44
+ """Extracts the partition spec from a layout or sharding."""
45
+ if isinstance(layout, keras.distribution.TensorLayout):
46
+ layout = layout.backend_layout
47
+
48
+ if isinstance(layout, jax_layout.Layout):
49
+ layout = layout.sharding
50
+
51
+ if isinstance(layout, jax.sharding.NamedSharding):
52
+ layout = layout.spec
53
+
54
+ return layout
55
+
56
+
57
+ class ShardedInitializer(keras.initializers.Initializer):
58
+ """Wraps an initializer to prepare for use with embedding tables.
59
+
60
+ Jit-compiles the function and applies optimal output sharding to
61
+ allow initialization on device.
62
+ """
63
+
64
+ def __init__(
65
+ self,
66
+ initializer: keras.initializers.Initializer | str,
67
+ layout: keras.distribution.TensorLayout | None,
68
+ ):
69
+ if isinstance(initializer, str):
70
+ initializer = keras.initializers.get(initializer)
71
+
72
+ self._initializer = initializer
73
+ self._layout = layout
74
+
75
+ def __call__(
76
+ self, shape: types.Shape, dtype: types.DType | None = None
77
+ ) -> jax.Array:
78
+ if self._layout is not None:
79
+ compiled_initializer = jax.jit(
80
+ self._initializer,
81
+ out_shardings=self._layout.backend_layout,
82
+ static_argnames=["shape", "dtype"],
83
+ )
84
+ output: jax.Array = compiled_initializer(shape, dtype)
85
+ return output
86
+
87
+ output = self._initializer(shape, dtype)
88
+ return output
89
+
90
+
91
+ class StackedTableInitializer(keras.initializers.Initializer):
92
+ """Initializes a single stacked table from multiple table initializers."""
93
+
94
+ def __init__(
95
+ self,
96
+ table_specs: Nested[embedding_spec.TableSpec],
97
+ num_shards: int,
98
+ layout: keras.distribution.TensorLayout,
99
+ seed: int | keras.random.SeedGenerator | jax.Array = 0,
100
+ ):
101
+ # Sort table specs so we can simply concatenate them when assembling the
102
+ # stacked table.
103
+ self._table_specs = sorted(
104
+ keras.tree.flatten(table_specs),
105
+ key=lambda table_spec: (
106
+ table_spec.setting_in_stack.row_offset_in_shard,
107
+ ),
108
+ )
109
+ self._num_shards = num_shards
110
+ self._layout = layout
111
+ self._key = keras.src.backend.jax.random.jax_draw_seed(seed)
112
+
113
+ def _initialize_shard(
114
+ self,
115
+ keys: jax.Array,
116
+ shape: tuple[int, int],
117
+ dtype: Any,
118
+ num_shards_per_device: int,
119
+ ) -> jax.Array:
120
+ """Initializes a single shard of a stacked table."""
121
+ del shape # Unused.
122
+ table_shards: list[jax.Array] = []
123
+ # NOTE: the following ignores padding, rotations in shard, and
124
+ # mod-sharding, assuming all initializers are shard-independent.
125
+ for i in range(num_shards_per_device):
126
+ for j, table_spec in enumerate(self._table_specs):
127
+ setting_in_stack = table_spec.setting_in_stack
128
+ table_shard_shape = (
129
+ setting_in_stack.padded_vocab_size // self._num_shards,
130
+ setting_in_stack.padded_embedding_dim,
131
+ )
132
+ initializer = table_spec.initializer
133
+ table_shards.append(
134
+ initializer(keys[i, j], table_shard_shape, dtype)
135
+ )
136
+
137
+ return jnp.concatenate(table_shards, axis=0)
138
+
139
+ def __call__(
140
+ self, shape: types.Shape, dtype: types.DType | None = None
141
+ ) -> jax.Array:
142
+ stacked_table_spec = typing.cast(
143
+ embedding_spec.StackedTableSpec,
144
+ self._table_specs[0].stacked_table_spec,
145
+ )
146
+
147
+ # Input shape is governed by the table specs.
148
+ assert shape == (
149
+ stacked_table_spec.stack_vocab_size,
150
+ stacked_table_spec.stack_embedding_dim,
151
+ )
152
+
153
+ layout = self._layout
154
+ backend_layout = layout.backend_layout
155
+ backend_mesh = layout.device_mesh.backend_mesh
156
+ num_devices_along_axis = backend_mesh.shape[layout.axes[0]]
157
+ num_shards_per_device = self._num_shards // num_devices_along_axis
158
+ shard_shape = (
159
+ stacked_table_spec.stack_vocab_size // num_devices_along_axis,
160
+ stacked_table_spec.stack_embedding_dim,
161
+ )
162
+
163
+ sharded_initializer = jax.jit(
164
+ shard_map(
165
+ lambda keys: self._initialize_shard(
166
+ keys, shard_shape, dtype, num_shards_per_device
167
+ ),
168
+ mesh=backend_mesh,
169
+ in_specs=_get_partition_spec(backend_layout),
170
+ out_specs=_get_partition_spec(backend_layout),
171
+ ),
172
+ out_shardings=backend_layout,
173
+ )
174
+
175
+ keys = jax.random.split(
176
+ self._key, (self._num_shards, len(self._table_specs))
177
+ )
178
+ # Try extracting seeds from the existing table initializers.
179
+ for i, table_spec in enumerate(self._table_specs):
180
+ initializer = table_spec.initializer
181
+ if isinstance(
182
+ initializer, config_conversion.WrappedKerasInitializer
183
+ ):
184
+ initializer_key = initializer.key()
185
+ if initializer_key is not None:
186
+ col = jax.random.split(initializer_key, self._num_shards)
187
+ keys = keys.at[:, i].set(col)
188
+
189
+ output: jax.Array = sharded_initializer(keys)
190
+ return output
191
+
192
+
193
+ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
194
+ """JAX implementation of the TPU embedding layer."""
195
+
196
+ def _create_sparsecore_distribution(
197
+ self, sparsecore_axis_name: str = "sparsecore"
198
+ ) -> tuple[
199
+ keras.distribution.ModelParallel, keras.distribution.TensorLayout
200
+ ]:
201
+ """SparseCore requires a specific layout.
202
+
203
+ The mesh must be 1D, must use all TPUs available, and must shard all
204
+ tables across all devices.
205
+
206
+ Args:
207
+ sparsecore_axis_name: The name of the sparsecore axis.
208
+
209
+ Returns:
210
+ A Keras distribution to use for all sparsecore operations.
211
+ """
212
+ all_devices = jax.devices()
213
+ axes = [sparsecore_axis_name]
214
+ device_mesh = keras.distribution.DeviceMesh(
215
+ (len(all_devices),), axes, all_devices
216
+ )
217
+ sparsecore_layout = keras.distribution.TensorLayout(axes, device_mesh)
218
+ # Custom sparsecore layout with tiling.
219
+ # pylint: disable-next=protected-access
220
+ sparsecore_layout._backend_layout = jax_layout.Layout(
221
+ jax_layout.DeviceLocalLayout(
222
+ major_to_minor=(0, 1),
223
+ _tiling=((8,),),
224
+ ),
225
+ jax.sharding.NamedSharding(
226
+ device_mesh.backend_mesh,
227
+ jax.sharding.PartitionSpec(
228
+ axes # type: ignore[no-untyped-call]
229
+ ),
230
+ ),
231
+ )
232
+ layout_map = keras.distribution.LayoutMap(device_mesh=device_mesh)
233
+ path = self.path
234
+ if path is None:
235
+ # Layer hasn't been properly built yet. Use current layer name.
236
+ path = self.name
237
+ layout_map[path + "/var"] = sparsecore_layout
238
+ sparsecore_distribution = keras.distribution.ModelParallel(
239
+ layout_map=layout_map
240
+ )
241
+ return sparsecore_distribution, sparsecore_layout
242
+
243
+ def _create_cpu_distribution(
244
+ self, cpu_axis_name: str = "cpu"
245
+ ) -> tuple[
246
+ keras.distribution.ModelParallel, keras.distribution.TensorLayout
247
+ ]:
248
+ """Share a variable across all CPU processes."""
249
+ cpu_devices = jax.devices("cpu")
250
+ device_mesh = keras.distribution.DeviceMesh(
251
+ (len(cpu_devices),), [cpu_axis_name], cpu_devices
252
+ )
253
+ replicated_layout = keras.distribution.TensorLayout([], device_mesh)
254
+ layout_map = keras.distribution.LayoutMap(device_mesh=device_mesh)
255
+ cpu_distribution = keras.distribution.ModelParallel(
256
+ layout_map=layout_map
257
+ )
258
+ return cpu_distribution, replicated_layout
259
+
260
+ def _add_sparsecore_weight(
261
+ self,
262
+ name: str,
263
+ shape: tuple[int, int],
264
+ initializer: jax.nn.initializers.Initializer,
265
+ dtype: Any,
266
+ overwrite_with_gradient: bool,
267
+ ) -> keras.Variable:
268
+ var = self.add_weight(
269
+ name=name, shape=shape, initializer=initializer, dtype=dtype
270
+ )
271
+ var.overwrite_with_gradient = overwrite_with_gradient
272
+ return var
273
+
274
+ def _add_table_variable(
275
+ self,
276
+ table_specs: Sequence[embedding_spec.TableSpec],
277
+ num_shards: int,
278
+ add_slot_variables: bool,
279
+ ) -> tuple[keras.Variable, tuple[keras.Variable, ...] | None]:
280
+ stacked_table_spec = typing.cast(
281
+ embedding_spec.StackedTableSpec, table_specs[0].stacked_table_spec
282
+ )
283
+ optimizer = stacked_table_spec.optimizer
284
+ num_slot_variables = optimizer.slot_variables_count()
285
+ table_shape = (
286
+ stacked_table_spec.stack_vocab_size,
287
+ stacked_table_spec.stack_embedding_dim,
288
+ )
289
+
290
+ # Make a stacked embedding table initializer.
291
+ table_initializers = [
292
+ config_conversion.jax_to_keras_initializer(table_spec.initializer)
293
+ for table_spec in table_specs
294
+ ]
295
+ # If all initializers are the same, we can use a single sharded
296
+ # initializer. Otherwise, we need to interleave individual stacked table
297
+ # shards.
298
+ sparsecore_layout = self._sparsecore_layout
299
+ stacked_table_initializer = ShardedInitializer(
300
+ table_initializers[0], sparsecore_layout
301
+ )
302
+ if not all(
303
+ initializer == table_initializers[0]
304
+ for initializer in table_initializers
305
+ ):
306
+ stacked_table_initializer = StackedTableInitializer(
307
+ table_specs, num_shards, sparsecore_layout
308
+ )
309
+
310
+ variable_name = f"var:{stacked_table_spec.stack_name}:table"
311
+ table_variable = self._add_sparsecore_weight(
312
+ name=variable_name,
313
+ shape=table_shape,
314
+ initializer=stacked_table_initializer,
315
+ dtype="float32",
316
+ overwrite_with_gradient=True,
317
+ )
318
+
319
+ slot_variables = None
320
+ if add_slot_variables:
321
+ # All optimizers for a given stacked table are guaranteed to be the
322
+ # same, so we can use a single sharded initializer for the entire
323
+ # stacked table.
324
+ slot_initializers = optimizer.slot_variables_initializers()
325
+ # Try extracting field names from variables, otherwise just use the
326
+ # count.
327
+ slot_names = range(num_slot_variables)
328
+ if hasattr(slot_initializers, "_fields"):
329
+ slot_names = slot_initializers._fields
330
+
331
+ slot_variables = tuple(
332
+ self._add_sparsecore_weight(
333
+ name=f"{variable_name}:slot:{slot_name}",
334
+ shape=table_shape,
335
+ initializer=ShardedInitializer(
336
+ config_conversion.jax_to_keras_initializer(initializer),
337
+ sparsecore_layout,
338
+ ),
339
+ dtype=jnp.float32,
340
+ overwrite_with_gradient=True,
341
+ )
342
+ for slot_name, initializer in zip(slot_names, slot_initializers)
343
+ )
344
+ slot_variables = keras.tree.pack_sequence_as(
345
+ slot_initializers, slot_variables
346
+ )
347
+
348
+ return table_variable, slot_variables
349
+
350
+ @keras_utils.no_automatic_dependency_tracking
351
+ def _sparsecore_init(
352
+ self,
353
+ feature_configs: dict[str, FeatureConfig],
354
+ table_stacking: str | Sequence[str] | Sequence[Sequence[str]],
355
+ ) -> None:
356
+ if not self._has_sparsecore():
357
+ raise ValueError(
358
+ "Not sparse cores available, cannot use explicit sparsecore"
359
+ " placement."
360
+ )
361
+
362
+ self._sc_feature_configs = feature_configs
363
+ self._sparsecore_built = False
364
+ # Fill in any empty default settings.
365
+ for feature_config in keras.tree.flatten(self._sc_feature_configs):
366
+ if feature_config.table.initializer is None:
367
+ table = feature_config.table
368
+ table.initializer = keras.initializers.TruncatedNormal(
369
+ mean=0.0, stddev=1.0 / math.sqrt(float(table.embedding_dim))
370
+ )
371
+
372
+ # Actual stacking of tables is done in build() to ensure the
373
+ # distribution is set up correctly.
374
+ self._table_stacking = table_stacking
375
+
376
+ def _sparsecore_build(
377
+ self, input_shapes: Nested[types.Shape] | None = None
378
+ ) -> None:
379
+ self.sparsecore_build(input_shapes)
380
+
381
+ @keras_utils.no_automatic_dependency_tracking
382
+ def sparsecore_build(
383
+ self, input_shapes: Nested[types.Shape] | None = None
384
+ ) -> None:
385
+ del input_shapes # Unused.
386
+
387
+ if self._sparsecore_built:
388
+ return
389
+
390
+ feature_specs = config_conversion.keras_to_jte_feature_configs(
391
+ self._sc_feature_configs
392
+ )
393
+
394
+ # Distribution for sparsecore operations.
395
+ sparsecore_distribution, sparsecore_layout = (
396
+ self._create_sparsecore_distribution()
397
+ )
398
+ self._sparsecore_layout = sparsecore_layout
399
+ self._sparsecore_distribution = sparsecore_distribution
400
+
401
+ # Distribution for CPU operations.
402
+ cpu_distribution, cpu_layout = self._create_cpu_distribution()
403
+ self._cpu_distribution = cpu_distribution
404
+ self._cpu_layout = cpu_layout
405
+
406
+ mesh = sparsecore_distribution.device_mesh.backend_mesh
407
+ global_device_count = mesh.devices.size
408
+ num_sc_per_device = jte_utils.num_sparsecores_per_device(
409
+ mesh.devices.item(0)
410
+ )
411
+ # One table shard per global sparsecore.
412
+ num_variable_shards = global_device_count * num_sc_per_device
413
+
414
+ # Maybe stack tables.
415
+ table_stacking = self._table_stacking
416
+ if table_stacking is not None:
417
+ if isinstance(table_stacking, str):
418
+ if table_stacking == "auto":
419
+ jte_table_stacking.auto_stack_tables(
420
+ feature_specs, global_device_count, num_sc_per_device
421
+ )
422
+ else:
423
+ raise ValueError(
424
+ f"Unsupported table stacking {table_stacking}, must be"
425
+ "None, 'auto', or sequences of table names to stack."
426
+ )
427
+ else:
428
+ if isinstance(table_stacking, list) and len(table_stacking) > 0:
429
+ elem = table_stacking[0]
430
+ # List of lists of table names.
431
+ if isinstance(elem, list):
432
+ for table_names in table_stacking:
433
+ jte_table_stacking.stack_tables(
434
+ feature_specs,
435
+ table_names,
436
+ global_device_count,
437
+ num_sc_per_device,
438
+ )
439
+ # Single list of table names.
440
+ elif isinstance(elem, str):
441
+ jte_table_stacking.stack_tables(
442
+ feature_specs,
443
+ table_stacking,
444
+ global_device_count,
445
+ num_sc_per_device,
446
+ )
447
+ else:
448
+ raise ValueError(
449
+ f"Unsupported table stacking {table_stacking}, "
450
+ "must be None, 'auto', or sequences of table names "
451
+ "to stack."
452
+ )
453
+
454
+ # Adjust any non-stacked tables to prepare for training.
455
+ embedding.prepare_feature_specs_for_training(
456
+ feature_specs, global_device_count, num_sc_per_device
457
+ )
458
+
459
+ # Collect all stacked tables.
460
+ table_specs = embedding_utils.get_table_specs(feature_specs)
461
+ table_stacks = embedding_utils.get_table_stacks(table_specs)
462
+ stacked_table_specs = {
463
+ stack_name: stack[0].stacked_table_spec
464
+ for stack_name, stack in table_stacks.items()
465
+ }
466
+
467
+ # Create variables for all stacked tables and slot variables.
468
+ with sparsecore_distribution.scope():
469
+ self._table_and_slot_variables = {
470
+ table_name: self._add_table_variable(
471
+ table_stack,
472
+ add_slot_variables=self.trainable,
473
+ num_shards=num_variable_shards,
474
+ )
475
+ for table_name, table_stack in table_stacks.items()
476
+ }
477
+
478
+ # Create a step-counter variable for use in custom table gradients.
479
+ # This must be a floating-point type so we can get a real gradient
480
+ # for it. It will automatically be updated with each application of
481
+ # the optimizer, since the next iteration is returned in the
482
+ # gradient.
483
+ sharded_zero_initializer = ShardedInitializer(
484
+ "zeros",
485
+ keras.distribution.TensorLayout(
486
+ [], sparsecore_layout.device_mesh
487
+ ),
488
+ )
489
+ self._iterations = self.add_weight(
490
+ shape=(),
491
+ name="iteration",
492
+ initializer=sharded_zero_initializer,
493
+ dtype="float32",
494
+ trainable=True,
495
+ )
496
+ self._iterations.overwrite_with_gradient = True
497
+
498
+ with cpu_distribution.scope():
499
+ # Create variables to track static buffer size and max IDs for each
500
+ # table during preprocessing. These variables are shared across all
501
+ # processes on CPU. We don't add these via `add_weight` because we
502
+ # can't have them passed to the training function.
503
+ replicated_zeros_initializer = ShardedInitializer(
504
+ "zeros", cpu_layout
505
+ )
506
+
507
+ with backend.name_scope(self.name, caller=self):
508
+ self._preprocessing_buffer_size = {
509
+ table_name: backend.Variable(
510
+ initializer=replicated_zeros_initializer,
511
+ shape=(),
512
+ dtype=backend.standardize_dtype("int32"),
513
+ trainable=False,
514
+ name=table_name + ":preprocessing:buffer_size",
515
+ )
516
+ for table_name in stacked_table_specs.keys()
517
+ }
518
+ self._preprocessing_max_unique_ids_per_partition = {
519
+ table_name: backend.Variable(
520
+ shape=(),
521
+ name=table_name
522
+ + ":preprocessing:max_unique_ids_per_partition",
523
+ initializer=replicated_zeros_initializer,
524
+ dtype=backend.standardize_dtype("int32"),
525
+ trainable=False,
526
+ )
527
+ for table_name in stacked_table_specs.keys()
528
+ }
529
+
530
+ self._preprocessing_max_ids_per_partition = {
531
+ table_name: backend.Variable(
532
+ shape=(),
533
+ name=table_name
534
+ + ":preprocessing:max_ids_per_partition",
535
+ initializer=replicated_zeros_initializer,
536
+ dtype=backend.standardize_dtype("int32"),
537
+ trainable=False,
538
+ )
539
+ for table_name in stacked_table_specs.keys()
540
+ }
541
+
542
+ self._config = jte_embedding_lookup.EmbeddingLookupConfiguration(
543
+ feature_specs,
544
+ mesh=mesh,
545
+ table_partition=_get_partition_spec(sparsecore_layout),
546
+ samples_partition=_get_partition_spec(sparsecore_layout),
547
+ table_layout=sparsecore_layout.backend_layout,
548
+ )
549
+
550
+ self._sparsecore_built = True
551
+
552
+ def _sparsecore_symbolic_preprocess(
553
+ self,
554
+ inputs: dict[str, types.Tensor],
555
+ weights: dict[str, types.Tensor] | None,
556
+ training: bool = False,
557
+ ) -> dict[str, dict[str, embedding_utils.ShardedCooMatrix]]:
558
+ """Allow preprocess(...) with `keras.Input`s.
559
+
560
+ This is to support creating functional models via:
561
+ ```python
562
+ inputs = keras.Input(shape=(None), dtype="int32")
563
+ weights = keras.Input(shape=(None), dtype="float32")
564
+ preprocessed_inputs = distributed_embedding.preprocess(inputs, weights)
565
+ outputs = distributed_embedding(preprocessed_inputs)
566
+ model = keras.Model(inputs=preprocessed_inputs, outputs=outputs)
567
+ ```
568
+
569
+ Args:
570
+ inputs: SparseCore path->tensor input ID's tensors.
571
+ weights: Optional Sparsecore path->tensor input weights tensors.
572
+ training: Whether the layer is training or not.
573
+
574
+ Returns:
575
+ Symbolic preprocessed input tensors to the layer/model.
576
+ """
577
+ # Arguments are currently ignored since the input shape is governed
578
+ # by the stacked table configuration.
579
+ del inputs, weights, training
580
+
581
+ # Each stacked-table gets a ShardedCooMatrix.
582
+ table_specs = embedding_utils.get_table_specs(
583
+ self._config.feature_specs
584
+ )
585
+ table_stacks = embedding_utils.get_table_stacks(table_specs)
586
+ stacked_table_specs = {
587
+ stack_name: stack[0].stacked_table_spec
588
+ for stack_name, stack in table_stacks.items()
589
+ }
590
+
591
+ def _compute_table_output_spec(
592
+ stacked_table_spec: embedding_spec.StackedTableSpec,
593
+ ) -> embedding_utils.ShardedCooMatrix:
594
+ # The true shape of the components in the ShardedCooMatrix depends
595
+ # on the hardware configuration (# devices, sparsecores),
596
+ # properties of the input data (# max IDs, unique IDs), and other
597
+ # hints like a suggested internal buffer size. Some of the
598
+ # calculations are currently a bit in flux as we experiment with
599
+ # memory trade-offs. For the purposes of input/output sizes,
600
+ # however, the size could be viewed as dynamic 1D without affecting
601
+ # the output spec sizes.
602
+ del stacked_table_spec
603
+ return embedding_utils.ShardedCooMatrix(
604
+ # Mark these as `Input`s since that's how they will be used when
605
+ # constructing a functional Keras model.
606
+ shard_starts=keras.Input(shape=tuple(), dtype="int32"),
607
+ shard_ends=keras.Input(shape=tuple(), dtype="int32"),
608
+ col_ids=keras.Input(shape=tuple(), dtype="int32"),
609
+ row_ids=keras.Input(shape=tuple(), dtype="int32"),
610
+ values=keras.Input(shape=tuple(), dtype="float32"),
611
+ )
612
+
613
+ preprocessed = keras.tree.map_structure(
614
+ _compute_table_output_spec, stacked_table_specs
615
+ )
616
+
617
+ return {"inputs": preprocessed}
618
+
619
+ def _sparsecore_preprocess(
620
+ self,
621
+ inputs: dict[str, types.Tensor],
622
+ weights: dict[str, types.Tensor] | None,
623
+ training: bool = False,
624
+ ) -> dict[str, dict[str, embedding_utils.ShardedCooMatrix]]:
625
+ if any(
626
+ isinstance(x, jax.core.Tracer) for x in keras.tree.flatten(inputs)
627
+ ):
628
+ raise ValueError(
629
+ "DistributedEmbedding.preprocess(...) does not support"
630
+ " jit-compilation"
631
+ )
632
+
633
+ if not self._sparsecore_built:
634
+ self._sparsecore_build()
635
+
636
+ # Support symbolic KerasTensors (i.e. keras.Input).
637
+ if any(
638
+ isinstance(x, keras.KerasTensor) for x in keras.tree.flatten(inputs)
639
+ ):
640
+ return self._sparsecore_symbolic_preprocess(
641
+ inputs, weights, training
642
+ )
643
+
644
+ samples = embedding_utils.create_feature_samples(
645
+ self._config.feature_specs, inputs, weights
646
+ )
647
+
648
+ layout = self._sparsecore_layout
649
+ mesh = layout.device_mesh.backend_mesh
650
+ global_device_count = mesh.devices.size
651
+ local_device_count = mesh.local_mesh.devices.size
652
+ num_sc_per_device = jte_utils.num_sparsecores_per_device(
653
+ mesh.devices.item(0)
654
+ )
655
+
656
+ # Get current buffer size/max_ids.
657
+ previous_max_ids_per_partition = keras.tree.map_structure(
658
+ lambda max_ids_per_partition: max_ids_per_partition.value.item(),
659
+ self._preprocessing_max_ids_per_partition,
660
+ )
661
+ previous_max_unique_ids_per_partition = keras.tree.map_structure(
662
+ lambda max_unique_ids_per_partition: (
663
+ max_unique_ids_per_partition.value.item()
664
+ ),
665
+ self._preprocessing_max_unique_ids_per_partition,
666
+ )
667
+ previous_buffer_size = keras.tree.map_structure(
668
+ lambda buffer_size: buffer_size.value.item(),
669
+ self._preprocessing_buffer_size,
670
+ )
671
+
672
+ preprocessed, stats = embedding_utils.stack_and_shard_samples(
673
+ self._config.feature_specs,
674
+ samples,
675
+ local_device_count,
676
+ global_device_count,
677
+ num_sc_per_device,
678
+ static_buffer_size=previous_buffer_size,
679
+ )
680
+
681
+ # Extract max unique IDs and buffer sizes.
682
+ # We need to replicate this value across all local CPU devices.
683
+ if training:
684
+ num_local_cpu_devices = jax.local_device_count("cpu")
685
+ local_max_ids_per_partition = {
686
+ table_name: np.repeat(
687
+ # Maximum across all partitions and previous max.
688
+ np.maximum(
689
+ np.max(elems),
690
+ previous_max_ids_per_partition[table_name],
691
+ ),
692
+ num_local_cpu_devices,
693
+ )
694
+ for table_name, elems in stats.max_ids_per_partition.items()
695
+ }
696
+ local_max_unique_ids_per_partition = {
697
+ name: np.repeat(
698
+ # Maximum across all partitions and previous max.
699
+ np.maximum(
700
+ np.max(elems),
701
+ previous_max_unique_ids_per_partition[name],
702
+ ),
703
+ num_local_cpu_devices,
704
+ )
705
+ for name, elems in stats.max_unique_ids_per_partition.items()
706
+ }
707
+ local_buffer_size = {
708
+ table_name: np.repeat(
709
+ np.maximum(
710
+ np.max(
711
+ # Round values up to the next multiple of 8.
712
+ # Currently using this as a proxy for the actual
713
+ # required buffer size.
714
+ ((elems + 7) // 8) * 8
715
+ )
716
+ * global_device_count
717
+ * num_sc_per_device
718
+ * local_device_count
719
+ * num_sc_per_device,
720
+ previous_buffer_size[table_name],
721
+ ),
722
+ num_local_cpu_devices,
723
+ )
724
+ for table_name, elems in stats.max_ids_per_partition.items()
725
+ }
726
+
727
+ # Aggregate variables across all processes/devices.
728
+ max_across_cpus = jax.pmap(
729
+ lambda x: jax.lax.pmax( # type: ignore[no-untyped-call]
730
+ x, "all_cpus"
731
+ ),
732
+ axis_name="all_cpus",
733
+ devices=self._cpu_layout.device_mesh.backend_mesh.devices,
734
+ )
735
+ new_max_ids_per_partition = max_across_cpus(
736
+ local_max_ids_per_partition
737
+ )
738
+ new_max_unique_ids_per_partition = max_across_cpus(
739
+ local_max_unique_ids_per_partition
740
+ )
741
+ new_buffer_size = max_across_cpus(local_buffer_size)
742
+
743
+ # Assign new preprocessing parameters.
744
+ with self._cpu_distribution.scope():
745
+ # For each process, all max ids/buffer sizes are replicated
746
+ # across all local devices. Take the value from the first
747
+ # device.
748
+ keras.tree.map_structure(
749
+ lambda var, values: var.assign(values[0]),
750
+ self._preprocessing_max_ids_per_partition,
751
+ new_max_ids_per_partition,
752
+ )
753
+ keras.tree.map_structure(
754
+ lambda var, values: var.assign(values[0]),
755
+ self._preprocessing_max_unique_ids_per_partition,
756
+ new_max_unique_ids_per_partition,
757
+ )
758
+ keras.tree.map_structure(
759
+ lambda var, values: var.assign(values[0]),
760
+ self._preprocessing_buffer_size,
761
+ new_buffer_size,
762
+ )
763
+ # Update parameters in the underlying feature specs.
764
+ int_max_ids_per_partition = keras.tree.map_structure(
765
+ lambda varray: varray.item(), new_max_ids_per_partition
766
+ )
767
+ int_max_unique_ids_per_partition = keras.tree.map_structure(
768
+ lambda varray: varray.item(),
769
+ new_max_unique_ids_per_partition,
770
+ )
771
+ embedding_utils.update_stacked_table_specs(
772
+ self._config.feature_specs,
773
+ int_max_ids_per_partition,
774
+ int_max_unique_ids_per_partition,
775
+ )
776
+
777
+ return {"inputs": preprocessed}
778
+
779
+ def _sparsecore_call(
780
+ self,
781
+ inputs: dict[str, types.Tensor],
782
+ weights: dict[str, types.Tensor] | None = None,
783
+ training: bool = False,
784
+ **kwargs: Any,
785
+ ) -> dict[str, types.Tensor]:
786
+ assert weights is None
787
+
788
+ if not self._sparsecore_built:
789
+ self._sparsecore_build()
790
+
791
+ table_and_slots = keras.tree.map_structure(
792
+ lambda var: var.value, self._table_and_slot_variables
793
+ )
794
+ with self._sparsecore_distribution.scope():
795
+ lookup_func = jax.jit(
796
+ jte_embedding_lookup.embedding_lookup, static_argnames="config"
797
+ )
798
+ out: dict[str, types.Tensor] = lookup_func(
799
+ self._config, inputs, table_and_slots, self._iterations.value
800
+ )
801
+ return out
802
+
803
+ def set_embedding_tables(self, tables: Mapping[str, ArrayLike]) -> None:
804
+ """Sets the embedding tables to specific (unsharded) values.
805
+
806
+ Args:
807
+ tables: Mapping of table name -> table values.
808
+ """
809
+ if "default_device" in self._placement_to_path_to_feature_config:
810
+ self._default_device_set_tables(tables)
811
+
812
+ if "sparsecore" in self._placement_to_path_to_feature_config:
813
+ self._sparsecore_set_tables(tables)
814
+
815
+ def _default_device_set_tables(
816
+ self, tables: Mapping[str, ArrayLike]
817
+ ) -> None:
818
+ if not self.built:
819
+ raise ValueError("Layer must first be built before setting tables.")
820
+
821
+ if "default_device" in self._placement_to_path_to_feature_config:
822
+ table_to_embedding_layer = {}
823
+ for (
824
+ path,
825
+ feature_config,
826
+ ) in self._placement_to_path_to_feature_config[
827
+ "default_device"
828
+ ].items():
829
+ table_to_embedding_layer[feature_config.table] = (
830
+ self._default_device_embedding_layers[path]
831
+ )
832
+
833
+ for table, embedding_layer in table_to_embedding_layer.items():
834
+ table_values = tables.get(table.name, None)
835
+ if table_values is not None:
836
+ if embedding_layer.lora_enabled:
837
+ raise ValueError("Cannot set table if LoRA is enabled.")
838
+ # pylint: disable-next=protected-access
839
+ embedding_layer._embeddings.assign(table_values)
840
+
841
+ def _sparsecore_set_tables(self, tables: Mapping[str, ArrayLike]) -> None:
842
+ if not self._sparsecore_built:
843
+ self._sparsecore_build()
844
+
845
+ config = self._config
846
+ num_table_shards = config.mesh.devices.size * config.num_sc_per_device
847
+ table_specs = embedding_utils.get_table_specs(config.feature_specs)
848
+ sharded_tables = embedding_utils.stack_and_shard_tables(
849
+ table_specs,
850
+ tables,
851
+ num_table_shards,
852
+ )
853
+
854
+ device_tables = jax.device_put(
855
+ jax.tree.map(
856
+ # Flatten shard dimension to allow auto-sharding to split the
857
+ # array.
858
+ lambda table: table.reshape((-1, table.shape[-1])),
859
+ sharded_tables,
860
+ ),
861
+ self._sparsecore_layout.backend_layout,
862
+ )
863
+
864
+ # Assign stacked table variables to the device values.
865
+ keras.tree.map_structure_up_to(
866
+ device_tables,
867
+ lambda table_and_slot_variables,
868
+ table_value: table_and_slot_variables[0].assign(table_value),
869
+ self._table_and_slot_variables,
870
+ device_tables,
871
+ )
872
+
873
+ def _sparsecore_get_embedding_tables(self) -> dict[str, ArrayLike]:
874
+ if not self._sparsecore_built:
875
+ self.sparsecore_build()
876
+
877
+ config = self._config
878
+ num_table_shards = config.mesh.devices.size * config.num_sc_per_device
879
+ table_specs = embedding_utils.get_table_specs(config.feature_specs)
880
+
881
+ # Extract only the table variables, not the gradient slot variables.
882
+ table_variables = {
883
+ name: jax.device_get(table_and_slots[0].value)
884
+ for name, table_and_slots in self._table_and_slot_variables.items()
885
+ }
886
+
887
+ return typing.cast(
888
+ dict[str, ArrayLike],
889
+ embedding_utils.unshard_and_unstack_tables(
890
+ table_specs, table_variables, num_table_shards
891
+ ),
892
+ )