keras-rs-nightly 0.2.2.dev202506190335__tar.gz → 0.2.2.dev202506200334__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of keras-rs-nightly might be problematic. Click here for more details.
- {keras_rs_nightly-0.2.2.dev202506190335 → keras_rs_nightly-0.2.2.dev202506200334}/PKG-INFO +1 -1
- {keras_rs_nightly-0.2.2.dev202506190335 → keras_rs_nightly-0.2.2.dev202506200334}/keras_rs/src/layers/embedding/jax/distributed_embedding.py +3 -3
- {keras_rs_nightly-0.2.2.dev202506190335 → keras_rs_nightly-0.2.2.dev202506200334}/keras_rs/src/layers/embedding/jax/embedding_lookup.py +2 -2
- {keras_rs_nightly-0.2.2.dev202506190335 → keras_rs_nightly-0.2.2.dev202506200334}/keras_rs/src/version.py +1 -1
- {keras_rs_nightly-0.2.2.dev202506190335 → keras_rs_nightly-0.2.2.dev202506200334}/keras_rs_nightly.egg-info/PKG-INFO +1 -1
- {keras_rs_nightly-0.2.2.dev202506190335 → keras_rs_nightly-0.2.2.dev202506200334}/README.md +0 -0
- {keras_rs_nightly-0.2.2.dev202506190335 → keras_rs_nightly-0.2.2.dev202506200334}/keras_rs/api/__init__.py +0 -0
- {keras_rs_nightly-0.2.2.dev202506190335 → keras_rs_nightly-0.2.2.dev202506200334}/keras_rs/api/layers/__init__.py +0 -0
- {keras_rs_nightly-0.2.2.dev202506190335 → keras_rs_nightly-0.2.2.dev202506200334}/keras_rs/api/losses/__init__.py +0 -0
- {keras_rs_nightly-0.2.2.dev202506190335 → keras_rs_nightly-0.2.2.dev202506200334}/keras_rs/api/metrics/__init__.py +0 -0
- {keras_rs_nightly-0.2.2.dev202506190335 → keras_rs_nightly-0.2.2.dev202506200334}/keras_rs/src/__init__.py +0 -0
- {keras_rs_nightly-0.2.2.dev202506190335 → keras_rs_nightly-0.2.2.dev202506200334}/keras_rs/src/api_export.py +0 -0
- {keras_rs_nightly-0.2.2.dev202506190335 → keras_rs_nightly-0.2.2.dev202506200334}/keras_rs/src/layers/__init__.py +0 -0
- {keras_rs_nightly-0.2.2.dev202506190335 → keras_rs_nightly-0.2.2.dev202506200334}/keras_rs/src/layers/embedding/__init__.py +0 -0
- {keras_rs_nightly-0.2.2.dev202506190335 → keras_rs_nightly-0.2.2.dev202506200334}/keras_rs/src/layers/embedding/base_distributed_embedding.py +0 -0
- {keras_rs_nightly-0.2.2.dev202506190335 → keras_rs_nightly-0.2.2.dev202506200334}/keras_rs/src/layers/embedding/distributed_embedding.py +0 -0
- {keras_rs_nightly-0.2.2.dev202506190335 → keras_rs_nightly-0.2.2.dev202506200334}/keras_rs/src/layers/embedding/distributed_embedding_config.py +0 -0
- {keras_rs_nightly-0.2.2.dev202506190335 → keras_rs_nightly-0.2.2.dev202506200334}/keras_rs/src/layers/embedding/embed_reduce.py +0 -0
- {keras_rs_nightly-0.2.2.dev202506190335 → keras_rs_nightly-0.2.2.dev202506200334}/keras_rs/src/layers/embedding/jax/__init__.py +0 -0
- {keras_rs_nightly-0.2.2.dev202506190335 → keras_rs_nightly-0.2.2.dev202506200334}/keras_rs/src/layers/embedding/jax/checkpoint_utils.py +0 -0
- {keras_rs_nightly-0.2.2.dev202506190335 → keras_rs_nightly-0.2.2.dev202506200334}/keras_rs/src/layers/embedding/jax/config_conversion.py +0 -0
- {keras_rs_nightly-0.2.2.dev202506190335 → keras_rs_nightly-0.2.2.dev202506200334}/keras_rs/src/layers/embedding/jax/embedding_utils.py +0 -0
- {keras_rs_nightly-0.2.2.dev202506190335 → keras_rs_nightly-0.2.2.dev202506200334}/keras_rs/src/layers/embedding/tensorflow/__init__.py +0 -0
- {keras_rs_nightly-0.2.2.dev202506190335 → keras_rs_nightly-0.2.2.dev202506200334}/keras_rs/src/layers/embedding/tensorflow/config_conversion.py +0 -0
- {keras_rs_nightly-0.2.2.dev202506190335 → keras_rs_nightly-0.2.2.dev202506200334}/keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py +0 -0
- {keras_rs_nightly-0.2.2.dev202506190335 → keras_rs_nightly-0.2.2.dev202506200334}/keras_rs/src/layers/feature_interaction/__init__.py +0 -0
- {keras_rs_nightly-0.2.2.dev202506190335 → keras_rs_nightly-0.2.2.dev202506200334}/keras_rs/src/layers/feature_interaction/dot_interaction.py +0 -0
- {keras_rs_nightly-0.2.2.dev202506190335 → keras_rs_nightly-0.2.2.dev202506200334}/keras_rs/src/layers/feature_interaction/feature_cross.py +0 -0
- {keras_rs_nightly-0.2.2.dev202506190335 → keras_rs_nightly-0.2.2.dev202506200334}/keras_rs/src/layers/retrieval/__init__.py +0 -0
- {keras_rs_nightly-0.2.2.dev202506190335 → keras_rs_nightly-0.2.2.dev202506200334}/keras_rs/src/layers/retrieval/brute_force_retrieval.py +0 -0
- {keras_rs_nightly-0.2.2.dev202506190335 → keras_rs_nightly-0.2.2.dev202506200334}/keras_rs/src/layers/retrieval/hard_negative_mining.py +0 -0
- {keras_rs_nightly-0.2.2.dev202506190335 → keras_rs_nightly-0.2.2.dev202506200334}/keras_rs/src/layers/retrieval/remove_accidental_hits.py +0 -0
- {keras_rs_nightly-0.2.2.dev202506190335 → keras_rs_nightly-0.2.2.dev202506200334}/keras_rs/src/layers/retrieval/retrieval.py +0 -0
- {keras_rs_nightly-0.2.2.dev202506190335 → keras_rs_nightly-0.2.2.dev202506200334}/keras_rs/src/layers/retrieval/sampling_probability_correction.py +0 -0
- {keras_rs_nightly-0.2.2.dev202506190335 → keras_rs_nightly-0.2.2.dev202506200334}/keras_rs/src/losses/__init__.py +0 -0
- {keras_rs_nightly-0.2.2.dev202506190335 → keras_rs_nightly-0.2.2.dev202506200334}/keras_rs/src/losses/pairwise_hinge_loss.py +0 -0
- {keras_rs_nightly-0.2.2.dev202506190335 → keras_rs_nightly-0.2.2.dev202506200334}/keras_rs/src/losses/pairwise_logistic_loss.py +0 -0
- {keras_rs_nightly-0.2.2.dev202506190335 → keras_rs_nightly-0.2.2.dev202506200334}/keras_rs/src/losses/pairwise_loss.py +0 -0
- {keras_rs_nightly-0.2.2.dev202506190335 → keras_rs_nightly-0.2.2.dev202506200334}/keras_rs/src/losses/pairwise_loss_utils.py +0 -0
- {keras_rs_nightly-0.2.2.dev202506190335 → keras_rs_nightly-0.2.2.dev202506200334}/keras_rs/src/losses/pairwise_mean_squared_error.py +0 -0
- {keras_rs_nightly-0.2.2.dev202506190335 → keras_rs_nightly-0.2.2.dev202506200334}/keras_rs/src/losses/pairwise_soft_zero_one_loss.py +0 -0
- {keras_rs_nightly-0.2.2.dev202506190335 → keras_rs_nightly-0.2.2.dev202506200334}/keras_rs/src/metrics/__init__.py +0 -0
- {keras_rs_nightly-0.2.2.dev202506190335 → keras_rs_nightly-0.2.2.dev202506200334}/keras_rs/src/metrics/dcg.py +0 -0
- {keras_rs_nightly-0.2.2.dev202506190335 → keras_rs_nightly-0.2.2.dev202506200334}/keras_rs/src/metrics/mean_average_precision.py +0 -0
- {keras_rs_nightly-0.2.2.dev202506190335 → keras_rs_nightly-0.2.2.dev202506200334}/keras_rs/src/metrics/mean_reciprocal_rank.py +0 -0
- {keras_rs_nightly-0.2.2.dev202506190335 → keras_rs_nightly-0.2.2.dev202506200334}/keras_rs/src/metrics/ndcg.py +0 -0
- {keras_rs_nightly-0.2.2.dev202506190335 → keras_rs_nightly-0.2.2.dev202506200334}/keras_rs/src/metrics/precision_at_k.py +0 -0
- {keras_rs_nightly-0.2.2.dev202506190335 → keras_rs_nightly-0.2.2.dev202506200334}/keras_rs/src/metrics/ranking_metric.py +0 -0
- {keras_rs_nightly-0.2.2.dev202506190335 → keras_rs_nightly-0.2.2.dev202506200334}/keras_rs/src/metrics/ranking_metrics_utils.py +0 -0
- {keras_rs_nightly-0.2.2.dev202506190335 → keras_rs_nightly-0.2.2.dev202506200334}/keras_rs/src/metrics/recall_at_k.py +0 -0
- {keras_rs_nightly-0.2.2.dev202506190335 → keras_rs_nightly-0.2.2.dev202506200334}/keras_rs/src/metrics/utils.py +0 -0
- {keras_rs_nightly-0.2.2.dev202506190335 → keras_rs_nightly-0.2.2.dev202506200334}/keras_rs/src/types.py +0 -0
- {keras_rs_nightly-0.2.2.dev202506190335 → keras_rs_nightly-0.2.2.dev202506200334}/keras_rs/src/utils/__init__.py +0 -0
- {keras_rs_nightly-0.2.2.dev202506190335 → keras_rs_nightly-0.2.2.dev202506200334}/keras_rs/src/utils/doc_string_utils.py +0 -0
- {keras_rs_nightly-0.2.2.dev202506190335 → keras_rs_nightly-0.2.2.dev202506200334}/keras_rs/src/utils/keras_utils.py +0 -0
- {keras_rs_nightly-0.2.2.dev202506190335 → keras_rs_nightly-0.2.2.dev202506200334}/keras_rs_nightly.egg-info/SOURCES.txt +0 -0
- {keras_rs_nightly-0.2.2.dev202506190335 → keras_rs_nightly-0.2.2.dev202506200334}/keras_rs_nightly.egg-info/dependency_links.txt +0 -0
- {keras_rs_nightly-0.2.2.dev202506190335 → keras_rs_nightly-0.2.2.dev202506200334}/keras_rs_nightly.egg-info/requires.txt +0 -0
- {keras_rs_nightly-0.2.2.dev202506190335 → keras_rs_nightly-0.2.2.dev202506200334}/keras_rs_nightly.egg-info/top_level.txt +0 -0
- {keras_rs_nightly-0.2.2.dev202506190335 → keras_rs_nightly-0.2.2.dev202506200334}/pyproject.toml +0 -0
- {keras_rs_nightly-0.2.2.dev202506190335 → keras_rs_nightly-0.2.2.dev202506200334}/setup.cfg +0 -0
|
@@ -36,7 +36,7 @@ shard_map = jax.experimental.shard_map.shard_map # type: ignore[attr-defined]
|
|
|
36
36
|
def _get_partition_spec(
|
|
37
37
|
layout: (
|
|
38
38
|
keras.distribution.TensorLayout
|
|
39
|
-
| jax_layout.
|
|
39
|
+
| jax_layout.Format
|
|
40
40
|
| jax.sharding.NamedSharding
|
|
41
41
|
| jax.sharding.PartitionSpec
|
|
42
42
|
),
|
|
@@ -45,7 +45,7 @@ def _get_partition_spec(
|
|
|
45
45
|
if isinstance(layout, keras.distribution.TensorLayout):
|
|
46
46
|
layout = layout.backend_layout
|
|
47
47
|
|
|
48
|
-
if isinstance(layout, jax_layout.
|
|
48
|
+
if isinstance(layout, jax_layout.Format):
|
|
49
49
|
layout = layout.sharding
|
|
50
50
|
|
|
51
51
|
if isinstance(layout, jax.sharding.NamedSharding):
|
|
@@ -217,7 +217,7 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
|
|
|
217
217
|
sparsecore_layout = keras.distribution.TensorLayout(axes, device_mesh)
|
|
218
218
|
# Custom sparsecore layout with tiling.
|
|
219
219
|
# pylint: disable-next=protected-access
|
|
220
|
-
sparsecore_layout._backend_layout = jax_layout.
|
|
220
|
+
sparsecore_layout._backend_layout = jax_layout.Format(
|
|
221
221
|
jax_layout.DeviceLocalLayout(
|
|
222
222
|
major_to_minor=(0, 1),
|
|
223
223
|
_tiling=((8,),),
|
|
@@ -8,7 +8,7 @@ from typing import Any, Mapping, TypeAlias
|
|
|
8
8
|
|
|
9
9
|
import jax
|
|
10
10
|
import numpy as np
|
|
11
|
-
from jax.experimental import layout
|
|
11
|
+
from jax.experimental import layout as jax_layout
|
|
12
12
|
from jax_tpu_embedding.sparsecore.lib.nn import embedding
|
|
13
13
|
from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec
|
|
14
14
|
from jax_tpu_embedding.sparsecore.utils import utils as jte_utils
|
|
@@ -20,7 +20,7 @@ ShardedCooMatrix = embedding_utils.ShardedCooMatrix
|
|
|
20
20
|
shard_map = jax.experimental.shard_map.shard_map # type: ignore[attr-defined]
|
|
21
21
|
|
|
22
22
|
ArrayLike: TypeAlias = jax.Array | np.ndarray[Any, Any]
|
|
23
|
-
JaxLayout: TypeAlias = jax.sharding.NamedSharding |
|
|
23
|
+
JaxLayout: TypeAlias = jax.sharding.NamedSharding | jax_layout.Format
|
|
24
24
|
|
|
25
25
|
|
|
26
26
|
class EmbeddingLookupConfiguration:
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{keras_rs_nightly-0.2.2.dev202506190335 → keras_rs_nightly-0.2.2.dev202506200334}/pyproject.toml
RENAMED
|
File without changes
|
|
File without changes
|