keras-rs-nightly 0.0.1.dev2025021903__py3-none-any.whl → 0.3.1.dev202512130338__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. keras_rs/__init__.py +9 -28
  2. keras_rs/layers/__init__.py +37 -0
  3. keras_rs/losses/__init__.py +19 -0
  4. keras_rs/metrics/__init__.py +16 -0
  5. keras_rs/src/layers/embedding/base_distributed_embedding.py +1151 -0
  6. keras_rs/src/layers/embedding/distributed_embedding.py +33 -0
  7. keras_rs/src/layers/embedding/distributed_embedding_config.py +132 -0
  8. keras_rs/src/layers/embedding/embed_reduce.py +309 -0
  9. keras_rs/src/layers/embedding/jax/__init__.py +0 -0
  10. keras_rs/src/layers/embedding/jax/checkpoint_utils.py +104 -0
  11. keras_rs/src/layers/embedding/jax/config_conversion.py +468 -0
  12. keras_rs/src/layers/embedding/jax/distributed_embedding.py +829 -0
  13. keras_rs/src/layers/embedding/jax/embedding_lookup.py +276 -0
  14. keras_rs/src/layers/embedding/jax/embedding_utils.py +217 -0
  15. keras_rs/src/layers/embedding/tensorflow/__init__.py +0 -0
  16. keras_rs/src/layers/embedding/tensorflow/config_conversion.py +363 -0
  17. keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py +436 -0
  18. keras_rs/src/layers/feature_interaction/__init__.py +0 -0
  19. keras_rs/src/layers/{modeling → feature_interaction}/dot_interaction.py +116 -25
  20. keras_rs/src/layers/{modeling → feature_interaction}/feature_cross.py +40 -22
  21. keras_rs/src/layers/retrieval/brute_force_retrieval.py +16 -65
  22. keras_rs/src/layers/retrieval/hard_negative_mining.py +94 -0
  23. keras_rs/src/layers/retrieval/remove_accidental_hits.py +97 -0
  24. keras_rs/src/layers/retrieval/retrieval.py +127 -0
  25. keras_rs/src/layers/retrieval/sampling_probability_correction.py +63 -0
  26. keras_rs/src/losses/__init__.py +0 -0
  27. keras_rs/src/losses/list_mle_loss.py +212 -0
  28. keras_rs/src/losses/pairwise_hinge_loss.py +90 -0
  29. keras_rs/src/losses/pairwise_logistic_loss.py +99 -0
  30. keras_rs/src/losses/pairwise_loss.py +165 -0
  31. keras_rs/src/losses/pairwise_loss_utils.py +39 -0
  32. keras_rs/src/losses/pairwise_mean_squared_error.py +133 -0
  33. keras_rs/src/losses/pairwise_soft_zero_one_loss.py +98 -0
  34. keras_rs/src/metrics/__init__.py +0 -0
  35. keras_rs/src/metrics/dcg.py +161 -0
  36. keras_rs/src/metrics/mean_average_precision.py +130 -0
  37. keras_rs/src/metrics/mean_reciprocal_rank.py +121 -0
  38. keras_rs/src/metrics/ndcg.py +197 -0
  39. keras_rs/src/metrics/precision_at_k.py +117 -0
  40. keras_rs/src/metrics/ranking_metric.py +260 -0
  41. keras_rs/src/metrics/ranking_metrics_utils.py +257 -0
  42. keras_rs/src/metrics/recall_at_k.py +108 -0
  43. keras_rs/src/metrics/utils.py +70 -0
  44. keras_rs/src/types.py +43 -14
  45. keras_rs/src/utils/doc_string_utils.py +53 -0
  46. keras_rs/src/utils/keras_utils.py +52 -3
  47. keras_rs/src/utils/tpu_test_utils.py +120 -0
  48. keras_rs/src/version.py +1 -1
  49. {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/METADATA +88 -8
  50. keras_rs_nightly-0.3.1.dev202512130338.dist-info/RECORD +58 -0
  51. {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/WHEEL +1 -1
  52. keras_rs/api/__init__.py +0 -9
  53. keras_rs/api/layers/__init__.py +0 -11
  54. keras_rs_nightly-0.0.1.dev2025021903.dist-info/RECORD +0 -19
  55. /keras_rs/src/layers/{modeling → embedding}/__init__.py +0 -0
  56. {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,197 @@
1
+ from typing import Any, Callable
2
+
3
+ from keras import ops
4
+ from keras.saving import deserialize_keras_object
5
+ from keras.saving import serialize_keras_object
6
+
7
+ from keras_rs.src import types
8
+ from keras_rs.src.api_export import keras_rs_export
9
+ from keras_rs.src.metrics.ranking_metric import RankingMetric
10
+ from keras_rs.src.metrics.ranking_metric import (
11
+ ranking_metric_subclass_doc_string,
12
+ )
13
+ from keras_rs.src.metrics.ranking_metric import (
14
+ ranking_metric_subclass_doc_string_post_desc,
15
+ )
16
+ from keras_rs.src.metrics.ranking_metrics_utils import compute_dcg
17
+ from keras_rs.src.metrics.ranking_metrics_utils import default_gain_fn
18
+ from keras_rs.src.metrics.ranking_metrics_utils import default_rank_discount_fn
19
+ from keras_rs.src.metrics.ranking_metrics_utils import get_list_weights
20
+ from keras_rs.src.metrics.ranking_metrics_utils import sort_by_scores
21
+ from keras_rs.src.utils.doc_string_utils import format_docstring
22
+
23
+
24
+ @keras_rs_export("keras_rs.metrics.NDCG")
25
+ class NDCG(RankingMetric):
26
+ def __init__(
27
+ self,
28
+ k: int | None = None,
29
+ gain_fn: Callable[[types.Tensor], types.Tensor] = default_gain_fn,
30
+ rank_discount_fn: Callable[
31
+ [types.Tensor], types.Tensor
32
+ ] = default_rank_discount_fn,
33
+ **kwargs: Any,
34
+ ) -> None:
35
+ super().__init__(k=k, **kwargs)
36
+
37
+ self.gain_fn = gain_fn
38
+ self.rank_discount_fn = rank_discount_fn
39
+
40
+ def compute_metric(
41
+ self,
42
+ y_true: types.Tensor,
43
+ y_pred: types.Tensor,
44
+ mask: types.Tensor,
45
+ sample_weight: types.Tensor,
46
+ ) -> types.Tensor:
47
+ sorted_y_true, sorted_weights = sort_by_scores(
48
+ tensors_to_sort=[y_true, sample_weight],
49
+ scores=y_pred,
50
+ k=self.k,
51
+ mask=mask,
52
+ shuffle_ties=self.shuffle_ties,
53
+ seed=self.seed_generator,
54
+ )
55
+
56
+ dcg = compute_dcg(
57
+ y_true=sorted_y_true,
58
+ sample_weight=sorted_weights,
59
+ gain_fn=self.gain_fn,
60
+ rank_discount_fn=self.rank_discount_fn,
61
+ )
62
+
63
+ weighted_gains = ops.multiply(
64
+ sample_weight,
65
+ self.gain_fn(y_true),
66
+ )
67
+ ideal_sorted_y_true, ideal_sorted_weights = sort_by_scores(
68
+ tensors_to_sort=[y_true, sample_weight],
69
+ scores=weighted_gains,
70
+ k=self.k,
71
+ mask=mask,
72
+ shuffle_ties=self.shuffle_ties,
73
+ seed=self.seed_generator,
74
+ )
75
+ ideal_dcg = compute_dcg(
76
+ y_true=ideal_sorted_y_true,
77
+ sample_weight=ideal_sorted_weights,
78
+ gain_fn=self.gain_fn,
79
+ rank_discount_fn=self.rank_discount_fn,
80
+ )
81
+ per_list_ndcg = ops.divide_no_nan(dcg, ideal_dcg)
82
+
83
+ per_list_weights = get_list_weights(
84
+ weights=sample_weight, relevance=self.gain_fn(y_true)
85
+ )
86
+
87
+ return per_list_ndcg, per_list_weights
88
+
89
+ def get_config(self) -> dict[str, Any]:
90
+ config: dict[str, Any] = super().get_config()
91
+ config.update(
92
+ {
93
+ "gain_fn": serialize_keras_object(self.gain_fn),
94
+ "rank_discount_fn": serialize_keras_object(
95
+ self.rank_discount_fn
96
+ ),
97
+ }
98
+ )
99
+ return config
100
+
101
+ @classmethod
102
+ def from_config(cls, config: dict[str, Any]) -> "NDCG":
103
+ config["gain_fn"] = deserialize_keras_object(config["gain_fn"])
104
+ config["rank_discount_fn"] = deserialize_keras_object(
105
+ config["rank_discount_fn"]
106
+ )
107
+ return cls(**config)
108
+
109
+
110
+ concept_sentence = (
111
+ "It normalizes the Discounted Cumulative Gain (DCG) with the Ideal "
112
+ "Discounted Cumulative Gain (IDCG) for each list"
113
+ )
114
+ relevance_type = (
115
+ "graded relevance scores (non-negative numbers where higher values "
116
+ "indicate greater relevance)"
117
+ )
118
+ score_range_interpretation = (
119
+ "A normalized score (between 0 and 1) is returned. A score of 1 "
120
+ "represents the perfect ranking according to true relevance (within the "
121
+ "top-k), while 0 typically represents a ranking with no relevant items. "
122
+ "Higher scores indicate better ranking quality relative to the best "
123
+ "possible ranking"
124
+ )
125
+
126
+ formula = """
127
+ ```
128
+ nDCG@k = DCG@k / IDCG@k
129
+ ```
130
+
131
+ where DCG@k is calculated based on the predicted ranking (`y_pred`):
132
+
133
+ ```
134
+ DCG@k(y') = sum_{i=1}^{k} (gain_fn(y'_i) / rank_discount_fn(i))
135
+ ```
136
+
137
+ And IDCG@k is the Ideal DCG, calculated using the same formula but on items
138
+ sorted perfectly by their *true relevance* (`y_true`):
139
+
140
+ ```
141
+ IDCG@k(y'') = sum_{i=1}^{k} (gain_fn(y''_i) / rank_discount_fn(i))
142
+ ```
143
+
144
+ where:
145
+ - `y'_i`: True relevance of the item at rank `i` in the ranking induced by
146
+ `y_pred`.
147
+ - `y''_i` True relevance of the item at rank `i` in the *ideal* ranking (sorted
148
+ by `y_true` descending).
149
+ - `gain_fn` is the user-provided function mapping relevance to gain. The default
150
+ function (`default_gain_fn`) is typically equivalent to `lambda y: 2**y - 1`.
151
+ - `rank_discount_fn` is the user-provided function mapping rank `i` (1-based) to
152
+ a discount value. The default function (`default_rank_discount_fn`) is
153
+ typically equivalent to `lambda rank: 1 / log2(rank + 1)`.
154
+ - If IDCG@k is 0 (e.g., no relevant items), nDCG@k is defined as 0.
155
+ - The final result often aggregates these per-list nDCG scores, potentially
156
+ involving normalization by list-specific weights, to produce a weighted
157
+ average."""
158
+ extra_args = """
159
+ gain_fn: callable. Maps relevance scores (`y_true`) to gain values. The
160
+ default implements `2**y - 1`.
161
+ rank_discount_fn: function. Maps rank positions to discount
162
+ values. The default (`default_rank_discount_fn`) implements
163
+ `1 / log2(rank + 1)`."""
164
+ example = """
165
+ >>> batch_size = 2
166
+ >>> list_size = 5
167
+ >>> labels = np.random.randint(0, 3, size=(batch_size, list_size))
168
+ >>> scores = np.random.random(size=(batch_size, list_size))
169
+ >>> metric = keras_rs.metrics.NDCG()(
170
+ ... y_true=labels, y_pred=scores
171
+ ... )
172
+
173
+ Mask certain elements (can be used for uneven inputs):
174
+
175
+ >>> batch_size = 2
176
+ >>> list_size = 5
177
+ >>> labels = np.random.randint(0, 3, size=(batch_size, list_size))
178
+ >>> scores = np.random.random(size=(batch_size, list_size))
179
+ >>> mask = np.random.randint(0, 2, size=(batch_size, list_size), dtype=bool)
180
+ >>> metric = keras_rs.metrics.NDCG()(
181
+ ... y_true={"labels": labels, "mask": mask}, y_pred=scores
182
+ ... )
183
+ """
184
+
185
+ NDCG.__doc__ = format_docstring(
186
+ ranking_metric_subclass_doc_string,
187
+ width=80,
188
+ metric_name="Normalised Discounted Cumulative Gain",
189
+ metric_abbreviation="nDCG",
190
+ concept_sentence=concept_sentence,
191
+ relevance_type=relevance_type,
192
+ score_range_interpretation=score_range_interpretation,
193
+ formula=formula,
194
+ extra_args=extra_args,
195
+ ) + ranking_metric_subclass_doc_string_post_desc.format(
196
+ extra_args=extra_args, example=example
197
+ )
@@ -0,0 +1,117 @@
1
+ from keras import ops
2
+
3
+ from keras_rs.src import types
4
+ from keras_rs.src.api_export import keras_rs_export
5
+ from keras_rs.src.metrics.ranking_metric import RankingMetric
6
+ from keras_rs.src.metrics.ranking_metric import (
7
+ ranking_metric_subclass_doc_string,
8
+ )
9
+ from keras_rs.src.metrics.ranking_metric import (
10
+ ranking_metric_subclass_doc_string_post_desc,
11
+ )
12
+ from keras_rs.src.metrics.ranking_metrics_utils import get_list_weights
13
+ from keras_rs.src.metrics.ranking_metrics_utils import sort_by_scores
14
+ from keras_rs.src.utils.doc_string_utils import format_docstring
15
+
16
+
17
+ @keras_rs_export("keras_rs.metrics.PrecisionAtK")
18
+ class PrecisionAtK(RankingMetric):
19
+ def compute_metric(
20
+ self,
21
+ y_true: types.Tensor,
22
+ y_pred: types.Tensor,
23
+ mask: types.Tensor,
24
+ sample_weight: types.Tensor,
25
+ ) -> types.Tensor:
26
+ # Assume: `y_true = [0, 0, 1]`, `y_pred = [0.1, 0.9, 0.2]`.
27
+ # `sorted_y_true = [0, 1, 0]` (sorted in descending order).
28
+ (sorted_y_true,) = sort_by_scores(
29
+ tensors_to_sort=[y_true],
30
+ scores=y_pred,
31
+ mask=mask,
32
+ k=self.k,
33
+ shuffle_ties=self.shuffle_ties,
34
+ seed=self.seed_generator,
35
+ )
36
+
37
+ # We consider only binary relevance here, anything above 1 is treated
38
+ # as 1. `relevance = [0., 1., 0.]`.
39
+ relevance = ops.cast(
40
+ ops.greater_equal(
41
+ sorted_y_true, ops.cast(1, dtype=sorted_y_true.dtype)
42
+ ),
43
+ dtype=y_pred.dtype,
44
+ )
45
+ list_length = ops.shape(sorted_y_true)[1]
46
+ # TODO: We do not do this for MRR, and the other metrics. Do we need to
47
+ # do this there too?
48
+ valid_list_length = ops.minimum(
49
+ list_length,
50
+ ops.sum(ops.cast(mask, dtype="int32"), axis=1, keepdims=True),
51
+ )
52
+
53
+ per_list_precision = ops.divide_no_nan(
54
+ ops.sum(relevance, axis=1, keepdims=True),
55
+ ops.cast(valid_list_length, dtype=y_pred.dtype),
56
+ )
57
+
58
+ # Get weights.
59
+ overall_relevance = ops.cast(
60
+ ops.greater_equal(y_true, ops.cast(1, dtype=y_true.dtype)),
61
+ dtype=y_pred.dtype,
62
+ )
63
+ per_list_weights = get_list_weights(
64
+ weights=sample_weight, relevance=overall_relevance
65
+ )
66
+
67
+ return per_list_precision, per_list_weights
68
+
69
+
70
+ concept_sentence = (
71
+ "It measures the proportion of relevant items among the top-k "
72
+ "recommendations"
73
+ )
74
+ relevance_type = "binary indicators (0 or 1) of relevance"
75
+ score_range_interpretation = (
76
+ "Scores range from 0 to 1, with 1 indicating all top-k items were relevant"
77
+ )
78
+ formula = """```
79
+ P@k(y, s) = 1/k sum_i I[rank(s_i) < k] y_i
80
+ ```
81
+
82
+ where `y_i` is the relevance label (0/1) of the item ranked at position
83
+ `i`, and `I[condition]` is 1 if the condition is met, otherwise 0."""
84
+ extra_args = ""
85
+ example = """
86
+ >>> batch_size = 2
87
+ >>> list_size = 5
88
+ >>> labels = np.random.randint(0, 2, size=(batch_size, list_size))
89
+ >>> scores = np.random.random(size=(batch_size, list_size))
90
+ >>> metric = keras_rs.metrics.PrecisionAtK()(
91
+ ... y_true=labels, y_pred=scores
92
+ ... )
93
+
94
+ Mask certain elements (can be used for uneven inputs):
95
+
96
+ >>> batch_size = 2
97
+ >>> list_size = 5
98
+ >>> labels = np.random.randint(0, 2, size=(batch_size, list_size))
99
+ >>> scores = np.random.random(size=(batch_size, list_size))
100
+ >>> mask = np.random.randint(0, 2, size=(batch_size, list_size), dtype=bool)
101
+ >>> metric = keras_rs.metrics.PrecisionAtK()(
102
+ ... y_true={"labels": labels, "mask": mask}, y_pred=scores
103
+ ... )
104
+ """
105
+
106
+ PrecisionAtK.__doc__ = format_docstring(
107
+ ranking_metric_subclass_doc_string,
108
+ width=80,
109
+ metric_name="Precision@k",
110
+ metric_abbreviation="P@k",
111
+ concept_sentence=concept_sentence,
112
+ relevance_type=relevance_type,
113
+ score_range_interpretation=score_range_interpretation,
114
+ formula=formula,
115
+ ) + ranking_metric_subclass_doc_string_post_desc.format(
116
+ extra_args=extra_args, example=example
117
+ )
@@ -0,0 +1,260 @@
1
+ import abc
2
+ from typing import Any
3
+
4
+ import keras
5
+ from keras import ops
6
+
7
+ from keras_rs.src import types
8
+ from keras_rs.src.metrics.utils import standardize_call_inputs_ranks
9
+ from keras_rs.src.utils.keras_utils import check_rank
10
+ from keras_rs.src.utils.keras_utils import check_shapes_compatible
11
+
12
+
13
+ class RankingMetric(keras.metrics.Mean, abc.ABC):
14
+ """Base class for ranking evaluation metrics (e.g., MAP, MRR, DCG, nDCG).
15
+
16
+ Ranking metrics are used to evaluate the quality of ranked lists produced
17
+ by a ranking model. The primary goal in ranking tasks is to order items
18
+ according to their relevance or utility for a given query or context.
19
+ These metrics provide quantitative measures of how well a model achieves
20
+ this goal, typically by comparing the predicted order of items against the
21
+ ground truth relevance judgments for each list.
22
+
23
+ To define your own ranking metric, subclass this class and implement the
24
+ `compute_metric` method.
25
+
26
+ Args:
27
+ k: int. The number of top-ranked items to consider (the 'k' in 'top-k').
28
+ Must be a positive integer.
29
+ shuffle_ties: bool. Whether to randomly shuffle scores before sorting.
30
+ This is done to break ties. Defaults to `True`.
31
+ seed: int. Random seed used for shuffling.
32
+ name: Optional name for the loss instance.
33
+ dtype: The dtype of the metric's computations. Defaults to `None`, which
34
+ means using `keras.backend.floatx()`. `keras.backend.floatx()` is a
35
+ `"float32"` unless set to different value
36
+ (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
37
+ provided, then the `compute_dtype` will be utilized.
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ k: int | None = None,
43
+ shuffle_ties: bool = True,
44
+ seed: int | keras.random.SeedGenerator | None = None,
45
+ **kwargs: Any,
46
+ ) -> None:
47
+ super().__init__(**kwargs)
48
+
49
+ if k is not None and (not isinstance(k, int) or k < 1):
50
+ raise ValueError(
51
+ f"`k` should be a positive integer. Received: `k` = {k}."
52
+ )
53
+
54
+ self.k = k
55
+ self.shuffle_ties = shuffle_ties
56
+ self.seed = seed
57
+
58
+ # Define `SeedGenerator`. JAX doesn't work, otherwise.
59
+ self.seed_generator = keras.random.SeedGenerator(seed)
60
+
61
+ @abc.abstractmethod
62
+ def compute_metric(
63
+ self,
64
+ y_true: types.Tensor,
65
+ y_pred: types.Tensor,
66
+ mask: types.Tensor,
67
+ sample_weight: types.Tensor,
68
+ ) -> types.Tensor:
69
+ """Abstract method, should be implemented by subclasses."""
70
+ pass
71
+
72
+ def update_state(
73
+ self,
74
+ y_true: types.Tensor,
75
+ y_pred: types.Tensor,
76
+ sample_weight: types.Tensor | None = None,
77
+ ) -> None:
78
+ """
79
+ Accumulates statistics for the ranking metric.
80
+
81
+ Args:
82
+ y_true: tensor or dict. Ground truth values. If tensor, of shape
83
+ `(list_size)` for unbatched inputs or `(batch_size, list_size)`
84
+ for batched inputs. If an item has a label of -1, it is ignored
85
+ in loss computation. If it is a dictionary, it should have two
86
+ keys: `"labels"` and `"mask"`. `"mask"` can be used to ignore
87
+ elements in metric computation, i.e., pairs will not be formed
88
+ with those items. Note that the final mask is an `and` of the
89
+ passed mask, `labels >= 0`, and `sample_weight > 0`.
90
+ y_pred: tensor. The predicted values, of shape `(list_size)` for
91
+ unbatched inputs or `(batch_size, list_size)` for batched
92
+ inputs. Should be of the same shape as `y_true`.
93
+ sample_weight: float/tensor. Can be float value, or tensor of
94
+ shape `(list_size)` or `(batch_size, list_size)`. Defaults to
95
+ `None`.
96
+ """
97
+ # === Process `y_true`, if dict ===
98
+ passed_mask = None
99
+ if isinstance(y_true, dict):
100
+ if "labels" not in y_true:
101
+ raise ValueError(
102
+ '`"labels"` should be present in `y_true`. Received: '
103
+ f"`y_true` = {y_true}"
104
+ )
105
+
106
+ passed_mask = y_true.get("mask", None)
107
+ y_true = y_true["labels"]
108
+
109
+ # === Convert to tensors, if list ===
110
+ # TODO (abheesht): Figure out if we need to cast tensors to
111
+ # `self.dtype`.
112
+ y_true = ops.convert_to_tensor(y_true)
113
+ y_pred = ops.convert_to_tensor(y_pred)
114
+ if sample_weight is not None:
115
+ sample_weight = ops.convert_to_tensor(sample_weight)
116
+ if passed_mask is not None:
117
+ passed_mask = ops.convert_to_tensor(passed_mask)
118
+
119
+ # Cast to the correct dtype.
120
+ y_true = ops.cast(y_true, dtype=self.dtype)
121
+ y_pred = ops.cast(y_pred, dtype=self.dtype)
122
+ if sample_weight is not None:
123
+ sample_weight = ops.cast(sample_weight, dtype=self.dtype)
124
+
125
+ # === Process `sample_weight` ===
126
+ if sample_weight is None:
127
+ sample_weight = ops.cast(1, dtype=y_pred.dtype)
128
+
129
+ y_true_shape = ops.shape(y_true)
130
+ y_true_rank = len(y_true_shape)
131
+ sample_weight_shape = ops.shape(sample_weight)
132
+ sample_weight_rank = len(sample_weight_shape)
133
+
134
+ # Check `y_true_rank` first. Can be 1 for unbatched inputs, 2 for
135
+ # batched.
136
+ check_rank(y_true_rank, allowed_ranks=(1, 2), tensor_name="y_true")
137
+
138
+ # Check `sample_weight` rank. Should be between 0 and `y_true_rank`.
139
+ check_rank(
140
+ sample_weight_rank,
141
+ allowed_ranks=tuple(range(y_true_rank + 1)),
142
+ tensor_name="sample_weight",
143
+ )
144
+
145
+ if y_true_rank == 2:
146
+ # If `sample_weight` rank is 1, it should be of shape
147
+ # `(batch_size,)`. Otherwise, it should be of shape
148
+ # `(batch_size, list_size)`.
149
+ if sample_weight_rank == 1:
150
+ check_shapes_compatible(sample_weight_shape, (y_true_shape[0],))
151
+ # Uprank this, so that we get per-list weights here.
152
+ sample_weight = ops.expand_dims(sample_weight, axis=1)
153
+ elif sample_weight_rank == 2:
154
+ check_shapes_compatible(sample_weight_shape, y_true_shape)
155
+
156
+ # Reshape `sample_weight` to the shape of `y_true`.
157
+ sample_weight = ops.multiply(ops.ones_like(y_true), sample_weight)
158
+
159
+ # Mask all values less than 0 (since less than 0 implies invalid
160
+ # labels).
161
+ valid_mask = ops.greater_equal(y_true, ops.cast(0, y_true.dtype))
162
+ if passed_mask is not None:
163
+ valid_mask = ops.logical_and(valid_mask, passed_mask)
164
+
165
+ # === Process inputs - shape checking, upranking, etc. ===
166
+ y_true, y_pred, valid_mask, batched = standardize_call_inputs_ranks(
167
+ y_true=y_true,
168
+ y_pred=y_pred,
169
+ mask=valid_mask,
170
+ check_y_true_rank=False,
171
+ )
172
+
173
+ # Uprank sample_weight if unbatched.
174
+ if not batched:
175
+ sample_weight = ops.expand_dims(sample_weight, axis=0)
176
+
177
+ # Get `mask` from `sample_weight`.
178
+ sample_weight_mask = ops.greater(
179
+ sample_weight, ops.cast(0, dtype=sample_weight.dtype)
180
+ )
181
+ mask = ops.logical_and(valid_mask, sample_weight_mask)
182
+
183
+ # === Update "invalid" `y_true`, `y_pred` entries based on mask ===
184
+
185
+ # `y_true`: assign 0 for invalid entries
186
+ y_true = ops.where(mask, y_true, ops.zeros_like(y_true))
187
+ # `y_pred`: assign a value slightly smaller than the smallest value
188
+ # so that it gets sorted last.
189
+ y_pred = ops.where(
190
+ mask,
191
+ y_pred,
192
+ -keras.config.epsilon() * ops.ones_like(y_pred)
193
+ + ops.amin(y_pred, axis=1, keepdims=True),
194
+ )
195
+ sample_weight = ops.where(
196
+ mask, sample_weight, ops.cast(0, sample_weight.dtype)
197
+ )
198
+
199
+ # === Actual computation ===
200
+ per_list_metric_values, per_list_metric_weights = self.compute_metric(
201
+ y_true=y_true, y_pred=y_pred, mask=mask, sample_weight=sample_weight
202
+ )
203
+
204
+ # Chain to `super()` to get mean metric.
205
+ # TODO(abheesht): Figure out if we want to return unaggregated metric
206
+ # values too of shape `(batch_size,)` from `result()`.
207
+ super().update_state(
208
+ per_list_metric_values, sample_weight=per_list_metric_weights
209
+ )
210
+
211
+ def get_config(self) -> dict[str, Any]:
212
+ config: dict[str, Any] = super().get_config()
213
+ config.update(
214
+ {"k": self.k, "shuffle_ties": self.shuffle_ties, "seed": self.seed}
215
+ )
216
+ return config
217
+
218
+
219
+ ranking_metric_subclass_doc_string = """
220
+ Computes {metric_name} ({metric_abbreviation}).
221
+
222
+ This metric evaluates ranking quality. {concept_sentence}. The metric processes
223
+ true relevance labels in `y_true` ({relevance_type}) against predicted scores in
224
+ `y_pred`. The scores in `y_pred` are used to determine the rank order of items,
225
+ by sorting in descending order. {score_range_interpretation}.
226
+
227
+ For each list of predicted scores `s` in `y_pred` and the corresponding list
228
+ of true labels `y` in `y_true`, the per-query {metric_abbreviation} score is
229
+ calculated as follows:
230
+ {formula}
231
+
232
+ The final {metric_abbreviation} score reported is typically the weighted
233
+ average of these per-query scores across all queries/lists in the dataset.
234
+
235
+ Note: `sample_weight` is handled differently for ranking metrics. For
236
+ batched inputs, `sample_weight` can be scalar, 1D, 2D. The scalar case and
237
+ 1D case (list-wise weights) are straightforward. The 2D case (item-wise
238
+ weights) is different, in the sense that the sample weights are aggregated
239
+ to get 1D weights. For more details, refer to
240
+ `keras_rs.src.metrics.ranking_metrics_utils.get_list_weights`.
241
+ """
242
+
243
+ ranking_metric_subclass_doc_string_post_desc = """
244
+
245
+ Args:{extra_args}
246
+ k: int. The number of top-ranked items to consider (the 'k' in 'top-k').
247
+ Must be a positive integer.
248
+ shuffle_ties: bool. Whether to randomly shuffle scores before sorting.
249
+ This is done to break ties. Defaults to `True`.
250
+ seed: int. Random seed used for shuffling.
251
+ name: Optional name for the loss instance.
252
+ dtype: The dtype of the metric's computations. Defaults to `None`, which
253
+ means using `keras.backend.floatx()`. `keras.backend.floatx()` is a
254
+ `"float32"` unless set to different value
255
+ (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
256
+ provided, then the `compute_dtype` will be utilized.
257
+
258
+ Example:
259
+ {example}
260
+ """