keras-rs-nightly 0.0.1.dev2025021903__py3-none-any.whl → 0.3.1.dev202512130338__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_rs/__init__.py +9 -28
- keras_rs/layers/__init__.py +37 -0
- keras_rs/losses/__init__.py +19 -0
- keras_rs/metrics/__init__.py +16 -0
- keras_rs/src/layers/embedding/base_distributed_embedding.py +1151 -0
- keras_rs/src/layers/embedding/distributed_embedding.py +33 -0
- keras_rs/src/layers/embedding/distributed_embedding_config.py +132 -0
- keras_rs/src/layers/embedding/embed_reduce.py +309 -0
- keras_rs/src/layers/embedding/jax/__init__.py +0 -0
- keras_rs/src/layers/embedding/jax/checkpoint_utils.py +104 -0
- keras_rs/src/layers/embedding/jax/config_conversion.py +468 -0
- keras_rs/src/layers/embedding/jax/distributed_embedding.py +829 -0
- keras_rs/src/layers/embedding/jax/embedding_lookup.py +276 -0
- keras_rs/src/layers/embedding/jax/embedding_utils.py +217 -0
- keras_rs/src/layers/embedding/tensorflow/__init__.py +0 -0
- keras_rs/src/layers/embedding/tensorflow/config_conversion.py +363 -0
- keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py +436 -0
- keras_rs/src/layers/feature_interaction/__init__.py +0 -0
- keras_rs/src/layers/{modeling → feature_interaction}/dot_interaction.py +116 -25
- keras_rs/src/layers/{modeling → feature_interaction}/feature_cross.py +40 -22
- keras_rs/src/layers/retrieval/brute_force_retrieval.py +16 -65
- keras_rs/src/layers/retrieval/hard_negative_mining.py +94 -0
- keras_rs/src/layers/retrieval/remove_accidental_hits.py +97 -0
- keras_rs/src/layers/retrieval/retrieval.py +127 -0
- keras_rs/src/layers/retrieval/sampling_probability_correction.py +63 -0
- keras_rs/src/losses/__init__.py +0 -0
- keras_rs/src/losses/list_mle_loss.py +212 -0
- keras_rs/src/losses/pairwise_hinge_loss.py +90 -0
- keras_rs/src/losses/pairwise_logistic_loss.py +99 -0
- keras_rs/src/losses/pairwise_loss.py +165 -0
- keras_rs/src/losses/pairwise_loss_utils.py +39 -0
- keras_rs/src/losses/pairwise_mean_squared_error.py +133 -0
- keras_rs/src/losses/pairwise_soft_zero_one_loss.py +98 -0
- keras_rs/src/metrics/__init__.py +0 -0
- keras_rs/src/metrics/dcg.py +161 -0
- keras_rs/src/metrics/mean_average_precision.py +130 -0
- keras_rs/src/metrics/mean_reciprocal_rank.py +121 -0
- keras_rs/src/metrics/ndcg.py +197 -0
- keras_rs/src/metrics/precision_at_k.py +117 -0
- keras_rs/src/metrics/ranking_metric.py +260 -0
- keras_rs/src/metrics/ranking_metrics_utils.py +257 -0
- keras_rs/src/metrics/recall_at_k.py +108 -0
- keras_rs/src/metrics/utils.py +70 -0
- keras_rs/src/types.py +43 -14
- keras_rs/src/utils/doc_string_utils.py +53 -0
- keras_rs/src/utils/keras_utils.py +52 -3
- keras_rs/src/utils/tpu_test_utils.py +120 -0
- keras_rs/src/version.py +1 -1
- {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/METADATA +88 -8
- keras_rs_nightly-0.3.1.dev202512130338.dist-info/RECORD +58 -0
- {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/WHEEL +1 -1
- keras_rs/api/__init__.py +0 -9
- keras_rs/api/layers/__init__.py +0 -11
- keras_rs_nightly-0.0.1.dev2025021903.dist-info/RECORD +0 -19
- /keras_rs/src/layers/{modeling → embedding}/__init__.py +0 -0
- {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,468 @@
|
|
|
1
|
+
"""Conversion utilities for Keras DistributedEmbeddingConfig to JAX."""
|
|
2
|
+
|
|
3
|
+
import inspect
|
|
4
|
+
import math
|
|
5
|
+
import random as python_random
|
|
6
|
+
from typing import Any, Callable
|
|
7
|
+
|
|
8
|
+
import jax
|
|
9
|
+
import jax.numpy as jnp
|
|
10
|
+
import keras
|
|
11
|
+
import numpy as np
|
|
12
|
+
from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec
|
|
13
|
+
|
|
14
|
+
from keras_rs.src import types
|
|
15
|
+
from keras_rs.src.layers.embedding import distributed_embedding_config as config
|
|
16
|
+
from keras_rs.src.types import Nested
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class WrappedKerasInitializer(jax.nn.initializers.Initializer):
|
|
20
|
+
"""Wraps a Keras initializer for use in JAX."""
|
|
21
|
+
|
|
22
|
+
def __init__(self, initializer: keras.initializers.Initializer):
|
|
23
|
+
if isinstance(initializer, str):
|
|
24
|
+
initializer = keras.initializers.get(initializer)
|
|
25
|
+
self.initializer = initializer
|
|
26
|
+
|
|
27
|
+
def key(self) -> jax.Array | None:
|
|
28
|
+
"""Extract a key from the underlying keras initializer."""
|
|
29
|
+
# All built-in keras initializers have a `seed` attribute.
|
|
30
|
+
# Extract this and turn it into a key for use with JAX.
|
|
31
|
+
if hasattr(self.initializer, "seed"):
|
|
32
|
+
output: jax.Array = keras.src.backend.jax.random.jax_draw_seed(
|
|
33
|
+
self.initializer.seed
|
|
34
|
+
)
|
|
35
|
+
return output
|
|
36
|
+
return None
|
|
37
|
+
|
|
38
|
+
def __call__(
|
|
39
|
+
self,
|
|
40
|
+
key: Any,
|
|
41
|
+
shape: Any,
|
|
42
|
+
dtype: Any = jnp.float_,
|
|
43
|
+
out_sharding: Any = None,
|
|
44
|
+
) -> jax.Array:
|
|
45
|
+
del out_sharding
|
|
46
|
+
# Force use of provided key. The JAX backend for random initializers
|
|
47
|
+
# forwards the `seed` attribute to the underlying JAX random functions.
|
|
48
|
+
if key is not None and hasattr(self.initializer, "seed"):
|
|
49
|
+
old_seed = self.initializer.seed
|
|
50
|
+
self.initializer.seed = key
|
|
51
|
+
out: jax.Array = self.initializer(shape, dtype)
|
|
52
|
+
self.initializer.seed = old_seed
|
|
53
|
+
return out
|
|
54
|
+
|
|
55
|
+
output: jax.Array = self.initializer(shape, dtype)
|
|
56
|
+
return output
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
# pylint: disable-next=g-classes-have-attributes
|
|
60
|
+
class WrappedJaxInitializer(keras.initializers.Initializer):
|
|
61
|
+
"""Wraps a JAX initializer for use in Keras.
|
|
62
|
+
|
|
63
|
+
Attributes:
|
|
64
|
+
initializer: The wrapped JAX initializer.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
initializer: The JAX initializer to wrap.
|
|
68
|
+
seed: Optional Keras seed for use with random JAX initializers.
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
def __init__(
|
|
72
|
+
self,
|
|
73
|
+
initializer: jax.nn.initializers.Initializer,
|
|
74
|
+
seed: int | keras.random.SeedGenerator | None = None,
|
|
75
|
+
):
|
|
76
|
+
self.initializer = initializer
|
|
77
|
+
if seed is None:
|
|
78
|
+
# Consistency with keras.random.make_default_seed().
|
|
79
|
+
seed = python_random.randint(1, int(1e9))
|
|
80
|
+
self.seed = seed
|
|
81
|
+
|
|
82
|
+
def key(self) -> jax.Array:
|
|
83
|
+
"""Converts the interal seed to a JAX random key."""
|
|
84
|
+
seed = self.seed
|
|
85
|
+
if isinstance(seed, int):
|
|
86
|
+
return jax.random.key(self.seed)
|
|
87
|
+
elif isinstance(seed, keras.random.SeedGenerator):
|
|
88
|
+
return jax.random.key(seed.next())
|
|
89
|
+
elif isinstance(seed, jax.Array):
|
|
90
|
+
return seed
|
|
91
|
+
else:
|
|
92
|
+
raise ValueError(f"Unknown seed {seed} of type {type(seed)}.")
|
|
93
|
+
|
|
94
|
+
def __call__(
|
|
95
|
+
self,
|
|
96
|
+
shape: types.Shape,
|
|
97
|
+
dtype: types.DType | None = None,
|
|
98
|
+
**kwargs: Any,
|
|
99
|
+
) -> jax.Array:
|
|
100
|
+
del kwargs # Unused.
|
|
101
|
+
return self.initializer(self.key(), shape, dtype)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def keras_to_jax_initializer(
|
|
105
|
+
initializer: str | keras.initializers.Initializer,
|
|
106
|
+
) -> jax.nn.initializers.Initializer:
|
|
107
|
+
"""Converts a Keras initializer to a JAX initializer.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
initializer: Keras initializer to convert.
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
A JAX-compatible equivalent initializer.
|
|
114
|
+
"""
|
|
115
|
+
if isinstance(initializer, WrappedJaxInitializer):
|
|
116
|
+
return initializer.initializer
|
|
117
|
+
return WrappedKerasInitializer(initializer)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def jax_to_keras_initializer(
|
|
121
|
+
initializer: jax.nn.initializers.Initializer,
|
|
122
|
+
) -> keras.initializers.Initializer:
|
|
123
|
+
"""Converts a JAX initializer to a Keras initializer.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
initializer: JAX initializer to convert.
|
|
127
|
+
|
|
128
|
+
Returns:
|
|
129
|
+
An equivalent Keras initializer.
|
|
130
|
+
"""
|
|
131
|
+
if isinstance(initializer, WrappedKerasInitializer):
|
|
132
|
+
return initializer.initializer
|
|
133
|
+
return WrappedJaxInitializer(initializer)
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def keras_to_jte_learning_rate(
|
|
137
|
+
learning_rate: keras.Variable | float | Callable[..., float],
|
|
138
|
+
) -> float | Callable[..., float]:
|
|
139
|
+
"""Converts a Keras learning rate to a JAX TPU Embedding learning rate.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
learning_rate: Any Keras-compatible learning-rate type. If a Callable,
|
|
143
|
+
must either take no parameters, or the step size as a single argument.
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
A JAX TPU Embedding learning rate.
|
|
147
|
+
|
|
148
|
+
Raises:
|
|
149
|
+
ValueError if the learning rate is not supported.
|
|
150
|
+
"""
|
|
151
|
+
|
|
152
|
+
# Supported keras optimizer general options.
|
|
153
|
+
if isinstance(learning_rate, keras.Variable):
|
|
154
|
+
# Extract the first (and only) element of the variable.
|
|
155
|
+
learning_rate = np.array(learning_rate.value, dtype=float)
|
|
156
|
+
assert learning_rate.size == 1
|
|
157
|
+
lr_float: float = learning_rate.item(0)
|
|
158
|
+
return lr_float
|
|
159
|
+
elif callable(learning_rate):
|
|
160
|
+
# Callable learning rate functions are expected to take a singular step
|
|
161
|
+
# count argument, or no arguments.
|
|
162
|
+
args = inspect.getfullargspec(learning_rate).args
|
|
163
|
+
# If not a function, then it's an object instance with `self` as the
|
|
164
|
+
# first argument.
|
|
165
|
+
num_args = (
|
|
166
|
+
len(args) if inspect.isfunction(learning_rate) else len(args) - 1
|
|
167
|
+
)
|
|
168
|
+
if num_args <= 1:
|
|
169
|
+
return learning_rate
|
|
170
|
+
elif isinstance(learning_rate, float):
|
|
171
|
+
return learning_rate
|
|
172
|
+
|
|
173
|
+
raise ValueError(
|
|
174
|
+
f"Unsupported learning rate: {learning_rate} of type"
|
|
175
|
+
f" {type(learning_rate)}."
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def jte_to_keras_learning_rate(
|
|
180
|
+
learning_rate: float | Callable[..., float],
|
|
181
|
+
) -> float | Callable[..., float]:
|
|
182
|
+
"""Converts a JAX TPU Embedding learning rate to a Keras learning rate.
|
|
183
|
+
|
|
184
|
+
Args:
|
|
185
|
+
learning_rate: The learning rate value or function. If a Callable, must
|
|
186
|
+
either take no parameters, or the step size as a single argument.
|
|
187
|
+
|
|
188
|
+
Returns:
|
|
189
|
+
A JAX TPU Embedding learning rate.
|
|
190
|
+
|
|
191
|
+
Raises:
|
|
192
|
+
ValueError if the learning rate is not supported.
|
|
193
|
+
"""
|
|
194
|
+
if callable(learning_rate):
|
|
195
|
+
# Callable learning rate functions are expected to take a singular step
|
|
196
|
+
# count argument, or no arguments.
|
|
197
|
+
args = inspect.getfullargspec(learning_rate).args
|
|
198
|
+
# If not a function, then it's an object instance, with `self` as the
|
|
199
|
+
# first arguments.
|
|
200
|
+
num_args = (
|
|
201
|
+
len(args) if inspect.isfunction(learning_rate) else len(args) - 1
|
|
202
|
+
)
|
|
203
|
+
if num_args <= 1:
|
|
204
|
+
return learning_rate
|
|
205
|
+
elif isinstance(learning_rate, float):
|
|
206
|
+
return learning_rate
|
|
207
|
+
|
|
208
|
+
raise ValueError(f"Unknown learning rate {learning_rate}")
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def keras_to_jte_optimizer(
|
|
212
|
+
optimizer: keras.optimizers.Optimizer | str,
|
|
213
|
+
) -> embedding_spec.OptimizerSpec:
|
|
214
|
+
"""Converts a Keras optimizer to a JAX TPU Embedding optimizer.
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
optimizer: Any Keras-compatible optimizer.
|
|
218
|
+
|
|
219
|
+
Returns:
|
|
220
|
+
A JAX TPU Embedding optimizer.
|
|
221
|
+
"""
|
|
222
|
+
if isinstance(optimizer, str):
|
|
223
|
+
optimizer = keras.optimizers.get(optimizer)
|
|
224
|
+
|
|
225
|
+
# We need to extract the actual internal learning_rate function.
|
|
226
|
+
# Unfortunately, the optimizer.learning_rate attribute tries to be smart,
|
|
227
|
+
# and evaluates the learning rate at the current iteration step, which is
|
|
228
|
+
# not what we want.
|
|
229
|
+
# pylint: disable-next=protected-access
|
|
230
|
+
learning_rate = keras_to_jte_learning_rate(optimizer._learning_rate)
|
|
231
|
+
|
|
232
|
+
# Unsupported keras optimizer general options.
|
|
233
|
+
if optimizer.clipnorm is not None:
|
|
234
|
+
raise ValueError("Unsupported optimizer option `clipnorm`.")
|
|
235
|
+
if optimizer.global_clipnorm is not None:
|
|
236
|
+
raise ValueError("Unsupported optimizer option `global_clipnorm`.")
|
|
237
|
+
if optimizer.use_ema:
|
|
238
|
+
raise ValueError("Unsupported optimizer option `use_ema`.")
|
|
239
|
+
if optimizer.loss_scale_factor is not None:
|
|
240
|
+
raise ValueError("Unsupported optimizer option `loss_scale_factor`.")
|
|
241
|
+
|
|
242
|
+
# Supported optimizers.
|
|
243
|
+
if isinstance(optimizer, keras.optimizers.SGD):
|
|
244
|
+
if getattr(optimizer, "nesterov", False):
|
|
245
|
+
raise ValueError("Unsupported optimizer option `nesterov`.")
|
|
246
|
+
if getattr(optimizer, "momentum", 0.0) != 0.0:
|
|
247
|
+
raise ValueError("Unsupported optimizer option `momentum`.")
|
|
248
|
+
return embedding_spec.SGDOptimizerSpec(learning_rate=learning_rate)
|
|
249
|
+
elif isinstance(optimizer, keras.optimizers.Adagrad):
|
|
250
|
+
if getattr(optimizer, "epsilon", 1e-7) != 1e-7:
|
|
251
|
+
raise ValueError("Unsupported optimizer option `epsilon`.")
|
|
252
|
+
return embedding_spec.AdagradOptimizerSpec(
|
|
253
|
+
learning_rate=learning_rate,
|
|
254
|
+
initial_accumulator_value=optimizer.initial_accumulator_value,
|
|
255
|
+
)
|
|
256
|
+
elif isinstance(optimizer, keras.optimizers.Adam):
|
|
257
|
+
if getattr(optimizer, "amsgrad", False):
|
|
258
|
+
raise ValueError("Unsupported optimizer option `amsgrad`.")
|
|
259
|
+
|
|
260
|
+
return embedding_spec.AdamOptimizerSpec(
|
|
261
|
+
learning_rate=learning_rate,
|
|
262
|
+
beta_1=optimizer.beta_1,
|
|
263
|
+
beta_2=optimizer.beta_2,
|
|
264
|
+
epsilon=optimizer.epsilon,
|
|
265
|
+
)
|
|
266
|
+
elif isinstance(optimizer, keras.optimizers.Ftrl):
|
|
267
|
+
if (
|
|
268
|
+
getattr(optimizer, "l2_shrinkage_regularization_strength", 0.0)
|
|
269
|
+
!= 0.0
|
|
270
|
+
):
|
|
271
|
+
raise ValueError(
|
|
272
|
+
"Unsupported optimizer option "
|
|
273
|
+
"`l2_shrinkage_regularization_strength`."
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
return embedding_spec.FTRLOptimizerSpec(
|
|
277
|
+
learning_rate=learning_rate,
|
|
278
|
+
learning_rate_power=optimizer.learning_rate_power,
|
|
279
|
+
l1_regularization_strength=optimizer.l1_regularization_strength,
|
|
280
|
+
l2_regularization_strength=optimizer.l2_regularization_strength,
|
|
281
|
+
beta=optimizer.beta,
|
|
282
|
+
initial_accumulator_value=optimizer.initial_accumulator_value,
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
raise ValueError(
|
|
286
|
+
f"Unsupported optimizer type {type(optimizer)}. Optimizer must be "
|
|
287
|
+
f"one of [Adagrad, Adam, Ftrl, SGD]."
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
def jte_to_keras_optimizer(
|
|
292
|
+
optimizer: embedding_spec.OptimizerSpec,
|
|
293
|
+
) -> keras.optimizers.Optimizer:
|
|
294
|
+
"""Converts a JAX TPU Embedding optimizer to a Keras optimizer.
|
|
295
|
+
|
|
296
|
+
Args:
|
|
297
|
+
optimizer: The JAX TPU Embedding optimizer.
|
|
298
|
+
|
|
299
|
+
Returns:
|
|
300
|
+
A corresponding Keras optimizer.
|
|
301
|
+
"""
|
|
302
|
+
learning_rate = jte_to_keras_learning_rate(optimizer.learning_rate)
|
|
303
|
+
if isinstance(optimizer, embedding_spec.SGDOptimizerSpec):
|
|
304
|
+
return keras.optimizers.SGD(learning_rate=learning_rate)
|
|
305
|
+
elif isinstance(optimizer, embedding_spec.AdagradOptimizerSpec):
|
|
306
|
+
return keras.optimizers.Adagrad(
|
|
307
|
+
learning_rate=learning_rate,
|
|
308
|
+
initial_accumulator_value=optimizer.initial_accumulator_value,
|
|
309
|
+
)
|
|
310
|
+
elif isinstance(optimizer, embedding_spec.AdamOptimizerSpec):
|
|
311
|
+
return keras.optimizers.Adam(
|
|
312
|
+
learning_rate=learning_rate,
|
|
313
|
+
beta_1=optimizer.beta_1,
|
|
314
|
+
beta_2=optimizer.beta_2,
|
|
315
|
+
epsilon=optimizer.epsilon,
|
|
316
|
+
)
|
|
317
|
+
elif isinstance(optimizer, embedding_spec.FTRLOptimizerSpec):
|
|
318
|
+
if getattr(optimizer, "initial_linear_value", 0.0) != 0.0:
|
|
319
|
+
raise ValueError(
|
|
320
|
+
"Unsupported optimizer option `initial_linear_value`."
|
|
321
|
+
)
|
|
322
|
+
if getattr(optimizer, "multiply_linear_by_learning_rate", False):
|
|
323
|
+
raise ValueError(
|
|
324
|
+
"Unsupported optimizer option "
|
|
325
|
+
"`multiply_linear_by_learning_rate`."
|
|
326
|
+
)
|
|
327
|
+
return keras.optimizers.Ftrl(
|
|
328
|
+
learning_rate=learning_rate,
|
|
329
|
+
learning_rate_power=optimizer.learning_rate_power,
|
|
330
|
+
initial_accumulator_value=optimizer.initial_accumulator_value,
|
|
331
|
+
l1_regularization_strength=optimizer.l1_regularization_strength,
|
|
332
|
+
l2_regularization_strength=optimizer.l2_regularization_strength,
|
|
333
|
+
beta=optimizer.beta,
|
|
334
|
+
)
|
|
335
|
+
|
|
336
|
+
raise ValueError(f"Unknown optimizer spec {type(optimizer)}.")
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
def _keras_to_jte_table_config(
|
|
340
|
+
table_config: config.TableConfig,
|
|
341
|
+
) -> embedding_spec.TableSpec:
|
|
342
|
+
# Initializer could be none. Default to truncated normal.
|
|
343
|
+
initializer = table_config.initializer
|
|
344
|
+
if initializer is None:
|
|
345
|
+
initializer = keras.initializers.TruncatedNormal(
|
|
346
|
+
mean=0.0, stddev=1.0 / math.sqrt(float(table_config.embedding_dim))
|
|
347
|
+
)
|
|
348
|
+
return embedding_spec.TableSpec(
|
|
349
|
+
name=table_config.name,
|
|
350
|
+
vocabulary_size=table_config.vocabulary_size,
|
|
351
|
+
embedding_dim=table_config.embedding_dim,
|
|
352
|
+
initializer=keras_to_jax_initializer(initializer),
|
|
353
|
+
optimizer=keras_to_jte_optimizer(table_config.optimizer),
|
|
354
|
+
combiner=table_config.combiner,
|
|
355
|
+
max_ids_per_partition=table_config.max_ids_per_partition,
|
|
356
|
+
max_unique_ids_per_partition=table_config.max_unique_ids_per_partition,
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
def keras_to_jte_table_configs(
|
|
361
|
+
table_configs: Nested[config.TableConfig],
|
|
362
|
+
) -> Nested[embedding_spec.TableSpec]:
|
|
363
|
+
"""Converts Keras RS `TableConfig`s to JAX TPU Embedding `TableSpec`s."""
|
|
364
|
+
return keras.tree.map_structure(
|
|
365
|
+
_keras_to_jte_table_config,
|
|
366
|
+
table_configs,
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
def _jte_to_keras_table_config(
|
|
371
|
+
table_spec: embedding_spec.TableSpec,
|
|
372
|
+
) -> config.TableConfig:
|
|
373
|
+
return config.TableConfig(
|
|
374
|
+
name=table_spec.name,
|
|
375
|
+
vocabulary_size=table_spec.vocabulary_size,
|
|
376
|
+
embedding_dim=table_spec.embedding_dim,
|
|
377
|
+
initializer=jax_to_keras_initializer(table_spec.initializer),
|
|
378
|
+
optimizer=jte_to_keras_optimizer(table_spec.optimizer),
|
|
379
|
+
combiner=table_spec.combiner,
|
|
380
|
+
max_ids_per_partition=table_spec.max_ids_per_partition,
|
|
381
|
+
max_unique_ids_per_partition=table_spec.max_unique_ids_per_partition,
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
|
|
385
|
+
def jte_to_keras_table_configs(
|
|
386
|
+
table_specs: Nested[embedding_spec.TableSpec],
|
|
387
|
+
) -> Nested[config.TableConfig]:
|
|
388
|
+
"""Converts JAX TPU Embedding `TableSpec`s to Keras RS `TableConfig`s."""
|
|
389
|
+
output: Nested[config.TableConfig] = keras.tree.map_structure(
|
|
390
|
+
_jte_to_keras_table_config,
|
|
391
|
+
table_specs,
|
|
392
|
+
)
|
|
393
|
+
return output
|
|
394
|
+
|
|
395
|
+
|
|
396
|
+
def _keras_to_jte_feature_config(
|
|
397
|
+
feature_config: config.FeatureConfig,
|
|
398
|
+
table_spec_map: dict[str, embedding_spec.TableSpec],
|
|
399
|
+
) -> embedding_spec.FeatureSpec:
|
|
400
|
+
table_spec = table_spec_map.get(feature_config.table.name, None)
|
|
401
|
+
if table_spec is None:
|
|
402
|
+
table_spec = _keras_to_jte_table_config(feature_config.table)
|
|
403
|
+
table_spec_map[feature_config.table.name] = table_spec
|
|
404
|
+
|
|
405
|
+
return embedding_spec.FeatureSpec(
|
|
406
|
+
name=feature_config.name,
|
|
407
|
+
table_spec=table_spec,
|
|
408
|
+
input_shape=feature_config.input_shape,
|
|
409
|
+
output_shape=feature_config.output_shape,
|
|
410
|
+
)
|
|
411
|
+
|
|
412
|
+
|
|
413
|
+
def keras_to_jte_feature_configs(
|
|
414
|
+
feature_configs: Nested[config.FeatureConfig],
|
|
415
|
+
) -> Nested[embedding_spec.FeatureSpec]:
|
|
416
|
+
"""Converts Keras RS `FeatureConfig`s to JAX TPU Embedding `FeatureSpec`s.
|
|
417
|
+
|
|
418
|
+
Args:
|
|
419
|
+
feature_configs: Keras RS feature configurations.
|
|
420
|
+
|
|
421
|
+
Returns:
|
|
422
|
+
JAX TPU Embedding feature specifications.
|
|
423
|
+
"""
|
|
424
|
+
table_spec_map: dict[str, embedding_spec.TableSpec] = {}
|
|
425
|
+
return keras.tree.map_structure(
|
|
426
|
+
lambda feature_config: _keras_to_jte_feature_config(
|
|
427
|
+
feature_config, table_spec_map
|
|
428
|
+
),
|
|
429
|
+
feature_configs,
|
|
430
|
+
)
|
|
431
|
+
|
|
432
|
+
|
|
433
|
+
def _jte_to_keras_feature_config(
|
|
434
|
+
feature_spec: embedding_spec.FeatureSpec,
|
|
435
|
+
table_config_map: dict[str, config.TableConfig],
|
|
436
|
+
) -> config.FeatureConfig:
|
|
437
|
+
table_config = table_config_map.get(feature_spec.table_spec.name, None)
|
|
438
|
+
if table_config is None:
|
|
439
|
+
table_config = _jte_to_keras_table_config(feature_spec.table_spec)
|
|
440
|
+
table_config_map[feature_spec.table_spec.name] = table_config
|
|
441
|
+
|
|
442
|
+
return config.FeatureConfig(
|
|
443
|
+
name=feature_spec.name,
|
|
444
|
+
table=table_config,
|
|
445
|
+
input_shape=feature_spec.input_shape,
|
|
446
|
+
output_shape=feature_spec.output_shape,
|
|
447
|
+
)
|
|
448
|
+
|
|
449
|
+
|
|
450
|
+
def jte_to_keras_feature_configs(
|
|
451
|
+
feature_specs: Nested[embedding_spec.FeatureSpec],
|
|
452
|
+
) -> Nested[config.FeatureConfig]:
|
|
453
|
+
"""Converts JAX TPU Embedding `FeatureSpec`s to Keras RS `FeatureConfig`s.
|
|
454
|
+
|
|
455
|
+
Args:
|
|
456
|
+
feature_specs: JAX TPU Embedding feature specifications.
|
|
457
|
+
|
|
458
|
+
Returns:
|
|
459
|
+
Keras RS feature configurations.
|
|
460
|
+
"""
|
|
461
|
+
table_config_map: dict[str, config.TableConfig] = {}
|
|
462
|
+
output: Nested[config.FeatureConfig] = keras.tree.map_structure(
|
|
463
|
+
lambda feature_spec: _jte_to_keras_feature_config(
|
|
464
|
+
feature_spec, table_config_map
|
|
465
|
+
),
|
|
466
|
+
feature_specs,
|
|
467
|
+
)
|
|
468
|
+
return output
|