keras-rs-nightly 0.0.1.dev2025021903__py3-none-any.whl → 0.3.1.dev202512130338__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. keras_rs/__init__.py +9 -28
  2. keras_rs/layers/__init__.py +37 -0
  3. keras_rs/losses/__init__.py +19 -0
  4. keras_rs/metrics/__init__.py +16 -0
  5. keras_rs/src/layers/embedding/base_distributed_embedding.py +1151 -0
  6. keras_rs/src/layers/embedding/distributed_embedding.py +33 -0
  7. keras_rs/src/layers/embedding/distributed_embedding_config.py +132 -0
  8. keras_rs/src/layers/embedding/embed_reduce.py +309 -0
  9. keras_rs/src/layers/embedding/jax/__init__.py +0 -0
  10. keras_rs/src/layers/embedding/jax/checkpoint_utils.py +104 -0
  11. keras_rs/src/layers/embedding/jax/config_conversion.py +468 -0
  12. keras_rs/src/layers/embedding/jax/distributed_embedding.py +829 -0
  13. keras_rs/src/layers/embedding/jax/embedding_lookup.py +276 -0
  14. keras_rs/src/layers/embedding/jax/embedding_utils.py +217 -0
  15. keras_rs/src/layers/embedding/tensorflow/__init__.py +0 -0
  16. keras_rs/src/layers/embedding/tensorflow/config_conversion.py +363 -0
  17. keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py +436 -0
  18. keras_rs/src/layers/feature_interaction/__init__.py +0 -0
  19. keras_rs/src/layers/{modeling → feature_interaction}/dot_interaction.py +116 -25
  20. keras_rs/src/layers/{modeling → feature_interaction}/feature_cross.py +40 -22
  21. keras_rs/src/layers/retrieval/brute_force_retrieval.py +16 -65
  22. keras_rs/src/layers/retrieval/hard_negative_mining.py +94 -0
  23. keras_rs/src/layers/retrieval/remove_accidental_hits.py +97 -0
  24. keras_rs/src/layers/retrieval/retrieval.py +127 -0
  25. keras_rs/src/layers/retrieval/sampling_probability_correction.py +63 -0
  26. keras_rs/src/losses/__init__.py +0 -0
  27. keras_rs/src/losses/list_mle_loss.py +212 -0
  28. keras_rs/src/losses/pairwise_hinge_loss.py +90 -0
  29. keras_rs/src/losses/pairwise_logistic_loss.py +99 -0
  30. keras_rs/src/losses/pairwise_loss.py +165 -0
  31. keras_rs/src/losses/pairwise_loss_utils.py +39 -0
  32. keras_rs/src/losses/pairwise_mean_squared_error.py +133 -0
  33. keras_rs/src/losses/pairwise_soft_zero_one_loss.py +98 -0
  34. keras_rs/src/metrics/__init__.py +0 -0
  35. keras_rs/src/metrics/dcg.py +161 -0
  36. keras_rs/src/metrics/mean_average_precision.py +130 -0
  37. keras_rs/src/metrics/mean_reciprocal_rank.py +121 -0
  38. keras_rs/src/metrics/ndcg.py +197 -0
  39. keras_rs/src/metrics/precision_at_k.py +117 -0
  40. keras_rs/src/metrics/ranking_metric.py +260 -0
  41. keras_rs/src/metrics/ranking_metrics_utils.py +257 -0
  42. keras_rs/src/metrics/recall_at_k.py +108 -0
  43. keras_rs/src/metrics/utils.py +70 -0
  44. keras_rs/src/types.py +43 -14
  45. keras_rs/src/utils/doc_string_utils.py +53 -0
  46. keras_rs/src/utils/keras_utils.py +52 -3
  47. keras_rs/src/utils/tpu_test_utils.py +120 -0
  48. keras_rs/src/version.py +1 -1
  49. {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/METADATA +88 -8
  50. keras_rs_nightly-0.3.1.dev202512130338.dist-info/RECORD +58 -0
  51. {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/WHEEL +1 -1
  52. keras_rs/api/__init__.py +0 -9
  53. keras_rs/api/layers/__init__.py +0 -11
  54. keras_rs_nightly-0.0.1.dev2025021903.dist-info/RECORD +0 -19
  55. /keras_rs/src/layers/{modeling → embedding}/__init__.py +0 -0
  56. {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,468 @@
1
+ """Conversion utilities for Keras DistributedEmbeddingConfig to JAX."""
2
+
3
+ import inspect
4
+ import math
5
+ import random as python_random
6
+ from typing import Any, Callable
7
+
8
+ import jax
9
+ import jax.numpy as jnp
10
+ import keras
11
+ import numpy as np
12
+ from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec
13
+
14
+ from keras_rs.src import types
15
+ from keras_rs.src.layers.embedding import distributed_embedding_config as config
16
+ from keras_rs.src.types import Nested
17
+
18
+
19
+ class WrappedKerasInitializer(jax.nn.initializers.Initializer):
20
+ """Wraps a Keras initializer for use in JAX."""
21
+
22
+ def __init__(self, initializer: keras.initializers.Initializer):
23
+ if isinstance(initializer, str):
24
+ initializer = keras.initializers.get(initializer)
25
+ self.initializer = initializer
26
+
27
+ def key(self) -> jax.Array | None:
28
+ """Extract a key from the underlying keras initializer."""
29
+ # All built-in keras initializers have a `seed` attribute.
30
+ # Extract this and turn it into a key for use with JAX.
31
+ if hasattr(self.initializer, "seed"):
32
+ output: jax.Array = keras.src.backend.jax.random.jax_draw_seed(
33
+ self.initializer.seed
34
+ )
35
+ return output
36
+ return None
37
+
38
+ def __call__(
39
+ self,
40
+ key: Any,
41
+ shape: Any,
42
+ dtype: Any = jnp.float_,
43
+ out_sharding: Any = None,
44
+ ) -> jax.Array:
45
+ del out_sharding
46
+ # Force use of provided key. The JAX backend for random initializers
47
+ # forwards the `seed` attribute to the underlying JAX random functions.
48
+ if key is not None and hasattr(self.initializer, "seed"):
49
+ old_seed = self.initializer.seed
50
+ self.initializer.seed = key
51
+ out: jax.Array = self.initializer(shape, dtype)
52
+ self.initializer.seed = old_seed
53
+ return out
54
+
55
+ output: jax.Array = self.initializer(shape, dtype)
56
+ return output
57
+
58
+
59
+ # pylint: disable-next=g-classes-have-attributes
60
+ class WrappedJaxInitializer(keras.initializers.Initializer):
61
+ """Wraps a JAX initializer for use in Keras.
62
+
63
+ Attributes:
64
+ initializer: The wrapped JAX initializer.
65
+
66
+ Args:
67
+ initializer: The JAX initializer to wrap.
68
+ seed: Optional Keras seed for use with random JAX initializers.
69
+ """
70
+
71
+ def __init__(
72
+ self,
73
+ initializer: jax.nn.initializers.Initializer,
74
+ seed: int | keras.random.SeedGenerator | None = None,
75
+ ):
76
+ self.initializer = initializer
77
+ if seed is None:
78
+ # Consistency with keras.random.make_default_seed().
79
+ seed = python_random.randint(1, int(1e9))
80
+ self.seed = seed
81
+
82
+ def key(self) -> jax.Array:
83
+ """Converts the interal seed to a JAX random key."""
84
+ seed = self.seed
85
+ if isinstance(seed, int):
86
+ return jax.random.key(self.seed)
87
+ elif isinstance(seed, keras.random.SeedGenerator):
88
+ return jax.random.key(seed.next())
89
+ elif isinstance(seed, jax.Array):
90
+ return seed
91
+ else:
92
+ raise ValueError(f"Unknown seed {seed} of type {type(seed)}.")
93
+
94
+ def __call__(
95
+ self,
96
+ shape: types.Shape,
97
+ dtype: types.DType | None = None,
98
+ **kwargs: Any,
99
+ ) -> jax.Array:
100
+ del kwargs # Unused.
101
+ return self.initializer(self.key(), shape, dtype)
102
+
103
+
104
+ def keras_to_jax_initializer(
105
+ initializer: str | keras.initializers.Initializer,
106
+ ) -> jax.nn.initializers.Initializer:
107
+ """Converts a Keras initializer to a JAX initializer.
108
+
109
+ Args:
110
+ initializer: Keras initializer to convert.
111
+
112
+ Returns:
113
+ A JAX-compatible equivalent initializer.
114
+ """
115
+ if isinstance(initializer, WrappedJaxInitializer):
116
+ return initializer.initializer
117
+ return WrappedKerasInitializer(initializer)
118
+
119
+
120
+ def jax_to_keras_initializer(
121
+ initializer: jax.nn.initializers.Initializer,
122
+ ) -> keras.initializers.Initializer:
123
+ """Converts a JAX initializer to a Keras initializer.
124
+
125
+ Args:
126
+ initializer: JAX initializer to convert.
127
+
128
+ Returns:
129
+ An equivalent Keras initializer.
130
+ """
131
+ if isinstance(initializer, WrappedKerasInitializer):
132
+ return initializer.initializer
133
+ return WrappedJaxInitializer(initializer)
134
+
135
+
136
+ def keras_to_jte_learning_rate(
137
+ learning_rate: keras.Variable | float | Callable[..., float],
138
+ ) -> float | Callable[..., float]:
139
+ """Converts a Keras learning rate to a JAX TPU Embedding learning rate.
140
+
141
+ Args:
142
+ learning_rate: Any Keras-compatible learning-rate type. If a Callable,
143
+ must either take no parameters, or the step size as a single argument.
144
+
145
+ Returns:
146
+ A JAX TPU Embedding learning rate.
147
+
148
+ Raises:
149
+ ValueError if the learning rate is not supported.
150
+ """
151
+
152
+ # Supported keras optimizer general options.
153
+ if isinstance(learning_rate, keras.Variable):
154
+ # Extract the first (and only) element of the variable.
155
+ learning_rate = np.array(learning_rate.value, dtype=float)
156
+ assert learning_rate.size == 1
157
+ lr_float: float = learning_rate.item(0)
158
+ return lr_float
159
+ elif callable(learning_rate):
160
+ # Callable learning rate functions are expected to take a singular step
161
+ # count argument, or no arguments.
162
+ args = inspect.getfullargspec(learning_rate).args
163
+ # If not a function, then it's an object instance with `self` as the
164
+ # first argument.
165
+ num_args = (
166
+ len(args) if inspect.isfunction(learning_rate) else len(args) - 1
167
+ )
168
+ if num_args <= 1:
169
+ return learning_rate
170
+ elif isinstance(learning_rate, float):
171
+ return learning_rate
172
+
173
+ raise ValueError(
174
+ f"Unsupported learning rate: {learning_rate} of type"
175
+ f" {type(learning_rate)}."
176
+ )
177
+
178
+
179
+ def jte_to_keras_learning_rate(
180
+ learning_rate: float | Callable[..., float],
181
+ ) -> float | Callable[..., float]:
182
+ """Converts a JAX TPU Embedding learning rate to a Keras learning rate.
183
+
184
+ Args:
185
+ learning_rate: The learning rate value or function. If a Callable, must
186
+ either take no parameters, or the step size as a single argument.
187
+
188
+ Returns:
189
+ A JAX TPU Embedding learning rate.
190
+
191
+ Raises:
192
+ ValueError if the learning rate is not supported.
193
+ """
194
+ if callable(learning_rate):
195
+ # Callable learning rate functions are expected to take a singular step
196
+ # count argument, or no arguments.
197
+ args = inspect.getfullargspec(learning_rate).args
198
+ # If not a function, then it's an object instance, with `self` as the
199
+ # first arguments.
200
+ num_args = (
201
+ len(args) if inspect.isfunction(learning_rate) else len(args) - 1
202
+ )
203
+ if num_args <= 1:
204
+ return learning_rate
205
+ elif isinstance(learning_rate, float):
206
+ return learning_rate
207
+
208
+ raise ValueError(f"Unknown learning rate {learning_rate}")
209
+
210
+
211
+ def keras_to_jte_optimizer(
212
+ optimizer: keras.optimizers.Optimizer | str,
213
+ ) -> embedding_spec.OptimizerSpec:
214
+ """Converts a Keras optimizer to a JAX TPU Embedding optimizer.
215
+
216
+ Args:
217
+ optimizer: Any Keras-compatible optimizer.
218
+
219
+ Returns:
220
+ A JAX TPU Embedding optimizer.
221
+ """
222
+ if isinstance(optimizer, str):
223
+ optimizer = keras.optimizers.get(optimizer)
224
+
225
+ # We need to extract the actual internal learning_rate function.
226
+ # Unfortunately, the optimizer.learning_rate attribute tries to be smart,
227
+ # and evaluates the learning rate at the current iteration step, which is
228
+ # not what we want.
229
+ # pylint: disable-next=protected-access
230
+ learning_rate = keras_to_jte_learning_rate(optimizer._learning_rate)
231
+
232
+ # Unsupported keras optimizer general options.
233
+ if optimizer.clipnorm is not None:
234
+ raise ValueError("Unsupported optimizer option `clipnorm`.")
235
+ if optimizer.global_clipnorm is not None:
236
+ raise ValueError("Unsupported optimizer option `global_clipnorm`.")
237
+ if optimizer.use_ema:
238
+ raise ValueError("Unsupported optimizer option `use_ema`.")
239
+ if optimizer.loss_scale_factor is not None:
240
+ raise ValueError("Unsupported optimizer option `loss_scale_factor`.")
241
+
242
+ # Supported optimizers.
243
+ if isinstance(optimizer, keras.optimizers.SGD):
244
+ if getattr(optimizer, "nesterov", False):
245
+ raise ValueError("Unsupported optimizer option `nesterov`.")
246
+ if getattr(optimizer, "momentum", 0.0) != 0.0:
247
+ raise ValueError("Unsupported optimizer option `momentum`.")
248
+ return embedding_spec.SGDOptimizerSpec(learning_rate=learning_rate)
249
+ elif isinstance(optimizer, keras.optimizers.Adagrad):
250
+ if getattr(optimizer, "epsilon", 1e-7) != 1e-7:
251
+ raise ValueError("Unsupported optimizer option `epsilon`.")
252
+ return embedding_spec.AdagradOptimizerSpec(
253
+ learning_rate=learning_rate,
254
+ initial_accumulator_value=optimizer.initial_accumulator_value,
255
+ )
256
+ elif isinstance(optimizer, keras.optimizers.Adam):
257
+ if getattr(optimizer, "amsgrad", False):
258
+ raise ValueError("Unsupported optimizer option `amsgrad`.")
259
+
260
+ return embedding_spec.AdamOptimizerSpec(
261
+ learning_rate=learning_rate,
262
+ beta_1=optimizer.beta_1,
263
+ beta_2=optimizer.beta_2,
264
+ epsilon=optimizer.epsilon,
265
+ )
266
+ elif isinstance(optimizer, keras.optimizers.Ftrl):
267
+ if (
268
+ getattr(optimizer, "l2_shrinkage_regularization_strength", 0.0)
269
+ != 0.0
270
+ ):
271
+ raise ValueError(
272
+ "Unsupported optimizer option "
273
+ "`l2_shrinkage_regularization_strength`."
274
+ )
275
+
276
+ return embedding_spec.FTRLOptimizerSpec(
277
+ learning_rate=learning_rate,
278
+ learning_rate_power=optimizer.learning_rate_power,
279
+ l1_regularization_strength=optimizer.l1_regularization_strength,
280
+ l2_regularization_strength=optimizer.l2_regularization_strength,
281
+ beta=optimizer.beta,
282
+ initial_accumulator_value=optimizer.initial_accumulator_value,
283
+ )
284
+
285
+ raise ValueError(
286
+ f"Unsupported optimizer type {type(optimizer)}. Optimizer must be "
287
+ f"one of [Adagrad, Adam, Ftrl, SGD]."
288
+ )
289
+
290
+
291
+ def jte_to_keras_optimizer(
292
+ optimizer: embedding_spec.OptimizerSpec,
293
+ ) -> keras.optimizers.Optimizer:
294
+ """Converts a JAX TPU Embedding optimizer to a Keras optimizer.
295
+
296
+ Args:
297
+ optimizer: The JAX TPU Embedding optimizer.
298
+
299
+ Returns:
300
+ A corresponding Keras optimizer.
301
+ """
302
+ learning_rate = jte_to_keras_learning_rate(optimizer.learning_rate)
303
+ if isinstance(optimizer, embedding_spec.SGDOptimizerSpec):
304
+ return keras.optimizers.SGD(learning_rate=learning_rate)
305
+ elif isinstance(optimizer, embedding_spec.AdagradOptimizerSpec):
306
+ return keras.optimizers.Adagrad(
307
+ learning_rate=learning_rate,
308
+ initial_accumulator_value=optimizer.initial_accumulator_value,
309
+ )
310
+ elif isinstance(optimizer, embedding_spec.AdamOptimizerSpec):
311
+ return keras.optimizers.Adam(
312
+ learning_rate=learning_rate,
313
+ beta_1=optimizer.beta_1,
314
+ beta_2=optimizer.beta_2,
315
+ epsilon=optimizer.epsilon,
316
+ )
317
+ elif isinstance(optimizer, embedding_spec.FTRLOptimizerSpec):
318
+ if getattr(optimizer, "initial_linear_value", 0.0) != 0.0:
319
+ raise ValueError(
320
+ "Unsupported optimizer option `initial_linear_value`."
321
+ )
322
+ if getattr(optimizer, "multiply_linear_by_learning_rate", False):
323
+ raise ValueError(
324
+ "Unsupported optimizer option "
325
+ "`multiply_linear_by_learning_rate`."
326
+ )
327
+ return keras.optimizers.Ftrl(
328
+ learning_rate=learning_rate,
329
+ learning_rate_power=optimizer.learning_rate_power,
330
+ initial_accumulator_value=optimizer.initial_accumulator_value,
331
+ l1_regularization_strength=optimizer.l1_regularization_strength,
332
+ l2_regularization_strength=optimizer.l2_regularization_strength,
333
+ beta=optimizer.beta,
334
+ )
335
+
336
+ raise ValueError(f"Unknown optimizer spec {type(optimizer)}.")
337
+
338
+
339
+ def _keras_to_jte_table_config(
340
+ table_config: config.TableConfig,
341
+ ) -> embedding_spec.TableSpec:
342
+ # Initializer could be none. Default to truncated normal.
343
+ initializer = table_config.initializer
344
+ if initializer is None:
345
+ initializer = keras.initializers.TruncatedNormal(
346
+ mean=0.0, stddev=1.0 / math.sqrt(float(table_config.embedding_dim))
347
+ )
348
+ return embedding_spec.TableSpec(
349
+ name=table_config.name,
350
+ vocabulary_size=table_config.vocabulary_size,
351
+ embedding_dim=table_config.embedding_dim,
352
+ initializer=keras_to_jax_initializer(initializer),
353
+ optimizer=keras_to_jte_optimizer(table_config.optimizer),
354
+ combiner=table_config.combiner,
355
+ max_ids_per_partition=table_config.max_ids_per_partition,
356
+ max_unique_ids_per_partition=table_config.max_unique_ids_per_partition,
357
+ )
358
+
359
+
360
+ def keras_to_jte_table_configs(
361
+ table_configs: Nested[config.TableConfig],
362
+ ) -> Nested[embedding_spec.TableSpec]:
363
+ """Converts Keras RS `TableConfig`s to JAX TPU Embedding `TableSpec`s."""
364
+ return keras.tree.map_structure(
365
+ _keras_to_jte_table_config,
366
+ table_configs,
367
+ )
368
+
369
+
370
+ def _jte_to_keras_table_config(
371
+ table_spec: embedding_spec.TableSpec,
372
+ ) -> config.TableConfig:
373
+ return config.TableConfig(
374
+ name=table_spec.name,
375
+ vocabulary_size=table_spec.vocabulary_size,
376
+ embedding_dim=table_spec.embedding_dim,
377
+ initializer=jax_to_keras_initializer(table_spec.initializer),
378
+ optimizer=jte_to_keras_optimizer(table_spec.optimizer),
379
+ combiner=table_spec.combiner,
380
+ max_ids_per_partition=table_spec.max_ids_per_partition,
381
+ max_unique_ids_per_partition=table_spec.max_unique_ids_per_partition,
382
+ )
383
+
384
+
385
+ def jte_to_keras_table_configs(
386
+ table_specs: Nested[embedding_spec.TableSpec],
387
+ ) -> Nested[config.TableConfig]:
388
+ """Converts JAX TPU Embedding `TableSpec`s to Keras RS `TableConfig`s."""
389
+ output: Nested[config.TableConfig] = keras.tree.map_structure(
390
+ _jte_to_keras_table_config,
391
+ table_specs,
392
+ )
393
+ return output
394
+
395
+
396
+ def _keras_to_jte_feature_config(
397
+ feature_config: config.FeatureConfig,
398
+ table_spec_map: dict[str, embedding_spec.TableSpec],
399
+ ) -> embedding_spec.FeatureSpec:
400
+ table_spec = table_spec_map.get(feature_config.table.name, None)
401
+ if table_spec is None:
402
+ table_spec = _keras_to_jte_table_config(feature_config.table)
403
+ table_spec_map[feature_config.table.name] = table_spec
404
+
405
+ return embedding_spec.FeatureSpec(
406
+ name=feature_config.name,
407
+ table_spec=table_spec,
408
+ input_shape=feature_config.input_shape,
409
+ output_shape=feature_config.output_shape,
410
+ )
411
+
412
+
413
+ def keras_to_jte_feature_configs(
414
+ feature_configs: Nested[config.FeatureConfig],
415
+ ) -> Nested[embedding_spec.FeatureSpec]:
416
+ """Converts Keras RS `FeatureConfig`s to JAX TPU Embedding `FeatureSpec`s.
417
+
418
+ Args:
419
+ feature_configs: Keras RS feature configurations.
420
+
421
+ Returns:
422
+ JAX TPU Embedding feature specifications.
423
+ """
424
+ table_spec_map: dict[str, embedding_spec.TableSpec] = {}
425
+ return keras.tree.map_structure(
426
+ lambda feature_config: _keras_to_jte_feature_config(
427
+ feature_config, table_spec_map
428
+ ),
429
+ feature_configs,
430
+ )
431
+
432
+
433
+ def _jte_to_keras_feature_config(
434
+ feature_spec: embedding_spec.FeatureSpec,
435
+ table_config_map: dict[str, config.TableConfig],
436
+ ) -> config.FeatureConfig:
437
+ table_config = table_config_map.get(feature_spec.table_spec.name, None)
438
+ if table_config is None:
439
+ table_config = _jte_to_keras_table_config(feature_spec.table_spec)
440
+ table_config_map[feature_spec.table_spec.name] = table_config
441
+
442
+ return config.FeatureConfig(
443
+ name=feature_spec.name,
444
+ table=table_config,
445
+ input_shape=feature_spec.input_shape,
446
+ output_shape=feature_spec.output_shape,
447
+ )
448
+
449
+
450
+ def jte_to_keras_feature_configs(
451
+ feature_specs: Nested[embedding_spec.FeatureSpec],
452
+ ) -> Nested[config.FeatureConfig]:
453
+ """Converts JAX TPU Embedding `FeatureSpec`s to Keras RS `FeatureConfig`s.
454
+
455
+ Args:
456
+ feature_specs: JAX TPU Embedding feature specifications.
457
+
458
+ Returns:
459
+ Keras RS feature configurations.
460
+ """
461
+ table_config_map: dict[str, config.TableConfig] = {}
462
+ output: Nested[config.FeatureConfig] = keras.tree.map_structure(
463
+ lambda feature_spec: _jte_to_keras_feature_config(
464
+ feature_spec, table_config_map
465
+ ),
466
+ feature_specs,
467
+ )
468
+ return output