keras-rs-nightly 0.2.2.dev202508190331__py3-none-any.whl → 0.4.1.dev202601250348__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,6 +4,7 @@ This file was autogenerated. Do not edit it by hand,
4
4
  since your modifications would be overwritten.
5
5
  """
6
6
 
7
+ from keras_rs.src.losses.list_mle_loss import ListMLELoss as ListMLELoss
7
8
  from keras_rs.src.losses.pairwise_hinge_loss import (
8
9
  PairwiseHingeLoss as PairwiseHingeLoss,
9
10
  )
@@ -457,6 +457,10 @@ class DistributedEmbedding(keras.layers.Layer):
457
457
  tables in the inner lists together. Note that table stacking is not
458
458
  supported on older TPUs, in which case the default value of `"auto"`
459
459
  will be interpreted as no table stacking.
460
+ update_stats: If True, `'max_ids_per_partition'`,
461
+ `'max_unique_ids_per_partition'` and
462
+ `'suggested_coo_buffer_size_per_device'` are updated during
463
+ training. This argument can be set to True only for the JAX backend.
460
464
  **kwargs: Additional arguments to pass to the layer base class.
461
465
  """
462
466
 
@@ -467,6 +471,7 @@ class DistributedEmbedding(keras.layers.Layer):
467
471
  table_stacking: (
468
472
  str | Sequence[str] | Sequence[Sequence[str]]
469
473
  ) = "auto",
474
+ update_stats: bool = False,
470
475
  **kwargs: Any,
471
476
  ) -> None:
472
477
  super().__init__(**kwargs)
@@ -486,6 +491,8 @@ class DistributedEmbedding(keras.layers.Layer):
486
491
  table_stacking,
487
492
  )
488
493
 
494
+ self.update_stats = update_stats
495
+
489
496
  @keras_utils.no_automatic_dependency_tracking
490
497
  def _init_feature_configs_structures(
491
498
  self,
@@ -822,13 +829,13 @@ class DistributedEmbedding(keras.layers.Layer):
822
829
  table_stacking: str | Sequence[Sequence[str]],
823
830
  ) -> None:
824
831
  del table_stacking
825
- table_to_embedding_layer: dict[TableConfig, EmbedReduce] = {}
832
+ table_config_id_to_embedding_layer: dict[int, EmbedReduce] = {}
826
833
  self._default_device_embedding_layers: dict[str, EmbedReduce] = {}
827
834
 
828
835
  for path, feature_config in feature_configs.items():
829
- if feature_config.table in table_to_embedding_layer:
836
+ if id(feature_config.table) in table_config_id_to_embedding_layer:
830
837
  self._default_device_embedding_layers[path] = (
831
- table_to_embedding_layer[feature_config.table]
838
+ table_config_id_to_embedding_layer[id(feature_config.table)]
832
839
  )
833
840
  else:
834
841
  embedding_layer = EmbedReduce(
@@ -838,7 +845,9 @@ class DistributedEmbedding(keras.layers.Layer):
838
845
  embeddings_initializer=feature_config.table.initializer,
839
846
  combiner=feature_config.table.combiner,
840
847
  )
841
- table_to_embedding_layer[feature_config.table] = embedding_layer
848
+ table_config_id_to_embedding_layer[id(feature_config.table)] = (
849
+ embedding_layer
850
+ )
842
851
  self._default_device_embedding_layers[path] = embedding_layer
843
852
 
