keras-rs-nightly 0.0.1.dev2025021903__py3-none-any.whl → 0.3.1.dev202512130338__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. keras_rs/__init__.py +9 -28
  2. keras_rs/layers/__init__.py +37 -0
  3. keras_rs/losses/__init__.py +19 -0
  4. keras_rs/metrics/__init__.py +16 -0
  5. keras_rs/src/layers/embedding/base_distributed_embedding.py +1151 -0
  6. keras_rs/src/layers/embedding/distributed_embedding.py +33 -0
  7. keras_rs/src/layers/embedding/distributed_embedding_config.py +132 -0
  8. keras_rs/src/layers/embedding/embed_reduce.py +309 -0
  9. keras_rs/src/layers/embedding/jax/__init__.py +0 -0
  10. keras_rs/src/layers/embedding/jax/checkpoint_utils.py +104 -0
  11. keras_rs/src/layers/embedding/jax/config_conversion.py +468 -0
  12. keras_rs/src/layers/embedding/jax/distributed_embedding.py +829 -0
  13. keras_rs/src/layers/embedding/jax/embedding_lookup.py +276 -0
  14. keras_rs/src/layers/embedding/jax/embedding_utils.py +217 -0
  15. keras_rs/src/layers/embedding/tensorflow/__init__.py +0 -0
  16. keras_rs/src/layers/embedding/tensorflow/config_conversion.py +363 -0
  17. keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py +436 -0
  18. keras_rs/src/layers/feature_interaction/__init__.py +0 -0
  19. keras_rs/src/layers/{modeling → feature_interaction}/dot_interaction.py +116 -25
  20. keras_rs/src/layers/{modeling → feature_interaction}/feature_cross.py +40 -22
  21. keras_rs/src/layers/retrieval/brute_force_retrieval.py +16 -65
  22. keras_rs/src/layers/retrieval/hard_negative_mining.py +94 -0
  23. keras_rs/src/layers/retrieval/remove_accidental_hits.py +97 -0
  24. keras_rs/src/layers/retrieval/retrieval.py +127 -0
  25. keras_rs/src/layers/retrieval/sampling_probability_correction.py +63 -0
  26. keras_rs/src/losses/__init__.py +0 -0
  27. keras_rs/src/losses/list_mle_loss.py +212 -0
  28. keras_rs/src/losses/pairwise_hinge_loss.py +90 -0
  29. keras_rs/src/losses/pairwise_logistic_loss.py +99 -0
  30. keras_rs/src/losses/pairwise_loss.py +165 -0
  31. keras_rs/src/losses/pairwise_loss_utils.py +39 -0
  32. keras_rs/src/losses/pairwise_mean_squared_error.py +133 -0
  33. keras_rs/src/losses/pairwise_soft_zero_one_loss.py +98 -0
  34. keras_rs/src/metrics/__init__.py +0 -0
  35. keras_rs/src/metrics/dcg.py +161 -0
  36. keras_rs/src/metrics/mean_average_precision.py +130 -0
  37. keras_rs/src/metrics/mean_reciprocal_rank.py +121 -0
  38. keras_rs/src/metrics/ndcg.py +197 -0
  39. keras_rs/src/metrics/precision_at_k.py +117 -0
  40. keras_rs/src/metrics/ranking_metric.py +260 -0
  41. keras_rs/src/metrics/ranking_metrics_utils.py +257 -0
  42. keras_rs/src/metrics/recall_at_k.py +108 -0
  43. keras_rs/src/metrics/utils.py +70 -0
  44. keras_rs/src/types.py +43 -14
  45. keras_rs/src/utils/doc_string_utils.py +53 -0
  46. keras_rs/src/utils/keras_utils.py +52 -3
  47. keras_rs/src/utils/tpu_test_utils.py +120 -0
  48. keras_rs/src/version.py +1 -1
  49. {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/METADATA +88 -8
  50. keras_rs_nightly-0.3.1.dev202512130338.dist-info/RECORD +58 -0
  51. {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/WHEEL +1 -1
  52. keras_rs/api/__init__.py +0 -9
  53. keras_rs/api/layers/__init__.py +0 -11
  54. keras_rs_nightly-0.0.1.dev2025021903.dist-info/RECORD +0 -19
  55. /keras_rs/src/layers/{modeling → embedding}/__init__.py +0 -0
  56. {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,133 @@
1
+ from keras import ops
2
+
3
+ from keras_rs.src import types
4
+ from keras_rs.src.api_export import keras_rs_export
5
+ from keras_rs.src.losses.pairwise_loss import PairwiseLoss
6
+ from keras_rs.src.losses.pairwise_loss import pairwise_loss_subclass_doc_string
7
+ from keras_rs.src.losses.pairwise_loss_utils import apply_pairwise_op
8
+
9
+
10
+ @keras_rs_export("keras_rs.losses.PairwiseMeanSquaredError")
11
+ class PairwiseMeanSquaredError(PairwiseLoss):
12
+ def pairwise_loss(self, pairwise_logits: types.Tensor) -> types.Tensor:
13
+ # Since we override `compute_unreduced_loss`, we do not need to
14
+ # implement this method.
15
+ pass
16
+
17
+ def compute_unreduced_loss(
18
+ self,
19
+ labels: types.Tensor,
20
+ logits: types.Tensor,
21
+ mask: types.Tensor | None = None,
22
+ ) -> tuple[types.Tensor, types.Tensor]:
23
+ # Override `PairwiseLoss.compute_unreduced_loss` since pairwise weights
24
+ # for MSE are computed differently.
25
+
26
+ batch_size, list_size = ops.shape(labels)
27
+
28
+ # Mask all values less than 0 (since less than 0 implies invalid
29
+ # labels).
30
+ valid_mask = ops.greater_equal(labels, ops.cast(0.0, labels.dtype))
31
+
32
+ if mask is not None:
33
+ valid_mask = ops.logical_and(valid_mask, mask)
34
+
35
+ # Compute the difference for all pairs in a list. The output is a tensor
36
+ # with shape `(batch_size, list_size, list_size)`, where `[:, i, j]`
37
+ # stores information for pair `(i, j)`.
38
+ pairwise_labels_diff = apply_pairwise_op(labels, ops.subtract)
39
+ pairwise_logits_diff = apply_pairwise_op(logits, ops.subtract)
40
+ valid_pair = apply_pairwise_op(valid_mask, ops.logical_and)
41
+ pairwise_mse = ops.square(pairwise_labels_diff - pairwise_logits_diff)
42
+
43
+ # Compute weights.
44
+ pairwise_weights = ops.ones_like(pairwise_mse)
45
+ # Exclude self pairs.
46
+ pairwise_weights = ops.subtract(
47
+ pairwise_weights,
48
+ ops.tile(ops.eye(list_size, list_size), (batch_size, 1, 1)),
49
+ )
50
+ # Include only valid pairs.
51
+ pairwise_weights = ops.multiply(
52
+ pairwise_weights, ops.cast(valid_pair, dtype=pairwise_weights.dtype)
53
+ )
54
+
55
+ return pairwise_mse, pairwise_weights
56
+
57
+
58
+ formula = "loss = sum_{i} sum_{j} I(y_i > y_j) * (s_i - s_j)^2"
59
+ explanation = """
60
+ - `(s_i - s_j)^2` is the squared difference between the predicted scores
61
+ of items `i` and `j`, which penalizes discrepancies between the predicted
62
+ order of items relative to their true order.
63
+ """
64
+ extra_args = ""
65
+ example = """
66
+ With `compile()` API:
67
+
68
+ ```python
69
+ model.compile(
70
+ loss=keras_rs.losses.PairwiseMeanSquaredError(),
71
+ ...
72
+ )
73
+ ```
74
+
75
+ As a standalone function with unbatched inputs:
76
+
77
+ >>> y_true = np.array([1.0, 0.0, 1.0, 3.0, 2.0])
78
+ >>> y_pred = np.array([1.0, 3.0, 2.0, 4.0, 0.8])
79
+ >>> pairwise_mse = keras_rs.losses.PairwiseMeanSquaredError()
80
+ >>> pairwise_mse(y_true=y_true, y_pred=y_pred)
81
+ >>> 19.10400
82
+
83
+ With batched inputs using default 'auto'/'sum_over_batch_size' reduction:
84
+
85
+ >>> y_true = np.array([[1.0, 0.0, 1.0, 3.0], [0.0, 1.0, 2.0, 3.0]])
86
+ >>> y_pred = np.array([[1.0, 3.0, 2.0, 4.0], [1.0, 1.8, 2.0, 3.0]])
87
+ >>> pairwise_mse = keras_rs.losses.PairwiseMeanSquaredError()
88
+ >>> pairwise_mse(y_true=y_true, y_pred=y_pred)
89
+ 5.57999
90
+
91
+ With masked inputs (useful for ragged inputs):
92
+
93
+ >>> y_true = {
94
+ ... "labels": np.array([[1.0, 0.0, 1.0, 3.0], [0.0, 1.0, 2.0, 3.0]]),
95
+ ... "mask": np.array(
96
+ ... [[True, True, True, True], [True, True, False, False]]
97
+ ... ),
98
+ ... }
99
+ >>> y_pred = np.array([[1.0, 3.0, 2.0, 4.0], [1.0, 1.8, 2.0, 3.0]])
100
+ >>> pairwise_mse(y_true=y_true, y_pred=y_pred)
101
+ 4.76000
102
+
103
+ With `sample_weight`:
104
+
105
+ >>> y_true = np.array([[1.0, 0.0, 1.0, 3.0], [0.0, 1.0, 2.0, 3.0]])
106
+ >>> y_pred = np.array([[1.0, 3.0, 2.0, 4.0], [1.0, 1.8, 2.0, 3.0]])
107
+ >>> sample_weight = np.array(
108
+ ... [[2.0, 3.0, 1.0, 1.0], [2.0, 1.0, 0.0, 0.0]]
109
+ ... )
110
+ >>> pairwise_mse = keras_rs.losses.PairwiseMeanSquaredError()
111
+ >>> pairwise_mse(
112
+ ... y_true=y_true, y_pred=y_pred, sample_weight=sample_weight
113
+ ... )
114
+ 11.0500
115
+
116
+ Using `'none'` reduction:
117
+
118
+ >>> y_true = np.array([[1.0, 0.0, 1.0, 3.0], [0.0, 1.0, 2.0, 3.0]])
119
+ >>> y_pred = np.array([[1.0, 3.0, 2.0, 4.0], [1.0, 1.8, 2.0, 3.0]])
120
+ >>> pairwise_mse = keras_rs.losses.PairwiseMeanSquaredError(
121
+ ... reduction="none"
122
+ ... )
123
+ >>> pairwise_mse(y_true=y_true, y_pred=y_pred)
124
+ [[11., 17., 5., 5.], [2.04, 1.3199998, 1.6399999, 1.6399999]]
125
+ """
126
+
127
+ PairwiseMeanSquaredError.__doc__ = pairwise_loss_subclass_doc_string.format(
128
+ loss_name="mean squared error",
129
+ formula=formula,
130
+ explanation=explanation,
131
+ extra_args=extra_args,
132
+ example=example,
133
+ )
@@ -0,0 +1,98 @@
1
+ from keras import ops
2
+
3
+ from keras_rs.src import types
4
+ from keras_rs.src.api_export import keras_rs_export
5
+ from keras_rs.src.losses.pairwise_loss import PairwiseLoss
6
+ from keras_rs.src.losses.pairwise_loss import pairwise_loss_subclass_doc_string
7
+
8
+
9
+ @keras_rs_export("keras_rs.losses.PairwiseSoftZeroOneLoss")
10
+ class PairwiseSoftZeroOneLoss(PairwiseLoss):
11
+ def pairwise_loss(self, pairwise_logits: types.Tensor) -> types.Tensor:
12
+ return ops.where(
13
+ ops.greater(pairwise_logits, ops.array(0.0)),
14
+ ops.subtract(ops.array(1.0), ops.sigmoid(pairwise_logits)),
15
+ ops.sigmoid(ops.negative(pairwise_logits)),
16
+ )
17
+
18
+
19
+ formula = "loss = sum_{i} sum_{j} I(y_i > y_j) * (1 - sigmoid(s_i - s_j))"
20
+ explanation = """
21
+ - `(1 - sigmoid(s_i - s_j))` represents the soft zero-one loss, which
22
+ approximates the ideal zero-one loss (which would be 1 if `s_i < s_j`
23
+ and 0 otherwise) with a smooth, differentiable function. This makes it
24
+ suitable for gradient-based optimization.
25
+ """
26
+ extra_args = ""
27
+ example = """
28
+ With `compile()` API:
29
+
30
+ ```python
31
+ model.compile(
32
+ loss=keras_rs.losses.PairwiseSoftZeroOneLoss(),
33
+ ...
34
+ )
35
+ ```
36
+
37
+ As a standalone function with unbatched inputs:
38
+
39
+ >>> y_true = np.array([1.0, 0.0, 1.0, 3.0, 2.0])
40
+ >>> y_pred = np.array([1.0, 3.0, 2.0, 4.0, 0.8])
41
+ >>> pairwise_soft_zero_one_loss = keras_rs.losses.PairwiseSoftZeroOneLoss()
42
+ >>> pairwise_soft_zero_one_loss(y_true=y_true, y_pred=y_pred)
43
+ 0.86103
44
+
45
+ With batched inputs using default 'auto'/'sum_over_batch_size' reduction:
46
+
47
+ >>> y_true = np.array([[1.0, 0.0, 1.0, 3.0], [0.0, 1.0, 2.0, 3.0]])
48
+ >>> y_pred = np.array([[1.0, 3.0, 2.0, 4.0], [1.0, 1.8, 2.0, 3.0]])
49
+ >>> pairwise_soft_zero_one_loss = keras_rs.losses.PairwiseSoftZeroOneLoss()
50
+ >>> pairwise_soft_zero_one_loss(y_true=y_true, y_pred=y_pred)
51
+ 0.46202
52
+
53
+ With masked inputs (useful for ragged inputs):
54
+
55
+ >>> y_true = {
56
+ ... "labels": np.array([[1.0, 0.0, 1.0, 3.0], [0.0, 1.0, 2.0, 3.0]]),
57
+ ... "mask": np.array(
58
+ ... [[True, True, True, True], [True, True, False, False]]
59
+ ... ),
60
+ ... }
61
+ >>> y_pred = np.array([[1.0, 3.0, 2.0, 4.0], [1.0, 1.8, 2.0, 3.0]])
62
+ >>> pairwise_soft_zero_one_loss(y_true=y_true, y_pred=y_pred)
63
+ 0.29468
64
+
65
+ With `sample_weight`:
66
+
67
+ >>> y_true = np.array([[1.0, 0.0, 1.0, 3.0], [0.0, 1.0, 2.0, 3.0]])
68
+ >>> y_pred = np.array([[1.0, 3.0, 2.0, 4.0], [1.0, 1.8, 2.0, 3.0]])
69
+ >>> sample_weight = np.array(
70
+ ... [[2.0, 3.0, 1.0, 1.0], [2.0, 1.0, 0.0, 0.0]]
71
+ ... )
72
+ >>> pairwise_soft_zero_one_loss = keras_rs.losses.PairwiseSoftZeroOneLoss()
73
+ >>> pairwise_soft_zero_one_loss(
74
+ ... y_true=y_true, y_pred=y_pred, sample_weight=sample_weight
75
+ ... )
76
+ 0.40478
77
+
78
+ Using `'none'` reduction:
79
+
80
+ >>> y_true = np.array([[1.0, 0.0, 1.0, 3.0], [0.0, 1.0, 2.0, 3.0]])
81
+ >>> y_pred = np.array([[1.0, 3.0, 2.0, 4.0], [1.0, 1.8, 2.0, 3.0]])
82
+ >>> pairwise_soft_zero_one_loss = keras_rs.losses.PairwiseSoftZeroOneLoss(
83
+ ... reduction="none"
84
+ ... )
85
+ >>> pairwise_soft_zero_one_loss(y_true=y_true, y_pred=y_pred)
86
+ [
87
+ [0.8807971 , 0., 0.73105854, 0.43557024],
88
+ [0., 0.31002545, 0.7191075 , 0.61961967]
89
+ ]
90
+ """
91
+
92
+ PairwiseSoftZeroOneLoss.__doc__ = pairwise_loss_subclass_doc_string.format(
93
+ loss_name="soft zero-one loss",
94
+ formula=formula,
95
+ explanation=explanation,
96
+ extra_args=extra_args,
97
+ example=example,
98
+ )
File without changes
@@ -0,0 +1,161 @@
1
+ from typing import Any, Callable
2
+
3
+ from keras import ops
4
+ from keras.saving import deserialize_keras_object
5
+ from keras.saving import serialize_keras_object
6
+
7
+ from keras_rs.src import types
8
+ from keras_rs.src.api_export import keras_rs_export
9
+ from keras_rs.src.metrics.ranking_metric import RankingMetric
10
+ from keras_rs.src.metrics.ranking_metric import (
11
+ ranking_metric_subclass_doc_string,
12
+ )
13
+ from keras_rs.src.metrics.ranking_metric import (
14
+ ranking_metric_subclass_doc_string_post_desc,
15
+ )
16
+ from keras_rs.src.metrics.ranking_metrics_utils import compute_dcg
17
+ from keras_rs.src.metrics.ranking_metrics_utils import default_gain_fn
18
+ from keras_rs.src.metrics.ranking_metrics_utils import default_rank_discount_fn
19
+ from keras_rs.src.metrics.ranking_metrics_utils import get_list_weights
20
+ from keras_rs.src.metrics.ranking_metrics_utils import sort_by_scores
21
+ from keras_rs.src.utils.doc_string_utils import format_docstring
22
+
23
+
24
+ @keras_rs_export("keras_rs.metrics.DCG")
25
+ class DCG(RankingMetric):
26
+ def __init__(
27
+ self,
28
+ k: int | None = None,
29
+ gain_fn: Callable[[types.Tensor], types.Tensor] = default_gain_fn,
30
+ rank_discount_fn: Callable[
31
+ [types.Tensor], types.Tensor
32
+ ] = default_rank_discount_fn,
33
+ **kwargs: Any,
34
+ ) -> None:
35
+ super().__init__(k=k, **kwargs)
36
+
37
+ self.gain_fn = gain_fn
38
+ self.rank_discount_fn = rank_discount_fn
39
+
40
+ def compute_metric(
41
+ self,
42
+ y_true: types.Tensor,
43
+ y_pred: types.Tensor,
44
+ mask: types.Tensor,
45
+ sample_weight: types.Tensor,
46
+ ) -> types.Tensor:
47
+ sorted_y_true, sorted_weights = sort_by_scores(
48
+ tensors_to_sort=[y_true, sample_weight],
49
+ scores=y_pred,
50
+ k=self.k,
51
+ mask=mask,
52
+ shuffle_ties=self.shuffle_ties,
53
+ seed=self.seed_generator,
54
+ )
55
+
56
+ dcg = compute_dcg(
57
+ y_true=sorted_y_true,
58
+ sample_weight=sorted_weights,
59
+ gain_fn=self.gain_fn,
60
+ rank_discount_fn=self.rank_discount_fn,
61
+ )
62
+
63
+ per_list_weights = get_list_weights(
64
+ weights=sample_weight, relevance=self.gain_fn(y_true)
65
+ )
66
+ # Since we have already multiplied with `sample_weight`, we need to
67
+ # divide by `per_list_weights` so as to nullify the multiplication
68
+ # which `keras.metrics.Mean` will do.
69
+ per_list_dcg = ops.divide_no_nan(dcg, per_list_weights)
70
+
71
+ return per_list_dcg, per_list_weights
72
+
73
+ def get_config(self) -> dict[str, Any]:
74
+ config: dict[str, Any] = super().get_config()
75
+ config.update(
76
+ {
77
+ "gain_fn": serialize_keras_object(self.gain_fn),
78
+ "rank_discount_fn": serialize_keras_object(
79
+ self.rank_discount_fn
80
+ ),
81
+ }
82
+ )
83
+ return config
84
+
85
+ @classmethod
86
+ def from_config(cls, config: dict[str, Any]) -> "DCG":
87
+ config["gain_fn"] = deserialize_keras_object(config["gain_fn"])
88
+ config["rank_discount_fn"] = deserialize_keras_object(
89
+ config["rank_discount_fn"]
90
+ )
91
+ return cls(**config)
92
+
93
+
94
+ concept_sentence = (
95
+ "It computes the sum of the graded relevance scores of items, applying a "
96
+ "configurable discount based on position"
97
+ )
98
+ relevance_type = (
99
+ "graded relevance scores (non-negative numbers where higher values "
100
+ "indicate greater relevance)"
101
+ )
102
+ score_range_interpretation = (
103
+ "Scores are non-negative, with higher values indicating better ranking "
104
+ "quality (highly relevant items are ranked higher). The score for a single "
105
+ "list is not bounded or normalized, i.e., it does not lie in a range"
106
+ )
107
+
108
+ formula = """
109
+ ```
110
+ DCG@k(y', w') = sum_{i=1}^{k} (gain_fn(y'_i) / rank_discount_fn(i))
111
+ ```
112
+
113
+ where:
114
+ - `y'_i` is the true relevance score of the item ranked at position `i`
115
+ (obtained by sorting `y_true` according to `y_pred`).
116
+ - `gain_fn` is the user-provided function mapping relevance `y'_i` to a
117
+ gain value. The default function (`default_gain_fn`) is typically
118
+ equivalent to `lambda y: 2**y - 1`.
119
+ - `rank_discount_fn` is the user-provided function mapping rank `i`
120
+ to a discount value. The default function (`default_rank_discount_fn`)
121
+ is typically equivalent to `lambda rank: 1 / log2(rank + 1)`.
122
+ - The final result aggregates these per-list scores."""
123
+ extra_args = """
124
+ gain_fn: callable. Maps relevance scores (`y_true`) to gain values. The
125
+ default implements `2**y - 1`.
126
+ rank_discount_fn: function. Maps rank positions to discount
127
+ values. The default (`default_rank_discount_fn`) implements
128
+ `1 / log2(rank + 1)`."""
129
+ example = """
130
+ >>> batch_size = 2
131
+ >>> list_size = 5
132
+ >>> labels = np.random.randint(0, 3, size=(batch_size, list_size))
133
+ >>> scores = np.random.random(size=(batch_size, list_size))
134
+ >>> metric = keras_rs.metrics.DCG()(
135
+ ... y_true=labels, y_pred=scores
136
+ ... )
137
+
138
+ Mask certain elements (can be used for uneven inputs):
139
+
140
+ >>> batch_size = 2
141
+ >>> list_size = 5
142
+ >>> labels = np.random.randint(0, 3, size=(batch_size, list_size))
143
+ >>> scores = np.random.random(size=(batch_size, list_size))
144
+ >>> mask = np.random.randint(0, 2, size=(batch_size, list_size), dtype=bool)
145
+ >>> metric = keras_rs.metrics.DCG()(
146
+ ... y_true={"labels": labels, "mask": mask}, y_pred=scores
147
+ ... )
148
+ """
149
+
150
+ DCG.__doc__ = format_docstring(
151
+ ranking_metric_subclass_doc_string,
152
+ width=80,
153
+ metric_name="Discounted Cumulative Gain",
154
+ metric_abbreviation="DCG",
155
+ concept_sentence=concept_sentence,
156
+ relevance_type=relevance_type,
157
+ score_range_interpretation=score_range_interpretation,
158
+ formula=formula,
159
+ ) + ranking_metric_subclass_doc_string_post_desc.format(
160
+ extra_args=extra_args, example=example
161
+ )
@@ -0,0 +1,130 @@
1
+ from keras import ops
2
+
3
+ from keras_rs.src import types
4
+ from keras_rs.src.api_export import keras_rs_export
5
+ from keras_rs.src.metrics.ranking_metric import RankingMetric
6
+ from keras_rs.src.metrics.ranking_metric import (
7
+ ranking_metric_subclass_doc_string,
8
+ )
9
+ from keras_rs.src.metrics.ranking_metric import (
10
+ ranking_metric_subclass_doc_string_post_desc,
11
+ )
12
+ from keras_rs.src.metrics.ranking_metrics_utils import get_list_weights
13
+ from keras_rs.src.metrics.ranking_metrics_utils import sort_by_scores
14
+ from keras_rs.src.utils.doc_string_utils import format_docstring
15
+
16
+
17
+ @keras_rs_export("keras_rs.metrics.MeanAveragePrecision")
18
+ class MeanAveragePrecision(RankingMetric):
19
+ def compute_metric(
20
+ self,
21
+ y_true: types.Tensor,
22
+ y_pred: types.Tensor,
23
+ mask: types.Tensor,
24
+ sample_weight: types.Tensor,
25
+ ) -> types.Tensor:
26
+ relevance = ops.cast(
27
+ ops.greater_equal(y_true, ops.cast(1, dtype=y_true.dtype)),
28
+ dtype=y_pred.dtype,
29
+ )
30
+ sorted_relevance, sorted_weights = sort_by_scores(
31
+ tensors_to_sort=[relevance, sample_weight],
32
+ scores=y_pred,
33
+ mask=mask,
34
+ k=self.k,
35
+ shuffle_ties=self.shuffle_ties,
36
+ seed=self.seed_generator,
37
+ )
38
+ per_list_relevant_counts = ops.cumsum(sorted_relevance, axis=1)
39
+ per_list_cutoffs = ops.cumsum(ops.ones_like(sorted_relevance), axis=1)
40
+ per_list_precisions = ops.divide_no_nan(
41
+ per_list_relevant_counts, per_list_cutoffs
42
+ )
43
+
44
+ total_precision = ops.sum(
45
+ ops.multiply(
46
+ per_list_precisions,
47
+ ops.multiply(sorted_weights, sorted_relevance),
48
+ ),
49
+ axis=1,
50
+ keepdims=True,
51
+ )
52
+
53
+ # Compute the total relevance.
54
+ total_relevance = ops.sum(
55
+ ops.multiply(sample_weight, relevance), axis=1, keepdims=True
56
+ )
57
+
58
+ per_list_map = ops.divide_no_nan(total_precision, total_relevance)
59
+
60
+ per_list_weights = get_list_weights(sample_weight, relevance)
61
+
62
+ return per_list_map, per_list_weights
63
+
64
+
65
+ concept_sentence = (
66
+ "It calculates the average of precision values computed after each "
67
+ "relevant item present in the ranked list"
68
+ )
69
+ relevance_type = "binary indicators (0 or 1) of relevance"
70
+ score_range_interpretation = (
71
+ "Scores range from 0 to 1, with higher values indicating that relevant "
72
+ "items are generally positioned higher in the ranking"
73
+ )
74
+
75
+ formula = """
76
+ The formula for average precision is defined below. MAP is the mean over average
77
+ precision computed for each list.
78
+
79
+ ```
80
+ AP(y, s) = sum_j (P@j(y, s) * rel(j)) / sum_i y_i
81
+ rel(j) = y_i if rank(s_i) = j
82
+ ```
83
+
84
+ where:
85
+ - `j` represents the rank position (starting from 1).
86
+ - `sum_j` indicates a summation over all ranks `j` from 1 up to the list
87
+ size (or `k`).
88
+ - `P@j(y, s)` denotes the Precision at rank `j`, calculated as the
89
+ number of relevant items found within the top `j` positions divided by `j`.
90
+ - `rel(j)` represents the relevance of the item specifically at rank
91
+ `j`. `rel(j)` is 1 if the item at rank `j` is relevant, and 0 otherwise.
92
+ - `y_i` is the true relevance label of the original item `i` before ranking.
93
+ - `rank(s_i)` is the rank position assigned to item `i` based on its score
94
+ `s_i`.
95
+ - `sum_i y_i` calculates the total number of relevant items in the original
96
+ list `y`."""
97
+ extra_args = ""
98
+ example = """
99
+ >>> batch_size = 2
100
+ >>> list_size = 5
101
+ >>> labels = np.random.randint(0, 2, size=(batch_size, list_size))
102
+ >>> scores = np.random.random(size=(batch_size, list_size))
103
+ >>> metric = keras_rs.metrics.MeanAveragePrecision()(
104
+ ... y_true=labels, y_pred=scores
105
+ ... )
106
+
107
+ Mask certain elements (can be used for uneven inputs):
108
+
109
+ >>> batch_size = 2
110
+ >>> list_size = 5
111
+ >>> labels = np.random.randint(0, 2, size=(batch_size, list_size))
112
+ >>> scores = np.random.random(size=(batch_size, list_size))
113
+ >>> mask = np.random.randint(0, 2, size=(batch_size, list_size), dtype=bool)
114
+ >>> metric = keras_rs.metrics.MeanAveragePrecision()(
115
+ ... y_true={"labels": labels, "mask": mask}, y_pred=scores
116
+ ... )
117
+ """
118
+
119
+ MeanAveragePrecision.__doc__ = format_docstring(
120
+ ranking_metric_subclass_doc_string,
121
+ width=80,
122
+ metric_name="Mean Average Precision",
123
+ metric_abbreviation="MAP",
124
+ concept_sentence=concept_sentence,
125
+ relevance_type=relevance_type,
126
+ score_range_interpretation=score_range_interpretation,
127
+ formula=formula,
128
+ ) + ranking_metric_subclass_doc_string_post_desc.format(
129
+ extra_args=extra_args, example=example
130
+ )
@@ -0,0 +1,121 @@
1
+ from keras import ops
2
+
3
+ from keras_rs.src import types
4
+ from keras_rs.src.api_export import keras_rs_export
5
+ from keras_rs.src.metrics.ranking_metric import RankingMetric
6
+ from keras_rs.src.metrics.ranking_metric import (
7
+ ranking_metric_subclass_doc_string,
8
+ )
9
+ from keras_rs.src.metrics.ranking_metric import (
10
+ ranking_metric_subclass_doc_string_post_desc,
11
+ )
12
+ from keras_rs.src.metrics.ranking_metrics_utils import get_list_weights
13
+ from keras_rs.src.metrics.ranking_metrics_utils import sort_by_scores
14
+ from keras_rs.src.utils.doc_string_utils import format_docstring
15
+
16
+
17
+ @keras_rs_export("keras_rs.metrics.MeanReciprocalRank")
18
+ class MeanReciprocalRank(RankingMetric):
19
+ def compute_metric(
20
+ self,
21
+ y_true: types.Tensor,
22
+ y_pred: types.Tensor,
23
+ mask: types.Tensor,
24
+ sample_weight: types.Tensor,
25
+ ) -> types.Tensor:
26
+ # Assume: `y_true = [0, 0, 1]`, `y_pred = [0.1, 0.9, 0.2]`.
27
+ # `sorted_y_true = [0, 1, 0]` (sorted in descending order).
28
+ (sorted_y_true,) = sort_by_scores(
29
+ tensors_to_sort=[y_true],
30
+ scores=y_pred,
31
+ mask=mask,
32
+ k=self.k,
33
+ shuffle_ties=self.shuffle_ties,
34
+ seed=self.seed_generator,
35
+ )
36
+
37
+ # This will depend on `k`, i.e., it will not always be the same as
38
+ # `len(y_true)`.
39
+ list_length = ops.shape(sorted_y_true)[1]
40
+
41
+ # We consider only binary relevance here, anything above 1 is treated
42
+ # as 1. `relevance = [0., 1., 0.]`.
43
+ relevance = ops.cast(
44
+ ops.greater_equal(
45
+ sorted_y_true, ops.cast(1, dtype=sorted_y_true.dtype)
46
+ ),
47
+ dtype=y_pred.dtype,
48
+ )
49
+
50
+ # `reciprocal_rank = [1, 0.5, 0.33]`
51
+ reciprocal_rank = ops.divide(
52
+ ops.cast(1, dtype=y_pred.dtype),
53
+ ops.arange(1, list_length + 1, dtype=y_pred.dtype),
54
+ )
55
+
56
+ # `mrr` should be of shape `(batch_size, 1)`.
57
+ # `mrr = amax([0., 0.5, 0.]) = 0.5`
58
+ mrr = ops.amax(
59
+ ops.multiply(relevance, reciprocal_rank),
60
+ axis=1,
61
+ keepdims=True,
62
+ )
63
+
64
+ # Get weights.
65
+ overall_relevance = ops.cast(
66
+ ops.greater_equal(y_true, ops.cast(1, dtype=y_true.dtype)),
67
+ dtype=y_pred.dtype,
68
+ )
69
+ per_list_weights = get_list_weights(
70
+ weights=sample_weight, relevance=overall_relevance
71
+ )
72
+
73
+ return mrr, per_list_weights
74
+
75
+
76
+ concept_sentence = (
77
+ "It focuses on the rank position of the single highest-scoring relevant "
78
+ "item"
79
+ )
80
+ relevance_type = "binary indicators (0 or 1) of relevance"
81
+ score_range_interpretation = (
82
+ "Scores range from 0 to 1, with 1 indicating the first relevant item was "
83
+ "always ranked first"
84
+ )
85
+ formula = """```
86
+ MRR(y, s) = max_{i} y_{i} / rank(s_{i})
87
+ ```"""
88
+ extra_args = ""
89
+ example = """
90
+ >>> batch_size = 2
91
+ >>> list_size = 5
92
+ >>> labels = np.random.randint(0, 2, size=(batch_size, list_size))
93
+ >>> scores = np.random.random(size=(batch_size, list_size))
94
+ >>> metric = keras_rs.metrics.MeanReciprocalRank()(
95
+ ... y_true=labels, y_pred=scores
96
+ ... )
97
+
98
+ Mask certain elements (can be used for uneven inputs):
99
+
100
+ >>> batch_size = 2
101
+ >>> list_size = 5
102
+ >>> labels = np.random.randint(0, 2, size=(batch_size, list_size))
103
+ >>> scores = np.random.random(size=(batch_size, list_size))
104
+ >>> mask = np.random.randint(0, 2, size=(batch_size, list_size), dtype=bool)
105
+ >>> metric = keras_rs.metrics.MeanReciprocalRank()(
106
+ ... y_true={"labels": labels, "mask": mask}, y_pred=scores
107
+ ... )
108
+ """
109
+
110
+ MeanReciprocalRank.__doc__ = format_docstring(
111
+ ranking_metric_subclass_doc_string,
112
+ width=80,
113
+ metric_name="Mean Reciprocal Rank",
114
+ metric_abbreviation="MRR",
115
+ concept_sentence=concept_sentence,
116
+ relevance_type=relevance_type,
117
+ score_range_interpretation=score_range_interpretation,
118
+ formula=formula,
119
+ ) + ranking_metric_subclass_doc_string_post_desc.format(
120
+ extra_args=extra_args, example=example
121
+ )