keras-rs-nightly 0.0.1.dev2025050103__py3-none-any.whl → 0.2.2.dev202506100336__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of keras-rs-nightly might be problematic. Click here for more details.
- keras_rs/layers/__init__.py +12 -0
- keras_rs/src/layers/embedding/__init__.py +0 -0
- keras_rs/src/layers/embedding/base_distributed_embedding.py +1124 -0
- keras_rs/src/layers/embedding/distributed_embedding.py +33 -0
- keras_rs/src/layers/embedding/distributed_embedding_config.py +129 -0
- keras_rs/src/layers/embedding/embed_reduce.py +309 -0
- keras_rs/src/layers/embedding/jax/__init__.py +0 -0
- keras_rs/src/layers/embedding/jax/config_conversion.py +398 -0
- keras_rs/src/layers/embedding/jax/distributed_embedding.py +892 -0
- keras_rs/src/layers/embedding/jax/embedding_lookup.py +255 -0
- keras_rs/src/layers/embedding/jax/embedding_utils.py +596 -0
- keras_rs/src/layers/embedding/tensorflow/__init__.py +0 -0
- keras_rs/src/layers/embedding/tensorflow/config_conversion.py +323 -0
- keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py +424 -0
- keras_rs/src/layers/feature_interaction/dot_interaction.py +2 -2
- keras_rs/src/layers/feature_interaction/feature_cross.py +14 -16
- keras_rs/src/layers/retrieval/brute_force_retrieval.py +5 -5
- keras_rs/src/layers/retrieval/retrieval.py +4 -4
- keras_rs/src/losses/pairwise_loss.py +2 -2
- keras_rs/src/losses/pairwise_mean_squared_error.py +1 -3
- keras_rs/src/metrics/dcg.py +2 -2
- keras_rs/src/metrics/ndcg.py +2 -2
- keras_rs/src/metrics/ranking_metric.py +4 -4
- keras_rs/src/metrics/ranking_metrics_utils.py +8 -8
- keras_rs/src/metrics/utils.py +2 -4
- keras_rs/src/types.py +43 -14
- keras_rs/src/utils/keras_utils.py +26 -6
- keras_rs/src/version.py +1 -1
- {keras_rs_nightly-0.0.1.dev2025050103.dist-info → keras_rs_nightly-0.2.2.dev202506100336.dist-info}/METADATA +6 -3
- keras_rs_nightly-0.2.2.dev202506100336.dist-info/RECORD +55 -0
- {keras_rs_nightly-0.0.1.dev2025050103.dist-info → keras_rs_nightly-0.2.2.dev202506100336.dist-info}/WHEEL +1 -1
- keras_rs_nightly-0.0.1.dev2025050103.dist-info/RECORD +0 -42
- {keras_rs_nightly-0.0.1.dev2025050103.dist-info → keras_rs_nightly-0.2.2.dev202506100336.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,255 @@
|
|
|
1
|
+
"""Defines a differentiable `embedding_lookup` function.
|
|
2
|
+
|
|
3
|
+
Implementation details for use in JAX models.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import functools
|
|
7
|
+
from typing import Any, Mapping, TypeAlias
|
|
8
|
+
|
|
9
|
+
import jax
|
|
10
|
+
import numpy as np
|
|
11
|
+
from jax.experimental import layout
|
|
12
|
+
from jax_tpu_embedding.sparsecore.lib.nn import embedding
|
|
13
|
+
from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec
|
|
14
|
+
from jax_tpu_embedding.sparsecore.utils import utils as jte_utils
|
|
15
|
+
|
|
16
|
+
from keras_rs.src.layers.embedding.jax import embedding_utils
|
|
17
|
+
from keras_rs.src.types import Nested
|
|
18
|
+
|
|
19
|
+
ShardedCooMatrix = embedding_utils.ShardedCooMatrix
|
|
20
|
+
shard_map = jax.experimental.shard_map.shard_map # type: ignore[attr-defined]
|
|
21
|
+
|
|
22
|
+
ArrayLike: TypeAlias = jax.Array | np.ndarray[Any, Any]
|
|
23
|
+
JaxLayout: TypeAlias = jax.sharding.NamedSharding | layout.Layout
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class EmbeddingLookupConfiguration:
|
|
27
|
+
"""Feature, mesh and sharding configrmation for lookups."""
|
|
28
|
+
|
|
29
|
+
mesh: jax.sharding.Mesh
|
|
30
|
+
feature_specs: embedding.Nested[embedding_spec.FeatureSpec]
|
|
31
|
+
table_sharding_strategy: str
|
|
32
|
+
num_sc_per_device: int
|
|
33
|
+
samples_partition: jax.sharding.PartitionSpec
|
|
34
|
+
table_partition: jax.sharding.PartitionSpec
|
|
35
|
+
table_layout: JaxLayout
|
|
36
|
+
|
|
37
|
+
def __init__(
|
|
38
|
+
self,
|
|
39
|
+
feature_specs: embedding.Nested[embedding_spec.FeatureSpec],
|
|
40
|
+
mesh: jax.sharding.Mesh | None = None,
|
|
41
|
+
table_sharding_strategy: str = "MOD",
|
|
42
|
+
num_sc_per_device: int | None = None,
|
|
43
|
+
sharding_axis: str = "sparsecore_sharding",
|
|
44
|
+
samples_partition: jax.sharding.PartitionSpec | None = None,
|
|
45
|
+
samples_layout: JaxLayout | None = None,
|
|
46
|
+
table_partition: jax.sharding.PartitionSpec | None = None,
|
|
47
|
+
table_layout: JaxLayout | None = None,
|
|
48
|
+
):
|
|
49
|
+
self.mesh = mesh or jax.sharding.Mesh(jax.devices(), sharding_axis)
|
|
50
|
+
self.feature_specs = feature_specs
|
|
51
|
+
self.table_sharding_strategy = table_sharding_strategy
|
|
52
|
+
self.num_sc_per_device = (
|
|
53
|
+
num_sc_per_device
|
|
54
|
+
if num_sc_per_device is not None
|
|
55
|
+
else jte_utils.num_sparsecores_per_device()
|
|
56
|
+
)
|
|
57
|
+
self.samples_partition = (
|
|
58
|
+
samples_partition
|
|
59
|
+
or jax.sharding.PartitionSpec(
|
|
60
|
+
sharding_axis # type: ignore[no-untyped-call]
|
|
61
|
+
)
|
|
62
|
+
)
|
|
63
|
+
self.samples_layout = samples_layout or jax.sharding.NamedSharding(
|
|
64
|
+
self.mesh, self.samples_partition
|
|
65
|
+
)
|
|
66
|
+
self.table_partition = table_partition or jax.sharding.PartitionSpec(
|
|
67
|
+
sharding_axis,
|
|
68
|
+
None, # type: ignore[no-untyped-call]
|
|
69
|
+
)
|
|
70
|
+
self.table_layout = table_layout or jax.sharding.NamedSharding(
|
|
71
|
+
self.mesh, self.table_partition
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
# Embedding lookup function with custom gradient.
|
|
76
|
+
@functools.partial(jax.custom_vjp, nondiff_argnums=(0,))
|
|
77
|
+
def embedding_lookup(
|
|
78
|
+
config: EmbeddingLookupConfiguration,
|
|
79
|
+
lookups: Mapping[str, ShardedCooMatrix],
|
|
80
|
+
tables: Nested[jax.Array],
|
|
81
|
+
step: jax.Array | None = None,
|
|
82
|
+
) -> Nested[jax.Array]:
|
|
83
|
+
"""Embedding lookup function with custom gradient.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
config: Embedding lookup configuration.
|
|
87
|
+
lookups: Embedding lookup stacked/sharded inputs.
|
|
88
|
+
tables: Embedding lookup stacked/sharded tables.
|
|
89
|
+
step: Current training step number.
|
|
90
|
+
"""
|
|
91
|
+
del step # Only used in backward pass.
|
|
92
|
+
|
|
93
|
+
# Decompose COO matrices.
|
|
94
|
+
row_pointers_raw = {}
|
|
95
|
+
embedding_ids_raw = {}
|
|
96
|
+
sample_ids_raw = {}
|
|
97
|
+
gains_raw = {}
|
|
98
|
+
for table_name, coo in lookups.items():
|
|
99
|
+
row_pointers_raw[table_name] = coo.shard_ends
|
|
100
|
+
embedding_ids_raw[table_name] = coo.col_ids
|
|
101
|
+
sample_ids_raw[table_name] = coo.row_ids
|
|
102
|
+
gains_raw[table_name] = coo.values
|
|
103
|
+
|
|
104
|
+
sparse_dense_matmul_input = embedding.SparseDenseMatmulInput(
|
|
105
|
+
row_pointers_raw,
|
|
106
|
+
embedding_ids_raw,
|
|
107
|
+
sample_ids_raw,
|
|
108
|
+
gains_raw,
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
pd = config.samples_partition
|
|
112
|
+
pt = config.table_partition
|
|
113
|
+
sharded_matmul = jax.jit(
|
|
114
|
+
shard_map(
|
|
115
|
+
functools.partial(
|
|
116
|
+
embedding.tpu_sparse_dense_matmul,
|
|
117
|
+
global_device_count=config.mesh.shape[pd[0]],
|
|
118
|
+
feature_specs=config.feature_specs,
|
|
119
|
+
sharding_strategy=config.table_sharding_strategy,
|
|
120
|
+
),
|
|
121
|
+
mesh=config.mesh,
|
|
122
|
+
in_specs=(pd, pt),
|
|
123
|
+
out_specs=pd,
|
|
124
|
+
check_rep=False,
|
|
125
|
+
),
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
activations: Nested[jax.Array] = sharded_matmul(
|
|
129
|
+
sparse_dense_matmul_input,
|
|
130
|
+
tables,
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
return activations
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def embedding_lookup_fwd(
|
|
137
|
+
config: EmbeddingLookupConfiguration,
|
|
138
|
+
lookups: Mapping[str, ShardedCooMatrix],
|
|
139
|
+
table: Nested[jax.Array],
|
|
140
|
+
step: jax.Array | None = None,
|
|
141
|
+
) -> tuple[
|
|
142
|
+
Nested[jax.Array],
|
|
143
|
+
tuple[Nested[ShardedCooMatrix], Nested[jax.Array], jax.Array | None],
|
|
144
|
+
]:
|
|
145
|
+
"""Forward pass for embedding lookup."""
|
|
146
|
+
return embedding_lookup(config, lookups, table, step), (
|
|
147
|
+
lookups,
|
|
148
|
+
table,
|
|
149
|
+
step,
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def embedding_lookup_bwd(
|
|
154
|
+
config: EmbeddingLookupConfiguration,
|
|
155
|
+
res: tuple[
|
|
156
|
+
Mapping[str, ShardedCooMatrix], # Lookups.
|
|
157
|
+
Mapping[str, Nested[jax.Array]], # Tables.
|
|
158
|
+
jax.Array | None, # Step.
|
|
159
|
+
],
|
|
160
|
+
gradients: Nested[jax.Array],
|
|
161
|
+
) -> tuple[None, Nested[jax.Array], jax.Array | None]:
|
|
162
|
+
"""Backward pass for embedding lookup.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
config: Embedding lookup configuration.
|
|
166
|
+
res: Tuple of embedding lookup (inputs, tables, step).
|
|
167
|
+
gradients: Embedding lookup gradients.
|
|
168
|
+
|
|
169
|
+
Returns:
|
|
170
|
+
A tuple of gradients (None, table_grads, step + 1).
|
|
171
|
+
"""
|
|
172
|
+
lookups, tables, step = res
|
|
173
|
+
|
|
174
|
+
# Decompose COO matrices.
|
|
175
|
+
row_pointers_raw = {}
|
|
176
|
+
embedding_ids_raw = {}
|
|
177
|
+
sample_ids_raw = {}
|
|
178
|
+
gains_raw = {}
|
|
179
|
+
for table_name, coo in lookups.items():
|
|
180
|
+
row_pointers_raw[table_name] = coo.shard_ends
|
|
181
|
+
embedding_ids_raw[table_name] = coo.col_ids
|
|
182
|
+
sample_ids_raw[table_name] = coo.row_ids
|
|
183
|
+
gains_raw[table_name] = coo.values
|
|
184
|
+
|
|
185
|
+
sparse_dense_matmul_input = embedding.SparseDenseMatmulInput(
|
|
186
|
+
row_pointers_raw,
|
|
187
|
+
embedding_ids_raw,
|
|
188
|
+
sample_ids_raw,
|
|
189
|
+
gains_raw,
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
pt = config.table_partition
|
|
193
|
+
pd = config.samples_partition
|
|
194
|
+
# Replicate step count.
|
|
195
|
+
preplicate = jax.sharding.PartitionSpec() # type: ignore[no-untyped-call]
|
|
196
|
+
|
|
197
|
+
def grad_func(
|
|
198
|
+
gradients: Nested[jax.Array],
|
|
199
|
+
sparse_input: embedding.SparseDenseMatmulInput,
|
|
200
|
+
tables: Mapping[str, embedding.EmbeddingVariables],
|
|
201
|
+
step: jax.Array | None,
|
|
202
|
+
) -> Mapping[str, embedding.EmbeddingVariables]:
|
|
203
|
+
output: Mapping[str, embedding.EmbeddingVariables] = (
|
|
204
|
+
embedding.tpu_sparse_dense_matmul_grad(
|
|
205
|
+
gradients,
|
|
206
|
+
sparse_input,
|
|
207
|
+
tables,
|
|
208
|
+
feature_specs=config.feature_specs,
|
|
209
|
+
sharding_strategy=config.table_sharding_strategy,
|
|
210
|
+
step=step,
|
|
211
|
+
)
|
|
212
|
+
)
|
|
213
|
+
return output
|
|
214
|
+
|
|
215
|
+
# activation_layout = jax.sharding.NamedSharding(config.mesh, pd)
|
|
216
|
+
# step_layout = jax.sharding.NamedSharding(config.mesh, preplicate)
|
|
217
|
+
sharded_matmul_grad = jax.jit(
|
|
218
|
+
shard_map(
|
|
219
|
+
grad_func,
|
|
220
|
+
mesh=config.mesh,
|
|
221
|
+
in_specs=(pd, pd, pt, preplicate),
|
|
222
|
+
out_specs=pt,
|
|
223
|
+
check_rep=False,
|
|
224
|
+
),
|
|
225
|
+
# in_shardings=(
|
|
226
|
+
# activation_layout,
|
|
227
|
+
# config.samples_layout,
|
|
228
|
+
# config.table_layout,
|
|
229
|
+
# step_layout,
|
|
230
|
+
# ),
|
|
231
|
+
# out_shardings=(config.table_layout),
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
table_grads = sharded_matmul_grad(
|
|
235
|
+
gradients,
|
|
236
|
+
sparse_dense_matmul_input,
|
|
237
|
+
tables,
|
|
238
|
+
step,
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
# tpu_sparse_dense_matmul_grad returns a general Mapping (usually a dict).
|
|
242
|
+
# It may not be the same type as the embedding table (e.g. FrozenDict).
|
|
243
|
+
# Here we use flatten / unflatten to ensure the types are the same.
|
|
244
|
+
table_grads = jax.tree.unflatten(
|
|
245
|
+
jax.tree.structure(tables), jax.tree.leaves(table_grads)
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
return (
|
|
249
|
+
None,
|
|
250
|
+
table_grads,
|
|
251
|
+
step + 1 if step is not None else None, # Incremented step count.
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
embedding_lookup.defvjp(embedding_lookup_fwd, embedding_lookup_bwd)
|