keras-rs-nightly 0.0.1.dev2025021903__py3-none-any.whl → 0.3.1.dev202512130338__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. keras_rs/__init__.py +9 -28
  2. keras_rs/layers/__init__.py +37 -0
  3. keras_rs/losses/__init__.py +19 -0
  4. keras_rs/metrics/__init__.py +16 -0
  5. keras_rs/src/layers/embedding/base_distributed_embedding.py +1151 -0
  6. keras_rs/src/layers/embedding/distributed_embedding.py +33 -0
  7. keras_rs/src/layers/embedding/distributed_embedding_config.py +132 -0
  8. keras_rs/src/layers/embedding/embed_reduce.py +309 -0
  9. keras_rs/src/layers/embedding/jax/__init__.py +0 -0
  10. keras_rs/src/layers/embedding/jax/checkpoint_utils.py +104 -0
  11. keras_rs/src/layers/embedding/jax/config_conversion.py +468 -0
  12. keras_rs/src/layers/embedding/jax/distributed_embedding.py +829 -0
  13. keras_rs/src/layers/embedding/jax/embedding_lookup.py +276 -0
  14. keras_rs/src/layers/embedding/jax/embedding_utils.py +217 -0
  15. keras_rs/src/layers/embedding/tensorflow/__init__.py +0 -0
  16. keras_rs/src/layers/embedding/tensorflow/config_conversion.py +363 -0
  17. keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py +436 -0
  18. keras_rs/src/layers/feature_interaction/__init__.py +0 -0
  19. keras_rs/src/layers/{modeling → feature_interaction}/dot_interaction.py +116 -25
  20. keras_rs/src/layers/{modeling → feature_interaction}/feature_cross.py +40 -22
  21. keras_rs/src/layers/retrieval/brute_force_retrieval.py +16 -65
  22. keras_rs/src/layers/retrieval/hard_negative_mining.py +94 -0
  23. keras_rs/src/layers/retrieval/remove_accidental_hits.py +97 -0
  24. keras_rs/src/layers/retrieval/retrieval.py +127 -0
  25. keras_rs/src/layers/retrieval/sampling_probability_correction.py +63 -0
  26. keras_rs/src/losses/__init__.py +0 -0
  27. keras_rs/src/losses/list_mle_loss.py +212 -0
  28. keras_rs/src/losses/pairwise_hinge_loss.py +90 -0
  29. keras_rs/src/losses/pairwise_logistic_loss.py +99 -0
  30. keras_rs/src/losses/pairwise_loss.py +165 -0
  31. keras_rs/src/losses/pairwise_loss_utils.py +39 -0
  32. keras_rs/src/losses/pairwise_mean_squared_error.py +133 -0
  33. keras_rs/src/losses/pairwise_soft_zero_one_loss.py +98 -0
  34. keras_rs/src/metrics/__init__.py +0 -0
  35. keras_rs/src/metrics/dcg.py +161 -0
  36. keras_rs/src/metrics/mean_average_precision.py +130 -0
  37. keras_rs/src/metrics/mean_reciprocal_rank.py +121 -0
  38. keras_rs/src/metrics/ndcg.py +197 -0
  39. keras_rs/src/metrics/precision_at_k.py +117 -0
  40. keras_rs/src/metrics/ranking_metric.py +260 -0
  41. keras_rs/src/metrics/ranking_metrics_utils.py +257 -0
  42. keras_rs/src/metrics/recall_at_k.py +108 -0
  43. keras_rs/src/metrics/utils.py +70 -0
  44. keras_rs/src/types.py +43 -14
  45. keras_rs/src/utils/doc_string_utils.py +53 -0
  46. keras_rs/src/utils/keras_utils.py +52 -3
  47. keras_rs/src/utils/tpu_test_utils.py +120 -0
  48. keras_rs/src/version.py +1 -1
  49. {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/METADATA +88 -8
  50. keras_rs_nightly-0.3.1.dev202512130338.dist-info/RECORD +58 -0
  51. {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/WHEEL +1 -1
  52. keras_rs/api/__init__.py +0 -9
  53. keras_rs/api/layers/__init__.py +0 -11
  54. keras_rs_nightly-0.0.1.dev2025021903.dist-info/RECORD +0 -19
  55. /keras_rs/src/layers/{modeling → embedding}/__init__.py +0 -0
  56. {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,829 @@
1
+ """JAX implementation of the TPU embedding layer."""
2
+
3
+ import math
4
+ import typing
5
+ from typing import Any, Mapping, Sequence, Union
6
+
7
+ import jax
8
+ import keras
9
+ import numpy as np
10
+ from jax import numpy as jnp
11
+ from jax.experimental import layout as jax_layout
12
+ from jax.experimental import multihost_utils
13
+ from jax_tpu_embedding.sparsecore.lib.nn import embedding
14
+ from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec
15
+ from jax_tpu_embedding.sparsecore.lib.nn import (
16
+ table_stacking as jte_table_stacking,
17
+ )
18
+ from jax_tpu_embedding.sparsecore.utils import utils as jte_utils
19
+
20
+ from keras_rs.src import types
21
+ from keras_rs.src.layers.embedding import base_distributed_embedding
22
+ from keras_rs.src.layers.embedding import distributed_embedding_config as config
23
+ from keras_rs.src.layers.embedding.jax import config_conversion
24
+ from keras_rs.src.layers.embedding.jax import (
25
+ embedding_lookup as jte_embedding_lookup,
26
+ )
27
+ from keras_rs.src.layers.embedding.jax import embedding_utils
28
+ from keras_rs.src.types import Nested
29
+ from keras_rs.src.utils import keras_utils
30
+
31
+ if jax.__version_info__ >= (0, 8, 0):
32
+ from jax import shard_map
33
+ else:
34
+ from jax.experimental.shard_map import shard_map # type: ignore[assignment]
35
+
36
+
37
+ ArrayLike = Union[np.ndarray[Any, Any], jax.Array]
38
+ FeatureConfig = config.FeatureConfig
39
+
40
+
41
+ def _get_partition_spec(
42
+ layout: (
43
+ keras.distribution.TensorLayout
44
+ | jax_layout.Format
45
+ | jax.sharding.NamedSharding
46
+ | jax.sharding.PartitionSpec
47
+ ),
48
+ ) -> Any:
49
+ """Extracts the partition spec from a layout or sharding."""
50
+ if isinstance(layout, keras.distribution.TensorLayout):
51
+ layout = layout.backend_layout
52
+
53
+ if isinstance(layout, jax_layout.Format):
54
+ layout = layout.sharding
55
+
56
+ if isinstance(layout, jax.sharding.NamedSharding):
57
+ layout = layout.spec
58
+
59
+ return layout
60
+
61
+
62
+ class ShardedInitializer(keras.initializers.Initializer):
63
+ """Wraps an initializer to prepare for use with embedding tables.
64
+
65
+ Jit-compiles the function and applies optimal output sharding to
66
+ allow initialization on device.
67
+ """
68
+
69
+ def __init__(
70
+ self,
71
+ initializer: keras.initializers.Initializer | str,
72
+ layout: keras.distribution.TensorLayout | None,
73
+ ):
74
+ if isinstance(initializer, str):
75
+ initializer = keras.initializers.get(initializer)
76
+
77
+ self._initializer = initializer
78
+ self._layout = layout
79
+
80
+ def __call__(
81
+ self, shape: types.Shape, dtype: types.DType | None = None
82
+ ) -> jax.Array:
83
+ if self._layout is not None:
84
+ compiled_initializer = jax.jit(
85
+ self._initializer,
86
+ out_shardings=self._layout.backend_layout,
87
+ static_argnames=["shape", "dtype"],
88
+ )
89
+ output: jax.Array = compiled_initializer(shape, dtype)
90
+ return output
91
+
92
+ output = self._initializer(shape, dtype)
93
+ return output
94
+
95
+
96
+ class StackedTableInitializer(keras.initializers.Initializer):
97
+ """Initializes a single stacked table from multiple table initializers."""
98
+
99
+ def __init__(
100
+ self,
101
+ table_specs: Nested[embedding_spec.TableSpec],
102
+ num_shards: int,
103
+ layout: keras.distribution.TensorLayout,
104
+ seed: int | keras.random.SeedGenerator | jax.Array = 0,
105
+ ):
106
+ # Sort table specs so we can simply concatenate them when assembling the
107
+ # stacked table.
108
+ self._table_specs = sorted(
109
+ keras.tree.flatten(table_specs),
110
+ key=lambda table_spec: (
111
+ table_spec.setting_in_stack.row_offset_in_shard,
112
+ ),
113
+ )
114
+ self._num_shards = num_shards
115
+ self._layout = layout
116
+ self._key = keras.src.backend.jax.random.jax_draw_seed(seed)
117
+
118
+ def _initialize_shard(
119
+ self,
120
+ keys: jax.Array,
121
+ shape: tuple[int, int],
122
+ dtype: Any,
123
+ num_shards_per_device: int,
124
+ ) -> jax.Array:
125
+ """Initializes a single shard of a stacked table."""
126
+ del shape # Unused.
127
+ table_shards: list[jax.Array] = []
128
+ # NOTE: the following ignores padding, rotations in shard, and
129
+ # mod-sharding, assuming all initializers are shard-independent.
130
+ for i in range(num_shards_per_device):
131
+ for j, table_spec in enumerate(self._table_specs):
132
+ setting_in_stack = table_spec.setting_in_stack
133
+ table_shard_shape = (
134
+ setting_in_stack.padded_vocab_size // self._num_shards,
135
+ setting_in_stack.padded_embedding_dim,
136
+ )
137
+ initializer = table_spec.initializer
138
+ table_shards.append(
139
+ initializer(keys[i, j], table_shard_shape, dtype)
140
+ )
141
+
142
+ return jnp.concatenate(table_shards, axis=0)
143
+
144
+ def __call__(
145
+ self, shape: types.Shape, dtype: types.DType | None = None
146
+ ) -> jax.Array:
147
+ stacked_table_spec = typing.cast(
148
+ embedding_spec.StackedTableSpec,
149
+ self._table_specs[0].stacked_table_spec,
150
+ )
151
+
152
+ # Input shape is governed by the table specs.
153
+ assert shape == (
154
+ stacked_table_spec.stack_vocab_size,
155
+ stacked_table_spec.stack_embedding_dim,
156
+ )
157
+
158
+ layout = self._layout
159
+ backend_layout = layout.backend_layout
160
+ backend_mesh = layout.device_mesh.backend_mesh
161
+ num_devices_along_axis = backend_mesh.shape[layout.axes[0]]
162
+ num_shards_per_device = self._num_shards // num_devices_along_axis
163
+ shard_shape = (
164
+ stacked_table_spec.stack_vocab_size // num_devices_along_axis,
165
+ stacked_table_spec.stack_embedding_dim,
166
+ )
167
+
168
+ sharded_initializer = jax.jit(
169
+ shard_map(
170
+ lambda keys: self._initialize_shard(
171
+ keys, shard_shape, dtype, num_shards_per_device
172
+ ),
173
+ mesh=backend_mesh,
174
+ in_specs=_get_partition_spec(backend_layout),
175
+ out_specs=_get_partition_spec(backend_layout),
176
+ ),
177
+ out_shardings=backend_layout,
178
+ )
179
+
180
+ keys = jax.random.split(
181
+ self._key, (self._num_shards, len(self._table_specs))
182
+ )
183
+ # Try extracting seeds from the existing table initializers.
184
+ for i, table_spec in enumerate(self._table_specs):
185
+ initializer = table_spec.initializer
186
+ if isinstance(
187
+ initializer, config_conversion.WrappedKerasInitializer
188
+ ):
189
+ initializer_key = initializer.key()
190
+ if initializer_key is not None:
191
+ col = jax.random.split(initializer_key, self._num_shards)
192
+ keys = keys.at[:, i].set(col)
193
+
194
+ output: jax.Array = sharded_initializer(keys)
195
+ return output
196
+
197
+
198
+ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
199
+ """JAX implementation of the TPU embedding layer."""
200
+
201
+ def _create_sparsecore_distribution(
202
+ self, sparsecore_axis_name: str = "sparsecore"
203
+ ) -> tuple[
204
+ keras.distribution.ModelParallel, keras.distribution.TensorLayout
205
+ ]:
206
+ """SparseCore requires a specific layout.
207
+
208
+ The mesh must be 1D, must use all TPUs available, and must shard all
209
+ tables across all devices.
210
+
211
+ Args:
212
+ sparsecore_axis_name: The name of the sparsecore axis.
213
+
214
+ Returns:
215
+ A Keras distribution to use for all sparsecore operations.
216
+ """
217
+ all_devices = jax.devices()
218
+ axes = [sparsecore_axis_name]
219
+ device_mesh = keras.distribution.DeviceMesh(
220
+ (len(all_devices),), axes, all_devices
221
+ )
222
+ sparsecore_layout = keras.distribution.TensorLayout(axes, device_mesh)
223
+ # Custom sparsecore layout with tiling.
224
+ LayoutClass = (
225
+ jax_layout.Layout
226
+ if jax.__version_info__ >= (0, 6, 3)
227
+ else jax_layout.DeviceLocalLayout # type: ignore
228
+ )
229
+ layout = (
230
+ LayoutClass(major_to_minor=(0, 1), tiling=((8,),)) # type: ignore
231
+ if jax.__version_info__ >= (0, 7, 1)
232
+ else LayoutClass(major_to_minor=(0, 1), _tiling=((8,),)) # type: ignore
233
+ )
234
+ # pylint: disable-next=protected-access
235
+ sparsecore_layout._backend_layout = jax_layout.Format(
236
+ layout, # type: ignore
237
+ jax.sharding.NamedSharding(
238
+ device_mesh.backend_mesh,
239
+ jax.sharding.PartitionSpec(
240
+ axes # type: ignore[no-untyped-call]
241
+ ),
242
+ ),
243
+ )
244
+ layout_map = keras.distribution.LayoutMap(device_mesh=device_mesh)
245
+ path = self.path
246
+ if path is None:
247
+ # Layer hasn't been properly built yet. Use current layer name.
248
+ path = self.name
249
+ layout_map[path + "/var"] = sparsecore_layout
250
+ sparsecore_distribution = keras.distribution.ModelParallel(
251
+ layout_map=layout_map
252
+ )
253
+ return sparsecore_distribution, sparsecore_layout
254
+
255
+ def _add_sparsecore_weight(
256
+ self,
257
+ name: str,
258
+ shape: tuple[int, int],
259
+ initializer: jax.nn.initializers.Initializer,
260
+ dtype: Any,
261
+ overwrite_with_gradient: bool,
262
+ ) -> keras.Variable:
263
+ var = self.add_weight(
264
+ name=name, shape=shape, initializer=initializer, dtype=dtype
265
+ )
266
+ var.overwrite_with_gradient = overwrite_with_gradient
267
+ return var
268
+
269
+ def _add_table_variable(
270
+ self,
271
+ table_specs: Sequence[embedding_spec.TableSpec],
272
+ num_shards: int,
273
+ add_slot_variables: bool,
274
+ ) -> embedding.EmbeddingVariables:
275
+ stacked_table_spec = typing.cast(
276
+ embedding_spec.StackedTableSpec, table_specs[0].stacked_table_spec
277
+ )
278
+ optimizer = stacked_table_spec.optimizer
279
+ num_slot_variables = optimizer.slot_variables_count()
280
+ table_shape = (
281
+ stacked_table_spec.stack_vocab_size,
282
+ stacked_table_spec.stack_embedding_dim,
283
+ )
284
+
285
+ # Make a stacked embedding table initializer.
286
+ table_initializers = [
287
+ config_conversion.jax_to_keras_initializer(table_spec.initializer)
288
+ for table_spec in table_specs
289
+ ]
290
+ # If all initializers are the same, we can use a single sharded
291
+ # initializer. Otherwise, we need to interleave individual stacked table
292
+ # shards.
293
+ sparsecore_layout = self._sparsecore_layout
294
+ stacked_table_initializer = ShardedInitializer(
295
+ table_initializers[0], sparsecore_layout
296
+ )
297
+ if not all(
298
+ initializer == table_initializers[0]
299
+ for initializer in table_initializers
300
+ ):
301
+ stacked_table_initializer = StackedTableInitializer(
302
+ table_specs, num_shards, sparsecore_layout
303
+ )
304
+
305
+ variable_name = f"var:{stacked_table_spec.stack_name}:table"
306
+ table_variable = self._add_sparsecore_weight(
307
+ name=variable_name,
308
+ shape=table_shape,
309
+ initializer=stacked_table_initializer,
310
+ dtype="float32",
311
+ overwrite_with_gradient=True,
312
+ )
313
+
314
+ slot_variables = None
315
+ if add_slot_variables:
316
+ # All optimizers for a given stacked table are guaranteed to be the
317
+ # same, so we can use a single sharded initializer for the entire
318
+ # stacked table.
319
+ slot_initializers = optimizer.slot_variables_initializers()
320
+ # Try extracting field names from variables, otherwise just use the
321
+ # count.
322
+ slot_names = range(num_slot_variables)
323
+ if hasattr(slot_initializers, "_fields"):
324
+ slot_names = slot_initializers._fields
325
+
326
+ slot_variables = tuple(
327
+ self._add_sparsecore_weight(
328
+ name=f"{variable_name}:slot:{slot_name}",
329
+ shape=table_shape,
330
+ initializer=ShardedInitializer(
331
+ config_conversion.jax_to_keras_initializer(initializer),
332
+ sparsecore_layout,
333
+ ),
334
+ dtype=jnp.float32,
335
+ overwrite_with_gradient=True,
336
+ )
337
+ for slot_name, initializer in zip(slot_names, slot_initializers)
338
+ )
339
+ slot_variables = keras.tree.pack_sequence_as(
340
+ slot_initializers, slot_variables
341
+ )
342
+
343
+ return embedding.EmbeddingVariables(table_variable, slot_variables)
344
+
345
+ @keras_utils.no_automatic_dependency_tracking
346
+ def _sparsecore_init(
347
+ self,
348
+ feature_configs: dict[str, FeatureConfig],
349
+ table_stacking: str | Sequence[str] | Sequence[Sequence[str]],
350
+ ) -> None:
351
+ if not self._has_sparsecore():
352
+ raise ValueError(
353
+ "Not sparse cores available, cannot use explicit sparsecore"
354
+ " placement."
355
+ )
356
+
357
+ self._sc_feature_configs = feature_configs
358
+ self._sparsecore_built = False
359
+ # Fill in any empty default settings.
360
+ for feature_config in keras.tree.flatten(self._sc_feature_configs):
361
+ if feature_config.table.initializer is None:
362
+ table = feature_config.table
363
+ table.initializer = keras.initializers.TruncatedNormal(
364
+ mean=0.0, stddev=1.0 / math.sqrt(float(table.embedding_dim))
365
+ )
366
+
367
+ # Actual stacking of tables is done in build() to ensure the
368
+ # distribution is set up correctly.
369
+ self._table_stacking = table_stacking
370
+
371
+ def _sparsecore_build(
372
+ self, input_shapes: Nested[types.Shape] | None = None
373
+ ) -> None:
374
+ self.sparsecore_build(input_shapes)
375
+
376
+ @keras_utils.no_automatic_dependency_tracking
377
+ def sparsecore_build(
378
+ self, input_shapes: Nested[types.Shape] | None = None
379
+ ) -> None:
380
+ del input_shapes # Unused.
381
+
382
+ if self._sparsecore_built:
383
+ return
384
+
385
+ feature_specs = config_conversion.keras_to_jte_feature_configs(
386
+ self._sc_feature_configs
387
+ )
388
+
389
+ # Distribution for sparsecore operations.
390
+ sparsecore_distribution, sparsecore_layout = (
391
+ self._create_sparsecore_distribution()
392
+ )
393
+ self._sparsecore_layout = sparsecore_layout
394
+ self._sparsecore_distribution = sparsecore_distribution
395
+
396
+ mesh = sparsecore_distribution.device_mesh.backend_mesh
397
+ global_device_count = mesh.devices.size
398
+ num_sc_per_device = jte_utils.num_sparsecores_per_device(
399
+ mesh.devices.item(0)
400
+ )
401
+ # One table shard per global sparsecore.
402
+ num_variable_shards = global_device_count * num_sc_per_device
403
+
404
+ # Maybe stack tables.
405
+ table_stacking = self._table_stacking
406
+ if table_stacking is not None:
407
+ if isinstance(table_stacking, str):
408
+ if table_stacking == "auto":
409
+ jte_table_stacking.auto_stack_tables(
410
+ feature_specs, global_device_count, num_sc_per_device
411
+ )
412
+ else:
413
+ raise ValueError(
414
+ f"Unsupported table stacking {table_stacking}, must be"
415
+ "None, 'auto', or sequences of table names to stack."
416
+ )
417
+ else:
418
+ if isinstance(table_stacking, list) and len(table_stacking) > 0:
419
+ elem = table_stacking[0]
420
+ # List of lists of table names.
421
+ if isinstance(elem, list):
422
+ for table_names in table_stacking:
423
+ jte_table_stacking.stack_tables(
424
+ feature_specs,
425
+ table_names,
426
+ global_device_count,
427
+ num_sc_per_device,
428
+ )
429
+ # Single list of table names.
430
+ elif isinstance(elem, str):
431
+ jte_table_stacking.stack_tables(
432
+ feature_specs,
433
+ table_stacking,
434
+ global_device_count,
435
+ num_sc_per_device,
436
+ )
437
+ else:
438
+ raise ValueError(
439
+ f"Unsupported table stacking {table_stacking}, "
440
+ "must be None, 'auto', or sequences of table names "
441
+ "to stack."
442
+ )
443
+
444
+ # Adjust any non-stacked tables to prepare for training.
445
+ embedding.prepare_feature_specs_for_training(
446
+ feature_specs, global_device_count, num_sc_per_device
447
+ )
448
+
449
+ # Collect all stacked tables.
450
+ table_specs = embedding.get_table_specs(feature_specs)
451
+ table_stacks = jte_table_stacking.get_table_stacks(table_specs)
452
+
453
+ # Update stacked table stats to max of values across involved tables.
454
+ max_ids_per_partition = {}
455
+ max_unique_ids_per_partition = {}
456
+ required_buffer_size_per_device = {}
457
+ id_drop_counters = {}
458
+ for stack_name, stack in table_stacks.items():
459
+ max_ids_per_partition[stack_name] = np.max(
460
+ np.asarray(
461
+ [s.max_ids_per_partition for s in stack], dtype=np.int32
462
+ )
463
+ )
464
+ max_unique_ids_per_partition[stack_name] = np.max(
465
+ np.asarray(
466
+ [s.max_unique_ids_per_partition for s in stack],
467
+ dtype=np.int32,
468
+ )
469
+ )
470
+
471
+ # Only set the suggested buffer size if set on any individual table.
472
+ valid_buffer_sizes = [
473
+ s.suggested_coo_buffer_size_per_device
474
+ for s in stack
475
+ if s.suggested_coo_buffer_size_per_device is not None
476
+ ]
477
+ if valid_buffer_sizes:
478
+ required_buffer_size_per_device[stack_name] = np.max(
479
+ np.asarray(valid_buffer_sizes, dtype=np.int32)
480
+ )
481
+
482
+ id_drop_counters[stack_name] = 0
483
+
484
+ aggregated_stats = embedding.SparseDenseMatmulInputStats(
485
+ max_ids_per_partition=max_ids_per_partition,
486
+ max_unique_ids_per_partition=max_unique_ids_per_partition,
487
+ required_buffer_size_per_sc=required_buffer_size_per_device,
488
+ id_drop_counters=id_drop_counters,
489
+ )
490
+ embedding.update_preprocessing_parameters(
491
+ feature_specs,
492
+ aggregated_stats,
493
+ num_sc_per_device,
494
+ )
495
+
496
+ # Create variables for all stacked tables and slot variables.
497
+ with sparsecore_distribution.scope():
498
+ self._table_and_slot_variables = {
499
+ table_name: self._add_table_variable(
500
+ table_stack,
501
+ add_slot_variables=self.trainable,
502
+ num_shards=num_variable_shards,
503
+ )
504
+ for table_name, table_stack in table_stacks.items()
505
+ }
506
+
507
+ # Create a step-counter variable for use in custom table gradients.
508
+ # This must be a floating-point type so we can get a real gradient
509
+ # for it. It will automatically be updated with each application of
510
+ # the optimizer, since the next iteration is returned in the
511
+ # gradient.
512
+ sharded_zero_initializer = ShardedInitializer(
513
+ "zeros",
514
+ keras.distribution.TensorLayout(
515
+ [], sparsecore_layout.device_mesh
516
+ ),
517
+ )
518
+ self._iterations = self.add_weight(
519
+ shape=(),
520
+ name="iteration",
521
+ initializer=sharded_zero_initializer,
522
+ dtype="float32",
523
+ trainable=True,
524
+ )
525
+ self._iterations.overwrite_with_gradient = True
526
+
527
+ self._config = jte_embedding_lookup.EmbeddingLookupConfiguration(
528
+ feature_specs,
529
+ mesh=mesh,
530
+ table_partition=_get_partition_spec(sparsecore_layout),
531
+ samples_partition=_get_partition_spec(sparsecore_layout),
532
+ table_layout=sparsecore_layout.backend_layout,
533
+ )
534
+
535
+ self._sparsecore_built = True
536
+
537
+ def _sparsecore_symbolic_preprocess(
538
+ self,
539
+ inputs: dict[str, types.Tensor],
540
+ weights: dict[str, types.Tensor] | None,
541
+ training: bool = False,
542
+ ) -> dict[str, dict[str, embedding_utils.ShardedCooMatrix]]:
543
+ """Allow preprocess(...) with `keras.Input`s.
544
+
545
+ This is to support creating functional models via:
546
+ ```python
547
+ inputs = keras.Input(shape=(None), dtype="int32")
548
+ weights = keras.Input(shape=(None), dtype="float32")
549
+ preprocessed_inputs = distributed_embedding.preprocess(inputs, weights)
550
+ outputs = distributed_embedding(preprocessed_inputs)
551
+ model = keras.Model(inputs=preprocessed_inputs, outputs=outputs)
552
+ ```
553
+
554
+ Args:
555
+ inputs: SparseCore path->tensor input ID's tensors.
556
+ weights: Optional Sparsecore path->tensor input weights tensors.
557
+ training: Whether the layer is training or not.
558
+
559
+ Returns:
560
+ Symbolic preprocessed input tensors to the layer/model.
561
+ """
562
+ # Arguments are currently ignored since the input shape is governed
563
+ # by the stacked table configuration.
564
+ del inputs, weights, training
565
+
566
+ # Each stacked-table gets a ShardedCooMatrix.
567
+ table_specs = embedding.get_table_specs(self._config.feature_specs)
568
+ table_stacks = jte_table_stacking.get_table_stacks(table_specs)
569
+ stacked_table_specs = {
570
+ stack_name: stack[0].stacked_table_spec
571
+ for stack_name, stack in table_stacks.items()
572
+ }
573
+
574
+ def _compute_table_output_spec(
575
+ stacked_table_spec: embedding_spec.StackedTableSpec,
576
+ ) -> embedding_utils.ShardedCooMatrix:
577
+ # The true shape of the components in the ShardedCooMatrix depends
578
+ # on the hardware configuration (# devices, sparsecores),
579
+ # properties of the input data (# max IDs, unique IDs), and other
580
+ # hints like a suggested internal buffer size. Some of the
581
+ # calculations are currently a bit in flux as we experiment with
582
+ # memory trade-offs. For the purposes of input/output sizes,
583
+ # however, the size could be viewed as dynamic 1D without affecting
584
+ # the output spec sizes.
585
+ del stacked_table_spec
586
+ return embedding_utils.ShardedCooMatrix(
587
+ # Mark these as `Input`s since that's how they will be used when
588
+ # constructing a functional Keras model.
589
+ shard_starts=keras.Input(shape=tuple(), dtype="int32"),
590
+ shard_ends=keras.Input(shape=tuple(), dtype="int32"),
591
+ col_ids=keras.Input(shape=tuple(), dtype="int32"),
592
+ row_ids=keras.Input(shape=tuple(), dtype="int32"),
593
+ values=keras.Input(shape=tuple(), dtype="float32"),
594
+ )
595
+
596
+ preprocessed = keras.tree.map_structure(
597
+ _compute_table_output_spec, stacked_table_specs
598
+ )
599
+
600
+ return {"inputs": preprocessed}
601
+
602
+ def _sparsecore_preprocess(
603
+ self,
604
+ inputs: dict[str, types.Tensor],
605
+ weights: dict[str, types.Tensor] | None,
606
+ training: bool = False,
607
+ ) -> dict[str, dict[str, embedding_utils.ShardedCooMatrix]]:
608
+ if any(
609
+ isinstance(x, jax.core.Tracer) for x in keras.tree.flatten(inputs)
610
+ ):
611
+ raise ValueError(
612
+ "DistributedEmbedding.preprocess(...) does not support"
613
+ " jit-compilation"
614
+ )
615
+
616
+ if not self._sparsecore_built:
617
+ self._sparsecore_build()
618
+
619
+ # Support symbolic KerasTensors (i.e. keras.Input).
620
+ if any(
621
+ isinstance(x, keras.KerasTensor) for x in keras.tree.flatten(inputs)
622
+ ):
623
+ return self._sparsecore_symbolic_preprocess(
624
+ inputs, weights, training
625
+ )
626
+
627
+ samples = embedding_utils.create_feature_samples(
628
+ self._config.feature_specs, inputs, weights
629
+ )
630
+
631
+ layout = self._sparsecore_layout
632
+ mesh = layout.device_mesh.backend_mesh
633
+ global_device_count = mesh.devices.size
634
+ local_device_count = mesh.local_mesh.devices.size
635
+ num_sc_per_device = jte_utils.num_sparsecores_per_device(
636
+ mesh.devices.item(0)
637
+ )
638
+
639
+ preprocessed, stats = embedding_utils.stack_and_shard_samples(
640
+ self._config.feature_specs,
641
+ samples,
642
+ local_device_count,
643
+ global_device_count,
644
+ num_sc_per_device,
645
+ )
646
+
647
+ if training:
648
+ # Synchronize input statistics across all devices and update the
649
+ # underlying stacked tables specs in the feature specs.
650
+
651
+ # Gather stats across all processes/devices via process_allgather.
652
+ all_stats = multihost_utils.process_allgather(stats)
653
+ all_stats = jax.tree.map(np.max, all_stats)
654
+
655
+ # Check if stats changed enough to warrant action.
656
+ stacked_table_specs = embedding.get_stacked_table_specs(
657
+ self._config.feature_specs
658
+ )
659
+ changed = any(
660
+ all_stats.max_ids_per_partition[stack_name]
661
+ > spec.max_ids_per_partition
662
+ or all_stats.max_unique_ids_per_partition[stack_name]
663
+ > spec.max_unique_ids_per_partition
664
+ or all_stats.required_buffer_size_per_sc[stack_name]
665
+ * num_sc_per_device
666
+ > (spec.suggested_coo_buffer_size_per_device or 0)
667
+ for stack_name, spec in stacked_table_specs.items()
668
+ )
669
+
670
+ # Update configuration and repeat preprocessing if stats changed.
671
+ if changed:
672
+ for stack_name, spec in stacked_table_specs.items():
673
+ all_stats.max_ids_per_partition[stack_name] = np.max(
674
+ [
675
+ all_stats.max_ids_per_partition[stack_name],
676
+ spec.max_ids_per_partition,
677
+ ]
678
+ )
679
+ all_stats.max_unique_ids_per_partition[stack_name] = np.max(
680
+ [
681
+ all_stats.max_unique_ids_per_partition[stack_name],
682
+ spec.max_unique_ids_per_partition,
683
+ ]
684
+ )
685
+ all_stats.required_buffer_size_per_sc[stack_name] = np.max(
686
+ [
687
+ all_stats.required_buffer_size_per_sc[stack_name],
688
+ (
689
+ (spec.suggested_coo_buffer_size_per_device or 0)
690
+ + (num_sc_per_device - 1)
691
+ )
692
+ // num_sc_per_device,
693
+ ]
694
+ )
695
+
696
+ embedding.update_preprocessing_parameters(
697
+ self._config.feature_specs, all_stats, num_sc_per_device
698
+ )
699
+
700
+ # Re-execute preprocessing with consistent input statistics.
701
+ preprocessed, _ = embedding_utils.stack_and_shard_samples(
702
+ self._config.feature_specs,
703
+ samples,
704
+ local_device_count,
705
+ global_device_count,
706
+ num_sc_per_device,
707
+ )
708
+
709
+ return {"inputs": preprocessed}
710
+
711
+ def _sparsecore_call(
712
+ self,
713
+ inputs: dict[str, types.Tensor],
714
+ weights: dict[str, types.Tensor] | None = None,
715
+ training: bool = False,
716
+ **kwargs: Any,
717
+ ) -> dict[str, types.Tensor]:
718
+ assert weights is None
719
+
720
+ if not self._sparsecore_built:
721
+ self._sparsecore_build()
722
+
723
+ table_and_slots = keras.tree.map_structure(
724
+ lambda var: var.value, self._table_and_slot_variables
725
+ )
726
+ with self._sparsecore_distribution.scope():
727
+ lookup_func = jax.jit(
728
+ jte_embedding_lookup.embedding_lookup, static_argnames="config"
729
+ )
730
+ out: dict[str, types.Tensor] = lookup_func(
731
+ self._config, inputs, table_and_slots, self._iterations.value
732
+ )
733
+ return out
734
+
735
+ def set_embedding_tables(self, tables: Mapping[str, ArrayLike]) -> None:
736
+ """Sets the embedding tables to specific (unsharded) values.
737
+
738
+ Args:
739
+ tables: Mapping of table name -> table values.
740
+ """
741
+ if "default_device" in self._placement_to_path_to_feature_config:
742
+ self._default_device_set_tables(tables)
743
+
744
+ if "sparsecore" in self._placement_to_path_to_feature_config:
745
+ self._sparsecore_set_tables(tables)
746
+
747
+ def _default_device_set_tables(
748
+ self, tables: Mapping[str, ArrayLike]
749
+ ) -> None:
750
+ if not self.built:
751
+ raise ValueError("Layer must first be built before setting tables.")
752
+
753
+ if "default_device" in self._placement_to_path_to_feature_config:
754
+ table_name_to_embedding_layer = {}
755
+ for (
756
+ path,
757
+ feature_config,
758
+ ) in self._placement_to_path_to_feature_config[
759
+ "default_device"
760
+ ].items():
761
+ table_name_to_embedding_layer[feature_config.table.name] = (
762
+ self._default_device_embedding_layers[path]
763
+ )
764
+
765
+ for (
766
+ table_name,
767
+ embedding_layer,
768
+ ) in table_name_to_embedding_layer.items():
769
+ table_values = tables.get(table_name, None)
770
+ if table_values is not None:
771
+ if embedding_layer.lora_enabled:
772
+ raise ValueError("Cannot set table if LoRA is enabled.")
773
+ # pylint: disable-next=protected-access
774
+ embedding_layer._embeddings.assign(table_values)
775
+
776
+ def _sparsecore_set_tables(self, tables: Mapping[str, ArrayLike]) -> None:
777
+ if not self._sparsecore_built:
778
+ self._sparsecore_build()
779
+
780
+ config = self._config
781
+ num_table_shards = config.mesh.devices.size * config.num_sc_per_device
782
+ table_specs = embedding.get_table_specs(config.feature_specs)
783
+ sharded_tables = jte_table_stacking.stack_and_shard_tables(
784
+ table_specs,
785
+ tables,
786
+ num_table_shards,
787
+ )
788
+
789
+ device_tables = jax.device_put(
790
+ jax.tree.map(
791
+ # Flatten shard dimension to allow auto-sharding to split the
792
+ # array.
793
+ lambda table: table.reshape((-1, table.shape[-1])),
794
+ sharded_tables,
795
+ ),
796
+ self._sparsecore_layout.backend_layout,
797
+ )
798
+
799
+ # Assign stacked table variables to the device values.
800
+ keras.tree.map_structure_up_to(
801
+ device_tables,
802
+ lambda embedding_variables,
803
+ table_value: embedding_variables.table.assign(table_value),
804
+ self._table_and_slot_variables,
805
+ device_tables,
806
+ )
807
+
808
+ def _sparsecore_get_embedding_tables(self) -> dict[str, ArrayLike]:
809
+ if not self._sparsecore_built:
810
+ self.sparsecore_build()
811
+
812
+ config = self._config
813
+ num_table_shards = config.mesh.devices.size * config.num_sc_per_device
814
+ table_specs = embedding.get_table_specs(config.feature_specs)
815
+
816
+ # Extract only the table variables, not the gradient slot variables.
817
+ table_variables = {
818
+ name: jax.device_get(embedding_variables.table.value)
819
+ for name, embedding_variables in (
820
+ self._table_and_slot_variables.items()
821
+ )
822
+ }
823
+
824
+ return typing.cast(
825
+ dict[str, ArrayLike],
826
+ jte_table_stacking.unshard_and_unstack_tables(
827
+ table_specs, table_variables, num_table_shards
828
+ ),
829
+ )