keras-rs-nightly 0.3.1.dev202510220333__py3-none-any.whl → 0.3.1.dev202510240328__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of keras-rs-nightly might be problematic. Click here for more details.

@@ -9,6 +9,7 @@ import keras
9
9
  import numpy as np
10
10
  from jax import numpy as jnp
11
11
  from jax.experimental import layout as jax_layout
12
+ from jax.experimental import multihost_utils
12
13
  from jax_tpu_embedding.sparsecore.lib.nn import embedding
13
14
  from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec
14
15
  from jax_tpu_embedding.sparsecore.lib.nn import (
@@ -600,31 +601,26 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
600
601
  # underlying stacked tables specs in the feature specs.
601
602
 
602
603
  # Aggregate stats across all processes/devices via pmax.
603
- num_local_cpu_devices = jax.local_device_count("cpu")
604
-
605
- def pmax_aggregate(x: Any) -> Any:
606
- if not hasattr(x, "ndim"):
607
- x = np.array(x)
608
- tiled_x = np.tile(x, (num_local_cpu_devices, *([1] * x.ndim)))
609
- return jax.pmap(
610
- lambda y: jax.lax.pmax(y, "all_cpus"), # type: ignore[no-untyped-call]
611
- axis_name="all_cpus",
612
- backend="cpu",
613
- )(tiled_x)[0]
614
-
615
- full_stats = jax.tree.map(pmax_aggregate, stats)
604
+ all_stats = multihost_utils.process_allgather(stats)
605
+ aggregated_stats = jax.tree.map(
606
+ lambda x: jnp.max(x, axis=0), all_stats
607
+ )
616
608
 
617
609
  # Check if stats changed enough to warrant action.
618
610
  stacked_table_specs = embedding.get_stacked_table_specs(
619
611
  self._config.feature_specs
620
612
  )
621
613
  changed = any(
622
- np.max(full_stats.max_ids_per_partition[stack_name])
614
+ np.max(aggregated_stats.max_ids_per_partition[stack_name])
623
615
  > spec.max_ids_per_partition
624
- or np.max(full_stats.max_unique_ids_per_partition[stack_name])
616
+ or np.max(
617
+ aggregated_stats.max_unique_ids_per_partition[stack_name]
618
+ )
625
619
  > spec.max_unique_ids_per_partition
626
620
  or (
627
- np.max(full_stats.required_buffer_size_per_sc[stack_name])
621
+ np.max(
622
+ aggregated_stats.required_buffer_size_per_sc[stack_name]
623
+ )
628
624
  * num_sc_per_device
629
625
  )
630
626
  > (spec.suggested_coo_buffer_size_per_device or 0)
@@ -634,7 +630,9 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
634
630
  # Update configuration and repeat preprocessing if stats changed.
635
631
  if changed:
636
632
  embedding.update_preprocessing_parameters(
637
- self._config.feature_specs, full_stats, num_sc_per_device
633
+ self._config.feature_specs,
634
+ aggregated_stats,
635
+ num_sc_per_device,
638
636
  )
639
637
 
640
638
  # Re-execute preprocessing with consistent input statistics.
keras_rs/src/version.py CHANGED
@@ -1,7 +1,7 @@
1
1
  from keras_rs.src.api_export import keras_rs_export
2
2
 
3
3
  # Unique source of truth for the version number.
4
- __version__ = "0.3.1.dev202510220333"
4
+ __version__ = "0.3.1.dev202510240328"
5
5
 
6
6
 
7
7
  @keras_rs_export("keras_rs.version")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: keras-rs-nightly
3
- Version: 0.3.1.dev202510220333
3
+ Version: 0.3.1.dev202510240328
4
4
  Summary: Multi-backend recommender systems with Keras 3.
5
5
  Author-email: Keras team <keras-users@googlegroups.com>
6
6
  License: Apache License 2.0
@@ -5,7 +5,7 @@ keras_rs/metrics/__init__.py,sha256=Qxpf6OFooIL9TIn2l3WgOea3HFRG0hq02glPAxtMZ9c,
5
5
  keras_rs/src/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
6
  keras_rs/src/api_export.py,sha256=RsmG-DvO-cdFeAF9W6LRzms0kvtm-Yp9BAA_d-952zI,510
7
7
  keras_rs/src/types.py,sha256=1A-oLRdX1-f2DsVZBcNl8qNsaH8pM-gnleLT9FWZWBw,1189
8
- keras_rs/src/version.py,sha256=5K8isi168MtN0VSp5_c1VYmRMw8pLpswFkARLUS9Z8Q,224
8
+ keras_rs/src/version.py,sha256=PT5g4Jbeo7CzD3xiVH7sJhFwaI8vEYnZJmnZMryylNo,224
9
9
  keras_rs/src/layers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
10
  keras_rs/src/layers/embedding/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
11
  keras_rs/src/layers/embedding/base_distributed_embedding.py,sha256=RkXZ6notj3Cq6ryR9w30Wb8UlaWjLcUK2Os9ZUQvuhY,45568
@@ -15,7 +15,7 @@ keras_rs/src/layers/embedding/embed_reduce.py,sha256=c-MnEw1-KWs0jTf0JJ_ZBOY-9hR
15
15
  keras_rs/src/layers/embedding/jax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
16
16
  keras_rs/src/layers/embedding/jax/checkpoint_utils.py,sha256=wZ4I5WZVNg5WnrD2j7nhAXgLzDc7xMrUEkSAOx5Sz5c,3495
17
17
  keras_rs/src/layers/embedding/jax/config_conversion.py,sha256=Di1UzRwLgGHd7RuWYJMj2mCOr1u9MseFEWaYKnwD9Bs,16742
18
- keras_rs/src/layers/embedding/jax/distributed_embedding.py,sha256=l3gpFXBAdkJw7yVntl0s25excCfC5jryyqBxUKZd2Fk,29820
18
+ keras_rs/src/layers/embedding/jax/distributed_embedding.py,sha256=O3G0AFRzukYdXPRyx7ZDqDvNgJrcbFwTCYTHigfdiKw,29628
19
19
  keras_rs/src/layers/embedding/jax/embedding_lookup.py,sha256=8LigXjPr7uQaUOdZM6yoLGoPYdRcbkXkFeL_sJoQ6uQ,8223
20
20
  keras_rs/src/layers/embedding/jax/embedding_utils.py,sha256=slJ0XwkI1z4vTAnRXQwm39LFnK9AL3CODuGRn5BufgE,8292
21
21
  keras_rs/src/layers/embedding/tensorflow/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -50,7 +50,7 @@ keras_rs/src/metrics/utils.py,sha256=fGTo8j0ykVE5Y3yQCS2orSFcHY20Uxt0NazyPsybUsw
50
50
  keras_rs/src/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
51
51
  keras_rs/src/utils/doc_string_utils.py,sha256=CmqomepmaYcvpACpXEXkrJb8DMnvIgmYK-lJ53lYarY,1675
52
52
  keras_rs/src/utils/keras_utils.py,sha256=dc-NFzs3a-qmRw0vBDiMslPLfrm9yymGduLWesXPhuY,2123
53
- keras_rs_nightly-0.3.1.dev202510220333.dist-info/METADATA,sha256=xiV9gYIXS_a1sbBEFQwlfLn4mZ3GyAmfBFVMYRhC4Jc,5324
54
- keras_rs_nightly-0.3.1.dev202510220333.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
55
- keras_rs_nightly-0.3.1.dev202510220333.dist-info/top_level.txt,sha256=pWs8X78Z0cn6lfcIb9VYOW5UeJ-TpoaO9dByzo7_FFo,9
56
- keras_rs_nightly-0.3.1.dev202510220333.dist-info/RECORD,,
53
+ keras_rs_nightly-0.3.1.dev202510240328.dist-info/METADATA,sha256=Ezi7EQGTudw0xaxAsoHu4thA66EKIeXFNOBvHTBkH1I,5324
54
+ keras_rs_nightly-0.3.1.dev202510240328.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
55
+ keras_rs_nightly-0.3.1.dev202510240328.dist-info/top_level.txt,sha256=pWs8X78Z0cn6lfcIb9VYOW5UeJ-TpoaO9dByzo7_FFo,9
56
+ keras_rs_nightly-0.3.1.dev202510240328.dist-info/RECORD,,