keras-rs-nightly 0.0.1.dev2025021903__py3-none-any.whl → 0.3.1.dev202512130338__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. keras_rs/__init__.py +9 -28
  2. keras_rs/layers/__init__.py +37 -0
  3. keras_rs/losses/__init__.py +19 -0
  4. keras_rs/metrics/__init__.py +16 -0
  5. keras_rs/src/layers/embedding/base_distributed_embedding.py +1151 -0
  6. keras_rs/src/layers/embedding/distributed_embedding.py +33 -0
  7. keras_rs/src/layers/embedding/distributed_embedding_config.py +132 -0
  8. keras_rs/src/layers/embedding/embed_reduce.py +309 -0
  9. keras_rs/src/layers/embedding/jax/__init__.py +0 -0
  10. keras_rs/src/layers/embedding/jax/checkpoint_utils.py +104 -0
  11. keras_rs/src/layers/embedding/jax/config_conversion.py +468 -0
  12. keras_rs/src/layers/embedding/jax/distributed_embedding.py +829 -0
  13. keras_rs/src/layers/embedding/jax/embedding_lookup.py +276 -0
  14. keras_rs/src/layers/embedding/jax/embedding_utils.py +217 -0
  15. keras_rs/src/layers/embedding/tensorflow/__init__.py +0 -0
  16. keras_rs/src/layers/embedding/tensorflow/config_conversion.py +363 -0
  17. keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py +436 -0
  18. keras_rs/src/layers/feature_interaction/__init__.py +0 -0
  19. keras_rs/src/layers/{modeling → feature_interaction}/dot_interaction.py +116 -25
  20. keras_rs/src/layers/{modeling → feature_interaction}/feature_cross.py +40 -22
  21. keras_rs/src/layers/retrieval/brute_force_retrieval.py +16 -65
  22. keras_rs/src/layers/retrieval/hard_negative_mining.py +94 -0
  23. keras_rs/src/layers/retrieval/remove_accidental_hits.py +97 -0
  24. keras_rs/src/layers/retrieval/retrieval.py +127 -0
  25. keras_rs/src/layers/retrieval/sampling_probability_correction.py +63 -0
  26. keras_rs/src/losses/__init__.py +0 -0
  27. keras_rs/src/losses/list_mle_loss.py +212 -0
  28. keras_rs/src/losses/pairwise_hinge_loss.py +90 -0
  29. keras_rs/src/losses/pairwise_logistic_loss.py +99 -0
  30. keras_rs/src/losses/pairwise_loss.py +165 -0
  31. keras_rs/src/losses/pairwise_loss_utils.py +39 -0
  32. keras_rs/src/losses/pairwise_mean_squared_error.py +133 -0
  33. keras_rs/src/losses/pairwise_soft_zero_one_loss.py +98 -0
  34. keras_rs/src/metrics/__init__.py +0 -0
  35. keras_rs/src/metrics/dcg.py +161 -0
  36. keras_rs/src/metrics/mean_average_precision.py +130 -0
  37. keras_rs/src/metrics/mean_reciprocal_rank.py +121 -0
  38. keras_rs/src/metrics/ndcg.py +197 -0
  39. keras_rs/src/metrics/precision_at_k.py +117 -0
  40. keras_rs/src/metrics/ranking_metric.py +260 -0
  41. keras_rs/src/metrics/ranking_metrics_utils.py +257 -0
  42. keras_rs/src/metrics/recall_at_k.py +108 -0
  43. keras_rs/src/metrics/utils.py +70 -0
  44. keras_rs/src/types.py +43 -14
  45. keras_rs/src/utils/doc_string_utils.py +53 -0
  46. keras_rs/src/utils/keras_utils.py +52 -3
  47. keras_rs/src/utils/tpu_test_utils.py +120 -0
  48. keras_rs/src/version.py +1 -1
  49. {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/METADATA +88 -8
  50. keras_rs_nightly-0.3.1.dev202512130338.dist-info/RECORD +58 -0
  51. {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/WHEEL +1 -1
  52. keras_rs/api/__init__.py +0 -9
  53. keras_rs/api/layers/__init__.py +0 -11
  54. keras_rs_nightly-0.0.1.dev2025021903.dist-info/RECORD +0 -19
  55. /keras_rs/src/layers/{modeling → embedding}/__init__.py +0 -0
  56. {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,363 @@
1
+ import collections
2
+ from typing import Any, Sequence, TypeAlias
3
+
4
+ import keras
5
+ import tensorflow as tf
6
+
7
+ from keras_rs.src import types
8
+ from keras_rs.src.layers.embedding import distributed_embedding_config
9
+
10
+ FeatureConfig = distributed_embedding_config.FeatureConfig
11
+ TableConfig = distributed_embedding_config.TableConfig
12
+
13
+ # Placeholder of tf.tpu.experimental.embedding._Optimizer which is not exposed.
14
+ TfTpuOptimizer: TypeAlias = Any
15
+
16
+
17
+ OptimizerMapping = collections.namedtuple(
18
+ "OptimizerMapping",
19
+ ["tpu_optimizer_class", "supported_kwargs", "unsupported_kwargs"],
20
+ )
21
+
22
+
23
+ OPTIMIZER_MAPPINGS = {
24
+ keras.optimizers.Adagrad: OptimizerMapping(
25
+ tpu_optimizer_class=tf.tpu.experimental.embedding.Adagrad,
26
+ supported_kwargs=["initial_accumulator_value"],
27
+ unsupported_kwargs={"epsilon": 1e-07},
28
+ ),
29
+ keras.optimizers.Adam: OptimizerMapping(
30
+ tpu_optimizer_class=tf.tpu.experimental.embedding.Adam,
31
+ supported_kwargs=["beta_1", "beta_2", "epsilon"],
32
+ unsupported_kwargs={"amsgrad": False},
33
+ ),
34
+ keras.optimizers.Ftrl: OptimizerMapping(
35
+ tpu_optimizer_class=tf.tpu.experimental.embedding.FTRL,
36
+ supported_kwargs=[
37
+ "learning_rate_power",
38
+ "initial_accumulator_value",
39
+ "l1_regularization_strength",
40
+ "l2_regularization_strength",
41
+ "beta",
42
+ ],
43
+ unsupported_kwargs={"l2_shrinkage_regularization_strength": 0.0},
44
+ ),
45
+ keras.optimizers.SGD: OptimizerMapping(
46
+ tpu_optimizer_class=tf.tpu.experimental.embedding.SGD,
47
+ supported_kwargs=[],
48
+ unsupported_kwargs={"momentum": 0.0, "nesterov": False},
49
+ ),
50
+ }
51
+
52
+
53
+ # KerasRS to TensorFlow
54
+
55
+
56
+ def keras_to_tf_tpu_configuration(
57
+ feature_configs: types.Nested[FeatureConfig],
58
+ table_stacking: str | Sequence[str] | Sequence[Sequence[str]],
59
+ num_replicas_in_sync: int,
60
+ ) -> tuple[
61
+ types.Nested[tf.tpu.experimental.embedding.FeatureConfig],
62
+ tf.tpu.experimental.embedding.SparseCoreEmbeddingConfig,
63
+ ]:
64
+ """Translates a Keras RS configuration to a TensorFlow TPU configuration.
65
+
66
+ Args:
67
+ feature_configs: The nested Keras RS feature configs.
68
+ table_stacking: The Keras RS table stacking.
69
+ num_replicas_in_sync: The number of replicas in sync from the strategy.
70
+
71
+ Returns:
72
+ A tuple containing the TensorFlow TPU feature configs and the TensorFlow
73
+ TPU sparse core embedding config.
74
+ """
75
+ tables: dict[int, tf.tpu.experimental.embedding.TableConfig] = {}
76
+ feature_configs = keras.tree.map_structure(
77
+ lambda f: keras_to_tf_tpu_feature_config(
78
+ f, tables, num_replicas_in_sync
79
+ ),
80
+ feature_configs,
81
+ )
82
+
83
+ # max_ids_per_chip_per_sample
84
+ # max_ids_per_table
85
+ # max_unique_ids_per_table
86
+
87
+ if table_stacking is None:
88
+ disable_table_stacking = True
89
+ elif table_stacking == "auto":
90
+ disable_table_stacking = False
91
+ else:
92
+ raise ValueError(
93
+ f"Unsupported table stacking for Tensorflow {table_stacking}, must "
94
+ "be 'auto' or None."
95
+ )
96
+
97
+ # Find alternative.
98
+ # `initialize_tables_on_host` is set to False. Otherwise, if the
99
+ # `TPUEmbedding` layer is built within Keras' `compute_output_spec` (meaning
100
+ # within `call`), the tables are created within a `FuncGraph` and the
101
+ # resulting tables are destroyed at the end of it.
102
+ sparse_core_embedding_config = (
103
+ tf.tpu.experimental.embedding.SparseCoreEmbeddingConfig(
104
+ disable_table_stacking=disable_table_stacking,
105
+ initialize_tables_on_host=False,
106
+ )
107
+ )
108
+
109
+ return feature_configs, sparse_core_embedding_config
110
+
111
+
112
+ def keras_to_tf_tpu_feature_config(
113
+ feature_config: FeatureConfig,
114
+ tables: dict[int, tf.tpu.experimental.embedding.TableConfig],
115
+ num_replicas_in_sync: int,
116
+ ) -> tf.tpu.experimental.embedding.FeatureConfig:
117
+ """Translates a Keras RS feature config to a TensorFlow TPU feature config.
118
+
119
+ This creates the table config and adds it to the mapping if it doesn't exist
120
+ in the `tables` mapping`.
121
+
122
+ Args:
123
+ feature_config: The Keras RS feature config to translate.
124
+ tables: A mapping of KerasRS table config ids to TF TPU table configs.
125
+ num_replicas_in_sync: The number of replicas in sync from the strategy.
126
+
127
+ Returns:
128
+ The TensorFlow TPU feature config.
129
+ """
130
+ if num_replicas_in_sync <= 0:
131
+ raise ValueError(
132
+ "`num_replicas_in_sync` must be positive, "
133
+ f"but got {num_replicas_in_sync}."
134
+ )
135
+
136
+ table = tables.get(id(feature_config.table), None)
137
+ if table is None:
138
+ table = keras_to_tf_tpu_table_config(feature_config.table)
139
+ tables[id(feature_config.table)] = table
140
+
141
+ if len(feature_config.output_shape) < 2:
142
+ raise ValueError(
143
+ f"Invalid `output_shape` {feature_config.output_shape} in "
144
+ f"`FeatureConfig` {feature_config}. It must have at least 2 "
145
+ "dimensions: a batch dimension and an embedding dimension."
146
+ )
147
+
148
+ # Exclude last dimension, TensorFlow's TPUEmbedding doesn't want it.
149
+ output_shape = list(feature_config.output_shape[0:-1])
150
+
151
+ batch_size = output_shape[0]
152
+ per_replica_batch_size: int | None = None
153
+ if batch_size is not None:
154
+ if batch_size % num_replicas_in_sync != 0:
155
+ raise ValueError(
156
+ f"Invalid `output_shape` {feature_config.output_shape} in "
157
+ f"`FeatureConfig` {feature_config}. Batch size {batch_size} is "
158
+ f"not a multiple of the number of TPUs {num_replicas_in_sync}."
159
+ )
160
+ per_replica_batch_size = batch_size // num_replicas_in_sync
161
+
162
+ # TensorFlow's TPUEmbedding wants the per replica batch size.
163
+ output_shape = [per_replica_batch_size] + output_shape[1:]
164
+
165
+ # max_sequence_length
166
+ return tf.tpu.experimental.embedding.FeatureConfig(
167
+ name=feature_config.name,
168
+ table=table,
169
+ output_shape=output_shape,
170
+ )
171
+
172
+
173
+ def keras_to_tf_tpu_table_config(
174
+ table_config: TableConfig,
175
+ ) -> tf.tpu.experimental.embedding.TableConfig:
176
+ initializer = table_config.initializer
177
+ if isinstance(initializer, str):
178
+ initializer = keras.initializers.get(initializer)
179
+
180
+ return tf.tpu.experimental.embedding.TableConfig(
181
+ vocabulary_size=table_config.vocabulary_size,
182
+ dim=table_config.embedding_dim,
183
+ initializer=initializer,
184
+ optimizer=to_tf_tpu_optimizer(table_config.optimizer),
185
+ combiner=table_config.combiner,
186
+ name=table_config.name,
187
+ )
188
+
189
+
190
+ def keras_to_tf_tpu_optimizer(
191
+ optimizer: keras.optimizers.Optimizer,
192
+ ) -> TfTpuOptimizer:
193
+ """Translates a Keras optimizer to a TensorFlow TPU `_Optimizer`.
194
+
195
+ Args:
196
+ optimizer: The Keras optimizer to translate.
197
+
198
+ Returns:
199
+ The TensorFlow TPU `_Optimizer`.
200
+ """
201
+ tpu_optimizer_kwargs: dict[str, Any] = {}
202
+
203
+ # Supported keras optimizer general options.
204
+ learning_rate = optimizer._learning_rate # pylint: disable=protected-access
205
+ if isinstance(
206
+ learning_rate, keras.optimizers.schedules.LearningRateSchedule
207
+ ):
208
+ # Note: learning rate requires incrementing iterations in optimizer.
209
+ tpu_optimizer_kwargs["learning_rate"] = lambda: optimizer.learning_rate
210
+ elif callable(learning_rate):
211
+ tpu_optimizer_kwargs["learning_rate"] = learning_rate
212
+ else:
213
+ learning_rate = optimizer.get_config()["learning_rate"]
214
+ if isinstance(learning_rate, float):
215
+ tpu_optimizer_kwargs["learning_rate"] = learning_rate
216
+ else:
217
+ raise ValueError(
218
+ f"Unsupported learning rate: {learning_rate} of type"
219
+ f" {type(learning_rate)}."
220
+ )
221
+
222
+ if optimizer.weight_decay is not None:
223
+ tpu_optimizer_kwargs["weight_decay_factor"] = optimizer.weight_decay
224
+ if optimizer.clipvalue is not None:
225
+ tpu_optimizer_kwargs["clipvalue"] = optimizer.clipvalue
226
+ if optimizer.gradient_accumulation_steps is not None:
227
+ tpu_optimizer_kwargs["use_gradient_accumulation"] = True
228
+
229
+ # Unsupported keras optimizer general options.
230
+ if optimizer.clipnorm is not None:
231
+ raise ValueError("Unsupported optimizer option `Optimizer.clipnorm`.")
232
+ if optimizer.global_clipnorm is not None:
233
+ raise ValueError(
234
+ "Unsupported optimizer option `Optimizer.global_clipnorm`."
235
+ )
236
+ if optimizer.use_ema:
237
+ raise ValueError("Unsupported optimizer option `Optimizer.use_ema`.")
238
+ if optimizer.loss_scale_factor is not None:
239
+ raise ValueError(
240
+ "Unsupported optimizer option `Optimizer.loss_scale_factor`."
241
+ )
242
+
243
+ optimizer_mapping = None
244
+ for optimizer_class, mapping in OPTIMIZER_MAPPINGS.items():
245
+ # Handle subclasses of the main optimizer class.
246
+ if isinstance(optimizer, optimizer_class):
247
+ optimizer_mapping = mapping
248
+ break
249
+ if optimizer_mapping is None:
250
+ raise ValueError(
251
+ f"Unsupported optimizer type {type(optimizer)}. Optimizer must be "
252
+ f"one of {list(OPTIMIZER_MAPPINGS.keys())}."
253
+ )
254
+
255
+ for argname in optimizer_mapping.supported_kwargs:
256
+ tpu_optimizer_kwargs[argname] = getattr(optimizer, argname)
257
+
258
+ for argname, disabled_value in optimizer_mapping.unsupported_kwargs.items():
259
+ if disabled_value is None:
260
+ if getattr(optimizer, argname) is not None:
261
+ raise ValueError(f"Unsupported optimizer option {argname}.")
262
+ elif getattr(optimizer, argname) != disabled_value:
263
+ raise ValueError(f"Unsupported optimizer option {argname}.")
264
+
265
+ return optimizer_mapping.tpu_optimizer_class(**tpu_optimizer_kwargs)
266
+
267
+
268
+ def to_tf_tpu_optimizer(
269
+ optimizer: str | keras.optimizers.Optimizer | TfTpuOptimizer | None,
270
+ ) -> TfTpuOptimizer:
271
+ """Translates a Keras optimizer into a TensorFlow TPU `_Optimizer`.
272
+
273
+ Args:
274
+ optimizer: The optimizer to translate.
275
+
276
+ Returns:
277
+ The equivalent TensorFlow TPU `_Optimizer`.
278
+
279
+ Raises:
280
+ ValueError: If the optimizer or one of its argument is not supported.
281
+ """
282
+ if optimizer is None:
283
+ return None
284
+ elif isinstance(
285
+ optimizer,
286
+ (
287
+ tf.tpu.experimental.embedding.SGD,
288
+ tf.tpu.experimental.embedding.Adagrad,
289
+ tf.tpu.experimental.embedding.Adam,
290
+ tf.tpu.experimental.embedding.FTRL,
291
+ ),
292
+ ):
293
+ return optimizer
294
+ elif isinstance(optimizer, str):
295
+ if optimizer == "sgd":
296
+ return tf.tpu.experimental.embedding.SGD()
297
+ elif optimizer == "adagrad":
298
+ return tf.tpu.experimental.embedding.Adagrad()
299
+ elif optimizer == "adam":
300
+ return tf.tpu.experimental.embedding.Adam()
301
+ elif optimizer == "ftrl":
302
+ return tf.tpu.experimental.embedding.FTRL()
303
+ else:
304
+ raise ValueError(
305
+ f"Unknown optimizer name '{optimizer}'. Please use one of "
306
+ "'sgd', 'adagrad', 'adam', or 'ftrl'"
307
+ )
308
+ elif isinstance(optimizer, keras.optimizers.Optimizer):
309
+ return keras_to_tf_tpu_optimizer(optimizer)
310
+ else:
311
+ raise ValueError(
312
+ f"Unknown optimizer type {type(optimizer)}. Please pass an "
313
+ "optimizername as a string, a subclass of keras optimizer or an "
314
+ "instance of one of the optimizer parameter classes in "
315
+ "`tf.tpu.experimental.embedding`."
316
+ )
317
+
318
+
319
+ # TensorFlow to TensorFlow
320
+
321
+
322
+ def clone_tf_tpu_feature_configs(
323
+ feature_configs: types.Nested[tf.tpu.experimental.embedding.FeatureConfig],
324
+ ) -> types.Nested[tf.tpu.experimental.embedding.FeatureConfig]:
325
+ """Clones and resolves TensorFlow TPU feature configs.
326
+
327
+ This function clones the feature configs and resolves the table configs.
328
+
329
+ Args:
330
+ feature_configs: The TensorFlow TPU feature configs to clone and resolve.
331
+
332
+ Returns:
333
+ The cloned and resolved TensorFlow TPU feature configs.
334
+ """
335
+ table_configs_dict = {}
336
+
337
+ def clone_and_resolve_tf_tpu_feature_config(
338
+ fc: tf.tpu.experimental.embedding.FeatureConfig,
339
+ ) -> tf.tpu.experimental.embedding.FeatureConfig:
340
+ if fc.table not in table_configs_dict:
341
+ table_configs_dict[fc.table] = (
342
+ tf.tpu.experimental.embedding.TableConfig(
343
+ vocabulary_size=fc.table.vocabulary_size,
344
+ dim=fc.table.dim,
345
+ initializer=fc.table.initializer,
346
+ optimizer=to_tf_tpu_optimizer(fc.table.optimizer),
347
+ combiner=fc.table.combiner,
348
+ name=fc.table.name,
349
+ quantization_config=fc.table.quantization_config,
350
+ layout=fc.table.layout,
351
+ )
352
+ )
353
+ return tf.tpu.experimental.embedding.FeatureConfig(
354
+ table=table_configs_dict[fc.table],
355
+ max_sequence_length=fc.max_sequence_length,
356
+ validate_weights_and_indices=fc.validate_weights_and_indices,
357
+ output_shape=fc.output_shape,
358
+ name=fc.name,
359
+ )
360
+
361
+ return keras.tree.map_structure(
362
+ clone_and_resolve_tf_tpu_feature_config, feature_configs
363
+ )