keras-rs-nightly 0.3.1.dev202510220333__tar.gz → 0.3.1.dev202510240328__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of keras-rs-nightly might be problematic. Click here for more details.

Files changed (61) hide show
  1. {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/PKG-INFO +1 -1
  2. {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/layers/embedding/jax/distributed_embedding.py +15 -17
  3. {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/version.py +1 -1
  4. {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs_nightly.egg-info/PKG-INFO +1 -1
  5. {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/README.md +0 -0
  6. {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/api/__init__.py +0 -0
  7. {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/api/layers/__init__.py +0 -0
  8. {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/api/losses/__init__.py +0 -0
  9. {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/api/metrics/__init__.py +0 -0
  10. {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/__init__.py +0 -0
  11. {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/api_export.py +0 -0
  12. {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/layers/__init__.py +0 -0
  13. {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/layers/embedding/__init__.py +0 -0
  14. {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/layers/embedding/base_distributed_embedding.py +0 -0
  15. {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/layers/embedding/distributed_embedding.py +0 -0
  16. {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/layers/embedding/distributed_embedding_config.py +0 -0
  17. {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/layers/embedding/embed_reduce.py +0 -0
  18. {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/layers/embedding/jax/__init__.py +0 -0
  19. {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/layers/embedding/jax/checkpoint_utils.py +0 -0
  20. {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/layers/embedding/jax/config_conversion.py +0 -0
  21. {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/layers/embedding/jax/embedding_lookup.py +0 -0
  22. {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/layers/embedding/jax/embedding_utils.py +0 -0
  23. {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/layers/embedding/tensorflow/__init__.py +0 -0
  24. {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/layers/embedding/tensorflow/config_conversion.py +0 -0
  25. {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py +0 -0
  26. {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/layers/feature_interaction/__init__.py +0 -0
  27. {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/layers/feature_interaction/dot_interaction.py +0 -0
  28. {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/layers/feature_interaction/feature_cross.py +0 -0
  29. {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/layers/retrieval/__init__.py +0 -0
  30. {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/layers/retrieval/brute_force_retrieval.py +0 -0
  31. {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/layers/retrieval/hard_negative_mining.py +0 -0
  32. {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/layers/retrieval/remove_accidental_hits.py +0 -0
  33. {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/layers/retrieval/retrieval.py +0 -0
  34. {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/layers/retrieval/sampling_probability_correction.py +0 -0
  35. {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/losses/__init__.py +0 -0
  36. {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/losses/pairwise_hinge_loss.py +0 -0
  37. {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/losses/pairwise_logistic_loss.py +0 -0
  38. {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/losses/pairwise_loss.py +0 -0
  39. {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/losses/pairwise_loss_utils.py +0 -0
  40. {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/losses/pairwise_mean_squared_error.py +0 -0
  41. {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/losses/pairwise_soft_zero_one_loss.py +0 -0
  42. {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/metrics/__init__.py +0 -0
  43. {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/metrics/dcg.py +0 -0
  44. {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/metrics/mean_average_precision.py +0 -0
  45. {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/metrics/mean_reciprocal_rank.py +0 -0
  46. {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/metrics/ndcg.py +0 -0
  47. {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/metrics/precision_at_k.py +0 -0
  48. {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/metrics/ranking_metric.py +0 -0
  49. {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/metrics/ranking_metrics_utils.py +0 -0
  50. {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/metrics/recall_at_k.py +0 -0
  51. {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/metrics/utils.py +0 -0
  52. {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/types.py +0 -0
  53. {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/utils/__init__.py +0 -0
  54. {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/utils/doc_string_utils.py +0 -0
  55. {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/utils/keras_utils.py +0 -0
  56. {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs_nightly.egg-info/SOURCES.txt +0 -0
  57. {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs_nightly.egg-info/dependency_links.txt +0 -0
  58. {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs_nightly.egg-info/requires.txt +0 -0
  59. {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs_nightly.egg-info/top_level.txt +0 -0
  60. {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/pyproject.toml +0 -0
  61. {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: keras-rs-nightly
3
- Version: 0.3.1.dev202510220333
3
+ Version: 0.3.1.dev202510240328
4
4
  Summary: Multi-backend recommender systems with Keras 3.
5
5
  Author-email: Keras team <keras-users@googlegroups.com>
6
6
  License: Apache License 2.0
@@ -9,6 +9,7 @@ import keras
9
9
  import numpy as np
10
10
  from jax import numpy as jnp
11
11
  from jax.experimental import layout as jax_layout
12
+ from jax.experimental import multihost_utils
12
13
  from jax_tpu_embedding.sparsecore.lib.nn import embedding
13
14
  from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec
14
15
  from jax_tpu_embedding.sparsecore.lib.nn import (
@@ -600,31 +601,26 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
600
601
  # underlying stacked tables specs in the feature specs.
601
602
 
602
603
  # Aggregate stats across all processes/devices via pmax.
603
- num_local_cpu_devices = jax.local_device_count("cpu")
604
-
605
- def pmax_aggregate(x: Any) -> Any:
606
- if not hasattr(x, "ndim"):
607
- x = np.array(x)
608
- tiled_x = np.tile(x, (num_local_cpu_devices, *([1] * x.ndim)))
609
- return jax.pmap(
610
- lambda y: jax.lax.pmax(y, "all_cpus"), # type: ignore[no-untyped-call]
611
- axis_name="all_cpus",
612
- backend="cpu",
613
- )(tiled_x)[0]
614
-
615
- full_stats = jax.tree.map(pmax_aggregate, stats)
604
+ all_stats = multihost_utils.process_allgather(stats)
605
+ aggregated_stats = jax.tree.map(
606
+ lambda x: jnp.max(x, axis=0), all_stats
607
+ )
616
608
 
617
609
  # Check if stats changed enough to warrant action.
618
610
  stacked_table_specs = embedding.get_stacked_table_specs(
619
611
  self._config.feature_specs
620
612
  )
621
613
  changed = any(
622
- np.max(full_stats.max_ids_per_partition[stack_name])
614
+ np.max(aggregated_stats.max_ids_per_partition[stack_name])
623
615
  > spec.max_ids_per_partition
624
- or np.max(full_stats.max_unique_ids_per_partition[stack_name])
616
+ or np.max(
617
+ aggregated_stats.max_unique_ids_per_partition[stack_name]
618
+ )
625
619
  > spec.max_unique_ids_per_partition
626
620
  or (
627
- np.max(full_stats.required_buffer_size_per_sc[stack_name])
621
+ np.max(
622
+ aggregated_stats.required_buffer_size_per_sc[stack_name]
623
+ )
628
624
  * num_sc_per_device
629
625
  )
630
626
  > (spec.suggested_coo_buffer_size_per_device or 0)
@@ -634,7 +630,9 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
634
630
  # Update configuration and repeat preprocessing if stats changed.
635
631
  if changed:
636
632
  embedding.update_preprocessing_parameters(
637
- self._config.feature_specs, full_stats, num_sc_per_device
633
+ self._config.feature_specs,
634
+ aggregated_stats,
635
+ num_sc_per_device,
638
636
  )
639
637
 
640
638
  # Re-execute preprocessing with consistent input statistics.
@@ -1,7 +1,7 @@
1
1
  from keras_rs.src.api_export import keras_rs_export
2
2
 
3
3
  # Unique source of truth for the version number.
4
- __version__ = "0.3.1.dev202510220333"
4
+ __version__ = "0.3.1.dev202510240328"
5
5
 
6
6
 
7
7
  @keras_rs_export("keras_rs.version")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: keras-rs-nightly
3
- Version: 0.3.1.dev202510220333
3
+ Version: 0.3.1.dev202510240328
4
4
  Summary: Multi-backend recommender systems with Keras 3.
5
5
  Author-email: Keras team <keras-users@googlegroups.com>
6
6
  License: Apache License 2.0