keras-rs-nightly 0.3.1.dev202510280332__tar.gz → 0.3.1.dev202510300334__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of keras-rs-nightly might be problematic. Click here for more details.
- {keras_rs_nightly-0.3.1.dev202510280332 → keras_rs_nightly-0.3.1.dev202510300334}/PKG-INFO +1 -1
- {keras_rs_nightly-0.3.1.dev202510280332 → keras_rs_nightly-0.3.1.dev202510300334}/keras_rs/src/layers/embedding/jax/distributed_embedding.py +25 -0
- {keras_rs_nightly-0.3.1.dev202510280332 → keras_rs_nightly-0.3.1.dev202510300334}/keras_rs/src/version.py +1 -1
- {keras_rs_nightly-0.3.1.dev202510280332 → keras_rs_nightly-0.3.1.dev202510300334}/keras_rs_nightly.egg-info/PKG-INFO +1 -1
- {keras_rs_nightly-0.3.1.dev202510280332 → keras_rs_nightly-0.3.1.dev202510300334}/README.md +0 -0
- {keras_rs_nightly-0.3.1.dev202510280332 → keras_rs_nightly-0.3.1.dev202510300334}/keras_rs/api/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510280332 → keras_rs_nightly-0.3.1.dev202510300334}/keras_rs/api/layers/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510280332 → keras_rs_nightly-0.3.1.dev202510300334}/keras_rs/api/losses/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510280332 → keras_rs_nightly-0.3.1.dev202510300334}/keras_rs/api/metrics/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510280332 → keras_rs_nightly-0.3.1.dev202510300334}/keras_rs/src/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510280332 → keras_rs_nightly-0.3.1.dev202510300334}/keras_rs/src/api_export.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510280332 → keras_rs_nightly-0.3.1.dev202510300334}/keras_rs/src/layers/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510280332 → keras_rs_nightly-0.3.1.dev202510300334}/keras_rs/src/layers/embedding/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510280332 → keras_rs_nightly-0.3.1.dev202510300334}/keras_rs/src/layers/embedding/base_distributed_embedding.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510280332 → keras_rs_nightly-0.3.1.dev202510300334}/keras_rs/src/layers/embedding/distributed_embedding.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510280332 → keras_rs_nightly-0.3.1.dev202510300334}/keras_rs/src/layers/embedding/distributed_embedding_config.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510280332 → keras_rs_nightly-0.3.1.dev202510300334}/keras_rs/src/layers/embedding/embed_reduce.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510280332 → keras_rs_nightly-0.3.1.dev202510300334}/keras_rs/src/layers/embedding/jax/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510280332 → keras_rs_nightly-0.3.1.dev202510300334}/keras_rs/src/layers/embedding/jax/checkpoint_utils.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510280332 → keras_rs_nightly-0.3.1.dev202510300334}/keras_rs/src/layers/embedding/jax/config_conversion.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510280332 → keras_rs_nightly-0.3.1.dev202510300334}/keras_rs/src/layers/embedding/jax/embedding_lookup.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510280332 → keras_rs_nightly-0.3.1.dev202510300334}/keras_rs/src/layers/embedding/jax/embedding_utils.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510280332 → keras_rs_nightly-0.3.1.dev202510300334}/keras_rs/src/layers/embedding/tensorflow/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510280332 → keras_rs_nightly-0.3.1.dev202510300334}/keras_rs/src/layers/embedding/tensorflow/config_conversion.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510280332 → keras_rs_nightly-0.3.1.dev202510300334}/keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510280332 → keras_rs_nightly-0.3.1.dev202510300334}/keras_rs/src/layers/feature_interaction/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510280332 → keras_rs_nightly-0.3.1.dev202510300334}/keras_rs/src/layers/feature_interaction/dot_interaction.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510280332 → keras_rs_nightly-0.3.1.dev202510300334}/keras_rs/src/layers/feature_interaction/feature_cross.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510280332 → keras_rs_nightly-0.3.1.dev202510300334}/keras_rs/src/layers/retrieval/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510280332 → keras_rs_nightly-0.3.1.dev202510300334}/keras_rs/src/layers/retrieval/brute_force_retrieval.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510280332 → keras_rs_nightly-0.3.1.dev202510300334}/keras_rs/src/layers/retrieval/hard_negative_mining.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510280332 → keras_rs_nightly-0.3.1.dev202510300334}/keras_rs/src/layers/retrieval/remove_accidental_hits.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510280332 → keras_rs_nightly-0.3.1.dev202510300334}/keras_rs/src/layers/retrieval/retrieval.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510280332 → keras_rs_nightly-0.3.1.dev202510300334}/keras_rs/src/layers/retrieval/sampling_probability_correction.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510280332 → keras_rs_nightly-0.3.1.dev202510300334}/keras_rs/src/losses/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510280332 → keras_rs_nightly-0.3.1.dev202510300334}/keras_rs/src/losses/pairwise_hinge_loss.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510280332 → keras_rs_nightly-0.3.1.dev202510300334}/keras_rs/src/losses/pairwise_logistic_loss.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510280332 → keras_rs_nightly-0.3.1.dev202510300334}/keras_rs/src/losses/pairwise_loss.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510280332 → keras_rs_nightly-0.3.1.dev202510300334}/keras_rs/src/losses/pairwise_loss_utils.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510280332 → keras_rs_nightly-0.3.1.dev202510300334}/keras_rs/src/losses/pairwise_mean_squared_error.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510280332 → keras_rs_nightly-0.3.1.dev202510300334}/keras_rs/src/losses/pairwise_soft_zero_one_loss.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510280332 → keras_rs_nightly-0.3.1.dev202510300334}/keras_rs/src/metrics/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510280332 → keras_rs_nightly-0.3.1.dev202510300334}/keras_rs/src/metrics/dcg.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510280332 → keras_rs_nightly-0.3.1.dev202510300334}/keras_rs/src/metrics/mean_average_precision.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510280332 → keras_rs_nightly-0.3.1.dev202510300334}/keras_rs/src/metrics/mean_reciprocal_rank.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510280332 → keras_rs_nightly-0.3.1.dev202510300334}/keras_rs/src/metrics/ndcg.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510280332 → keras_rs_nightly-0.3.1.dev202510300334}/keras_rs/src/metrics/precision_at_k.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510280332 → keras_rs_nightly-0.3.1.dev202510300334}/keras_rs/src/metrics/ranking_metric.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510280332 → keras_rs_nightly-0.3.1.dev202510300334}/keras_rs/src/metrics/ranking_metrics_utils.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510280332 → keras_rs_nightly-0.3.1.dev202510300334}/keras_rs/src/metrics/recall_at_k.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510280332 → keras_rs_nightly-0.3.1.dev202510300334}/keras_rs/src/metrics/utils.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510280332 → keras_rs_nightly-0.3.1.dev202510300334}/keras_rs/src/types.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510280332 → keras_rs_nightly-0.3.1.dev202510300334}/keras_rs/src/utils/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510280332 → keras_rs_nightly-0.3.1.dev202510300334}/keras_rs/src/utils/doc_string_utils.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510280332 → keras_rs_nightly-0.3.1.dev202510300334}/keras_rs/src/utils/keras_utils.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510280332 → keras_rs_nightly-0.3.1.dev202510300334}/keras_rs_nightly.egg-info/SOURCES.txt +0 -0
- {keras_rs_nightly-0.3.1.dev202510280332 → keras_rs_nightly-0.3.1.dev202510300334}/keras_rs_nightly.egg-info/dependency_links.txt +0 -0
- {keras_rs_nightly-0.3.1.dev202510280332 → keras_rs_nightly-0.3.1.dev202510300334}/keras_rs_nightly.egg-info/requires.txt +0 -0
- {keras_rs_nightly-0.3.1.dev202510280332 → keras_rs_nightly-0.3.1.dev202510300334}/keras_rs_nightly.egg-info/top_level.txt +0 -0
- {keras_rs_nightly-0.3.1.dev202510280332 → keras_rs_nightly-0.3.1.dev202510300334}/pyproject.toml +0 -0
- {keras_rs_nightly-0.3.1.dev202510280332 → keras_rs_nightly-0.3.1.dev202510300334}/setup.cfg +0 -0
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""JAX implementation of the TPU embedding layer."""
|
|
2
2
|
|
|
3
|
+
import dataclasses
|
|
3
4
|
import math
|
|
4
5
|
import typing
|
|
5
6
|
from typing import Any, Mapping, Sequence, Union
|
|
@@ -445,6 +446,30 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
|
|
|
445
446
|
table_specs = embedding.get_table_specs(feature_specs)
|
|
446
447
|
table_stacks = jte_table_stacking.get_table_stacks(table_specs)
|
|
447
448
|
|
|
449
|
+
# Create new instances of StackTableSpec with updated values that are
|
|
450
|
+
# the maximum from stacked tables.
|
|
451
|
+
stacked_table_specs = embedding.get_stacked_table_specs(feature_specs)
|
|
452
|
+
stacked_table_specs = {
|
|
453
|
+
stack_name: dataclasses.replace(
|
|
454
|
+
stacked_table_spec,
|
|
455
|
+
max_ids_per_partition=max(
|
|
456
|
+
table.max_ids_per_partition
|
|
457
|
+
for table in table_stacks[stack_name]
|
|
458
|
+
),
|
|
459
|
+
max_unique_ids_per_partition=max(
|
|
460
|
+
table.max_unique_ids_per_partition
|
|
461
|
+
for table in table_stacks[stack_name]
|
|
462
|
+
),
|
|
463
|
+
)
|
|
464
|
+
for stack_name, stacked_table_spec in stacked_table_specs.items()
|
|
465
|
+
}
|
|
466
|
+
|
|
467
|
+
# Rewrite the stacked_table_spec in all TableSpecs.
|
|
468
|
+
for stack_name, table_specs in table_stacks.items():
|
|
469
|
+
stacked_table_spec = stacked_table_specs[stack_name]
|
|
470
|
+
for table_spec in table_specs:
|
|
471
|
+
table_spec.stacked_table_spec = stacked_table_spec
|
|
472
|
+
|
|
448
473
|
# Create variables for all stacked tables and slot variables.
|
|
449
474
|
with sparsecore_distribution.scope():
|
|
450
475
|
self._table_and_slot_variables = {
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{keras_rs_nightly-0.3.1.dev202510280332 → keras_rs_nightly-0.3.1.dev202510300334}/pyproject.toml
RENAMED
|
File without changes
|
|
File without changes
|