keras-rs-nightly 0.0.1.dev2025043003__py3-none-any.whl → 0.2.2.dev202506100336__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of keras-rs-nightly might be problematic. Click here for more details.

Files changed (37) hide show
  1. keras_rs/layers/__init__.py +12 -0
  2. keras_rs/src/layers/embedding/__init__.py +0 -0
  3. keras_rs/src/layers/embedding/base_distributed_embedding.py +1124 -0
  4. keras_rs/src/layers/embedding/distributed_embedding.py +33 -0
  5. keras_rs/src/layers/embedding/distributed_embedding_config.py +129 -0
  6. keras_rs/src/layers/embedding/embed_reduce.py +309 -0
  7. keras_rs/src/layers/embedding/jax/__init__.py +0 -0
  8. keras_rs/src/layers/embedding/jax/config_conversion.py +398 -0
  9. keras_rs/src/layers/embedding/jax/distributed_embedding.py +892 -0
  10. keras_rs/src/layers/embedding/jax/embedding_lookup.py +255 -0
  11. keras_rs/src/layers/embedding/jax/embedding_utils.py +596 -0
  12. keras_rs/src/layers/embedding/tensorflow/__init__.py +0 -0
  13. keras_rs/src/layers/embedding/tensorflow/config_conversion.py +323 -0
  14. keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py +424 -0
  15. keras_rs/src/layers/feature_interaction/dot_interaction.py +2 -2
  16. keras_rs/src/layers/feature_interaction/feature_cross.py +14 -16
  17. keras_rs/src/layers/retrieval/brute_force_retrieval.py +5 -5
  18. keras_rs/src/layers/retrieval/retrieval.py +4 -4
  19. keras_rs/src/losses/pairwise_loss.py +2 -2
  20. keras_rs/src/losses/pairwise_mean_squared_error.py +1 -3
  21. keras_rs/src/metrics/dcg.py +2 -2
  22. keras_rs/src/metrics/mean_average_precision.py +1 -1
  23. keras_rs/src/metrics/mean_reciprocal_rank.py +4 -4
  24. keras_rs/src/metrics/ndcg.py +2 -2
  25. keras_rs/src/metrics/precision_at_k.py +3 -3
  26. keras_rs/src/metrics/ranking_metric.py +11 -5
  27. keras_rs/src/metrics/ranking_metrics_utils.py +10 -10
  28. keras_rs/src/metrics/recall_at_k.py +2 -2
  29. keras_rs/src/metrics/utils.py +2 -4
  30. keras_rs/src/types.py +43 -14
  31. keras_rs/src/utils/keras_utils.py +26 -6
  32. keras_rs/src/version.py +1 -1
  33. {keras_rs_nightly-0.0.1.dev2025043003.dist-info → keras_rs_nightly-0.2.2.dev202506100336.dist-info}/METADATA +6 -3
  34. keras_rs_nightly-0.2.2.dev202506100336.dist-info/RECORD +55 -0
  35. {keras_rs_nightly-0.0.1.dev2025043003.dist-info → keras_rs_nightly-0.2.2.dev202506100336.dist-info}/WHEEL +1 -1
  36. keras_rs_nightly-0.0.1.dev2025043003.dist-info/RECORD +0 -42
  37. {keras_rs_nightly-0.0.1.dev2025043003.dist-info → keras_rs_nightly-0.2.2.dev202506100336.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,424 @@
1
+ from typing import Any, Callable, Sequence, TypeAlias
2
+
3
+ import keras
4
+ import tensorflow as tf
5
+
6
+ from keras_rs.src import types
7
+ from keras_rs.src.layers.embedding import base_distributed_embedding
8
+ from keras_rs.src.layers.embedding import distributed_embedding_config
9
+ from keras_rs.src.layers.embedding.tensorflow import config_conversion
10
+ from keras_rs.src.utils import keras_utils
11
+
12
+ FeatureConfig = distributed_embedding_config.FeatureConfig
13
+ TableConfig = distributed_embedding_config.TableConfig
14
+
15
+ # Placeholder of tf.tpu.experimental.embedding._Optimizer which is not exposed.
16
+ TfTpuOptimizer: TypeAlias = Any
17
+
18
+
19
+ GRADIENT_TRAP_DUMMY_NAME = "_gradient_trap_dummy"
20
+
21
+ EMBEDDING_FEATURE_V1 = tf.tpu.experimental.HardwareFeature.EmbeddingFeature.V1
22
+ EMBEDDING_FEATURE_V2 = tf.tpu.experimental.HardwareFeature.EmbeddingFeature.V2
23
+ UNSUPPORTED = tf.tpu.experimental.HardwareFeature.EmbeddingFeature.UNSUPPORTED
24
+
25
+
26
+ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
27
+ """TensorFlow implementation of the TPU embedding layer."""
28
+
29
+ def __init__(
30
+ self,
31
+ feature_configs: types.Nested[
32
+ FeatureConfig | tf.tpu.experimental.embedding.FeatureConfig
33
+ ],
34
+ *,
35
+ table_stacking: (
36
+ str | Sequence[str] | Sequence[Sequence[str]]
37
+ ) = "auto",
38
+ **kwargs: Any,
39
+ ) -> None:
40
+ # Intercept arguments that are supported only on TensorFlow.
41
+ self._optimizer = kwargs.pop("optimizer", None)
42
+ self._pipeline_execution_with_tensor_core = kwargs.pop(
43
+ "pipeline_execution_with_tensor_core", False
44
+ )
45
+ self._sparse_core_embedding_config = kwargs.pop(
46
+ "sparse_core_embedding_config", None
47
+ )
48
+
49
+ # Mark as True by default for `_verify_input_shapes`. This will be
50
+ # updated in `_sparsecore_init` if applicable.
51
+ self._using_keras_rs_configuration = True
52
+
53
+ super().__init__(
54
+ feature_configs, table_stacking=table_stacking, **kwargs
55
+ )
56
+
57
+ def _is_tpu_strategy(self, strategy: tf.distribute.Strategy) -> bool:
58
+ return isinstance(
59
+ strategy,
60
+ (tf.distribute.TPUStrategy, tf.distribute.experimental.TPUStrategy),
61
+ )
62
+
63
+ def _has_sparsecore(self) -> bool:
64
+ strategy = tf.distribute.get_strategy()
65
+ if self._is_tpu_strategy(strategy):
66
+ tpu_embedding_feature = (
67
+ strategy.extended.tpu_hardware_feature.embedding_feature
68
+ )
69
+ return tpu_embedding_feature in (
70
+ EMBEDDING_FEATURE_V2,
71
+ EMBEDDING_FEATURE_V1,
72
+ )
73
+ return False
74
+
75
+ @keras_utils.no_automatic_dependency_tracking
76
+ def _sparsecore_init(
77
+ self,
78
+ feature_configs: dict[
79
+ str,
80
+ FeatureConfig | tf.tpu.experimental.embedding.FeatureConfig,
81
+ ],
82
+ table_stacking: str | Sequence[str] | Sequence[Sequence[str]],
83
+ ) -> None:
84
+ self._table_stacking = table_stacking
85
+
86
+ strategy = tf.distribute.get_strategy()
87
+ if not self._is_tpu_strategy(strategy):
88
+ raise ValueError(
89
+ "Placement to sparsecore was requested, however, we are not "
90
+ "running under a TPU strategy."
91
+ )
92
+
93
+ tpu_embedding_feature = (
94
+ strategy.extended.tpu_hardware_feature.embedding_feature
95
+ )
96
+
97
+ self._using_keras_rs_configuration = isinstance(
98
+ next(iter(feature_configs.values())), FeatureConfig
99
+ )
100
+ if self._using_keras_rs_configuration:
101
+ if self._sparse_core_embedding_config is not None:
102
+ raise ValueError(
103
+ "The `sparse_core_embedding_config` argument is only "
104
+ "supported when using "
105
+ "`tf.tpu.experimental.embedding.FeatureConfig` instances "
106
+ "for the configuration."
107
+ )
108
+ self._tpu_feature_configs, self._sparse_core_embedding_config = (
109
+ config_conversion.translate_keras_rs_configuration(
110
+ feature_configs, table_stacking
111
+ )
112
+ )
113
+ if tpu_embedding_feature == EMBEDDING_FEATURE_V1:
114
+ # Remove auto-generated SparseCoreEmbeddingConfig, which is not
115
+ # used.
116
+ self._sparse_core_embedding_config = None
117
+ else:
118
+ if table_stacking != "auto":
119
+ raise ValueError(
120
+ "The `table_stacking` argument is not supported when using "
121
+ "`tf.tpu.experimental.embedding.FeatureConfig` for the "
122
+ "configuration. You can use the `disable_table_stacking` "
123
+ "attribute of "
124
+ "`tf.tpu.experimental.embedding.SparseCoreEmbeddingConfig` "
125
+ "to disable table stacking."
126
+ )
127
+ if (
128
+ tpu_embedding_feature == EMBEDDING_FEATURE_V1
129
+ and self._sparse_core_embedding_config is not None
130
+ ):
131
+ raise ValueError(
132
+ "The `sparse_core_embedding_config` argument is not "
133
+ "supported with this TPU generation."
134
+ )
135
+ self._tpu_feature_configs = (
136
+ config_conversion.clone_tf_feature_configs(feature_configs)
137
+ )
138
+
139
+ self._tpu_optimizer = config_conversion.translate_optimizer(
140
+ self._optimizer
141
+ )
142
+
143
+ if tpu_embedding_feature == EMBEDDING_FEATURE_V1:
144
+ self._tpu_embedding = tf.tpu.experimental.embedding.TPUEmbedding(
145
+ self._tpu_feature_configs,
146
+ self._tpu_optimizer,
147
+ self._pipeline_execution_with_tensor_core,
148
+ )
149
+ self._v1_call_id = 0
150
+ elif tpu_embedding_feature == EMBEDDING_FEATURE_V2:
151
+ self._tpu_embedding = tf.tpu.experimental.embedding.TPUEmbeddingV2(
152
+ self._tpu_feature_configs,
153
+ self._tpu_optimizer,
154
+ self._pipeline_execution_with_tensor_core,
155
+ self._sparse_core_embedding_config,
156
+ )
157
+ elif tpu_embedding_feature == UNSUPPORTED:
158
+ raise ValueError(
159
+ "Placement to sparsecore was requested, however, this TPU does "
160
+ "not support it."
161
+ )
162
+ elif tpu_embedding_feature != UNSUPPORTED:
163
+ raise ValueError(
164
+ f"Unsupported TPU embedding feature: {tpu_embedding_feature}."
165
+ )
166
+
167
+ # We need at least one trainable variable for the gradient trap to work.
168
+ # Note that the Python attribute name "_gradient_trap_dummy" should
169
+ # match the name of the variable GRADIENT_TRAP_DUMMY_NAME.
170
+ self._gradient_trap_dummy = self.add_weight(
171
+ name=GRADIENT_TRAP_DUMMY_NAME,
172
+ shape=(1,),
173
+ initializer=tf.zeros_initializer(),
174
+ trainable=True,
175
+ dtype=tf.float32,
176
+ )
177
+
178
+ def compute_output_shape(
179
+ self, input_shapes: types.Nested[types.Shape]
180
+ ) -> types.Nested[types.Shape]:
181
+ if self._using_keras_rs_configuration:
182
+ return super().compute_output_shape(input_shapes)
183
+
184
+ def _compute_output_shape(
185
+ feature_config: tf.tpu.experimental.embedding.FeatureConfig,
186
+ input_shape: types.Shape,
187
+ ) -> types.Shape:
188
+ if len(input_shape) < 1:
189
+ raise ValueError(
190
+ f"Received input shape {input_shape}. Rank must be 1 or "
191
+ "above."
192
+ )
193
+ max_sequence_length: int = feature_config.max_sequence_length
194
+ embed_dim = feature_config.table.dim
195
+ if (
196
+ feature_config.output_shape is not None
197
+ and feature_config.output_shape.rank is not None
198
+ ):
199
+ return tuple(feature_config.output_shape.as_list())
200
+ elif (
201
+ len(input_shape) == 2
202
+ and input_shape[-1] != 1
203
+ and max_sequence_length > 0
204
+ ):
205
+ # Update the input shape with the max sequence length. Only
206
+ # update when:
207
+ # 1. Input feature is 2D ragged or sparse tensor.
208
+ # 2. Output shape is not set and max sequence length is set.
209
+ return tuple(input_shape[:-1]) + (
210
+ max_sequence_length,
211
+ embed_dim,
212
+ )
213
+ elif len(input_shape) == 1:
214
+ return tuple(input_shape) + (embed_dim,)
215
+ else:
216
+ return tuple(input_shape[:-1]) + (embed_dim,)
217
+
218
+ output_shapes: types.Nested[types.Shape] = (
219
+ keras.tree.map_structure_up_to(
220
+ self._feature_configs,
221
+ _compute_output_shape,
222
+ self._feature_configs,
223
+ input_shapes,
224
+ )
225
+ )
226
+ return output_shapes
227
+
228
+ def _sparsecore_build(self, input_shapes: dict[str, types.Shape]) -> None:
229
+ if isinstance(
230
+ self._tpu_embedding, tf.tpu.experimental.embedding.TPUEmbedding
231
+ ):
232
+ tf_input_shapes = keras.tree.map_shape_structure(
233
+ tf.TensorShape, input_shapes
234
+ )
235
+ tpu_embedding_build = tf.autograph.to_graph(
236
+ self._tpu_embedding.build, recursive=False
237
+ )
238
+ tpu_embedding_build(
239
+ self._tpu_embedding, per_replica_input_shapes=tf_input_shapes
240
+ )
241
+ elif isinstance(
242
+ self._tpu_embedding, tf.tpu.experimental.embedding.TPUEmbeddingV2
243
+ ):
244
+ self._tpu_embedding.build()
245
+
246
+ def _sparsecore_call(
247
+ self,
248
+ inputs: dict[str, types.Tensor],
249
+ weights: dict[str, types.Tensor] | None = None,
250
+ training: bool = False,
251
+ ) -> dict[str, types.Tensor]:
252
+ del training # Unused.
253
+ strategy = tf.distribute.get_strategy()
254
+ if not self._is_tpu_strategy(strategy):
255
+ raise RuntimeError(
256
+ "DistributedEmbedding needs to be called under a TPUStrategy "
257
+ "for features placed on the embedding feature but is being "
258
+ f"called under strategy {strategy}. Please use `strategy.run` "
259
+ "when calling this layer."
260
+ )
261
+ if isinstance(
262
+ self._tpu_embedding, tf.tpu.experimental.embedding.TPUEmbedding
263
+ ):
264
+ return self._tpu_embedding_lookup_v1(
265
+ self._tpu_embedding, inputs, weights
266
+ )
267
+ elif isinstance(
268
+ self._tpu_embedding, tf.tpu.experimental.embedding.TPUEmbeddingV2
269
+ ):
270
+ return self._tpu_embedding_lookup_v2(
271
+ self._tpu_embedding, inputs, weights
272
+ )
273
+ else:
274
+ raise ValueError(
275
+ "DistributedEmbedding is receiving features to lookup on the "
276
+ "TPU embedding feature but no such feature was configured."
277
+ )
278
+
279
+ def _sparsecore_get_embedding_tables(self) -> dict[str, types.Tensor]:
280
+ tables: dict[str, types.Tensor] = {}
281
+ strategy = tf.distribute.get_strategy()
282
+ # 4 is the number of sparsecores per chip
283
+ num_shards = strategy.num_replicas_in_sync * 4
284
+
285
+ def populate_table(
286
+ feature_config: tf.tpu.experimental.embedding.FeatureConfig,
287
+ ) -> None:
288
+ table_name = feature_config.table.name
289
+ if table_name in tables:
290
+ return
291
+
292
+ embedding_dim = feature_config.table.dim
293
+ table = self._tpu_embedding.embedding_tables[table_name]
294
+
295
+ # This table has num_sparse_cores mod shards, so we need to slice,
296
+ # reconcat and reshape.
297
+ table_shards = [
298
+ shard.numpy()[:, :embedding_dim] for shard in table.values
299
+ ]
300
+ full_table = keras.ops.concatenate(table_shards, axis=0)
301
+ full_table = keras.ops.concatenate(
302
+ keras.ops.split(full_table, num_shards, axis=0), axis=1
303
+ )
304
+ full_table = keras.ops.reshape(full_table, [-1, embedding_dim])
305
+ tables[table_name] = full_table[
306
+ : feature_config.table.vocabulary_size, :
307
+ ]
308
+
309
+ keras.tree.map_structure(populate_table, self._tpu_feature_configs)
310
+ return tables
311
+
312
+ def _verify_input_shapes(
313
+ self, input_shapes: types.Nested[types.Shape]
314
+ ) -> None:
315
+ if self._using_keras_rs_configuration:
316
+ return super()._verify_input_shapes(input_shapes)
317
+ # `tf.tpu.experimental.embedding.FeatureConfig` does not provide any
318
+ # information about the input shape, so there is nothing to verify.
319
+
320
+ def _tpu_embedding_lookup_v1(
321
+ self,
322
+ tpu_embedding: tf.tpu.experimental.embedding.TPUEmbedding,
323
+ inputs: dict[str, types.Tensor],
324
+ weights: dict[str, types.Tensor] | None = None,
325
+ ) -> dict[str, types.Tensor]:
326
+ # Each call to this function increments the _v1_call_id by 1, this
327
+ # allows us to tag each of the main embedding ops with this call id so
328
+ # that we know during graph rewriting passes which ops correspond to the
329
+ # same layer call.
330
+ self._v1_call_id += 1
331
+ name = str(self._v1_call_id)
332
+
333
+ # Set training to true, even during eval. When name is set, this will
334
+ # trigger a pass that updates the training based on if there is a send
335
+ # gradients with the same name.
336
+ tpu_embedding.enqueue(inputs, weights, training=True, name=name)
337
+
338
+ @tf.custom_gradient # type: ignore
339
+ def gradient_trap(
340
+ dummy: types.Tensor,
341
+ ) -> tuple[
342
+ list[types.Tensor], Callable[[tuple[types.Tensor]], types.Tensor]
343
+ ]:
344
+ """Register a gradient function for activation."""
345
+ activations = tpu_embedding.dequeue(name=name)
346
+
347
+ def grad(*grad_wrt_activations: types.Tensor) -> types.Tensor:
348
+ """Gradient function."""
349
+ # Since the output were flattened, the gradients are also
350
+ # flattened. Pack them back into the correct nested structure.
351
+ gradients = tf.nest.pack_sequence_as(
352
+ self._placement_to_path_to_feature_config["sparsecore"],
353
+ grad_wrt_activations,
354
+ )
355
+ tpu_embedding.apply_gradients(gradients, name=name)
356
+
357
+ # This is the gradient for the input variable.
358
+ return tf.zeros_like(dummy)
359
+
360
+ # Custom gradient functions don't like nested structures of tensors,
361
+ # so we flatten them here.
362
+ return tf.nest.flatten(activations), grad
363
+
364
+ activations_with_trap = gradient_trap(self._gradient_trap_dummy.value)
365
+ result: dict[str, types.Tensor] = tf.nest.pack_sequence_as(
366
+ self._placement_to_path_to_feature_config["sparsecore"],
367
+ activations_with_trap,
368
+ )
369
+ return result
370
+
371
+ def _tpu_embedding_lookup_v2(
372
+ self,
373
+ tpu_embedding: tf.tpu.experimental.embedding.TPUEmbeddingV2,
374
+ inputs: dict[str, types.Tensor],
375
+ weights: dict[str, types.Tensor] | None = None,
376
+ ) -> dict[str, types.Tensor]:
377
+ @tf.custom_gradient # type: ignore
378
+ def gradient_trap(
379
+ dummy: types.Tensor,
380
+ ) -> tuple[
381
+ list[types.Tensor], Callable[[tuple[types.Tensor]], types.Tensor]
382
+ ]:
383
+ """Register a gradient function for activation."""
384
+ activations, preserved_result = tpu_embedding(inputs, weights)
385
+
386
+ def grad(*grad_wrt_activations: types.Tensor) -> types.Tensor:
387
+ """Gradient function."""
388
+ # Since the output were flattened, the gradients are also
389
+ # flattened. Pack them back into the correct nested structure.
390
+ gradients = tf.nest.pack_sequence_as(
391
+ self._placement_to_path_to_feature_config["sparsecore"],
392
+ grad_wrt_activations,
393
+ )
394
+ tpu_embedding.apply_gradients(
395
+ gradients, preserved_outputs=preserved_result
396
+ )
397
+ # This is the gradient for the input variable.
398
+ return tf.zeros_like(dummy)
399
+
400
+ # Custom gradient functions don't like nested structures of tensors,
401
+ # so we flatten them here.
402
+ return tf.nest.flatten(activations), grad
403
+
404
+ activations_with_trap = gradient_trap(self._gradient_trap_dummy)
405
+ result: dict[str, types.Tensor] = tf.nest.pack_sequence_as(
406
+ self._placement_to_path_to_feature_config["sparsecore"],
407
+ activations_with_trap,
408
+ )
409
+ return result
410
+
411
+ def _trackable_children(
412
+ self, save_type: str = "checkpoint", **kwargs: dict[str, Any]
413
+ ) -> dict[str, Any]:
414
+ # Remove dummy variable, we don't want it in checkpoints.
415
+ children: dict[str, Any] = super()._trackable_children(
416
+ save_type, **kwargs
417
+ )
418
+ children.pop(GRADIENT_TRAP_DUMMY_NAME, None)
419
+ return children
420
+
421
+
422
+ DistributedEmbedding.__doc__ = (
423
+ base_distributed_embedding.DistributedEmbedding.__doc__
424
+ )
@@ -205,8 +205,8 @@ class DotInteraction(keras.layers.Layer):
205
205
  return activations
206
206
 
207
207
  def compute_output_shape(
208
- self, input_shape: list[types.TensorShape]
209
- ) -> types.TensorShape:
208
+ self, input_shape: list[types.Shape]
209
+ ) -> types.Shape:
210
210
  num_features = len(input_shape)
