keras-rs-nightly 0.2.2.dev202507030337__tar.gz → 0.3.1.dev202511120334__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of keras-rs-nightly might be problematic. Click here for more details.

Files changed (63) hide show
  1. {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/PKG-INFO +4 -3
  2. {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/api/losses/__init__.py +1 -0
  3. {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/embedding/base_distributed_embedding.py +85 -55
  4. {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/embedding/distributed_embedding_config.py +6 -3
  5. {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/embedding/jax/distributed_embedding.py +111 -195
  6. keras_rs_nightly-0.3.1.dev202511120334/keras_rs/src/layers/embedding/jax/embedding_utils.py +244 -0
  7. {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/embedding/tensorflow/config_conversion.py +62 -22
  8. {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py +18 -6
  9. keras_rs_nightly-0.3.1.dev202511120334/keras_rs/src/losses/list_mle_loss.py +212 -0
  10. {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/metrics/ranking_metrics_utils.py +19 -0
  11. {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/version.py +1 -1
  12. {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs_nightly.egg-info/PKG-INFO +4 -3
  13. {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs_nightly.egg-info/SOURCES.txt +1 -0
  14. {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/pyproject.toml +4 -3
  15. keras_rs_nightly-0.2.2.dev202507030337/keras_rs/src/layers/embedding/jax/embedding_utils.py +0 -596
  16. {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/README.md +0 -0
  17. {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/api/__init__.py +0 -0
  18. {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/api/layers/__init__.py +0 -0
  19. {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/api/metrics/__init__.py +0 -0
  20. {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/__init__.py +0 -0
  21. {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/api_export.py +0 -0
  22. {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/__init__.py +0 -0
  23. {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/embedding/__init__.py +0 -0
  24. {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/embedding/distributed_embedding.py +0 -0
  25. {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/embedding/embed_reduce.py +0 -0
  26. {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/embedding/jax/__init__.py +0 -0
  27. {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/embedding/jax/checkpoint_utils.py +0 -0
  28. {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/embedding/jax/config_conversion.py +0 -0
  29. {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/embedding/jax/embedding_lookup.py +0 -0
  30. {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/embedding/tensorflow/__init__.py +0 -0
  31. {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/feature_interaction/__init__.py +0 -0
  32. {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/feature_interaction/dot_interaction.py +0 -0
  33. {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/feature_interaction/feature_cross.py +0 -0
  34. {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/retrieval/__init__.py +0 -0
  35. {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/retrieval/brute_force_retrieval.py +0 -0
  36. {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/retrieval/hard_negative_mining.py +0 -0
  37. {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/retrieval/remove_accidental_hits.py +0 -0
  38. {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/retrieval/retrieval.py +0 -0
  39. {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/retrieval/sampling_probability_correction.py +0 -0
  40. {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/losses/__init__.py +0 -0
  41. {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/losses/pairwise_hinge_loss.py +0 -0
  42. {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/losses/pairwise_logistic_loss.py +0 -0
  43. {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/losses/pairwise_loss.py +0 -0
  44. {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/losses/pairwise_loss_utils.py +0 -0
  45. {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/losses/pairwise_mean_squared_error.py +0 -0
  46. {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/losses/pairwise_soft_zero_one_loss.py +0 -0
  47. {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/metrics/__init__.py +0 -0
  48. {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/metrics/dcg.py +0 -0
  49. {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/metrics/mean_average_precision.py +0 -0
  50. {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/metrics/mean_reciprocal_rank.py +0 -0
  51. {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/metrics/ndcg.py +0 -0
  52. {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/metrics/precision_at_k.py +0 -0
  53. {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/metrics/ranking_metric.py +0 -0
  54. {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/metrics/recall_at_k.py +0 -0
  55. {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/metrics/utils.py +0 -0
  56. {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/types.py +0 -0
  57. {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/utils/__init__.py +0 -0
  58. {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/utils/doc_string_utils.py +0 -0
  59. {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/utils/keras_utils.py +0 -0
  60. {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs_nightly.egg-info/dependency_links.txt +0 -0
  61. {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs_nightly.egg-info/requires.txt +0 -0
  62. {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs_nightly.egg-info/top_level.txt +0 -0
  63. {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: keras-rs-nightly
3
- Version: 0.2.2.dev202507030337
3
+ Version: 0.3.1.dev202511120334
4
4
  Summary: Multi-backend recommender systems with Keras 3.
5
5
  Author-email: Keras team <keras-users@googlegroups.com>
6
6
  License: Apache License 2.0
@@ -8,8 +8,9 @@ Project-URL: Home, https://keras.io/keras_rs
8
8
  Project-URL: Repository, https://github.com/keras-team/keras-rs
9
9
  Classifier: Development Status :: 3 - Alpha
10
10
  Classifier: Programming Language :: Python :: 3
11
- Classifier: Programming Language :: Python :: 3.10
12
11
  Classifier: Programming Language :: Python :: 3.11
12
+ Classifier: Programming Language :: Python :: 3.12
13
+ Classifier: Programming Language :: Python :: 3.13
13
14
  Classifier: Programming Language :: Python :: 3 :: Only
14
15
  Classifier: Operating System :: Unix
15
16
  Classifier: Operating System :: Microsoft :: Windows
@@ -17,7 +18,7 @@ Classifier: Operating System :: MacOS
17
18
  Classifier: Intended Audience :: Science/Research
18
19
  Classifier: Topic :: Scientific/Engineering
19
20
  Classifier: Topic :: Software Development
20
- Requires-Python: >=3.10
21
+ Requires-Python: >=3.11
21
22
  Description-Content-Type: text/markdown
22
23
  Requires-Dist: keras
23
24
  Requires-Dist: ml-dtypes
@@ -4,6 +4,7 @@ This file was autogenerated. Do not edit it by hand,
4
4
  since your modifications would be overwritten.
5
5
  """
6
6
 
7
+ from keras_rs.src.losses.list_mle_loss import ListMLELoss as ListMLELoss
7
8
  from keras_rs.src.losses.pairwise_hinge_loss import (
8
9
  PairwiseHingeLoss as PairwiseHingeLoss,
9
10
  )
@@ -1,4 +1,5 @@
1
1
  import collections
2
+ import dataclasses
2
3
  import importlib.util
3
4
  import typing
4
5
  from typing import Any, Sequence
@@ -20,9 +21,10 @@ EmbedReduce = embed_reduce.EmbedReduce
20
21
  SUPPORTED_PLACEMENTS = ("auto", "default_device", "sparsecore")
21
22
 
22
23
 
23
- PlacementAndPath = collections.namedtuple(
24
- "PlacementAndPath", ["placement", "path"]
25
- )
24
+ @dataclasses.dataclass(eq=True, unsafe_hash=True, order=True)
25
+ class PlacementAndPath:
26
+ placement: str
27
+ path: str
26
28
 
27
29
 
28
30
  def _ragged_to_dense_inputs(
@@ -146,14 +148,14 @@ class DistributedEmbedding(keras.layers.Layer):
146
148
  feature1 = keras_rs.layers.FeatureConfig(
147
149
  name="feature1",
148
150
  table=table1,
149
- input_shape=(PER_REPLICA_BATCH_SIZE,),
150
- output_shape=(PER_REPLICA_BATCH_SIZE, TABLE1_EMBEDDING_SIZE),
151
+ input_shape=(GLOBAL_BATCH_SIZE,),
152
+ output_shape=(GLOBAL_BATCH_SIZE, TABLE1_EMBEDDING_SIZE),
151
153
  )
152
154
  feature2 = keras_rs.layers.FeatureConfig(
153
155
  name="feature2",
154
156
  table=table2,
155
- input_shape=(PER_REPLICA_BATCH_SIZE,),
156
- output_shape=(PER_REPLICA_BATCH_SIZE, TABLE2_EMBEDDING_SIZE),
157
+ input_shape=(GLOBAL_BATCH_SIZE,),
158
+ output_shape=(GLOBAL_BATCH_SIZE, TABLE2_EMBEDDING_SIZE),
157
159
  )
158
160
 
159
161
  feature_configs = {
@@ -337,18 +339,33 @@ class DistributedEmbedding(keras.layers.Layer):
337
339
  embedding_layer = DistributedEmbedding(feature_configs)
338
340
 
339
341
  # Add preprocessing to a data input pipeline.
340
- def train_dataset_generator():
341
- for (inputs, weights), labels in iter(train_dataset):
342
+ def preprocessed_dataset_generator(dataset):
343
+ for (inputs, weights), labels in iter(dataset):
342
344
  yield embedding_layer.preprocess(
343
345
  inputs, weights, training=True
344
346
  ), labels
345
347
 
346
- preprocessed_train_dataset = train_dataset_generator()
348
+ preprocessed_train_dataset = preprocessed_dataset_generator(train_dataset)
347
349
  ```
348
350
  This explicit preprocessing stage combines the input and optional weights,
349
351
  so the new data can be passed directly into the `inputs` argument of the
350
352
  layer or model.
351
353
 
354
+ **NOTE**: When working in a multi-host setting with data parallelism, the
355
+ data needs to be sharded properly across hosts. If the original dataset is
356
+ of type `tf.data.Dataset`, it will need to be manually sharded _prior_ to
357
+ applying the preprocess generator:
358
+ ```python
359
+ # Manually shard the dataset across hosts.
360
+ train_dataset = distribution.distribute_dataset(train_dataset)
361
+ distribution.auto_shard_dataset = False # Dataset is already sharded.
362
+
363
+ # Add a preprocessing stage to the distributed data input pipeline.
364
+ train_dataset = preprocessed_dataset_generator(train_dataset)
365
+ ```
366
+ If the original dataset is _not_ a `tf.data.Dataset`, it must already be
367
+ pre-sharded across hosts.
368
+
352
369
  #### Usage in a Keras model
353
370
 
354
371
  Once the global distribution is set and the input preprocessing pipeline
@@ -503,12 +520,12 @@ class DistributedEmbedding(keras.layers.Layer):
503
520
  With these structures in place, the steps to:
504
521
  - go from the deeply nested structure to the two-level structure are:
505
522
  - `assert_same_struct` as `self._feature_configs`
506
- - `flatten`
507
- - `pack_sequence_as` `self._placement_to_path_to_feature_config`
523
+ - use `self._feature_deeply_nested_placement_and_paths` to map from
524
+ deeply nested to two-level
508
525
  - go from the two-level structure to the deeply nested structure:
509
- - `assert_same_struct` as `self._placement_to_path_to_feature_config`
510
- - `flatten`
511
- - `pack_sequence_as` `self._feature_configs`
526
+ - `assert_same_struct` as `self._placement_to_path_to_feature_config`
527
+ - use `self._feature_deeply_nested_placement_and_paths` to locate each
528
+ output in the two-level dicts
512
529
 
513
530
  Args:
514
531
  feature_configs: The deeply nested structure of `FeatureConfig` or
@@ -575,14 +592,14 @@ class DistributedEmbedding(keras.layers.Layer):
575
592
  ] = collections.defaultdict(dict)
576
593
 
577
594
  def populate_placement_to_path_to_input_shape(
578
- placement_and_path: PlacementAndPath, input_shape: types.Shape
595
+ pp: PlacementAndPath, input_shape: types.Shape
579
596
  ) -> None:
580
- placement_to_path_to_input_shape[placement_and_path.placement][
581
- placement_and_path.path
582
- ] = input_shape
597
+ placement_to_path_to_input_shape[pp.placement][pp.path] = (
598
+ input_shape
599
+ )
583
600
 
584
601
  keras.tree.map_structure_up_to(
585
- self._feature_configs,
602
+ self._feature_deeply_nested_placement_and_paths,
586
603
  populate_placement_to_path_to_input_shape,
587
604
  self._feature_deeply_nested_placement_and_paths,
588
605
  input_shapes,
@@ -630,35 +647,40 @@ class DistributedEmbedding(keras.layers.Layer):
630
647
  """
631
648
  # Verify input structure.
632
649
  keras.tree.assert_same_structure(self._feature_configs, inputs)
650
+ if weights is not None:
651
+ keras.tree.assert_same_structure(self._feature_configs, weights)
633
652
 
634
653
  if not self.built:
635
- input_shapes = keras.tree.map_structure_up_to(
636
- self._feature_configs,
654
+ input_shapes = keras.tree.map_structure(
637
655
  lambda array: backend.standardize_shape(array.shape),
638
656
  inputs,
639
657
  )
640
658
  self.build(input_shapes)
641
659
 
642
- # Go from deeply nested structure of inputs to flat inputs.
643
- flat_inputs = keras.tree.flatten(inputs)
660
+ # Go from deeply nested to nested dict placement -> path -> input.
661
+ def to_placement_to_path(
662
+ tensors: types.Nested[types.Tensor],
663
+ ) -> dict[str, dict[str, types.Tensor]]:
664
+ result: dict[str, dict[str, types.Tensor]] = {
665
+ p: dict() for p in self._placement_to_path_to_feature_config
666
+ }
644
667
 
645
- # Go from flat to nested dict placement -> path -> input.
646
- placement_to_path_to_inputs = keras.tree.pack_sequence_as(
647
- self._placement_to_path_to_feature_config, flat_inputs
648
- )
668
+ def populate(pp: PlacementAndPath, x: types.Tensor) -> None:
669
+ result[pp.placement][pp.path] = x
649
670
 
650
- if weights is not None:
651
- # Same for weights if present.
652
- keras.tree.assert_same_structure(self._feature_configs, weights)
653
- flat_weights = keras.tree.flatten(weights)
654
- placement_to_path_to_weights = keras.tree.pack_sequence_as(
655
- self._placement_to_path_to_feature_config, flat_weights
671
+ keras.tree.map_structure(
672
+ populate,
673
+ self._feature_deeply_nested_placement_and_paths,
674
+ tensors,
656
675
  )
657
- else:
658
- # Populate keys for weights.
659
- placement_to_path_to_weights = {
660
- k: None for k in placement_to_path_to_inputs
661
- }
676
+ return result
677
+
678
+ placement_to_path_to_inputs = to_placement_to_path(inputs)
679
+
680
+ # Same for weights if present.
681
+ placement_to_path_to_weights = (
682
+ to_placement_to_path(weights) if weights is not None else None
683
+ )
662
684
 
663
685
  placement_to_path_to_preprocessed: dict[
664
686
  str, dict[str, dict[str, types.Nested[types.Tensor]]]
@@ -669,7 +691,9 @@ class DistributedEmbedding(keras.layers.Layer):
669
691
  placement_to_path_to_preprocessed["sparsecore"] = (
670
692
  self._sparsecore_preprocess(
671
693
  placement_to_path_to_inputs["sparsecore"],
672
- placement_to_path_to_weights["sparsecore"],
694
+ placement_to_path_to_weights["sparsecore"]
695
+ if placement_to_path_to_weights is not None
696
+ else None,
673
697
  training,
674
698
  )
675
699
  )
@@ -679,7 +703,9 @@ class DistributedEmbedding(keras.layers.Layer):
679
703
  placement_to_path_to_preprocessed["default_device"] = (
680
704
  self._default_device_preprocess(
681
705
  placement_to_path_to_inputs["default_device"],
682
- placement_to_path_to_weights["default_device"],
706
+ placement_to_path_to_weights["default_device"]
707
+ if placement_to_path_to_weights is not None
708
+ else None,
683
709
  training,
684
710
  )
685
711
  )
@@ -765,11 +791,13 @@ class DistributedEmbedding(keras.layers.Layer):
765
791
  placement_to_path_to_outputs,
766
792
  )
767
793
 
768
- # Go from placement -> path -> output to flat outputs.
769
- flat_outputs = keras.tree.flatten(placement_to_path_to_outputs)
794
+ # Go from placement -> path -> output to deeply nested structure.
795
+ def populate_output(pp: PlacementAndPath) -> types.Tensor:
796
+ return placement_to_path_to_outputs[pp.placement][pp.path]
770
797
 
771
- # Go from flat outputs to deeply nested structure.
772
- return keras.tree.pack_sequence_as(self._feature_configs, flat_outputs)
798
+ return keras.tree.map_structure(
799
+ populate_output, self._feature_deeply_nested_placement_and_paths
800
+ )
773
801
 
774
802
  def get_embedding_tables(self) -> dict[str, types.Tensor]:
775
803
  """Return the content of the embedding tables by table name.
@@ -794,13 +822,13 @@ class DistributedEmbedding(keras.layers.Layer):
794
822
  table_stacking: str | Sequence[Sequence[str]],
795
823
  ) -> None:
796
824
  del table_stacking
797
- table_to_embedding_layer: dict[TableConfig, EmbedReduce] = {}
825
+ table_config_id_to_embedding_layer: dict[int, EmbedReduce] = {}
798
826
  self._default_device_embedding_layers: dict[str, EmbedReduce] = {}
799
827
 
800
828
  for path, feature_config in feature_configs.items():
801
- if feature_config.table in table_to_embedding_layer:
829
+ if id(feature_config.table) in table_config_id_to_embedding_layer:
802
830
  self._default_device_embedding_layers[path] = (
803
- table_to_embedding_layer[feature_config.table]
831
+ table_config_id_to_embedding_layer[id(feature_config.table)]
804
832
  )
805
833
  else:
806
834
  embedding_layer = EmbedReduce(
@@ -810,7 +838,9 @@ class DistributedEmbedding(keras.layers.Layer):
810
838
  embeddings_initializer=feature_config.table.initializer,
811
839
  combiner=feature_config.table.combiner,
812
840
  )
813
- table_to_embedding_layer[feature_config.table] = embedding_layer
841
+ table_config_id_to_embedding_layer[id(feature_config.table)] = (
842
+ embedding_layer
843
+ )
814
844
  self._default_device_embedding_layers[path] = embedding_layer
815
845
 
816
846
  def _default_device_build(
@@ -985,8 +1015,8 @@ class DistributedEmbedding(keras.layers.Layer):
985
1015
 
986
1016
  # The serialized `TableConfig` objects.
987
1017
  table_config_dicts: list[dict[str, Any]] = []
988
- # Mapping from `TableConfig` to index in `table_config_dicts`.
989
- table_config_indices: dict[TableConfig, int] = {}
1018
+ # Mapping from `TableConfig` id to index in `table_config_dicts`.
1019
+ table_config_id_to_index: dict[int, int] = {}
990
1020
 
991
1021
  def serialize_feature_config(
992
1022
  feature_config: FeatureConfig,
@@ -996,17 +1026,17 @@ class DistributedEmbedding(keras.layers.Layer):
996
1026
  # key.
997
1027
  feature_config_dict = feature_config.get_config()
998
1028
 
999
- if feature_config.table not in table_config_indices:
1029
+ if id(feature_config.table) not in table_config_id_to_index:
1000
1030
  # Save the serialized `TableConfig` the first time we see it and
1001
1031
  # remember its index.
1002
- table_config_indices[feature_config.table] = len(
1032
+ table_config_id_to_index[id(feature_config.table)] = len(
1003
1033
  table_config_dicts
1004
1034
  )
1005
1035
  table_config_dicts.append(feature_config_dict["table"])
1006
1036
 
1007
1037
  # Replace the serialized `TableConfig` with its index.
1008
- feature_config_dict["table"] = table_config_indices[
1009
- feature_config.table
1038
+ feature_config_dict["table"] = table_config_id_to_index[
1039
+ id(feature_config.table)
1010
1040
  ]
1011
1041
  return feature_config_dict
1012
1042
 
@@ -10,7 +10,7 @@ from keras_rs.src.api_export import keras_rs_export
10
10
 
11
11
 
12
12
  @keras_rs_export("keras_rs.layers.TableConfig")
13
- @dataclasses.dataclass(eq=True, unsafe_hash=True, order=True)
13
+ @dataclasses.dataclass(order=True)
14
14
  class TableConfig:
15
15
  """Configuration for one embedding table.
16
16
 
@@ -88,7 +88,7 @@ class TableConfig:
88
88
 
89
89
 
90
90
  @keras_rs_export("keras_rs.layers.FeatureConfig")
91
- @dataclasses.dataclass(eq=True, unsafe_hash=True, order=True)
91
+ @dataclasses.dataclass(order=True)
92
92
  class FeatureConfig:
93
93
  """Configuration for one embedding feature.
94
94
 
@@ -102,7 +102,10 @@ class FeatureConfig:
102
102
  input_shape: The input shape of the feature. The feature fed into the
103
103
  layer has to match the shape. Note that for ragged dimensions in the
104
104
  input, the dimension provided here presents the maximum value;
105
- anything larger will be truncated.
105
+ anything larger will be truncated. Also note that the first
106
+ dimension represents the global batch size. For example, on TPU,
107
+ this represents the total number of samples that are dispatched to
108
+ all the TPUs connected to the current host.
106
109
  output_shape: The output shape of the feature activation. What is
107
110
  returned by the embedding layer has to match this shape.
108
111
  """