keras-rs-nightly 0.3.1.dev202510220333__tar.gz → 0.3.1.dev202510240328__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of keras-rs-nightly might be problematic. Click here for more details.
- {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/PKG-INFO +1 -1
- {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/layers/embedding/jax/distributed_embedding.py +15 -17
- {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/version.py +1 -1
- {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs_nightly.egg-info/PKG-INFO +1 -1
- {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/README.md +0 -0
- {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/api/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/api/layers/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/api/losses/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/api/metrics/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/api_export.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/layers/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/layers/embedding/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/layers/embedding/base_distributed_embedding.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/layers/embedding/distributed_embedding.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/layers/embedding/distributed_embedding_config.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/layers/embedding/embed_reduce.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/layers/embedding/jax/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/layers/embedding/jax/checkpoint_utils.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/layers/embedding/jax/config_conversion.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/layers/embedding/jax/embedding_lookup.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/layers/embedding/jax/embedding_utils.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/layers/embedding/tensorflow/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/layers/embedding/tensorflow/config_conversion.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/layers/feature_interaction/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/layers/feature_interaction/dot_interaction.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/layers/feature_interaction/feature_cross.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/layers/retrieval/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/layers/retrieval/brute_force_retrieval.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/layers/retrieval/hard_negative_mining.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/layers/retrieval/remove_accidental_hits.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/layers/retrieval/retrieval.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/layers/retrieval/sampling_probability_correction.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/losses/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/losses/pairwise_hinge_loss.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/losses/pairwise_logistic_loss.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/losses/pairwise_loss.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/losses/pairwise_loss_utils.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/losses/pairwise_mean_squared_error.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/losses/pairwise_soft_zero_one_loss.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/metrics/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/metrics/dcg.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/metrics/mean_average_precision.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/metrics/mean_reciprocal_rank.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/metrics/ndcg.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/metrics/precision_at_k.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/metrics/ranking_metric.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/metrics/ranking_metrics_utils.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/metrics/recall_at_k.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/metrics/utils.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/types.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/utils/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/utils/doc_string_utils.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs/src/utils/keras_utils.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs_nightly.egg-info/SOURCES.txt +0 -0
- {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs_nightly.egg-info/dependency_links.txt +0 -0
- {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs_nightly.egg-info/requires.txt +0 -0
- {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/keras_rs_nightly.egg-info/top_level.txt +0 -0
- {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/pyproject.toml +0 -0
- {keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/setup.cfg +0 -0
|
@@ -9,6 +9,7 @@ import keras
|
|
|
9
9
|
import numpy as np
|
|
10
10
|
from jax import numpy as jnp
|
|
11
11
|
from jax.experimental import layout as jax_layout
|
|
12
|
+
from jax.experimental import multihost_utils
|
|
12
13
|
from jax_tpu_embedding.sparsecore.lib.nn import embedding
|
|
13
14
|
from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec
|
|
14
15
|
from jax_tpu_embedding.sparsecore.lib.nn import (
|
|
@@ -600,31 +601,26 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
|
|
|
600
601
|
# underlying stacked tables specs in the feature specs.
|
|
601
602
|
|
|
602
603
|
# Aggregate stats across all processes/devices via pmax.
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
x = np.array(x)
|
|
608
|
-
tiled_x = np.tile(x, (num_local_cpu_devices, *([1] * x.ndim)))
|
|
609
|
-
return jax.pmap(
|
|
610
|
-
lambda y: jax.lax.pmax(y, "all_cpus"), # type: ignore[no-untyped-call]
|
|
611
|
-
axis_name="all_cpus",
|
|
612
|
-
backend="cpu",
|
|
613
|
-
)(tiled_x)[0]
|
|
614
|
-
|
|
615
|
-
full_stats = jax.tree.map(pmax_aggregate, stats)
|
|
604
|
+
all_stats = multihost_utils.process_allgather(stats)
|
|
605
|
+
aggregated_stats = jax.tree.map(
|
|
606
|
+
lambda x: jnp.max(x, axis=0), all_stats
|
|
607
|
+
)
|
|
616
608
|
|
|
617
609
|
# Check if stats changed enough to warrant action.
|
|
618
610
|
stacked_table_specs = embedding.get_stacked_table_specs(
|
|
619
611
|
self._config.feature_specs
|
|
620
612
|
)
|
|
621
613
|
changed = any(
|
|
622
|
-
np.max(
|
|
614
|
+
np.max(aggregated_stats.max_ids_per_partition[stack_name])
|
|
623
615
|
> spec.max_ids_per_partition
|
|
624
|
-
or np.max(
|
|
616
|
+
or np.max(
|
|
617
|
+
aggregated_stats.max_unique_ids_per_partition[stack_name]
|
|
618
|
+
)
|
|
625
619
|
> spec.max_unique_ids_per_partition
|
|
626
620
|
or (
|
|
627
|
-
np.max(
|
|
621
|
+
np.max(
|
|
622
|
+
aggregated_stats.required_buffer_size_per_sc[stack_name]
|
|
623
|
+
)
|
|
628
624
|
* num_sc_per_device
|
|
629
625
|
)
|
|
630
626
|
> (spec.suggested_coo_buffer_size_per_device or 0)
|
|
@@ -634,7 +630,9 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
|
|
|
634
630
|
# Update configuration and repeat preprocessing if stats changed.
|
|
635
631
|
if changed:
|
|
636
632
|
embedding.update_preprocessing_parameters(
|
|
637
|
-
self._config.feature_specs,
|
|
633
|
+
self._config.feature_specs,
|
|
634
|
+
aggregated_stats,
|
|
635
|
+
num_sc_per_device,
|
|
638
636
|
)
|
|
639
637
|
|
|
640
638
|
# Re-execute preprocessing with consistent input statistics.
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{keras_rs_nightly-0.3.1.dev202510220333 → keras_rs_nightly-0.3.1.dev202510240328}/pyproject.toml
RENAMED
|
File without changes
|
|
File without changes
|