211
211
  batch_size = input_shape[0][0]
212
212
 
@@ -1,4 +1,4 @@
1
- from typing import Any, Optional, Text, Union
1
+ from typing import Any
2
2
 
3
3
  import keras
4
4
  from keras import ops
@@ -92,20 +92,18 @@ class FeatureCross(keras.layers.Layer):
92
92
 
93
93
  def __init__(
94
94
  self,
95
- projection_dim: Optional[int] = None,
96
- diag_scale: Optional[float] = 0.0,
95
+ projection_dim: int | None = None,
96
+ diag_scale: float | None = 0.0,
97
97
  use_bias: bool = True,
98
- pre_activation: Optional[Union[str, keras.layers.Activation]] = None,
99
- kernel_initializer: Union[
100
- Text, keras.initializers.Initializer
101
- ] = "glorot_uniform",
102
- bias_initializer: Union[Text, keras.initializers.Initializer] = "zeros",
103
- kernel_regularizer: Union[
104
- Text, None, keras.regularizers.Regularizer
105
- ] = None,
106
- bias_regularizer: Union[
107
- Text, None, keras.regularizers.Regularizer
108
- ] = None,
98
+ pre_activation: str | keras.layers.Activation | None = None,
99
+ kernel_initializer: (
100
+ str | keras.initializers.Initializer
101
+ ) = "glorot_uniform",
102
+ bias_initializer: str | keras.initializers.Initializer = "zeros",
103
+ kernel_regularizer: (
104
+ str | None | keras.regularizers.Regularizer
105
+ ) = None,
106
+ bias_regularizer: (str | None | keras.regularizers.Regularizer) = None,
109
107
  **kwargs: Any,
110
108
  ) -> None:
