keras-rs-nightly 0.2.2.dev202508190331__py3-none-any.whl → 0.3.1.dev202512130338__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,6 +4,7 @@ This file was autogenerated. Do not edit it by hand,
4
4
  since your modifications would be overwritten.
5
5
  """
6
6
 
7
+ from keras_rs.src.losses.list_mle_loss import ListMLELoss as ListMLELoss
7
8
  from keras_rs.src.losses.pairwise_hinge_loss import (
8
9
  PairwiseHingeLoss as PairwiseHingeLoss,
9
10
  )
@@ -822,13 +822,13 @@ class DistributedEmbedding(keras.layers.Layer):
822
822
  table_stacking: str | Sequence[Sequence[str]],
823
823
  ) -> None:
824
824
  del table_stacking
825
- table_to_embedding_layer: dict[TableConfig, EmbedReduce] = {}
825
+ table_config_id_to_embedding_layer: dict[int, EmbedReduce] = {}
826
826
  self._default_device_embedding_layers: dict[str, EmbedReduce] = {}
827
827
 
828
828
  for path, feature_config in feature_configs.items():
829
- if feature_config.table in table_to_embedding_layer:
829
+ if id(feature_config.table) in table_config_id_to_embedding_layer:
830
830
  self._default_device_embedding_layers[path] = (
831
- table_to_embedding_layer[feature_config.table]
831
+ table_config_id_to_embedding_layer[id(feature_config.table)]
832
832
  )
833
833
  else:
834
834
  embedding_layer = EmbedReduce(
@@ -838,7 +838,9 @@ class DistributedEmbedding(keras.layers.Layer):
838
838
  embeddings_initializer=feature_config.table.initializer,
839
839
  combiner=feature_config.table.combiner,
840
840
  )
841
- table_to_embedding_layer[feature_config.table] = embedding_layer
841
+ table_config_id_to_embedding_layer[id(feature_config.table)] = (
842
+ embedding_layer
843
+ )
842
844
  self._default_device_embedding_layers[path] = embedding_layer
843
845
 
844
846
  def _default_device_build(
@@ -1013,8 +1015,8 @@ class DistributedEmbedding(keras.layers.Layer):
1013
1015
 
1014
1016
  # The serialized `TableConfig` objects.
1015
1017
  table_config_dicts: list[dict[str, Any]] = []
1016
- # Mapping from `TableConfig` to index in `table_config_dicts`.
1017
- table_config_indices: dict[TableConfig, int] = {}
1018
+ # Mapping from `TableConfig` id to index in `table_config_dicts`.
1019
+ table_config_id_to_index: dict[int, int] = {}
1018
1020
 
1019
1021
  def serialize_feature_config(
1020
1022
  feature_config: FeatureConfig,
@@ -1024,17 +1026,17 @@ class DistributedEmbedding(keras.layers.Layer):
1024
1026
  # key.
1025
1027
  feature_config_dict = feature_config.get_config()
1026
1028
 
1027
- if feature_config.table not in table_config_indices:
1029
+ if id(feature_config.table) not in table_config_id_to_index:
1028
1030
  # Save the serialized `TableConfig` the first time we see it and
1029
1031
  # remember its index.
1030
- table_config_indices[feature_config.table] = len(
1032
+ table_config_id_to_index[id(feature_config.table)] = len(
1031
1033
  table_config_dicts
1032
1034
  )
1033
1035
  table_config_dicts.append(feature_config_dict["table"])
1034
1036
 
1035
1037
  # Replace the serialized `TableConfig` with its index.
1036
- feature_config_dict["table"] = table_config_indices[
1037
- feature_config.table
1038
+ feature_config_dict["table"] = table_config_id_to_index[
1039
+ id(feature_config.table)
1038
1040
  ]
1039
1041
  return feature_config_dict
1040
1042
 
@@ -10,7 +10,7 @@ from keras_rs.src.api_export import keras_rs_export
10
10
 
11
11
 
12
12
  @keras_rs_export("keras_rs.layers.TableConfig")
13
- @dataclasses.dataclass(eq=True, unsafe_hash=True, order=True)
13
+ @dataclasses.dataclass(order=True)
14
14
  class TableConfig:
15
15
  """Configuration for one embedding table.
16
16
 
@@ -88,7 +88,7 @@ class TableConfig:
88
88
 
89
89
 
90
90
  @keras_rs_export("keras_rs.layers.FeatureConfig")
