keras-rs-nightly 0.2.2.dev202506160338__py3-none-any.whl → 0.2.2.dev202506180335__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of keras-rs-nightly might be problematic. Click here for more details.

@@ -174,12 +174,9 @@ class DistributedEmbedding(keras.layers.Layer):
174
174
  supported on all backends and accelerators:
175
175
 
176
176
  - `keras.optimizers.Adagrad`
177
- - `keras.optimizers.SGD`
178
-
179
- The following are additionally available when using the TensorFlow backend:
180
-
181
177
  - `keras.optimizers.Adam`
182
178
  - `keras.optimizers.Ftrl`
179
+ - `keras.optimizers.SGD`
183
180
 
184
181
  Also, not all parameters of the optimizers are supported (e.g. the
185
182
  `nesterov` option of `SGD`). An error is raised when an unsupported
@@ -0,0 +1,104 @@
1
+ """A Wrapper over orbax CheckpointManager for Keras3 Jax TPU Embeddings."""
2
+
3
+ from typing import Any
4
+
5
+ import keras
6
+ import orbax.checkpoint as ocp
7
+ from etils import epath
8
+
9
+
10
+ class JaxKeras3CheckpointManager(ocp.CheckpointManager):
11
+ """A wrapper over orbax CheckpointManager for Keras3 Jax TPU Embeddings."""
12
+
13
+ def __init__(
14
+ self,
15
+ model: keras.Model,
16
+ checkpoint_dir: epath.PathLike,
17
+ max_to_keep: int,
18
+ steps_per_epoch: int = 1,
19
+ **kwargs: Any,
20
+ ):
21
+ options = ocp.CheckpointManagerOptions(
22
+ max_to_keep=max_to_keep, enable_async_checkpointing=False, **kwargs
23
+ )
24
+ self._model = model
25
+ self._steps_per_epoch = steps_per_epoch
26
+ self._checkpoint_dir = checkpoint_dir
27
+ super().__init__(checkpoint_dir, options=options)
28
+
29
+ def _get_state(self) -> tuple[dict[str, Any], Any | None]:
30
+ """Gets the model state and metrics"""
31
+ model_state = self._model.get_state_tree()
32
+ state = {}
33
+ metrics = None
34
+ for k, v in model_state.items():
35
+ if k == "metrics_variables":
36
+ metrics = v
37
+ else:
38
+ state[k] = v
39
+ return state, metrics
40
+
41
+ def save_state(self, epoch: int) -> None:
42
+ """Saves the model to the checkpoint directory.
43
+
44
+ Args:
45
+ epoch: The epoch number at which the state is saved.
46
+ """
47
+ state, metrics_value = self._get_state()
48
+ self.save(
49
+ epoch * self._steps_per_epoch,
50
+ args=ocp.args.StandardSave(item=state),
51
+ metrics=metrics_value,
52
+ )
53
+
54
+ def restore_state(self, step: int | None = None) -> None:
55
+ """Restores the model from the checkpoint directory.
56
+
57
+ Args:
58
+ step: The step .number to restore the state from. Default=None
59
+ restores the latest step.
60
+ """
61
+ if step is None:
62
+ step = self.latest_step()
63
+ # Restore the model state only, not metrics.
64
+ state, _ = self._get_state()
65
+ restored_state = self.restore(
66
+ step, args=ocp.args.StandardRestore(item=state)
67
+ )
68
+ self._model.set_state_tree(restored_state)
69
+
70
+
71
+ class JaxKeras3CheckpointCallback(keras.callbacks.Callback):
72
+ """A callback for checkpointing and restoring state using Orbax."""
73
+
74
+ def __init__(
75
+ self,
76
+ model: keras.Model,
77
+ checkpoint_dir: epath.PathLike,
78
+ max_to_keep: int,
79
+ steps_per_epoch: int = 1,
80
+ **kwargs: Any,
81
+ ):
82
+ if keras.backend.backend() != "jax":
83
+ raise ValueError(
84
+ "`JaxKeras3CheckpointCallback` is only supported on a "
85
+ "`jax` backend."
86
+ )
87
+ self._checkpoint_manager = JaxKeras3CheckpointManager(
88
+ model, checkpoint_dir, max_to_keep, steps_per_epoch, **kwargs
89
+ )
90
+
91
+ def on_train_begin(self, logs: dict[str, Any] | None = None) -> None:
92
+ if not self.model.built or not self.model.optimizer.built:
93
+ raise ValueError(
94
+ "To use `JaxKeras3CheckpointCallback`, your model and "
95
+ "optimizer must be built before you call `fit()`."
96
+ )
97
+ latest_epoch = self._checkpoint_manager.latest_step()
98
+ if latest_epoch is not None:
99
+ self._checkpoint_manager.restore_state(step=latest_epoch)
100
+
101
+ def on_epoch_end(
102
+ self, epoch: int, logs: dict[str, Any] | None = None
103
+ ) -> None:
104
+ self._checkpoint_manager.save_state(epoch)
@@ -229,18 +229,63 @@ def keras_to_jte_optimizer(
229
229
  # pylint: disable-next=protected-access
230
230
  learning_rate = keras_to_jte_learning_rate(optimizer._learning_rate)
231
231
 
232
- # SGD or Adagrad
232
+ # Unsupported keras optimizer general options.
233
+ if optimizer.clipnorm is not None:
234
+ raise ValueError("Unsupported optimizer option `clipnorm`.")
235
+ if optimizer.global_clipnorm is not None:
236
+ raise ValueError("Unsupported optimizer option `global_clipnorm`.")
237
+ if optimizer.use_ema:
238
+ raise ValueError("Unsupported optimizer option `use_ema`.")
239
+ if optimizer.loss_scale_factor is not None:
240
+ raise ValueError("Unsupported optimizer option `loss_scale_factor`.")
241
+
242
+ # Supported optimizers.
233
243
  if isinstance(optimizer, keras.optimizers.SGD):
244
+ if getattr(optimizer, "nesterov", False):
245
+ raise ValueError("Unsupported optimizer option `nesterov`.")
246
+ if getattr(optimizer, "momentum", 0.0) != 0.0:
247
+ raise ValueError("Unsupported optimizer option `momentum`.")
234
248
  return embedding_spec.SGDOptimizerSpec(learning_rate=learning_rate)
235
249
  elif isinstance(optimizer, keras.optimizers.Adagrad):
250
+ if getattr(optimizer, "epsilon", 1e-7) != 1e-7:
251
+ raise ValueError("Unsupported optimizer option `epsilon`.")
236
252
  return embedding_spec.AdagradOptimizerSpec(
237
253
  learning_rate=learning_rate,
238
254
  initial_accumulator_value=optimizer.initial_accumulator_value,
239
255
  )
256
+ elif isinstance(optimizer, keras.optimizers.Adam):
257
+ if getattr(optimizer, "amsgrad", False):
258
+ raise ValueError("Unsupported optimizer option `amsgrad`.")
240
259
 
241
- # Default to SGD for now, since other optimizers are still being created,
242
- # and we don't want to fail.
243
- return embedding_spec.SGDOptimizerSpec(learning_rate=learning_rate)
260
+ return embedding_spec.AdamOptimizerSpec(
261
+ learning_rate=learning_rate,
262
+ beta_1=optimizer.beta_1,
263
+ beta_2=optimizer.beta_2,
264
+ epsilon=optimizer.epsilon,
265
+ )
266
+ elif isinstance(optimizer, keras.optimizers.Ftrl):
267
+ if (
268
+ getattr(optimizer, "l2_shrinkage_regularization_strength", 0.0)
269
+ != 0.0
270
+ ):
271
+ raise ValueError(
272
+ "Unsupported optimizer option "
273
+ "`l2_shrinkage_regularization_strength`."
274
+ )
275
+
276
+ return embedding_spec.FTRLOptimizerSpec(
277
+ learning_rate=learning_rate,
278
+ learning_rate_power=optimizer.learning_rate_power,
279
+ l1_regularization_strength=optimizer.l1_regularization_strength,
280
+ l2_regularization_strength=optimizer.l2_regularization_strength,
281
+ beta=optimizer.beta,
282
+ initial_accumulator_value=optimizer.initial_accumulator_value,
283
+ )
284
+
285
+ raise ValueError(
286
+ f"Unsupported optimizer type {type(optimizer)}. Optimizer must be "
287
+ f"one of [Adagrad, Adam, Ftrl, SGD]."
288
+ )
244
289
 
245
290
 
246
291
  def jte_to_keras_optimizer(
@@ -262,8 +307,33 @@ def jte_to_keras_optimizer(
262
307
  learning_rate=learning_rate,
263
308
  initial_accumulator_value=optimizer.initial_accumulator_value,
264
309
  )
310
+ elif isinstance(optimizer, embedding_spec.AdamOptimizerSpec):
311
+ return keras.optimizers.Adam(
312
+ learning_rate=learning_rate,
313
+ beta_1=optimizer.beta_1,
314
+ beta_2=optimizer.beta_2,
315
+ epsilon=optimizer.epsilon,
316
+ )
317
+ elif isinstance(optimizer, embedding_spec.FTRLOptimizerSpec):
318
+ if getattr(optimizer, "initial_linear_value", 0.0) != 0.0:
319
+ raise ValueError(
320
+ "Unsupported optimizer option `initial_linear_value`."
321
+ )
322
+ if getattr(optimizer, "multiply_linear_by_learning_rate", False):
323
+ raise ValueError(
324
+ "Unsupported optimizer option "
325
+ "`multiply_linear_by_learning_rate`."
326
+ )
327
+ return keras.optimizers.Ftrl(
328
+ learning_rate=learning_rate,
329
+ learning_rate_power=optimizer.learning_rate_power,
330
+ initial_accumulator_value=optimizer.initial_accumulator_value,
331
+ l1_regularization_strength=optimizer.l1_regularization_strength,
332
+ l2_regularization_strength=optimizer.l2_regularization_strength,
333
+ beta=optimizer.beta,
334
+ )
265
335
 
266
- raise ValueError(f"Unknown optimizer spec {optimizer}")
336
+ raise ValueError(f"Unknown optimizer spec {type(optimizer)}.")
267
337
 
268
338
 
269
339
  def _keras_to_jte_table_config(
keras_rs/src/version.py CHANGED
@@ -1,7 +1,7 @@
1
1
  from keras_rs.src.api_export import keras_rs_export
2
2
 
3
3
  # Unique source of truth for the version number.
4
- __version__ = "0.2.2.dev202506160338"
4
+ __version__ = "0.2.2.dev202506180335"
5
5
 
6
6
 
7
7
  @keras_rs_export("keras_rs.version")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: keras-rs-nightly
3
- Version: 0.2.2.dev202506160338
3
+ Version: 0.2.2.dev202506180335
4
4
  Summary: Multi-backend recommender systems with Keras 3.
5
5
  Author-email: Keras team <keras-users@googlegroups.com>
6
6
  License: Apache License 2.0
@@ -5,15 +5,16 @@ keras_rs/metrics/__init__.py,sha256=Qxpf6OFooIL9TIn2l3WgOea3HFRG0hq02glPAxtMZ9c,
5
5
  keras_rs/src/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
6
  keras_rs/src/api_export.py,sha256=RsmG-DvO-cdFeAF9W6LRzms0kvtm-Yp9BAA_d-952zI,510
7
7
  keras_rs/src/types.py,sha256=1A-oLRdX1-f2DsVZBcNl8qNsaH8pM-gnleLT9FWZWBw,1189
8
- keras_rs/src/version.py,sha256=ua2Hsp6lXwvYzm-yqp78ielSnR4SRHWXYVwyaSH_nj8,224
8
+ keras_rs/src/version.py,sha256=LNFzWjtxjRhe0TX6c7WYi7sKOVa0qIKYwIc1tSfziS4,224
9
9
  keras_rs/src/layers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
10
  keras_rs/src/layers/embedding/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
- keras_rs/src/layers/embedding/base_distributed_embedding.py,sha256=dUZ4eS6ktnbnw_Z5gbyZGpQqO44Oyi7DkpNCReL66No,44347
11
+ keras_rs/src/layers/embedding/base_distributed_embedding.py,sha256=11GicbB6m0wsHJQXISp6lcUyACVVYFLFerluUJUjDFA,44265
12
12
  keras_rs/src/layers/embedding/distributed_embedding.py,sha256=94jxUHoGK3Gs9yfV0KxFTuqPo7XFnhgCNlO2FEeiSgM,1072
13
13
  keras_rs/src/layers/embedding/distributed_embedding_config.py,sha256=AWPmZBir1shhqNP6U_jiQ9lsBhMXVikW4B5VnzLsvPg,5579
14
14
  keras_rs/src/layers/embedding/embed_reduce.py,sha256=c-MnEw1-KWs0jTf0JJ_ZBOY-9hRkiFyu989Dof3DnS8,12343
15
15
  keras_rs/src/layers/embedding/jax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
16
- keras_rs/src/layers/embedding/jax/config_conversion.py,sha256=kDgzab8AVYf4jd_8fsiycPA0oFnT83kSWx-TXhzy6sk,13590
16
+ keras_rs/src/layers/embedding/jax/checkpoint_utils.py,sha256=wZ4I5WZVNg5WnrD2j7nhAXgLzDc7xMrUEkSAOx5Sz5c,3495
17
+ keras_rs/src/layers/embedding/jax/config_conversion.py,sha256=Di1UzRwLgGHd7RuWYJMj2mCOr1u9MseFEWaYKnwD9Bs,16742
17
18
  keras_rs/src/layers/embedding/jax/distributed_embedding.py,sha256=s_V2h8smO6_Nd3lQfp6zqNi9XxXIn9wjnggSedRoE8E,35410
18
19
  keras_rs/src/layers/embedding/jax/embedding_lookup.py,sha256=HFkc0pGB9JngnCtbEJE2gDxC2K4gDdQ6GpnatSdnW6s,8205
19
20
  keras_rs/src/layers/embedding/jax/embedding_utils.py,sha256=EHrQjPLl94STLWf9g8Ew8nuwupXRq-a_QmvFlXV6G6A,20331
@@ -49,7 +50,7 @@ keras_rs/src/metrics/utils.py,sha256=fGTo8j0ykVE5Y3yQCS2orSFcHY20Uxt0NazyPsybUsw
49
50
  keras_rs/src/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
50
51
  keras_rs/src/utils/doc_string_utils.py,sha256=CmqomepmaYcvpACpXEXkrJb8DMnvIgmYK-lJ53lYarY,1675
51
52
  keras_rs/src/utils/keras_utils.py,sha256=dc-NFzs3a-qmRw0vBDiMslPLfrm9yymGduLWesXPhuY,2123
52
- keras_rs_nightly-0.2.2.dev202506160338.dist-info/METADATA,sha256=a9xMKY-hqmkOw59DRdW4M7l6iXVR4GgRdY5qACQsXvE,5273
53
- keras_rs_nightly-0.2.2.dev202506160338.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
54
- keras_rs_nightly-0.2.2.dev202506160338.dist-info/top_level.txt,sha256=pWs8X78Z0cn6lfcIb9VYOW5UeJ-TpoaO9dByzo7_FFo,9
55
- keras_rs_nightly-0.2.2.dev202506160338.dist-info/RECORD,,
53
+ keras_rs_nightly-0.2.2.dev202506180335.dist-info/METADATA,sha256=yRFTgK31VJDFNIHnptUz6hmDK7NPyXdhXptXrvFN9_M,5273
54
+ keras_rs_nightly-0.2.2.dev202506180335.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
55
+ keras_rs_nightly-0.2.2.dev202506180335.dist-info/top_level.txt,sha256=pWs8X78Z0cn6lfcIb9VYOW5UeJ-TpoaO9dByzo7_FFo,9
56
+ keras_rs_nightly-0.2.2.dev202506180335.dist-info/RECORD,,