keras-rs-nightly 0.0.1.dev2025021903__py3-none-any.whl → 0.3.1.dev202512130338__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. keras_rs/__init__.py +9 -28
  2. keras_rs/layers/__init__.py +37 -0
  3. keras_rs/losses/__init__.py +19 -0
  4. keras_rs/metrics/__init__.py +16 -0
  5. keras_rs/src/layers/embedding/base_distributed_embedding.py +1151 -0
  6. keras_rs/src/layers/embedding/distributed_embedding.py +33 -0
  7. keras_rs/src/layers/embedding/distributed_embedding_config.py +132 -0
  8. keras_rs/src/layers/embedding/embed_reduce.py +309 -0
  9. keras_rs/src/layers/embedding/jax/__init__.py +0 -0
  10. keras_rs/src/layers/embedding/jax/checkpoint_utils.py +104 -0
  11. keras_rs/src/layers/embedding/jax/config_conversion.py +468 -0
  12. keras_rs/src/layers/embedding/jax/distributed_embedding.py +829 -0
  13. keras_rs/src/layers/embedding/jax/embedding_lookup.py +276 -0
  14. keras_rs/src/layers/embedding/jax/embedding_utils.py +217 -0
  15. keras_rs/src/layers/embedding/tensorflow/__init__.py +0 -0
  16. keras_rs/src/layers/embedding/tensorflow/config_conversion.py +363 -0
  17. keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py +436 -0
  18. keras_rs/src/layers/feature_interaction/__init__.py +0 -0
  19. keras_rs/src/layers/{modeling → feature_interaction}/dot_interaction.py +116 -25
  20. keras_rs/src/layers/{modeling → feature_interaction}/feature_cross.py +40 -22
  21. keras_rs/src/layers/retrieval/brute_force_retrieval.py +16 -65
  22. keras_rs/src/layers/retrieval/hard_negative_mining.py +94 -0
  23. keras_rs/src/layers/retrieval/remove_accidental_hits.py +97 -0
  24. keras_rs/src/layers/retrieval/retrieval.py +127 -0
  25. keras_rs/src/layers/retrieval/sampling_probability_correction.py +63 -0
  26. keras_rs/src/losses/__init__.py +0 -0
  27. keras_rs/src/losses/list_mle_loss.py +212 -0
  28. keras_rs/src/losses/pairwise_hinge_loss.py +90 -0
  29. keras_rs/src/losses/pairwise_logistic_loss.py +99 -0
  30. keras_rs/src/losses/pairwise_loss.py +165 -0
  31. keras_rs/src/losses/pairwise_loss_utils.py +39 -0
  32. keras_rs/src/losses/pairwise_mean_squared_error.py +133 -0
  33. keras_rs/src/losses/pairwise_soft_zero_one_loss.py +98 -0
  34. keras_rs/src/metrics/__init__.py +0 -0
  35. keras_rs/src/metrics/dcg.py +161 -0
  36. keras_rs/src/metrics/mean_average_precision.py +130 -0
  37. keras_rs/src/metrics/mean_reciprocal_rank.py +121 -0
  38. keras_rs/src/metrics/ndcg.py +197 -0
  39. keras_rs/src/metrics/precision_at_k.py +117 -0
  40. keras_rs/src/metrics/ranking_metric.py +260 -0
  41. keras_rs/src/metrics/ranking_metrics_utils.py +257 -0
  42. keras_rs/src/metrics/recall_at_k.py +108 -0
  43. keras_rs/src/metrics/utils.py +70 -0
  44. keras_rs/src/types.py +43 -14
  45. keras_rs/src/utils/doc_string_utils.py +53 -0
  46. keras_rs/src/utils/keras_utils.py +52 -3
  47. keras_rs/src/utils/tpu_test_utils.py +120 -0
  48. keras_rs/src/version.py +1 -1
  49. {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/METADATA +88 -8
  50. keras_rs_nightly-0.3.1.dev202512130338.dist-info/RECORD +58 -0
  51. {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/WHEEL +1 -1
  52. keras_rs/api/__init__.py +0 -9
  53. keras_rs/api/layers/__init__.py +0 -11
  54. keras_rs_nightly-0.0.1.dev2025021903.dist-info/RECORD +0 -19
  55. /keras_rs/src/layers/{modeling → embedding}/__init__.py +0 -0
  56. {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,276 @@
1
+ """Defines a differentiable `embedding_lookup` function.
2
+
3
+ Implementation details for use in JAX models.
4
+ """
5
+
6
+ import functools
7
+ from typing import Any, Mapping, TypeAlias
8
+
9
+ import jax
10
+ import numpy as np
11
+ from jax.experimental import layout as jax_layout
12
+ from jax_tpu_embedding.sparsecore.lib.nn import embedding
13
+ from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec
14
+ from jax_tpu_embedding.sparsecore.utils import utils as jte_utils
15
+
16
+ from keras_rs.src.layers.embedding.jax import embedding_utils
17
+ from keras_rs.src.types import Nested
18
+
19
+ if jax.__version_info__ >= (0, 8, 0):
20
+ from jax import shard_map
21
+ else:
22
+ from jax.experimental.shard_map import shard_map as exp_shard_map
23
+
24
+ def shard_map( # type: ignore[misc]
25
+ f: Any = None,
26
+ /,
27
+ *,
28
+ out_specs: Any,
29
+ in_specs: Any,
30
+ mesh: Any = None,
31
+ check_vma: bool = True,
32
+ ) -> Any:
33
+ return exp_shard_map(
34
+ f,
35
+ mesh=mesh,
36
+ in_specs=in_specs,
37
+ out_specs=out_specs,
38
+ check_rep=check_vma,
39
+ ) # type: ignore[no-untyped-call]
40
+
41
+
42
+ ShardedCooMatrix = embedding_utils.ShardedCooMatrix
43
+ ArrayLike: TypeAlias = jax.Array | np.ndarray[Any, Any]
44
+ JaxLayout: TypeAlias = jax.sharding.NamedSharding | jax_layout.Format
45
+
46
+
47
+ class EmbeddingLookupConfiguration:
48
+ """Feature, mesh and sharding configrmation for lookups."""
49
+
50
+ mesh: jax.sharding.Mesh
51
+ feature_specs: embedding.Nested[embedding_spec.FeatureSpec]
52
+ table_sharding_strategy: str
53
+ num_sc_per_device: int
54
+ samples_partition: jax.sharding.PartitionSpec
55
+ table_partition: jax.sharding.PartitionSpec
56
+ table_layout: JaxLayout
57
+
58
+ def __init__(
59
+ self,
60
+ feature_specs: embedding.Nested[embedding_spec.FeatureSpec],
61
+ mesh: jax.sharding.Mesh | None = None,
62
+ table_sharding_strategy: str = "MOD",
63
+ num_sc_per_device: int | None = None,
64
+ sharding_axis: str = "sparsecore_sharding",
65
+ samples_partition: jax.sharding.PartitionSpec | None = None,
66
+ samples_layout: JaxLayout | None = None,
67
+ table_partition: jax.sharding.PartitionSpec | None = None,
68
+ table_layout: JaxLayout | None = None,
69
+ ):
70
+ self.mesh = mesh or jax.sharding.Mesh(jax.devices(), sharding_axis)
71
+ self.feature_specs = feature_specs
72
+ self.table_sharding_strategy = table_sharding_strategy
73
+ self.num_sc_per_device = (
74
+ num_sc_per_device
75
+ if num_sc_per_device is not None
76
+ else jte_utils.num_sparsecores_per_device()
77
+ )
78
+ self.samples_partition = (
79
+ samples_partition
80
+ or jax.sharding.PartitionSpec(
81
+ sharding_axis # type: ignore[no-untyped-call]
82
+ )
83
+ )
84
+ self.samples_layout = samples_layout or jax.sharding.NamedSharding(
85
+ self.mesh, self.samples_partition
86
+ )
87
+ self.table_partition = table_partition or jax.sharding.PartitionSpec(
88
+ sharding_axis,
89
+ None, # type: ignore[no-untyped-call]
90
+ )
91
+ self.table_layout = table_layout or jax.sharding.NamedSharding(
92
+ self.mesh, self.table_partition
93
+ )
94
+
95
+
96
+ # Embedding lookup function with custom gradient.
97
+ @functools.partial(jax.custom_vjp, nondiff_argnums=(0,))
98
+ def embedding_lookup(
99
+ config: EmbeddingLookupConfiguration,
100
+ lookups: Mapping[str, ShardedCooMatrix],
101
+ tables: Nested[jax.Array],
102
+ step: jax.Array | None = None,
103
+ ) -> Nested[jax.Array]:
104
+ """Embedding lookup function with custom gradient.
105
+
106
+ Args:
107
+ config: Embedding lookup configuration.
108
+ lookups: Embedding lookup stacked/sharded inputs.
109
+ tables: Embedding lookup stacked/sharded tables.
110
+ step: Current training step number.
111
+ """
112
+ del step # Only used in backward pass.
113
+
114
+ # Decompose COO matrices.
115
+ row_pointers_raw = {}
116
+ embedding_ids_raw = {}
117
+ sample_ids_raw = {}
118
+ gains_raw = {}
119
+ for table_name, coo in lookups.items():
120
+ row_pointers_raw[table_name] = coo.shard_ends
121
+ embedding_ids_raw[table_name] = coo.col_ids
122
+ sample_ids_raw[table_name] = coo.row_ids
123
+ gains_raw[table_name] = coo.values
124
+
125
+ sparse_dense_matmul_input = embedding.SparseDenseMatmulInput(
126
+ row_pointers_raw,
127
+ embedding_ids_raw,
128
+ sample_ids_raw,
129
+ gains_raw,
130
+ )
131
+
132
+ pd = config.samples_partition
133
+ pt = config.table_partition
134
+ sharded_matmul = jax.jit(
135
+ shard_map(
136
+ functools.partial(
137
+ embedding.tpu_sparse_dense_matmul,
138
+ global_device_count=config.mesh.shape[pd[0]],
139
+ feature_specs=config.feature_specs,
140
+ sharding_strategy=config.table_sharding_strategy,
141
+ ),
142
+ mesh=config.mesh,
143
+ in_specs=(pd, pt),
144
+ out_specs=pd,
145
+ check_vma=False,
146
+ ),
147
+ )
148
+
149
+ activations: Nested[jax.Array] = sharded_matmul(
150
+ sparse_dense_matmul_input,
151
+ tables,
152
+ )
153
+
154
+ return activations
155
+
156
+
157
+ def embedding_lookup_fwd(
158
+ config: EmbeddingLookupConfiguration,
159
+ lookups: Mapping[str, ShardedCooMatrix],
160
+ table: Nested[jax.Array],
161
+ step: jax.Array | None = None,
162
+ ) -> tuple[
163
+ Nested[jax.Array],
164
+ tuple[Nested[ShardedCooMatrix], Nested[jax.Array], jax.Array | None],
165
+ ]:
166
+ """Forward pass for embedding lookup."""
167
+ return embedding_lookup(config, lookups, table, step), (
168
+ lookups,
169
+ table,
170
+ step,
171
+ )
172
+
173
+
174
+ def embedding_lookup_bwd(
175
+ config: EmbeddingLookupConfiguration,
176
+ res: tuple[
177
+ Mapping[str, ShardedCooMatrix], # Lookups.
178
+ Mapping[str, Nested[jax.Array]], # Tables.
179
+ jax.Array | None, # Step.
180
+ ],
181
+ gradients: Nested[jax.Array],
182
+ ) -> tuple[None, Nested[jax.Array], jax.Array | None]:
183
+ """Backward pass for embedding lookup.
184
+
185
+ Args:
186
+ config: Embedding lookup configuration.
187
+ res: Tuple of embedding lookup (inputs, tables, step).
188
+ gradients: Embedding lookup gradients.
189
+
190
+ Returns:
191
+ A tuple of gradients (None, table_grads, step + 1).
192
+ """
193
+ lookups, tables, step = res
194
+
195
+ # Decompose COO matrices.
196
+ row_pointers_raw = {}
197
+ embedding_ids_raw = {}
198
+ sample_ids_raw = {}
199
+ gains_raw = {}
200
+ for table_name, coo in lookups.items():
201
+ row_pointers_raw[table_name] = coo.shard_ends
202
+ embedding_ids_raw[table_name] = coo.col_ids
203
+ sample_ids_raw[table_name] = coo.row_ids
204
+ gains_raw[table_name] = coo.values
205
+
206
+ sparse_dense_matmul_input = embedding.SparseDenseMatmulInput(
207
+ row_pointers_raw,
208
+ embedding_ids_raw,
209
+ sample_ids_raw,
210
+ gains_raw,
211
+ )
212
+
213
+ pt = config.table_partition
214
+ pd = config.samples_partition
215
+ # Replicate step count.
216
+ preplicate = jax.sharding.PartitionSpec() # type: ignore[no-untyped-call]
217
+
218
+ def grad_func(
219
+ gradients: Nested[jax.Array],
220
+ sparse_input: embedding.SparseDenseMatmulInput,
221
+ tables: Mapping[str, embedding.EmbeddingVariables],
222
+ step: jax.Array | None,
223
+ ) -> Mapping[str, embedding.EmbeddingVariables]:
224
+ output: Mapping[str, embedding.EmbeddingVariables] = (
225
+ embedding.tpu_sparse_dense_matmul_grad(
226
+ gradients,
227
+ sparse_input,
228
+ tables,
229
+ feature_specs=config.feature_specs,
230
+ sharding_strategy=config.table_sharding_strategy,
231
+ step=step,
232
+ )
233
+ )
234
+ return output
235
+
236
+ # activation_layout = jax.sharding.NamedSharding(config.mesh, pd)
237
+ # step_layout = jax.sharding.NamedSharding(config.mesh, preplicate)
238
+ sharded_matmul_grad = jax.jit(
239
+ shard_map(
240
+ grad_func,
241
+ mesh=config.mesh,
242
+ in_specs=(pd, pd, pt, preplicate),
243
+ out_specs=pt,
244
+ check_vma=False,
245
+ ),
246
+ # in_shardings=(
247
+ # activation_layout,
248
+ # config.samples_layout,
249
+ # config.table_layout,
250
+ # step_layout,
251
+ # ),
252
+ # out_shardings=(config.table_layout),
253
+ )
254
+
255
+ table_grads = sharded_matmul_grad(
256
+ gradients,
257
+ sparse_dense_matmul_input,
258
+ tables,
259
+ step,
260
+ )
261
+
262
+ # tpu_sparse_dense_matmul_grad returns a general Mapping (usually a dict).
263
+ # It may not be the same type as the embedding table (e.g. FrozenDict).
264
+ # Here we use flatten / unflatten to ensure the types are the same.
265
+ table_grads = jax.tree.unflatten(
266
+ jax.tree.structure(tables), jax.tree.leaves(table_grads)
267
+ )
268
+
269
+ return (
270
+ None,
271
+ table_grads,
272
+ step + 1 if step is not None else None, # Incremented step count.
273
+ )
274
+
275
+
276
+ embedding_lookup.defvjp(embedding_lookup_fwd, embedding_lookup_bwd)
@@ -0,0 +1,217 @@
1
+ """Utility functions for manipulating JAX embedding tables and inputs."""
2
+
3
+ import collections
4
+ from typing import Any, Mapping, NamedTuple, Sequence, TypeAlias, TypeVar
5
+
6
+ import jax
7
+ import numpy as np
8
+ from jax_tpu_embedding.sparsecore.lib.nn import embedding
9
+ from jax_tpu_embedding.sparsecore.lib.nn import table_stacking
10
+ from jax_tpu_embedding.sparsecore.lib.nn.embedding_spec import FeatureSpec
11
+
12
+ from keras_rs.src.types import Nested
13
+
14
+ T = TypeVar("T")
15
+
16
+ # Any to support tf.Ragged without needing an explicit TF dependency.
17
+ ArrayLike: TypeAlias = jax.Array | np.ndarray | Any # type: ignore
18
+ Shape: TypeAlias = tuple[int, ...]
19
+
20
+
21
+ class FeatureSamples(NamedTuple):
22
+ tokens: ArrayLike
23
+ weights: ArrayLike | None
24
+
25
+
26
+ class ShardedCooMatrix(NamedTuple):
27
+ shard_starts: ArrayLike
28
+ shard_ends: ArrayLike
29
+ col_ids: ArrayLike
30
+ row_ids: ArrayLike
31
+ values: ArrayLike
32
+
33
+
34
+ def convert_to_numpy(
35
+ ragged_or_dense: np.ndarray[Any, Any] | Sequence[Sequence[Any]] | Any,
36
+ dtype: Any,
37
+ ) -> np.ndarray[Any, Any]:
38
+ """Converts a ragged or dense list of inputs to a ragged/dense numpy array.
39
+
40
+ The output is adjusted to be 2D.
41
+
42
+ Args:
43
+ ragged_or_dense: Input that is either already a numpy array, or nested
44
+ sequence.
45
+ dtype: Numpy dtype of output array.
46
+
47
+ Returns:
48
+ Corresponding numpy array.
49
+ """
50
+ if hasattr(ragged_or_dense, "numpy"):
51
+ # Support tf.RaggedTensor and other TF input dtypes.
52
+ if callable(getattr(ragged_or_dense, "numpy")):
53
+ ragged_or_dense = ragged_or_dense.numpy()
54
+
55
+ if isinstance(ragged_or_dense, jax.Array):
56
+ ragged_or_dense = np.asarray(ragged_or_dense)
57
+
58
+ if isinstance(ragged_or_dense, np.ndarray):
59
+ # Convert 1D to 2D.
60
+ if ragged_or_dense.dtype != np.ndarray and ragged_or_dense.ndim == 1:
61
+ return ragged_or_dense.reshape(-1, 1).astype(dtype)
62
+
63
+ # If dense, return converted dense type.
64
+ if ragged_or_dense.dtype != np.ndarray:
65
+ return ragged_or_dense.astype(dtype)
66
+
67
+ # Ragged numpy array.
68
+ return ragged_or_dense
69
+
70
+ # Handle 1D sequence input.
71
+ if not isinstance(ragged_or_dense[0], collections.abc.Sequence):
72
+ return np.asarray(ragged_or_dense, dtype=dtype).reshape(-1, 1)
73
+
74
+ # Assemble elements into an nd-array.
75
+ counts = [len(vals) for vals in ragged_or_dense]
76
+ if all([count == counts[0] for count in counts]):
77
+ # Dense input.
78
+ return np.asarray(ragged_or_dense, dtype=dtype)
79
+ else:
80
+ # Ragged input, convert to ragged numpy arrays.
81
+ return np.array(
82
+ [np.array(row, dtype=dtype) for row in ragged_or_dense],
83
+ dtype=np.ndarray,
84
+ )
85
+
86
+
87
+ def create_feature_samples(
88
+ feature_structure: Nested[T],
89
+ feature_ids: Nested[ArrayLike | Sequence[int] | Sequence[Sequence[int]]],
90
+ feature_weights: None
91
+ | (Nested[ArrayLike | Sequence[float] | Sequence[Sequence[float]]]),
92
+ ) -> Nested[FeatureSamples]:
93
+ """Constructs a collection of sample tuples from provided IDs and weights.
94
+
95
+ Args:
96
+ feature_structure: The nested structure of the inputs (typically
97
+ `FeatureSpec`s).
98
+ feature_ids: The feature IDs to use for the samples.
99
+ feature_weights: The feature weights to use for the samples. Defaults
100
+ to ones if not provided.
101
+
102
+ Returns:
103
+ A nested collection of `FeatureSamples` corresponding to the input IDs
104
+ and weights, for use in embedding lookups.
105
+ """
106
+ # Create numpy arrays from inputs.
107
+ feature_ids = jax.tree.map(
108
+ lambda _, ids: convert_to_numpy(ids, np.int32),
109
+ feature_structure,
110
+ feature_ids,
111
+ )
112
+
113
+ if feature_weights is None:
114
+ return jax.tree.map( # type: ignore[no-any-return]
115
+ lambda _, ids: FeatureSamples(ids, None),
116
+ feature_structure,
117
+ feature_ids,
118
+ )
119
+
120
+ feature_weights = jax.tree.map(
121
+ lambda _, wgts: convert_to_numpy(wgts, np.float32),
122
+ feature_structure,
123
+ feature_weights,
124
+ )
125
+
126
+ # Assemble.
127
+ def _create_feature_samples(
128
+ sample_ids: np.ndarray[Any, Any],
129
+ sample_weights: np.ndarray[Any, Any],
130
+ ) -> FeatureSamples:
131
+ return FeatureSamples(sample_ids, sample_weights)
132
+
133
+ output: Nested[FeatureSamples] = jax.tree.map(
134
+ lambda _, sample_ids, sample_weights: _create_feature_samples(
135
+ sample_ids, sample_weights
136
+ ),
137
+ feature_structure,
138
+ feature_ids,
139
+ feature_weights,
140
+ )
141
+ return output
142
+
143
+
144
+ def stack_and_shard_samples(
145
+ feature_specs: Nested[FeatureSpec],
146
+ feature_samples: Nested[FeatureSamples],
147
+ local_device_count: int,
148
+ global_device_count: int,
149
+ num_sc_per_device: int,
150
+ static_buffer_size: int | Mapping[str, int] | None = None,
151
+ ) -> tuple[dict[str, ShardedCooMatrix], embedding.SparseDenseMatmulInputStats]:
152
+ """Prepares input samples for use in embedding lookups.
153
+
154
+ Args:
155
+ feature_specs: Nested collection of feature specifications.
156
+ feature_samples: Nested collection of feature samples.
157
+ local_device_count: Number of local JAX devices.
158
+ global_device_count: Number of global JAX devices.
159
+ num_sc_per_device: Number of sparsecores per device.
160
+ static_buffer_size: The static buffer size to use for the samples.
161
+ Defaults to None, in which case an upper-bound for the buffer size
162
+ will be automatically determined.
163
+
164
+ Returns:
165
+ The preprocessed inputs, and statistics useful for updating FeatureSpecs
166
+ based on the provided input data.
167
+ """
168
+ del static_buffer_size # Currently ignored.
169
+ flat_feature_specs, _ = jax.tree.flatten(feature_specs)
170
+
171
+ feature_tokens = []
172
+ collected_weights = []
173
+
174
+ def collect_tokens_and_weights(
175
+ feature_spec: FeatureSpec, samples: FeatureSamples
176
+ ) -> None:
177
+ del feature_spec
178
+ feature_tokens.append(samples.tokens)
179
+ collected_weights.append(samples.weights)
180
+
181
+ jax.tree.map(collect_tokens_and_weights, feature_specs, feature_samples)
182
+
183
+ feature_weights = (
184
+ None if all(w is None for w in collected_weights) else collected_weights
185
+ )
186
+
187
+ preprocessed_inputs, stats = embedding.preprocess_sparse_dense_matmul_input(
188
+ feature_tokens,
189
+ feature_weights,
190
+ flat_feature_specs,
191
+ local_device_count=local_device_count,
192
+ global_device_count=global_device_count,
193
+ num_sc_per_device=num_sc_per_device,
194
+ sharding_strategy="MOD",
195
+ has_leading_dimension=False,
196
+ allow_id_dropping=True,
197
+ )
198
+
199
+ out: dict[str, ShardedCooMatrix] = {}
200
+ tables_names = preprocessed_inputs.lhs_row_pointers.keys()
201
+ for table_name in tables_names:
202
+ shard_ends = preprocessed_inputs.lhs_row_pointers[table_name]
203
+ shard_starts = np.concatenate(
204
+ [
205
+ np.asarray([0]),
206
+ table_stacking._next_largest_multiple(shard_ends[:-1], 8),
207
+ ]
208
+ )
209
+ out[table_name] = ShardedCooMatrix(
210
+ shard_starts=shard_starts,
211
+ shard_ends=shard_ends,
212
+ col_ids=preprocessed_inputs.lhs_embedding_ids[table_name],
213
+ row_ids=preprocessed_inputs.lhs_sample_ids[table_name],
214
+ values=preprocessed_inputs.lhs_gains[table_name],
215
+ )
216
+
217
+ return out, stats
File without changes