91
- @dataclasses.dataclass(eq=True, unsafe_hash=True, order=True)
91
+ @dataclasses.dataclass(order=True)
92
92
  class FeatureConfig:
93
93
  """Configuration for one embedding feature.
94
94
 
@@ -9,13 +9,13 @@ import keras
9
9
  import numpy as np
10
10
  from jax import numpy as jnp
11
11
  from jax.experimental import layout as jax_layout
12
+ from jax.experimental import multihost_utils
12
13
  from jax_tpu_embedding.sparsecore.lib.nn import embedding
13
14
  from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec
14
15
  from jax_tpu_embedding.sparsecore.lib.nn import (
15
16
  table_stacking as jte_table_stacking,
16
17
  )
17
18
  from jax_tpu_embedding.sparsecore.utils import utils as jte_utils
18
- from keras.src import backend
19
19
 
20
20
  from keras_rs.src import types
21
21
  from keras_rs.src.layers.embedding import base_distributed_embedding
@@ -28,9 +28,14 @@ from keras_rs.src.layers.embedding.jax import embedding_utils
28
28
  from keras_rs.src.types import Nested
29
29
  from keras_rs.src.utils import keras_utils
30
30
 
31
+ if jax.__version_info__ >= (0, 8, 0):
32
+ from jax import shard_map
33
+ else:
34
+ from jax.experimental.shard_map import shard_map # type: ignore[assignment]
35
+
36
+
31
37
  ArrayLike = Union[np.ndarray[Any, Any], jax.Array]
32
38
  FeatureConfig = config.FeatureConfig
33
- shard_map = jax.experimental.shard_map.shard_map # type: ignore[attr-defined]
34
39
 
35
40
 
36
41
  def _get_partition_spec(
@@ -247,23 +252,6 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
247
252
  )
248
253
  return sparsecore_distribution, sparsecore_layout
249
254
 
250
- def _create_cpu_distribution(
251
- self, cpu_axis_name: str = "cpu"
252
- ) -> tuple[
253
- keras.distribution.ModelParallel, keras.distribution.TensorLayout
254
- ]:
255
- """Share a variable across all CPU processes."""
256
- cpu_devices = jax.devices("cpu")
257
- device_mesh = keras.distribution.DeviceMesh(
258
- (len(cpu_devices),), [cpu_axis_name], cpu_devices
259
- )
260
- replicated_layout = keras.distribution.TensorLayout([], device_mesh)
261
- layout_map = keras.distribution.LayoutMap(device_mesh=device_mesh)
262
- cpu_distribution = keras.distribution.ModelParallel(
263
- layout_map=layout_map
264
- )
265
- return cpu_distribution, replicated_layout
266
-
267
255
  def _add_sparsecore_weight(
268
256
  self,
269
257
  name: str,
@@ -283,7 +271,7 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
283
271
  table_specs: Sequence[embedding_spec.TableSpec],
284
272
  num_shards: int,
285
273
  add_slot_variables: bool,
286
- ) -> tuple[keras.Variable, tuple[keras.Variable, ...] | None]:
274
+ ) -> embedding.EmbeddingVariables:
287
275
  stacked_table_spec = typing.cast(
288
276
  embedding_spec.StackedTableSpec, table_specs[0].stacked_table_spec
289
277
  )
@@ -352,7 +340,7 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
352
340
  slot_initializers, slot_variables
353
341
  )
354
342
 
355
- return table_variable, slot_variables
343
+ return embedding.EmbeddingVariables(table_variable, slot_variables)
356
344
 
357
345
  @keras_utils.no_automatic_dependency_tracking
358
346
  def _sparsecore_init(
@@ -405,11 +393,6 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
405
393
  self._sparsecore_layout = sparsecore_layout
406
394
  self._sparsecore_distribution = sparsecore_distribution
407
395
 
408
- # Distribution for CPU operations.
409
- cpu_distribution, cpu_layout = self._create_cpu_distribution()
410
- self._cpu_distribution = cpu_distribution
411
- self._cpu_layout = cpu_layout
412
-
413
396
  mesh = sparsecore_distribution.device_mesh.backend_mesh
414
397
  global_device_count = mesh.devices.size
415
398
  num_sc_per_device = jte_utils.num_sparsecores_per_device(
@@ -464,12 +447,51 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
464
447
  )
465
448
 
466
449
  # Collect all stacked tables.
467
- table_specs = embedding_utils.get_table_specs(feature_specs)
468
- table_stacks = embedding_utils.get_table_stacks(table_specs)
469
- stacked_table_specs = {
470
- stack_name: stack[0].stacked_table_spec
471
- for stack_name, stack in table_stacks.items()
472
- }
450
+ table_specs = embedding.get_table_specs(feature_specs)
451
+ table_stacks = jte_table_stacking.get_table_stacks(table_specs)
452
+
453
+ # Update stacked table stats to max of values across involved tables.
454
+ max_ids_per_partition = {}
455
+ max_unique_ids_per_partition = {}
456
+ required_buffer_size_per_device = {}
457
+ id_drop_counters = {}
458
+ for stack_name, stack in table_stacks.items():
459
+ max_ids_per_partition[stack_name] = np.max(
460
+ np.asarray(
461
+ [s.max_ids_per_partition for s in stack], dtype=np.int32
462
+ )
463
+ )
464
+ max_unique_ids_per_partition[stack_name] = np.max(
465
+ np.asarray(
466
+ [s.max_unique_ids_per_partition for s in stack],
467
+ dtype=np.int32,
468
+ )
469
+ )
470
+
471
+ # Only set the suggested buffer size if set on any individual table.
472
+ valid_buffer_sizes = [
473
+ s.suggested_coo_buffer_size_per_device
474
+ for s in stack
475
+ if s.suggested_coo_buffer_size_per_device is not None
476
+ ]
477
+ if valid_buffer_sizes:
478
+ required_buffer_size_per_device[stack_name] = np.max(
479
+ np.asarray(valid_buffer_sizes, dtype=np.int32)
480
+ )
481
+
482
+ id_drop_counters[stack_name] = 0
483
+
484
+ aggregated_stats = embedding.SparseDenseMatmulInputStats(
485
+ max_ids_per_partition=max_ids_per_partition,
486
+ max_unique_ids_per_partition=max_unique_ids_per_partition,
487
+ required_buffer_size_per_sc=required_buffer_size_per_device,
488
+ id_drop_counters=id_drop_counters,
489
+ )
490
+ embedding.update_preprocessing_parameters(
491
+ feature_specs,
492
+ aggregated_stats,
493
+ num_sc_per_device,
494
+ )
473
495
 
474
496
  # Create variables for all stacked tables and slot variables.
475
497
  with sparsecore_distribution.scope():
@@ -502,50 +524,6 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
502
524
  )
503
525
  self._iterations.overwrite_with_gradient = True
504
526
 
505
- with cpu_distribution.scope():
506
- # Create variables to track static buffer size and max IDs for each
507
- # table during preprocessing. These variables are shared across all
508
- # processes on CPU. We don't add these via `add_weight` because we
509
- # can't have them passed to the training function.
510
- replicated_zeros_initializer = ShardedInitializer(
511
- "zeros", cpu_layout
512
- )
513
-
514
- with backend.name_scope(self.name, caller=self):
515
- self._preprocessing_buffer_size = {
516
- table_name: backend.Variable(
517
- initializer=replicated_zeros_initializer,
518
- shape=(),
519
- dtype=backend.standardize_dtype("int32"),
520
- trainable=False,
521
- name=table_name + ":preprocessing:buffer_size",
522
- )
523
- for table_name in stacked_table_specs.keys()
524
- }
525
- self._preprocessing_max_unique_ids_per_partition = {
526
- table_name: backend.Variable(
527
- shape=(),
528
- name=table_name
529
- + ":preprocessing:max_unique_ids_per_partition",
530
- initializer=replicated_zeros_initializer,
531
- dtype=backend.standardize_dtype("int32"),
532
- trainable=False,
533
- )
534
- for table_name in stacked_table_specs.keys()
535
- }
536
-
537
- self._preprocessing_max_ids_per_partition = {
538
- table_name: backend.Variable(
539
- shape=(),
540
- name=table_name
541
- + ":preprocessing:max_ids_per_partition",
542
- initializer=replicated_zeros_initializer,
543
- dtype=backend.standardize_dtype("int32"),
544
- trainable=False,
545
- )
546
- for table_name in stacked_table_specs.keys()
547
- }
548
-
549
527
  self._config = jte_embedding_lookup.EmbeddingLookupConfiguration(
550
528
  feature_specs,
551
529
  mesh=mesh,
@@ -586,10 +564,8 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
586
564
  del inputs, weights, training
587
565
 
588
566
  # Each stacked-table gets a ShardedCooMatrix.
589
- table_specs = embedding_utils.get_table_specs(
590
- self._config.feature_specs
591
- )
592
- table_stacks = embedding_utils.get_table_stacks(table_specs)
567
+ table_specs = embedding.get_table_specs(self._config.feature_specs)
568
+ table_stacks = jte_table_stacking.get_table_stacks(table_specs)
593
569
  stacked_table_specs = {
594
570
  stack_name: stack[0].stacked_table_spec
595
571
  for stack_name, stack in table_stacks.items()
@@ -660,125 +636,74 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
660
636
  mesh.devices.item(0)
661
637
  )
662
638
 
663
- # Get current buffer size/max_ids.
664
- previous_max_ids_per_partition = keras.tree.map_structure(
665
- lambda max_ids_per_partition: max_ids_per_partition.value.item(),
666
- self._preprocessing_max_ids_per_partition,
667
- )
668
- previous_max_unique_ids_per_partition = keras.tree.map_structure(
669
- lambda max_unique_ids_per_partition: (
670
- max_unique_ids_per_partition.value.item()
671
- ),
672
- self._preprocessing_max_unique_ids_per_partition,
673
- )
674
- previous_buffer_size = keras.tree.map_structure(
675
- lambda buffer_size: buffer_size.value.item(),
676
- self._preprocessing_buffer_size,
677
- )
678
-
679
639
  preprocessed, stats = embedding_utils.stack_and_shard_samples(
680
640
  self._config.feature_specs,
681
641
  samples,
682
642
  local_device_count,
683
643
  global_device_count,
684
644
  num_sc_per_device,
685
- static_buffer_size=previous_buffer_size,
686
645
  )
687
646
 
688
- # Extract max unique IDs and buffer sizes.
689
- # We need to replicate this value across all local CPU devices.
690
647
  if training:
691
- num_local_cpu_devices = jax.local_device_count("cpu")
692
- local_max_ids_per_partition = {
693
- table_name: np.repeat(
694
- # Maximum across all partitions and previous max.
695
- np.maximum(
696
- np.max(elems),
697
- previous_max_ids_per_partition[table_name],
698
- ),
699
- num_local_cpu_devices,
700
- )
701
- for table_name, elems in stats.max_ids_per_partition.items()
702
- }
703
- local_max_unique_ids_per_partition = {
704
- name: np.repeat(
705
- # Maximum across all partitions and previous max.
706
- np.maximum(
707
- np.max(elems),
708
- previous_max_unique_ids_per_partition[name],
709
- ),
710
- num_local_cpu_devices,
711
- )
712
- for name, elems in stats.max_unique_ids_per_partition.items()
713
- }
714
- local_buffer_size = {
715
- table_name: np.repeat(
716
- np.maximum(
717
- np.max(
718
- # Round values up to the next multiple of 8.
719
- # Currently using this as a proxy for the actual
720
- # required buffer size.
721
- ((elems + 7) // 8) * 8
722
- )
723
- * global_device_count
724
- * num_sc_per_device
725
- * local_device_count
726
- * num_sc_per_device,
727
- previous_buffer_size[table_name],
728
- ),
729
- num_local_cpu_devices,
730
- )
731
- for table_name, elems in stats.max_ids_per_partition.items()
732
- }
648
+ # Synchronize input statistics across all devices and update the
649
+ # underlying stacked tables specs in the feature specs.
733
650
 
734
- # Aggregate variables across all processes/devices.
735
- max_across_cpus = jax.pmap(
736
- lambda x: jax.lax.pmax( # type: ignore[no-untyped-call]
737
- x, "all_cpus"
738
- ),
739
- axis_name="all_cpus",
740
- devices=self._cpu_layout.device_mesh.backend_mesh.devices,
741
- )
742
- new_max_ids_per_partition = max_across_cpus(
743
- local_max_ids_per_partition
651
+ # Gather stats across all processes/devices via process_allgather.
652
+ all_stats = multihost_utils.process_allgather(stats)
653
+ all_stats = jax.tree.map(np.max, all_stats)
654
+
655
+ # Check if stats changed enough to warrant action.
656
+ stacked_table_specs = embedding.get_stacked_table_specs(
657
+ self._config.feature_specs
744
658
  )
745
- new_max_unique_ids_per_partition = max_across_cpus(
746
- local_max_unique_ids_per_partition
659
+ changed = any(
660
+ all_stats.max_ids_per_partition[stack_name]
661
+ > spec.max_ids_per_partition
662
+ or all_stats.max_unique_ids_per_partition[stack_name]
663
+ > spec.max_unique_ids_per_partition
664
+ or all_stats.required_buffer_size_per_sc[stack_name]
665
+ * num_sc_per_device
666
+ > (spec.suggested_coo_buffer_size_per_device or 0)
667
+ for stack_name, spec in stacked_table_specs.items()
747
668
  )
748
- new_buffer_size = max_across_cpus(local_buffer_size)
749
-
750
- # Assign new preprocessing parameters.
751
- with self._cpu_distribution.scope():
752
- # For each process, all max ids/buffer sizes are replicated
753
- # across all local devices. Take the value from the first
754
- # device.
755
- keras.tree.map_structure(
756
- lambda var, values: var.assign(values[0]),
757
- self._preprocessing_max_ids_per_partition,
758
- new_max_ids_per_partition,
759
- )
760
- keras.tree.map_structure(
761
- lambda var, values: var.assign(values[0]),
762
- self._preprocessing_max_unique_ids_per_partition,
763
- new_max_unique_ids_per_partition,
764
- )
765
- keras.tree.map_structure(
766
- lambda var, values: var.assign(values[0]),
767
- self._preprocessing_buffer_size,
768
- new_buffer_size,
769
- )
770
- # Update parameters in the underlying feature specs.
771
- int_max_ids_per_partition = keras.tree.map_structure(
772
- lambda varray: varray.item(), new_max_ids_per_partition
773
- )
774
- int_max_unique_ids_per_partition = keras.tree.map_structure(
775
- lambda varray: varray.item(),
776
- new_max_unique_ids_per_partition,
669
+
670
+ # Update configuration and repeat preprocessing if stats changed.
671
+ if changed:
672
+ for stack_name, spec in stacked_table_specs.items():
673
+ all_stats.max_ids_per_partition[stack_name] = np.max(
674
+ [
675
+ all_stats.max_ids_per_partition[stack_name],
676
+ spec.max_ids_per_partition,
677
+ ]
678
+ )
679
+ all_stats.max_unique_ids_per_partition[stack_name] = np.max(
680
+ [
681
+ all_stats.max_unique_ids_per_partition[stack_name],
682
+ spec.max_unique_ids_per_partition,
683
+ ]
684
+ )
685
+ all_stats.required_buffer_size_per_sc[stack_name] = np.max(
686
+ [
687
+ all_stats.required_buffer_size_per_sc[stack_name],
688
+ (
689
+ (spec.suggested_coo_buffer_size_per_device or 0)
690
+ + (num_sc_per_device - 1)
691
+ )
692
+ // num_sc_per_device,
693
+ ]
694
+ )
695
+
696
+ embedding.update_preprocessing_parameters(
697
+ self._config.feature_specs, all_stats, num_sc_per_device
777
698
  )
778
- embedding_utils.update_stacked_table_specs(
699
+
700
+ # Re-execute preprocessing with consistent input statistics.
701
+ preprocessed, _ = embedding_utils.stack_and_shard_samples(
779
702
  self._config.feature_specs,
780
- int_max_ids_per_partition,
781
- int_max_unique_ids_per_partition,
703
+ samples,
704
+ local_device_count,
705
+ global_device_count,
706
+ num_sc_per_device,
782
707
  )
783
708
 
784
709
  return {"inputs": preprocessed}
@@ -826,19 +751,22 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
826
751
  raise ValueError("Layer must first be built before setting tables.")
827
752
 
828
753
  if "default_device" in self._placement_to_path_to_feature_config:
829
- table_to_embedding_layer = {}
754
+ table_name_to_embedding_layer = {}
830
755
  for (
831
756
  path,
832
757
  feature_config,
833
758
  ) in self._placement_to_path_to_feature_config[
834
759
  "default_device"
835
760
  ].items():
836
- table_to_embedding_layer[feature_config.table] = (
761
+ table_name_to_embedding_layer[feature_config.table.name] = (
837
762
  self._default_device_embedding_layers[path]
838
763
  )
839
764
 
840
- for table, embedding_layer in table_to_embedding_layer.items():
841
- table_values = tables.get(table.name, None)
765
+ for (
766
+ table_name,
767
+ embedding_layer,
768
+ ) in table_name_to_embedding_layer.items():
769
+ table_values = tables.get(table_name, None)
842
770
  if table_values is not None:
843
771
  if embedding_layer.lora_enabled:
844
772
  raise ValueError("Cannot set table if LoRA is enabled.")
@@ -851,8 +779,8 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
851
779
 
852
780
  config = self._config
853
781
  num_table_shards = config.mesh.devices.size * config.num_sc_per_device
854
- table_specs = embedding_utils.get_table_specs(config.feature_specs)
855
- sharded_tables = embedding_utils.stack_and_shard_tables(
782
+ table_specs = embedding.get_table_specs(config.feature_specs)
783
+ sharded_tables = jte_table_stacking.stack_and_shard_tables(
856
784
  table_specs,
857
785
  tables,
858
786
  num_table_shards,
@@ -871,8 +799,8 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
871
799
  # Assign stacked table variables to the device values.
872
800
  keras.tree.map_structure_up_to(
873
801
  device_tables,
874
- lambda table_and_slot_variables,
875
- table_value: table_and_slot_variables[0].assign(table_value),
802
+ lambda embedding_variables,
803
+ table_value: embedding_variables.table.assign(table_value),
876
804
  self._table_and_slot_variables,
877
805
  device_tables,
878
806
  )
@@ -883,17 +811,19 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
883
811
 
884
812
  config = self._config
885
813
  num_table_shards = config.mesh.devices.size * config.num_sc_per_device
886
- table_specs = embedding_utils.get_table_specs(config.feature_specs)
814
+ table_specs = embedding.get_table_specs(config.feature_specs)
887
815
 
888
816
  # Extract only the table variables, not the gradient slot variables.
889
817
  table_variables = {
890
- name: jax.device_get(table_and_slots[0].value)
891
- for name, table_and_slots in self._table_and_slot_variables.items()
818
+ name: jax.device_get(embedding_variables.table.value)
819
+ for name, embedding_variables in (
820
+ self._table_and_slot_variables.items()
821
+ )
892
822
  }
893
823
 
894
824
  return typing.cast(
895
825
  dict[str, ArrayLike],
896
- embedding_utils.unshard_and_unstack_tables(
826
+ jte_table_stacking.unshard_and_unstack_tables(
897
827
  table_specs, table_variables, num_table_shards
898
828
  ),
899
829
  )
@@ -16,9 +16,30 @@ from jax_tpu_embedding.sparsecore.utils import utils as jte_utils
16
16
  from keras_rs.src.layers.embedding.jax import embedding_utils
17
17
  from keras_rs.src.types import Nested
18
18
 
19
- ShardedCooMatrix = embedding_utils.ShardedCooMatrix
20
- shard_map = jax.experimental.shard_map.shard_map # type: ignore[attr-defined]
19
+ if jax.__version_info__ >= (0, 8, 0):
20
+ from jax import shard_map
21
+ else:
22
+ from jax.experimental.shard_map import shard_map as exp_shard_map
23
+
24
+ def shard_map( # type: ignore[misc]
25
+ f: Any = None,
26
+ /,
27
+ *,
28
+ out_specs: Any,
29
+ in_specs: Any,
30
+ mesh: Any = None,
31
+ check_vma: bool = True,
32
+ ) -> Any:
33
+ return exp_shard_map(
34
+ f,
35
+ mesh=mesh,
36
+ in_specs=in_specs,
37
+ out_specs=out_specs,
38
+ check_rep=check_vma,
39
+ ) # type: ignore[no-untyped-call]
40
+
21
41
 
42
+ ShardedCooMatrix = embedding_utils.ShardedCooMatrix
22
43
  ArrayLike: TypeAlias = jax.Array | np.ndarray[Any, Any]
23
44
  JaxLayout: TypeAlias = jax.sharding.NamedSharding | jax_layout.Format
24
45
 
@@ -121,7 +142,7 @@ def embedding_lookup(
121
142
  mesh=config.mesh,
122
143
  in_specs=(pd, pt),
123
144
  out_specs=pd,
124
- check_rep=False,
145
+ check_vma=False,
125
146
  ),
126
147
  )
127
148
 
@@ -220,7 +241,7 @@ def embedding_lookup_bwd(
220
241
  mesh=config.mesh,
221
242
  in_specs=(pd, pd, pt, preplicate),
222
243
  out_specs=pt,
223
- check_rep=False,
244
+ check_vma=False,
224
245
  ),
225
246
  # in_shardings=(
226
247
  # activation_layout,