keras-rs-nightly 0.2.2.dev202509030321__tar.gz → 0.2.2.dev202509170322__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of keras-rs-nightly might be problematic. Click here for more details.

Files changed (61) hide show
  1. {keras_rs_nightly-0.2.2.dev202509030321 → keras_rs_nightly-0.2.2.dev202509170322}/PKG-INFO +4 -3
  2. {keras_rs_nightly-0.2.2.dev202509030321 → keras_rs_nightly-0.2.2.dev202509170322}/keras_rs/src/layers/embedding/base_distributed_embedding.py +12 -10
  3. {keras_rs_nightly-0.2.2.dev202509030321 → keras_rs_nightly-0.2.2.dev202509170322}/keras_rs/src/layers/embedding/distributed_embedding_config.py +2 -2
  4. {keras_rs_nightly-0.2.2.dev202509030321 → keras_rs_nightly-0.2.2.dev202509170322}/keras_rs/src/layers/embedding/jax/distributed_embedding.py +41 -174
  5. {keras_rs_nightly-0.2.2.dev202509030321 → keras_rs_nightly-0.2.2.dev202509170322}/keras_rs/src/layers/embedding/jax/embedding_utils.py +68 -22
  6. {keras_rs_nightly-0.2.2.dev202509030321 → keras_rs_nightly-0.2.2.dev202509170322}/keras_rs/src/layers/embedding/tensorflow/config_conversion.py +26 -19
  7. {keras_rs_nightly-0.2.2.dev202509030321 → keras_rs_nightly-0.2.2.dev202509170322}/keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py +15 -5
  8. {keras_rs_nightly-0.2.2.dev202509030321 → keras_rs_nightly-0.2.2.dev202509170322}/keras_rs/src/version.py +1 -1
  9. {keras_rs_nightly-0.2.2.dev202509030321 → keras_rs_nightly-0.2.2.dev202509170322}/keras_rs_nightly.egg-info/PKG-INFO +4 -3
  10. {keras_rs_nightly-0.2.2.dev202509030321 → keras_rs_nightly-0.2.2.dev202509170322}/pyproject.toml +3 -2
  11. {keras_rs_nightly-0.2.2.dev202509030321 → keras_rs_nightly-0.2.2.dev202509170322}/README.md +0 -0
  12. {keras_rs_nightly-0.2.2.dev202509030321 → keras_rs_nightly-0.2.2.dev202509170322}/keras_rs/api/__init__.py +0 -0
  13. {keras_rs_nightly-0.2.2.dev202509030321 → keras_rs_nightly-0.2.2.dev202509170322}/keras_rs/api/layers/__init__.py +0 -0
  14. {keras_rs_nightly-0.2.2.dev202509030321 → keras_rs_nightly-0.2.2.dev202509170322}/keras_rs/api/losses/__init__.py +0 -0
  15. {keras_rs_nightly-0.2.2.dev202509030321 → keras_rs_nightly-0.2.2.dev202509170322}/keras_rs/api/metrics/__init__.py +0 -0
  16. {keras_rs_nightly-0.2.2.dev202509030321 → keras_rs_nightly-0.2.2.dev202509170322}/keras_rs/src/__init__.py +0 -0
  17. {keras_rs_nightly-0.2.2.dev202509030321 → keras_rs_nightly-0.2.2.dev202509170322}/keras_rs/src/api_export.py +0 -0
  18. {keras_rs_nightly-0.2.2.dev202509030321 → keras_rs_nightly-0.2.2.dev202509170322}/keras_rs/src/layers/__init__.py +0 -0
  19. {keras_rs_nightly-0.2.2.dev202509030321 → keras_rs_nightly-0.2.2.dev202509170322}/keras_rs/src/layers/embedding/__init__.py +0 -0
  20. {keras_rs_nightly-0.2.2.dev202509030321 → keras_rs_nightly-0.2.2.dev202509170322}/keras_rs/src/layers/embedding/distributed_embedding.py +0 -0
  21. {keras_rs_nightly-0.2.2.dev202509030321 → keras_rs_nightly-0.2.2.dev202509170322}/keras_rs/src/layers/embedding/embed_reduce.py +0 -0
  22. {keras_rs_nightly-0.2.2.dev202509030321 → keras_rs_nightly-0.2.2.dev202509170322}/keras_rs/src/layers/embedding/jax/__init__.py +0 -0
  23. {keras_rs_nightly-0.2.2.dev202509030321 → keras_rs_nightly-0.2.2.dev202509170322}/keras_rs/src/layers/embedding/jax/checkpoint_utils.py +0 -0
  24. {keras_rs_nightly-0.2.2.dev202509030321 → keras_rs_nightly-0.2.2.dev202509170322}/keras_rs/src/layers/embedding/jax/config_conversion.py +0 -0
  25. {keras_rs_nightly-0.2.2.dev202509030321 → keras_rs_nightly-0.2.2.dev202509170322}/keras_rs/src/layers/embedding/jax/embedding_lookup.py +0 -0
  26. {keras_rs_nightly-0.2.2.dev202509030321 → keras_rs_nightly-0.2.2.dev202509170322}/keras_rs/src/layers/embedding/tensorflow/__init__.py +0 -0
  27. {keras_rs_nightly-0.2.2.dev202509030321 → keras_rs_nightly-0.2.2.dev202509170322}/keras_rs/src/layers/feature_interaction/__init__.py +0 -0
  28. {keras_rs_nightly-0.2.2.dev202509030321 → keras_rs_nightly-0.2.2.dev202509170322}/keras_rs/src/layers/feature_interaction/dot_interaction.py +0 -0
  29. {keras_rs_nightly-0.2.2.dev202509030321 → keras_rs_nightly-0.2.2.dev202509170322}/keras_rs/src/layers/feature_interaction/feature_cross.py +0 -0
  30. {keras_rs_nightly-0.2.2.dev202509030321 → keras_rs_nightly-0.2.2.dev202509170322}/keras_rs/src/layers/retrieval/__init__.py +0 -0
  31. {keras_rs_nightly-0.2.2.dev202509030321 → keras_rs_nightly-0.2.2.dev202509170322}/keras_rs/src/layers/retrieval/brute_force_retrieval.py +0 -0
  32. {keras_rs_nightly-0.2.2.dev202509030321 → keras_rs_nightly-0.2.2.dev202509170322}/keras_rs/src/layers/retrieval/hard_negative_mining.py +0 -0
  33. {keras_rs_nightly-0.2.2.dev202509030321 → keras_rs_nightly-0.2.2.dev202509170322}/keras_rs/src/layers/retrieval/remove_accidental_hits.py +0 -0
  34. {keras_rs_nightly-0.2.2.dev202509030321 → keras_rs_nightly-0.2.2.dev202509170322}/keras_rs/src/layers/retrieval/retrieval.py +0 -0
  35. {keras_rs_nightly-0.2.2.dev202509030321 → keras_rs_nightly-0.2.2.dev202509170322}/keras_rs/src/layers/retrieval/sampling_probability_correction.py +0 -0
  36. {keras_rs_nightly-0.2.2.dev202509030321 → keras_rs_nightly-0.2.2.dev202509170322}/keras_rs/src/losses/__init__.py +0 -0
  37. {keras_rs_nightly-0.2.2.dev202509030321 → keras_rs_nightly-0.2.2.dev202509170322}/keras_rs/src/losses/pairwise_hinge_loss.py +0 -0
  38. {keras_rs_nightly-0.2.2.dev202509030321 → keras_rs_nightly-0.2.2.dev202509170322}/keras_rs/src/losses/pairwise_logistic_loss.py +0 -0
  39. {keras_rs_nightly-0.2.2.dev202509030321 → keras_rs_nightly-0.2.2.dev202509170322}/keras_rs/src/losses/pairwise_loss.py +0 -0
  40. {keras_rs_nightly-0.2.2.dev202509030321 → keras_rs_nightly-0.2.2.dev202509170322}/keras_rs/src/losses/pairwise_loss_utils.py +0 -0
  41. {keras_rs_nightly-0.2.2.dev202509030321 → keras_rs_nightly-0.2.2.dev202509170322}/keras_rs/src/losses/pairwise_mean_squared_error.py +0 -0
  42. {keras_rs_nightly-0.2.2.dev202509030321 → keras_rs_nightly-0.2.2.dev202509170322}/keras_rs/src/losses/pairwise_soft_zero_one_loss.py +0 -0
  43. {keras_rs_nightly-0.2.2.dev202509030321 → keras_rs_nightly-0.2.2.dev202509170322}/keras_rs/src/metrics/__init__.py +0 -0
  44. {keras_rs_nightly-0.2.2.dev202509030321 → keras_rs_nightly-0.2.2.dev202509170322}/keras_rs/src/metrics/dcg.py +0 -0
  45. {keras_rs_nightly-0.2.2.dev202509030321 → keras_rs_nightly-0.2.2.dev202509170322}/keras_rs/src/metrics/mean_average_precision.py +0 -0
  46. {keras_rs_nightly-0.2.2.dev202509030321 → keras_rs_nightly-0.2.2.dev202509170322}/keras_rs/src/metrics/mean_reciprocal_rank.py +0 -0
  47. {keras_rs_nightly-0.2.2.dev202509030321 → keras_rs_nightly-0.2.2.dev202509170322}/keras_rs/src/metrics/ndcg.py +0 -0
  48. {keras_rs_nightly-0.2.2.dev202509030321 → keras_rs_nightly-0.2.2.dev202509170322}/keras_rs/src/metrics/precision_at_k.py +0 -0
  49. {keras_rs_nightly-0.2.2.dev202509030321 → keras_rs_nightly-0.2.2.dev202509170322}/keras_rs/src/metrics/ranking_metric.py +0 -0
  50. {keras_rs_nightly-0.2.2.dev202509030321 → keras_rs_nightly-0.2.2.dev202509170322}/keras_rs/src/metrics/ranking_metrics_utils.py +0 -0
  51. {keras_rs_nightly-0.2.2.dev202509030321 → keras_rs_nightly-0.2.2.dev202509170322}/keras_rs/src/metrics/recall_at_k.py +0 -0
  52. {keras_rs_nightly-0.2.2.dev202509030321 → keras_rs_nightly-0.2.2.dev202509170322}/keras_rs/src/metrics/utils.py +0 -0
  53. {keras_rs_nightly-0.2.2.dev202509030321 → keras_rs_nightly-0.2.2.dev202509170322}/keras_rs/src/types.py +0 -0
  54. {keras_rs_nightly-0.2.2.dev202509030321 → keras_rs_nightly-0.2.2.dev202509170322}/keras_rs/src/utils/__init__.py +0 -0
  55. {keras_rs_nightly-0.2.2.dev202509030321 → keras_rs_nightly-0.2.2.dev202509170322}/keras_rs/src/utils/doc_string_utils.py +0 -0
  56. {keras_rs_nightly-0.2.2.dev202509030321 → keras_rs_nightly-0.2.2.dev202509170322}/keras_rs/src/utils/keras_utils.py +0 -0
  57. {keras_rs_nightly-0.2.2.dev202509030321 → keras_rs_nightly-0.2.2.dev202509170322}/keras_rs_nightly.egg-info/SOURCES.txt +0 -0
  58. {keras_rs_nightly-0.2.2.dev202509030321 → keras_rs_nightly-0.2.2.dev202509170322}/keras_rs_nightly.egg-info/dependency_links.txt +0 -0
  59. {keras_rs_nightly-0.2.2.dev202509030321 → keras_rs_nightly-0.2.2.dev202509170322}/keras_rs_nightly.egg-info/requires.txt +0 -0
  60. {keras_rs_nightly-0.2.2.dev202509030321 → keras_rs_nightly-0.2.2.dev202509170322}/keras_rs_nightly.egg-info/top_level.txt +0 -0
  61. {keras_rs_nightly-0.2.2.dev202509030321 → keras_rs_nightly-0.2.2.dev202509170322}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: keras-rs-nightly
