keras-rs-nightly 0.0.1.dev2025021903__py3-none-any.whl → 0.3.1.dev202512130338__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. keras_rs/__init__.py +9 -28
  2. keras_rs/layers/__init__.py +37 -0
  3. keras_rs/losses/__init__.py +19 -0
  4. keras_rs/metrics/__init__.py +16 -0
  5. keras_rs/src/layers/embedding/base_distributed_embedding.py +1151 -0
  6. keras_rs/src/layers/embedding/distributed_embedding.py +33 -0
  7. keras_rs/src/layers/embedding/distributed_embedding_config.py +132 -0
  8. keras_rs/src/layers/embedding/embed_reduce.py +309 -0
  9. keras_rs/src/layers/embedding/jax/__init__.py +0 -0
  10. keras_rs/src/layers/embedding/jax/checkpoint_utils.py +104 -0
  11. keras_rs/src/layers/embedding/jax/config_conversion.py +468 -0
  12. keras_rs/src/layers/embedding/jax/distributed_embedding.py +829 -0
  13. keras_rs/src/layers/embedding/jax/embedding_lookup.py +276 -0
  14. keras_rs/src/layers/embedding/jax/embedding_utils.py +217 -0
  15. keras_rs/src/layers/embedding/tensorflow/__init__.py +0 -0
  16. keras_rs/src/layers/embedding/tensorflow/config_conversion.py +363 -0
  17. keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py +436 -0
  18. keras_rs/src/layers/feature_interaction/__init__.py +0 -0
  19. keras_rs/src/layers/{modeling → feature_interaction}/dot_interaction.py +116 -25
  20. keras_rs/src/layers/{modeling → feature_interaction}/feature_cross.py +40 -22
  21. keras_rs/src/layers/retrieval/brute_force_retrieval.py +16 -65
  22. keras_rs/src/layers/retrieval/hard_negative_mining.py +94 -0
  23. keras_rs/src/layers/retrieval/remove_accidental_hits.py +97 -0
  24. keras_rs/src/layers/retrieval/retrieval.py +127 -0
  25. keras_rs/src/layers/retrieval/sampling_probability_correction.py +63 -0
  26. keras_rs/src/losses/__init__.py +0 -0
  27. keras_rs/src/losses/list_mle_loss.py +212 -0
  28. keras_rs/src/losses/pairwise_hinge_loss.py +90 -0
  29. keras_rs/src/losses/pairwise_logistic_loss.py +99 -0
  30. keras_rs/src/losses/pairwise_loss.py +165 -0
  31. keras_rs/src/losses/pairwise_loss_utils.py +39 -0
  32. keras_rs/src/losses/pairwise_mean_squared_error.py +133 -0
  33. keras_rs/src/losses/pairwise_soft_zero_one_loss.py +98 -0
  34. keras_rs/src/metrics/__init__.py +0 -0
  35. keras_rs/src/metrics/dcg.py +161 -0
  36. keras_rs/src/metrics/mean_average_precision.py +130 -0
  37. keras_rs/src/metrics/mean_reciprocal_rank.py +121 -0
  38. keras_rs/src/metrics/ndcg.py +197 -0
  39. keras_rs/src/metrics/precision_at_k.py +117 -0
  40. keras_rs/src/metrics/ranking_metric.py +260 -0
  41. keras_rs/src/metrics/ranking_metrics_utils.py +257 -0
  42. keras_rs/src/metrics/recall_at_k.py +108 -0
  43. keras_rs/src/metrics/utils.py +70 -0
  44. keras_rs/src/types.py +43 -14
  45. keras_rs/src/utils/doc_string_utils.py +53 -0
  46. keras_rs/src/utils/keras_utils.py +52 -3
  47. keras_rs/src/utils/tpu_test_utils.py +120 -0
  48. keras_rs/src/version.py +1 -1
  49. {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/METADATA +88 -8
  50. keras_rs_nightly-0.3.1.dev202512130338.dist-info/RECORD +58 -0
  51. {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/WHEEL +1 -1
  52. keras_rs/api/__init__.py +0 -9
  53. keras_rs/api/layers/__init__.py +0 -11
  54. keras_rs_nightly-0.0.1.dev2025021903.dist-info/RECORD +0 -19
  55. /keras_rs/src/layers/{modeling → embedding}/__init__.py +0 -0
  56. {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,33 @@
1
+ import importlib.util
2
+ import platform
3
+ import sys
4
+
5
+ import keras
6
+
7
+ from keras_rs.src.api_export import keras_rs_export
8
+
9
+ # JAX distributed embedding is only available on linux_x86_64, and only if
10
+ # jax-tpu-embedding is installed.
11
+ jax_tpu_embedding_spec = importlib.util.find_spec("jax_tpu_embedding")
12
+ if (
13
+ keras.backend.backend() == "jax"
14
+ and sys.platform == "linux"
15
+ and platform.machine().lower() == "x86_64"
16
+ and jax_tpu_embedding_spec is not None
17
+ ):
18
+ from keras_rs.src.layers.embedding.jax.distributed_embedding import (
19
+ DistributedEmbedding as BackendDistributedEmbedding,
20
+ )
21
+ elif keras.backend.backend() == "tensorflow":
22
+ from keras_rs.src.layers.embedding.tensorflow.distributed_embedding import (
23
+ DistributedEmbedding as BackendDistributedEmbedding,
24
+ )
25
+ else:
26
+ from keras_rs.src.layers.embedding.base_distributed_embedding import (
27
+ DistributedEmbedding as BackendDistributedEmbedding,
28
+ )
29
+
30
+
31
+ @keras_rs_export("keras_rs.layers.DistributedEmbedding")
32
+ class DistributedEmbedding(BackendDistributedEmbedding):
33
+ pass
@@ -0,0 +1,132 @@
1
+ """Configuration for TPU embedding layer."""
2
+
3
+ import dataclasses
4
+ from typing import Any
5
+
6
+ import keras
7
+
8
+ from keras_rs.src import types
9
+ from keras_rs.src.api_export import keras_rs_export
10
+
11
+
12
+ @keras_rs_export("keras_rs.layers.TableConfig")
13
+ @dataclasses.dataclass(order=True)
14
+ class TableConfig:
15
+ """Configuration for one embedding table.
16
+
17
+ Configures one table for use by one or more `keras_rs.layers.FeatureConfig`,
18
+ which in turn is used to configure a `keras_rs.layers.DistributedEmbedding`.
19
+
20
+ Attributes:
21
+ name: The name of the table. Must be defined.
22
+ vocabulary_size: Size of the table's vocabulary (number of rows).
23
+ embedding_dim: The embedding dimension (width) of the table.
24
+ initializer: The initializer for the embedding weights. If not
25
+ specified, defaults to `truncated_normal_initializer` with mean
26
+ `0.0` and standard deviation `1 / sqrt(embedding_dim)`.
27
+ optimizer: The optimizer for the embedding table. Only SGD, Adagrad,
28
+ Adam, and FTRL are supported. Note that not all of the optimizer's
29
+ parameters are supported. Defaults to Adam.
30
+ combiner: Specifies how to reduce if there are multiple entries in a
31
+ single row. `mean`, `sqrtn` and `sum` are supported. `mean` is the
32
+ default. `sqrtn` often achieves good accuracy, in particular with
33
+ bag-of-words columns.
34
+ placement: Where to place the embedding table. `"auto"`, which is the
35
+ default, means that the table is placed on SparseCore if available,
36
+ otherwise on the default device where the rest of the model is
37
+ placed. A value of `"sparsecore"` means the table will be placed on
38
+ the SparseCore chips and an error is raised if SparseCore is not
39
+ available. A value of `"default_device"` means the table will be
40
+ placed on the default device where the rest of the model is placed,
41
+ even if SparseCore is available. The default device for the rest of
42
+ the model is the TPU's TensorCore on TPUs, otherwise the GPU or CPU.
43
+ max_ids_per_partition: The max number of ids per partition for the
44
+ table. This is an input data dependent value and is required by the
45
+ compiler to appropriately allocate memory.
46
+ max_unique_ids_per_partition: The max number of unique ids per partition
47
+ for the table. This is an input data dependent value and is required
48
+ by the compiler to appropriately allocate memory.
49
+ """
50
+
51
+ name: str
52
+ vocabulary_size: int
53
+ embedding_dim: int
54
+ initializer: str | keras.initializers.Initializer = (
55
+ keras.initializers.VarianceScaling(mode="fan_out")
56
+ )
57
+ optimizer: str | keras.optimizers.Optimizer = "adam"
58
+ combiner: str = "mean"
59
+ placement: str = "auto"
60
+ max_ids_per_partition: int = 256
61
+ max_unique_ids_per_partition: int = 256
62
+
63
+ def get_config(self) -> dict[str, Any]:
64
+ return {
65
+ "name": self.name,
66
+ "vocabulary_size": self.vocabulary_size,
67
+ "embedding_dim": self.embedding_dim,
68
+ "initializer": keras.saving.serialize_keras_object(
69
+ self.initializer
70
+ ),
71
+ "optimizer": keras.saving.serialize_keras_object(self.optimizer),
72
+ "combiner": self.combiner,
73
+ "placement": self.placement,
74
+ "max_ids_per_partition": self.max_ids_per_partition,
75
+ "max_unique_ids_per_partition": self.max_unique_ids_per_partition,
76
+ }
77
+
78
+ @classmethod
79
+ def from_config(cls, config: dict[str, Any]) -> "TableConfig":
80
+ config = config.copy()
81
+ config["initializer"] = keras.saving.deserialize_keras_object(
82
+ config["initializer"]
83
+ )
84
+ config["optimizer"] = keras.saving.deserialize_keras_object(
85
+ config["optimizer"]
86
+ )
87
+ return cls(**config)
88
+
89
+
90
+ @keras_rs_export("keras_rs.layers.FeatureConfig")
91
+ @dataclasses.dataclass(order=True)
92
+ class FeatureConfig:
93
+ """Configuration for one embedding feature.
94
+
95
+ Configures one feature for `keras_rs.layers.DistributedEmbedding`. Each
96
+ feature uses a table configured via `keras_rs.layers.TableConfig` and
97
+ multiple features can share the same table.
98
+
99
+ Attributes:
100
+ name: The name of the feature. Must be defined.
101
+ table: The table in which to look up this feature.
102
+ input_shape: The input shape of the feature. The feature fed into the
103
+ layer has to match the shape. Note that for ragged dimensions in the
104
+ input, the dimension provided here presents the maximum value;
105
+ anything larger will be truncated. Also note that the first
106
+ dimension represents the global batch size. For example, on TPU,
107
+ this represents the total number of samples that are dispatched to
108
+ all the TPUs connected to the current host.
109
+ output_shape: The output shape of the feature activation. What is
110
+ returned by the embedding layer has to match this shape.
111
+ """
112
+
113
+ name: str
114
+ table: TableConfig
115
+ input_shape: types.Shape
116
+ output_shape: types.Shape
117
+
118
+ def get_config(self) -> dict[str, Any]:
119
+ return {
120
+ "name": self.name,
121
+ "table": self.table.get_config(),
122
+ "input_shape": self.input_shape,
123
+ "output_shape": self.output_shape,
124
+ }
125
+
126
+ @classmethod
127
+ def from_config(cls, config: dict[str, Any]) -> "FeatureConfig":
128
+ config = config.copy()
129
+ # Note: the handling of shared tables during serialization is done in
130
+ # `DistributedEmbedding.from_config()`.
131
+ config["table"] = TableConfig.from_config(config["table"])
132
+ return cls(**config)
@@ -0,0 +1,309 @@
1
+ from typing import Any
2
+
3
+ import keras
4
+ from keras import ops
5
+
6
+ from keras_rs.src import types
7
+ from keras_rs.src.api_export import keras_rs_export
8
+ from keras_rs.src.utils.keras_utils import check_shapes_compatible
9
+
10
+ SUPPORTED_COMBINERS = ("mean", "sum", "sqrtn")
11
+
12
+
13
+ def _is_supported_sparse(x: types.Tensor) -> bool:
14
+ """Determines if the input is a supported sparse tensor.
15
+
16
+ NOTE: Currently only works for the TensorFlow and JAX backends.
17
+
18
+ Args:
19
+ x: Input tensor to check for sparsity.
20
+
21
+ Returns:
22
+ True if `x` is a supported sparse tensor.
23
+ """
24
+ if keras.backend.backend() == "tensorflow":
25
+ import tensorflow as tf
26
+
27
+ return isinstance(x, tf.SparseTensor)
28
+ elif keras.backend.backend() == "jax":
29
+ from jax.experimental import sparse as jax_sparse
30
+
31
+ return isinstance(x, jax_sparse.BCOO) or isinstance(x, jax_sparse.BCSR)
32
+
33
+ return False
34
+
35
+
36
+ def _sparse_ones_like(
37
+ x: types.Tensor, dtype: types.DType | None = None
38
+ ) -> types.Tensor:
39
+ """Creates a tensor of ones with the same sparsity as the input.
40
+
41
+ This differs from `keras.ops.ones_like`, which would create a dense
42
+ tensor of ones.
43
+
44
+ Args:
45
+ x: Input sparse tensor.
46
+ dtype: Optional dtype for the output tensor values.
47
+
48
+ Returns:
49
+ Sparse tensor of ones.
50
+
51
+ Raises:
52
+ ValueError for unsupported sparse input type and backend.
53
+ """
54
+ dtype = dtype or x.dtype
55
+ if keras.backend.backend() == "tensorflow":
56
+ import tensorflow as tf
57
+
58
+ # Ensure shape is copied exactly for compatibility in graph mode.
59
+ x_shape = x.shape
60
+ y = tf.SparseTensor(
61
+ x.indices, tf.ones_like(x.values, dtype=dtype), x.dense_shape
62
+ )
63
+ y.set_shape(x_shape)
64
+ return y
65
+ elif keras.backend.backend() == "jax":
66
+ import jax.numpy as jnp
67
+ from jax.experimental import sparse as jax_sparse
68
+
69
+ if isinstance(x, jax_sparse.BCOO):
70
+ return jax_sparse.BCOO(
71
+ (jnp.ones_like(x.data, dtype=dtype), x.indices),
72
+ shape=x.shape,
73
+ indices_sorted=x.indices_sorted,
74
+ unique_indices=x.unique_indices,
75
+ )
76
+ elif isinstance(x, jax_sparse.BCSR):
77
+ return jax_sparse.BCSR(
78
+ (jnp.ones_like(x.data, dtype=dtype), x.indices, x.indptr),
79
+ shape=x.shape,
80
+ indices_sorted=x.indices_sorted,
81
+ unique_indices=x.unique_indices,
82
+ )
83
+
84
+ raise ValueError(
85
+ f"Unsupported sparse input type '{x.__class__.__name__}' for backend "
86
+ f"{keras.backend.backend()}."
87
+ )
88
+
89
+
90
+ @keras_rs_export("keras_rs.layers.EmbedReduce")
91
+ class EmbedReduce(keras.layers.Embedding):
92
+ """An embedding layer that reduces with a combiner.
93
+
94
+ This layer embeds inputs and then applies a reduction to combine a set of
95
+ embeddings into a single embedding. This is typically used to embed a
96
+ sequence of items as a single embedding.
97
+
98
+ If the inputs passed to `__call__` are 1D, no reduction is applied. If the
99
+ inputs are 2D, dimension 1 is reduced using the combiner so that the result
100
+ is of shape `(batch_size, output_dim`). Inputs of rank 3 and higher are not
101
+ allowed. Weights can optionally be passed to the `__call__` method to
102
+ apply weights to different samples before reduction.
103
+
104
+ This layer supports sparse inputs and ragged inputs with backends that
105
+ support them. The output after reduction is dense. For ragged inputs, the
106
+ ragged dimension must be 1 as it is the dimension that is reduced.
107
+
108
+ Args:
109
+ input_dim: Integer. Size of the vocabulary, maximum integer index + 1.
110
+ output_dim: Integer. Dimension of the dense embedding.
111
+ embeddings_initializer: Initializer for the `embeddings` matrix (see
112
+ `keras.initializers`).
113
+ embeddings_regularizer: Regularizer function applied to the `embeddings`
114
+ matrix (see `keras.regularizers`).
115
+ embeddings_constraint: Constraint function applied to the `embeddings`
116
+ matrix (see `keras.constraints`).
117
+ mask_zero: Boolean, whether or not the input value 0 is a special
118
+ "padding" value that should be masked out. This is useful when using
119
+ recurrent layers which may take variable length input. If this is
120
+ `True`, then all subsequent layers in the model need to support
121
+ masking or an exception will be raised. If `mask_zero` is set to
122
+ `True`, as a consequence, index 0 cannot be used in the vocabulary
123
+ (`input_dim` should equal size of vocabulary + 1).
124
+ weights: Optional floating-point matrix of size
125
+ `(input_dim, output_dim)`. The initial embeddings values to use.
126
+ combiner: Specifies how to reduce if there are multiple entries in a
127
+ single row. Currently `mean`, `sqrtn` and `sum` are supported.
128
+ `mean` is the default. `sqrtn` often achieves good accuracy, in
129
+ particular with bag-of-words columns.
130
+ **kwargs: Additional keyword arguments passed to `Embedding`.
131
+ """
132
+
133
+ def __init__(
134
+ self,
135
+ input_dim: int,
136
+ output_dim: int,
137
+ embeddings_initializer: types.InitializerLike = "uniform",
138
+ embeddings_regularizer: types.RegularizerLike | None = None,
139
+ embeddings_constraint: types.ConstraintLike | None = None,
140
+ mask_zero: bool = False,
141
+ weights: types.Tensor = None,
142
+ combiner: str = "mean",
143
+ **kwargs: Any,
144
+ ) -> None:
145
+ super().__init__(
146
+ input_dim,
147
+ output_dim,
148
+ embeddings_initializer=embeddings_initializer,
149
+ embeddings_regularizer=embeddings_regularizer,
150
+ embeddings_constraint=embeddings_constraint,
151
+ mask_zero=mask_zero,
152
+ weights=weights,
153
+ **kwargs,
154
+ )
155
+ if combiner not in SUPPORTED_COMBINERS:
156
+ raise ValueError(
157
+ f"Invalid `combiner`: '{combiner}', "
158
+ f"use one of {', '.join(SUPPORTED_COMBINERS)}."
159
+ )
160
+ self.combiner = combiner
161
+
162
+ def call(
163
+ self,
164
+ inputs: types.Tensor,
165
+ weights: types.Tensor | None = None,
166
+ ) -> types.Tensor:
167
+ """Apply embedding and reduction.
168
+
169
+ Args:
170
+ inputs: 1D tensor to embed or 2D tensor to embed and reduce.
171
+ weights: Optional tensor of weights to apply before reduction, which
172
+ can be 1D or 2D and must match for the first dimension of
173
+ `inputs` (1D case) or match the shape of `inputs` (2D case).
174
+
175
+ Returns:
176
+ A dense 2D tensor of shape `(batch_size, output_dim)`.
177
+ """
178
+ x = super().call(inputs)
179
+ unreduced_rank = len(x.shape)
180
+
181
+ # Check that weights has a compatible shape.
182
+ if weights is not None:
183
+ weights_rank = len(weights.shape)
184
+ if weights_rank > unreduced_rank or not check_shapes_compatible(
185
+ x.shape[0:weights_rank], weights.shape
186
+ ):
187
+ raise ValueError(
188
+ f"The shape of `weights`: {weights.shape} is not compatible"
189
+ f" with the shape of `inputs` after embedding: {x.shape}."
190
+ )
191
+
192
+ dtype = (
193
+ x.dtype
194
+ if weights is None
195
+ else keras.backend.result_type(x.dtype, weights.dtype)
196
+ )
197
+
198
+ # When `weights` is `None`:
199
+ # - For ragged inputs, after embedding, we get a ragged result that has
200
+ # a ragged dimension of 1, but when we do the "mean" or "sqrtn", we
201
+ # need to divide by the number of items in each row. However, there is
202
+ # no explicit cross backend API to get the row length. `ones_like`
203
+ # gives us a ragged tensor that is ragged in the same way as the
204
+ # inputs. When we do `ops.sum(weights, axis=-2)`, it gives us the
205
+ # number of items per row.
206
+ # - For sparse inputs, after embedding, we get a dense tensor, not a
207
+ # sparse tensor. What it does for missing values is use embedding 0.
208
+ # These are bogus embedding and should be ignored. `ones_like` gives
209
+ # us a sparse tensor with the exact same missing values. Later, when
210
+ # we do `x = ops.multiply(x, weights)`, which masks the bogus values
211
+ # (note that `weights` has been densified beforehand). Additionally,
212
+ # when we do `ops.sum(weights, axis=-2)`, it gives us the number of
213
+ # items per row.
214
+ #
215
+ # When `unreduced_rank <= 2`, this means that the inputs where 1D and
216
+ # dense, there is only one embedding per row, so there is no real
217
+ # reduction is going on.
218
+ # - For mean: result = weights * x / weights = x we don't need `weights`
219
+ # - For sqrtn: result = weights * x / sqrt(square(weights)) = x we don't
220
+ # needs `weights`
221
+ # - For sum however: `result = weights * x` we do need `weights`.
222
+ # So for mean and sqrtn we don't need the weights, we use ones instead.
223
+ # This is to avoid divisions by zero and improve the precision.
224
+ if weights is None or (unreduced_rank <= 2 and self.combiner != "sum"):
225
+ # Discard the weights if there were some and create a mask for
226
+ # ragged and sparse tensors to mask the result correctly (sparse
227
+ # only) and the apply the reduction correctly (ragged and sparse).
228
+ if _is_supported_sparse(inputs):
229
+ weights = _sparse_ones_like(inputs, dtype=dtype)
230
+ else:
231
+ weights = ops.ones_like(inputs, dtype=dtype)
232
+
233
+ else:
234
+ weights = ops.cast(weights, dtype)
235
+
236
+ # When looking up using sparse indices, the result is dense but contains
237
+ # values that should be ignored as all missing values use index 0. We
238
+ # use `weights` as a mask, but it needs to be densified as
239
+ # `expand_dims` and broadcasting a sparse tensor does not produce the
240
+ # expected result.
241
+ weights = ops.convert_to_tensor(weights, sparse=False)
242
+
243
+ # Make weights and the unreduced embeddings have the same rank.
244
+ weights_rank = len(weights.shape)
245
+ if weights_rank < unreduced_rank:
246
+ weights = ops.expand_dims(
247
+ weights, axis=tuple(range(weights_rank, unreduced_rank))
248
+ )
249
+
250
+ # Note that `x` and `weights` are:
251
+ # - ragged if `inputs` was ragged and `weights` was ragged or None
252
+ # - dense otherwise (even if `inputs` and `weights` were sparse).
253
+ x = ops.multiply(x, weights)
254
+
255
+ if unreduced_rank <= 2:
256
+ # No reduction is applied.
257
+ return x
258
+
259
+ # After this reduction, `x` is always dense as we reduce the ragged
260
+ # dimension in the ragged case.
261
+ x = ops.sum(x, axis=-2)
262
+
263
+ # Apply the right divisor for the combiner.
264
+ # Where we use `weights` in the divisor, we use
265
+ # `ops.sum(weights, axis=-2)` which always makes it dense as we reduce
266
+ # the ragged dimension in the ragged case.
267
+ if self.combiner == "mean":
268
+ return ops.divide_no_nan(x, ops.sum(weights, axis=-2))
269
+ elif self.combiner == "sum":
270
+ return x
271
+ elif self.combiner == "sqrtn":
272
+ return ops.divide_no_nan(
273
+ x, ops.sqrt(ops.sum(ops.square(weights), axis=-2))
274
+ )
275
+
276
+ def compute_output_shape(
277
+ self,
278
+ input_shape: types.Shape,
279
+ weights_shape: types.Shape | None = None,
280
+ ) -> types.Shape:
281
+ del weights_shape
282
+
283
+ if len(input_shape) <= 1:
284
+ # No reduction
285
+ return (*input_shape, self.output_dim)
286
+ else:
287
+ # Reduce last dimension
288
+ return (*input_shape[0:-1], self.output_dim)
289
+
290
+ def compute_output_spec(
291
+ self,
292
+ inputs: keras.KerasTensor,
293
+ weights: keras.KerasTensor | None = None,
294
+ ) -> keras.KerasTensor:
295
+ del weights
296
+
297
+ output_shape = self.compute_output_shape(inputs.shape)
298
+ return keras.KerasTensor(output_shape, dtype=self.compute_dtype)
299
+
300
+ def get_config(self) -> dict[str, Any]:
301
+ config: dict[str, Any] = super().get_config()
302
+
303
+ config.update(
304
+ {
305
+ "combiner": self.combiner,
306
+ }
307
+ )
308
+
309
+ return config
File without changes
@@ -0,0 +1,104 @@
1
+ """A Wrapper over orbax CheckpointManager for Keras3 Jax TPU Embeddings."""
2
+
3
+ from typing import Any
4
+
5
+ import keras
6
+ import orbax.checkpoint as ocp
7
+ from etils import epath
8
+
9
+
10
+ class JaxKeras3CheckpointManager(ocp.CheckpointManager):
11
+ """A wrapper over orbax CheckpointManager for Keras3 Jax TPU Embeddings."""
12
+
13
+ def __init__(
14
+ self,
15
+ model: keras.Model,
16
+ checkpoint_dir: epath.PathLike,
17
+ max_to_keep: int,
18
+ steps_per_epoch: int = 1,
19
+ **kwargs: Any,
20
+ ):
21
+ options = ocp.CheckpointManagerOptions(
22
+ max_to_keep=max_to_keep, enable_async_checkpointing=False, **kwargs
23
+ )
24
+ self._model = model
25
+ self._steps_per_epoch = steps_per_epoch
26
+ self._checkpoint_dir = checkpoint_dir
27
+ super().__init__(checkpoint_dir, options=options)
28
+
29
+ def _get_state(self) -> tuple[dict[str, Any], Any | None]:
30
+ """Gets the model state and metrics"""
31
+ model_state = self._model.get_state_tree()
32
+ state = {}
33
+ metrics = None
34
+ for k, v in model_state.items():
35
+ if k == "metrics_variables":
36
+ metrics = v
37
+ else:
38
+ state[k] = v
39
+ return state, metrics
40
+
41
+ def save_state(self, epoch: int) -> None:
42
+ """Saves the model to the checkpoint directory.
43
+
44
+ Args:
45
+ epoch: The epoch number at which the state is saved.
46
+ """
47
+ state, metrics_value = self._get_state()
48
+ self.save(
49
+ epoch * self._steps_per_epoch,
50
+ args=ocp.args.StandardSave(item=state),
51
+ metrics=metrics_value,
52
+ )
53
+
54
+ def restore_state(self, step: int | None = None) -> None:
55
+ """Restores the model from the checkpoint directory.
56
+
57
+ Args:
58
+ step: The step .number to restore the state from. Default=None
59
+ restores the latest step.
60
+ """
61
+ if step is None:
62
+ step = self.latest_step()
63
+ # Restore the model state only, not metrics.
64
+ state, _ = self._get_state()
65
+ restored_state = self.restore(
66
+ step, args=ocp.args.StandardRestore(item=state)
67
+ )
68
+ self._model.set_state_tree(restored_state)
69
+
70
+
71
+ class JaxKeras3CheckpointCallback(keras.callbacks.Callback):
72
+ """A callback for checkpointing and restoring state using Orbax."""
73
+
74
+ def __init__(
75
+ self,
76
+ model: keras.Model,
77
+ checkpoint_dir: epath.PathLike,
78
+ max_to_keep: int,
79
+ steps_per_epoch: int = 1,
80
+ **kwargs: Any,
81
+ ):
82
+ if keras.backend.backend() != "jax":
83
+ raise ValueError(
84
+ "`JaxKeras3CheckpointCallback` is only supported on a "
85
+ "`jax` backend."
86
+ )
87
+ self._checkpoint_manager = JaxKeras3CheckpointManager(
88
+ model, checkpoint_dir, max_to_keep, steps_per_epoch, **kwargs
89
+ )
90
+
91
+ def on_train_begin(self, logs: dict[str, Any] | None = None) -> None:
92
+ if not self.model.built or not self.model.optimizer.built:
93
+ raise ValueError(
94
+ "To use `JaxKeras3CheckpointCallback`, your model and "
95
+ "optimizer must be built before you call `fit()`."
96
+ )
97
+ latest_epoch = self._checkpoint_manager.latest_step()
98
+ if latest_epoch is not None:
99
+ self._checkpoint_manager.restore_state(step=latest_epoch)
100
+
101
+ def on_epoch_end(
102
+ self, epoch: int, logs: dict[str, Any] | None = None
103
+ ) -> None:
104
+ self._checkpoint_manager.save_state(epoch)