111
109
  super().__init__(**kwargs)
@@ -129,7 +127,7 @@ class FeatureCross(keras.layers.Layer):
129
127
  f"`diag_scale={self.diag_scale}`"
130
128
  )
131
129
 
132
- def build(self, input_shape: types.TensorShape) -> None:
130
+ def build(self, input_shape: types.Shape) -> None:
133
131
  last_dim = input_shape[-1]
134
132
 
135
133
  if self.projection_dim is not None:
@@ -155,7 +153,7 @@ class FeatureCross(keras.layers.Layer):
155
153
  self.built = True
156
154
 
157
155
  def call(
158
- self, x0: types.Tensor, x: Optional[types.Tensor] = None
156
+ self, x0: types.Tensor, x: types.Tensor | None = None
159
157
  ) -> types.Tensor:
160
158
  """Forward pass of the cross layer.
161
159
 
@@ -1,4 +1,4 @@
1
- from typing import Any, Optional, Union
1
+ from typing import Any
2
2
 
3
3
  import keras
4
4
 
@@ -55,8 +55,8 @@ class BruteForceRetrieval(Retrieval):
55
55
 
56
56
  def __init__(
57
57
  self,
58
- candidate_embeddings: Optional[types.Tensor] = None,
59
- candidate_ids: Optional[types.Tensor] = None,
58
+ candidate_embeddings: types.Tensor | None = None,
59
+ candidate_ids: types.Tensor | None = None,
60
60
  k: int = 10,
61
61
  return_scores: bool = True,
62
62
  **kwargs: Any,
@@ -81,7 +81,7 @@ class BruteForceRetrieval(Retrieval):
81
81
  def update_candidates(
82
82
  self,
83
83
  candidate_embeddings: types.Tensor,
84
- candidate_ids: Optional[types.Tensor] = None,
84
+ candidate_ids: types.Tensor | None = None,
85
85
  ) -> None:
86
86
  """Update the set of candidates and optionally their candidate IDs.
