keras-rs-nightly 0.0.1.dev2025021903__py3-none-any.whl → 0.3.1.dev202512130338__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_rs/__init__.py +9 -28
- keras_rs/layers/__init__.py +37 -0
- keras_rs/losses/__init__.py +19 -0
- keras_rs/metrics/__init__.py +16 -0
- keras_rs/src/layers/embedding/base_distributed_embedding.py +1151 -0
- keras_rs/src/layers/embedding/distributed_embedding.py +33 -0
- keras_rs/src/layers/embedding/distributed_embedding_config.py +132 -0
- keras_rs/src/layers/embedding/embed_reduce.py +309 -0
- keras_rs/src/layers/embedding/jax/__init__.py +0 -0
- keras_rs/src/layers/embedding/jax/checkpoint_utils.py +104 -0
- keras_rs/src/layers/embedding/jax/config_conversion.py +468 -0
- keras_rs/src/layers/embedding/jax/distributed_embedding.py +829 -0
- keras_rs/src/layers/embedding/jax/embedding_lookup.py +276 -0
- keras_rs/src/layers/embedding/jax/embedding_utils.py +217 -0
- keras_rs/src/layers/embedding/tensorflow/__init__.py +0 -0
- keras_rs/src/layers/embedding/tensorflow/config_conversion.py +363 -0
- keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py +436 -0
- keras_rs/src/layers/feature_interaction/__init__.py +0 -0
- keras_rs/src/layers/{modeling → feature_interaction}/dot_interaction.py +116 -25
- keras_rs/src/layers/{modeling → feature_interaction}/feature_cross.py +40 -22
- keras_rs/src/layers/retrieval/brute_force_retrieval.py +16 -65
- keras_rs/src/layers/retrieval/hard_negative_mining.py +94 -0
- keras_rs/src/layers/retrieval/remove_accidental_hits.py +97 -0
- keras_rs/src/layers/retrieval/retrieval.py +127 -0
- keras_rs/src/layers/retrieval/sampling_probability_correction.py +63 -0
- keras_rs/src/losses/__init__.py +0 -0
- keras_rs/src/losses/list_mle_loss.py +212 -0
- keras_rs/src/losses/pairwise_hinge_loss.py +90 -0
- keras_rs/src/losses/pairwise_logistic_loss.py +99 -0
- keras_rs/src/losses/pairwise_loss.py +165 -0
- keras_rs/src/losses/pairwise_loss_utils.py +39 -0
- keras_rs/src/losses/pairwise_mean_squared_error.py +133 -0
- keras_rs/src/losses/pairwise_soft_zero_one_loss.py +98 -0
- keras_rs/src/metrics/__init__.py +0 -0
- keras_rs/src/metrics/dcg.py +161 -0
- keras_rs/src/metrics/mean_average_precision.py +130 -0
- keras_rs/src/metrics/mean_reciprocal_rank.py +121 -0
- keras_rs/src/metrics/ndcg.py +197 -0
- keras_rs/src/metrics/precision_at_k.py +117 -0
- keras_rs/src/metrics/ranking_metric.py +260 -0
- keras_rs/src/metrics/ranking_metrics_utils.py +257 -0
- keras_rs/src/metrics/recall_at_k.py +108 -0
- keras_rs/src/metrics/utils.py +70 -0
- keras_rs/src/types.py +43 -14
- keras_rs/src/utils/doc_string_utils.py +53 -0
- keras_rs/src/utils/keras_utils.py +52 -3
- keras_rs/src/utils/tpu_test_utils.py +120 -0
- keras_rs/src/version.py +1 -1
- {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/METADATA +88 -8
- keras_rs_nightly-0.3.1.dev202512130338.dist-info/RECORD +58 -0
- {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/WHEEL +1 -1
- keras_rs/api/__init__.py +0 -9
- keras_rs/api/layers/__init__.py +0 -11
- keras_rs_nightly-0.0.1.dev2025021903.dist-info/RECORD +0 -19
- /keras_rs/src/layers/{modeling → embedding}/__init__.py +0 -0
- {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
import importlib.util
|
|
2
|
+
import platform
|
|
3
|
+
import sys
|
|
4
|
+
|
|
5
|
+
import keras
|
|
6
|
+
|
|
7
|
+
from keras_rs.src.api_export import keras_rs_export
|
|
8
|
+
|
|
9
|
+
# JAX distributed embedding is only available on linux_x86_64, and only if
|
|
10
|
+
# jax-tpu-embedding is installed.
|
|
11
|
+
jax_tpu_embedding_spec = importlib.util.find_spec("jax_tpu_embedding")
|
|
12
|
+
if (
|
|
13
|
+
keras.backend.backend() == "jax"
|
|
14
|
+
and sys.platform == "linux"
|
|
15
|
+
and platform.machine().lower() == "x86_64"
|
|
16
|
+
and jax_tpu_embedding_spec is not None
|
|
17
|
+
):
|
|
18
|
+
from keras_rs.src.layers.embedding.jax.distributed_embedding import (
|
|
19
|
+
DistributedEmbedding as BackendDistributedEmbedding,
|
|
20
|
+
)
|
|
21
|
+
elif keras.backend.backend() == "tensorflow":
|
|
22
|
+
from keras_rs.src.layers.embedding.tensorflow.distributed_embedding import (
|
|
23
|
+
DistributedEmbedding as BackendDistributedEmbedding,
|
|
24
|
+
)
|
|
25
|
+
else:
|
|
26
|
+
from keras_rs.src.layers.embedding.base_distributed_embedding import (
|
|
27
|
+
DistributedEmbedding as BackendDistributedEmbedding,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@keras_rs_export("keras_rs.layers.DistributedEmbedding")
|
|
32
|
+
class DistributedEmbedding(BackendDistributedEmbedding):
|
|
33
|
+
pass
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
"""Configuration for TPU embedding layer."""
|
|
2
|
+
|
|
3
|
+
import dataclasses
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import keras
|
|
7
|
+
|
|
8
|
+
from keras_rs.src import types
|
|
9
|
+
from keras_rs.src.api_export import keras_rs_export
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@keras_rs_export("keras_rs.layers.TableConfig")
|
|
13
|
+
@dataclasses.dataclass(order=True)
|
|
14
|
+
class TableConfig:
|
|
15
|
+
"""Configuration for one embedding table.
|
|
16
|
+
|
|
17
|
+
Configures one table for use by one or more `keras_rs.layers.FeatureConfig`,
|
|
18
|
+
which in turn is used to configure a `keras_rs.layers.DistributedEmbedding`.
|
|
19
|
+
|
|
20
|
+
Attributes:
|
|
21
|
+
name: The name of the table. Must be defined.
|
|
22
|
+
vocabulary_size: Size of the table's vocabulary (number of rows).
|
|
23
|
+
embedding_dim: The embedding dimension (width) of the table.
|
|
24
|
+
initializer: The initializer for the embedding weights. If not
|
|
25
|
+
specified, defaults to `truncated_normal_initializer` with mean
|
|
26
|
+
`0.0` and standard deviation `1 / sqrt(embedding_dim)`.
|
|
27
|
+
optimizer: The optimizer for the embedding table. Only SGD, Adagrad,
|
|
28
|
+
Adam, and FTRL are supported. Note that not all of the optimizer's
|
|
29
|
+
parameters are supported. Defaults to Adam.
|
|
30
|
+
combiner: Specifies how to reduce if there are multiple entries in a
|
|
31
|
+
single row. `mean`, `sqrtn` and `sum` are supported. `mean` is the
|
|
32
|
+
default. `sqrtn` often achieves good accuracy, in particular with
|
|
33
|
+
bag-of-words columns.
|
|
34
|
+
placement: Where to place the embedding table. `"auto"`, which is the
|
|
35
|
+
default, means that the table is placed on SparseCore if available,
|
|
36
|
+
otherwise on the default device where the rest of the model is
|
|
37
|
+
placed. A value of `"sparsecore"` means the table will be placed on
|
|
38
|
+
the SparseCore chips and an error is raised if SparseCore is not
|
|
39
|
+
available. A value of `"default_device"` means the table will be
|
|
40
|
+
placed on the default device where the rest of the model is placed,
|
|
41
|
+
even if SparseCore is available. The default device for the rest of
|
|
42
|
+
the model is the TPU's TensorCore on TPUs, otherwise the GPU or CPU.
|
|
43
|
+
max_ids_per_partition: The max number of ids per partition for the
|
|
44
|
+
table. This is an input data dependent value and is required by the
|
|
45
|
+
compiler to appropriately allocate memory.
|
|
46
|
+
max_unique_ids_per_partition: The max number of unique ids per partition
|
|
47
|
+
for the table. This is an input data dependent value and is required
|
|
48
|
+
by the compiler to appropriately allocate memory.
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
name: str
|
|
52
|
+
vocabulary_size: int
|
|
53
|
+
embedding_dim: int
|
|
54
|
+
initializer: str | keras.initializers.Initializer = (
|
|
55
|
+
keras.initializers.VarianceScaling(mode="fan_out")
|
|
56
|
+
)
|
|
57
|
+
optimizer: str | keras.optimizers.Optimizer = "adam"
|
|
58
|
+
combiner: str = "mean"
|
|
59
|
+
placement: str = "auto"
|
|
60
|
+
max_ids_per_partition: int = 256
|
|
61
|
+
max_unique_ids_per_partition: int = 256
|
|
62
|
+
|
|
63
|
+
def get_config(self) -> dict[str, Any]:
|
|
64
|
+
return {
|
|
65
|
+
"name": self.name,
|
|
66
|
+
"vocabulary_size": self.vocabulary_size,
|
|
67
|
+
"embedding_dim": self.embedding_dim,
|
|
68
|
+
"initializer": keras.saving.serialize_keras_object(
|
|
69
|
+
self.initializer
|
|
70
|
+
),
|
|
71
|
+
"optimizer": keras.saving.serialize_keras_object(self.optimizer),
|
|
72
|
+
"combiner": self.combiner,
|
|
73
|
+
"placement": self.placement,
|
|
74
|
+
"max_ids_per_partition": self.max_ids_per_partition,
|
|
75
|
+
"max_unique_ids_per_partition": self.max_unique_ids_per_partition,
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
@classmethod
|
|
79
|
+
def from_config(cls, config: dict[str, Any]) -> "TableConfig":
|
|
80
|
+
config = config.copy()
|
|
81
|
+
config["initializer"] = keras.saving.deserialize_keras_object(
|
|
82
|
+
config["initializer"]
|
|
83
|
+
)
|
|
84
|
+
config["optimizer"] = keras.saving.deserialize_keras_object(
|
|
85
|
+
config["optimizer"]
|
|
86
|
+
)
|
|
87
|
+
return cls(**config)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@keras_rs_export("keras_rs.layers.FeatureConfig")
|
|
91
|
+
@dataclasses.dataclass(order=True)
|
|
92
|
+
class FeatureConfig:
|
|
93
|
+
"""Configuration for one embedding feature.
|
|
94
|
+
|
|
95
|
+
Configures one feature for `keras_rs.layers.DistributedEmbedding`. Each
|
|
96
|
+
feature uses a table configured via `keras_rs.layers.TableConfig` and
|
|
97
|
+
multiple features can share the same table.
|
|
98
|
+
|
|
99
|
+
Attributes:
|
|
100
|
+
name: The name of the feature. Must be defined.
|
|
101
|
+
table: The table in which to look up this feature.
|
|
102
|
+
input_shape: The input shape of the feature. The feature fed into the
|
|
103
|
+
layer has to match the shape. Note that for ragged dimensions in the
|
|
104
|
+
input, the dimension provided here presents the maximum value;
|
|
105
|
+
anything larger will be truncated. Also note that the first
|
|
106
|
+
dimension represents the global batch size. For example, on TPU,
|
|
107
|
+
this represents the total number of samples that are dispatched to
|
|
108
|
+
all the TPUs connected to the current host.
|
|
109
|
+
output_shape: The output shape of the feature activation. What is
|
|
110
|
+
returned by the embedding layer has to match this shape.
|
|
111
|
+
"""
|
|
112
|
+
|
|
113
|
+
name: str
|
|
114
|
+
table: TableConfig
|
|
115
|
+
input_shape: types.Shape
|
|
116
|
+
output_shape: types.Shape
|
|
117
|
+
|
|
118
|
+
def get_config(self) -> dict[str, Any]:
|
|
119
|
+
return {
|
|
120
|
+
"name": self.name,
|
|
121
|
+
"table": self.table.get_config(),
|
|
122
|
+
"input_shape": self.input_shape,
|
|
123
|
+
"output_shape": self.output_shape,
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
@classmethod
|
|
127
|
+
def from_config(cls, config: dict[str, Any]) -> "FeatureConfig":
|
|
128
|
+
config = config.copy()
|
|
129
|
+
# Note: the handling of shared tables during serialization is done in
|
|
130
|
+
# `DistributedEmbedding.from_config()`.
|
|
131
|
+
config["table"] = TableConfig.from_config(config["table"])
|
|
132
|
+
return cls(**config)
|
|
@@ -0,0 +1,309 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
import keras
|
|
4
|
+
from keras import ops
|
|
5
|
+
|
|
6
|
+
from keras_rs.src import types
|
|
7
|
+
from keras_rs.src.api_export import keras_rs_export
|
|
8
|
+
from keras_rs.src.utils.keras_utils import check_shapes_compatible
|
|
9
|
+
|
|
10
|
+
SUPPORTED_COMBINERS = ("mean", "sum", "sqrtn")
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _is_supported_sparse(x: types.Tensor) -> bool:
|
|
14
|
+
"""Determines if the input is a supported sparse tensor.
|
|
15
|
+
|
|
16
|
+
NOTE: Currently only works for the TensorFlow and JAX backends.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
x: Input tensor to check for sparsity.
|
|
20
|
+
|
|
21
|
+
Returns:
|
|
22
|
+
True if `x` is a supported sparse tensor.
|
|
23
|
+
"""
|
|
24
|
+
if keras.backend.backend() == "tensorflow":
|
|
25
|
+
import tensorflow as tf
|
|
26
|
+
|
|
27
|
+
return isinstance(x, tf.SparseTensor)
|
|
28
|
+
elif keras.backend.backend() == "jax":
|
|
29
|
+
from jax.experimental import sparse as jax_sparse
|
|
30
|
+
|
|
31
|
+
return isinstance(x, jax_sparse.BCOO) or isinstance(x, jax_sparse.BCSR)
|
|
32
|
+
|
|
33
|
+
return False
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _sparse_ones_like(
|
|
37
|
+
x: types.Tensor, dtype: types.DType | None = None
|
|
38
|
+
) -> types.Tensor:
|
|
39
|
+
"""Creates a tensor of ones with the same sparsity as the input.
|
|
40
|
+
|
|
41
|
+
This differs from `keras.ops.ones_like`, which would create a dense
|
|
42
|
+
tensor of ones.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
x: Input sparse tensor.
|
|
46
|
+
dtype: Optional dtype for the output tensor values.
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
Sparse tensor of ones.
|
|
50
|
+
|
|
51
|
+
Raises:
|
|
52
|
+
ValueError for unsupported sparse input type and backend.
|
|
53
|
+
"""
|
|
54
|
+
dtype = dtype or x.dtype
|
|
55
|
+
if keras.backend.backend() == "tensorflow":
|
|
56
|
+
import tensorflow as tf
|
|
57
|
+
|
|
58
|
+
# Ensure shape is copied exactly for compatibility in graph mode.
|
|
59
|
+
x_shape = x.shape
|
|
60
|
+
y = tf.SparseTensor(
|
|
61
|
+
x.indices, tf.ones_like(x.values, dtype=dtype), x.dense_shape
|
|
62
|
+
)
|
|
63
|
+
y.set_shape(x_shape)
|
|
64
|
+
return y
|
|
65
|
+
elif keras.backend.backend() == "jax":
|
|
66
|
+
import jax.numpy as jnp
|
|
67
|
+
from jax.experimental import sparse as jax_sparse
|
|
68
|
+
|
|
69
|
+
if isinstance(x, jax_sparse.BCOO):
|
|
70
|
+
return jax_sparse.BCOO(
|
|
71
|
+
(jnp.ones_like(x.data, dtype=dtype), x.indices),
|
|
72
|
+
shape=x.shape,
|
|
73
|
+
indices_sorted=x.indices_sorted,
|
|
74
|
+
unique_indices=x.unique_indices,
|
|
75
|
+
)
|
|
76
|
+
elif isinstance(x, jax_sparse.BCSR):
|
|
77
|
+
return jax_sparse.BCSR(
|
|
78
|
+
(jnp.ones_like(x.data, dtype=dtype), x.indices, x.indptr),
|
|
79
|
+
shape=x.shape,
|
|
80
|
+
indices_sorted=x.indices_sorted,
|
|
81
|
+
unique_indices=x.unique_indices,
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
raise ValueError(
|
|
85
|
+
f"Unsupported sparse input type '{x.__class__.__name__}' for backend "
|
|
86
|
+
f"{keras.backend.backend()}."
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@keras_rs_export("keras_rs.layers.EmbedReduce")
|
|
91
|
+
class EmbedReduce(keras.layers.Embedding):
|
|
92
|
+
"""An embedding layer that reduces with a combiner.
|
|
93
|
+
|
|
94
|
+
This layer embeds inputs and then applies a reduction to combine a set of
|
|
95
|
+
embeddings into a single embedding. This is typically used to embed a
|
|
96
|
+
sequence of items as a single embedding.
|
|
97
|
+
|
|
98
|
+
If the inputs passed to `__call__` are 1D, no reduction is applied. If the
|
|
99
|
+
inputs are 2D, dimension 1 is reduced using the combiner so that the result
|
|
100
|
+
is of shape `(batch_size, output_dim`). Inputs of rank 3 and higher are not
|
|
101
|
+
allowed. Weights can optionally be passed to the `__call__` method to
|
|
102
|
+
apply weights to different samples before reduction.
|
|
103
|
+
|
|
104
|
+
This layer supports sparse inputs and ragged inputs with backends that
|
|
105
|
+
support them. The output after reduction is dense. For ragged inputs, the
|
|
106
|
+
ragged dimension must be 1 as it is the dimension that is reduced.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
input_dim: Integer. Size of the vocabulary, maximum integer index + 1.
|
|
110
|
+
output_dim: Integer. Dimension of the dense embedding.
|
|
111
|
+
embeddings_initializer: Initializer for the `embeddings` matrix (see
|
|
112
|
+
`keras.initializers`).
|
|
113
|
+
embeddings_regularizer: Regularizer function applied to the `embeddings`
|
|
114
|
+
matrix (see `keras.regularizers`).
|
|
115
|
+
embeddings_constraint: Constraint function applied to the `embeddings`
|
|
116
|
+
matrix (see `keras.constraints`).
|
|
117
|
+
mask_zero: Boolean, whether or not the input value 0 is a special
|
|
118
|
+
"padding" value that should be masked out. This is useful when using
|
|
119
|
+
recurrent layers which may take variable length input. If this is
|
|
120
|
+
`True`, then all subsequent layers in the model need to support
|
|
121
|
+
masking or an exception will be raised. If `mask_zero` is set to
|
|
122
|
+
`True`, as a consequence, index 0 cannot be used in the vocabulary
|
|
123
|
+
(`input_dim` should equal size of vocabulary + 1).
|
|
124
|
+
weights: Optional floating-point matrix of size
|
|
125
|
+
`(input_dim, output_dim)`. The initial embeddings values to use.
|
|
126
|
+
combiner: Specifies how to reduce if there are multiple entries in a
|
|
127
|
+
single row. Currently `mean`, `sqrtn` and `sum` are supported.
|
|
128
|
+
`mean` is the default. `sqrtn` often achieves good accuracy, in
|
|
129
|
+
particular with bag-of-words columns.
|
|
130
|
+
**kwargs: Additional keyword arguments passed to `Embedding`.
|
|
131
|
+
"""
|
|
132
|
+
|
|
133
|
+
def __init__(
|
|
134
|
+
self,
|
|
135
|
+
input_dim: int,
|
|
136
|
+
output_dim: int,
|
|
137
|
+
embeddings_initializer: types.InitializerLike = "uniform",
|
|
138
|
+
embeddings_regularizer: types.RegularizerLike | None = None,
|
|
139
|
+
embeddings_constraint: types.ConstraintLike | None = None,
|
|
140
|
+
mask_zero: bool = False,
|
|
141
|
+
weights: types.Tensor = None,
|
|
142
|
+
combiner: str = "mean",
|
|
143
|
+
**kwargs: Any,
|
|
144
|
+
) -> None:
|
|
145
|
+
super().__init__(
|
|
146
|
+
input_dim,
|
|
147
|
+
output_dim,
|
|
148
|
+
embeddings_initializer=embeddings_initializer,
|
|
149
|
+
embeddings_regularizer=embeddings_regularizer,
|
|
150
|
+
embeddings_constraint=embeddings_constraint,
|
|
151
|
+
mask_zero=mask_zero,
|
|
152
|
+
weights=weights,
|
|
153
|
+
**kwargs,
|
|
154
|
+
)
|
|
155
|
+
if combiner not in SUPPORTED_COMBINERS:
|
|
156
|
+
raise ValueError(
|
|
157
|
+
f"Invalid `combiner`: '{combiner}', "
|
|
158
|
+
f"use one of {', '.join(SUPPORTED_COMBINERS)}."
|
|
159
|
+
)
|
|
160
|
+
self.combiner = combiner
|
|
161
|
+
|
|
162
|
+
def call(
|
|
163
|
+
self,
|
|
164
|
+
inputs: types.Tensor,
|
|
165
|
+
weights: types.Tensor | None = None,
|
|
166
|
+
) -> types.Tensor:
|
|
167
|
+
"""Apply embedding and reduction.
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
inputs: 1D tensor to embed or 2D tensor to embed and reduce.
|
|
171
|
+
weights: Optional tensor of weights to apply before reduction, which
|
|
172
|
+
can be 1D or 2D and must match for the first dimension of
|
|
173
|
+
`inputs` (1D case) or match the shape of `inputs` (2D case).
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
A dense 2D tensor of shape `(batch_size, output_dim)`.
|
|
177
|
+
"""
|
|
178
|
+
x = super().call(inputs)
|
|
179
|
+
unreduced_rank = len(x.shape)
|
|
180
|
+
|
|
181
|
+
# Check that weights has a compatible shape.
|
|
182
|
+
if weights is not None:
|
|
183
|
+
weights_rank = len(weights.shape)
|
|
184
|
+
if weights_rank > unreduced_rank or not check_shapes_compatible(
|
|
185
|
+
x.shape[0:weights_rank], weights.shape
|
|
186
|
+
):
|
|
187
|
+
raise ValueError(
|
|
188
|
+
f"The shape of `weights`: {weights.shape} is not compatible"
|
|
189
|
+
f" with the shape of `inputs` after embedding: {x.shape}."
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
dtype = (
|
|
193
|
+
x.dtype
|
|
194
|
+
if weights is None
|
|
195
|
+
else keras.backend.result_type(x.dtype, weights.dtype)
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
# When `weights` is `None`:
|
|
199
|
+
# - For ragged inputs, after embedding, we get a ragged result that has
|
|
200
|
+
# a ragged dimension of 1, but when we do the "mean" or "sqrtn", we
|
|
201
|
+
# need to divide by the number of items in each row. However, there is
|
|
202
|
+
# no explicit cross backend API to get the row length. `ones_like`
|
|
203
|
+
# gives us a ragged tensor that is ragged in the same way as the
|
|
204
|
+
# inputs. When we do `ops.sum(weights, axis=-2)`, it gives us the
|
|
205
|
+
# number of items per row.
|
|
206
|
+
# - For sparse inputs, after embedding, we get a dense tensor, not a
|
|
207
|
+
# sparse tensor. What it does for missing values is use embedding 0.
|
|
208
|
+
# These are bogus embedding and should be ignored. `ones_like` gives
|
|
209
|
+
# us a sparse tensor with the exact same missing values. Later, when
|
|
210
|
+
# we do `x = ops.multiply(x, weights)`, which masks the bogus values
|
|
211
|
+
# (note that `weights` has been densified beforehand). Additionally,
|
|
212
|
+
# when we do `ops.sum(weights, axis=-2)`, it gives us the number of
|
|
213
|
+
# items per row.
|
|
214
|
+
#
|
|
215
|
+
# When `unreduced_rank <= 2`, this means that the inputs where 1D and
|
|
216
|
+
# dense, there is only one embedding per row, so there is no real
|
|
217
|
+
# reduction is going on.
|
|
218
|
+
# - For mean: result = weights * x / weights = x we don't need `weights`
|
|
219
|
+
# - For sqrtn: result = weights * x / sqrt(square(weights)) = x we don't
|
|
220
|
+
# needs `weights`
|
|
221
|
+
# - For sum however: `result = weights * x` we do need `weights`.
|
|
222
|
+
# So for mean and sqrtn we don't need the weights, we use ones instead.
|
|
223
|
+
# This is to avoid divisions by zero and improve the precision.
|
|
224
|
+
if weights is None or (unreduced_rank <= 2 and self.combiner != "sum"):
|
|
225
|
+
# Discard the weights if there were some and create a mask for
|
|
226
|
+
# ragged and sparse tensors to mask the result correctly (sparse
|
|
227
|
+
# only) and the apply the reduction correctly (ragged and sparse).
|
|
228
|
+
if _is_supported_sparse(inputs):
|
|
229
|
+
weights = _sparse_ones_like(inputs, dtype=dtype)
|
|
230
|
+
else:
|
|
231
|
+
weights = ops.ones_like(inputs, dtype=dtype)
|
|
232
|
+
|
|
233
|
+
else:
|
|
234
|
+
weights = ops.cast(weights, dtype)
|
|
235
|
+
|
|
236
|
+
# When looking up using sparse indices, the result is dense but contains
|
|
237
|
+
# values that should be ignored as all missing values use index 0. We
|
|
238
|
+
# use `weights` as a mask, but it needs to be densified as
|
|
239
|
+
# `expand_dims` and broadcasting a sparse tensor does not produce the
|
|
240
|
+
# expected result.
|
|
241
|
+
weights = ops.convert_to_tensor(weights, sparse=False)
|
|
242
|
+
|
|
243
|
+
# Make weights and the unreduced embeddings have the same rank.
|
|
244
|
+
weights_rank = len(weights.shape)
|
|
245
|
+
if weights_rank < unreduced_rank:
|
|
246
|
+
weights = ops.expand_dims(
|
|
247
|
+
weights, axis=tuple(range(weights_rank, unreduced_rank))
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
# Note that `x` and `weights` are:
|
|
251
|
+
# - ragged if `inputs` was ragged and `weights` was ragged or None
|
|
252
|
+
# - dense otherwise (even if `inputs` and `weights` were sparse).
|
|
253
|
+
x = ops.multiply(x, weights)
|
|
254
|
+
|
|
255
|
+
if unreduced_rank <= 2:
|
|
256
|
+
# No reduction is applied.
|
|
257
|
+
return x
|
|
258
|
+
|
|
259
|
+
# After this reduction, `x` is always dense as we reduce the ragged
|
|
260
|
+
# dimension in the ragged case.
|
|
261
|
+
x = ops.sum(x, axis=-2)
|
|
262
|
+
|
|
263
|
+
# Apply the right divisor for the combiner.
|
|
264
|
+
# Where we use `weights` in the divisor, we use
|
|
265
|
+
# `ops.sum(weights, axis=-2)` which always makes it dense as we reduce
|
|
266
|
+
# the ragged dimension in the ragged case.
|
|
267
|
+
if self.combiner == "mean":
|
|
268
|
+
return ops.divide_no_nan(x, ops.sum(weights, axis=-2))
|
|
269
|
+
elif self.combiner == "sum":
|
|
270
|
+
return x
|
|
271
|
+
elif self.combiner == "sqrtn":
|
|
272
|
+
return ops.divide_no_nan(
|
|
273
|
+
x, ops.sqrt(ops.sum(ops.square(weights), axis=-2))
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
def compute_output_shape(
|
|
277
|
+
self,
|
|
278
|
+
input_shape: types.Shape,
|
|
279
|
+
weights_shape: types.Shape | None = None,
|
|
280
|
+
) -> types.Shape:
|
|
281
|
+
del weights_shape
|
|
282
|
+
|
|
283
|
+
if len(input_shape) <= 1:
|
|
284
|
+
# No reduction
|
|
285
|
+
return (*input_shape, self.output_dim)
|
|
286
|
+
else:
|
|
287
|
+
# Reduce last dimension
|
|
288
|
+
return (*input_shape[0:-1], self.output_dim)
|
|
289
|
+
|
|
290
|
+
def compute_output_spec(
|
|
291
|
+
self,
|
|
292
|
+
inputs: keras.KerasTensor,
|
|
293
|
+
weights: keras.KerasTensor | None = None,
|
|
294
|
+
) -> keras.KerasTensor:
|
|
295
|
+
del weights
|
|
296
|
+
|
|
297
|
+
output_shape = self.compute_output_shape(inputs.shape)
|
|
298
|
+
return keras.KerasTensor(output_shape, dtype=self.compute_dtype)
|
|
299
|
+
|
|
300
|
+
def get_config(self) -> dict[str, Any]:
|
|
301
|
+
config: dict[str, Any] = super().get_config()
|
|
302
|
+
|
|
303
|
+
config.update(
|
|
304
|
+
{
|
|
305
|
+
"combiner": self.combiner,
|
|
306
|
+
}
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
return config
|
|
File without changes
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
"""A Wrapper over orbax CheckpointManager for Keras3 Jax TPU Embeddings."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import keras
|
|
6
|
+
import orbax.checkpoint as ocp
|
|
7
|
+
from etils import epath
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class JaxKeras3CheckpointManager(ocp.CheckpointManager):
|
|
11
|
+
"""A wrapper over orbax CheckpointManager for Keras3 Jax TPU Embeddings."""
|
|
12
|
+
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
model: keras.Model,
|
|
16
|
+
checkpoint_dir: epath.PathLike,
|
|
17
|
+
max_to_keep: int,
|
|
18
|
+
steps_per_epoch: int = 1,
|
|
19
|
+
**kwargs: Any,
|
|
20
|
+
):
|
|
21
|
+
options = ocp.CheckpointManagerOptions(
|
|
22
|
+
max_to_keep=max_to_keep, enable_async_checkpointing=False, **kwargs
|
|
23
|
+
)
|
|
24
|
+
self._model = model
|
|
25
|
+
self._steps_per_epoch = steps_per_epoch
|
|
26
|
+
self._checkpoint_dir = checkpoint_dir
|
|
27
|
+
super().__init__(checkpoint_dir, options=options)
|
|
28
|
+
|
|
29
|
+
def _get_state(self) -> tuple[dict[str, Any], Any | None]:
|
|
30
|
+
"""Gets the model state and metrics"""
|
|
31
|
+
model_state = self._model.get_state_tree()
|
|
32
|
+
state = {}
|
|
33
|
+
metrics = None
|
|
34
|
+
for k, v in model_state.items():
|
|
35
|
+
if k == "metrics_variables":
|
|
36
|
+
metrics = v
|
|
37
|
+
else:
|
|
38
|
+
state[k] = v
|
|
39
|
+
return state, metrics
|
|
40
|
+
|
|
41
|
+
def save_state(self, epoch: int) -> None:
|
|
42
|
+
"""Saves the model to the checkpoint directory.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
epoch: The epoch number at which the state is saved.
|
|
46
|
+
"""
|
|
47
|
+
state, metrics_value = self._get_state()
|
|
48
|
+
self.save(
|
|
49
|
+
epoch * self._steps_per_epoch,
|
|
50
|
+
args=ocp.args.StandardSave(item=state),
|
|
51
|
+
metrics=metrics_value,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
def restore_state(self, step: int | None = None) -> None:
|
|
55
|
+
"""Restores the model from the checkpoint directory.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
step: The step .number to restore the state from. Default=None
|
|
59
|
+
restores the latest step.
|
|
60
|
+
"""
|
|
61
|
+
if step is None:
|
|
62
|
+
step = self.latest_step()
|
|
63
|
+
# Restore the model state only, not metrics.
|
|
64
|
+
state, _ = self._get_state()
|
|
65
|
+
restored_state = self.restore(
|
|
66
|
+
step, args=ocp.args.StandardRestore(item=state)
|
|
67
|
+
)
|
|
68
|
+
self._model.set_state_tree(restored_state)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class JaxKeras3CheckpointCallback(keras.callbacks.Callback):
|
|
72
|
+
"""A callback for checkpointing and restoring state using Orbax."""
|
|
73
|
+
|
|
74
|
+
def __init__(
|
|
75
|
+
self,
|
|
76
|
+
model: keras.Model,
|
|
77
|
+
checkpoint_dir: epath.PathLike,
|
|
78
|
+
max_to_keep: int,
|
|
79
|
+
steps_per_epoch: int = 1,
|
|
80
|
+
**kwargs: Any,
|
|
81
|
+
):
|
|
82
|
+
if keras.backend.backend() != "jax":
|
|
83
|
+
raise ValueError(
|
|
84
|
+
"`JaxKeras3CheckpointCallback` is only supported on a "
|
|
85
|
+
"`jax` backend."
|
|
86
|
+
)
|
|
87
|
+
self._checkpoint_manager = JaxKeras3CheckpointManager(
|
|
88
|
+
model, checkpoint_dir, max_to_keep, steps_per_epoch, **kwargs
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
def on_train_begin(self, logs: dict[str, Any] | None = None) -> None:
|
|
92
|
+
if not self.model.built or not self.model.optimizer.built:
|
|
93
|
+
raise ValueError(
|
|
94
|
+
"To use `JaxKeras3CheckpointCallback`, your model and "
|
|
95
|
+
"optimizer must be built before you call `fit()`."
|
|
96
|
+
)
|
|
97
|
+
latest_epoch = self._checkpoint_manager.latest_step()
|
|
98
|
+
if latest_epoch is not None:
|
|
99
|
+
self._checkpoint_manager.restore_state(step=latest_epoch)
|
|
100
|
+
|
|
101
|
+
def on_epoch_end(
|
|
102
|
+
self, epoch: int, logs: dict[str, Any] | None = None
|
|
103
|
+
) -> None:
|
|
104
|
+
self._checkpoint_manager.save_state(epoch)
|