3
- Version: 0.2.2.dev202509030321
3
+ Version: 0.2.2.dev202509170322
4
4
  Summary: Multi-backend recommender systems with Keras 3.
5
5
  Author-email: Keras team <keras-users@googlegroups.com>
6
6
  License: Apache License 2.0
@@ -8,8 +8,9 @@ Project-URL: Home, https://keras.io/keras_rs
8
8
  Project-URL: Repository, https://github.com/keras-team/keras-rs
9
9
  Classifier: Development Status :: 3 - Alpha
10
10
  Classifier: Programming Language :: Python :: 3
11
- Classifier: Programming Language :: Python :: 3.10
12
11
  Classifier: Programming Language :: Python :: 3.11
12
+ Classifier: Programming Language :: Python :: 3.12
13
+ Classifier: Programming Language :: Python :: 3.13
13
14
  Classifier: Programming Language :: Python :: 3 :: Only
14
15
  Classifier: Operating System :: Unix
15
16
  Classifier: Operating System :: Microsoft :: Windows
@@ -17,7 +18,7 @@ Classifier: Operating System :: MacOS
17
18
  Classifier: Intended Audience :: Science/Research
18
19
  Classifier: Topic :: Scientific/Engineering
19
20
  Classifier: Topic :: Software Development
20
- Requires-Python: >=3.10
21
+ Requires-Python: >=3.11
21
22
  Description-Content-Type: text/markdown
