keras-rs-nightly 0.0.1.dev2025021903__py3-none-any.whl → 0.3.1.dev202512130338__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_rs/__init__.py +9 -28
- keras_rs/layers/__init__.py +37 -0
- keras_rs/losses/__init__.py +19 -0
- keras_rs/metrics/__init__.py +16 -0
- keras_rs/src/layers/embedding/base_distributed_embedding.py +1151 -0
- keras_rs/src/layers/embedding/distributed_embedding.py +33 -0
- keras_rs/src/layers/embedding/distributed_embedding_config.py +132 -0
- keras_rs/src/layers/embedding/embed_reduce.py +309 -0
- keras_rs/src/layers/embedding/jax/__init__.py +0 -0
- keras_rs/src/layers/embedding/jax/checkpoint_utils.py +104 -0
- keras_rs/src/layers/embedding/jax/config_conversion.py +468 -0
- keras_rs/src/layers/embedding/jax/distributed_embedding.py +829 -0
- keras_rs/src/layers/embedding/jax/embedding_lookup.py +276 -0
- keras_rs/src/layers/embedding/jax/embedding_utils.py +217 -0
- keras_rs/src/layers/embedding/tensorflow/__init__.py +0 -0
- keras_rs/src/layers/embedding/tensorflow/config_conversion.py +363 -0
- keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py +436 -0
- keras_rs/src/layers/feature_interaction/__init__.py +0 -0
- keras_rs/src/layers/{modeling → feature_interaction}/dot_interaction.py +116 -25
- keras_rs/src/layers/{modeling → feature_interaction}/feature_cross.py +40 -22
- keras_rs/src/layers/retrieval/brute_force_retrieval.py +16 -65
- keras_rs/src/layers/retrieval/hard_negative_mining.py +94 -0
- keras_rs/src/layers/retrieval/remove_accidental_hits.py +97 -0
- keras_rs/src/layers/retrieval/retrieval.py +127 -0
- keras_rs/src/layers/retrieval/sampling_probability_correction.py +63 -0
- keras_rs/src/losses/__init__.py +0 -0
- keras_rs/src/losses/list_mle_loss.py +212 -0
- keras_rs/src/losses/pairwise_hinge_loss.py +90 -0
- keras_rs/src/losses/pairwise_logistic_loss.py +99 -0
- keras_rs/src/losses/pairwise_loss.py +165 -0
- keras_rs/src/losses/pairwise_loss_utils.py +39 -0
- keras_rs/src/losses/pairwise_mean_squared_error.py +133 -0
- keras_rs/src/losses/pairwise_soft_zero_one_loss.py +98 -0
- keras_rs/src/metrics/__init__.py +0 -0
- keras_rs/src/metrics/dcg.py +161 -0
- keras_rs/src/metrics/mean_average_precision.py +130 -0
- keras_rs/src/metrics/mean_reciprocal_rank.py +121 -0
- keras_rs/src/metrics/ndcg.py +197 -0
- keras_rs/src/metrics/precision_at_k.py +117 -0
- keras_rs/src/metrics/ranking_metric.py +260 -0
- keras_rs/src/metrics/ranking_metrics_utils.py +257 -0
- keras_rs/src/metrics/recall_at_k.py +108 -0
- keras_rs/src/metrics/utils.py +70 -0
- keras_rs/src/types.py +43 -14
- keras_rs/src/utils/doc_string_utils.py +53 -0
- keras_rs/src/utils/keras_utils.py +52 -3
- keras_rs/src/utils/tpu_test_utils.py +120 -0
- keras_rs/src/version.py +1 -1
- {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/METADATA +88 -8
- keras_rs_nightly-0.3.1.dev202512130338.dist-info/RECORD +58 -0
- {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/WHEEL +1 -1
- keras_rs/api/__init__.py +0 -9
- keras_rs/api/layers/__init__.py +0 -11
- keras_rs_nightly-0.0.1.dev2025021903.dist-info/RECORD +0 -19
- /keras_rs/src/layers/{modeling → embedding}/__init__.py +0 -0
- {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,436 @@
|
|
|
1
|
+
from typing import Any, Callable, Sequence, TypeAlias
|
|
2
|
+
|
|
3
|
+
import keras
|
|
4
|
+
import tensorflow as tf
|
|
5
|
+
|
|
6
|
+
from keras_rs.src import types
|
|
7
|
+
from keras_rs.src.layers.embedding import base_distributed_embedding
|
|
8
|
+
from keras_rs.src.layers.embedding import distributed_embedding_config
|
|
9
|
+
from keras_rs.src.layers.embedding.tensorflow import config_conversion
|
|
10
|
+
from keras_rs.src.utils import keras_utils
|
|
11
|
+
|
|
12
|
+
FeatureConfig = distributed_embedding_config.FeatureConfig
|
|
13
|
+
TableConfig = distributed_embedding_config.TableConfig
|
|
14
|
+
|
|
15
|
+
# Placeholder of tf.tpu.experimental.embedding._Optimizer which is not exposed.
|
|
16
|
+
TfTpuOptimizer: TypeAlias = Any
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
GRADIENT_TRAP_DUMMY_NAME = "_gradient_trap_dummy"
|
|
20
|
+
|
|
21
|
+
EMBEDDING_FEATURE_V1 = tf.tpu.experimental.HardwareFeature.EmbeddingFeature.V1
|
|
22
|
+
EMBEDDING_FEATURE_V2 = tf.tpu.experimental.HardwareFeature.EmbeddingFeature.V2
|
|
23
|
+
UNSUPPORTED = tf.tpu.experimental.HardwareFeature.EmbeddingFeature.UNSUPPORTED
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
|
|
27
|
+
"""TensorFlow implementation of the TPU embedding layer."""
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
feature_configs: types.Nested[
|
|
32
|
+
FeatureConfig | tf.tpu.experimental.embedding.FeatureConfig
|
|
33
|
+
],
|
|
34
|
+
*,
|
|
35
|
+
table_stacking: (
|
|
36
|
+
str | Sequence[str] | Sequence[Sequence[str]]
|
|
37
|
+
) = "auto",
|
|
38
|
+
**kwargs: Any,
|
|
39
|
+
) -> None:
|
|
40
|
+
# Intercept arguments that are supported only on TensorFlow.
|
|
41
|
+
self._optimizer = kwargs.pop("optimizer", None)
|
|
42
|
+
self._pipeline_execution_with_tensor_core = kwargs.pop(
|
|
43
|
+
"pipeline_execution_with_tensor_core", False
|
|
44
|
+
)
|
|
45
|
+
self._sparse_core_embedding_config = kwargs.pop(
|
|
46
|
+
"sparse_core_embedding_config", None
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
# Mark as True by default for `_verify_input_shapes`. This will be
|
|
50
|
+
# updated in `_sparsecore_init` if applicable.
|
|
51
|
+
self._using_keras_rs_configuration = True
|
|
52
|
+
|
|
53
|
+
super().__init__(
|
|
54
|
+
feature_configs, table_stacking=table_stacking, **kwargs
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
def _is_tpu_strategy(self, strategy: tf.distribute.Strategy) -> bool:
|
|
58
|
+
return isinstance(
|
|
59
|
+
strategy,
|
|
60
|
+
(tf.distribute.TPUStrategy, tf.distribute.experimental.TPUStrategy),
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
def _has_sparsecore(self) -> bool:
|
|
64
|
+
strategy = tf.distribute.get_strategy()
|
|
65
|
+
if self._is_tpu_strategy(strategy):
|
|
66
|
+
tpu_embedding_feature = (
|
|
67
|
+
strategy.extended.tpu_hardware_feature.embedding_feature
|
|
68
|
+
)
|
|
69
|
+
return tpu_embedding_feature in (
|
|
70
|
+
EMBEDDING_FEATURE_V2,
|
|
71
|
+
EMBEDDING_FEATURE_V1,
|
|
72
|
+
)
|
|
73
|
+
return False
|
|
74
|
+
|
|
75
|
+
@keras_utils.no_automatic_dependency_tracking
|
|
76
|
+
def _sparsecore_init(
|
|
77
|
+
self,
|
|
78
|
+
feature_configs: dict[
|
|
79
|
+
str,
|
|
80
|
+
FeatureConfig | tf.tpu.experimental.embedding.FeatureConfig,
|
|
81
|
+
],
|
|
82
|
+
table_stacking: str | Sequence[str] | Sequence[Sequence[str]],
|
|
83
|
+
) -> None:
|
|
84
|
+
self._table_stacking = table_stacking
|
|
85
|
+
|
|
86
|
+
strategy = tf.distribute.get_strategy()
|
|
87
|
+
if not self._is_tpu_strategy(strategy):
|
|
88
|
+
raise ValueError(
|
|
89
|
+
"Placement to sparsecore was requested, however, we are not "
|
|
90
|
+
"running under a TPU strategy."
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
tpu_embedding_feature = (
|
|
94
|
+
strategy.extended.tpu_hardware_feature.embedding_feature
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
self._using_keras_rs_configuration = isinstance(
|
|
98
|
+
next(iter(feature_configs.values())), FeatureConfig
|
|
99
|
+
)
|
|
100
|
+
if self._using_keras_rs_configuration:
|
|
101
|
+
if self._sparse_core_embedding_config is not None:
|
|
102
|
+
raise ValueError(
|
|
103
|
+
"The `sparse_core_embedding_config` argument is only "
|
|
104
|
+
"supported when using "
|
|
105
|
+
"`tf.tpu.experimental.embedding.FeatureConfig` instances "
|
|
106
|
+
"for the configuration."
|
|
107
|
+
)
|
|
108
|
+
self._tpu_feature_configs, self._sparse_core_embedding_config = (
|
|
109
|
+
config_conversion.keras_to_tf_tpu_configuration(
|
|
110
|
+
feature_configs,
|
|
111
|
+
table_stacking,
|
|
112
|
+
strategy.num_replicas_in_sync,
|
|
113
|
+
)
|
|
114
|
+
)
|
|
115
|
+
if tpu_embedding_feature == EMBEDDING_FEATURE_V1:
|
|
116
|
+
# Remove auto-generated SparseCoreEmbeddingConfig, which is not
|
|
117
|
+
# used.
|
|
118
|
+
self._sparse_core_embedding_config = None
|
|
119
|
+
else:
|
|
120
|
+
if table_stacking != "auto":
|
|
121
|
+
raise ValueError(
|
|
122
|
+
"The `table_stacking` argument is not supported when using "
|
|
123
|
+
"`tf.tpu.experimental.embedding.FeatureConfig` for the "
|
|
124
|
+
"configuration. You can use the `disable_table_stacking` "
|
|
125
|
+
"attribute of "
|
|
126
|
+
"`tf.tpu.experimental.embedding.SparseCoreEmbeddingConfig` "
|
|
127
|
+
"to disable table stacking."
|
|
128
|
+
)
|
|
129
|
+
if (
|
|
130
|
+
tpu_embedding_feature == EMBEDDING_FEATURE_V1
|
|
131
|
+
and self._sparse_core_embedding_config is not None
|
|
132
|
+
):
|
|
133
|
+
raise ValueError(
|
|
134
|
+
"The `sparse_core_embedding_config` argument is not "
|
|
135
|
+
"supported with this TPU generation."
|
|
136
|
+
)
|
|
137
|
+
self._tpu_feature_configs = (
|
|
138
|
+
config_conversion.clone_tf_tpu_feature_configs(feature_configs)
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
self._tpu_optimizer = config_conversion.to_tf_tpu_optimizer(
|
|
142
|
+
self._optimizer
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
if tpu_embedding_feature == EMBEDDING_FEATURE_V1:
|
|
146
|
+
self._tpu_embedding = tf.tpu.experimental.embedding.TPUEmbedding(
|
|
147
|
+
self._tpu_feature_configs,
|
|
148
|
+
self._tpu_optimizer,
|
|
149
|
+
self._pipeline_execution_with_tensor_core,
|
|
150
|
+
)
|
|
151
|
+
self._v1_call_id = 0
|
|
152
|
+
elif tpu_embedding_feature == EMBEDDING_FEATURE_V2:
|
|
153
|
+
self._tpu_embedding = tf.tpu.experimental.embedding.TPUEmbeddingV2(
|
|
154
|
+
self._tpu_feature_configs,
|
|
155
|
+
self._tpu_optimizer,
|
|
156
|
+
self._pipeline_execution_with_tensor_core,
|
|
157
|
+
self._sparse_core_embedding_config,
|
|
158
|
+
)
|
|
159
|
+
elif tpu_embedding_feature == UNSUPPORTED:
|
|
160
|
+
raise ValueError(
|
|
161
|
+
"Placement to sparsecore was requested, however, this TPU does "
|
|
162
|
+
"not support it."
|
|
163
|
+
)
|
|
164
|
+
elif tpu_embedding_feature != UNSUPPORTED:
|
|
165
|
+
raise ValueError(
|
|
166
|
+
f"Unsupported TPU embedding feature: {tpu_embedding_feature}."
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
# We need at least one trainable variable for the gradient trap to work.
|
|
170
|
+
# Note that the Python attribute name "_gradient_trap_dummy" should
|
|
171
|
+
# match the name of the variable GRADIENT_TRAP_DUMMY_NAME.
|
|
172
|
+
self._gradient_trap_dummy = self.add_weight(
|
|
173
|
+
name=GRADIENT_TRAP_DUMMY_NAME,
|
|
174
|
+
shape=(1,),
|
|
175
|
+
initializer=tf.zeros_initializer(),
|
|
176
|
+
trainable=True,
|
|
177
|
+
dtype=tf.float32,
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
def compute_output_shape(
|
|
181
|
+
self, input_shapes: types.Nested[types.Shape]
|
|
182
|
+
) -> types.Nested[types.Shape]:
|
|
183
|
+
if self._using_keras_rs_configuration:
|
|
184
|
+
return super().compute_output_shape(input_shapes)
|
|
185
|
+
|
|
186
|
+
def _compute_output_shape(
|
|
187
|
+
feature_config: tf.tpu.experimental.embedding.FeatureConfig,
|
|
188
|
+
input_shape: types.Shape,
|
|
189
|
+
) -> types.Shape:
|
|
190
|
+
if len(input_shape) < 1:
|
|
191
|
+
raise ValueError(
|
|
192
|
+
f"Received input shape {input_shape}. Rank must be 1 or "
|
|
193
|
+
"above."
|
|
194
|
+
)
|
|
195
|
+
max_sequence_length: int = feature_config.max_sequence_length
|
|
196
|
+
embed_dim = feature_config.table.dim
|
|
197
|
+
if (
|
|
198
|
+
feature_config.output_shape is not None
|
|
199
|
+
and feature_config.output_shape.rank is not None
|
|
200
|
+
):
|
|
201
|
+
return tuple(feature_config.output_shape.as_list())
|
|
202
|
+
elif (
|
|
203
|
+
len(input_shape) == 2
|
|
204
|
+
and input_shape[-1] != 1
|
|
205
|
+
and max_sequence_length > 0
|
|
206
|
+
):
|
|
207
|
+
# Update the input shape with the max sequence length. Only
|
|
208
|
+
# update when:
|
|
209
|
+
# 1. Input feature is 2D ragged or sparse tensor.
|
|
210
|
+
# 2. Output shape is not set and max sequence length is set.
|
|
211
|
+
return tuple(input_shape[:-1]) + (
|
|
212
|
+
max_sequence_length,
|
|
213
|
+
embed_dim,
|
|
214
|
+
)
|
|
215
|
+
elif len(input_shape) == 1:
|
|
216
|
+
return tuple(input_shape) + (embed_dim,)
|
|
217
|
+
else:
|
|
218
|
+
return tuple(input_shape[:-1]) + (embed_dim,)
|
|
219
|
+
|
|
220
|
+
output_shapes: types.Nested[types.Shape] = (
|
|
221
|
+
keras.tree.map_structure_up_to(
|
|
222
|
+
self._feature_configs,
|
|
223
|
+
_compute_output_shape,
|
|
224
|
+
self._feature_configs,
|
|
225
|
+
input_shapes,
|
|
226
|
+
)
|
|
227
|
+
)
|
|
228
|
+
return output_shapes
|
|
229
|
+
|
|
230
|
+
def _sparsecore_build(self, input_shapes: dict[str, types.Shape]) -> None:
|
|
231
|
+
if isinstance(
|
|
232
|
+
self._tpu_embedding, tf.tpu.experimental.embedding.TPUEmbedding
|
|
233
|
+
):
|
|
234
|
+
tf_input_shapes = keras.tree.map_shape_structure(
|
|
235
|
+
tf.TensorShape, input_shapes
|
|
236
|
+
)
|
|
237
|
+
tpu_embedding_build = tf.autograph.to_graph(
|
|
238
|
+
self._tpu_embedding.build, recursive=False
|
|
239
|
+
)
|
|
240
|
+
tpu_embedding_build(
|
|
241
|
+
self._tpu_embedding, per_replica_input_shapes=tf_input_shapes
|
|
242
|
+
)
|
|
243
|
+
elif isinstance(
|
|
244
|
+
self._tpu_embedding, tf.tpu.experimental.embedding.TPUEmbeddingV2
|
|
245
|
+
):
|
|
246
|
+
self._tpu_embedding.build()
|
|
247
|
+
|
|
248
|
+
def _sparsecore_call(
|
|
249
|
+
self,
|
|
250
|
+
inputs: dict[str, types.Tensor],
|
|
251
|
+
weights: dict[str, types.Tensor] | None = None,
|
|
252
|
+
training: bool = False,
|
|
253
|
+
) -> dict[str, types.Tensor]:
|
|
254
|
+
del training # Unused.
|
|
255
|
+
strategy = tf.distribute.get_strategy()
|
|
256
|
+
if not self._is_tpu_strategy(strategy):
|
|
257
|
+
raise RuntimeError(
|
|
258
|
+
"DistributedEmbedding needs to be called under a TPUStrategy "
|
|
259
|
+
"for features placed on the embedding feature but is being "
|
|
260
|
+
f"called under strategy {strategy}. Please use `strategy.run` "
|
|
261
|
+
"when calling this layer."
|
|
262
|
+
)
|
|
263
|
+
if isinstance(
|
|
264
|
+
self._tpu_embedding, tf.tpu.experimental.embedding.TPUEmbedding
|
|
265
|
+
):
|
|
266
|
+
return self._tpu_embedding_lookup_v1(
|
|
267
|
+
self._tpu_embedding, inputs, weights
|
|
268
|
+
)
|
|
269
|
+
elif isinstance(
|
|
270
|
+
self._tpu_embedding, tf.tpu.experimental.embedding.TPUEmbeddingV2
|
|
271
|
+
):
|
|
272
|
+
return self._tpu_embedding_lookup_v2(
|
|
273
|
+
self._tpu_embedding, inputs, weights
|
|
274
|
+
)
|
|
275
|
+
else:
|
|
276
|
+
raise ValueError(
|
|
277
|
+
"DistributedEmbedding is receiving features to lookup on the "
|
|
278
|
+
"TPU embedding feature but no such feature was configured."
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
def _sparsecore_get_embedding_tables(self) -> dict[str, types.Tensor]:
|
|
282
|
+
tables: dict[str, types.Tensor] = {}
|
|
283
|
+
strategy = tf.distribute.get_strategy()
|
|
284
|
+
if not self._is_tpu_strategy(strategy):
|
|
285
|
+
raise RuntimeError(
|
|
286
|
+
"`DistributedEmbedding.get_embedding_tables` needs to be "
|
|
287
|
+
"called under the TPUStrategy that DistributedEmbedding was "
|
|
288
|
+
f"created with, but is being called under strategy {strategy}. "
|
|
289
|
+
"Please use `with strategy.scope()` when calling "
|
|
290
|
+
"`get_embedding_tables`."
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
tpu_hardware = strategy.extended.tpu_hardware_feature
|
|
294
|
+
num_sc_per_device = tpu_hardware.num_embedding_devices_per_chip
|
|
295
|
+
num_shards = strategy.num_replicas_in_sync * num_sc_per_device
|
|
296
|
+
|
|
297
|
+
def populate_table(
|
|
298
|
+
feature_config: tf.tpu.experimental.embedding.FeatureConfig,
|
|
299
|
+
) -> None:
|
|
300
|
+
table_name = feature_config.table.name
|
|
301
|
+
if table_name in tables:
|
|
302
|
+
return
|
|
303
|
+
|
|
304
|
+
embedding_dim = feature_config.table.dim
|
|
305
|
+
table = self._tpu_embedding.embedding_tables[table_name]
|
|
306
|
+
|
|
307
|
+
# This table has num_sparse_cores mod shards, so we need to slice,
|
|
308
|
+
# reconcat and reshape.
|
|
309
|
+
table_shards = [
|
|
310
|
+
shard.numpy()[:, :embedding_dim] for shard in table.values
|
|
311
|
+
]
|
|
312
|
+
full_table = keras.ops.concatenate(table_shards, axis=0)
|
|
313
|
+
full_table = keras.ops.concatenate(
|
|
314
|
+
keras.ops.split(full_table, num_shards, axis=0), axis=1
|
|
315
|
+
)
|
|
316
|
+
full_table = keras.ops.reshape(full_table, [-1, embedding_dim])
|
|
317
|
+
tables[table_name] = full_table[
|
|
318
|
+
: feature_config.table.vocabulary_size, :
|
|
319
|
+
]
|
|
320
|
+
|
|
321
|
+
keras.tree.map_structure(populate_table, self._tpu_feature_configs)
|
|
322
|
+
return tables
|
|
323
|
+
|
|
324
|
+
def _verify_input_shapes(
|
|
325
|
+
self, input_shapes: types.Nested[types.Shape]
|
|
326
|
+
) -> None:
|
|
327
|
+
if self._using_keras_rs_configuration:
|
|
328
|
+
return super()._verify_input_shapes(input_shapes)
|
|
329
|
+
# `tf.tpu.experimental.embedding.FeatureConfig` does not provide any
|
|
330
|
+
# information about the input shape, so there is nothing to verify.
|
|
331
|
+
|
|
332
|
+
def _tpu_embedding_lookup_v1(
|
|
333
|
+
self,
|
|
334
|
+
tpu_embedding: tf.tpu.experimental.embedding.TPUEmbedding,
|
|
335
|
+
inputs: dict[str, types.Tensor],
|
|
336
|
+
weights: dict[str, types.Tensor] | None = None,
|
|
337
|
+
) -> dict[str, types.Tensor]:
|
|
338
|
+
# Each call to this function increments the _v1_call_id by 1, this
|
|
339
|
+
# allows us to tag each of the main embedding ops with this call id so
|
|
340
|
+
# that we know during graph rewriting passes which ops correspond to the
|
|
341
|
+
# same layer call.
|
|
342
|
+
self._v1_call_id += 1
|
|
343
|
+
name = str(self._v1_call_id)
|
|
344
|
+
|
|
345
|
+
# Set training to true, even during eval. When name is set, this will
|
|
346
|
+
# trigger a pass that updates the training based on if there is a send
|
|
347
|
+
# gradients with the same name.
|
|
348
|
+
tpu_embedding.enqueue(inputs, weights, training=True, name=name)
|
|
349
|
+
|
|
350
|
+
@tf.custom_gradient # type: ignore
|
|
351
|
+
def gradient_trap(
|
|
352
|
+
dummy: types.Tensor,
|
|
353
|
+
) -> tuple[
|
|
354
|
+
list[types.Tensor], Callable[[tuple[types.Tensor]], types.Tensor]
|
|
355
|
+
]:
|
|
356
|
+
"""Register a gradient function for activation."""
|
|
357
|
+
activations = tpu_embedding.dequeue(name=name)
|
|
358
|
+
|
|
359
|
+
def grad(*grad_wrt_activations: types.Tensor) -> types.Tensor:
|
|
360
|
+
"""Gradient function."""
|
|
361
|
+
# Since the output were flattened, the gradients are also
|
|
362
|
+
# flattened. Pack them back into the correct nested structure.
|
|
363
|
+
gradients = tf.nest.pack_sequence_as(
|
|
364
|
+
self._placement_to_path_to_feature_config["sparsecore"],
|
|
365
|
+
grad_wrt_activations,
|
|
366
|
+
)
|
|
367
|
+
tpu_embedding.apply_gradients(gradients, name=name)
|
|
368
|
+
|
|
369
|
+
# This is the gradient for the input variable.
|
|
370
|
+
return tf.zeros_like(dummy)
|
|
371
|
+
|
|
372
|
+
# Custom gradient functions don't like nested structures of tensors,
|
|
373
|
+
# so we flatten them here.
|
|
374
|
+
return tf.nest.flatten(activations), grad
|
|
375
|
+
|
|
376
|
+
activations_with_trap = gradient_trap(self._gradient_trap_dummy.value)
|
|
377
|
+
result: dict[str, types.Tensor] = tf.nest.pack_sequence_as(
|
|
378
|
+
self._placement_to_path_to_feature_config["sparsecore"],
|
|
379
|
+
activations_with_trap,
|
|
380
|
+
)
|
|
381
|
+
return result
|
|
382
|
+
|
|
383
|
+
def _tpu_embedding_lookup_v2(
|
|
384
|
+
self,
|
|
385
|
+
tpu_embedding: tf.tpu.experimental.embedding.TPUEmbeddingV2,
|
|
386
|
+
inputs: dict[str, types.Tensor],
|
|
387
|
+
weights: dict[str, types.Tensor] | None = None,
|
|
388
|
+
) -> dict[str, types.Tensor]:
|
|
389
|
+
@tf.custom_gradient # type: ignore
|
|
390
|
+
def gradient_trap(
|
|
391
|
+
dummy: types.Tensor,
|
|
392
|
+
) -> tuple[
|
|
393
|
+
list[types.Tensor], Callable[[tuple[types.Tensor]], types.Tensor]
|
|
394
|
+
]:
|
|
395
|
+
"""Register a gradient function for activation."""
|
|
396
|
+
activations, preserved_result = tpu_embedding(inputs, weights)
|
|
397
|
+
|
|
398
|
+
def grad(*grad_wrt_activations: types.Tensor) -> types.Tensor:
|
|
399
|
+
"""Gradient function."""
|
|
400
|
+
# Since the output were flattened, the gradients are also
|
|
401
|
+
# flattened. Pack them back into the correct nested structure.
|
|
402
|
+
gradients = tf.nest.pack_sequence_as(
|
|
403
|
+
self._placement_to_path_to_feature_config["sparsecore"],
|
|
404
|
+
grad_wrt_activations,
|
|
405
|
+
)
|
|
406
|
+
tpu_embedding.apply_gradients(
|
|
407
|
+
gradients, preserved_outputs=preserved_result
|
|
408
|
+
)
|
|
409
|
+
# This is the gradient for the input variable.
|
|
410
|
+
return tf.zeros_like(dummy)
|
|
411
|
+
|
|
412
|
+
# Custom gradient functions don't like nested structures of tensors,
|
|
413
|
+
# so we flatten them here.
|
|
414
|
+
return tf.nest.flatten(activations), grad
|
|
415
|
+
|
|
416
|
+
activations_with_trap = gradient_trap(self._gradient_trap_dummy)
|
|
417
|
+
result: dict[str, types.Tensor] = tf.nest.pack_sequence_as(
|
|
418
|
+
self._placement_to_path_to_feature_config["sparsecore"],
|
|
419
|
+
activations_with_trap,
|
|
420
|
+
)
|
|
421
|
+
return result
|
|
422
|
+
|
|
423
|
+
def _trackable_children(
|
|
424
|
+
self, save_type: str = "checkpoint", **kwargs: dict[str, Any]
|
|
425
|
+
) -> dict[str, Any]:
|
|
426
|
+
# Remove dummy variable, we don't want it in checkpoints.
|
|
427
|
+
children: dict[str, Any] = super()._trackable_children(
|
|
428
|
+
save_type, **kwargs
|
|
429
|
+
)
|
|
430
|
+
children.pop(GRADIENT_TRAP_DUMMY_NAME, None)
|
|
431
|
+
return children
|
|
432
|
+
|
|
433
|
+
|
|
434
|
+
DistributedEmbedding.__doc__ = (
|
|
435
|
+
base_distributed_embedding.DistributedEmbedding.__doc__
|
|
436
|
+
)
|
|
File without changes
|
|
@@ -5,6 +5,7 @@ from keras import ops
|
|
|
5
5
|
|
|
6
6
|
from keras_rs.src import types
|
|
7
7
|
from keras_rs.src.api_export import keras_rs_export
|
|
8
|
+
from keras_rs.src.utils.keras_utils import check_shapes_compatible
|
|
8
9
|
|
|
9
10
|
|
|
10
11
|
@keras_rs_export("keras_rs.layers.DotInteraction")
|
|
@@ -27,6 +28,54 @@ class DotInteraction(keras.layers.Layer):
|
|
|
27
28
|
entries will be zeros. Otherwise, the output will be only the lower
|
|
28
29
|
triangular part of the interaction matrix. The latter saves space
|
|
29
30
|
but is much slower.
|
|
31
|
+
**kwargs: Args to pass to the base class.
|
|
32
|
+
|
|
33
|
+
Example:
|
|
34
|
+
|
|
35
|
+
```python
|
|
36
|
+
# 1. Simple forward pass
|
|
37
|
+
batch_size = 2
|
|
38
|
+
embedding_dim = 32
|
|
39
|
+
feature1 = np.random.randn(batch_size, embedding_dim)
|
|
40
|
+
feature2 = np.random.randn(batch_size, embedding_dim)
|
|
41
|
+
feature3 = np.random.randn(batch_size, embedding_dim)
|
|
42
|
+
feature_interactions = keras_rs.layers.DotInteraction()(
|
|
43
|
+
[feature1, feature2, feature3]
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
# 2. After embedding layer in a model
|
|
47
|
+
vocabulary_size = 32
|
|
48
|
+
embedding_dim = 6
|
|
49
|
+
|
|
50
|
+
# Create a simple model containing the layer.
|
|
51
|
+
feature_input_1 = keras.Input(shape=(), name='indices_1', dtype="int32")
|
|
52
|
+
feature_input_2 = keras.Input(shape=(), name='indices_2', dtype="int32")
|
|
53
|
+
feature_input_3 = keras.Input(shape=(), name='indices_3', dtype="int32")
|
|
54
|
+
x1 = keras.layers.Embedding(
|
|
55
|
+
input_dim=vocabulary_size,
|
|
56
|
+
output_dim=embedding_dim
|
|
57
|
+
)(feature_input_1)
|
|
58
|
+
x2 = keras.layers.Embedding(
|
|
59
|
+
input_dim=vocabulary_size,
|
|
60
|
+
output_dim=embedding_dim
|
|
61
|
+
)(feature_input_2)
|
|
62
|
+
x3 = keras.layers.Embedding(
|
|
63
|
+
input_dim=vocabulary_size,
|
|
64
|
+
output_dim=embedding_dim
|
|
65
|
+
)(feature_input_3)
|
|
66
|
+
feature_interactions = keras_rs.layers.DotInteraction()([x1, x2, x3])
|
|
67
|
+
output = keras.layers.Dense(units=10)(x2)
|
|
68
|
+
model = keras.Model(
|
|
69
|
+
[feature_input_1, feature_input_2, feature_input_3], output
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
# Call the model on the inputs.
|
|
73
|
+
batch_size = 2
|
|
74
|
+
f1 = np.random.randint(0, vocabulary_size, size=(batch_size,))
|
|
75
|
+
f2 = np.random.randint(0, vocabulary_size, size=(batch_size,))
|
|
76
|
+
f3 = np.random.randint(0, vocabulary_size, size=(batch_size,))
|
|
77
|
+
outputs = model([f1, f2, f3])
|
|
78
|
+
```
|
|
30
79
|
|
|
31
80
|
References:
|
|
32
81
|
- [M. Naumov et al.](https://arxiv.org/abs/1906.00091)
|
|
@@ -44,6 +93,44 @@ class DotInteraction(keras.layers.Layer):
|
|
|
44
93
|
self.self_interaction = self_interaction
|
|
45
94
|
self.skip_gather = skip_gather
|
|
46
95
|
|
|
96
|
+
def _generate_tril_mask(
|
|
97
|
+
self, pairwise_interaction_matrix: types.Tensor
|
|
98
|
+
) -> types.Tensor:
|
|
99
|
+
"""Generates lower triangular mask."""
|
|
100
|
+
|
|
101
|
+
# If `self.self_interaction` is `True`, keep the main diagonal.
|
|
102
|
+
k = -1
|
|
103
|
+
if self.self_interaction:
|
|
104
|
+
k = 0
|
|
105
|
+
|
|
106
|
+
# Typecast k from Python int to tensor, because `ops.tril` uses
|
|
107
|
+
# `tf.cond` (which requires tensors).
|
|
108
|
+
# TODO (abheesht): Remove typecast once fix is merged in core Keras.
|
|
109
|
+
if keras.config.backend() == "tensorflow":
|
|
110
|
+
k = ops.array(k)
|
|
111
|
+
tril_mask = ops.tril(
|
|
112
|
+
ops.ones_like(pairwise_interaction_matrix, dtype=bool),
|
|
113
|
+
k=k,
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
return tril_mask
|
|
117
|
+
|
|
118
|
+
def _get_lower_triangular_indices(self, num_features: int) -> list[int]:
|
|
119
|
+
"""Python function which generates indices to get the lower triangular
|
|
120
|
+
matrix as if it were flattened.
|
|
121
|
+
"""
|
|
122
|
+
flattened_indices = []
|
|
123
|
+
for i in range(num_features):
|
|
124
|
+
k = i
|
|
125
|
+
# if `self.self_interaction` is `True`, keep the main diagonal.
|
|
126
|
+
if self.self_interaction:
|
|
127
|
+
k += 1
|
|
128
|
+
for j in range(k):
|
|
129
|
+
flattened_index = i * num_features + j
|
|
130
|
+
flattened_indices.append(flattened_index)
|
|
131
|
+
|
|
132
|
+
return flattened_indices
|
|
133
|
+
|
|
47
134
|
def call(self, inputs: list[types.Tensor]) -> types.Tensor:
|
|
48
135
|
"""Forward pass of the dot interaction layer.
|
|
49
136
|
|
|
@@ -64,23 +151,25 @@ class DotInteraction(keras.layers.Layer):
|
|
|
64
151
|
# Check if all feature tensors have the same shape and are of rank 2.
|
|
65
152
|
shape = ops.shape(inputs[0])
|
|
66
153
|
for idx, tensor in enumerate(inputs):
|
|
67
|
-
|
|
154
|
+
other_shape = ops.shape(tensor)
|
|
155
|
+
|
|
156
|
+
if len(shape) != 2:
|
|
157
|
+
raise ValueError(
|
|
158
|
+
"All feature tensors inside `inputs` should have rank 2. "
|
|
159
|
+
f"Received rank {len(shape)} at index {idx}."
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
if not check_shapes_compatible(shape, other_shape):
|
|
68
163
|
raise ValueError(
|
|
69
164
|
"All feature tensors in `inputs` should have the same "
|
|
70
165
|
f"shape. Found at least one conflict: shape = {shape} at "
|
|
71
|
-
f"index 0 and shape = {ops.shape(tensor)} at index {idx}"
|
|
166
|
+
f"index 0 and shape = {ops.shape(tensor)} at index {idx}."
|
|
72
167
|
)
|
|
73
168
|
|
|
74
|
-
if len(shape) != 2:
|
|
75
|
-
raise ValueError(
|
|
76
|
-
"All feature tensors inside `inputs` should have rank 2. "
|
|
77
|
-
f"Received rank {len(shape)}."
|
|
78
|
-
)
|
|
79
|
-
|
|
80
169
|
# `(batch_size, num_features, feature_dim)`
|
|
81
170
|
features = ops.stack(inputs, axis=1)
|
|
82
171
|
|
|
83
|
-
batch_size,
|
|
172
|
+
batch_size, num_features, _ = ops.shape(features)
|
|
84
173
|
|
|
85
174
|
# Compute the dot product to get feature interactions. The shape here is
|
|
86
175
|
# `(batch_size, num_features, num_features)`.
|
|
@@ -88,34 +177,36 @@ class DotInteraction(keras.layers.Layer):
|
|
|
88
177
|
features, ops.transpose(features, axes=(0, 2, 1))
|
|
89
178
|
)
|
|
90
179
|
|
|
91
|
-
# if `self.self_interaction` is `True`, keep the main diagonal.
|
|
92
|
-
k = -1
|
|
93
|
-
if self.self_interaction:
|
|
94
|
-
k = 0
|
|
95
|
-
|
|
96
|
-
tril_mask = ops.tril(
|
|
97
|
-
ops.ones_like(pairwise_interaction_matrix, dtype=bool),
|
|
98
|
-
k=k,
|
|
99
|
-
)
|
|
100
|
-
|
|
101
180
|
# Set the upper triangle entries to 0, if `self.skip_gather` is True.
|
|
102
181
|
# Else, "pick" only the lower triangle entries.
|
|
103
182
|
if self.skip_gather:
|
|
183
|
+
tril_mask = self._generate_tril_mask(pairwise_interaction_matrix)
|
|
184
|
+
|
|
104
185
|
activations = ops.multiply(
|
|
105
186
|
pairwise_interaction_matrix,
|
|
106
187
|
ops.cast(tril_mask, dtype=pairwise_interaction_matrix.dtype),
|
|
107
188
|
)
|
|
189
|
+
# Rank-2 tensor.
|
|
190
|
+
activations = ops.reshape(
|
|
191
|
+
activations, (batch_size, num_features * num_features)
|
|
192
|
+
)
|
|
108
193
|
else:
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
194
|
+
flattened_indices = self._get_lower_triangular_indices(num_features)
|
|
195
|
+
pairwise_interaction_matrix_flattened = ops.reshape(
|
|
196
|
+
pairwise_interaction_matrix,
|
|
197
|
+
(batch_size, num_features * num_features),
|
|
198
|
+
)
|
|
199
|
+
activations = ops.take(
|
|
200
|
+
pairwise_interaction_matrix_flattened,
|
|
201
|
+
flattened_indices,
|
|
202
|
+
axis=-1,
|
|
203
|
+
)
|
|
113
204
|
|
|
114
205
|
return activations
|
|
115
206
|
|
|
116
207
|
def compute_output_shape(
|
|
117
|
-
self, input_shape: list[types.
|
|
118
|
-
) -> types.
|
|
208
|
+
self, input_shape: list[types.Shape]
|
|
209
|
+
) -> types.Shape:
|
|
119
210
|
num_features = len(input_shape)
|
|
120
211
|
batch_size = input_shape[0][0]
|
|
121
212
|
|