keras-rs-nightly 0.0.1.dev2025042103__py3-none-any.whl → 0.0.1.dev2025042503__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of keras-rs-nightly might be problematic. Click here for more details.

Files changed (31) hide show
  1. keras_rs/__init__.py +9 -28
  2. keras_rs/{api/layers → layers}/__init__.py +9 -7
  3. keras_rs/losses/__init__.py +18 -0
  4. keras_rs/metrics/__init__.py +16 -0
  5. keras_rs/src/losses/pairwise_hinge_loss.py +1 -0
  6. keras_rs/src/losses/pairwise_logistic_loss.py +1 -0
  7. keras_rs/src/losses/pairwise_loss.py +36 -12
  8. keras_rs/src/losses/pairwise_loss_utils.py +39 -0
  9. keras_rs/src/losses/pairwise_mean_squared_error.py +2 -1
  10. keras_rs/src/losses/pairwise_soft_zero_one_loss.py +1 -0
  11. keras_rs/src/metrics/__init__.py +0 -0
  12. keras_rs/src/metrics/dcg.py +140 -0
  13. keras_rs/src/metrics/mean_average_precision.py +112 -0
  14. keras_rs/src/metrics/mean_reciprocal_rank.py +98 -0
  15. keras_rs/src/metrics/ndcg.py +184 -0
  16. keras_rs/src/metrics/precision_at_k.py +94 -0
  17. keras_rs/src/metrics/ranking_metric.py +252 -0
  18. keras_rs/src/metrics/ranking_metrics_utils.py +238 -0
  19. keras_rs/src/metrics/recall_at_k.py +85 -0
  20. keras_rs/src/metrics/utils.py +72 -0
  21. keras_rs/src/utils/doc_string_utils.py +48 -0
  22. keras_rs/src/utils/keras_utils.py +12 -0
  23. keras_rs/src/version.py +1 -1
  24. {keras_rs_nightly-0.0.1.dev2025042103.dist-info → keras_rs_nightly-0.0.1.dev2025042503.dist-info}/METADATA +4 -3
  25. keras_rs_nightly-0.0.1.dev2025042503.dist-info/RECORD +42 -0
  26. {keras_rs_nightly-0.0.1.dev2025042103.dist-info → keras_rs_nightly-0.0.1.dev2025042503.dist-info}/WHEEL +1 -1
  27. keras_rs/api/__init__.py +0 -10
  28. keras_rs/api/losses/__init__.py +0 -14
  29. keras_rs/src/utils/pairwise_loss_utils.py +0 -102
  30. keras_rs_nightly-0.0.1.dev2025042103.dist-info/RECORD +0 -31
  31. {keras_rs_nightly-0.0.1.dev2025042103.dist-info → keras_rs_nightly-0.0.1.dev2025042503.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,184 @@
1
+ from typing import Any, Callable, Optional
2
+
3
+ from keras import ops
4
+ from keras.saving import deserialize_keras_object
5
+ from keras.saving import serialize_keras_object
6
+
7
+ from keras_rs.src import types
8
+ from keras_rs.src.api_export import keras_rs_export
9
+ from keras_rs.src.metrics.ranking_metric import RankingMetric
10
+ from keras_rs.src.metrics.ranking_metric import (
11
+ ranking_metric_subclass_doc_string,
12
+ )
13
+ from keras_rs.src.metrics.ranking_metric import (
14
+ ranking_metric_subclass_doc_string_args,
15
+ )
16
+ from keras_rs.src.metrics.ranking_metrics_utils import compute_dcg
17
+ from keras_rs.src.metrics.ranking_metrics_utils import default_gain_fn
18
+ from keras_rs.src.metrics.ranking_metrics_utils import default_rank_discount_fn
19
+ from keras_rs.src.metrics.ranking_metrics_utils import get_list_weights
20
+ from keras_rs.src.metrics.ranking_metrics_utils import sort_by_scores
21
+ from keras_rs.src.utils.doc_string_utils import format_docstring
22
+
23
+
24
+ @keras_rs_export("keras_rs.metrics.NDCG")
25
+ class NDCG(RankingMetric):
26
+ def __init__(
27
+ self,
28
+ k: Optional[int] = None,
29
+ gain_fn: Callable[[types.Tensor], types.Tensor] = default_gain_fn,
30
+ rank_discount_fn: Callable[
31
+ [types.Tensor], types.Tensor
32
+ ] = default_rank_discount_fn,
33
+ **kwargs: Any,
34
+ ) -> None:
35
+ super().__init__(k=k, **kwargs)
36
+
37
+ self.gain_fn = gain_fn
38
+ self.rank_discount_fn = rank_discount_fn
39
+
40
+ def compute_metric(
41
+ self,
42
+ y_true: types.Tensor,
43
+ y_pred: types.Tensor,
44
+ mask: types.Tensor,
45
+ sample_weight: types.Tensor,
46
+ ) -> types.Tensor:
47
+ sorted_y_true, sorted_weights = sort_by_scores(
48
+ tensors_to_sort=[y_true, sample_weight],
49
+ scores=y_pred,
50
+ k=self.k,
51
+ mask=mask,
52
+ shuffle_ties=self.shuffle_ties,
53
+ seed=self.seed_generator,
54
+ )
55
+
56
+ dcg = compute_dcg(
57
+ y_true=sorted_y_true,
58
+ sample_weight=sorted_weights,
59
+ gain_fn=self.gain_fn,
60
+ rank_discount_fn=self.rank_discount_fn,
61
+ )
62
+
63
+ weighted_gains = ops.multiply(
64
+ sample_weight,
65
+ self.gain_fn(y_true),
66
+ )
67
+ ideal_sorted_y_true, ideal_sorted_weights = sort_by_scores(
68
+ tensors_to_sort=[y_true, sample_weight],
69
+ scores=weighted_gains,
70
+ k=self.k,
71
+ mask=mask,
72
+ shuffle_ties=self.shuffle_ties,
73
+ seed=self.seed_generator,
74
+ )
75
+ ideal_dcg = compute_dcg(
76
+ y_true=ideal_sorted_y_true,
77
+ sample_weight=ideal_sorted_weights,
78
+ gain_fn=self.gain_fn,
79
+ rank_discount_fn=self.rank_discount_fn,
80
+ )
81
+ per_list_ndcg = ops.divide_no_nan(dcg, ideal_dcg)
82
+
83
+ per_list_weights = get_list_weights(
84
+ weights=sample_weight, relevance=self.gain_fn(y_true)
85
+ )
86
+
87
+ return per_list_ndcg, per_list_weights
88
+
89
+ def get_config(self) -> dict[str, Any]:
90
+ config: dict[str, Any] = super().get_config()
91
+ config.update(
92
+ {
93
+ "gain_fn": serialize_keras_object(self.gain_fn),
94
+ "rank_discount_fn": serialize_keras_object(
95
+ self.rank_discount_fn
96
+ ),
97
+ }
98
+ )
99
+ return config
100
+
101
+ @classmethod
102
+ def from_config(cls, config: dict[str, Any]) -> "NDCG":
103
+ config["gain_fn"] = deserialize_keras_object(config["gain_fn"])
104
+ config["rank_discount_fn"] = deserialize_keras_object(
105
+ config["rank_discount_fn"]
106
+ )
107
+ return cls(**config)
108
+
109
+
110
+ concept_sentence = (
111
+ "It normalizes the Discounted Cumulative Gain (DCG) with the Ideal "
112
+ "Discounted Cumulative Gain (IDCG) for each list."
113
+ )
114
+ relevance_type = (
115
+ "graded relevance scores (non-negative numbers where higher values "
116
+ "indicate greater relevance)"
117
+ )
118
+ score_range_interpretation = (
119
+ "A normalized score (between 0 and 1) is returned. A score of 1 "
120
+ "represents the perfect ranking according to true relevance (within the "
121
+ "top-k), while 0 typically represents a ranking with no relevant items. "
122
+ "Higher scores indicate better ranking quality relative to the best "
123
+ "possible ranking"
124
+ )
125
+
126
+ formula = """
127
+ The metric calculates a weighted average nDCG score per list.
128
+ For a single list, nDCG is computed as the ratio of the Discounted
129
+ Cumulative Gain (DCG) of the predicted ranking to the Ideal Discounted
130
+ Cumulative Gain (IDCG) of the best possible ranking:
131
+
132
+ ```
133
+ nDCG@k = DCG@k / IDCG@k
134
+ ```
135
+
136
+ where DCG@k is calculated based on the predicted ranking (`y_pred`):
137
+
138
+ ```
139
+ DCG@k(y') = sum_{i=1}^{k} (gain_fn(y'_i) / rank_discount_fn(i))
140
+ ```
141
+
142
+ And IDCG@k is the Ideal DCG, calculated using the same formula but on items
143
+ sorted perfectly by their *true relevance* (`y_true`):
144
+
145
+ ```
146
+ IDCG@k(y'') = sum_{i=1}^{k} (gain_fn(y''_i) / rank_discount_fn(i))
147
+ ```
148
+
149
+ where:
150
+ - `y'_i`: True relevance of the item at rank `i` in
151
+ the ranking induced by `y_pred`.
152
+ - `y''_i` True relevance of the item at rank `i` in
153
+ the *ideal* ranking (sorted by `y_true` descending).
154
+ - `gain_fn` is the user-provided function mapping relevance to gain.
155
+ The default function (`default_gain_fn`) is typically equivalent to
156
+ `lambda y: 2**y - 1`.
157
+ - `rank_discount_fn` is the user-provided function mapping rank `i`
158
+ (1-based) to a discount value. The default function
159
+ (`default_rank_discount_fn`) is typically equivalent to
160
+ `lambda rank: 1 / log2(rank + 1)`.
161
+ - If IDCG@k is 0 (e.g., no relevant items), nDCG@k is defined as 0.
162
+ - The final result often aggregates these per-list nDCG scores,
163
+ potentially involving normalization by list-specific weights, to
164
+ produce a weighted average.
165
+ """
166
+ extra_args = """
167
+ gain_fn: callable. Maps relevance scores (`y_true`) to gain values. The
168
+ default implements `2**y - 1`. Used for both DCG and IDCG.
169
+ rank_discount_fn: callable. Maps rank positions (1-based) to discount
170
+ values. The default (`default_rank_discount_fn`) typically implements
171
+ `1 / log2(rank + 1)`. Used for both DCG and IDCG.
172
+ """
173
+
174
+ NDCG.__doc__ = format_docstring(
175
+ ranking_metric_subclass_doc_string,
176
+ width=80,
177
+ metric_name="Normalised Discounted Cumulative Gain",
178
+ metric_abbreviation="nDCG",
179
+ concept_sentence=concept_sentence,
180
+ relevance_type=relevance_type,
181
+ score_range_interpretation=score_range_interpretation,
182
+ formula=formula,
183
+ extra_args=extra_args,
184
+ ) + ranking_metric_subclass_doc_string_args.format(extra_args=extra_args)
@@ -0,0 +1,94 @@
1
+ from keras import ops
2
+
3
+ from keras_rs.src import types
4
+ from keras_rs.src.api_export import keras_rs_export
5
+ from keras_rs.src.metrics.ranking_metric import RankingMetric
6
+ from keras_rs.src.metrics.ranking_metric import (
7
+ ranking_metric_subclass_doc_string,
8
+ )
9
+ from keras_rs.src.metrics.ranking_metric import (
10
+ ranking_metric_subclass_doc_string_args,
11
+ )
12
+ from keras_rs.src.metrics.ranking_metrics_utils import get_list_weights
13
+ from keras_rs.src.metrics.ranking_metrics_utils import sort_by_scores
14
+ from keras_rs.src.utils.doc_string_utils import format_docstring
15
+
16
+
17
+ @keras_rs_export("keras_rs.metrics.PrecisionAtK")
18
+ class PrecisionAtK(RankingMetric):
19
+ def compute_metric(
20
+ self,
21
+ y_true: types.Tensor,
22
+ y_pred: types.Tensor,
23
+ mask: types.Tensor,
24
+ sample_weight: types.Tensor,
25
+ ) -> types.Tensor:
26
+ # Assume: `y_true = [0, 0, 1]`, `y_pred = [0.1, 0.9, 0.2]`.
27
+ # `sorted_y_true = [0, 1, 0]` (sorted in descending order).
28
+ (sorted_y_true,) = sort_by_scores(
29
+ tensors_to_sort=[y_true],
30
+ scores=y_pred,
31
+ mask=mask,
32
+ k=self.k,
33
+ shuffle_ties=self.shuffle_ties,
34
+ seed=self.seed_generator,
35
+ )
36
+
37
+ # We consider only binary relevance here, anything above 1 is treated
38
+ # as 1. `relevance = [0., 1., 0.]`.
39
+ relevance = ops.cast(
40
+ ops.greater_equal(
41
+ sorted_y_true, ops.cast(1, dtype=sorted_y_true.dtype)
42
+ ),
43
+ dtype="float32",
44
+ )
45
+ list_length = ops.shape(sorted_y_true)[1]
46
+ # TODO: We do not do this for MRR, and the other metrics. Do we need to
47
+ # do this there too?
48
+ valid_list_length = ops.minimum(
49
+ list_length,
50
+ ops.sum(ops.cast(mask, dtype="int32"), axis=1, keepdims=True),
51
+ )
52
+
53
+ per_list_precision = ops.divide_no_nan(
54
+ ops.sum(relevance, axis=1, keepdims=True),
55
+ ops.cast(valid_list_length, dtype="float32"),
56
+ )
57
+
58
+ # Get weights.
59
+ overall_relevance = ops.cast(
60
+ ops.greater_equal(y_true, ops.cast(1, dtype=y_true.dtype)),
61
+ dtype="float32",
62
+ )
63
+ per_list_weights = get_list_weights(
64
+ weights=sample_weight, relevance=overall_relevance
65
+ )
66
+
67
+ return per_list_precision, per_list_weights
68
+
69
+
70
+ concept_sentence = (
71
+ "It measures the proportion of relevant items among the top-k "
72
+ "recommendations"
73
+ )
74
+ relevance_type = "binary indicators (0 or 1) of relevance"
75
+ score_range_interpretation = (
76
+ "Scores range from 0 to 1, with 1 indicating all top-k items were relevant"
77
+ )
78
+ formula = """```
79
+ P@k(y, s) = 1/k sum_i I[rank(s_i) < k] y_i
80
+ ```
81
+
82
+ where `y_i` is the relevance label (0/1) of the item ranked at position
83
+ `i`, and `I[condition]` is 1 if the condition is met, otherwise 0."""
84
+ extra_args = ""
85
+ PrecisionAtK.__doc__ = format_docstring(
86
+ ranking_metric_subclass_doc_string,
87
+ width=80,
88
+ metric_name="Precision@k",
89
+ metric_abbreviation="P@k",
90
+ concept_sentence=concept_sentence,
91
+ relevance_type=relevance_type,
92
+ score_range_interpretation=score_range_interpretation,
93
+ formula=formula,
94
+ ) + ranking_metric_subclass_doc_string_args.format(extra_args=extra_args)
@@ -0,0 +1,252 @@
1
+ import abc
2
+ from typing import Any, Optional, Union
3
+
4
+ import keras
5
+ from keras import ops
6
+
7
+ from keras_rs.src import types
8
+ from keras_rs.src.metrics.utils import standardize_call_inputs_ranks
9
+ from keras_rs.src.utils.keras_utils import check_rank
10
+ from keras_rs.src.utils.keras_utils import check_shapes_compatible
11
+
12
+
13
+ class RankingMetric(keras.metrics.Mean, abc.ABC):
14
+ """Base class for ranking evaluation metrics (e.g., MAP, MRR, DCG, nDCG).
15
+
16
+ Ranking metrics are used to evaluate the quality of ranked lists produced
17
+ by a ranking model. The primary goal in ranking tasks is to order items
18
+ according to their relevance or utility for a given query or context.
19
+ These metrics provide quantitative measures of how well a model achieves
20
+ this goal, typically by comparing the predicted order of items against the
21
+ ground truth relevance judgments for each list.
22
+
23
+ To define your own ranking metric, subclass this class and implement the
24
+ `compute_metric` method.
25
+
26
+ Args:
27
+ k: int. The number of top-ranked items to consider (the 'k' in 'top-k').
28
+ Must be a positive integer.
29
+ shuffle_ties: bool. Whether to randomly shuffle scores before sorting.
30
+ This is done to break ties. Defaults to `True`.
31
+ seed: int. Random seed used for shuffling.
32
+ name: Optional name for the loss instance.
33
+ dtype: The dtype of the metric's computations. Defaults to `None`, which
34
+ means using `keras.backend.floatx()`. `keras.backend.floatx()` is a
35
+ `"float32"` unless set to different value
36
+ (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
37
+ provided, then the `compute_dtype` will be utilized.
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ k: Optional[int] = None,
43
+ shuffle_ties: bool = True,
44
+ seed: Optional[Union[int, keras.random.SeedGenerator]] = None,
45
+ **kwargs: Any,
46
+ ) -> None:
47
+ super().__init__(**kwargs)
48
+
49
+ if k is not None and (not isinstance(k, int) or k < 1):
50
+ raise ValueError(
51
+ f"`k` should be a positive integer. Received: `k` = {k}."
52
+ )
53
+
54
+ self.k = k
55
+ self.shuffle_ties = shuffle_ties
56
+ self.seed = seed
57
+
58
+ # Define `SeedGenerator`. JAX doesn't work, otherwise.
59
+ self.seed_generator = keras.random.SeedGenerator(seed)
60
+
61
+ @abc.abstractmethod
62
+ def compute_metric(
63
+ self,
64
+ y_true: types.Tensor,
65
+ y_pred: types.Tensor,
66
+ mask: types.Tensor,
67
+ sample_weight: types.Tensor,
68
+ ) -> types.Tensor:
69
+ """Abstract method, should be implemented by subclasses."""
70
+ pass
71
+
72
+ def update_state(
73
+ self,
74
+ y_true: types.Tensor,
75
+ y_pred: types.Tensor,
76
+ sample_weight: Optional[types.Tensor] = None,
77
+ ) -> None:
78
+ """
79
+ Accumulates statistics for the ranking metric.
80
+
81
+ Args:
82
+ y_true: tensor or dict. Ground truth values. If tensor, of shape
83
+ `(list_size)` for unbatched inputs or `(batch_size, list_size)`
84
+ for batched inputs. If an item has a label of -1, it is ignored
85
+ in loss computation. If it is a dictionary, it should have two
86
+ keys: `"labels"` and `"mask"`. `"mask"` can be used to ignore
87
+ elements in metric computation, i.e., pairs will not be formed
88
+ with those items. Note that the final mask is an `and` of the
89
+ passed mask, `labels >= 0`, and `sample_weight > 0`.
90
+ y_pred: tensor. The predicted values, of shape `(list_size)` for
91
+ unbatched inputs or `(batch_size, list_size)` for batched
92
+ inputs. Should be of the same shape as `y_true`.
93
+ sample_weight: float/tensor. Can be float value, or tensor of
94
+ shape `(list_size)` or `(batch_size, list_size)`. Defaults to
95
+ `None`.
96
+ """
97
+ # === Process `y_true`, if dict ===
98
+ passed_mask = None
99
+ if isinstance(y_true, dict):
100
+ if "labels" not in y_true:
101
+ raise ValueError(
102
+ '`"labels"` should be present in `y_true`. Received: '
103
+ f"`y_true` = {y_true}"
104
+ )
105
+
106
+ passed_mask = y_true.get("mask", None)
107
+ y_true = y_true["labels"]
108
+
109
+ # === Convert to tensors, if list ===
110
+ # TODO (abheesht): Figure out if we need to cast tensors to
111
+ # `self.dtype`.
112
+ y_true = ops.convert_to_tensor(y_true)
113
+ y_pred = ops.convert_to_tensor(y_pred)
114
+ if sample_weight is not None:
115
+ sample_weight = ops.convert_to_tensor(sample_weight)
116
+ if passed_mask is not None:
117
+ passed_mask = ops.convert_to_tensor(passed_mask)
118
+
119
+ # === Process `sample_weight` ===
120
+ if sample_weight is None:
121
+ sample_weight = ops.cast(1, dtype=y_pred.dtype)
122
+
123
+ y_true_shape = ops.shape(y_true)
124
+ y_true_rank = len(y_true_shape)
125
+ sample_weight_shape = ops.shape(sample_weight)
126
+ sample_weight_rank = len(sample_weight_shape)
127
+
128
+ # Check `y_true_rank` first. Can be 1 for unbatched inputs, 2 for
129
+ # batched.
130
+ check_rank(y_true_rank, allowed_ranks=(1, 2), tensor_name="y_true")
131
+
132
+ # Check `sample_weight` rank. Should be between 0 and `y_true_rank`.
133
+ check_rank(
134
+ sample_weight_rank,
135
+ allowed_ranks=tuple(range(y_true_rank + 1)),
136
+ tensor_name="sample_weight",
137
+ )
138
+
139
+ if y_true_rank == 2:
140
+ # If `sample_weight` rank is 1, it should be of shape
141
+ # `(batch_size,)`. Otherwise, it should be of shape
142
+ # `(batch_size, list_size)`.
143
+ if sample_weight_rank == 1:
144
+ check_shapes_compatible(sample_weight_shape, (y_true_shape[0],))
145
+ # Uprank this, so that we get per-list weights here.
146
+ sample_weight = ops.expand_dims(sample_weight, axis=1)
147
+ elif sample_weight_rank == 2:
148
+ check_shapes_compatible(sample_weight_shape, y_true_shape)
149
+
150
+ # Reshape `sample_weight` to the shape of `y_true`.
151
+ sample_weight = ops.multiply(ops.ones_like(y_true), sample_weight)
152
+
153
+ # Mask all values less than 0 (since less than 0 implies invalid
154
+ # labels).
155
+ valid_mask = ops.greater_equal(y_true, ops.cast(0.0, y_true.dtype))
156
+ if passed_mask is not None:
157
+ valid_mask = ops.logical_and(valid_mask, passed_mask)
158
+
159
+ # === Process inputs - shape checking, upranking, etc. ===
160
+ y_true, y_pred, valid_mask, batched = standardize_call_inputs_ranks(
161
+ y_true=y_true,
162
+ y_pred=y_pred,
163
+ mask=valid_mask,
164
+ check_y_true_rank=False,
165
+ )
166
+
167
+ # Uprank sample_weight if unbatched.
168
+ if not batched:
169
+ sample_weight = ops.expand_dims(sample_weight, axis=0)
170
+
171
+ # Get `mask` from `sample_weight`.
172
+ sample_weight_mask = ops.greater(
173
+ sample_weight, ops.cast(0, dtype=sample_weight.dtype)
174
+ )
175
+ mask = ops.logical_and(valid_mask, sample_weight_mask)
176
+
177
+ # === Update "invalid" `y_true`, `y_pred` entries based on mask ===
178
+
179
+ # `y_true`: assign 0 for invalid entries
180
+ y_true = ops.where(mask, y_true, ops.zeros_like(y_true))
181
+ # `y_pred`: assign a value slightly smaller than the smallest value
182
+ # so that it gets sorted last.
183
+ y_pred = ops.where(
184
+ mask,
185
+ y_pred,
186
+ -keras.config.epsilon() * ops.ones_like(y_pred)
187
+ + ops.amin(y_pred, axis=1, keepdims=True),
188
+ )
189
+ sample_weight = ops.where(
190
+ mask, sample_weight, ops.cast(0, sample_weight.dtype)
191
+ )
192
+
193
+ # === Actual computation ===
194
+ per_list_metric_values, per_list_metric_weights = self.compute_metric(
195
+ y_true=y_true, y_pred=y_pred, mask=mask, sample_weight=sample_weight
196
+ )
197
+
198
+ # Chain to `super()` to get mean metric.
199
+ # TODO(abheesht): Figure out if we want to return unaggregated metric
200
+ # values too of shape `(batch_size,)` from `result()`.
201
+ super().update_state(
202
+ per_list_metric_values, sample_weight=per_list_metric_weights
203
+ )
204
+
205
+ def get_config(self) -> dict[str, Any]:
206
+ config: dict[str, Any] = super().get_config()
207
+ config.update(
208
+ {"k": self.k, "shuffle_ties": self.shuffle_ties, "seed": self.seed}
209
+ )
210
+ return config
211
+
212
+
213
+ ranking_metric_subclass_doc_string = """
214
+ Computes {metric_name} ({metric_abbreviation}).
215
+
216
+ This metric evaluates ranking quality. {concept_sentence}. The metric processes
217
+ true relevance labels in `y_true` ({relevance_type}) against predicted scores in
218
+ `y_pred`. The scores in `y_pred` are used to determine the rank order of items,
219
+ by sorting in descending order. {score_range_interpretation}.
220
+
221
+ For each list of predicted scores `s` in `y_pred` and the corresponding list
222
+ of true labels `y` in `y_true`, the per-query {metric_abbreviation} score is
223
+ calculated as follows:
224
+
225
+ {formula}
226
+
227
+ The final {metric_abbreviation} score reported is typically the weighted
228
+ average of these per-query scores across all queries/lists in the dataset.
229
+
230
+ Note: `sample_weight` is handled differently for ranking metrics. For
231
+ batched inputs, `sample_weight` can be scalar, 1D, 2D. The scalar case and
232
+ 1D case (list-wise weights) are straightforward. The 2D case (item-wise
233
+ weights) is different, in the sense that the sample weights are aggregated
234
+ to get 1D weights. For more details, refer to
235
+ `keras_rs.src.metrics.ranking_metrics_utils.get_list_weights`.
236
+ """
237
+
238
+ ranking_metric_subclass_doc_string_args = """
239
+
240
+ Args:{extra_args}
241
+ k: int. The number of top-ranked items to consider (the 'k' in 'top-k').
242
+ Must be a positive integer.
243
+ shuffle_ties: bool. Whether to randomly shuffle scores before sorting.
244
+ This is done to break ties. Defaults to `True`.
245
+ seed: int. Random seed used for shuffling.
246
+ name: Optional name for the loss instance.
247
+ dtype: The dtype of the metric's computations. Defaults to `None`, which
248
+ means using `keras.backend.floatx()`. `keras.backend.floatx()` is a
249
+ `"float32"` unless set to different value
250
+ (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
251
+ provided, then the `compute_dtype` will be utilized.
252
+ """