22
23
  Requires-Dist: keras
23
24
  Requires-Dist: ml-dtypes
@@ -822,13 +822,13 @@ class DistributedEmbedding(keras.layers.Layer):
822
822
  table_stacking: str | Sequence[Sequence[str]],
823
823
  ) -> None:
824
824
  del table_stacking
825
- table_to_embedding_layer: dict[TableConfig, EmbedReduce] = {}
825
+ table_config_id_to_embedding_layer: dict[int, EmbedReduce] = {}
826
826
  self._default_device_embedding_layers: dict[str, EmbedReduce] = {}
827
827
 
828
828
  for path, feature_config in feature_configs.items():
829
- if feature_config.table in table_to_embedding_layer:
829
+ if id(feature_config.table) in table_config_id_to_embedding_layer:
830
830
  self._default_device_embedding_layers[path] = (
831
- table_to_embedding_layer[feature_config.table]
831
+ table_config_id_to_embedding_layer[id(feature_config.table)]
832
832
  )
833
833
  else:
834
834
  embedding_layer = EmbedReduce(
@@ -838,7 +838,9 @@ class DistributedEmbedding(keras.layers.Layer):
838
838
  embeddings_initializer=feature_config.table.initializer,
839
839
  combiner=feature_config.table.combiner,
840
840
  )
841
- table_to_embedding_layer[feature_config.table] = embedding_layer
841
+ table_config_id_to_embedding_layer[id(feature_config.table)] = (
842
+ embedding_layer
843
+ )
842
844
  self._default_device_embedding_layers[path] = embedding_layer
843
845
 