87
87
 
@@ -125,7 +125,7 @@ class BruteForceRetrieval(Retrieval):
125
125
 
126
126
  def call(
127
127
  self, inputs: types.Tensor
128
- ) -> Union[types.Tensor, tuple[types.Tensor, types.Tensor]]:
128
+ ) -> types.Tensor | tuple[types.Tensor, types.Tensor]:
129
129
  """Returns the top candidates for the query passed as input.
130
130
 
131
131
  Args:
@@ -1,5 +1,5 @@
1
1
  import abc
2
- from typing import Any, Optional, Union
2
+ from typing import Any
3
3
 
4
4
  import keras
5
5
 
@@ -35,7 +35,7 @@ class Retrieval(keras.layers.Layer, abc.ABC):
35
35
  def _validate_candidate_embeddings_and_ids(
36
36
  self,
37
37
  candidate_embeddings: types.Tensor,
38
- candidate_ids: Optional[types.Tensor] = None,
38
+ candidate_ids: types.Tensor | None = None,
39
39
  ) -> None:
40
40
  """Validates inputs to `update_candidates()`."""
41
41
 
@@ -71,7 +71,7 @@ class Retrieval(keras.layers.Layer, abc.ABC):
71
71
  def update_candidates(
72
72
  self,
73
73
  candidate_embeddings: types.Tensor,
74
- candidate_ids: Optional[types.Tensor] = None,
74
+ candidate_ids: types.Tensor | None = None,
75
75
  ) -> None:
76
76
  """Update the set of candidates and optionally their candidate IDs.
