keras-rs-nightly 0.0.1.dev2025043003__py3-none-any.whl → 0.2.2.dev202506100336__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of keras-rs-nightly might be problematic. Click here for more details.
- keras_rs/layers/__init__.py +12 -0
- keras_rs/src/layers/embedding/__init__.py +0 -0
- keras_rs/src/layers/embedding/base_distributed_embedding.py +1124 -0
- keras_rs/src/layers/embedding/distributed_embedding.py +33 -0
- keras_rs/src/layers/embedding/distributed_embedding_config.py +129 -0
- keras_rs/src/layers/embedding/embed_reduce.py +309 -0
- keras_rs/src/layers/embedding/jax/__init__.py +0 -0
- keras_rs/src/layers/embedding/jax/config_conversion.py +398 -0
- keras_rs/src/layers/embedding/jax/distributed_embedding.py +892 -0
- keras_rs/src/layers/embedding/jax/embedding_lookup.py +255 -0
- keras_rs/src/layers/embedding/jax/embedding_utils.py +596 -0
- keras_rs/src/layers/embedding/tensorflow/__init__.py +0 -0
- keras_rs/src/layers/embedding/tensorflow/config_conversion.py +323 -0
- keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py +424 -0
- keras_rs/src/layers/feature_interaction/dot_interaction.py +2 -2
- keras_rs/src/layers/feature_interaction/feature_cross.py +14 -16
- keras_rs/src/layers/retrieval/brute_force_retrieval.py +5 -5
- keras_rs/src/layers/retrieval/retrieval.py +4 -4
- keras_rs/src/losses/pairwise_loss.py +2 -2
- keras_rs/src/losses/pairwise_mean_squared_error.py +1 -3
- keras_rs/src/metrics/dcg.py +2 -2
- keras_rs/src/metrics/mean_average_precision.py +1 -1
- keras_rs/src/metrics/mean_reciprocal_rank.py +4 -4
- keras_rs/src/metrics/ndcg.py +2 -2
- keras_rs/src/metrics/precision_at_k.py +3 -3
- keras_rs/src/metrics/ranking_metric.py +11 -5
- keras_rs/src/metrics/ranking_metrics_utils.py +10 -10
- keras_rs/src/metrics/recall_at_k.py +2 -2
- keras_rs/src/metrics/utils.py +2 -4
- keras_rs/src/types.py +43 -14
- keras_rs/src/utils/keras_utils.py +26 -6
- keras_rs/src/version.py +1 -1
- {keras_rs_nightly-0.0.1.dev2025043003.dist-info → keras_rs_nightly-0.2.2.dev202506100336.dist-info}/METADATA +6 -3
- keras_rs_nightly-0.2.2.dev202506100336.dist-info/RECORD +55 -0
- {keras_rs_nightly-0.0.1.dev2025043003.dist-info → keras_rs_nightly-0.2.2.dev202506100336.dist-info}/WHEEL +1 -1
- keras_rs_nightly-0.0.1.dev2025043003.dist-info/RECORD +0 -42
- {keras_rs_nightly-0.0.1.dev2025043003.dist-info → keras_rs_nightly-0.2.2.dev202506100336.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,424 @@
|
|
|
1
|
+
from typing import Any, Callable, Sequence, TypeAlias
|
|
2
|
+
|
|
3
|
+
import keras
|
|
4
|
+
import tensorflow as tf
|
|
5
|
+
|
|
6
|
+
from keras_rs.src import types
|
|
7
|
+
from keras_rs.src.layers.embedding import base_distributed_embedding
|
|
8
|
+
from keras_rs.src.layers.embedding import distributed_embedding_config
|
|
9
|
+
from keras_rs.src.layers.embedding.tensorflow import config_conversion
|
|
10
|
+
from keras_rs.src.utils import keras_utils
|
|
11
|
+
|
|
12
|
+
FeatureConfig = distributed_embedding_config.FeatureConfig
|
|
13
|
+
TableConfig = distributed_embedding_config.TableConfig
|
|
14
|
+
|
|
15
|
+
# Placeholder of tf.tpu.experimental.embedding._Optimizer which is not exposed.
|
|
16
|
+
TfTpuOptimizer: TypeAlias = Any
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
GRADIENT_TRAP_DUMMY_NAME = "_gradient_trap_dummy"
|
|
20
|
+
|
|
21
|
+
EMBEDDING_FEATURE_V1 = tf.tpu.experimental.HardwareFeature.EmbeddingFeature.V1
|
|
22
|
+
EMBEDDING_FEATURE_V2 = tf.tpu.experimental.HardwareFeature.EmbeddingFeature.V2
|
|
23
|
+
UNSUPPORTED = tf.tpu.experimental.HardwareFeature.EmbeddingFeature.UNSUPPORTED
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
|
|
27
|
+
"""TensorFlow implementation of the TPU embedding layer."""
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
feature_configs: types.Nested[
|
|
32
|
+
FeatureConfig | tf.tpu.experimental.embedding.FeatureConfig
|
|
33
|
+
],
|
|
34
|
+
*,
|
|
35
|
+
table_stacking: (
|
|
36
|
+
str | Sequence[str] | Sequence[Sequence[str]]
|
|
37
|
+
) = "auto",
|
|
38
|
+
**kwargs: Any,
|
|
39
|
+
) -> None:
|
|
40
|
+
# Intercept arguments that are supported only on TensorFlow.
|
|
41
|
+
self._optimizer = kwargs.pop("optimizer", None)
|
|
42
|
+
self._pipeline_execution_with_tensor_core = kwargs.pop(
|
|
43
|
+
"pipeline_execution_with_tensor_core", False
|
|
44
|
+
)
|
|
45
|
+
self._sparse_core_embedding_config = kwargs.pop(
|
|
46
|
+
"sparse_core_embedding_config", None
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
# Mark as True by default for `_verify_input_shapes`. This will be
|
|
50
|
+
# updated in `_sparsecore_init` if applicable.
|
|
51
|
+
self._using_keras_rs_configuration = True
|
|
52
|
+
|
|
53
|
+
super().__init__(
|
|
54
|
+
feature_configs, table_stacking=table_stacking, **kwargs
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
def _is_tpu_strategy(self, strategy: tf.distribute.Strategy) -> bool:
|
|
58
|
+
return isinstance(
|
|
59
|
+
strategy,
|
|
60
|
+
(tf.distribute.TPUStrategy, tf.distribute.experimental.TPUStrategy),
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
def _has_sparsecore(self) -> bool:
|
|
64
|
+
strategy = tf.distribute.get_strategy()
|
|
65
|
+
if self._is_tpu_strategy(strategy):
|
|
66
|
+
tpu_embedding_feature = (
|
|
67
|
+
strategy.extended.tpu_hardware_feature.embedding_feature
|
|
68
|
+
)
|
|
69
|
+
return tpu_embedding_feature in (
|
|
70
|
+
EMBEDDING_FEATURE_V2,
|
|
71
|
+
EMBEDDING_FEATURE_V1,
|
|
72
|
+
)
|
|
73
|
+
return False
|
|
74
|
+
|
|
75
|
+
@keras_utils.no_automatic_dependency_tracking
|
|
76
|
+
def _sparsecore_init(
|
|
77
|
+
self,
|
|
78
|
+
feature_configs: dict[
|
|
79
|
+
str,
|
|
80
|
+
FeatureConfig | tf.tpu.experimental.embedding.FeatureConfig,
|
|
81
|
+
],
|
|
82
|
+
table_stacking: str | Sequence[str] | Sequence[Sequence[str]],
|
|
83
|
+
) -> None:
|
|
84
|
+
self._table_stacking = table_stacking
|
|
85
|
+
|
|
86
|
+
strategy = tf.distribute.get_strategy()
|
|
87
|
+
if not self._is_tpu_strategy(strategy):
|
|
88
|
+
raise ValueError(
|
|
89
|
+
"Placement to sparsecore was requested, however, we are not "
|
|
90
|
+
"running under a TPU strategy."
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
tpu_embedding_feature = (
|
|
94
|
+
strategy.extended.tpu_hardware_feature.embedding_feature
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
self._using_keras_rs_configuration = isinstance(
|
|
98
|
+
next(iter(feature_configs.values())), FeatureConfig
|
|
99
|
+
)
|
|
100
|
+
if self._using_keras_rs_configuration:
|
|
101
|
+
if self._sparse_core_embedding_config is not None:
|
|
102
|
+
raise ValueError(
|
|
103
|
+
"The `sparse_core_embedding_config` argument is only "
|
|
104
|
+
"supported when using "
|
|
105
|
+
"`tf.tpu.experimental.embedding.FeatureConfig` instances "
|
|
106
|
+
"for the configuration."
|
|
107
|
+
)
|
|
108
|
+
self._tpu_feature_configs, self._sparse_core_embedding_config = (
|
|
109
|
+
config_conversion.translate_keras_rs_configuration(
|
|
110
|
+
feature_configs, table_stacking
|
|
111
|
+
)
|
|
112
|
+
)
|
|
113
|
+
if tpu_embedding_feature == EMBEDDING_FEATURE_V1:
|
|
114
|
+
# Remove auto-generated SparseCoreEmbeddingConfig, which is not
|
|
115
|
+
# used.
|
|
116
|
+
self._sparse_core_embedding_config = None
|
|
117
|
+
else:
|
|
118
|
+
if table_stacking != "auto":
|
|
119
|
+
raise ValueError(
|
|
120
|
+
"The `table_stacking` argument is not supported when using "
|
|
121
|
+
"`tf.tpu.experimental.embedding.FeatureConfig` for the "
|
|
122
|
+
"configuration. You can use the `disable_table_stacking` "
|
|
123
|
+
"attribute of "
|
|
124
|
+
"`tf.tpu.experimental.embedding.SparseCoreEmbeddingConfig` "
|
|
125
|
+
"to disable table stacking."
|
|
126
|
+
)
|
|
127
|
+
if (
|
|
128
|
+
tpu_embedding_feature == EMBEDDING_FEATURE_V1
|
|
129
|
+
and self._sparse_core_embedding_config is not None
|
|
130
|
+
):
|
|
131
|
+
raise ValueError(
|
|
132
|
+
"The `sparse_core_embedding_config` argument is not "
|
|
133
|
+
"supported with this TPU generation."
|
|
134
|
+
)
|
|
135
|
+
self._tpu_feature_configs = (
|
|
136
|
+
config_conversion.clone_tf_feature_configs(feature_configs)
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
self._tpu_optimizer = config_conversion.translate_optimizer(
|
|
140
|
+
self._optimizer
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
if tpu_embedding_feature == EMBEDDING_FEATURE_V1:
|
|
144
|
+
self._tpu_embedding = tf.tpu.experimental.embedding.TPUEmbedding(
|
|
145
|
+
self._tpu_feature_configs,
|
|
146
|
+
self._tpu_optimizer,
|
|
147
|
+
self._pipeline_execution_with_tensor_core,
|
|
148
|
+
)
|
|
149
|
+
self._v1_call_id = 0
|
|
150
|
+
elif tpu_embedding_feature == EMBEDDING_FEATURE_V2:
|
|
151
|
+
self._tpu_embedding = tf.tpu.experimental.embedding.TPUEmbeddingV2(
|
|
152
|
+
self._tpu_feature_configs,
|
|
153
|
+
self._tpu_optimizer,
|
|
154
|
+
self._pipeline_execution_with_tensor_core,
|
|
155
|
+
self._sparse_core_embedding_config,
|
|
156
|
+
)
|
|
157
|
+
elif tpu_embedding_feature == UNSUPPORTED:
|
|
158
|
+
raise ValueError(
|
|
159
|
+
"Placement to sparsecore was requested, however, this TPU does "
|
|
160
|
+
"not support it."
|
|
161
|
+
)
|
|
162
|
+
elif tpu_embedding_feature != UNSUPPORTED:
|
|
163
|
+
raise ValueError(
|
|
164
|
+
f"Unsupported TPU embedding feature: {tpu_embedding_feature}."
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
# We need at least one trainable variable for the gradient trap to work.
|
|
168
|
+
# Note that the Python attribute name "_gradient_trap_dummy" should
|
|
169
|
+
# match the name of the variable GRADIENT_TRAP_DUMMY_NAME.
|
|
170
|
+
self._gradient_trap_dummy = self.add_weight(
|
|
171
|
+
name=GRADIENT_TRAP_DUMMY_NAME,
|
|
172
|
+
shape=(1,),
|
|
173
|
+
initializer=tf.zeros_initializer(),
|
|
174
|
+
trainable=True,
|
|
175
|
+
dtype=tf.float32,
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
def compute_output_shape(
|
|
179
|
+
self, input_shapes: types.Nested[types.Shape]
|
|
180
|
+
) -> types.Nested[types.Shape]:
|
|
181
|
+
if self._using_keras_rs_configuration:
|
|
182
|
+
return super().compute_output_shape(input_shapes)
|
|
183
|
+
|
|
184
|
+
def _compute_output_shape(
|
|
185
|
+
feature_config: tf.tpu.experimental.embedding.FeatureConfig,
|
|
186
|
+
input_shape: types.Shape,
|
|
187
|
+
) -> types.Shape:
|
|
188
|
+
if len(input_shape) < 1:
|
|
189
|
+
raise ValueError(
|
|
190
|
+
f"Received input shape {input_shape}. Rank must be 1 or "
|
|
191
|
+
"above."
|
|
192
|
+
)
|
|
193
|
+
max_sequence_length: int = feature_config.max_sequence_length
|
|
194
|
+
embed_dim = feature_config.table.dim
|
|
195
|
+
if (
|
|
196
|
+
feature_config.output_shape is not None
|
|
197
|
+
and feature_config.output_shape.rank is not None
|
|
198
|
+
):
|
|
199
|
+
return tuple(feature_config.output_shape.as_list())
|
|
200
|
+
elif (
|
|
201
|
+
len(input_shape) == 2
|
|
202
|
+
and input_shape[-1] != 1
|
|
203
|
+
and max_sequence_length > 0
|
|
204
|
+
):
|
|
205
|
+
# Update the input shape with the max sequence length. Only
|
|
206
|
+
# update when:
|
|
207
|
+
# 1. Input feature is 2D ragged or sparse tensor.
|
|
208
|
+
# 2. Output shape is not set and max sequence length is set.
|
|
209
|
+
return tuple(input_shape[:-1]) + (
|
|
210
|
+
max_sequence_length,
|
|
211
|
+
embed_dim,
|
|
212
|
+
)
|
|
213
|
+
elif len(input_shape) == 1:
|
|
214
|
+
return tuple(input_shape) + (embed_dim,)
|
|
215
|
+
else:
|
|
216
|
+
return tuple(input_shape[:-1]) + (embed_dim,)
|
|
217
|
+
|
|
218
|
+
output_shapes: types.Nested[types.Shape] = (
|
|
219
|
+
keras.tree.map_structure_up_to(
|
|
220
|
+
self._feature_configs,
|
|
221
|
+
_compute_output_shape,
|
|
222
|
+
self._feature_configs,
|
|
223
|
+
input_shapes,
|
|
224
|
+
)
|
|
225
|
+
)
|
|
226
|
+
return output_shapes
|
|
227
|
+
|
|
228
|
+
def _sparsecore_build(self, input_shapes: dict[str, types.Shape]) -> None:
|
|
229
|
+
if isinstance(
|
|
230
|
+
self._tpu_embedding, tf.tpu.experimental.embedding.TPUEmbedding
|
|
231
|
+
):
|
|
232
|
+
tf_input_shapes = keras.tree.map_shape_structure(
|
|
233
|
+
tf.TensorShape, input_shapes
|
|
234
|
+
)
|
|
235
|
+
tpu_embedding_build = tf.autograph.to_graph(
|
|
236
|
+
self._tpu_embedding.build, recursive=False
|
|
237
|
+
)
|
|
238
|
+
tpu_embedding_build(
|
|
239
|
+
self._tpu_embedding, per_replica_input_shapes=tf_input_shapes
|
|
240
|
+
)
|
|
241
|
+
elif isinstance(
|
|
242
|
+
self._tpu_embedding, tf.tpu.experimental.embedding.TPUEmbeddingV2
|
|
243
|
+
):
|
|
244
|
+
self._tpu_embedding.build()
|
|
245
|
+
|
|
246
|
+
def _sparsecore_call(
|
|
247
|
+
self,
|
|
248
|
+
inputs: dict[str, types.Tensor],
|
|
249
|
+
weights: dict[str, types.Tensor] | None = None,
|
|
250
|
+
training: bool = False,
|
|
251
|
+
) -> dict[str, types.Tensor]:
|
|
252
|
+
del training # Unused.
|
|
253
|
+
strategy = tf.distribute.get_strategy()
|
|
254
|
+
if not self._is_tpu_strategy(strategy):
|
|
255
|
+
raise RuntimeError(
|
|
256
|
+
"DistributedEmbedding needs to be called under a TPUStrategy "
|
|
257
|
+
"for features placed on the embedding feature but is being "
|
|
258
|
+
f"called under strategy {strategy}. Please use `strategy.run` "
|
|
259
|
+
"when calling this layer."
|
|
260
|
+
)
|
|
261
|
+
if isinstance(
|
|
262
|
+
self._tpu_embedding, tf.tpu.experimental.embedding.TPUEmbedding
|
|
263
|
+
):
|
|
264
|
+
return self._tpu_embedding_lookup_v1(
|
|
265
|
+
self._tpu_embedding, inputs, weights
|
|
266
|
+
)
|
|
267
|
+
elif isinstance(
|
|
268
|
+
self._tpu_embedding, tf.tpu.experimental.embedding.TPUEmbeddingV2
|
|
269
|
+
):
|
|
270
|
+
return self._tpu_embedding_lookup_v2(
|
|
271
|
+
self._tpu_embedding, inputs, weights
|
|
272
|
+
)
|
|
273
|
+
else:
|
|
274
|
+
raise ValueError(
|
|
275
|
+
"DistributedEmbedding is receiving features to lookup on the "
|
|
276
|
+
"TPU embedding feature but no such feature was configured."
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
def _sparsecore_get_embedding_tables(self) -> dict[str, types.Tensor]:
|
|
280
|
+
tables: dict[str, types.Tensor] = {}
|
|
281
|
+
strategy = tf.distribute.get_strategy()
|
|
282
|
+
# 4 is the number of sparsecores per chip
|
|
283
|
+
num_shards = strategy.num_replicas_in_sync * 4
|
|
284
|
+
|
|
285
|
+
def populate_table(
|
|
286
|
+
feature_config: tf.tpu.experimental.embedding.FeatureConfig,
|
|
287
|
+
) -> None:
|
|
288
|
+
table_name = feature_config.table.name
|
|
289
|
+
if table_name in tables:
|
|
290
|
+
return
|
|
291
|
+
|
|
292
|
+
embedding_dim = feature_config.table.dim
|
|
293
|
+
table = self._tpu_embedding.embedding_tables[table_name]
|
|
294
|
+
|
|
295
|
+
# This table has num_sparse_cores mod shards, so we need to slice,
|
|
296
|
+
# reconcat and reshape.
|
|
297
|
+
table_shards = [
|
|
298
|
+
shard.numpy()[:, :embedding_dim] for shard in table.values
|
|
299
|
+
]
|
|
300
|
+
full_table = keras.ops.concatenate(table_shards, axis=0)
|
|
301
|
+
full_table = keras.ops.concatenate(
|
|
302
|
+
keras.ops.split(full_table, num_shards, axis=0), axis=1
|
|
303
|
+
)
|
|
304
|
+
full_table = keras.ops.reshape(full_table, [-1, embedding_dim])
|
|
305
|
+
tables[table_name] = full_table[
|
|
306
|
+
: feature_config.table.vocabulary_size, :
|
|
307
|
+
]
|
|
308
|
+
|
|
309
|
+
keras.tree.map_structure(populate_table, self._tpu_feature_configs)
|
|
310
|
+
return tables
|
|
311
|
+
|
|
312
|
+
def _verify_input_shapes(
|
|
313
|
+
self, input_shapes: types.Nested[types.Shape]
|
|
314
|
+
) -> None:
|
|
315
|
+
if self._using_keras_rs_configuration:
|
|
316
|
+
return super()._verify_input_shapes(input_shapes)
|
|
317
|
+
# `tf.tpu.experimental.embedding.FeatureConfig` does not provide any
|
|
318
|
+
# information about the input shape, so there is nothing to verify.
|
|
319
|
+
|
|
320
|
+
def _tpu_embedding_lookup_v1(
|
|
321
|
+
self,
|
|
322
|
+
tpu_embedding: tf.tpu.experimental.embedding.TPUEmbedding,
|
|
323
|
+
inputs: dict[str, types.Tensor],
|
|
324
|
+
weights: dict[str, types.Tensor] | None = None,
|
|
325
|
+
) -> dict[str, types.Tensor]:
|
|
326
|
+
# Each call to this function increments the _v1_call_id by 1, this
|
|
327
|
+
# allows us to tag each of the main embedding ops with this call id so
|
|
328
|
+
# that we know during graph rewriting passes which ops correspond to the
|
|
329
|
+
# same layer call.
|
|
330
|
+
self._v1_call_id += 1
|
|
331
|
+
name = str(self._v1_call_id)
|
|
332
|
+
|
|
333
|
+
# Set training to true, even during eval. When name is set, this will
|
|
334
|
+
# trigger a pass that updates the training based on if there is a send
|
|
335
|
+
# gradients with the same name.
|
|
336
|
+
tpu_embedding.enqueue(inputs, weights, training=True, name=name)
|
|
337
|
+
|
|
338
|
+
@tf.custom_gradient # type: ignore
|
|
339
|
+
def gradient_trap(
|
|
340
|
+
dummy: types.Tensor,
|
|
341
|
+
) -> tuple[
|
|
342
|
+
list[types.Tensor], Callable[[tuple[types.Tensor]], types.Tensor]
|
|
343
|
+
]:
|
|
344
|
+
"""Register a gradient function for activation."""
|
|
345
|
+
activations = tpu_embedding.dequeue(name=name)
|
|
346
|
+
|
|
347
|
+
def grad(*grad_wrt_activations: types.Tensor) -> types.Tensor:
|
|
348
|
+
"""Gradient function."""
|
|
349
|
+
# Since the output were flattened, the gradients are also
|
|
350
|
+
# flattened. Pack them back into the correct nested structure.
|
|
351
|
+
gradients = tf.nest.pack_sequence_as(
|
|
352
|
+
self._placement_to_path_to_feature_config["sparsecore"],
|
|
353
|
+
grad_wrt_activations,
|
|
354
|
+
)
|
|
355
|
+
tpu_embedding.apply_gradients(gradients, name=name)
|
|
356
|
+
|
|
357
|
+
# This is the gradient for the input variable.
|
|
358
|
+
return tf.zeros_like(dummy)
|
|
359
|
+
|
|
360
|
+
# Custom gradient functions don't like nested structures of tensors,
|
|
361
|
+
# so we flatten them here.
|
|
362
|
+
return tf.nest.flatten(activations), grad
|
|
363
|
+
|
|
364
|
+
activations_with_trap = gradient_trap(self._gradient_trap_dummy.value)
|
|
365
|
+
result: dict[str, types.Tensor] = tf.nest.pack_sequence_as(
|
|
366
|
+
self._placement_to_path_to_feature_config["sparsecore"],
|
|
367
|
+
activations_with_trap,
|
|
368
|
+
)
|
|
369
|
+
return result
|
|
370
|
+
|
|
371
|
+
def _tpu_embedding_lookup_v2(
|
|
372
|
+
self,
|
|
373
|
+
tpu_embedding: tf.tpu.experimental.embedding.TPUEmbeddingV2,
|
|
374
|
+
inputs: dict[str, types.Tensor],
|
|
375
|
+
weights: dict[str, types.Tensor] | None = None,
|
|
376
|
+
) -> dict[str, types.Tensor]:
|
|
377
|
+
@tf.custom_gradient # type: ignore
|
|
378
|
+
def gradient_trap(
|
|
379
|
+
dummy: types.Tensor,
|
|
380
|
+
) -> tuple[
|
|
381
|
+
list[types.Tensor], Callable[[tuple[types.Tensor]], types.Tensor]
|
|
382
|
+
]:
|
|
383
|
+
"""Register a gradient function for activation."""
|
|
384
|
+
activations, preserved_result = tpu_embedding(inputs, weights)
|
|
385
|
+
|
|
386
|
+
def grad(*grad_wrt_activations: types.Tensor) -> types.Tensor:
|
|
387
|
+
"""Gradient function."""
|
|
388
|
+
# Since the output were flattened, the gradients are also
|
|
389
|
+
# flattened. Pack them back into the correct nested structure.
|
|
390
|
+
gradients = tf.nest.pack_sequence_as(
|
|
391
|
+
self._placement_to_path_to_feature_config["sparsecore"],
|
|
392
|
+
grad_wrt_activations,
|
|
393
|
+
)
|
|
394
|
+
tpu_embedding.apply_gradients(
|
|
395
|
+
gradients, preserved_outputs=preserved_result
|
|
396
|
+
)
|
|
397
|
+
# This is the gradient for the input variable.
|
|
398
|
+
return tf.zeros_like(dummy)
|
|
399
|
+
|
|
400
|
+
# Custom gradient functions don't like nested structures of tensors,
|
|
401
|
+
# so we flatten them here.
|
|
402
|
+
return tf.nest.flatten(activations), grad
|
|
403
|
+
|
|
404
|
+
activations_with_trap = gradient_trap(self._gradient_trap_dummy)
|
|
405
|
+
result: dict[str, types.Tensor] = tf.nest.pack_sequence_as(
|
|
406
|
+
self._placement_to_path_to_feature_config["sparsecore"],
|
|
407
|
+
activations_with_trap,
|
|
408
|
+
)
|
|
409
|
+
return result
|
|
410
|
+
|
|
411
|
+
def _trackable_children(
|
|
412
|
+
self, save_type: str = "checkpoint", **kwargs: dict[str, Any]
|
|
413
|
+
) -> dict[str, Any]:
|
|
414
|
+
# Remove dummy variable, we don't want it in checkpoints.
|
|
415
|
+
children: dict[str, Any] = super()._trackable_children(
|
|
416
|
+
save_type, **kwargs
|
|
417
|
+
)
|
|
418
|
+
children.pop(GRADIENT_TRAP_DUMMY_NAME, None)
|
|
419
|
+
return children
|
|
420
|
+
|
|
421
|
+
|
|
422
|
+
DistributedEmbedding.__doc__ = (
|
|
423
|
+
base_distributed_embedding.DistributedEmbedding.__doc__
|
|
424
|
+
)
|
|
@@ -205,8 +205,8 @@ class DotInteraction(keras.layers.Layer):
|
|
|
205
205
|
return activations
|
|
206
206
|
|
|
207
207
|
def compute_output_shape(
|
|
208
|
-
self, input_shape: list[types.
|
|
209
|
-
) -> types.
|
|
208
|
+
self, input_shape: list[types.Shape]
|
|
209
|
+
) -> types.Shape:
|
|
210
210
|
num_features = len(input_shape)
|
|
211
211
|
batch_size = input_shape[0][0]
|
|
212
212
|
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Any
|
|
1
|
+
from typing import Any
|
|
2
2
|
|
|
3
3
|
import keras
|
|
4
4
|
from keras import ops
|
|
@@ -92,20 +92,18 @@ class FeatureCross(keras.layers.Layer):
|
|
|
92
92
|
|
|
93
93
|
def __init__(
|
|
94
94
|
self,
|
|
95
|
-
projection_dim:
|
|
96
|
-
diag_scale:
|
|
95
|
+
projection_dim: int | None = None,
|
|
96
|
+
diag_scale: float | None = 0.0,
|
|
97
97
|
use_bias: bool = True,
|
|
98
|
-
pre_activation:
|
|
99
|
-
kernel_initializer:
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
bias_initializer:
|
|
103
|
-
kernel_regularizer:
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
bias_regularizer:
|
|
107
|
-
Text, None, keras.regularizers.Regularizer
|
|
108
|
-
] = None,
|
|
98
|
+
pre_activation: str | keras.layers.Activation | None = None,
|
|
99
|
+
kernel_initializer: (
|
|
100
|
+
str | keras.initializers.Initializer
|
|
101
|
+
) = "glorot_uniform",
|
|
102
|
+
bias_initializer: str | keras.initializers.Initializer = "zeros",
|
|
103
|
+
kernel_regularizer: (
|
|
104
|
+
str | None | keras.regularizers.Regularizer
|
|
105
|
+
) = None,
|
|
106
|
+
bias_regularizer: (str | None | keras.regularizers.Regularizer) = None,
|
|
109
107
|
**kwargs: Any,
|
|
110
108
|
) -> None:
|
|
111
109
|
super().__init__(**kwargs)
|
|
@@ -129,7 +127,7 @@ class FeatureCross(keras.layers.Layer):
|
|
|
129
127
|
f"`diag_scale={self.diag_scale}`"
|
|
130
128
|
)
|
|
131
129
|
|
|
132
|
-
def build(self, input_shape: types.
|
|
130
|
+
def build(self, input_shape: types.Shape) -> None:
|
|
133
131
|
last_dim = input_shape[-1]
|
|
134
132
|
|
|
135
133
|
if self.projection_dim is not None:
|
|
@@ -155,7 +153,7 @@ class FeatureCross(keras.layers.Layer):
|
|
|
155
153
|
self.built = True
|
|
156
154
|
|
|
157
155
|
def call(
|
|
158
|
-
self, x0: types.Tensor, x:
|
|
156
|
+
self, x0: types.Tensor, x: types.Tensor | None = None
|
|
159
157
|
) -> types.Tensor:
|
|
160
158
|
"""Forward pass of the cross layer.
|
|
161
159
|
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Any
|
|
1
|
+
from typing import Any
|
|
2
2
|
|
|
3
3
|
import keras
|
|
4
4
|
|
|
@@ -55,8 +55,8 @@ class BruteForceRetrieval(Retrieval):
|
|
|
55
55
|
|
|
56
56
|
def __init__(
|
|
57
57
|
self,
|
|
58
|
-
candidate_embeddings:
|
|
59
|
-
candidate_ids:
|
|
58
|
+
candidate_embeddings: types.Tensor | None = None,
|
|
59
|
+
candidate_ids: types.Tensor | None = None,
|
|
60
60
|
k: int = 10,
|
|
61
61
|
return_scores: bool = True,
|
|
62
62
|
**kwargs: Any,
|
|
@@ -81,7 +81,7 @@ class BruteForceRetrieval(Retrieval):
|
|
|
81
81
|
def update_candidates(
|
|
82
82
|
self,
|
|
83
83
|
candidate_embeddings: types.Tensor,
|
|
84
|
-
candidate_ids:
|
|
84
|
+
candidate_ids: types.Tensor | None = None,
|
|
85
85
|
) -> None:
|
|
86
86
|
"""Update the set of candidates and optionally their candidate IDs.
|
|
87
87
|
|
|
@@ -125,7 +125,7 @@ class BruteForceRetrieval(Retrieval):
|
|
|
125
125
|
|
|
126
126
|
def call(
|
|
127
127
|
self, inputs: types.Tensor
|
|
128
|
-
) ->
|
|
128
|
+
) -> types.Tensor | tuple[types.Tensor, types.Tensor]:
|
|
129
129
|
"""Returns the top candidates for the query passed as input.
|
|
130
130
|
|
|
131
131
|
Args:
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import abc
|
|
2
|
-
from typing import Any
|
|
2
|
+
from typing import Any
|
|
3
3
|
|
|
4
4
|
import keras
|
|
5
5
|
|
|
@@ -35,7 +35,7 @@ class Retrieval(keras.layers.Layer, abc.ABC):
|
|
|
35
35
|
def _validate_candidate_embeddings_and_ids(
|
|
36
36
|
self,
|
|
37
37
|
candidate_embeddings: types.Tensor,
|
|
38
|
-
candidate_ids:
|
|
38
|
+
candidate_ids: types.Tensor | None = None,
|
|
39
39
|
) -> None:
|
|
40
40
|
"""Validates inputs to `update_candidates()`."""
|
|
41
41
|
|
|
@@ -71,7 +71,7 @@ class Retrieval(keras.layers.Layer, abc.ABC):
|
|
|
71
71
|
def update_candidates(
|
|
72
72
|
self,
|
|
73
73
|
candidate_embeddings: types.Tensor,
|
|
74
|
-
candidate_ids:
|
|
74
|
+
candidate_ids: types.Tensor | None = None,
|
|
75
75
|
) -> None:
|
|
76
76
|
"""Update the set of candidates and optionally their candidate IDs.
|
|
77
77
|
|
|
@@ -85,7 +85,7 @@ class Retrieval(keras.layers.Layer, abc.ABC):
|
|
|
85
85
|
@abc.abstractmethod
|
|
86
86
|
def call(
|
|
87
87
|
self, inputs: types.Tensor
|
|
88
|
-
) ->
|
|
88
|
+
) -> types.Tensor | tuple[types.Tensor, types.Tensor]:
|
|
89
89
|
"""Returns the top candidates for the query passed as input.
|
|
90
90
|
|
|
91
91
|
Args:
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import abc
|
|
2
|
-
from typing import Any
|
|
2
|
+
from typing import Any
|
|
3
3
|
|
|
4
4
|
import keras
|
|
5
5
|
from keras import ops
|
|
@@ -43,7 +43,7 @@ class PairwiseLoss(keras.losses.Loss, abc.ABC):
|
|
|
43
43
|
self,
|
|
44
44
|
labels: types.Tensor,
|
|
45
45
|
logits: types.Tensor,
|
|
46
|
-
mask:
|
|
46
|
+
mask: types.Tensor | None = None,
|
|
47
47
|
) -> tuple[types.Tensor, types.Tensor]:
|
|
48
48
|
# Mask all values less than 0 (since less than 0 implies invalid
|
|
49
49
|
# labels).
|
|
@@ -1,5 +1,3 @@
|
|
|
1
|
-
from typing import Optional
|
|
2
|
-
|
|
3
1
|
from keras import ops
|
|
4
2
|
|
|
5
3
|
from keras_rs.src import types
|
|
@@ -20,7 +18,7 @@ class PairwiseMeanSquaredError(PairwiseLoss):
|
|
|
20
18
|
self,
|
|
21
19
|
labels: types.Tensor,
|
|
22
20
|
logits: types.Tensor,
|
|
23
|
-
mask:
|
|
21
|
+
mask: types.Tensor | None = None,
|
|
24
22
|
) -> tuple[types.Tensor, types.Tensor]:
|
|
25
23
|
# Override `PairwiseLoss.compute_unreduced_loss` since pairwise weights
|
|
26
24
|
# for MSE are computed differently.
|
keras_rs/src/metrics/dcg.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Any, Callable
|
|
1
|
+
from typing import Any, Callable
|
|
2
2
|
|
|
3
3
|
from keras import ops
|
|
4
4
|
from keras.saving import deserialize_keras_object
|
|
@@ -25,7 +25,7 @@ from keras_rs.src.utils.doc_string_utils import format_docstring
|
|
|
25
25
|
class DCG(RankingMetric):
|
|
26
26
|
def __init__(
|
|
27
27
|
self,
|
|
28
|
-
k:
|
|
28
|
+
k: int | None = None,
|
|
29
29
|
gain_fn: Callable[[types.Tensor], types.Tensor] = default_gain_fn,
|
|
30
30
|
rank_discount_fn: Callable[
|
|
31
31
|
[types.Tensor], types.Tensor
|
|
@@ -25,7 +25,7 @@ class MeanAveragePrecision(RankingMetric):
|
|
|
25
25
|
) -> types.Tensor:
|
|
26
26
|
relevance = ops.cast(
|
|
27
27
|
ops.greater_equal(y_true, ops.cast(1, dtype=y_true.dtype)),
|
|
28
|
-
dtype=
|
|
28
|
+
dtype=y_pred.dtype,
|
|
29
29
|
)
|
|
30
30
|
sorted_relevance, sorted_weights = sort_by_scores(
|
|
31
31
|
tensors_to_sort=[relevance, sample_weight],
|
|
@@ -44,13 +44,13 @@ class MeanReciprocalRank(RankingMetric):
|
|
|
44
44
|
ops.greater_equal(
|
|
45
45
|
sorted_y_true, ops.cast(1, dtype=sorted_y_true.dtype)
|
|
46
46
|
),
|
|
47
|
-
dtype=
|
|
47
|
+
dtype=y_pred.dtype,
|
|
48
48
|
)
|
|
49
49
|
|
|
50
50
|
# `reciprocal_rank = [1, 0.5, 0.33]`
|
|
51
51
|
reciprocal_rank = ops.divide(
|
|
52
|
-
ops.cast(1, dtype=
|
|
53
|
-
ops.arange(1, list_length + 1, dtype=
|
|
52
|
+
ops.cast(1, dtype=y_pred.dtype),
|
|
53
|
+
ops.arange(1, list_length + 1, dtype=y_pred.dtype),
|
|
54
54
|
)
|
|
55
55
|
|
|
56
56
|
# `mrr` should be of shape `(batch_size, 1)`.
|
|
@@ -64,7 +64,7 @@ class MeanReciprocalRank(RankingMetric):
|
|
|
64
64
|
# Get weights.
|
|
65
65
|
overall_relevance = ops.cast(
|
|
66
66
|
ops.greater_equal(y_true, ops.cast(1, dtype=y_true.dtype)),
|
|
67
|
-
dtype=
|
|
67
|
+
dtype=y_pred.dtype,
|
|
68
68
|
)
|
|
69
69
|
per_list_weights = get_list_weights(
|
|
70
70
|
weights=sample_weight, relevance=overall_relevance
|
keras_rs/src/metrics/ndcg.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Any, Callable
|
|
1
|
+
from typing import Any, Callable
|
|
2
2
|
|
|
3
3
|
from keras import ops
|
|
4
4
|
from keras.saving import deserialize_keras_object
|
|
@@ -25,7 +25,7 @@ from keras_rs.src.utils.doc_string_utils import format_docstring
|
|
|
25
25
|
class NDCG(RankingMetric):
|
|
26
26
|
def __init__(
|
|
27
27
|
self,
|
|
28
|
-
k:
|
|
28
|
+
k: int | None = None,
|
|
29
29
|
gain_fn: Callable[[types.Tensor], types.Tensor] = default_gain_fn,
|
|
30
30
|
rank_discount_fn: Callable[
|
|
31
31
|
[types.Tensor], types.Tensor
|
|
@@ -40,7 +40,7 @@ class PrecisionAtK(RankingMetric):
|
|
|
40
40
|
ops.greater_equal(
|
|
41
41
|
sorted_y_true, ops.cast(1, dtype=sorted_y_true.dtype)
|
|
42
42
|
),
|
|
43
|
-
dtype=
|
|
43
|
+
dtype=y_pred.dtype,
|
|
44
44
|
)
|
|
45
45
|
list_length = ops.shape(sorted_y_true)[1]
|
|
46
46
|
# TODO: We do not do this for MRR, and the other metrics. Do we need to
|
|
@@ -52,13 +52,13 @@ class PrecisionAtK(RankingMetric):
|
|
|
52
52
|
|
|
53
53
|
per_list_precision = ops.divide_no_nan(
|
|
54
54
|
ops.sum(relevance, axis=1, keepdims=True),
|
|
55
|
-
ops.cast(valid_list_length, dtype=
|
|
55
|
+
ops.cast(valid_list_length, dtype=y_pred.dtype),
|
|
56
56
|
)
|
|
57
57
|
|
|
58
58
|
# Get weights.
|
|
59
59
|
overall_relevance = ops.cast(
|
|
60
60
|
ops.greater_equal(y_true, ops.cast(1, dtype=y_true.dtype)),
|
|
61
|
-
dtype=
|
|
61
|
+
dtype=y_pred.dtype,
|
|
62
62
|
)
|
|
63
63
|
per_list_weights = get_list_weights(
|
|
64
64
|
weights=sample_weight, relevance=overall_relevance
|