keras-rs-nightly 0.0.1.dev2025021903__py3-none-any.whl → 0.3.1.dev202512130338__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_rs/__init__.py +9 -28
- keras_rs/layers/__init__.py +37 -0
- keras_rs/losses/__init__.py +19 -0
- keras_rs/metrics/__init__.py +16 -0
- keras_rs/src/layers/embedding/base_distributed_embedding.py +1151 -0
- keras_rs/src/layers/embedding/distributed_embedding.py +33 -0
- keras_rs/src/layers/embedding/distributed_embedding_config.py +132 -0
- keras_rs/src/layers/embedding/embed_reduce.py +309 -0
- keras_rs/src/layers/embedding/jax/__init__.py +0 -0
- keras_rs/src/layers/embedding/jax/checkpoint_utils.py +104 -0
- keras_rs/src/layers/embedding/jax/config_conversion.py +468 -0
- keras_rs/src/layers/embedding/jax/distributed_embedding.py +829 -0
- keras_rs/src/layers/embedding/jax/embedding_lookup.py +276 -0
- keras_rs/src/layers/embedding/jax/embedding_utils.py +217 -0
- keras_rs/src/layers/embedding/tensorflow/__init__.py +0 -0
- keras_rs/src/layers/embedding/tensorflow/config_conversion.py +363 -0
- keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py +436 -0
- keras_rs/src/layers/feature_interaction/__init__.py +0 -0
- keras_rs/src/layers/{modeling → feature_interaction}/dot_interaction.py +116 -25
- keras_rs/src/layers/{modeling → feature_interaction}/feature_cross.py +40 -22
- keras_rs/src/layers/retrieval/brute_force_retrieval.py +16 -65
- keras_rs/src/layers/retrieval/hard_negative_mining.py +94 -0
- keras_rs/src/layers/retrieval/remove_accidental_hits.py +97 -0
- keras_rs/src/layers/retrieval/retrieval.py +127 -0
- keras_rs/src/layers/retrieval/sampling_probability_correction.py +63 -0
- keras_rs/src/losses/__init__.py +0 -0
- keras_rs/src/losses/list_mle_loss.py +212 -0
- keras_rs/src/losses/pairwise_hinge_loss.py +90 -0
- keras_rs/src/losses/pairwise_logistic_loss.py +99 -0
- keras_rs/src/losses/pairwise_loss.py +165 -0
- keras_rs/src/losses/pairwise_loss_utils.py +39 -0
- keras_rs/src/losses/pairwise_mean_squared_error.py +133 -0
- keras_rs/src/losses/pairwise_soft_zero_one_loss.py +98 -0
- keras_rs/src/metrics/__init__.py +0 -0
- keras_rs/src/metrics/dcg.py +161 -0
- keras_rs/src/metrics/mean_average_precision.py +130 -0
- keras_rs/src/metrics/mean_reciprocal_rank.py +121 -0
- keras_rs/src/metrics/ndcg.py +197 -0
- keras_rs/src/metrics/precision_at_k.py +117 -0
- keras_rs/src/metrics/ranking_metric.py +260 -0
- keras_rs/src/metrics/ranking_metrics_utils.py +257 -0
- keras_rs/src/metrics/recall_at_k.py +108 -0
- keras_rs/src/metrics/utils.py +70 -0
- keras_rs/src/types.py +43 -14
- keras_rs/src/utils/doc_string_utils.py +53 -0
- keras_rs/src/utils/keras_utils.py +52 -3
- keras_rs/src/utils/tpu_test_utils.py +120 -0
- keras_rs/src/version.py +1 -1
- {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/METADATA +88 -8
- keras_rs_nightly-0.3.1.dev202512130338.dist-info/RECORD +58 -0
- {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/WHEEL +1 -1
- keras_rs/api/__init__.py +0 -9
- keras_rs/api/layers/__init__.py +0 -11
- keras_rs_nightly-0.0.1.dev2025021903.dist-info/RECORD +0 -19
- /keras_rs/src/layers/{modeling → embedding}/__init__.py +0 -0
- {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,363 @@
|
|
|
1
|
+
import collections
|
|
2
|
+
from typing import Any, Sequence, TypeAlias
|
|
3
|
+
|
|
4
|
+
import keras
|
|
5
|
+
import tensorflow as tf
|
|
6
|
+
|
|
7
|
+
from keras_rs.src import types
|
|
8
|
+
from keras_rs.src.layers.embedding import distributed_embedding_config
|
|
9
|
+
|
|
10
|
+
FeatureConfig = distributed_embedding_config.FeatureConfig
|
|
11
|
+
TableConfig = distributed_embedding_config.TableConfig
|
|
12
|
+
|
|
13
|
+
# Placeholder of tf.tpu.experimental.embedding._Optimizer which is not exposed.
|
|
14
|
+
TfTpuOptimizer: TypeAlias = Any
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
OptimizerMapping = collections.namedtuple(
|
|
18
|
+
"OptimizerMapping",
|
|
19
|
+
["tpu_optimizer_class", "supported_kwargs", "unsupported_kwargs"],
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
OPTIMIZER_MAPPINGS = {
|
|
24
|
+
keras.optimizers.Adagrad: OptimizerMapping(
|
|
25
|
+
tpu_optimizer_class=tf.tpu.experimental.embedding.Adagrad,
|
|
26
|
+
supported_kwargs=["initial_accumulator_value"],
|
|
27
|
+
unsupported_kwargs={"epsilon": 1e-07},
|
|
28
|
+
),
|
|
29
|
+
keras.optimizers.Adam: OptimizerMapping(
|
|
30
|
+
tpu_optimizer_class=tf.tpu.experimental.embedding.Adam,
|
|
31
|
+
supported_kwargs=["beta_1", "beta_2", "epsilon"],
|
|
32
|
+
unsupported_kwargs={"amsgrad": False},
|
|
33
|
+
),
|
|
34
|
+
keras.optimizers.Ftrl: OptimizerMapping(
|
|
35
|
+
tpu_optimizer_class=tf.tpu.experimental.embedding.FTRL,
|
|
36
|
+
supported_kwargs=[
|
|
37
|
+
"learning_rate_power",
|
|
38
|
+
"initial_accumulator_value",
|
|
39
|
+
"l1_regularization_strength",
|
|
40
|
+
"l2_regularization_strength",
|
|
41
|
+
"beta",
|
|
42
|
+
],
|
|
43
|
+
unsupported_kwargs={"l2_shrinkage_regularization_strength": 0.0},
|
|
44
|
+
),
|
|
45
|
+
keras.optimizers.SGD: OptimizerMapping(
|
|
46
|
+
tpu_optimizer_class=tf.tpu.experimental.embedding.SGD,
|
|
47
|
+
supported_kwargs=[],
|
|
48
|
+
unsupported_kwargs={"momentum": 0.0, "nesterov": False},
|
|
49
|
+
),
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
# KerasRS to TensorFlow
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def keras_to_tf_tpu_configuration(
|
|
57
|
+
feature_configs: types.Nested[FeatureConfig],
|
|
58
|
+
table_stacking: str | Sequence[str] | Sequence[Sequence[str]],
|
|
59
|
+
num_replicas_in_sync: int,
|
|
60
|
+
) -> tuple[
|
|
61
|
+
types.Nested[tf.tpu.experimental.embedding.FeatureConfig],
|
|
62
|
+
tf.tpu.experimental.embedding.SparseCoreEmbeddingConfig,
|
|
63
|
+
]:
|
|
64
|
+
"""Translates a Keras RS configuration to a TensorFlow TPU configuration.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
feature_configs: The nested Keras RS feature configs.
|
|
68
|
+
table_stacking: The Keras RS table stacking.
|
|
69
|
+
num_replicas_in_sync: The number of replicas in sync from the strategy.
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
A tuple containing the TensorFlow TPU feature configs and the TensorFlow
|
|
73
|
+
TPU sparse core embedding config.
|
|
74
|
+
"""
|
|
75
|
+
tables: dict[int, tf.tpu.experimental.embedding.TableConfig] = {}
|
|
76
|
+
feature_configs = keras.tree.map_structure(
|
|
77
|
+
lambda f: keras_to_tf_tpu_feature_config(
|
|
78
|
+
f, tables, num_replicas_in_sync
|
|
79
|
+
),
|
|
80
|
+
feature_configs,
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
# max_ids_per_chip_per_sample
|
|
84
|
+
# max_ids_per_table
|
|
85
|
+
# max_unique_ids_per_table
|
|
86
|
+
|
|
87
|
+
if table_stacking is None:
|
|
88
|
+
disable_table_stacking = True
|
|
89
|
+
elif table_stacking == "auto":
|
|
90
|
+
disable_table_stacking = False
|
|
91
|
+
else:
|
|
92
|
+
raise ValueError(
|
|
93
|
+
f"Unsupported table stacking for Tensorflow {table_stacking}, must "
|
|
94
|
+
"be 'auto' or None."
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
# Find alternative.
|
|
98
|
+
# `initialize_tables_on_host` is set to False. Otherwise, if the
|
|
99
|
+
# `TPUEmbedding` layer is built within Keras' `compute_output_spec` (meaning
|
|
100
|
+
# within `call`), the tables are created within a `FuncGraph` and the
|
|
101
|
+
# resulting tables are destroyed at the end of it.
|
|
102
|
+
sparse_core_embedding_config = (
|
|
103
|
+
tf.tpu.experimental.embedding.SparseCoreEmbeddingConfig(
|
|
104
|
+
disable_table_stacking=disable_table_stacking,
|
|
105
|
+
initialize_tables_on_host=False,
|
|
106
|
+
)
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
return feature_configs, sparse_core_embedding_config
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def keras_to_tf_tpu_feature_config(
|
|
113
|
+
feature_config: FeatureConfig,
|
|
114
|
+
tables: dict[int, tf.tpu.experimental.embedding.TableConfig],
|
|
115
|
+
num_replicas_in_sync: int,
|
|
116
|
+
) -> tf.tpu.experimental.embedding.FeatureConfig:
|
|
117
|
+
"""Translates a Keras RS feature config to a TensorFlow TPU feature config.
|
|
118
|
+
|
|
119
|
+
This creates the table config and adds it to the mapping if it doesn't exist
|
|
120
|
+
in the `tables` mapping`.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
feature_config: The Keras RS feature config to translate.
|
|
124
|
+
tables: A mapping of KerasRS table config ids to TF TPU table configs.
|
|
125
|
+
num_replicas_in_sync: The number of replicas in sync from the strategy.
|
|
126
|
+
|
|
127
|
+
Returns:
|
|
128
|
+
The TensorFlow TPU feature config.
|
|
129
|
+
"""
|
|
130
|
+
if num_replicas_in_sync <= 0:
|
|
131
|
+
raise ValueError(
|
|
132
|
+
"`num_replicas_in_sync` must be positive, "
|
|
133
|
+
f"but got {num_replicas_in_sync}."
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
table = tables.get(id(feature_config.table), None)
|
|
137
|
+
if table is None:
|
|
138
|
+
table = keras_to_tf_tpu_table_config(feature_config.table)
|
|
139
|
+
tables[id(feature_config.table)] = table
|
|
140
|
+
|
|
141
|
+
if len(feature_config.output_shape) < 2:
|
|
142
|
+
raise ValueError(
|
|
143
|
+
f"Invalid `output_shape` {feature_config.output_shape} in "
|
|
144
|
+
f"`FeatureConfig` {feature_config}. It must have at least 2 "
|
|
145
|
+
"dimensions: a batch dimension and an embedding dimension."
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
# Exclude last dimension, TensorFlow's TPUEmbedding doesn't want it.
|
|
149
|
+
output_shape = list(feature_config.output_shape[0:-1])
|
|
150
|
+
|
|
151
|
+
batch_size = output_shape[0]
|
|
152
|
+
per_replica_batch_size: int | None = None
|
|
153
|
+
if batch_size is not None:
|
|
154
|
+
if batch_size % num_replicas_in_sync != 0:
|
|
155
|
+
raise ValueError(
|
|
156
|
+
f"Invalid `output_shape` {feature_config.output_shape} in "
|
|
157
|
+
f"`FeatureConfig` {feature_config}. Batch size {batch_size} is "
|
|
158
|
+
f"not a multiple of the number of TPUs {num_replicas_in_sync}."
|
|
159
|
+
)
|
|
160
|
+
per_replica_batch_size = batch_size // num_replicas_in_sync
|
|
161
|
+
|
|
162
|
+
# TensorFlow's TPUEmbedding wants the per replica batch size.
|
|
163
|
+
output_shape = [per_replica_batch_size] + output_shape[1:]
|
|
164
|
+
|
|
165
|
+
# max_sequence_length
|
|
166
|
+
return tf.tpu.experimental.embedding.FeatureConfig(
|
|
167
|
+
name=feature_config.name,
|
|
168
|
+
table=table,
|
|
169
|
+
output_shape=output_shape,
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def keras_to_tf_tpu_table_config(
|
|
174
|
+
table_config: TableConfig,
|
|
175
|
+
) -> tf.tpu.experimental.embedding.TableConfig:
|
|
176
|
+
initializer = table_config.initializer
|
|
177
|
+
if isinstance(initializer, str):
|
|
178
|
+
initializer = keras.initializers.get(initializer)
|
|
179
|
+
|
|
180
|
+
return tf.tpu.experimental.embedding.TableConfig(
|
|
181
|
+
vocabulary_size=table_config.vocabulary_size,
|
|
182
|
+
dim=table_config.embedding_dim,
|
|
183
|
+
initializer=initializer,
|
|
184
|
+
optimizer=to_tf_tpu_optimizer(table_config.optimizer),
|
|
185
|
+
combiner=table_config.combiner,
|
|
186
|
+
name=table_config.name,
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def keras_to_tf_tpu_optimizer(
|
|
191
|
+
optimizer: keras.optimizers.Optimizer,
|
|
192
|
+
) -> TfTpuOptimizer:
|
|
193
|
+
"""Translates a Keras optimizer to a TensorFlow TPU `_Optimizer`.
|
|
194
|
+
|
|
195
|
+
Args:
|
|
196
|
+
optimizer: The Keras optimizer to translate.
|
|
197
|
+
|
|
198
|
+
Returns:
|
|
199
|
+
The TensorFlow TPU `_Optimizer`.
|
|
200
|
+
"""
|
|
201
|
+
tpu_optimizer_kwargs: dict[str, Any] = {}
|
|
202
|
+
|
|
203
|
+
# Supported keras optimizer general options.
|
|
204
|
+
learning_rate = optimizer._learning_rate # pylint: disable=protected-access
|
|
205
|
+
if isinstance(
|
|
206
|
+
learning_rate, keras.optimizers.schedules.LearningRateSchedule
|
|
207
|
+
):
|
|
208
|
+
# Note: learning rate requires incrementing iterations in optimizer.
|
|
209
|
+
tpu_optimizer_kwargs["learning_rate"] = lambda: optimizer.learning_rate
|
|
210
|
+
elif callable(learning_rate):
|
|
211
|
+
tpu_optimizer_kwargs["learning_rate"] = learning_rate
|
|
212
|
+
else:
|
|
213
|
+
learning_rate = optimizer.get_config()["learning_rate"]
|
|
214
|
+
if isinstance(learning_rate, float):
|
|
215
|
+
tpu_optimizer_kwargs["learning_rate"] = learning_rate
|
|
216
|
+
else:
|
|
217
|
+
raise ValueError(
|
|
218
|
+
f"Unsupported learning rate: {learning_rate} of type"
|
|
219
|
+
f" {type(learning_rate)}."
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
if optimizer.weight_decay is not None:
|
|
223
|
+
tpu_optimizer_kwargs["weight_decay_factor"] = optimizer.weight_decay
|
|
224
|
+
if optimizer.clipvalue is not None:
|
|
225
|
+
tpu_optimizer_kwargs["clipvalue"] = optimizer.clipvalue
|
|
226
|
+
if optimizer.gradient_accumulation_steps is not None:
|
|
227
|
+
tpu_optimizer_kwargs["use_gradient_accumulation"] = True
|
|
228
|
+
|
|
229
|
+
# Unsupported keras optimizer general options.
|
|
230
|
+
if optimizer.clipnorm is not None:
|
|
231
|
+
raise ValueError("Unsupported optimizer option `Optimizer.clipnorm`.")
|
|
232
|
+
if optimizer.global_clipnorm is not None:
|
|
233
|
+
raise ValueError(
|
|
234
|
+
"Unsupported optimizer option `Optimizer.global_clipnorm`."
|
|
235
|
+
)
|
|
236
|
+
if optimizer.use_ema:
|
|
237
|
+
raise ValueError("Unsupported optimizer option `Optimizer.use_ema`.")
|
|
238
|
+
if optimizer.loss_scale_factor is not None:
|
|
239
|
+
raise ValueError(
|
|
240
|
+
"Unsupported optimizer option `Optimizer.loss_scale_factor`."
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
optimizer_mapping = None
|
|
244
|
+
for optimizer_class, mapping in OPTIMIZER_MAPPINGS.items():
|
|
245
|
+
# Handle subclasses of the main optimizer class.
|
|
246
|
+
if isinstance(optimizer, optimizer_class):
|
|
247
|
+
optimizer_mapping = mapping
|
|
248
|
+
break
|
|
249
|
+
if optimizer_mapping is None:
|
|
250
|
+
raise ValueError(
|
|
251
|
+
f"Unsupported optimizer type {type(optimizer)}. Optimizer must be "
|
|
252
|
+
f"one of {list(OPTIMIZER_MAPPINGS.keys())}."
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
for argname in optimizer_mapping.supported_kwargs:
|
|
256
|
+
tpu_optimizer_kwargs[argname] = getattr(optimizer, argname)
|
|
257
|
+
|
|
258
|
+
for argname, disabled_value in optimizer_mapping.unsupported_kwargs.items():
|
|
259
|
+
if disabled_value is None:
|
|
260
|
+
if getattr(optimizer, argname) is not None:
|
|
261
|
+
raise ValueError(f"Unsupported optimizer option {argname}.")
|
|
262
|
+
elif getattr(optimizer, argname) != disabled_value:
|
|
263
|
+
raise ValueError(f"Unsupported optimizer option {argname}.")
|
|
264
|
+
|
|
265
|
+
return optimizer_mapping.tpu_optimizer_class(**tpu_optimizer_kwargs)
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
def to_tf_tpu_optimizer(
|
|
269
|
+
optimizer: str | keras.optimizers.Optimizer | TfTpuOptimizer | None,
|
|
270
|
+
) -> TfTpuOptimizer:
|
|
271
|
+
"""Translates a Keras optimizer into a TensorFlow TPU `_Optimizer`.
|
|
272
|
+
|
|
273
|
+
Args:
|
|
274
|
+
optimizer: The optimizer to translate.
|
|
275
|
+
|
|
276
|
+
Returns:
|
|
277
|
+
The equivalent TensorFlow TPU `_Optimizer`.
|
|
278
|
+
|
|
279
|
+
Raises:
|
|
280
|
+
ValueError: If the optimizer or one of its argument is not supported.
|
|
281
|
+
"""
|
|
282
|
+
if optimizer is None:
|
|
283
|
+
return None
|
|
284
|
+
elif isinstance(
|
|
285
|
+
optimizer,
|
|
286
|
+
(
|
|
287
|
+
tf.tpu.experimental.embedding.SGD,
|
|
288
|
+
tf.tpu.experimental.embedding.Adagrad,
|
|
289
|
+
tf.tpu.experimental.embedding.Adam,
|
|
290
|
+
tf.tpu.experimental.embedding.FTRL,
|
|
291
|
+
),
|
|
292
|
+
):
|
|
293
|
+
return optimizer
|
|
294
|
+
elif isinstance(optimizer, str):
|
|
295
|
+
if optimizer == "sgd":
|
|
296
|
+
return tf.tpu.experimental.embedding.SGD()
|
|
297
|
+
elif optimizer == "adagrad":
|
|
298
|
+
return tf.tpu.experimental.embedding.Adagrad()
|
|
299
|
+
elif optimizer == "adam":
|
|
300
|
+
return tf.tpu.experimental.embedding.Adam()
|
|
301
|
+
elif optimizer == "ftrl":
|
|
302
|
+
return tf.tpu.experimental.embedding.FTRL()
|
|
303
|
+
else:
|
|
304
|
+
raise ValueError(
|
|
305
|
+
f"Unknown optimizer name '{optimizer}'. Please use one of "
|
|
306
|
+
"'sgd', 'adagrad', 'adam', or 'ftrl'"
|
|
307
|
+
)
|
|
308
|
+
elif isinstance(optimizer, keras.optimizers.Optimizer):
|
|
309
|
+
return keras_to_tf_tpu_optimizer(optimizer)
|
|
310
|
+
else:
|
|
311
|
+
raise ValueError(
|
|
312
|
+
f"Unknown optimizer type {type(optimizer)}. Please pass an "
|
|
313
|
+
"optimizername as a string, a subclass of keras optimizer or an "
|
|
314
|
+
"instance of one of the optimizer parameter classes in "
|
|
315
|
+
"`tf.tpu.experimental.embedding`."
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
# TensorFlow to TensorFlow
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
def clone_tf_tpu_feature_configs(
|
|
323
|
+
feature_configs: types.Nested[tf.tpu.experimental.embedding.FeatureConfig],
|
|
324
|
+
) -> types.Nested[tf.tpu.experimental.embedding.FeatureConfig]:
|
|
325
|
+
"""Clones and resolves TensorFlow TPU feature configs.
|
|
326
|
+
|
|
327
|
+
This function clones the feature configs and resolves the table configs.
|
|
328
|
+
|
|
329
|
+
Args:
|
|
330
|
+
feature_configs: The TensorFlow TPU feature configs to clone and resolve.
|
|
331
|
+
|
|
332
|
+
Returns:
|
|
333
|
+
The cloned and resolved TensorFlow TPU feature configs.
|
|
334
|
+
"""
|
|
335
|
+
table_configs_dict = {}
|
|
336
|
+
|
|
337
|
+
def clone_and_resolve_tf_tpu_feature_config(
|
|
338
|
+
fc: tf.tpu.experimental.embedding.FeatureConfig,
|
|
339
|
+
) -> tf.tpu.experimental.embedding.FeatureConfig:
|
|
340
|
+
if fc.table not in table_configs_dict:
|
|
341
|
+
table_configs_dict[fc.table] = (
|
|
342
|
+
tf.tpu.experimental.embedding.TableConfig(
|
|
343
|
+
vocabulary_size=fc.table.vocabulary_size,
|
|
344
|
+
dim=fc.table.dim,
|
|
345
|
+
initializer=fc.table.initializer,
|
|
346
|
+
optimizer=to_tf_tpu_optimizer(fc.table.optimizer),
|
|
347
|
+
combiner=fc.table.combiner,
|
|
348
|
+
name=fc.table.name,
|
|
349
|
+
quantization_config=fc.table.quantization_config,
|
|
350
|
+
layout=fc.table.layout,
|
|
351
|
+
)
|
|
352
|
+
)
|
|
353
|
+
return tf.tpu.experimental.embedding.FeatureConfig(
|
|
354
|
+
table=table_configs_dict[fc.table],
|
|
355
|
+
max_sequence_length=fc.max_sequence_length,
|
|
356
|
+
validate_weights_and_indices=fc.validate_weights_and_indices,
|
|
357
|
+
output_shape=fc.output_shape,
|
|
358
|
+
name=fc.name,
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
return keras.tree.map_structure(
|
|
362
|
+
clone_and_resolve_tf_tpu_feature_config, feature_configs
|
|
363
|
+
)
|