844
846
  def _default_device_build(
@@ -1013,8 +1015,8 @@ class DistributedEmbedding(keras.layers.Layer):
1013
1015
 
1014
1016
  # The serialized `TableConfig` objects.
1015
1017
  table_config_dicts: list[dict[str, Any]] = []
1016
- # Mapping from `TableConfig` to index in `table_config_dicts`.
1017
- table_config_indices: dict[TableConfig, int] = {}
1018
+ # Mapping from `TableConfig` id to index in `table_config_dicts`.
1019
+ table_config_id_to_index: dict[int, int] = {}
1018
1020
 
1019
1021
  def serialize_feature_config(
1020
1022
  feature_config: FeatureConfig,
@@ -1024,17 +1026,17 @@ class DistributedEmbedding(keras.layers.Layer):
1024
1026
  # key.
1025
1027
  feature_config_dict = feature_config.get_config()
1026
1028
 
1027
- if feature_config.table not in table_config_indices:
1029
+ if id(feature_config.table) not in table_config_id_to_index:
1028
1030
  # Save the serialized `TableConfig` the first time we see it and
1029
1031
  # remember its index.
1030
- table_config_indices[feature_config.table] = len(
1032
+ table_config_id_to_index[id(feature_config.table)] = len(
1031
1033
  table_config_dicts
1032
1034
  )
1033
1035
  table_config_dicts.append(feature_config_dict["table"])
1034
1036
 
1035
1037
  # Replace the serialized `TableConfig` with its index.
1036
- feature_config_dict["table"] = table_config_indices[
1037
- feature_config.table
1038
+ feature_config_dict["table"] = table_config_id_to_index[
1039
+ id(feature_config.table)
1038
1040
  ]
1039
1041
  return feature_config_dict
1040
1042
 
@@ -10,7 +10,7 @@ from keras_rs.src.api_export import keras_rs_export
10
10
 
11
11
 
12
12
  @keras_rs_export("keras_rs.layers.TableConfig")
13
- @dataclasses.dataclass(eq=True, unsafe_hash=True, order=True)
13
+ @dataclasses.dataclass(order=True)
14
14
  class TableConfig:
15
15
  """Configuration for one embedding table.
16
16
 
@@ -88,7 +88,7 @@ class TableConfig:
88
88
 
89
89
 
90
90
  @keras_rs_export("keras_rs.layers.FeatureConfig")
91
- @dataclasses.dataclass(eq=True, unsafe_hash=True, order=True)
91
+ @dataclasses.dataclass(order=True)
92
92
  class FeatureConfig:
93
93
  """Configuration for one embedding feature.
94
94
 
@@ -15,7 +15,6 @@ from jax_tpu_embedding.sparsecore.lib.nn import (
15
15
  table_stacking as jte_table_stacking,
16
16
  )
17
17
  from jax_tpu_embedding.sparsecore.utils import utils as jte_utils
18
- from keras.src import backend
19
18
 
20
19
  from keras_rs.src import types
21
20
  from keras_rs.src.layers.embedding import base_distributed_embedding
@@ -247,23 +246,6 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
247
246
  )
248
247
  return sparsecore_distribution, sparsecore_layout
249
248
 
250
- def _create_cpu_distribution(
251
- self, cpu_axis_name: str = "cpu"
252
- ) -> tuple[
253
- keras.distribution.ModelParallel, keras.distribution.TensorLayout
254
- ]:
255
- """Share a variable across all CPU processes."""
256
- cpu_devices = jax.devices("cpu")
257
- device_mesh = keras.distribution.DeviceMesh(
258
- (len(cpu_devices),), [cpu_axis_name], cpu_devices
259
- )
260
- replicated_layout = keras.distribution.TensorLayout([], device_mesh)
261
- layout_map = keras.distribution.LayoutMap(device_mesh=device_mesh)
262
- cpu_distribution = keras.distribution.ModelParallel(
263
- layout_map=layout_map
264
- )
265
- return cpu_distribution, replicated_layout
266
-
267
249
  def _add_sparsecore_weight(
268
250
  self,
269
251
  name: str,
@@ -405,11 +387,6 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
405
387
  self._sparsecore_layout = sparsecore_layout
406
388
  self._sparsecore_distribution = sparsecore_distribution
407
389
 
408
- # Distribution for CPU operations.
409
- cpu_distribution, cpu_layout = self._create_cpu_distribution()
410
- self._cpu_distribution = cpu_distribution
411
- self._cpu_layout = cpu_layout
412
-
413
390
  mesh = sparsecore_distribution.device_mesh.backend_mesh
414
391
  global_device_count = mesh.devices.size
415
392
  num_sc_per_device = jte_utils.num_sparsecores_per_device(
@@ -466,10 +443,6 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
466
443
  # Collect all stacked tables.
467
444
  table_specs = embedding_utils.get_table_specs(feature_specs)
468
445
  table_stacks = embedding_utils.get_table_stacks(table_specs)
469
- stacked_table_specs = {
470
- stack_name: stack[0].stacked_table_spec
471
- for stack_name, stack in table_stacks.items()
472
- }
473
446
 
474
447
  # Create variables for all stacked tables and slot variables.
475
448
  with sparsecore_distribution.scope():
@@ -502,50 +475,6 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
502
475
  )
503
476
  self._iterations.overwrite_with_gradient = True
504
477
 
505
- with cpu_distribution.scope():
506
- # Create variables to track static buffer size and max IDs for each
507
- # table during preprocessing. These variables are shared across all
508
- # processes on CPU. We don't add these via `add_weight` because we
509
- # can't have them passed to the training function.
510
- replicated_zeros_initializer = ShardedInitializer(
511
- "zeros", cpu_layout
512
- )
513
-
514
- with backend.name_scope(self.name, caller=self):
515
- self._preprocessing_buffer_size = {
516
- table_name: backend.Variable(
517
- initializer=replicated_zeros_initializer,
518
- shape=(),
519
- dtype=backend.standardize_dtype("int32"),
520
- trainable=False,
521
- name=table_name + ":preprocessing:buffer_size",
522
- )
523
- for table_name in stacked_table_specs.keys()
524
- }
525
- self._preprocessing_max_unique_ids_per_partition = {
526
- table_name: backend.Variable(
527
- shape=(),
528
- name=table_name
529
- + ":preprocessing:max_unique_ids_per_partition",
530
- initializer=replicated_zeros_initializer,
531
- dtype=backend.standardize_dtype("int32"),
532
- trainable=False,
533
- )
534
- for table_name in stacked_table_specs.keys()
535
- }
536
-
537
- self._preprocessing_max_ids_per_partition = {
538
- table_name: backend.Variable(
539
- shape=(),
540
- name=table_name
541
- + ":preprocessing:max_ids_per_partition",
542
- initializer=replicated_zeros_initializer,
543
- dtype=backend.standardize_dtype("int32"),
544
- trainable=False,
545
- )
546
- for table_name in stacked_table_specs.keys()
547
- }
548
-
549
478
  self._config = jte_embedding_lookup.EmbeddingLookupConfiguration(
550
479
  feature_specs,
551
480
  mesh=mesh,
@@ -660,76 +589,35 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
660
589
  mesh.devices.item(0)
661
590
  )
662
591
 
663
- # Get current buffer size/max_ids.
664
- previous_max_ids_per_partition = keras.tree.map_structure(
665
- lambda max_ids_per_partition: max_ids_per_partition.value.item(),
666
- self._preprocessing_max_ids_per_partition,
667
- )
668
- previous_max_unique_ids_per_partition = keras.tree.map_structure(
669
- lambda max_unique_ids_per_partition: (
670
- max_unique_ids_per_partition.value.item()
671
- ),
672
- self._preprocessing_max_unique_ids_per_partition,
673
- )
674
- previous_buffer_size = keras.tree.map_structure(
675
- lambda buffer_size: buffer_size.value.item(),
676
- self._preprocessing_buffer_size,
677
- )
678
-
679
592
  preprocessed, stats = embedding_utils.stack_and_shard_samples(
680
593
  self._config.feature_specs,
681
594
  samples,
682
595
  local_device_count,
683
596
  global_device_count,
684
597
  num_sc_per_device,
685
- static_buffer_size=previous_buffer_size,
686
598
  )
687
599
 
688
- # Extract max unique IDs and buffer sizes.
689
- # We need to replicate this value across all local CPU devices.
690
600
  if training:
601
+ # Synchronize input statistics across all devices and update the
602
+ # underlying stacked tables specs in the feature specs.
603
+ prev_stats = embedding_utils.get_stacked_table_stats(
604
+ self._config.feature_specs
605
+ )
606
+
607
+ # Take the maximum with existing stats.
608
+ stats = keras.tree.map_structure(max, prev_stats, stats)
609
+
610
+ # Flatten the stats so we can more efficiently transfer them
611
+ # between hosts. We use jax.tree because we will later need to
612
+ # unflatten.
613
+ flat_stats, stats_treedef = jax.tree.flatten(stats)
614
+
615
+ # In the case of multiple local CPU devices per host, we need to
616
+ # replicate the stats to placate JAX collectives.
691
617
  num_local_cpu_devices = jax.local_device_count("cpu")
692
- local_max_ids_per_partition = {
693
- table_name: np.repeat(
694
- # Maximum across all partitions and previous max.
695
- np.maximum(
696
- np.max(elems),
697
- previous_max_ids_per_partition[table_name],
698
- ),
699
- num_local_cpu_devices,
700
- )
701
- for table_name, elems in stats.max_ids_per_partition.items()
702
- }
703
- local_max_unique_ids_per_partition = {
704
- name: np.repeat(
705
- # Maximum across all partitions and previous max.
706
- np.maximum(
707
- np.max(elems),
708
- previous_max_unique_ids_per_partition[name],
709
- ),
710
- num_local_cpu_devices,
711
- )
712
- for name, elems in stats.max_unique_ids_per_partition.items()
713
- }
714
- local_buffer_size = {
715
- table_name: np.repeat(
716
- np.maximum(
717
- np.max(
718
- # Round values up to the next multiple of 8.
719
- # Currently using this as a proxy for the actual
720
- # required buffer size.
721
- ((elems + 7) // 8) * 8
722
- )
723
- * global_device_count
724
- * num_sc_per_device
725
- * local_device_count
726
- * num_sc_per_device,
727
- previous_buffer_size[table_name],
728
- ),
729
- num_local_cpu_devices,
730
- )
731
- for table_name, elems in stats.max_ids_per_partition.items()
732
- }
618
+ tiled_stats = np.tile(
619
+ np.array(flat_stats, dtype=np.int32), (num_local_cpu_devices, 1)
620
+ )
733
621
 
734
622
  # Aggregate variables across all processes/devices.
735
623
  max_across_cpus = jax.pmap(
@@ -737,48 +625,24 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
737
625
  x, "all_cpus"
738
626
  ),
739
627
  axis_name="all_cpus",
740
- devices=self._cpu_layout.device_mesh.backend_mesh.devices,
628
+ backend="cpu",
741
629
  )
742
- new_max_ids_per_partition = max_across_cpus(
743
- local_max_ids_per_partition
744
- )
745
- new_max_unique_ids_per_partition = max_across_cpus(
746
- local_max_unique_ids_per_partition
747
- )
748
- new_buffer_size = max_across_cpus(local_buffer_size)
749
-
750
- # Assign new preprocessing parameters.
751
- with self._cpu_distribution.scope():
752
- # For each process, all max ids/buffer sizes are replicated
753
- # across all local devices. Take the value from the first
754
- # device.
755
- keras.tree.map_structure(
756
- lambda var, values: var.assign(values[0]),
757
- self._preprocessing_max_ids_per_partition,
758
- new_max_ids_per_partition,
759
- )
760
- keras.tree.map_structure(
761
- lambda var, values: var.assign(values[0]),
762
- self._preprocessing_max_unique_ids_per_partition,
763
- new_max_unique_ids_per_partition,
764
- )
765
- keras.tree.map_structure(
766
- lambda var, values: var.assign(values[0]),
767
- self._preprocessing_buffer_size,
768
- new_buffer_size,
769
- )
770
- # Update parameters in the underlying feature specs.
771
- int_max_ids_per_partition = keras.tree.map_structure(
772
- lambda varray: varray.item(), new_max_ids_per_partition
773
- )
774
- int_max_unique_ids_per_partition = keras.tree.map_structure(
775
- lambda varray: varray.item(),
776
- new_max_unique_ids_per_partition,
630
+ flat_stats = max_across_cpus(tiled_stats)[0].tolist()
631
+ stats = jax.tree.unflatten(stats_treedef, flat_stats)
632
+
633
+ # Update configuration and repeat preprocessing if stats changed.
634
+ if stats != prev_stats:
635
+ embedding_utils.update_stacked_table_stats(
636
+ self._config.feature_specs, stats
777
637
  )
778
- embedding_utils.update_stacked_table_specs(
638
+
639
+ # Re-execute preprocessing with consistent input statistics.
640
+ preprocessed, _ = embedding_utils.stack_and_shard_samples(
779
641
  self._config.feature_specs,
780
- int_max_ids_per_partition,
781
- int_max_unique_ids_per_partition,
642
+ samples,
643
+ local_device_count,
644
+ global_device_count,
645
+ num_sc_per_device,
782
646
  )
783
647
 
784
648
  return {"inputs": preprocessed}
@@ -826,19 +690,22 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
826
690
  raise ValueError("Layer must first be built before setting tables.")
827
691
 
828
692
  if "default_device" in self._placement_to_path_to_feature_config:
829
- table_to_embedding_layer = {}
693
+ table_name_to_embedding_layer = {}
830
694
  for (
831
695
  path,
832
696
  feature_config,
833
697
  ) in self._placement_to_path_to_feature_config[
834
698
  "default_device"
835
699
  ].items():
836
- table_to_embedding_layer[feature_config.table] = (
700
+ table_name_to_embedding_layer[feature_config.table.name] = (
837
701
  self._default_device_embedding_layers[path]
838
702
  )
839
703
 
840
- for table, embedding_layer in table_to_embedding_layer.items():
841
- table_values = tables.get(table.name, None)
704
+ for (
705
+ table_name,
706
+ embedding_layer,
707
+ ) in table_name_to_embedding_layer.items():
708
+ table_values = tables.get(table_name, None)
842
709
  if table_values is not None:
843
710
  if embedding_layer.lora_enabled:
844
711
  raise ValueError("Cannot set table if LoRA is enabled.")
@@ -35,6 +35,12 @@ class ShardedCooMatrix(NamedTuple):
35
35
  values: ArrayLike
36
36
 
37
37
 
38
+ class InputStatsPerTable(NamedTuple):
39
+ max_ids_per_partition: int
40
+ max_unique_ids_per_partition: int
41
+ required_buffer_size_per_device: int
42
+
43
+
38
44
  def _round_up_to_multiple(value: int, multiple: int) -> int:
39
45
  return ((value + multiple - 1) // multiple) * multiple
40
46
 
@@ -335,19 +341,47 @@ def get_table_stacks(
335
341
  return stacked_table_specs
336
342
 
337
343
 
338
- def update_stacked_table_specs(
344
+ def get_stacked_table_stats(
339
345
  feature_specs: Nested[FeatureSpec],
340
- max_ids_per_partition: Mapping[str, int],
341
- max_unique_ids_per_partition: Mapping[str, int],
346
+ ) -> dict[str, InputStatsPerTable]:
347
+ """Extracts the stacked-table input statistics from the feature specs.
348
+
349
+ Args:
350
+ feature_specs: Feature specs from which to extracts the statistics.
351
+
352
+ Returns:
353
+ A mapping of stacked table names to input statistics per table.
354
+ """
355
+ stacked_table_specs: dict[str, StackedTableSpec] = {}
356
+ for feature_spec in jax.tree.flatten(feature_specs)[0]:
357
+ feature_spec = typing.cast(FeatureSpec, feature_spec)
358
+ stacked_table_spec = typing.cast(
359
+ StackedTableSpec, feature_spec.table_spec.stacked_table_spec
360
+ )
361
+ stacked_table_specs[stacked_table_spec.stack_name] = stacked_table_spec
362
+
363
+ stats: dict[str, InputStatsPerTable] = {}
364
+ for stacked_table_spec in stacked_table_specs.values():
365
+ buffer_size = stacked_table_spec.suggested_coo_buffer_size_per_device
366
+ buffer_size = buffer_size or 0
367
+ stats[stacked_table_spec.stack_name] = InputStatsPerTable(
368
+ max_ids_per_partition=stacked_table_spec.max_ids_per_partition,
369
+ max_unique_ids_per_partition=stacked_table_spec.max_unique_ids_per_partition,
370
+ required_buffer_size_per_device=buffer_size,
371
+ )
372
+
373
+ return stats
374
+
375
+
376
+ def update_stacked_table_stats(
377
+ feature_specs: Nested[FeatureSpec],
378
+ stats: Mapping[str, InputStatsPerTable],
342
379
  ) -> None:
343
- """Updates properties in the supplied feature specs.
380
+ """Updates stacked-table input properties in the supplied feature specs.
344
381
 
345
382
  Args:
346
383
  feature_specs: Feature specs to update in-place.
347
- max_ids_per_partition: Mapping of table stack name to
348
- new `max_ids_per_partition` for the stack.
349
- max_unique_ids_per_partition: Mapping of table stack name to
350
- new `max_unique_ids_per_partition` for the stack.
384
+ stats: Per-stacked-table input statistics.
351
385
  """
352
386
  # Collect table specs and stacked table specs.
353
387
  table_specs: dict[str, TableSpec] = {}
@@ -363,18 +397,17 @@ def update_stacked_table_specs(
363
397
  stacked_table_specs[stacked_table_spec.stack_name] = stacked_table_spec
364
398
 
365
399
  # Replace fields in the stacked_table_specs.
366
- stacked_table_specs = {
367
- stack_name: dataclasses.replace(
400
+ stack_names = stacked_table_specs.keys()
401
+ for stack_name in stack_names:
402
+ stack_stats = stats[stack_name]
403
+ stacked_table_spec = stacked_table_specs[stack_name]
404
+ buffer_size = stack_stats.required_buffer_size_per_device or None
405
+ stacked_table_specs[stack_name] = dataclasses.replace(
368
406
  stacked_table_spec,
369
- max_ids_per_partition=max_ids_per_partition[
370
- stacked_table_spec.stack_name
371
- ],
372
- max_unique_ids_per_partition=max_unique_ids_per_partition[
373
- stacked_table_spec.stack_name
374
- ],
407
+ max_ids_per_partition=stack_stats.max_ids_per_partition,
408
+ max_unique_ids_per_partition=stack_stats.max_unique_ids_per_partition,
409
+ suggested_coo_buffer_size_per_device=buffer_size,
375
410
  )
376
- for stack_name, stacked_table_spec in stacked_table_specs.items()
377
- }
378
411
 
379
412
  # Insert new stacked tables into tables.
380
413
  for table_spec in table_specs.values():
@@ -534,7 +567,7 @@ def stack_and_shard_samples(
534
567
  global_device_count: int,
535
568
  num_sc_per_device: int,
536
569
  static_buffer_size: int | Mapping[str, int] | None = None,
537
- ) -> tuple[dict[str, ShardedCooMatrix], embedding.SparseDenseMatmulInputStats]:
570
+ ) -> tuple[dict[str, ShardedCooMatrix], dict[str, InputStatsPerTable]]:
538
571
  """Prepares input samples for use in embedding lookups.
539
572
 
540
573
  Args:
@@ -544,8 +577,8 @@ def stack_and_shard_samples(
544
577
  global_device_count: Number of global JAX devices.
545
578
  num_sc_per_device: Number of sparsecores per device.
546
579
  static_buffer_size: The static buffer size to use for the samples.
547
- Defaults to None, in which case an upper-bound for the buffer size
548
- will be automatically determined.
580
+ Defaults to None, in which case an upper-bound for the buffer size
581
+ will be automatically determined.
549
582
 
550
583
  Returns:
551
584
  The preprocessed inputs, and statistics useful for updating FeatureSpecs
@@ -579,6 +612,7 @@ def stack_and_shard_samples(
579
612
  )
580
613
 
581
614
  out: dict[str, ShardedCooMatrix] = {}
615
+ out_stats: dict[str, InputStatsPerTable] = {}
582
616
  tables_names = preprocessed_inputs.lhs_row_pointers.keys()
583
617
  for table_name in tables_names:
584
618
  shard_ends = preprocessed_inputs.lhs_row_pointers[table_name]
@@ -592,5 +626,17 @@ def stack_and_shard_samples(
592
626
  row_ids=preprocessed_inputs.lhs_sample_ids[table_name],
593
627
  values=preprocessed_inputs.lhs_gains[table_name],
594
628
  )
629
+ out_stats[table_name] = InputStatsPerTable(
630
+ max_ids_per_partition=np.max(
631
+ stats.max_ids_per_partition[table_name]
632
+ ),
633
+ max_unique_ids_per_partition=np.max(
634
+ stats.max_unique_ids_per_partition[table_name]
635
+ ),
636
+ required_buffer_size_per_device=np.max(
637
+ stats.required_buffer_size_per_sc[table_name]
638
+ )
639
+ * num_sc_per_device,
640
+ )
595
641
 
596
- return out, stats
642
+ return out, out_stats
@@ -53,7 +53,7 @@ OPTIMIZER_MAPPINGS = {
53
53
  # KerasRS to TensorFlow
54
54
 
55
55
 
56
- def translate_keras_rs_configuration(
56
+ def keras_to_tf_tpu_configuration(
57
57
  feature_configs: types.Nested[FeatureConfig],
58
58
  table_stacking: str | Sequence[str] | Sequence[Sequence[str]],
59
59
  num_replicas_in_sync: int,
@@ -66,14 +66,15 @@ def translate_keras_rs_configuration(
66
66
  Args:
67
67
  feature_configs: The nested Keras RS feature configs.
68
68
  table_stacking: The Keras RS table stacking.
69
+ num_replicas_in_sync: The number of replicas in sync from the strategy.
69
70
 
70
71
  Returns:
71
72
  A tuple containing the TensorFlow TPU feature configs and the TensorFlow
72
73
  TPU sparse core embedding config.
73
74
  """
74
- tables: dict[TableConfig, tf.tpu.experimental.embedding.TableConfig] = {}
75
+ tables: dict[int, tf.tpu.experimental.embedding.TableConfig] = {}
75
76
  feature_configs = keras.tree.map_structure(
76
- lambda f: translate_keras_rs_feature_config(
77
+ lambda f: keras_to_tf_tpu_feature_config(
77
78
  f, tables, num_replicas_in_sync
78
79
  ),
79
80
  feature_configs,
@@ -108,9 +109,9 @@ def translate_keras_rs_configuration(
108
109
  return feature_configs, sparse_core_embedding_config
109
110
 
110
111
 
111
- def translate_keras_rs_feature_config(
112
+ def keras_to_tf_tpu_feature_config(
112
113
  feature_config: FeatureConfig,
113
- tables: dict[TableConfig, tf.tpu.experimental.embedding.TableConfig],
114
+ tables: dict[int, tf.tpu.experimental.embedding.TableConfig],
114
115
  num_replicas_in_sync: int,
115
116
  ) -> tf.tpu.experimental.embedding.FeatureConfig:
116
117
  """Translates a Keras RS feature config to a TensorFlow TPU feature config.
@@ -120,7 +121,8 @@ def translate_keras_rs_feature_config(
120
121
 
121
122
  Args:
122
123
  feature_config: The Keras RS feature config to translate.
123
- tables: A mapping of KerasRS table configs to TF TPU table configs.
124
+ tables: A mapping of KerasRS table config ids to TF TPU table configs.
125
+ num_replicas_in_sync: The number of replicas in sync from the strategy.
124
126
 
125
127
  Returns:
126
128
  The TensorFlow TPU feature config.
@@ -131,10 +133,10 @@ def translate_keras_rs_feature_config(
131
133
  f"but got {num_replicas_in_sync}."
132
134
  )
133
135
 
134
- table = tables.get(feature_config.table, None)
136
+ table = tables.get(id(feature_config.table), None)
135
137
  if table is None:
136
- table = translate_keras_rs_table_config(feature_config.table)
137
- tables[feature_config.table] = table
138
+ table = keras_to_tf_tpu_table_config(feature_config.table)
139
+ tables[id(feature_config.table)] = table
138
140
 
139
141
  if len(feature_config.output_shape) < 2:
140
142
  raise ValueError(
@@ -168,7 +170,7 @@ def translate_keras_rs_feature_config(
168
170
  )
169
171
 
170
172
 
171
- def translate_keras_rs_table_config(
173
+ def keras_to_tf_tpu_table_config(
172
174
  table_config: TableConfig,
173
175
  ) -> tf.tpu.experimental.embedding.TableConfig:
174
176
  initializer = table_config.initializer
@@ -179,13 +181,13 @@ def translate_keras_rs_table_config(
179
181
  vocabulary_size=table_config.vocabulary_size,
180
182
  dim=table_config.embedding_dim,
181
183
  initializer=initializer,
182
- optimizer=translate_optimizer(table_config.optimizer),
184
+ optimizer=to_tf_tpu_optimizer(table_config.optimizer),
183
185
  combiner=table_config.combiner,
184
186
  name=table_config.name,
185
187
  )
186
188
 
187
189
 
188
- def translate_keras_optimizer(
190
+ def keras_to_tf_tpu_optimizer(
189
191
  optimizer: keras.optimizers.Optimizer,
190
192
  ) -> TfTpuOptimizer:
191
193
  """Translates a Keras optimizer to a TensorFlow TPU `_Optimizer`.
@@ -238,7 +240,12 @@ def translate_keras_optimizer(
238
240
  "Unsupported optimizer option `Optimizer.loss_scale_factor`."
239
241
  )
240
242
 
241
- optimizer_mapping = OPTIMIZER_MAPPINGS.get(type(optimizer), None)
243
+ optimizer_mapping = None
244
+ for optimizer_class, mapping in OPTIMIZER_MAPPINGS.items():
245
+ # Handle subclasses of the main optimizer class.
246
+ if isinstance(optimizer, optimizer_class):
247
+ optimizer_mapping = mapping
248
+ break
242
249
  if optimizer_mapping is None:
243
250
  raise ValueError(
244
251
  f"Unsupported optimizer type {type(optimizer)}. Optimizer must be "
@@ -258,7 +265,7 @@ def translate_keras_optimizer(
258
265
  return optimizer_mapping.tpu_optimizer_class(**tpu_optimizer_kwargs)
259
266
 
260
267
 
261
- def translate_optimizer(
268
+ def to_tf_tpu_optimizer(
262
269
  optimizer: str | keras.optimizers.Optimizer | TfTpuOptimizer | None,
263
270
  ) -> TfTpuOptimizer:
264
271
  """Translates a Keras optimizer into a TensorFlow TPU `_Optimizer`.
@@ -299,7 +306,7 @@ def translate_optimizer(
299
306
  "'sgd', 'adagrad', 'adam', or 'ftrl'"
300
307
  )
301
308
  elif isinstance(optimizer, keras.optimizers.Optimizer):
302
- return translate_keras_optimizer(optimizer)
309
+ return keras_to_tf_tpu_optimizer(optimizer)
303
310
  else:
304
311
  raise ValueError(
305
312
  f"Unknown optimizer type {type(optimizer)}. Please pass an "
@@ -312,7 +319,7 @@ def translate_optimizer(
312
319
  # TensorFlow to TensorFlow
313
320
 
314
321
 
315
- def clone_tf_feature_configs(
322
+ def clone_tf_tpu_feature_configs(
316
323
  feature_configs: types.Nested[tf.tpu.experimental.embedding.FeatureConfig],
317
324
  ) -> types.Nested[tf.tpu.experimental.embedding.FeatureConfig]:
318
325
  """Clones and resolves TensorFlow TPU feature configs.
@@ -327,7 +334,7 @@ def clone_tf_feature_configs(
327
334
  """
328
335
  table_configs_dict = {}
329
336
 
330
- def clone_and_resolve_tf_feature_config(
337
+ def clone_and_resolve_tf_tpu_feature_config(
331
338
  fc: tf.tpu.experimental.embedding.FeatureConfig,
332
339
  ) -> tf.tpu.experimental.embedding.FeatureConfig:
333
340
  if fc.table not in table_configs_dict:
@@ -336,7 +343,7 @@ def clone_tf_feature_configs(
336
343
  vocabulary_size=fc.table.vocabulary_size,
337
344
  dim=fc.table.dim,
338
345
  initializer=fc.table.initializer,
339
- optimizer=translate_optimizer(fc.table.optimizer),
346
+ optimizer=to_tf_tpu_optimizer(fc.table.optimizer),
340
347
  combiner=fc.table.combiner,
341
348
  name=fc.table.name,
342
349
  quantization_config=fc.table.quantization_config,
@@ -352,5 +359,5 @@ def clone_tf_feature_configs(
352
359
  )
353
360
 
354
361
  return keras.tree.map_structure(
355
- clone_and_resolve_tf_feature_config, feature_configs
362
+ clone_and_resolve_tf_tpu_feature_config, feature_configs
356
363
  )
@@ -106,7 +106,7 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
106
106
  "for the configuration."
107
107
  )
108
108
  self._tpu_feature_configs, self._sparse_core_embedding_config = (
109
- config_conversion.translate_keras_rs_configuration(
109
+ config_conversion.keras_to_tf_tpu_configuration(
110
110
  feature_configs,
111
111
  table_stacking,
112
112
  strategy.num_replicas_in_sync,
@@ -135,10 +135,10 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
135
135
  "supported with this TPU generation."
136
136
  )
137
137
  self._tpu_feature_configs = (
138
- config_conversion.clone_tf_feature_configs(feature_configs)
138
+ config_conversion.clone_tf_tpu_feature_configs(feature_configs)
139
139
  )
140
140
 
141
- self._tpu_optimizer = config_conversion.translate_optimizer(
141
+ self._tpu_optimizer = config_conversion.to_tf_tpu_optimizer(
142
142
  self._optimizer
143
143
  )
144
144
 
@@ -281,8 +281,18 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
281
281
  def _sparsecore_get_embedding_tables(self) -> dict[str, types.Tensor]:
282
282
  tables: dict[str, types.Tensor] = {}
283
283
  strategy = tf.distribute.get_strategy()
284
- # 4 is the number of sparsecores per chip
285
- num_shards = strategy.num_replicas_in_sync * 4
284
+ if not self._is_tpu_strategy(strategy):
285
+ raise RuntimeError(
286
+ "`DistributedEmbedding.get_embedding_tables` needs to be "
287
+ "called under the TPUStrategy that DistributedEmbedding was "
288
+ f"created with, but is being called under strategy {strategy}. "
289
+ "Please use `with strategy.scope()` when calling "
290
+ "`get_embedding_tables`."
291
+ )
292
+
293
+ tpu_hardware = strategy.extended.tpu_hardware_feature
294
+ num_sc_per_device = tpu_hardware.num_embedding_devices_per_chip
295
+ num_shards = strategy.num_replicas_in_sync * num_sc_per_device
286
296
 
287
297
  def populate_table(
288
298
  feature_config: tf.tpu.experimental.embedding.FeatureConfig,
@@ -1,7 +1,7 @@
1
1
  from keras_rs.src.api_export import keras_rs_export
2
2
 
3
3
  # Unique source of truth for the version number.
4
- __version__ = "0.2.2.dev202509030321"
4
+ __version__ = "0.2.2.dev202509170322"
5
5
 
6
6
 
7
7
  @keras_rs_export("keras_rs.version")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: keras-rs-nightly
3
- Version: 0.2.2.dev202509030321
3
+ Version: 0.2.2.dev202509170322
4
4
  Summary: Multi-backend recommender systems with Keras 3.
5
5
  Author-email: Keras team <keras-users@googlegroups.com>
6
6
  License: Apache License 2.0
@@ -8,8 +8,9 @@ Project-URL: Home, https://keras.io/keras_rs
8
8
  Project-URL: Repository, https://github.com/keras-team/keras-rs
9
9
  Classifier: Development Status :: 3 - Alpha
10
10
  Classifier: Programming Language :: Python :: 3
11
- Classifier: Programming Language :: Python :: 3.10
12
11
  Classifier: Programming Language :: Python :: 3.11
12
+ Classifier: Programming Language :: Python :: 3.12
13
+ Classifier: Programming Language :: Python :: 3.13
13
14
  Classifier: Programming Language :: Python :: 3 :: Only
14
15
  Classifier: Operating System :: Unix
15
16
  Classifier: Operating System :: Microsoft :: Windows
@@ -17,7 +18,7 @@ Classifier: Operating System :: MacOS
17
18
  Classifier: Intended Audience :: Science/Research
18
19
  Classifier: Topic :: Scientific/Engineering
19
20
  Classifier: Topic :: Software Development
20
- Requires-Python: >=3.10
21
+ Requires-Python: >=3.11
21
22
  Description-Content-Type: text/markdown
22
23
  Requires-Dist: keras
23
24
  Requires-Dist: ml-dtypes
@@ -9,14 +9,15 @@ authors = [
9
9
  ]
10
10
  description = "Multi-backend recommender systems with Keras 3."
11
11
  readme = "README.md"
12
- requires-python = ">=3.10"
12
+ requires-python = ">=3.11"
13
13
  license = {text = "Apache License 2.0"}
14
14
  dynamic = ["version"]
15
15
  classifiers = [
16
16
  "Development Status :: 3 - Alpha",
17
17
  "Programming Language :: Python :: 3",
18
- "Programming Language :: Python :: 3.10",
19
18
  "Programming Language :: Python :: 3.11",
19
+ "Programming Language :: Python :: 3.12",
20
+ "Programming Language :: Python :: 3.13",
20
21
  "Programming Language :: Python :: 3 :: Only",
21
22
  "Operating System :: Unix",
22
23
  "Operating System :: Microsoft :: Windows",