keras-rs-nightly 0.0.1.dev2025043003__py3-none-any.whl → 0.2.2.dev202506100336__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of keras-rs-nightly might be problematic. Click here for more details.

Files changed (37) hide show
  1. keras_rs/layers/__init__.py +12 -0
  2. keras_rs/src/layers/embedding/__init__.py +0 -0
  3. keras_rs/src/layers/embedding/base_distributed_embedding.py +1124 -0
  4. keras_rs/src/layers/embedding/distributed_embedding.py +33 -0
  5. keras_rs/src/layers/embedding/distributed_embedding_config.py +129 -0
  6. keras_rs/src/layers/embedding/embed_reduce.py +309 -0
  7. keras_rs/src/layers/embedding/jax/__init__.py +0 -0
  8. keras_rs/src/layers/embedding/jax/config_conversion.py +398 -0
  9. keras_rs/src/layers/embedding/jax/distributed_embedding.py +892 -0
  10. keras_rs/src/layers/embedding/jax/embedding_lookup.py +255 -0
  11. keras_rs/src/layers/embedding/jax/embedding_utils.py +596 -0
  12. keras_rs/src/layers/embedding/tensorflow/__init__.py +0 -0
  13. keras_rs/src/layers/embedding/tensorflow/config_conversion.py +323 -0
  14. keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py +424 -0
  15. keras_rs/src/layers/feature_interaction/dot_interaction.py +2 -2
  16. keras_rs/src/layers/feature_interaction/feature_cross.py +14 -16
  17. keras_rs/src/layers/retrieval/brute_force_retrieval.py +5 -5
  18. keras_rs/src/layers/retrieval/retrieval.py +4 -4
  19. keras_rs/src/losses/pairwise_loss.py +2 -2
  20. keras_rs/src/losses/pairwise_mean_squared_error.py +1 -3
  21. keras_rs/src/metrics/dcg.py +2 -2
  22. keras_rs/src/metrics/mean_average_precision.py +1 -1
  23. keras_rs/src/metrics/mean_reciprocal_rank.py +4 -4
  24. keras_rs/src/metrics/ndcg.py +2 -2
  25. keras_rs/src/metrics/precision_at_k.py +3 -3
  26. keras_rs/src/metrics/ranking_metric.py +11 -5
  27. keras_rs/src/metrics/ranking_metrics_utils.py +10 -10
  28. keras_rs/src/metrics/recall_at_k.py +2 -2
  29. keras_rs/src/metrics/utils.py +2 -4
  30. keras_rs/src/types.py +43 -14
  31. keras_rs/src/utils/keras_utils.py +26 -6
  32. keras_rs/src/version.py +1 -1
  33. {keras_rs_nightly-0.0.1.dev2025043003.dist-info → keras_rs_nightly-0.2.2.dev202506100336.dist-info}/METADATA +6 -3
  34. keras_rs_nightly-0.2.2.dev202506100336.dist-info/RECORD +55 -0
  35. {keras_rs_nightly-0.0.1.dev2025043003.dist-info → keras_rs_nightly-0.2.2.dev202506100336.dist-info}/WHEEL +1 -1
  36. keras_rs_nightly-0.0.1.dev2025043003.dist-info/RECORD +0 -42
  37. {keras_rs_nightly-0.0.1.dev2025043003.dist-info → keras_rs_nightly-0.2.2.dev202506100336.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,255 @@
1
+ """Defines a differentiable `embedding_lookup` function.
2
+
3
+ Implementation details for use in JAX models.
4
+ """
5
+
6
+ import functools
7
+ from typing import Any, Mapping, TypeAlias
8
+
9
+ import jax
10
+ import numpy as np
11
+ from jax.experimental import layout
12
+ from jax_tpu_embedding.sparsecore.lib.nn import embedding
13
+ from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec
14
+ from jax_tpu_embedding.sparsecore.utils import utils as jte_utils
15
+
16
+ from keras_rs.src.layers.embedding.jax import embedding_utils
17
+ from keras_rs.src.types import Nested
18
+
19
+ ShardedCooMatrix = embedding_utils.ShardedCooMatrix
20
+ shard_map = jax.experimental.shard_map.shard_map # type: ignore[attr-defined]
21
+
22
+ ArrayLike: TypeAlias = jax.Array | np.ndarray[Any, Any]
23
+ JaxLayout: TypeAlias = jax.sharding.NamedSharding | layout.Layout
24
+
25
+
26
+ class EmbeddingLookupConfiguration:
27
+ """Feature, mesh and sharding configrmation for lookups."""
28
+
29
+ mesh: jax.sharding.Mesh
30
+ feature_specs: embedding.Nested[embedding_spec.FeatureSpec]
31
+ table_sharding_strategy: str
32
+ num_sc_per_device: int
33
+ samples_partition: jax.sharding.PartitionSpec
34
+ table_partition: jax.sharding.PartitionSpec
35
+ table_layout: JaxLayout
36
+
37
+ def __init__(
38
+ self,
39
+ feature_specs: embedding.Nested[embedding_spec.FeatureSpec],
40
+ mesh: jax.sharding.Mesh | None = None,
41
+ table_sharding_strategy: str = "MOD",
42
+ num_sc_per_device: int | None = None,
43
+ sharding_axis: str = "sparsecore_sharding",
44
+ samples_partition: jax.sharding.PartitionSpec | None = None,
45
+ samples_layout: JaxLayout | None = None,
46
+ table_partition: jax.sharding.PartitionSpec | None = None,
47
+ table_layout: JaxLayout | None = None,
48
+ ):
49
+ self.mesh = mesh or jax.sharding.Mesh(jax.devices(), sharding_axis)
50
+ self.feature_specs = feature_specs
51
+ self.table_sharding_strategy = table_sharding_strategy
52
+ self.num_sc_per_device = (
53
+ num_sc_per_device
54
+ if num_sc_per_device is not None
55
+ else jte_utils.num_sparsecores_per_device()
56
+ )
57
+ self.samples_partition = (
58
+ samples_partition
59
+ or jax.sharding.PartitionSpec(
60
+ sharding_axis # type: ignore[no-untyped-call]
61
+ )
62
+ )
63
+ self.samples_layout = samples_layout or jax.sharding.NamedSharding(
64
+ self.mesh, self.samples_partition
65
+ )
66
+ self.table_partition = table_partition or jax.sharding.PartitionSpec(
67
+ sharding_axis,
68
+ None, # type: ignore[no-untyped-call]
69
+ )
70
+ self.table_layout = table_layout or jax.sharding.NamedSharding(
71
+ self.mesh, self.table_partition
72
+ )
73
+
74
+
75
+ # Embedding lookup function with custom gradient.
76
+ @functools.partial(jax.custom_vjp, nondiff_argnums=(0,))
77
+ def embedding_lookup(
78
+ config: EmbeddingLookupConfiguration,
79
+ lookups: Mapping[str, ShardedCooMatrix],
80
+ tables: Nested[jax.Array],
81
+ step: jax.Array | None = None,
82
+ ) -> Nested[jax.Array]:
83
+ """Embedding lookup function with custom gradient.
84
+
85
+ Args:
86
+ config: Embedding lookup configuration.
87
+ lookups: Embedding lookup stacked/sharded inputs.
88
+ tables: Embedding lookup stacked/sharded tables.
89
+ step: Current training step number.
90
+ """
91
+ del step # Only used in backward pass.
92
+
93
+ # Decompose COO matrices.
94
+ row_pointers_raw = {}
95
+ embedding_ids_raw = {}
96
+ sample_ids_raw = {}
97
+ gains_raw = {}
98
+ for table_name, coo in lookups.items():
99
+ row_pointers_raw[table_name] = coo.shard_ends
100
+ embedding_ids_raw[table_name] = coo.col_ids
101
+ sample_ids_raw[table_name] = coo.row_ids
102
+ gains_raw[table_name] = coo.values
103
+
104
+ sparse_dense_matmul_input = embedding.SparseDenseMatmulInput(
105
+ row_pointers_raw,
106
+ embedding_ids_raw,
107
+ sample_ids_raw,
108
+ gains_raw,
109
+ )
110
+
111
+ pd = config.samples_partition
112
+ pt = config.table_partition
113
+ sharded_matmul = jax.jit(
114
+ shard_map(
115
+ functools.partial(
116
+ embedding.tpu_sparse_dense_matmul,
117
+ global_device_count=config.mesh.shape[pd[0]],
118
+ feature_specs=config.feature_specs,
119
+ sharding_strategy=config.table_sharding_strategy,
120
+ ),
121
+ mesh=config.mesh,
122
+ in_specs=(pd, pt),
123
+ out_specs=pd,
124
+ check_rep=False,
125
+ ),
126
+ )
127
+
128
+ activations: Nested[jax.Array] = sharded_matmul(
129
+ sparse_dense_matmul_input,
130
+ tables,
131
+ )
132
+
133
+ return activations
134
+
135
+
136
+ def embedding_lookup_fwd(
137
+ config: EmbeddingLookupConfiguration,
138
+ lookups: Mapping[str, ShardedCooMatrix],
139
+ table: Nested[jax.Array],
140
+ step: jax.Array | None = None,
141
+ ) -> tuple[
142
+ Nested[jax.Array],
143
+ tuple[Nested[ShardedCooMatrix], Nested[jax.Array], jax.Array | None],
144
+ ]:
145
+ """Forward pass for embedding lookup."""
146
+ return embedding_lookup(config, lookups, table, step), (
147
+ lookups,
148
+ table,
149
+ step,
150
+ )
151
+
152
+
153
+ def embedding_lookup_bwd(
154
+ config: EmbeddingLookupConfiguration,
155
+ res: tuple[
156
+ Mapping[str, ShardedCooMatrix], # Lookups.
157
+ Mapping[str, Nested[jax.Array]], # Tables.
158
+ jax.Array | None, # Step.
159
+ ],
160
+ gradients: Nested[jax.Array],
161
+ ) -> tuple[None, Nested[jax.Array], jax.Array | None]:
162
+ """Backward pass for embedding lookup.
163
+
164
+ Args:
165
+ config: Embedding lookup configuration.
166
+ res: Tuple of embedding lookup (inputs, tables, step).
167
+ gradients: Embedding lookup gradients.
168
+
169
+ Returns:
170
+ A tuple of gradients (None, table_grads, step + 1).
171
+ """
172
+ lookups, tables, step = res
173
+
174
+ # Decompose COO matrices.
175
+ row_pointers_raw = {}
176
+ embedding_ids_raw = {}
177
+ sample_ids_raw = {}
178
+ gains_raw = {}
179
+ for table_name, coo in lookups.items():
180
+ row_pointers_raw[table_name] = coo.shard_ends
181
+ embedding_ids_raw[table_name] = coo.col_ids
182
+ sample_ids_raw[table_name] = coo.row_ids
183
+ gains_raw[table_name] = coo.values
184
+
185
+ sparse_dense_matmul_input = embedding.SparseDenseMatmulInput(
186
+ row_pointers_raw,
187
+ embedding_ids_raw,
188
+ sample_ids_raw,
189
+ gains_raw,
190
+ )
191
+
192
+ pt = config.table_partition
193
+ pd = config.samples_partition
194
+ # Replicate step count.
195
+ preplicate = jax.sharding.PartitionSpec() # type: ignore[no-untyped-call]
196
+
197
+ def grad_func(
198
+ gradients: Nested[jax.Array],
199
+ sparse_input: embedding.SparseDenseMatmulInput,
200
+ tables: Mapping[str, embedding.EmbeddingVariables],
201
+ step: jax.Array | None,
202
+ ) -> Mapping[str, embedding.EmbeddingVariables]:
203
+ output: Mapping[str, embedding.EmbeddingVariables] = (
204
+ embedding.tpu_sparse_dense_matmul_grad(
205
+ gradients,
206
+ sparse_input,
207
+ tables,
208
+ feature_specs=config.feature_specs,
209
+ sharding_strategy=config.table_sharding_strategy,
210
+ step=step,
211
+ )
212
+ )
213
+ return output
214
+
215
+ # activation_layout = jax.sharding.NamedSharding(config.mesh, pd)
216
+ # step_layout = jax.sharding.NamedSharding(config.mesh, preplicate)
217
+ sharded_matmul_grad = jax.jit(
218
+ shard_map(
219
+ grad_func,
220
+ mesh=config.mesh,
221
+ in_specs=(pd, pd, pt, preplicate),
222
+ out_specs=pt,
223
+ check_rep=False,
224
+ ),
225
+ # in_shardings=(
226
+ # activation_layout,
227
+ # config.samples_layout,
228
+ # config.table_layout,
229
+ # step_layout,
230
+ # ),
231
+ # out_shardings=(config.table_layout),
232
+ )
233
+
234
+ table_grads = sharded_matmul_grad(
235
+ gradients,
236
+ sparse_dense_matmul_input,
237
+ tables,
238
+ step,
239
+ )
240
+
241
+ # tpu_sparse_dense_matmul_grad returns a general Mapping (usually a dict).
242
+ # It may not be the same type as the embedding table (e.g. FrozenDict).
243
+ # Here we use flatten / unflatten to ensure the types are the same.
244
+ table_grads = jax.tree.unflatten(
245
+ jax.tree.structure(tables), jax.tree.leaves(table_grads)
246
+ )
247
+
248
+ return (
249
+ None,
250
+ table_grads,
251
+ step + 1 if step is not None else None, # Incremented step count.
252
+ )
253
+
254
+
255
+ embedding_lookup.defvjp(embedding_lookup_fwd, embedding_lookup_bwd)