77
77
 
@@ -85,7 +85,7 @@ class Retrieval(keras.layers.Layer, abc.ABC):
85
85
  @abc.abstractmethod
86
86
  def call(
87
87
  self, inputs: types.Tensor
88
- ) -> Union[types.Tensor, tuple[types.Tensor, types.Tensor]]:
88
+ ) -> types.Tensor | tuple[types.Tensor, types.Tensor]:
89
89
  """Returns the top candidates for the query passed as input.
90
90
 
91
91
  Args:
@@ -1,5 +1,5 @@
1
1
  import abc
2
- from typing import Any, Optional
2
+ from typing import Any
3
3
 
4
4
  import keras
5
5
  from keras import ops
@@ -43,7 +43,7 @@ class PairwiseLoss(keras.losses.Loss, abc.ABC):
43
43
  self,
44
44
  labels: types.Tensor,
45
45
  logits: types.Tensor,
46
- mask: Optional[types.Tensor] = None,
46
+ mask: types.Tensor | None = None,
47
47
  ) -> tuple[types.Tensor, types.Tensor]:
48
48
  # Mask all values less than 0 (since less than 0 implies invalid
49
49
  # labels).
@@ -1,5 +1,3 @@
1
- from typing import Optional
2
-
3
1
  from keras import ops
4
2
 
5
3
  from keras_rs.src import types
@@ -20,7 +18,7 @@ class PairwiseMeanSquaredError(PairwiseLoss):
20
18
  self,
21
19
  labels: types.Tensor,
22
20
  logits: types.Tensor,
23
- mask: Optional[types.Tensor] = None,
21
+ mask: types.Tensor | None = None,
24
22
  ) -> tuple[types.Tensor, types.Tensor]:
25
23
  # Override `PairwiseLoss.compute_unreduced_loss` since pairwise weights
26
24
  # for MSE are computed differently.
@@ -1,4 +1,4 @@
1
- from typing import Any, Callable, Optional
1
+ from typing import Any, Callable
2
2
 
3
3
  from keras import ops
4
4
  from keras.saving import deserialize_keras_object
@@ -25,7 +25,7 @@ from keras_rs.src.utils.doc_string_utils import format_docstring
25
25
  class DCG(RankingMetric):
26
26
  def __init__(
27
27
  self,
28
- k: Optional[int] = None,
28
+ k: int | None = None,
29
29
  gain_fn: Callable[[types.Tensor], types.Tensor] = default_gain_fn,
30
30
  rank_discount_fn: Callable[
31
31
  [types.Tensor], types.Tensor
@@ -25,7 +25,7 @@ class MeanAveragePrecision(RankingMetric):
25
25
  ) -> types.Tensor:
26
26
  relevance = ops.cast(
27
27
  ops.greater_equal(y_true, ops.cast(1, dtype=y_true.dtype)),
28
- dtype="float32",
28
+ dtype=y_pred.dtype,
29
29
  )
30
30
  sorted_relevance, sorted_weights = sort_by_scores(
31
31
  tensors_to_sort=[relevance, sample_weight],
@@ -44,13 +44,13 @@ class MeanReciprocalRank(RankingMetric):
44
44
  ops.greater_equal(
45
45
  sorted_y_true, ops.cast(1, dtype=sorted_y_true.dtype)
46
46
  ),
47
- dtype="float32",
47
+ dtype=y_pred.dtype,
48
48
  )
49
49
 
50
50
  # `reciprocal_rank = [1, 0.5, 0.33]`
51
51
  reciprocal_rank = ops.divide(
52
- ops.cast(1, dtype="float32"),
53
- ops.arange(1, list_length + 1, dtype="float32"),
52
+ ops.cast(1, dtype=y_pred.dtype),
53
+ ops.arange(1, list_length + 1, dtype=y_pred.dtype),
54
54
  )
55
55
 
56
56
  # `mrr` should be of shape `(batch_size, 1)`.
@@ -64,7 +64,7 @@ class MeanReciprocalRank(RankingMetric):
64
64
  # Get weights.
65
65
  overall_relevance = ops.cast(
66
66
  ops.greater_equal(y_true, ops.cast(1, dtype=y_true.dtype)),
67
- dtype="float32",
67
+ dtype=y_pred.dtype,
68
68
  )
69
69
  per_list_weights = get_list_weights(
70
70
  weights=sample_weight, relevance=overall_relevance
@@ -1,4 +1,4 @@
1
- from typing import Any, Callable, Optional
1
+ from typing import Any, Callable
2
2
 
3
3
  from keras import ops
4
4
  from keras.saving import deserialize_keras_object
@@ -25,7 +25,7 @@ from keras_rs.src.utils.doc_string_utils import format_docstring
25
25
  class NDCG(RankingMetric):
26
26
  def __init__(
27
27
  self,
28
- k: Optional[int] = None,
28
+ k: int | None = None,
29
29
  gain_fn: Callable[[types.Tensor], types.Tensor] = default_gain_fn,
30
30
  rank_discount_fn: Callable[
31
31
  [types.Tensor], types.Tensor
@@ -40,7 +40,7 @@ class PrecisionAtK(RankingMetric):
40
40
  ops.greater_equal(
41
41
  sorted_y_true, ops.cast(1, dtype=sorted_y_true.dtype)
42
42
  ),
43
- dtype="float32",
43
+ dtype=y_pred.dtype,
44
44
  )
45
45
  list_length = ops.shape(sorted_y_true)[1]
46
46
  # TODO: We do not do this for MRR, and the other metrics. Do we need to
@@ -52,13 +52,13 @@ class PrecisionAtK(RankingMetric):
52
52
 
53
53
  per_list_precision = ops.divide_no_nan(
54
54
  ops.sum(relevance, axis=1, keepdims=True),
55
- ops.cast(valid_list_length, dtype="float32"),
55
+ ops.cast(valid_list_length, dtype=y_pred.dtype),
56
56
  )
57
57
 
58
58
  # Get weights.
59
59
  overall_relevance = ops.cast(
60
60
  ops.greater_equal(y_true, ops.cast(1, dtype=y_true.dtype)),
61
- dtype="float32",
61
+ dtype=y_pred.dtype,
62
62
  )
63
63
  per_list_weights = get_list_weights(
64
64
  weights=sample_weight, relevance=overall_relevance