844
853
  def _default_device_build(
@@ -1013,8 +1022,8 @@ class DistributedEmbedding(keras.layers.Layer):
1013
1022
 
1014
1023
  # The serialized `TableConfig` objects.
1015
1024
  table_config_dicts: list[dict[str, Any]] = []
1016
- # Mapping from `TableConfig` to index in `table_config_dicts`.
1017
- table_config_indices: dict[TableConfig, int] = {}
1025
+ # Mapping from `TableConfig` id to index in `table_config_dicts`.
1026
+ table_config_id_to_index: dict[int, int] = {}
1018
1027
 
1019
1028
  def serialize_feature_config(
1020
1029
  feature_config: FeatureConfig,
@@ -1024,17 +1033,17 @@ class DistributedEmbedding(keras.layers.Layer):
1024
1033
  # key.
1025
1034
  feature_config_dict = feature_config.get_config()
1026
1035
 
1027
- if feature_config.table not in table_config_indices:
1036
+ if id(feature_config.table) not in table_config_id_to_index:
1028
1037
  # Save the serialized `TableConfig` the first time we see it and
1029
1038
  # remember its index.
1030
- table_config_indices[feature_config.table] = len(
1039
+ table_config_id_to_index[id(feature_config.table)] = len(
1031
1040
  table_config_dicts
1032
1041
  )
1033
1042
  table_config_dicts.append(feature_config_dict["table"])
1034
1043
 
1035
1044
  # Replace the serialized `TableConfig` with its index.
1036
- feature_config_dict["table"] = table_config_indices[
1037
- feature_config.table
1045
+ feature_config_dict["table"] = table_config_id_to_index[
1046
+ id(feature_config.table)
1038
1047
  ]
1039
1048
  return feature_config_dict
1040
1049
 
@@ -10,7 +10,7 @@ from keras_rs.src.api_export import keras_rs_export
10
10
 
11
11
 
12
12
  @keras_rs_export("keras_rs.layers.TableConfig")
13
- @dataclasses.dataclass(eq=True, unsafe_hash=True, order=True)
13
+ @dataclasses.dataclass(order=True)
14
14
  class TableConfig:
15
15
  """Configuration for one embedding table.
16
16
 
@@ -88,7 +88,7 @@ class TableConfig:
88
88
 
89
89
 
90
90
  @keras_rs_export("keras_rs.layers.FeatureConfig")
91
- @dataclasses.dataclass(eq=True, unsafe_hash=True, order=True)
91
+ @dataclasses.dataclass(order=True)
92
92
  class FeatureConfig:
93
93
  """Configuration for one embedding feature.
94
94
 
@@ -9,13 +9,13 @@ import keras
9
9
  import numpy as np
10
10
  from jax import numpy as jnp
11
11
  from jax.experimental import layout as jax_layout
12
+ from jax.experimental import multihost_utils
12
13
  from jax_tpu_embedding.sparsecore.lib.nn import embedding
13
14
  from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec
14
15
  from jax_tpu_embedding.sparsecore.lib.nn import (
15
16
  table_stacking as jte_table_stacking,
16
17
  )
17
18
  from jax_tpu_embedding.sparsecore.utils import utils as jte_utils
18
- from keras.src import backend
19
19
 
20
20
  from keras_rs.src import types
21
21
  from keras_rs.src.layers.embedding import base_distributed_embedding
@@ -28,9 +28,14 @@ from keras_rs.src.layers.embedding.jax import embedding_utils
28
28
  from keras_rs.src.types import Nested
29
29
  from keras_rs.src.utils import keras_utils
30
30
 
31
+ if jax.__version_info__ >= (0, 8, 0):
32
+ from jax import shard_map
33
+ else:
34
+ from jax.experimental.shard_map import shard_map # type: ignore[assignment]
35
+
36
+
31
37
  ArrayLike = Union[np.ndarray[Any, Any], jax.Array]
32
38
  FeatureConfig = config.FeatureConfig
33
- shard_map = jax.experimental.shard_map.shard_map # type: ignore[attr-defined]
34
39
 
35
40
 
36
41
  def _get_partition_spec(
@@ -247,23 +252,6 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
247
252
  )
248
253
  return sparsecore_distribution, sparsecore_layout
249
254
 
250
- def _create_cpu_distribution(
251
- self, cpu_axis_name: str = "cpu"
252
- ) -> tuple[
253
- keras.distribution.ModelParallel, keras.distribution.TensorLayout
254
- ]:
255
- """Share a variable across all CPU processes."""
256
- cpu_devices = jax.devices("cpu")
257
- device_mesh = keras.distribution.DeviceMesh(
258
- (len(cpu_devices),), [cpu_axis_name], cpu_devices
259
- )
260
- replicated_layout = keras.distribution.TensorLayout([], device_mesh)
261
- layout_map = keras.distribution.LayoutMap(device_mesh=device_mesh)
262
- cpu_distribution = keras.distribution.ModelParallel(
263
- layout_map=layout_map
264
- )
265
- return cpu_distribution, replicated_layout
266
-
267
255
  def _add_sparsecore_weight(
268
256
  self,
269
257
  name: str,
@@ -283,7 +271,7 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
283
271
  table_specs: Sequence[embedding_spec.TableSpec],
284
272
  num_shards: int,
285
273
  add_slot_variables: bool,
286
- ) -> tuple[keras.Variable, tuple[keras.Variable, ...] | None]:
274
+ ) -> embedding.EmbeddingVariables:
287
275
  stacked_table_spec = typing.cast(
288
276
  embedding_spec.StackedTableSpec, table_specs[0].stacked_table_spec
289
277
  )
@@ -352,7 +340,7 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
352
340
  slot_initializers, slot_variables
353
341
  )
354
342
 
355
- return table_variable, slot_variables
343
+ return embedding.EmbeddingVariables(table_variable, slot_variables)
356
344
 
357
345
  @keras_utils.no_automatic_dependency_tracking
358
346
  def _sparsecore_init(
@@ -405,11 +393,6 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
405
393
  self._sparsecore_layout = sparsecore_layout
406
394
  self._sparsecore_distribution = sparsecore_distribution
407
395
 
408
- # Distribution for CPU operations.
409
- cpu_distribution, cpu_layout = self._create_cpu_distribution()
410
- self._cpu_distribution = cpu_distribution
411
- self._cpu_layout = cpu_layout
412
-
413
396
  mesh = sparsecore_distribution.device_mesh.backend_mesh
414
397
  global_device_count = mesh.devices.size
415
398
  num_sc_per_device = jte_utils.num_sparsecores_per_device(
@@ -424,7 +407,9 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
424
407
  if isinstance(table_stacking, str):
425
408
  if table_stacking == "auto":
426
409
  jte_table_stacking.auto_stack_tables(
427
- feature_specs, global_device_count, num_sc_per_device
410
+ feature_specs,
411
+ global_device_count,
412
+ num_sc_per_device,
428
413
  )
429
414
  else:
430
415
  raise ValueError(
@@ -464,12 +449,51 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
464
449
  )
465
450
 
466
451
  # Collect all stacked tables.
467
- table_specs = embedding_utils.get_table_specs(feature_specs)
468
- table_stacks = embedding_utils.get_table_stacks(table_specs)
469
- stacked_table_specs = {
470
- stack_name: stack[0].stacked_table_spec
471
- for stack_name, stack in table_stacks.items()
472
- }
452
+ table_specs = embedding.get_table_specs(feature_specs)
453
+ table_stacks = jte_table_stacking.get_table_stacks(table_specs)
454
+
455
+ # Update stacked table stats to max of values across involved tables.
456
+ max_ids_per_partition = {}
457
+ max_unique_ids_per_partition = {}
458
+ required_buffer_size_per_device = {}
459
+ id_drop_counters = {}
460
+ for stack_name, stack in table_stacks.items():
461
+ max_ids_per_partition[stack_name] = np.max(
462
+ np.asarray(
463
+ [s.max_ids_per_partition for s in stack], dtype=np.int32
464
+ )
465
+ )
466
+ max_unique_ids_per_partition[stack_name] = np.max(
467
+ np.asarray(
468
+ [s.max_unique_ids_per_partition for s in stack],
469
+ dtype=np.int32,
470
+ )
471
+ )
472
+
473
+ # Only set the suggested buffer size if set on any individual table.
474
+ valid_buffer_sizes = [
475
+ s.suggested_coo_buffer_size_per_device
476
+ for s in stack
477
+ if s.suggested_coo_buffer_size_per_device is not None
478
+ ]
479
+ if valid_buffer_sizes:
480
+ required_buffer_size_per_device[stack_name] = np.max(
481
+ np.asarray(valid_buffer_sizes, dtype=np.int32)
482
+ )
483
+
484
+ id_drop_counters[stack_name] = 0
485
+
486
+ aggregated_stats = embedding.SparseDenseMatmulInputStats(
487
+ max_ids_per_partition=max_ids_per_partition,
488
+ max_unique_ids_per_partition=max_unique_ids_per_partition,
489
+ required_buffer_size_per_sc=required_buffer_size_per_device,
490
+ id_drop_counters=id_drop_counters,
491
+ )
492
+ embedding.update_preprocessing_parameters(
493
+ feature_specs,
494
+ aggregated_stats,
495
+ num_sc_per_device,
496
+ )
473
497
 
474
498
  # Create variables for all stacked tables and slot variables.
475
499
  with sparsecore_distribution.scope():
@@ -502,50 +526,6 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
502
526
  )
503
527
  self._iterations.overwrite_with_gradient = True
504
528
 
505
- with cpu_distribution.scope():
506
- # Create variables to track static buffer size and max IDs for each
507
- # table during preprocessing. These variables are shared across all
508
- # processes on CPU. We don't add these via `add_weight` because we
509
- # can't have them passed to the training function.
510
- replicated_zeros_initializer = ShardedInitializer(
511
- "zeros", cpu_layout
512
- )
513
-
514
- with backend.name_scope(self.name, caller=self):
515
- self._preprocessing_buffer_size = {
516
- table_name: backend.Variable(
517
- initializer=replicated_zeros_initializer,
518
- shape=(),
519
- dtype=backend.standardize_dtype("int32"),
520
- trainable=False,
521
- name=table_name + ":preprocessing:buffer_size",
522
- )
523
- for table_name in stacked_table_specs.keys()
524
- }
525
- self._preprocessing_max_unique_ids_per_partition = {
526
- table_name: backend.Variable(
527
- shape=(),
528
- name=table_name
529
- + ":preprocessing:max_unique_ids_per_partition",
530
- initializer=replicated_zeros_initializer,
531
- dtype=backend.standardize_dtype("int32"),
532
- trainable=False,
533
- )
534
- for table_name in stacked_table_specs.keys()
535
- }
536
-
537
- self._preprocessing_max_ids_per_partition = {
538
- table_name: backend.Variable(
539
- shape=(),
540
- name=table_name
541
- + ":preprocessing:max_ids_per_partition",
542
- initializer=replicated_zeros_initializer,
543
- dtype=backend.standardize_dtype("int32"),
544
- trainable=False,
545
- )
546
- for table_name in stacked_table_specs.keys()
547
- }
548
-
549
529
  self._config = jte_embedding_lookup.EmbeddingLookupConfiguration(
550
530
  feature_specs,
551
531
  mesh=mesh,
@@ -586,10 +566,8 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
586
566
  del inputs, weights, training
587
567
 
588
568
  # Each stacked-table gets a ShardedCooMatrix.
589
- table_specs = embedding_utils.get_table_specs(
590
- self._config.feature_specs
591
- )
592
- table_stacks = embedding_utils.get_table_stacks(table_specs)
569
+ table_specs = embedding.get_table_specs(self._config.feature_specs)
570
+ table_stacks = jte_table_stacking.get_table_stacks(table_specs)
593
571
  stacked_table_specs = {
594
572
  stack_name: stack[0].stacked_table_spec
595
573
  for stack_name, stack in table_stacks.items()
@@ -660,125 +638,74 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
660
638
  mesh.devices.item(0)
661
639
  )
662
640
 
663
- # Get current buffer size/max_ids.
664
- previous_max_ids_per_partition = keras.tree.map_structure(
665
- lambda max_ids_per_partition: max_ids_per_partition.value.item(),
666
- self._preprocessing_max_ids_per_partition,
667
- )
668
- previous_max_unique_ids_per_partition = keras.tree.map_structure(
669
- lambda max_unique_ids_per_partition: (
670
- max_unique_ids_per_partition.value.item()
671
- ),
672
- self._preprocessing_max_unique_ids_per_partition,
673
- )
674
- previous_buffer_size = keras.tree.map_structure(
675
- lambda buffer_size: buffer_size.value.item(),
676
- self._preprocessing_buffer_size,
677
- )
678
-
679
641
  preprocessed, stats = embedding_utils.stack_and_shard_samples(
680
642
  self._config.feature_specs,
681
643
  samples,
682
644
  local_device_count,
683
645
  global_device_count,
684
646
  num_sc_per_device,
685
- static_buffer_size=previous_buffer_size,
686
- )
687
-
688
- # Extract max unique IDs and buffer sizes.
689
- # We need to replicate this value across all local CPU devices.
690
- if training:
691
- num_local_cpu_devices = jax.local_device_count("cpu")
692
- local_max_ids_per_partition = {
693
- table_name: np.repeat(
694
- # Maximum across all partitions and previous max.
695
- np.maximum(
696
- np.max(elems),
697
- previous_max_ids_per_partition[table_name],
698
- ),
699
- num_local_cpu_devices,
700
- )
701
- for table_name, elems in stats.max_ids_per_partition.items()
702
- }
703
- local_max_unique_ids_per_partition = {
704
- name: np.repeat(
705
- # Maximum across all partitions and previous max.
706
- np.maximum(
707
- np.max(elems),
708
- previous_max_unique_ids_per_partition[name],
709
- ),
710
- num_local_cpu_devices,
711
- )
712
- for name, elems in stats.max_unique_ids_per_partition.items()
713
- }
714
- local_buffer_size = {
715
- table_name: np.repeat(
716
- np.maximum(
717
- np.max(
718
- # Round values up to the next multiple of 8.
719
- # Currently using this as a proxy for the actual
720
- # required buffer size.
721
- ((elems + 7) // 8) * 8
722
- )
723
- * global_device_count
724
- * num_sc_per_device
725
- * local_device_count
726
- * num_sc_per_device,
727
- previous_buffer_size[table_name],
728
- ),
729
- num_local_cpu_devices,
730
- )
731
- for table_name, elems in stats.max_ids_per_partition.items()
732
- }
647
+ )
733
648
 
734
- # Aggregate variables across all processes/devices.
735
- max_across_cpus = jax.pmap(
736
- lambda x: jax.lax.pmax( # type: ignore[no-untyped-call]
737
- x, "all_cpus"
738
- ),
739
- axis_name="all_cpus",
740
- devices=self._cpu_layout.device_mesh.backend_mesh.devices,
741
- )
742
- new_max_ids_per_partition = max_across_cpus(
743
- local_max_ids_per_partition
649
+ if training and self.update_stats:
650
+ # Synchronize input statistics across all devices and update the
651
+ # underlying stacked tables specs in the feature specs.
652
+
653
+ # Gather stats across all processes/devices via process_allgather.
654
+ all_stats = multihost_utils.process_allgather(stats)
655
+ all_stats = jax.tree.map(np.max, all_stats)
656
+
657
+ # Check if stats changed enough to warrant action.
658
+ stacked_table_specs = embedding.get_stacked_table_specs(
659
+ self._config.feature_specs
744
660
  )
745
- new_max_unique_ids_per_partition = max_across_cpus(
746
- local_max_unique_ids_per_partition
661
+ changed = any(
662
+ all_stats.max_ids_per_partition[stack_name]
663
+ > spec.max_ids_per_partition
664
+ or all_stats.max_unique_ids_per_partition[stack_name]
665
+ > spec.max_unique_ids_per_partition
666
+ or all_stats.required_buffer_size_per_sc[stack_name]
667
+ * num_sc_per_device
668
+ > (spec.suggested_coo_buffer_size_per_device or 0)
669
+ for stack_name, spec in stacked_table_specs.items()
747
670
  )
748
- new_buffer_size = max_across_cpus(local_buffer_size)
749
-
750
- # Assign new preprocessing parameters.
751
- with self._cpu_distribution.scope():
752
- # For each process, all max ids/buffer sizes are replicated
753
- # across all local devices. Take the value from the first
754
- # device.
755
- keras.tree.map_structure(
756
- lambda var, values: var.assign(values[0]),
757
- self._preprocessing_max_ids_per_partition,
758
- new_max_ids_per_partition,
759
- )
760
- keras.tree.map_structure(
761
- lambda var, values: var.assign(values[0]),
762
- self._preprocessing_max_unique_ids_per_partition,
763
- new_max_unique_ids_per_partition,
764
- )
765
- keras.tree.map_structure(
766
- lambda var, values: var.assign(values[0]),
767
- self._preprocessing_buffer_size,
768
- new_buffer_size,
769
- )
770
- # Update parameters in the underlying feature specs.
771
- int_max_ids_per_partition = keras.tree.map_structure(
772
- lambda varray: varray.item(), new_max_ids_per_partition
773
- )
774
- int_max_unique_ids_per_partition = keras.tree.map_structure(
775
- lambda varray: varray.item(),
776
- new_max_unique_ids_per_partition,
671
+
672
+ # Update configuration and repeat preprocessing if stats changed.
673
+ if changed:
674
+ for stack_name, spec in stacked_table_specs.items():
675
+ all_stats.max_ids_per_partition[stack_name] = np.max(
676
+ [
677
+ all_stats.max_ids_per_partition[stack_name],
678
+ spec.max_ids_per_partition,
679
+ ]
680
+ )
681
+ all_stats.max_unique_ids_per_partition[stack_name] = np.max(
682
+ [
683
+ all_stats.max_unique_ids_per_partition[stack_name],
684
+ spec.max_unique_ids_per_partition,
685
+ ]
686
+ )
687
+ all_stats.required_buffer_size_per_sc[stack_name] = np.max(
688
+ [
689
+ all_stats.required_buffer_size_per_sc[stack_name],
690
+ (
691
+ (spec.suggested_coo_buffer_size_per_device or 0)
692
+ + (num_sc_per_device - 1)
693
+ )
694
+ // num_sc_per_device,
695
+ ]
696
+ )
697
+
698
+ embedding.update_preprocessing_parameters(
699
+ self._config.feature_specs, all_stats, num_sc_per_device
777
700
  )
778
- embedding_utils.update_stacked_table_specs(
701
+
702
+ # Re-execute preprocessing with consistent input statistics.
703
+ preprocessed, _ = embedding_utils.stack_and_shard_samples(
779
704
  self._config.feature_specs,
780
- int_max_ids_per_partition,
781
- int_max_unique_ids_per_partition,
705
+ samples,
706
+ local_device_count,
707
+ global_device_count,
708
+ num_sc_per_device,
782
709
  )
783
710
 
784
711
  return {"inputs": preprocessed}
@@ -826,19 +753,22 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
826
753
  raise ValueError("Layer must first be built before setting tables.")
827
754
 
828
755
  if "default_device" in self._placement_to_path_to_feature_config:
829
- table_to_embedding_layer = {}
756
+ table_name_to_embedding_layer = {}
830
757
  for (
831
758
  path,
832
759
  feature_config,
833
760
  ) in self._placement_to_path_to_feature_config[
834
761
  "default_device"
835
762
  ].items():
836
- table_to_embedding_layer[feature_config.table] = (
763
+ table_name_to_embedding_layer[feature_config.table.name] = (
837
764
  self._default_device_embedding_layers[path]
838
765
  )
839
766
 
840
- for table, embedding_layer in table_to_embedding_layer.items():
841
- table_values = tables.get(table.name, None)
767
+ for (
768
+ table_name,
769
+ embedding_layer,
770
+ ) in table_name_to_embedding_layer.items():
771
+ table_values = tables.get(table_name, None)
842
772
  if table_values is not None:
843
773
  if embedding_layer.lora_enabled:
844
774
  raise ValueError("Cannot set table if LoRA is enabled.")
@@ -851,8 +781,8 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
851
781
 
852
782
  config = self._config
853
783
  num_table_shards = config.mesh.devices.size * config.num_sc_per_device
854
- table_specs = embedding_utils.get_table_specs(config.feature_specs)
855
- sharded_tables = embedding_utils.stack_and_shard_tables(
784
+ table_specs = embedding.get_table_specs(config.feature_specs)
785
+ sharded_tables = jte_table_stacking.stack_and_shard_tables(
856
786
  table_specs,
857
787
  tables,
858
788
  num_table_shards,
@@ -871,8 +801,8 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
871
801
  # Assign stacked table variables to the device values.
872
802
  keras.tree.map_structure_up_to(
873
803
  device_tables,
874
- lambda table_and_slot_variables,
875
- table_value: table_and_slot_variables[0].assign(table_value),
804
+ lambda embedding_variables,
805
+ table_value: embedding_variables.table.assign(table_value),
876
806
  self._table_and_slot_variables,
877
807
  device_tables,
878
808
  )
@@ -883,17 +813,19 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
883
813
 
884
814
  config = self._config
885
815
  num_table_shards = config.mesh.devices.size * config.num_sc_per_device
886
- table_specs = embedding_utils.get_table_specs(config.feature_specs)
816
+ table_specs = embedding.get_table_specs(config.feature_specs)
887
817
 
888
818
  # Extract only the table variables, not the gradient slot variables.
889
819
  table_variables = {
890
- name: jax.device_get(table_and_slots[0].value)
891
- for name, table_and_slots in self._table_and_slot_variables.items()
820
+ name: jax.device_get(embedding_variables.table.value)
821
+ for name, embedding_variables in (
822
+ self._table_and_slot_variables.items()
823
+ )
892
824
  }
893
825
 
894
826
  return typing.cast(
895
827
  dict[str, ArrayLike],
896
- embedding_utils.unshard_and_unstack_tables(
828
+ jte_table_stacking.unshard_and_unstack_tables(
897
829
  table_specs, table_variables, num_table_shards
898
830
  ),
899
831
  )
@@ -16,9 +16,30 @@ from jax_tpu_embedding.sparsecore.utils import utils as jte_utils
16
16
  from keras_rs.src.layers.embedding.jax import embedding_utils
17
17
  from keras_rs.src.types import Nested
18
18
 
19
- ShardedCooMatrix = embedding_utils.ShardedCooMatrix
20
- shard_map = jax.experimental.shard_map.shard_map # type: ignore[attr-defined]
19
+ if jax.__version_info__ >= (0, 8, 0):
20
+ from jax import shard_map
21
+ else:
22
+ from jax.experimental.shard_map import shard_map as exp_shard_map
23
+
24
+ def shard_map( # type: ignore[misc]
25
+ f: Any = None,
26
+ /,
27
+ *,
28
+ out_specs: Any,
29
+ in_specs: Any,
30
+ mesh: Any = None,
31
+ check_vma: bool = True,
32
+ ) -> Any:
33
+ return exp_shard_map(
34
+ f,
35
+ mesh=mesh,
36
+ in_specs=in_specs,
37
+ out_specs=out_specs,
38
+ check_rep=check_vma,
39
+ ) # type: ignore[no-untyped-call]
40
+
21
41
 
42
+ ShardedCooMatrix = embedding_utils.ShardedCooMatrix
22
43
  ArrayLike: TypeAlias = jax.Array | np.ndarray[Any, Any]
23
44
  JaxLayout: TypeAlias = jax.sharding.NamedSharding | jax_layout.Format
24
45
 
@@ -121,7 +142,7 @@ def embedding_lookup(
121
142
  mesh=config.mesh,
122
143
  in_specs=(pd, pt),
123
144
  out_specs=pd,
124
- check_rep=False,
145
+ check_vma=False,
125
146
  ),
126
147
  )
127
148
 
@@ -220,7 +241,7 @@ def embedding_lookup_bwd(
220
241
  mesh=config.mesh,
221
242
  in_specs=(pd, pd, pt, preplicate),
222
243
  out_specs=pt,
223
- check_rep=False,
244
+ check_vma=False,
224
245
  ),
225
246
  # in_shardings=(
226
247
  # activation_layout,