keras-rs-nightly 0.0.1.dev2025050103__py3-none-any.whl → 0.2.2.dev202506100336__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of keras-rs-nightly might be problematic. Click here for more details.

Files changed (33) hide show
  1. keras_rs/layers/__init__.py +12 -0
  2. keras_rs/src/layers/embedding/__init__.py +0 -0
  3. keras_rs/src/layers/embedding/base_distributed_embedding.py +1124 -0
  4. keras_rs/src/layers/embedding/distributed_embedding.py +33 -0
  5. keras_rs/src/layers/embedding/distributed_embedding_config.py +129 -0
  6. keras_rs/src/layers/embedding/embed_reduce.py +309 -0
  7. keras_rs/src/layers/embedding/jax/__init__.py +0 -0
  8. keras_rs/src/layers/embedding/jax/config_conversion.py +398 -0
  9. keras_rs/src/layers/embedding/jax/distributed_embedding.py +892 -0
  10. keras_rs/src/layers/embedding/jax/embedding_lookup.py +255 -0
  11. keras_rs/src/layers/embedding/jax/embedding_utils.py +596 -0
  12. keras_rs/src/layers/embedding/tensorflow/__init__.py +0 -0
  13. keras_rs/src/layers/embedding/tensorflow/config_conversion.py +323 -0
  14. keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py +424 -0
  15. keras_rs/src/layers/feature_interaction/dot_interaction.py +2 -2
  16. keras_rs/src/layers/feature_interaction/feature_cross.py +14 -16
  17. keras_rs/src/layers/retrieval/brute_force_retrieval.py +5 -5
  18. keras_rs/src/layers/retrieval/retrieval.py +4 -4
  19. keras_rs/src/losses/pairwise_loss.py +2 -2
  20. keras_rs/src/losses/pairwise_mean_squared_error.py +1 -3
  21. keras_rs/src/metrics/dcg.py +2 -2
  22. keras_rs/src/metrics/ndcg.py +2 -2
  23. keras_rs/src/metrics/ranking_metric.py +4 -4
  24. keras_rs/src/metrics/ranking_metrics_utils.py +8 -8
  25. keras_rs/src/metrics/utils.py +2 -4
  26. keras_rs/src/types.py +43 -14
  27. keras_rs/src/utils/keras_utils.py +26 -6
  28. keras_rs/src/version.py +1 -1
  29. {keras_rs_nightly-0.0.1.dev2025050103.dist-info → keras_rs_nightly-0.2.2.dev202506100336.dist-info}/METADATA +6 -3
  30. keras_rs_nightly-0.2.2.dev202506100336.dist-info/RECORD +55 -0
  31. {keras_rs_nightly-0.0.1.dev2025050103.dist-info → keras_rs_nightly-0.2.2.dev202506100336.dist-info}/WHEEL +1 -1
  32. keras_rs_nightly-0.0.1.dev2025050103.dist-info/RECORD +0 -42
  33. {keras_rs_nightly-0.0.1.dev2025050103.dist-info → keras_rs_nightly-0.2.2.dev202506100336.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,323 @@
1
+ import collections
2
+ from typing import Any, Sequence, TypeAlias
3
+
4
+ import keras
5
+ import tensorflow as tf
6
+
7
+ from keras_rs.src import types
8
+ from keras_rs.src.layers.embedding import distributed_embedding_config
9
+
10
+ FeatureConfig = distributed_embedding_config.FeatureConfig
11
+ TableConfig = distributed_embedding_config.TableConfig
12
+
13
+ # Placeholder of tf.tpu.experimental.embedding._Optimizer which is not exposed.
14
+ TfTpuOptimizer: TypeAlias = Any
15
+
16
+
17
+ OptimizerMapping = collections.namedtuple(
18
+ "OptimizerMapping",
19
+ ["tpu_optimizer_class", "supported_kwargs", "unsupported_kwargs"],
20
+ )
21
+
22
+
23
+ OPTIMIZER_MAPPINGS = {
24
+ keras.optimizers.Adagrad: OptimizerMapping(
25
+ tpu_optimizer_class=tf.tpu.experimental.embedding.Adagrad,
26
+ supported_kwargs=["initial_accumulator_value"],
27
+ unsupported_kwargs={"epsilon": 1e-07},
28
+ ),
29
+ keras.optimizers.Adam: OptimizerMapping(
30
+ tpu_optimizer_class=tf.tpu.experimental.embedding.Adam,
31
+ supported_kwargs=["beta_1", "beta_2", "epsilon"],
32
+ unsupported_kwargs={"amsgrad": False},
33
+ ),
34
+ keras.optimizers.Ftrl: OptimizerMapping(
35
+ tpu_optimizer_class=tf.tpu.experimental.embedding.FTRL,
36
+ supported_kwargs=[
37
+ "learning_rate_power",
38
+ "initial_accumulator_value",
39
+ "l1_regularization_strength",
40
+ "l2_regularization_strength",
41
+ "beta",
42
+ ],
43
+ unsupported_kwargs={"l2_shrinkage_regularization_strength": 0.0},
44
+ ),
45
+ keras.optimizers.SGD: OptimizerMapping(
46
+ tpu_optimizer_class=tf.tpu.experimental.embedding.SGD,
47
+ supported_kwargs=[],
48
+ unsupported_kwargs={"momentum": 0.0, "nesterov": False},
49
+ ),
50
+ }
51
+
52
+
53
+ # KerasRS to TensorFlow
54
+
55
+
56
+ def translate_keras_rs_configuration(
57
+ feature_configs: types.Nested[FeatureConfig],
58
+ table_stacking: str | Sequence[str] | Sequence[Sequence[str]],
59
+ ) -> tuple[
60
+ types.Nested[tf.tpu.experimental.embedding.FeatureConfig],
61
+ tf.tpu.experimental.embedding.SparseCoreEmbeddingConfig,
62
+ ]:
63
+ """Translates a Keras RS configuration to a TensorFlow TPU configuration.
64
+
65
+ Args:
66
+ feature_configs: The nested Keras RS feature configs.
67
+ table_stacking: The Keras RS table stacking.
68
+
69
+ Returns:
70
+ A tuple containing the TensorFlow TPU feature configs and the TensorFlow
71
+ TPU sparse core embedding config.
72
+ """
73
+ tables: dict[TableConfig, tf.tpu.experimental.embedding.TableConfig] = {}
74
+ feature_configs = keras.tree.map_structure(
75
+ lambda f: translate_keras_rs_feature_config(f, tables), feature_configs
76
+ )
77
+
78
+ # max_ids_per_chip_per_sample
79
+ # max_ids_per_table
80
+ # max_unique_ids_per_table
81
+
82
+ if table_stacking is None:
83
+ disable_table_stacking = True
84
+ elif table_stacking == "auto":
85
+ disable_table_stacking = False
86
+ else:
87
+ raise ValueError(
88
+ f"Unsupported table stacking for Tensorflow {table_stacking}, must "
89
+ "be 'auto' or None."
90
+ )
91
+
92
+ # Find alternative.
93
+ # `initialize_tables_on_host` is set to False. Otherwise, if the
94
+ # `TPUEmbedding` layer is built within Keras' `compute_output_spec` (meaning
95
+ # within `call`), the tables are created within a `FuncGraph` and the
96
+ # resulting tables are destroyed at the end of it.
97
+ sparse_core_embedding_config = (
98
+ tf.tpu.experimental.embedding.SparseCoreEmbeddingConfig(
99
+ disable_table_stacking=disable_table_stacking,
100
+ initialize_tables_on_host=False,
101
+ )
102
+ )
103
+
104
+ return feature_configs, sparse_core_embedding_config
105
+
106
+
107
+ def translate_keras_rs_feature_config(
108
+ feature_config: FeatureConfig,
109
+ tables: dict[TableConfig, tf.tpu.experimental.embedding.TableConfig],
110
+ ) -> tf.tpu.experimental.embedding.FeatureConfig:
111
+ """Translates a Keras RS feature config to a TensorFlow TPU feature config.
112
+
113
+ This creates the table config and adds it to the mapping if it doesn't exist
114
+ in the `tables` mapping`.
115
+
116
+ Args:
117
+ feature_config: The Keras RS feature config to translate.
118
+ tables: A mapping of KerasRS table configs to TF TPU table configs.
119
+
120
+ Returns:
121
+ The TensorFlow TPU feature config.
122
+ """
123
+ table = tables.get(feature_config.table, None)
124
+ if table is None:
125
+ table = translate_keras_rs_table_config(feature_config.table)
126
+ tables[feature_config.table] = table
127
+
128
+ # max_sequence_length
129
+ return tf.tpu.experimental.embedding.FeatureConfig(
130
+ name=feature_config.name,
131
+ table=table,
132
+ output_shape=feature_config.output_shape[
133
+ 0:-1
134
+ ], # exclude last dimension
135
+ )
136
+
137
+
138
+ def translate_keras_rs_table_config(
139
+ table_config: TableConfig,
140
+ ) -> tf.tpu.experimental.embedding.TableConfig:
141
+ initializer = table_config.initializer
142
+ if isinstance(initializer, str):
143
+ initializer = keras.initializers.get(initializer)
144
+
145
+ return tf.tpu.experimental.embedding.TableConfig(
146
+ vocabulary_size=table_config.vocabulary_size,
147
+ dim=table_config.embedding_dim,
148
+ initializer=initializer,
149
+ optimizer=translate_optimizer(table_config.optimizer),
150
+ combiner=table_config.combiner,
151
+ name=table_config.name,
152
+ )
153
+
154
+
155
+ def translate_keras_optimizer(
156
+ optimizer: keras.optimizers.Optimizer,
157
+ ) -> TfTpuOptimizer:
158
+ """Translates a Keras optimizer to a TensorFlow TPU `_Optimizer`.
159
+
160
+ Args:
161
+ optimizer: The Keras optimizer to translate.
162
+
163
+ Returns:
164
+ The TensorFlow TPU `_Optimizer`.
165
+ """
166
+ tpu_optimizer_kwargs: dict[str, Any] = {}
167
+
168
+ # Supported keras optimizer general options.
169
+ learning_rate = optimizer._learning_rate # pylint: disable=protected-access
170
+ if isinstance(
171
+ learning_rate, keras.optimizers.schedules.LearningRateSchedule
172
+ ):
173
+ # Note: learning rate requires incrementing iterations in optimizer.
174
+ tpu_optimizer_kwargs["learning_rate"] = lambda: optimizer.learning_rate
175
+ elif callable(learning_rate):
176
+ tpu_optimizer_kwargs["learning_rate"] = learning_rate
177
+ else:
178
+ learning_rate = optimizer.get_config()["learning_rate"]
179
+ if isinstance(learning_rate, float):
180
+ tpu_optimizer_kwargs["learning_rate"] = learning_rate
181
+ else:
182
+ raise ValueError(
183
+ f"Unsupported learning rate: {learning_rate} of type"
184
+ f" {type(learning_rate)}."
185
+ )
186
+
187
+ if optimizer.weight_decay is not None:
188
+ tpu_optimizer_kwargs["weight_decay_factor"] = optimizer.weight_decay
189
+ if optimizer.clipvalue is not None:
190
+ tpu_optimizer_kwargs["clipvalue"] = optimizer.clipvalue
191
+ if optimizer.gradient_accumulation_steps is not None:
192
+ tpu_optimizer_kwargs["use_gradient_accumulation"] = True
193
+
194
+ # Unsupported keras optimizer general options.
195
+ if optimizer.clipnorm is not None:
196
+ raise ValueError("Unsupported optimizer option `Optimizer.clipnorm`.")
197
+ if optimizer.global_clipnorm is not None:
198
+ raise ValueError(
199
+ "Unsupported optimizer option `Optimizer.global_clipnorm`."
200
+ )
201
+ if optimizer.use_ema:
202
+ raise ValueError("Unsupported optimizer option `Optimizer.use_ema`.")
203
+ if optimizer.loss_scale_factor is not None:
204
+ raise ValueError(
205
+ "Unsupported optimizer option `Optimizer.loss_scale_factor`."
206
+ )
207
+
208
+ optimizer_mapping = OPTIMIZER_MAPPINGS.get(type(optimizer), None)
209
+ if optimizer_mapping is None:
210
+ raise ValueError(
211
+ f"Unsupported optimizer type {type(optimizer)}. Optimizer must be "
212
+ f"one of {list(OPTIMIZER_MAPPINGS.keys())}."
213
+ )
214
+
215
+ for argname in optimizer_mapping.supported_kwargs:
216
+ tpu_optimizer_kwargs[argname] = getattr(optimizer, argname)
217
+
218
+ for argname, disabled_value in optimizer_mapping.unsupported_kwargs.items():
219
+ if disabled_value is None:
220
+ if getattr(optimizer, argname) is not None:
221
+ raise ValueError(f"Unsupported optimizer option {argname}.")
222
+ elif getattr(optimizer, argname) != disabled_value:
223
+ raise ValueError(f"Unsupported optimizer option {argname}.")
224
+
225
+ return optimizer_mapping.tpu_optimizer_class(**tpu_optimizer_kwargs)
226
+
227
+
228
+ def translate_optimizer(
229
+ optimizer: str | keras.optimizers.Optimizer | TfTpuOptimizer | None,
230
+ ) -> TfTpuOptimizer:
231
+ """Translates a Keras optimizer into a TensorFlow TPU `_Optimizer`.
232
+
233
+ Args:
234
+ optimizer: The optimizer to translate.
235
+
236
+ Returns:
237
+ The equivalent TensorFlow TPU `_Optimizer`.
238
+
239
+ Raises:
240
+ ValueError: If the optimizer or one of its argument is not supported.
241
+ """
242
+ if optimizer is None:
243
+ return None
244
+ elif isinstance(
245
+ optimizer,
246
+ (
247
+ tf.tpu.experimental.embedding.SGD,
248
+ tf.tpu.experimental.embedding.Adagrad,
249
+ tf.tpu.experimental.embedding.Adam,
250
+ tf.tpu.experimental.embedding.FTRL,
251
+ ),
252
+ ):
253
+ return optimizer
254
+ elif isinstance(optimizer, str):
255
+ if optimizer == "sgd":
256
+ return tf.tpu.experimental.embedding.SGD()
257
+ elif optimizer == "adagrad":
258
+ return tf.tpu.experimental.embedding.Adagrad()
259
+ elif optimizer == "adam":
260
+ return tf.tpu.experimental.embedding.Adam()
261
+ elif optimizer == "ftrl":
262
+ return tf.tpu.experimental.embedding.FTRL()
263
+ else:
264
+ raise ValueError(
265
+ f"Unknown optimizer name '{optimizer}'. Please use one of "
266
+ "'sgd', 'adagrad', 'adam', or 'ftrl'"
267
+ )
268
+ elif isinstance(optimizer, keras.optimizers.Optimizer):
269
+ return translate_keras_optimizer(optimizer)
270
+ else:
271
+ raise ValueError(
272
+ f"Unknown optimizer type {type(optimizer)}. Please pass an "
273
+ "optimizername as a string, a subclass of keras optimizer or an "
274
+ "instance of one of the optimizer parameter classes in "
275
+ "`tf.tpu.experimental.embedding`."
276
+ )
277
+
278
+
279
+ # TensorFlow to TensorFlow
280
+
281
+
282
+ def clone_tf_feature_configs(
283
+ feature_configs: types.Nested[tf.tpu.experimental.embedding.FeatureConfig],
284
+ ) -> types.Nested[tf.tpu.experimental.embedding.FeatureConfig]:
285
+ """Clones and resolves TensorFlow TPU feature configs.
286
+
287
+ This function clones the feature configs and resolves the table configs.
288
+
289
+ Args:
290
+ feature_configs: The TensorFlow TPU feature configs to clone and resolve.
291
+
292
+ Returns:
293
+ The cloned and resolved TensorFlow TPU feature configs.
294
+ """
295
+ table_configs_dict = {}
296
+
297
+ def clone_and_resolve_tf_feature_config(
298
+ fc: tf.tpu.experimental.embedding.FeatureConfig,
299
+ ) -> tf.tpu.experimental.embedding.FeatureConfig:
300
+ if fc.table not in table_configs_dict:
301
+ table_configs_dict[fc.table] = (
302
+ tf.tpu.experimental.embedding.TableConfig(
303
+ vocabulary_size=fc.table.vocabulary_size,
304
+ dim=fc.table.dim,
305
+ initializer=fc.table.initializer,
306
+ optimizer=translate_optimizer(fc.table.optimizer),
307
+ combiner=fc.table.combiner,
308
+ name=fc.table.name,
309
+ quantization_config=fc.table.quantization_config,
310
+ layout=fc.table.layout,
311
+ )
312
+ )
313
+ return tf.tpu.experimental.embedding.FeatureConfig(
314
+ table=table_configs_dict[fc.table],
315
+ max_sequence_length=fc.max_sequence_length,
316
+ validate_weights_and_indices=fc.validate_weights_and_indices,
317
+ output_shape=fc.output_shape,
318
+ name=fc.name,
319
+ )
320
+
321
+ return keras.tree.map_structure(
322
+ clone_and_resolve_tf_feature_config, feature_configs
323
+ )