keras-rs-nightly 0.0.1.dev2025021903__py3-none-any.whl → 0.3.1.dev202512130338__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. keras_rs/__init__.py +9 -28
  2. keras_rs/layers/__init__.py +37 -0
  3. keras_rs/losses/__init__.py +19 -0
  4. keras_rs/metrics/__init__.py +16 -0
  5. keras_rs/src/layers/embedding/base_distributed_embedding.py +1151 -0
  6. keras_rs/src/layers/embedding/distributed_embedding.py +33 -0
  7. keras_rs/src/layers/embedding/distributed_embedding_config.py +132 -0
  8. keras_rs/src/layers/embedding/embed_reduce.py +309 -0
  9. keras_rs/src/layers/embedding/jax/__init__.py +0 -0
  10. keras_rs/src/layers/embedding/jax/checkpoint_utils.py +104 -0
  11. keras_rs/src/layers/embedding/jax/config_conversion.py +468 -0
  12. keras_rs/src/layers/embedding/jax/distributed_embedding.py +829 -0
  13. keras_rs/src/layers/embedding/jax/embedding_lookup.py +276 -0
  14. keras_rs/src/layers/embedding/jax/embedding_utils.py +217 -0
  15. keras_rs/src/layers/embedding/tensorflow/__init__.py +0 -0
  16. keras_rs/src/layers/embedding/tensorflow/config_conversion.py +363 -0
  17. keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py +436 -0
  18. keras_rs/src/layers/feature_interaction/__init__.py +0 -0
  19. keras_rs/src/layers/{modeling → feature_interaction}/dot_interaction.py +116 -25
  20. keras_rs/src/layers/{modeling → feature_interaction}/feature_cross.py +40 -22
  21. keras_rs/src/layers/retrieval/brute_force_retrieval.py +16 -65
  22. keras_rs/src/layers/retrieval/hard_negative_mining.py +94 -0
  23. keras_rs/src/layers/retrieval/remove_accidental_hits.py +97 -0
  24. keras_rs/src/layers/retrieval/retrieval.py +127 -0
  25. keras_rs/src/layers/retrieval/sampling_probability_correction.py +63 -0
  26. keras_rs/src/losses/__init__.py +0 -0
  27. keras_rs/src/losses/list_mle_loss.py +212 -0
  28. keras_rs/src/losses/pairwise_hinge_loss.py +90 -0
  29. keras_rs/src/losses/pairwise_logistic_loss.py +99 -0
  30. keras_rs/src/losses/pairwise_loss.py +165 -0
  31. keras_rs/src/losses/pairwise_loss_utils.py +39 -0
  32. keras_rs/src/losses/pairwise_mean_squared_error.py +133 -0
  33. keras_rs/src/losses/pairwise_soft_zero_one_loss.py +98 -0
  34. keras_rs/src/metrics/__init__.py +0 -0
  35. keras_rs/src/metrics/dcg.py +161 -0
  36. keras_rs/src/metrics/mean_average_precision.py +130 -0
  37. keras_rs/src/metrics/mean_reciprocal_rank.py +121 -0
  38. keras_rs/src/metrics/ndcg.py +197 -0
  39. keras_rs/src/metrics/precision_at_k.py +117 -0
  40. keras_rs/src/metrics/ranking_metric.py +260 -0
  41. keras_rs/src/metrics/ranking_metrics_utils.py +257 -0
  42. keras_rs/src/metrics/recall_at_k.py +108 -0
  43. keras_rs/src/metrics/utils.py +70 -0
  44. keras_rs/src/types.py +43 -14
  45. keras_rs/src/utils/doc_string_utils.py +53 -0
  46. keras_rs/src/utils/keras_utils.py +52 -3
  47. keras_rs/src/utils/tpu_test_utils.py +120 -0
  48. keras_rs/src/version.py +1 -1
  49. {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/METADATA +88 -8
  50. keras_rs_nightly-0.3.1.dev202512130338.dist-info/RECORD +58 -0
  51. {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/WHEEL +1 -1
  52. keras_rs/api/__init__.py +0 -9
  53. keras_rs/api/layers/__init__.py +0 -11
  54. keras_rs_nightly-0.0.1.dev2025021903.dist-info/RECORD +0 -19
  55. /keras_rs/src/layers/{modeling → embedding}/__init__.py +0 -0
  56. {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,436 @@
1
+ from typing import Any, Callable, Sequence, TypeAlias
2
+
3
+ import keras
4
+ import tensorflow as tf
5
+
6
+ from keras_rs.src import types
7
+ from keras_rs.src.layers.embedding import base_distributed_embedding
8
+ from keras_rs.src.layers.embedding import distributed_embedding_config
9
+ from keras_rs.src.layers.embedding.tensorflow import config_conversion
10
+ from keras_rs.src.utils import keras_utils
11
+
12
+ FeatureConfig = distributed_embedding_config.FeatureConfig
13
+ TableConfig = distributed_embedding_config.TableConfig
14
+
15
+ # Placeholder of tf.tpu.experimental.embedding._Optimizer which is not exposed.
16
+ TfTpuOptimizer: TypeAlias = Any
17
+
18
+
19
+ GRADIENT_TRAP_DUMMY_NAME = "_gradient_trap_dummy"
20
+
21
+ EMBEDDING_FEATURE_V1 = tf.tpu.experimental.HardwareFeature.EmbeddingFeature.V1
22
+ EMBEDDING_FEATURE_V2 = tf.tpu.experimental.HardwareFeature.EmbeddingFeature.V2
23
+ UNSUPPORTED = tf.tpu.experimental.HardwareFeature.EmbeddingFeature.UNSUPPORTED
24
+
25
+
26
+ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
27
+ """TensorFlow implementation of the TPU embedding layer."""
28
+
29
+ def __init__(
30
+ self,
31
+ feature_configs: types.Nested[
32
+ FeatureConfig | tf.tpu.experimental.embedding.FeatureConfig
33
+ ],
34
+ *,
35
+ table_stacking: (
36
+ str | Sequence[str] | Sequence[Sequence[str]]
37
+ ) = "auto",
38
+ **kwargs: Any,
39
+ ) -> None:
40
+ # Intercept arguments that are supported only on TensorFlow.
41
+ self._optimizer = kwargs.pop("optimizer", None)
42
+ self._pipeline_execution_with_tensor_core = kwargs.pop(
43
+ "pipeline_execution_with_tensor_core", False
44
+ )
45
+ self._sparse_core_embedding_config = kwargs.pop(
46
+ "sparse_core_embedding_config", None
47
+ )
48
+
49
+ # Mark as True by default for `_verify_input_shapes`. This will be
50
+ # updated in `_sparsecore_init` if applicable.
51
+ self._using_keras_rs_configuration = True
52
+
53
+ super().__init__(
54
+ feature_configs, table_stacking=table_stacking, **kwargs
55
+ )
56
+
57
+ def _is_tpu_strategy(self, strategy: tf.distribute.Strategy) -> bool:
58
+ return isinstance(
59
+ strategy,
60
+ (tf.distribute.TPUStrategy, tf.distribute.experimental.TPUStrategy),
61
+ )
62
+
63
+ def _has_sparsecore(self) -> bool:
64
+ strategy = tf.distribute.get_strategy()
65
+ if self._is_tpu_strategy(strategy):
66
+ tpu_embedding_feature = (
67
+ strategy.extended.tpu_hardware_feature.embedding_feature
68
+ )
69
+ return tpu_embedding_feature in (
70
+ EMBEDDING_FEATURE_V2,
71
+ EMBEDDING_FEATURE_V1,
72
+ )
73
+ return False
74
+
75
+ @keras_utils.no_automatic_dependency_tracking
76
+ def _sparsecore_init(
77
+ self,
78
+ feature_configs: dict[
79
+ str,
80
+ FeatureConfig | tf.tpu.experimental.embedding.FeatureConfig,
81
+ ],
82
+ table_stacking: str | Sequence[str] | Sequence[Sequence[str]],
83
+ ) -> None:
84
+ self._table_stacking = table_stacking
85
+
86
+ strategy = tf.distribute.get_strategy()
87
+ if not self._is_tpu_strategy(strategy):
88
+ raise ValueError(
89
+ "Placement to sparsecore was requested, however, we are not "
90
+ "running under a TPU strategy."
91
+ )
92
+
93
+ tpu_embedding_feature = (
94
+ strategy.extended.tpu_hardware_feature.embedding_feature
95
+ )
96
+
97
+ self._using_keras_rs_configuration = isinstance(
98
+ next(iter(feature_configs.values())), FeatureConfig
99
+ )
100
+ if self._using_keras_rs_configuration:
101
+ if self._sparse_core_embedding_config is not None:
102
+ raise ValueError(
103
+ "The `sparse_core_embedding_config` argument is only "
104
+ "supported when using "
105
+ "`tf.tpu.experimental.embedding.FeatureConfig` instances "
106
+ "for the configuration."
107
+ )
108
+ self._tpu_feature_configs, self._sparse_core_embedding_config = (
109
+ config_conversion.keras_to_tf_tpu_configuration(
110
+ feature_configs,
111
+ table_stacking,
112
+ strategy.num_replicas_in_sync,
113
+ )
114
+ )
115
+ if tpu_embedding_feature == EMBEDDING_FEATURE_V1:
116
+ # Remove auto-generated SparseCoreEmbeddingConfig, which is not
117
+ # used.
118
+ self._sparse_core_embedding_config = None
119
+ else:
120
+ if table_stacking != "auto":
121
+ raise ValueError(
122
+ "The `table_stacking` argument is not supported when using "
123
+ "`tf.tpu.experimental.embedding.FeatureConfig` for the "
124
+ "configuration. You can use the `disable_table_stacking` "
125
+ "attribute of "
126
+ "`tf.tpu.experimental.embedding.SparseCoreEmbeddingConfig` "
127
+ "to disable table stacking."
128
+ )
129
+ if (
130
+ tpu_embedding_feature == EMBEDDING_FEATURE_V1
131
+ and self._sparse_core_embedding_config is not None
132
+ ):
133
+ raise ValueError(
134
+ "The `sparse_core_embedding_config` argument is not "
135
+ "supported with this TPU generation."
136
+ )
137
+ self._tpu_feature_configs = (
138
+ config_conversion.clone_tf_tpu_feature_configs(feature_configs)
139
+ )
140
+
141
+ self._tpu_optimizer = config_conversion.to_tf_tpu_optimizer(
142
+ self._optimizer
143
+ )
144
+
145
+ if tpu_embedding_feature == EMBEDDING_FEATURE_V1:
146
+ self._tpu_embedding = tf.tpu.experimental.embedding.TPUEmbedding(
147
+ self._tpu_feature_configs,
148
+ self._tpu_optimizer,
149
+ self._pipeline_execution_with_tensor_core,
150
+ )
151
+ self._v1_call_id = 0
152
+ elif tpu_embedding_feature == EMBEDDING_FEATURE_V2:
153
+ self._tpu_embedding = tf.tpu.experimental.embedding.TPUEmbeddingV2(
154
+ self._tpu_feature_configs,
155
+ self._tpu_optimizer,
156
+ self._pipeline_execution_with_tensor_core,
157
+ self._sparse_core_embedding_config,
158
+ )
159
+ elif tpu_embedding_feature == UNSUPPORTED:
160
+ raise ValueError(
161
+ "Placement to sparsecore was requested, however, this TPU does "
162
+ "not support it."
163
+ )
164
+ elif tpu_embedding_feature != UNSUPPORTED:
165
+ raise ValueError(
166
+ f"Unsupported TPU embedding feature: {tpu_embedding_feature}."
167
+ )
168
+
169
+ # We need at least one trainable variable for the gradient trap to work.
170
+ # Note that the Python attribute name "_gradient_trap_dummy" should
171
+ # match the name of the variable GRADIENT_TRAP_DUMMY_NAME.
172
+ self._gradient_trap_dummy = self.add_weight(
173
+ name=GRADIENT_TRAP_DUMMY_NAME,
174
+ shape=(1,),
175
+ initializer=tf.zeros_initializer(),
176
+ trainable=True,
177
+ dtype=tf.float32,
178
+ )
179
+
180
+ def compute_output_shape(
181
+ self, input_shapes: types.Nested[types.Shape]
182
+ ) -> types.Nested[types.Shape]:
183
+ if self._using_keras_rs_configuration:
184
+ return super().compute_output_shape(input_shapes)
185
+
186
+ def _compute_output_shape(
187
+ feature_config: tf.tpu.experimental.embedding.FeatureConfig,
188
+ input_shape: types.Shape,
189
+ ) -> types.Shape:
190
+ if len(input_shape) < 1:
191
+ raise ValueError(
192
+ f"Received input shape {input_shape}. Rank must be 1 or "
193
+ "above."
194
+ )
195
+ max_sequence_length: int = feature_config.max_sequence_length
196
+ embed_dim = feature_config.table.dim
197
+ if (
198
+ feature_config.output_shape is not None
199
+ and feature_config.output_shape.rank is not None
200
+ ):
201
+ return tuple(feature_config.output_shape.as_list())
202
+ elif (
203
+ len(input_shape) == 2
204
+ and input_shape[-1] != 1
205
+ and max_sequence_length > 0
206
+ ):
207
+ # Update the input shape with the max sequence length. Only
208
+ # update when:
209
+ # 1. Input feature is 2D ragged or sparse tensor.
210
+ # 2. Output shape is not set and max sequence length is set.
211
+ return tuple(input_shape[:-1]) + (
212
+ max_sequence_length,
213
+ embed_dim,
214
+ )
215
+ elif len(input_shape) == 1:
216
+ return tuple(input_shape) + (embed_dim,)
217
+ else:
218
+ return tuple(input_shape[:-1]) + (embed_dim,)
219
+
220
+ output_shapes: types.Nested[types.Shape] = (
221
+ keras.tree.map_structure_up_to(
222
+ self._feature_configs,
223
+ _compute_output_shape,
224
+ self._feature_configs,
225
+ input_shapes,
226
+ )
227
+ )
228
+ return output_shapes
229
+
230
+ def _sparsecore_build(self, input_shapes: dict[str, types.Shape]) -> None:
231
+ if isinstance(
232
+ self._tpu_embedding, tf.tpu.experimental.embedding.TPUEmbedding
233
+ ):
234
+ tf_input_shapes = keras.tree.map_shape_structure(
235
+ tf.TensorShape, input_shapes
236
+ )
237
+ tpu_embedding_build = tf.autograph.to_graph(
238
+ self._tpu_embedding.build, recursive=False
239
+ )
240
+ tpu_embedding_build(
241
+ self._tpu_embedding, per_replica_input_shapes=tf_input_shapes
242
+ )
243
+ elif isinstance(
244
+ self._tpu_embedding, tf.tpu.experimental.embedding.TPUEmbeddingV2
245
+ ):
246
+ self._tpu_embedding.build()
247
+
248
+ def _sparsecore_call(
249
+ self,
250
+ inputs: dict[str, types.Tensor],
251
+ weights: dict[str, types.Tensor] | None = None,
252
+ training: bool = False,
253
+ ) -> dict[str, types.Tensor]:
254
+ del training # Unused.
255
+ strategy = tf.distribute.get_strategy()
256
+ if not self._is_tpu_strategy(strategy):
257
+ raise RuntimeError(
258
+ "DistributedEmbedding needs to be called under a TPUStrategy "
259
+ "for features placed on the embedding feature but is being "
260
+ f"called under strategy {strategy}. Please use `strategy.run` "
261
+ "when calling this layer."
262
+ )
263
+ if isinstance(
264
+ self._tpu_embedding, tf.tpu.experimental.embedding.TPUEmbedding
265
+ ):
266
+ return self._tpu_embedding_lookup_v1(
267
+ self._tpu_embedding, inputs, weights
268
+ )
269
+ elif isinstance(
270
+ self._tpu_embedding, tf.tpu.experimental.embedding.TPUEmbeddingV2
271
+ ):
272
+ return self._tpu_embedding_lookup_v2(
273
+ self._tpu_embedding, inputs, weights
274
+ )
275
+ else:
276
+ raise ValueError(
277
+ "DistributedEmbedding is receiving features to lookup on the "
278
+ "TPU embedding feature but no such feature was configured."
279
+ )
280
+
281
+ def _sparsecore_get_embedding_tables(self) -> dict[str, types.Tensor]:
282
+ tables: dict[str, types.Tensor] = {}
283
+ strategy = tf.distribute.get_strategy()
284
+ if not self._is_tpu_strategy(strategy):
285
+ raise RuntimeError(
286
+ "`DistributedEmbedding.get_embedding_tables` needs to be "
287
+ "called under the TPUStrategy that DistributedEmbedding was "
288
+ f"created with, but is being called under strategy {strategy}. "
289
+ "Please use `with strategy.scope()` when calling "
290
+ "`get_embedding_tables`."
291
+ )
292
+
293
+ tpu_hardware = strategy.extended.tpu_hardware_feature
294
+ num_sc_per_device = tpu_hardware.num_embedding_devices_per_chip
295
+ num_shards = strategy.num_replicas_in_sync * num_sc_per_device
296
+
297
+ def populate_table(
298
+ feature_config: tf.tpu.experimental.embedding.FeatureConfig,
299
+ ) -> None:
300
+ table_name = feature_config.table.name
301
+ if table_name in tables:
302
+ return
303
+
304
+ embedding_dim = feature_config.table.dim
305
+ table = self._tpu_embedding.embedding_tables[table_name]
306
+
307
+ # This table has num_sparse_cores mod shards, so we need to slice,
308
+ # reconcat and reshape.
309
+ table_shards = [
310
+ shard.numpy()[:, :embedding_dim] for shard in table.values
311
+ ]
312
+ full_table = keras.ops.concatenate(table_shards, axis=0)
313
+ full_table = keras.ops.concatenate(
314
+ keras.ops.split(full_table, num_shards, axis=0), axis=1
315
+ )
316
+ full_table = keras.ops.reshape(full_table, [-1, embedding_dim])
317
+ tables[table_name] = full_table[
318
+ : feature_config.table.vocabulary_size, :
319
+ ]
320
+
321
+ keras.tree.map_structure(populate_table, self._tpu_feature_configs)
322
+ return tables
323
+
324
+ def _verify_input_shapes(
325
+ self, input_shapes: types.Nested[types.Shape]
326
+ ) -> None:
327
+ if self._using_keras_rs_configuration:
328
+ return super()._verify_input_shapes(input_shapes)
329
+ # `tf.tpu.experimental.embedding.FeatureConfig` does not provide any
330
+ # information about the input shape, so there is nothing to verify.
331
+
332
+ def _tpu_embedding_lookup_v1(
333
+ self,
334
+ tpu_embedding: tf.tpu.experimental.embedding.TPUEmbedding,
335
+ inputs: dict[str, types.Tensor],
336
+ weights: dict[str, types.Tensor] | None = None,
337
+ ) -> dict[str, types.Tensor]:
338
+ # Each call to this function increments the _v1_call_id by 1, this
339
+ # allows us to tag each of the main embedding ops with this call id so
340
+ # that we know during graph rewriting passes which ops correspond to the
341
+ # same layer call.
342
+ self._v1_call_id += 1
343
+ name = str(self._v1_call_id)
344
+
345
+ # Set training to true, even during eval. When name is set, this will
346
+ # trigger a pass that updates the training based on if there is a send
347
+ # gradients with the same name.
348
+ tpu_embedding.enqueue(inputs, weights, training=True, name=name)
349
+
350
+ @tf.custom_gradient # type: ignore
351
+ def gradient_trap(
352
+ dummy: types.Tensor,
353
+ ) -> tuple[
354
+ list[types.Tensor], Callable[[tuple[types.Tensor]], types.Tensor]
355
+ ]:
356
+ """Register a gradient function for activation."""
357
+ activations = tpu_embedding.dequeue(name=name)
358
+
359
+ def grad(*grad_wrt_activations: types.Tensor) -> types.Tensor:
360
+ """Gradient function."""
361
+ # Since the output were flattened, the gradients are also
362
+ # flattened. Pack them back into the correct nested structure.
363
+ gradients = tf.nest.pack_sequence_as(
364
+ self._placement_to_path_to_feature_config["sparsecore"],
365
+ grad_wrt_activations,
366
+ )
367
+ tpu_embedding.apply_gradients(gradients, name=name)
368
+
369
+ # This is the gradient for the input variable.
370
+ return tf.zeros_like(dummy)
371
+
372
+ # Custom gradient functions don't like nested structures of tensors,
373
+ # so we flatten them here.
374
+ return tf.nest.flatten(activations), grad
375
+
376
+ activations_with_trap = gradient_trap(self._gradient_trap_dummy.value)
377
+ result: dict[str, types.Tensor] = tf.nest.pack_sequence_as(
378
+ self._placement_to_path_to_feature_config["sparsecore"],
379
+ activations_with_trap,
380
+ )
381
+ return result
382
+
383
+ def _tpu_embedding_lookup_v2(
384
+ self,
385
+ tpu_embedding: tf.tpu.experimental.embedding.TPUEmbeddingV2,
386
+ inputs: dict[str, types.Tensor],
387
+ weights: dict[str, types.Tensor] | None = None,
388
+ ) -> dict[str, types.Tensor]:
389
+ @tf.custom_gradient # type: ignore
390
+ def gradient_trap(
391
+ dummy: types.Tensor,
392
+ ) -> tuple[
393
+ list[types.Tensor], Callable[[tuple[types.Tensor]], types.Tensor]
394
+ ]:
395
+ """Register a gradient function for activation."""
396
+ activations, preserved_result = tpu_embedding(inputs, weights)
397
+
398
+ def grad(*grad_wrt_activations: types.Tensor) -> types.Tensor:
399
+ """Gradient function."""
400
+ # Since the output were flattened, the gradients are also
401
+ # flattened. Pack them back into the correct nested structure.
402
+ gradients = tf.nest.pack_sequence_as(
403
+ self._placement_to_path_to_feature_config["sparsecore"],
404
+ grad_wrt_activations,
405
+ )
406
+ tpu_embedding.apply_gradients(
407
+ gradients, preserved_outputs=preserved_result
408
+ )
409
+ # This is the gradient for the input variable.
410
+ return tf.zeros_like(dummy)
411
+
412
+ # Custom gradient functions don't like nested structures of tensors,
413
+ # so we flatten them here.
414
+ return tf.nest.flatten(activations), grad
415
+
416
+ activations_with_trap = gradient_trap(self._gradient_trap_dummy)
417
+ result: dict[str, types.Tensor] = tf.nest.pack_sequence_as(
418
+ self._placement_to_path_to_feature_config["sparsecore"],
419
+ activations_with_trap,
420
+ )
421
+ return result
422
+
423
+ def _trackable_children(
424
+ self, save_type: str = "checkpoint", **kwargs: dict[str, Any]
425
+ ) -> dict[str, Any]:
426
+ # Remove dummy variable, we don't want it in checkpoints.
427
+ children: dict[str, Any] = super()._trackable_children(
428
+ save_type, **kwargs
429
+ )
430
+ children.pop(GRADIENT_TRAP_DUMMY_NAME, None)
431
+ return children
432
+
433
+
434
+ DistributedEmbedding.__doc__ = (
435
+ base_distributed_embedding.DistributedEmbedding.__doc__
436
+ )
File without changes
@@ -5,6 +5,7 @@ from keras import ops
5
5
 
6
6
  from keras_rs.src import types
7
7
  from keras_rs.src.api_export import keras_rs_export
8
+ from keras_rs.src.utils.keras_utils import check_shapes_compatible
8
9
 
9
10
 
10
11
  @keras_rs_export("keras_rs.layers.DotInteraction")
@@ -27,6 +28,54 @@ class DotInteraction(keras.layers.Layer):
27
28
  entries will be zeros. Otherwise, the output will be only the lower
28
29
  triangular part of the interaction matrix. The latter saves space
29
30
  but is much slower.
31
+ **kwargs: Args to pass to the base class.
32
+
33
+ Example:
34
+
35
+ ```python
36
+ # 1. Simple forward pass
37
+ batch_size = 2
38
+ embedding_dim = 32
39
+ feature1 = np.random.randn(batch_size, embedding_dim)
40
+ feature2 = np.random.randn(batch_size, embedding_dim)
41
+ feature3 = np.random.randn(batch_size, embedding_dim)
42
+ feature_interactions = keras_rs.layers.DotInteraction()(
43
+ [feature1, feature2, feature3]
44
+ )
45
+
46
+ # 2. After embedding layer in a model
47
+ vocabulary_size = 32
48
+ embedding_dim = 6
49
+
50
+ # Create a simple model containing the layer.
51
+ feature_input_1 = keras.Input(shape=(), name='indices_1', dtype="int32")
52
+ feature_input_2 = keras.Input(shape=(), name='indices_2', dtype="int32")
53
+ feature_input_3 = keras.Input(shape=(), name='indices_3', dtype="int32")
54
+ x1 = keras.layers.Embedding(
55
+ input_dim=vocabulary_size,
56
+ output_dim=embedding_dim
57
+ )(feature_input_1)
58
+ x2 = keras.layers.Embedding(
59
+ input_dim=vocabulary_size,
60
+ output_dim=embedding_dim
61
+ )(feature_input_2)
62
+ x3 = keras.layers.Embedding(
63
+ input_dim=vocabulary_size,
64
+ output_dim=embedding_dim
65
+ )(feature_input_3)
66
+ feature_interactions = keras_rs.layers.DotInteraction()([x1, x2, x3])
67
+ output = keras.layers.Dense(units=10)(x2)
68
+ model = keras.Model(
69
+ [feature_input_1, feature_input_2, feature_input_3], output
70
+ )
71
+
72
+ # Call the model on the inputs.
73
+ batch_size = 2
74
+ f1 = np.random.randint(0, vocabulary_size, size=(batch_size,))
75
+ f2 = np.random.randint(0, vocabulary_size, size=(batch_size,))
76
+ f3 = np.random.randint(0, vocabulary_size, size=(batch_size,))
77
+ outputs = model([f1, f2, f3])
78
+ ```
30
79
 
31
80
  References:
32
81
  - [M. Naumov et al.](https://arxiv.org/abs/1906.00091)
@@ -44,6 +93,44 @@ class DotInteraction(keras.layers.Layer):
44
93
  self.self_interaction = self_interaction
45
94
  self.skip_gather = skip_gather
46
95
 
96
+ def _generate_tril_mask(
97
+ self, pairwise_interaction_matrix: types.Tensor
98
+ ) -> types.Tensor:
99
+ """Generates lower triangular mask."""
100
+
101
+ # If `self.self_interaction` is `True`, keep the main diagonal.
102
+ k = -1
103
+ if self.self_interaction:
104
+ k = 0
105
+
106
+ # Typecast k from Python int to tensor, because `ops.tril` uses
107
+ # `tf.cond` (which requires tensors).
108
+ # TODO (abheesht): Remove typecast once fix is merged in core Keras.
109
+ if keras.config.backend() == "tensorflow":
110
+ k = ops.array(k)
111
+ tril_mask = ops.tril(
112
+ ops.ones_like(pairwise_interaction_matrix, dtype=bool),
113
+ k=k,
114
+ )
115
+
116
+ return tril_mask
117
+
118
+ def _get_lower_triangular_indices(self, num_features: int) -> list[int]:
119
+ """Python function which generates indices to get the lower triangular
120
+ matrix as if it were flattened.
121
+ """
122
+ flattened_indices = []
123
+ for i in range(num_features):
124
+ k = i
125
+ # if `self.self_interaction` is `True`, keep the main diagonal.
126
+ if self.self_interaction:
127
+ k += 1
128
+ for j in range(k):
129
+ flattened_index = i * num_features + j
130
+ flattened_indices.append(flattened_index)
131
+
132
+ return flattened_indices
133
+
47
134
  def call(self, inputs: list[types.Tensor]) -> types.Tensor:
48
135
  """Forward pass of the dot interaction layer.
49
136
 
@@ -64,23 +151,25 @@ class DotInteraction(keras.layers.Layer):
64
151
  # Check if all feature tensors have the same shape and are of rank 2.
65
152
  shape = ops.shape(inputs[0])
66
153
  for idx, tensor in enumerate(inputs):
67
- if ops.shape(tensor) != shape:
154
+ other_shape = ops.shape(tensor)
155
+
156
+ if len(shape) != 2:
157
+ raise ValueError(
158
+ "All feature tensors inside `inputs` should have rank 2. "
159
+ f"Received rank {len(shape)} at index {idx}."
160
+ )
161
+
162
+ if not check_shapes_compatible(shape, other_shape):
68
163
  raise ValueError(
69
164
  "All feature tensors in `inputs` should have the same "
70
165
  f"shape. Found at least one conflict: shape = {shape} at "
71
- f"index 0 and shape = {ops.shape(tensor)} at index {idx}"
166
+ f"index 0 and shape = {ops.shape(tensor)} at index {idx}."
72
167
  )
73
168
 
74
- if len(shape) != 2:
75
- raise ValueError(
76
- "All feature tensors inside `inputs` should have rank 2. "
77
- f"Received rank {len(shape)}."
78
- )
79
-
80
169
  # `(batch_size, num_features, feature_dim)`
81
170
  features = ops.stack(inputs, axis=1)
82
171
 
83
- batch_size, _, _ = ops.shape(features)
172
+ batch_size, num_features, _ = ops.shape(features)
84
173
 
85
174
  # Compute the dot product to get feature interactions. The shape here is
86
175
  # `(batch_size, num_features, num_features)`.
@@ -88,34 +177,36 @@ class DotInteraction(keras.layers.Layer):
88
177
  features, ops.transpose(features, axes=(0, 2, 1))
89
178
  )
90
179
 
91
- # if `self.self_interaction` is `True`, keep the main diagonal.
92
- k = -1
93
- if self.self_interaction:
94
- k = 0
95
-
96
- tril_mask = ops.tril(
97
- ops.ones_like(pairwise_interaction_matrix, dtype=bool),
98
- k=k,
99
- )
100
-
101
180
  # Set the upper triangle entries to 0, if `self.skip_gather` is True.
102
181
  # Else, "pick" only the lower triangle entries.
103
182
  if self.skip_gather:
183
+ tril_mask = self._generate_tril_mask(pairwise_interaction_matrix)
184
+
104
185
  activations = ops.multiply(
105
186
  pairwise_interaction_matrix,
106
187
  ops.cast(tril_mask, dtype=pairwise_interaction_matrix.dtype),
107
188
  )
189
+ # Rank-2 tensor.
190
+ activations = ops.reshape(
191
+ activations, (batch_size, num_features * num_features)
192
+ )
108
193
  else:
109
- activations = pairwise_interaction_matrix[tril_mask]
110
-
111
- # Rank-2 tensor.
112
- activations = ops.reshape(activations, (batch_size, -1))
194
+ flattened_indices = self._get_lower_triangular_indices(num_features)
195
+ pairwise_interaction_matrix_flattened = ops.reshape(
196
+ pairwise_interaction_matrix,
197
+ (batch_size, num_features * num_features),
198
+ )
199
+ activations = ops.take(
200
+ pairwise_interaction_matrix_flattened,
201
+ flattened_indices,
202
+ axis=-1,
203
+ )
113
204
 
114
205
  return activations
115
206
 
116
207
  def compute_output_shape(
117
- self, input_shape: list[types.TensorShape]
118
- ) -> types.TensorShape:
208
+ self, input_shape: list[types.Shape]
209
+ ) -> types.Shape:
119
210
  num_features = len(input_shape)
120
211
  batch_size = input_shape[0][0]
121
212