keras-rs-nightly 0.0.1.dev2025021903__py3-none-any.whl → 0.3.1.dev202512130338__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. keras_rs/__init__.py +9 -28
  2. keras_rs/layers/__init__.py +37 -0
  3. keras_rs/losses/__init__.py +19 -0
  4. keras_rs/metrics/__init__.py +16 -0
  5. keras_rs/src/layers/embedding/base_distributed_embedding.py +1151 -0
  6. keras_rs/src/layers/embedding/distributed_embedding.py +33 -0
  7. keras_rs/src/layers/embedding/distributed_embedding_config.py +132 -0
  8. keras_rs/src/layers/embedding/embed_reduce.py +309 -0
  9. keras_rs/src/layers/embedding/jax/__init__.py +0 -0
  10. keras_rs/src/layers/embedding/jax/checkpoint_utils.py +104 -0
  11. keras_rs/src/layers/embedding/jax/config_conversion.py +468 -0
  12. keras_rs/src/layers/embedding/jax/distributed_embedding.py +829 -0
  13. keras_rs/src/layers/embedding/jax/embedding_lookup.py +276 -0
  14. keras_rs/src/layers/embedding/jax/embedding_utils.py +217 -0
  15. keras_rs/src/layers/embedding/tensorflow/__init__.py +0 -0
  16. keras_rs/src/layers/embedding/tensorflow/config_conversion.py +363 -0
  17. keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py +436 -0
  18. keras_rs/src/layers/feature_interaction/__init__.py +0 -0
  19. keras_rs/src/layers/{modeling → feature_interaction}/dot_interaction.py +116 -25
  20. keras_rs/src/layers/{modeling → feature_interaction}/feature_cross.py +40 -22
  21. keras_rs/src/layers/retrieval/brute_force_retrieval.py +16 -65
  22. keras_rs/src/layers/retrieval/hard_negative_mining.py +94 -0
  23. keras_rs/src/layers/retrieval/remove_accidental_hits.py +97 -0
  24. keras_rs/src/layers/retrieval/retrieval.py +127 -0
  25. keras_rs/src/layers/retrieval/sampling_probability_correction.py +63 -0
  26. keras_rs/src/losses/__init__.py +0 -0
  27. keras_rs/src/losses/list_mle_loss.py +212 -0
  28. keras_rs/src/losses/pairwise_hinge_loss.py +90 -0
  29. keras_rs/src/losses/pairwise_logistic_loss.py +99 -0
  30. keras_rs/src/losses/pairwise_loss.py +165 -0
  31. keras_rs/src/losses/pairwise_loss_utils.py +39 -0
  32. keras_rs/src/losses/pairwise_mean_squared_error.py +133 -0
  33. keras_rs/src/losses/pairwise_soft_zero_one_loss.py +98 -0
  34. keras_rs/src/metrics/__init__.py +0 -0
  35. keras_rs/src/metrics/dcg.py +161 -0
  36. keras_rs/src/metrics/mean_average_precision.py +130 -0
  37. keras_rs/src/metrics/mean_reciprocal_rank.py +121 -0
  38. keras_rs/src/metrics/ndcg.py +197 -0
  39. keras_rs/src/metrics/precision_at_k.py +117 -0
  40. keras_rs/src/metrics/ranking_metric.py +260 -0
  41. keras_rs/src/metrics/ranking_metrics_utils.py +257 -0
  42. keras_rs/src/metrics/recall_at_k.py +108 -0
  43. keras_rs/src/metrics/utils.py +70 -0
  44. keras_rs/src/types.py +43 -14
  45. keras_rs/src/utils/doc_string_utils.py +53 -0
  46. keras_rs/src/utils/keras_utils.py +52 -3
  47. keras_rs/src/utils/tpu_test_utils.py +120 -0
  48. keras_rs/src/version.py +1 -1
  49. {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/METADATA +88 -8
  50. keras_rs_nightly-0.3.1.dev202512130338.dist-info/RECORD +58 -0
  51. {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/WHEEL +1 -1
  52. keras_rs/api/__init__.py +0 -9
  53. keras_rs/api/layers/__init__.py +0 -11
  54. keras_rs_nightly-0.0.1.dev2025021903.dist-info/RECORD +0 -19
  55. /keras_rs/src/layers/{modeling → embedding}/__init__.py +0 -0
  56. {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- from typing import Any, Optional, Text, Union
1
+ from typing import Any
2
2
 
3
3
  import keras
4
4
  from keras import ops
@@ -52,17 +52,37 @@ class FeatureCross(keras.layers.Layer):
52
52
  Regularizer to use for the kernel matrix.
53
53
  bias_regularizer: string or `keras.regularizer` regularizer.
54
54
  Regularizer to use for the bias vector.
55
+ **kwargs: Args to pass to the base class.
55
56
 
56
57
  Example:
57
58
 
58
59
  ```python
59
- # after embedding layer in a functional model
60
- input = keras.Input(shape=(), name='indices', dtype="int64")
61
- x0 = keras.layers.Embedding(input_dim=32, output_dim=6)(x0)
62
- x1 = FeatureCross()(x0, x0)
63
- x2 = FeatureCross()(x0, x1)
60
+ # 1. Simple forward pass
61
+ batch_size = 2
62
+ embedding_dim = 32
63
+ feature1 = np.random.randn(batch_size, embedding_dim)
64
+ feature2 = np.random.randn(batch_size, embedding_dim)
65
+ crossed_features = keras_rs.layers.FeatureCross()(feature1, feature2)
66
+
67
+ # 2. After embedding layer in a model
68
+ vocabulary_size = 32
69
+ embedding_dim = 6
70
+
71
+ # Create a simple model containing the layer.
72
+ inputs = keras.Input(shape=(), name='indices', dtype="int32")
73
+ x0 = keras.layers.Embedding(
74
+ input_dim=vocabulary_size,
75
+ output_dim=embedding_dim
76
+ )(inputs)
77
+ x1 = keras_rs.layers.FeatureCross()(x0, x0)
78
+ x2 = keras_rs.layers.FeatureCross()(x0, x1)
64
79
  logits = keras.layers.Dense(units=10)(x2)
65
- model = keras.Model(input, logits)
80
+ model = keras.Model(inputs, logits)
81
+
82
+ # Call the model on the inputs.
83
+ batch_size = 2
84
+ input_data = np.random.randint(0, vocabulary_size, size=(batch_size,))
85
+ outputs = model(input_data)
66
86
  ```
67
87
 
68
88
  References:
@@ -72,20 +92,18 @@ class FeatureCross(keras.layers.Layer):
72
92
 
73
93
  def __init__(
74
94
  self,
75
- projection_dim: Optional[int] = None,
76
- diag_scale: Optional[float] = 0.0,
95
+ projection_dim: int | None = None,
96
+ diag_scale: float | None = 0.0,
77
97
  use_bias: bool = True,
78
- pre_activation: Optional[Union[str, keras.layers.Activation]] = None,
79
- kernel_initializer: Union[
80
- Text, keras.initializers.Initializer
81
- ] = "glorot_uniform",
82
- bias_initializer: Union[Text, keras.initializers.Initializer] = "zeros",
83
- kernel_regularizer: Union[
84
- Text, None, keras.regularizers.Regularizer
85
- ] = None,
86
- bias_regularizer: Union[
87
- Text, None, keras.regularizers.Regularizer
88
- ] = None,
98
+ pre_activation: str | keras.layers.Activation | None = None,
99
+ kernel_initializer: (
100
+ str | keras.initializers.Initializer
101
+ ) = "glorot_uniform",
102
+ bias_initializer: str | keras.initializers.Initializer = "zeros",
103
+ kernel_regularizer: (
104
+ str | None | keras.regularizers.Regularizer
105
+ ) = None,
106
+ bias_regularizer: (str | None | keras.regularizers.Regularizer) = None,
89
107
  **kwargs: Any,
90
108
  ) -> None:
91
109
  super().__init__(**kwargs)
@@ -109,7 +127,7 @@ class FeatureCross(keras.layers.Layer):
109
127
  f"`diag_scale={self.diag_scale}`"
110
128
  )
111
129
 
112
- def build(self, input_shape: types.TensorShape) -> None:
130
+ def build(self, input_shape: types.Shape) -> None:
113
131
  last_dim = input_shape[-1]
114
132
 
115
133
  if self.projection_dim is not None:
@@ -135,7 +153,7 @@ class FeatureCross(keras.layers.Layer):
135
153
  self.built = True
136
154
 
137
155
  def call(
138
- self, x0: types.Tensor, x: Optional[types.Tensor] = None
156
+ self, x0: types.Tensor, x: types.Tensor | None = None
139
157
  ) -> types.Tensor:
140
158
  """Forward pass of the cross layer.
141
159
 
@@ -1,13 +1,14 @@
1
- from typing import Any, Optional, Union
1
+ from typing import Any
2
2
 
3
3
  import keras
4
4
 
5
5
  from keras_rs.src import types
6
6
  from keras_rs.src.api_export import keras_rs_export
7
+ from keras_rs.src.layers.retrieval.retrieval import Retrieval
7
8
 
8
9
 
9
10
  @keras_rs_export("keras_rs.layers.BruteForceRetrieval")
10
- class BruteForceRetrieval(keras.layers.Layer):
11
+ class BruteForceRetrieval(Retrieval):
11
12
  """Brute force top-k retrieval.
12
13
 
13
14
  This layer maintains a set of candidates and is able to exactly retrieve the
@@ -54,17 +55,19 @@ class BruteForceRetrieval(keras.layers.Layer):
54
55
 
55
56
  def __init__(
56
57
  self,
57
- candidate_embeddings: Optional[types.Tensor] = None,
58
- candidate_ids: Optional[types.Tensor] = None,
58
+ candidate_embeddings: types.Tensor | None = None,
59
+ candidate_ids: types.Tensor | None = None,
59
60
  k: int = 10,
60
61
  return_scores: bool = True,
61
62
  **kwargs: Any,
62
63
  ) -> None:
63
- super().__init__(**kwargs)
64
+ # Keep `k`, `return_scores` as separately passed args instead of keeping
65
+ # them in `kwargs`. This is to ensure the user does not have to hop
66
+ # to the base class to check which other args can be passed.
67
+ super().__init__(k=k, return_scores=return_scores, **kwargs)
68
+
64
69
  self.candidate_embeddings = None
65
70
  self.candidate_ids = None
66
- self.k = k
67
- self.return_scores = return_scores
68
71
 
69
72
  if candidate_embeddings is None:
70
73
  if candidate_ids is not None:
@@ -78,42 +81,18 @@ class BruteForceRetrieval(keras.layers.Layer):
78
81
  def update_candidates(
79
82
  self,
80
83
  candidate_embeddings: types.Tensor,
81
- candidate_ids: Optional[types.Tensor] = None,
84
+ candidate_ids: types.Tensor | None = None,
82
85
  ) -> None:
83
86
  """Update the set of candidates and optionally their candidate IDs.
84
87
 
85
88
  Args:
86
89
  candidate_embeddings: The candidate embeddings.
87
- candidate_ids: The identifiers for the candidates. If `None` the
90
+ candidate_ids: The identifiers for the candidates. If `None`, the
88
91
  indices of the candidates are returned instead.
89
92
  """
90
- if candidate_embeddings is None:
91
- raise ValueError("`candidate_embeddings` is required")
92
-
93
- if len(candidate_embeddings.shape) != 2:
94
- raise ValueError(
95
- "`candidate_embeddings` must be a tensor of rank 2 "
96
- "(num_candidates, embedding_size), received "
97
- "`candidate_embeddings` with shape "
98
- f"{candidate_embeddings.shape}"
99
- )
100
-
101
- if candidate_embeddings.shape[0] < self.k:
102
- raise ValueError(
103
- "The number of candidates provided "
104
- f"({candidate_embeddings.shape[0]}) is less than the number of "
105
- f"candidates to retrieve (k={self.k})."
106
- )
107
-
108
- if (
109
- candidate_ids is not None
110
- and candidate_ids.shape[0] != candidate_embeddings.shape[0]
111
- ):
112
- raise ValueError(
113
- "The `candidate_embeddings` and `candidate_is` tensors must "
114
- "have the same number of rows, got tensors of shape "
115
- f"{candidate_embeddings.shape} and {candidate_ids.shape}."
116
- )
93
+ self._validate_candidate_embeddings_and_ids(
94
+ candidate_embeddings, candidate_ids
95
+ )
117
96
 
118
97
  if self.candidate_embeddings is not None:
119
98
  # Update of existing variables.
@@ -146,7 +125,7 @@ class BruteForceRetrieval(keras.layers.Layer):
146
125
 
147
126
  def call(
148
127
  self, inputs: types.Tensor
149
- ) -> Union[types.Tensor, tuple[types.Tensor, types.Tensor]]:
128
+ ) -> types.Tensor | tuple[types.Tensor, types.Tensor]:
150
129
  """Returns the top candidates for the query passed as input.
151
130
 
152
131
  Args:
@@ -167,31 +146,3 @@ class BruteForceRetrieval(keras.layers.Layer):
167
146
  return top_scores, top_ids
168
147
  else:
169
148
  return top_ids
170
-
171
- def compute_score(
172
- self, query_embedding: types.Tensor, candidate_embedding: types.Tensor
173
- ) -> types.Tensor:
174
- """Computes the standard dot product score from queries and candidates.
175
-
176
- Args:
177
- query_embedding: Tensor of query embedding corresponding to the
178
- queries for which to retrieve top candidates.
179
- candidate_embedding: Tensor of candidate embeddings.
180
-
181
- Returns:
182
- The dot product of queries and candidates.
183
- """
184
-
185
- return keras.ops.matmul(
186
- query_embedding, keras.ops.transpose(candidate_embedding)
187
- )
188
-
189
- def get_config(self) -> dict[str, Any]:
190
- config: dict[str, Any] = super().get_config()
191
- config.update(
192
- {
193
- "k": self.k,
194
- "return_scores": self.compute_score,
195
- }
196
- )
197
- return config
@@ -0,0 +1,94 @@
1
+ from typing import Any
2
+
3
+ import keras
4
+ import ml_dtypes
5
+ from keras import ops
6
+
7
+ from keras_rs.src import types
8
+ from keras_rs.src.api_export import keras_rs_export
9
+
10
+ MAX_FLOAT = ml_dtypes.finfo("float32").max / 100.0
11
+
12
+
13
+ @keras_rs_export("keras_rs.layers.HardNegativeMining")
14
+ class HardNegativeMining(keras.layers.Layer):
15
+ """Filter logits and labels to return hard negatives.
16
+
17
+ The output will include logits and labels for the requested number of hard
18
+ negatives as well as the positive candidate.
19
+
20
+ Args:
21
+ num_hard_negatives: How many hard negatives to return.
22
+ **kwargs: Args to pass to the base class.
23
+
24
+ Example:
25
+
26
+ ```python
27
+ # Create layer with the configured number of hard negatives to mine.
28
+ hard_negative_mining = keras_rs.layers.HardNegativeMining(
29
+ num_hard_negatives=10
30
+ )
31
+
32
+ # This will retrieve the top 10 negative candidates plus the positive
33
+ # candidate from `labels` for each row.
34
+ out_logits, out_labels = hard_negative_mining(in_logits, in_labels)
35
+ ```
36
+ """
37
+
38
+ def __init__(self, num_hard_negatives: int, **kwargs: Any) -> None:
39
+ super().__init__(**kwargs)
40
+ self._num_hard_negatives = num_hard_negatives
41
+ self.built = True
42
+
43
+ def call(
44
+ self, logits: types.Tensor, labels: types.Tensor
45
+ ) -> tuple[types.Tensor, types.Tensor]:
46
+ """Filters logits and labels with per-query hard negative mining.
47
+
48
+ The result will include logits and labels for `num_hard_negatives`
49
+ negatives as well as the positive candidate.
50
+
51
+ Args:
52
+ logits: The logits tensor, typically `[batch_size, num_candidates]`
53
+ but can have more dimensions or be 1D as `[num_candidates]`.
54
+ labels: The one-hot labels tensor, must be the same shape as
55
+ `logits`.
56
+
57
+ Returns:
58
+ A tuple containing two tensors with the last dimension of
59
+ `num_candidates` replaced with `num_hard_negatives + 1`.
60
+
61
+ * logits: `[..., num_hard_negatives + 1]` tensor of logits.
62
+ * labels: `[..., num_hard_negatives + 1]` one-hot tensor of labels.
63
+ """
64
+
65
+ # Number of sampled logits, i.e, the number of hard negatives to be
66
+ # sampled (k) + number of true logit (1) per query, capped by batch
67
+ # size.
68
+ num_logits = ops.shape(logits)[-1]
69
+ if isinstance(num_logits, int):
70
+ num_sampled = min(self._num_hard_negatives + 1, num_logits)
71
+ else:
72
+ num_sampled = ops.minimum(self._num_hard_negatives + 1, num_logits)
73
+ # To gather indices of top k negative logits per row (query) in logits,
74
+ # true logits need to be excluded. First replace the true logits
75
+ # (corresponding to positive labels) with a large score value and then
76
+ # select the top k + 1 logits from each row so that selected indices
77
+ # include the indices of true logit + top k negative logits. This
78
+ # approach is to avoid using inefficient masking when excluding true
79
+ # logits.
80
+
81
+ # For each query, get the indices of the logits which have the highest
82
+ # k + 1 logit values, including the highest k negative logits and one
83
+ # true logit.
84
+ _, indices = ops.top_k(
85
+ ops.add(logits, ops.multiply(labels, MAX_FLOAT)),
86
+ k=num_sampled,
87
+ sorted=False,
88
+ )
89
+
90
+ # Gather sampled logits and corresponding labels.
91
+ logits = ops.take_along_axis(logits, indices, axis=-1)
92
+ labels = ops.take_along_axis(labels, indices, axis=-1)
93
+
94
+ return logits, labels
@@ -0,0 +1,97 @@
1
+ import keras
2
+ import ml_dtypes
3
+ from keras import ops
4
+
5
+ from keras_rs.src import types
6
+ from keras_rs.src.api_export import keras_rs_export
7
+ from keras_rs.src.utils import keras_utils
8
+
9
+ SMALLEST_FLOAT = ml_dtypes.finfo("float32").smallest_normal / 100.0
10
+
11
+
12
+ @keras_rs_export("keras_rs.layers.RemoveAccidentalHits")
13
+ class RemoveAccidentalHits(keras.layers.Layer):
14
+ """Zeroes the logits of accidental negatives.
15
+
16
+ Zeroes the logits of negative candidates that have the same ID as the
17
+ positive candidate in that row.
18
+
19
+ Example:
20
+
21
+ ```python
22
+ # Create layer with the configured number of hard negatives to mine.
23
+ remove_accidental_hits = keras_rs.layers.RemoveAccidentalHits()
24
+
25
+ # This will zero the logits of negative candidates that have the same ID as
26
+ # the positive candidate from `labels` so as to not negatively impact the
27
+ # true positive.
28
+ logits = remove_accidental_hits(logits, labels, candidate_ids)
29
+ ```
30
+ """
31
+
32
+ def call(
33
+ self,
34
+ logits: types.Tensor,
35
+ labels: types.Tensor,
36
+ candidate_ids: types.Tensor,
37
+ ) -> types.Tensor:
38
+ """Zeroes selected logits.
39
+
40
+ For each row in the batch, zeroes the logits of negative candidates that
41
+ have the same ID as the positive candidate in that row.
42
+
43
+ Args:
44
+ logits: The logits tensor, typically `[batch_size, num_candidates]`
45
+ but can have more dimensions or be 1D as `[num_candidates]`.
46
+ labels: The one-hot labels tensor, must be the same shape as
47
+ `logits`.
48
+ candidate_ids: The candidate identifiers tensor, can be
49
+ `[num_candidates]` or `[batch_size, num_candidates]` or have
50
+ more dimensions as long as they match the last dimensions of
51
+ `labels`.
52
+
53
+ Returns:
54
+ The modified logits with the same shape as the input logits.
55
+ """
56
+ # A more principled way is to implement
57
+ # `softmax_cross_entropy_with_logits` with a input mask. Here we
58
+ # approximate so by letting accidental hits have extremely small logits
59
+ # (SMALLEST_FLOAT) for ease-of-implementation.
60
+
61
+ labels_shape = ops.shape(labels)
62
+ labels_rank = len(labels_shape)
63
+ logits_shape = ops.shape(logits)
64
+ candidate_ids_shape = ops.shape(candidate_ids)
65
+ candidate_ids_rank = len(candidate_ids_shape)
66
+
67
+ if not keras_utils.check_shapes_compatible(labels_shape, logits_shape):
68
+ raise ValueError(
69
+ "`labels` and `logits` should have the same shape. Received: "
70
+ f"`labels.shape` = {labels_shape}, "
71
+ f"`logits.shape` = {logits_shape}."
72
+ )
73
+
74
+ if not keras_utils.check_shapes_compatible(
75
+ labels_shape[-candidate_ids_rank:], candidate_ids_shape
76
+ ):
77
+ raise ValueError(
78
+ "`candidate_ids` should have the same shape as the last "
79
+ "dimensions of `labels`. Received: "
80
+ f"`candidate_ids.shape` = {candidate_ids_shape}, "
81
+ f"`labels.shape` = {labels_shape}."
82
+ )
83
+
84
+ # Add dimensions to `candidate_ids` to have the same rank as `labels`.
85
+ if candidate_ids_rank < labels_rank:
86
+ candidate_ids = ops.expand_dims(
87
+ candidate_ids, list(range(labels_rank - candidate_ids_rank))
88
+ )
89
+ positive_indices = ops.expand_dims(ops.argmax(labels, axis=-1), -1)
90
+ positive_candidate_ids = ops.take(candidate_ids, positive_indices)
91
+
92
+ duplicate = ops.cast(
93
+ ops.equal(positive_candidate_ids, candidate_ids), labels.dtype
94
+ )
95
+ duplicate = ops.subtract(duplicate, labels)
96
+
97
+ return ops.add(logits, ops.multiply(duplicate, SMALLEST_FLOAT))
@@ -0,0 +1,127 @@
1
+ import abc
2
+ from typing import Any
3
+
4
+ import keras
5
+
6
+ from keras_rs.src import types
7
+ from keras_rs.src.api_export import keras_rs_export
8
+
9
+
10
+ @keras_rs_export("keras_rs.layers.Retrieval")
11
+ class Retrieval(keras.layers.Layer, abc.ABC):
12
+ """Retrieval base abstract class.
13
+
14
+ This layer provides a common interface for all retrieval layers. In order
15
+ to implement a custom retrieval layer, this abstract class should be
16
+ subclassed.
17
+
18
+ Args:
19
+ k: int. Number of candidates to retrieve.
20
+ return_scores: bool. When `True`, this layer returns a tuple with the
21
+ top scores and the top identifiers. When `False`, this layer returns
22
+ a single tensor with the top identifiers.
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ k: int = 10,
28
+ return_scores: bool = True,
29
+ **kwargs: Any,
30
+ ) -> None:
31
+ super().__init__(**kwargs)
32
+ self.k = k
33
+ self.return_scores = return_scores
34
+
35
+ def _validate_candidate_embeddings_and_ids(
36
+ self,
37
+ candidate_embeddings: types.Tensor,
38
+ candidate_ids: types.Tensor | None = None,
39
+ ) -> None:
40
+ """Validates inputs to `update_candidates()`."""
41
+
42
+ if candidate_embeddings is None:
43
+ raise ValueError("`candidate_embeddings` is required.")
44
+
45
+ if len(candidate_embeddings.shape) != 2:
46
+ raise ValueError(
47
+ "`candidate_embeddings` must be a tensor of rank 2 "
48
+ "(num_candidates, embedding_size), received "
49
+ "`candidate_embeddings` with shape "
50
+ f"{candidate_embeddings.shape}"
51
+ )
52
+
53
+ if candidate_embeddings.shape[0] < self.k:
54
+ raise ValueError(
55
+ "The number of candidates provided "
56
+ f"({candidate_embeddings.shape[0]}) is less than the number of "
57
+ f"candidates to retrieve (k={self.k})."
58
+ )
59
+
60
+ if (
61
+ candidate_ids is not None
62
+ and candidate_ids.shape[0] != candidate_embeddings.shape[0]
63
+ ):
64
+ raise ValueError(
65
+ "The `candidate_embeddings` and `candidate_is` tensors must "
66
+ "have the same number of rows, got tensors of shape "
67
+ f"{candidate_embeddings.shape} and {candidate_ids.shape}."
68
+ )
69
+
70
+ @abc.abstractmethod
71
+ def update_candidates(
72
+ self,
73
+ candidate_embeddings: types.Tensor,
74
+ candidate_ids: types.Tensor | None = None,
75
+ ) -> None:
76
+ """Update the set of candidates and optionally their candidate IDs.
77
+
78
+ Args:
79
+ candidate_embeddings: The candidate embeddings.
80
+ candidate_ids: The identifiers for the candidates. If `None`, the
81
+ indices of the candidates are returned instead.
82
+ """
83
+ pass
84
+
85
+ @abc.abstractmethod
86
+ def call(
87
+ self, inputs: types.Tensor
88
+ ) -> types.Tensor | tuple[types.Tensor, types.Tensor]:
89
+ """Returns the top candidates for the query passed as input.
90
+
91
+ Args:
92
+ inputs: the query for which to return top candidates.
93
+
94
+ Returns:
95
+ A tuple with the top scores and the top identifiers if
96
+ `returns_scores` is True, otherwise a tensor with the top
97
+ identifiers.
98
+ """
99
+ pass
100
+
101
+ def compute_score(
102
+ self, query_embedding: types.Tensor, candidate_embedding: types.Tensor
103
+ ) -> types.Tensor:
104
+ """Computes the standard dot product score from queries and candidates.
105
+
106
+ Args:
107
+ query_embedding: Tensor of query embedding corresponding to the
108
+ queries for which to retrieve top candidates.
109
+ candidate_embedding: Tensor of candidate embeddings.
110
+
111
+ Returns:
112
+ The dot product of queries and candidates.
113
+ """
114
+
115
+ return keras.ops.matmul(
116
+ query_embedding, keras.ops.transpose(candidate_embedding)
117
+ )
118
+
119
+ def get_config(self) -> dict[str, Any]:
120
+ config: dict[str, Any] = super().get_config()
121
+ config.update(
122
+ {
123
+ "k": self.k,
124
+ "return_scores": self.compute_score,
125
+ }
126
+ )
127
+ return config
@@ -0,0 +1,63 @@
1
+ from typing import Any
2
+
3
+ import keras
4
+ from keras import ops
5
+
6
+ from keras_rs.src import types
7
+ from keras_rs.src.api_export import keras_rs_export
8
+
9
+
10
+ @keras_rs_export("keras_rs.layers.SamplingProbabilityCorrection")
11
+ class SamplingProbabilityCorrection(keras.layers.Layer):
12
+ """Sampling probability correction.
13
+
14
+ Corrects the logits to reflect the sampling probability of negatives.
15
+
16
+ Args:
17
+ epsilon: float. Small float added to sampling probability to avoid
18
+ taking the log of zero. Defaults to 1e-6.
19
+ **kwargs: Args to pass to the base class.
20
+
21
+ Example:
22
+
23
+ ```python
24
+ # Create the layer.
25
+ sampling_probability_correction = (
26
+ keras_rs.layers.SamplingProbabilityCorrection()
27
+ )
28
+
29
+ # Correct the logits based on the provided candidate sampling probability.
30
+ logits = sampling_probability_correction(logits, probabilities)
31
+ ```
32
+ """
33
+
34
+ def __init__(self, epsilon: float = 1e-6, **kwargs: Any) -> None:
35
+ super().__init__(**kwargs)
36
+ self.epsilon = epsilon
37
+ self.built = True
38
+
39
+ def call(
40
+ self,
41
+ logits: types.Tensor,
42
+ candidate_sampling_probability: types.Tensor,
43
+ ) -> types.Tensor:
44
+ """Corrects input logits to account for candidate sampling probability.
45
+
46
+ Args:
47
+ logits: The logits tensor to correct, typically
48
+ `[batch_size, num_candidates]` but can have more dimensions or
49
+ be 1D as `[num_candidates]`.
50
+ candidate_sampling_probability: The sampling probability with the
51
+ same shape as `logits`.
52
+
53
+ Returns:
54
+ The corrected logits with the same shape as the input logits.
55
+ """
56
+ return logits - ops.log(
57
+ ops.clip(candidate_sampling_probability, self.epsilon, 1.0)
58
+ )
59
+
60
+ def get_config(self) -> dict[str, Any]:
61
+ config: dict[str, Any] = super().get_config()
62
+ config.update({"epsilon": self.epsilon})
63
+ return config
File without changes