keras-rs-nightly 0.0.1.dev2025042103__py3-none-any.whl → 0.0.1.dev2025042203__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of keras-rs-nightly might be problematic. Click here for more details.
- keras_rs/api/__init__.py +1 -0
- keras_rs/api/metrics/__init__.py +12 -0
- keras_rs/src/losses/pairwise_hinge_loss.py +1 -0
- keras_rs/src/losses/pairwise_logistic_loss.py +1 -0
- keras_rs/src/losses/pairwise_loss.py +36 -12
- keras_rs/src/losses/pairwise_loss_utils.py +39 -0
- keras_rs/src/losses/pairwise_mean_squared_error.py +2 -1
- keras_rs/src/losses/pairwise_soft_zero_one_loss.py +1 -0
- keras_rs/src/metrics/__init__.py +0 -0
- keras_rs/src/metrics/dcg.py +140 -0
- keras_rs/src/metrics/mean_average_precision.py +112 -0
- keras_rs/src/metrics/mean_reciprocal_rank.py +98 -0
- keras_rs/src/metrics/ndcg.py +184 -0
- keras_rs/src/metrics/precision_at_k.py +94 -0
- keras_rs/src/metrics/ranking_metric.py +252 -0
- keras_rs/src/metrics/ranking_metrics_utils.py +238 -0
- keras_rs/src/metrics/recall_at_k.py +85 -0
- keras_rs/src/metrics/utils.py +72 -0
- keras_rs/src/utils/doc_string_utils.py +48 -0
- keras_rs/src/utils/keras_utils.py +12 -0
- keras_rs/src/version.py +1 -1
- {keras_rs_nightly-0.0.1.dev2025042103.dist-info → keras_rs_nightly-0.0.1.dev2025042203.dist-info}/METADATA +2 -2
- keras_rs_nightly-0.0.1.dev2025042203.dist-info/RECORD +43 -0
- keras_rs/src/utils/pairwise_loss_utils.py +0 -102
- keras_rs_nightly-0.0.1.dev2025042103.dist-info/RECORD +0 -31
- {keras_rs_nightly-0.0.1.dev2025042103.dist-info → keras_rs_nightly-0.0.1.dev2025042203.dist-info}/WHEEL +0 -0
- {keras_rs_nightly-0.0.1.dev2025042103.dist-info → keras_rs_nightly-0.0.1.dev2025042203.dist-info}/top_level.txt +0 -0
keras_rs/api/__init__.py
CHANGED
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
"""DO NOT EDIT.
|
|
2
|
+
|
|
3
|
+
This file was autogenerated. Do not edit it by hand,
|
|
4
|
+
since your modifications would be overwritten.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from keras_rs.src.metrics.dcg import DCG
|
|
8
|
+
from keras_rs.src.metrics.mean_average_precision import MeanAveragePrecision
|
|
9
|
+
from keras_rs.src.metrics.mean_reciprocal_rank import MeanReciprocalRank
|
|
10
|
+
from keras_rs.src.metrics.ndcg import NDCG
|
|
11
|
+
from keras_rs.src.metrics.precision_at_k import PrecisionAtK
|
|
12
|
+
from keras_rs.src.metrics.recall_at_k import RecallAtK
|
|
@@ -1,12 +1,12 @@
|
|
|
1
1
|
import abc
|
|
2
|
-
from typing import Optional
|
|
2
|
+
from typing import Any, Optional
|
|
3
3
|
|
|
4
4
|
import keras
|
|
5
5
|
from keras import ops
|
|
6
6
|
|
|
7
7
|
from keras_rs.src import types
|
|
8
|
-
from keras_rs.src.
|
|
9
|
-
from keras_rs.src.utils
|
|
8
|
+
from keras_rs.src.losses.pairwise_loss_utils import pairwise_comparison
|
|
9
|
+
from keras_rs.src.metrics.utils import standardize_call_inputs_ranks
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
class PairwiseLoss(keras.losses.Loss, abc.ABC):
|
|
@@ -22,14 +22,22 @@ class PairwiseLoss(keras.losses.Loss, abc.ABC):
|
|
|
22
22
|
`pairwise_loss` method.
|
|
23
23
|
"""
|
|
24
24
|
|
|
25
|
-
|
|
25
|
+
def __init__(self, temperature: float = 1.0, **kwargs: Any) -> None:
|
|
26
|
+
super().__init__(**kwargs)
|
|
27
|
+
|
|
28
|
+
if temperature <= 0.0:
|
|
29
|
+
raise ValueError(
|
|
30
|
+
f"`temperature` should be a positive float. Received: "
|
|
31
|
+
f"`temperature` = {temperature}."
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
self.temperature = temperature
|
|
35
|
+
|
|
36
|
+
# TODO(abheesht): Add `lambda_weights`.
|
|
26
37
|
|
|
27
38
|
@abc.abstractmethod
|
|
28
39
|
def pairwise_loss(self, pairwise_logits: types.Tensor) -> types.Tensor:
|
|
29
|
-
|
|
30
|
-
"All subclasses of `keras_rs.losses.pairwise_loss.PairwiseLoss`"
|
|
31
|
-
"must implement the `pairwise_loss()` method."
|
|
32
|
-
)
|
|
40
|
+
pass
|
|
33
41
|
|
|
34
42
|
def compute_unreduced_loss(
|
|
35
43
|
self,
|
|
@@ -50,6 +58,10 @@ class PairwiseLoss(keras.losses.Loss, abc.ABC):
|
|
|
50
58
|
mask=valid_mask,
|
|
51
59
|
logits_op=ops.subtract,
|
|
52
60
|
)
|
|
61
|
+
pairwise_logits = ops.divide(
|
|
62
|
+
pairwise_logits,
|
|
63
|
+
ops.cast(self.temperature, dtype=pairwise_logits.dtype),
|
|
64
|
+
)
|
|
53
65
|
|
|
54
66
|
return self.pairwise_loss(pairwise_logits), pairwise_labels
|
|
55
67
|
|
|
@@ -66,8 +78,8 @@ class PairwiseLoss(keras.losses.Loss, abc.ABC):
|
|
|
66
78
|
in loss computation. If it is a dictionary, it should have two
|
|
67
79
|
keys: `"labels"` and `"mask"`. `"mask"` can be used to ignore
|
|
68
80
|
elements in loss computation, i.e., pairs will not be formed
|
|
69
|
-
with those items. Note that the final mask is an and of the
|
|
70
|
-
passed mask, and `labels
|
|
81
|
+
with those items. Note that the final mask is an `and` of the
|
|
82
|
+
passed mask, and `labels >= 0`.
|
|
71
83
|
y_pred: tensor. The predicted values, of shape `(list_size)` for
|
|
72
84
|
unbatched inputs or `(batch_size, list_size)` for batched
|
|
73
85
|
inputs. Should be of the same shape as `y_true`.
|
|
@@ -83,7 +95,14 @@ class PairwiseLoss(keras.losses.Loss, abc.ABC):
|
|
|
83
95
|
mask = y_true.get("mask", None)
|
|
84
96
|
y_true = y_true["labels"]
|
|
85
97
|
|
|
86
|
-
y_true
|
|
98
|
+
y_true = ops.convert_to_tensor(y_true)
|
|
99
|
+
y_pred = ops.convert_to_tensor(y_pred)
|
|
100
|
+
if mask is not None:
|
|
101
|
+
mask = ops.convert_to_tensor(mask)
|
|
102
|
+
|
|
103
|
+
y_true, y_pred, mask, _ = standardize_call_inputs_ranks(
|
|
104
|
+
y_true, y_pred, mask
|
|
105
|
+
)
|
|
87
106
|
|
|
88
107
|
losses, weights = self.compute_unreduced_loss(
|
|
89
108
|
labels=y_true, logits=y_pred, mask=mask
|
|
@@ -92,9 +111,14 @@ class PairwiseLoss(keras.losses.Loss, abc.ABC):
|
|
|
92
111
|
losses = ops.sum(losses, axis=-1)
|
|
93
112
|
return losses
|
|
94
113
|
|
|
114
|
+
def get_config(self) -> dict[str, Any]:
|
|
115
|
+
config: dict[str, Any] = super().get_config()
|
|
116
|
+
config.update({"temperature": self.temperature})
|
|
117
|
+
return config
|
|
118
|
+
|
|
95
119
|
|
|
96
120
|
pairwise_loss_subclass_doc_string = (
|
|
97
|
-
"Computes pairwise
|
|
121
|
+
"Computes pairwise {loss_name} between true labels and predicted scores."
|
|
98
122
|
"""
|
|
99
123
|
This loss function is designed for ranking tasks, where the goal is to
|
|
100
124
|
correctly order items within each list. It computes the loss by comparing
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
from typing import Callable
|
|
2
|
+
|
|
3
|
+
from keras import ops
|
|
4
|
+
|
|
5
|
+
from keras_rs.src import types
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def apply_pairwise_op(
|
|
9
|
+
x: types.Tensor, op: Callable[[types.Tensor, types.Tensor], types.Tensor]
|
|
10
|
+
) -> types.Tensor:
|
|
11
|
+
return op(
|
|
12
|
+
ops.expand_dims(x, axis=-1),
|
|
13
|
+
ops.expand_dims(x, axis=-2),
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def pairwise_comparison(
|
|
18
|
+
labels: types.Tensor,
|
|
19
|
+
logits: types.Tensor,
|
|
20
|
+
mask: types.Tensor,
|
|
21
|
+
logits_op: Callable[[types.Tensor, types.Tensor], types.Tensor],
|
|
22
|
+
) -> tuple[types.Tensor, types.Tensor]:
|
|
23
|
+
# Compute the difference for all pairs in a list. The output is a tensor
|
|
24
|
+
# with shape `(batch_size, list_size, list_size)`, where `[:, i, j]` stores
|
|
25
|
+
# information for pair `(i, j)`.
|
|
26
|
+
pairwise_labels_diff = apply_pairwise_op(labels, ops.subtract)
|
|
27
|
+
pairwise_logits = apply_pairwise_op(logits, logits_op)
|
|
28
|
+
|
|
29
|
+
# Keep only those cases where `l_i < l_j`.
|
|
30
|
+
pairwise_labels = ops.cast(
|
|
31
|
+
ops.greater(pairwise_labels_diff, 0), dtype=labels.dtype
|
|
32
|
+
)
|
|
33
|
+
if mask is not None:
|
|
34
|
+
valid_pairs = apply_pairwise_op(mask, ops.logical_and)
|
|
35
|
+
pairwise_labels = ops.multiply(
|
|
36
|
+
pairwise_labels, ops.cast(valid_pairs, dtype=pairwise_labels.dtype)
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
return pairwise_labels, pairwise_logits
|
|
@@ -6,7 +6,7 @@ from keras_rs.src import types
|
|
|
6
6
|
from keras_rs.src.api_export import keras_rs_export
|
|
7
7
|
from keras_rs.src.losses.pairwise_loss import PairwiseLoss
|
|
8
8
|
from keras_rs.src.losses.pairwise_loss import pairwise_loss_subclass_doc_string
|
|
9
|
-
from keras_rs.src.
|
|
9
|
+
from keras_rs.src.losses.pairwise_loss_utils import apply_pairwise_op
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
@keras_rs_export("keras_rs.losses.PairwiseMeanSquaredError")
|
|
@@ -65,6 +65,7 @@ explanation = """
|
|
|
65
65
|
"""
|
|
66
66
|
extra_args = ""
|
|
67
67
|
PairwiseMeanSquaredError.__doc__ = pairwise_loss_subclass_doc_string.format(
|
|
68
|
+
loss_name="mean squared error",
|
|
68
69
|
formula=formula,
|
|
69
70
|
explanation=explanation,
|
|
70
71
|
extra_args=extra_args,
|
|
File without changes
|
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
from typing import Any, Callable, Optional
|
|
2
|
+
|
|
3
|
+
from keras import ops
|
|
4
|
+
from keras.saving import deserialize_keras_object
|
|
5
|
+
from keras.saving import serialize_keras_object
|
|
6
|
+
|
|
7
|
+
from keras_rs.src import types
|
|
8
|
+
from keras_rs.src.api_export import keras_rs_export
|
|
9
|
+
from keras_rs.src.metrics.ranking_metric import RankingMetric
|
|
10
|
+
from keras_rs.src.metrics.ranking_metric import (
|
|
11
|
+
ranking_metric_subclass_doc_string,
|
|
12
|
+
)
|
|
13
|
+
from keras_rs.src.metrics.ranking_metric import (
|
|
14
|
+
ranking_metric_subclass_doc_string_args,
|
|
15
|
+
)
|
|
16
|
+
from keras_rs.src.metrics.ranking_metrics_utils import compute_dcg
|
|
17
|
+
from keras_rs.src.metrics.ranking_metrics_utils import default_gain_fn
|
|
18
|
+
from keras_rs.src.metrics.ranking_metrics_utils import default_rank_discount_fn
|
|
19
|
+
from keras_rs.src.metrics.ranking_metrics_utils import get_list_weights
|
|
20
|
+
from keras_rs.src.metrics.ranking_metrics_utils import sort_by_scores
|
|
21
|
+
from keras_rs.src.utils.doc_string_utils import format_docstring
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@keras_rs_export("keras_rs.metrics.DCG")
|
|
25
|
+
class DCG(RankingMetric):
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
k: Optional[int] = None,
|
|
29
|
+
gain_fn: Callable[[types.Tensor], types.Tensor] = default_gain_fn,
|
|
30
|
+
rank_discount_fn: Callable[
|
|
31
|
+
[types.Tensor], types.Tensor
|
|
32
|
+
] = default_rank_discount_fn,
|
|
33
|
+
**kwargs: Any,
|
|
34
|
+
) -> None:
|
|
35
|
+
super().__init__(k=k, **kwargs)
|
|
36
|
+
|
|
37
|
+
self.gain_fn = gain_fn
|
|
38
|
+
self.rank_discount_fn = rank_discount_fn
|
|
39
|
+
|
|
40
|
+
def compute_metric(
|
|
41
|
+
self,
|
|
42
|
+
y_true: types.Tensor,
|
|
43
|
+
y_pred: types.Tensor,
|
|
44
|
+
mask: types.Tensor,
|
|
45
|
+
sample_weight: types.Tensor,
|
|
46
|
+
) -> types.Tensor:
|
|
47
|
+
sorted_y_true, sorted_weights = sort_by_scores(
|
|
48
|
+
tensors_to_sort=[y_true, sample_weight],
|
|
49
|
+
scores=y_pred,
|
|
50
|
+
k=self.k,
|
|
51
|
+
mask=mask,
|
|
52
|
+
shuffle_ties=self.shuffle_ties,
|
|
53
|
+
seed=self.seed_generator,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
dcg = compute_dcg(
|
|
57
|
+
y_true=sorted_y_true,
|
|
58
|
+
sample_weight=sorted_weights,
|
|
59
|
+
gain_fn=self.gain_fn,
|
|
60
|
+
rank_discount_fn=self.rank_discount_fn,
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
per_list_weights = get_list_weights(
|
|
64
|
+
weights=sample_weight, relevance=self.gain_fn(y_true)
|
|
65
|
+
)
|
|
66
|
+
# Since we have already multiplied with `sample_weight`, we need to
|
|
67
|
+
# divide by `per_list_weights` so as to nullify the multiplication
|
|
68
|
+
# which `keras.metrics.Mean` will do.
|
|
69
|
+
per_list_dcg = ops.divide_no_nan(dcg, per_list_weights)
|
|
70
|
+
|
|
71
|
+
return per_list_dcg, per_list_weights
|
|
72
|
+
|
|
73
|
+
def get_config(self) -> dict[str, Any]:
|
|
74
|
+
config: dict[str, Any] = super().get_config()
|
|
75
|
+
config.update(
|
|
76
|
+
{
|
|
77
|
+
"gain_fn": serialize_keras_object(self.gain_fn),
|
|
78
|
+
"rank_discount_fn": serialize_keras_object(
|
|
79
|
+
self.rank_discount_fn
|
|
80
|
+
),
|
|
81
|
+
}
|
|
82
|
+
)
|
|
83
|
+
return config
|
|
84
|
+
|
|
85
|
+
@classmethod
|
|
86
|
+
def from_config(cls, config: dict[str, Any]) -> "DCG":
|
|
87
|
+
config["gain_fn"] = deserialize_keras_object(config["gain_fn"])
|
|
88
|
+
config["rank_discount_fn"] = deserialize_keras_object(
|
|
89
|
+
config["rank_discount_fn"]
|
|
90
|
+
)
|
|
91
|
+
return cls(**config)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
concept_sentence = (
|
|
95
|
+
"It computes the sum of the graded relevance scores of items, applying a "
|
|
96
|
+
"configurable discount based on position"
|
|
97
|
+
)
|
|
98
|
+
relevance_type = (
|
|
99
|
+
"graded relevance scores (non-negative numbers where higher values "
|
|
100
|
+
"indicate greater relevance)"
|
|
101
|
+
)
|
|
102
|
+
score_range_interpretation = (
|
|
103
|
+
"Scores are non-negative, with higher values indicating better ranking "
|
|
104
|
+
"quality (highly relevant items are ranked higher). The score for a single "
|
|
105
|
+
"list is not bounded or normalized, i.e., it does not lie in a range"
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
formula = """
|
|
109
|
+
```
|
|
110
|
+
DCG@k(y', w') = sum_{i=1}^{k} (gain_fn(y'_i) / rank_discount_fn(i))
|
|
111
|
+
```
|
|
112
|
+
|
|
113
|
+
where:
|
|
114
|
+
- `y'_i` is the true relevance score of the item ranked at position `i`
|
|
115
|
+
(obtained by sorting `y_true` according to `y_pred`).
|
|
116
|
+
- `gain_fn` is the user-provided function mapping relevance `y'_i` to a
|
|
117
|
+
gain value. The default function (`default_gain_fn`) is typically
|
|
118
|
+
equivalent to `lambda y: 2**y - 1`.
|
|
119
|
+
- `rank_discount_fn` is the user-provided function mapping rank `i`
|
|
120
|
+
to a discount value. The default function (`default_rank_discount_fn`)
|
|
121
|
+
is typically equivalent to `lambda rank: 1 / log2(rank + 1)`.
|
|
122
|
+
- The final result aggregates these per-list scores.
|
|
123
|
+
"""
|
|
124
|
+
extra_args = """
|
|
125
|
+
gain_fn: callable. Maps relevance scores (`y_true`) to gain values. The
|
|
126
|
+
default implements `2**y - 1`.
|
|
127
|
+
rank_discount_fn: function. Maps rank positions to discount
|
|
128
|
+
values. The default (`default_rank_discount_fn`) implements
|
|
129
|
+
`1 / log2(rank + 1)`."""
|
|
130
|
+
|
|
131
|
+
DCG.__doc__ = format_docstring(
|
|
132
|
+
ranking_metric_subclass_doc_string,
|
|
133
|
+
width=80,
|
|
134
|
+
metric_name="Discounted Cumulative Gain",
|
|
135
|
+
metric_abbreviation="DCG",
|
|
136
|
+
concept_sentence=concept_sentence,
|
|
137
|
+
relevance_type=relevance_type,
|
|
138
|
+
score_range_interpretation=score_range_interpretation,
|
|
139
|
+
formula=formula,
|
|
140
|
+
) + ranking_metric_subclass_doc_string_args.format(extra_args=extra_args)
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
from keras import ops
|
|
2
|
+
|
|
3
|
+
from keras_rs.src import types
|
|
4
|
+
from keras_rs.src.api_export import keras_rs_export
|
|
5
|
+
from keras_rs.src.metrics.ranking_metric import RankingMetric
|
|
6
|
+
from keras_rs.src.metrics.ranking_metric import (
|
|
7
|
+
ranking_metric_subclass_doc_string,
|
|
8
|
+
)
|
|
9
|
+
from keras_rs.src.metrics.ranking_metric import (
|
|
10
|
+
ranking_metric_subclass_doc_string_args,
|
|
11
|
+
)
|
|
12
|
+
from keras_rs.src.metrics.ranking_metrics_utils import get_list_weights
|
|
13
|
+
from keras_rs.src.metrics.ranking_metrics_utils import sort_by_scores
|
|
14
|
+
from keras_rs.src.utils.doc_string_utils import format_docstring
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@keras_rs_export("keras_rs.metrics.MeanAveragePrecision")
|
|
18
|
+
class MeanAveragePrecision(RankingMetric):
|
|
19
|
+
def compute_metric(
|
|
20
|
+
self,
|
|
21
|
+
y_true: types.Tensor,
|
|
22
|
+
y_pred: types.Tensor,
|
|
23
|
+
mask: types.Tensor,
|
|
24
|
+
sample_weight: types.Tensor,
|
|
25
|
+
) -> types.Tensor:
|
|
26
|
+
relevance = ops.cast(
|
|
27
|
+
ops.greater_equal(y_true, ops.cast(1, dtype=y_true.dtype)),
|
|
28
|
+
dtype="float32",
|
|
29
|
+
)
|
|
30
|
+
sorted_relevance, sorted_weights = sort_by_scores(
|
|
31
|
+
tensors_to_sort=[relevance, sample_weight],
|
|
32
|
+
scores=y_pred,
|
|
33
|
+
mask=mask,
|
|
34
|
+
k=self.k,
|
|
35
|
+
shuffle_ties=self.shuffle_ties,
|
|
36
|
+
seed=self.seed_generator,
|
|
37
|
+
)
|
|
38
|
+
per_list_relevant_counts = ops.cumsum(sorted_relevance, axis=1)
|
|
39
|
+
per_list_cutoffs = ops.cumsum(ops.ones_like(sorted_relevance), axis=1)
|
|
40
|
+
per_list_precisions = ops.divide_no_nan(
|
|
41
|
+
per_list_relevant_counts, per_list_cutoffs
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
total_precision = ops.sum(
|
|
45
|
+
ops.multiply(
|
|
46
|
+
per_list_precisions,
|
|
47
|
+
ops.multiply(sorted_weights, sorted_relevance),
|
|
48
|
+
),
|
|
49
|
+
axis=1,
|
|
50
|
+
keepdims=True,
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
# Compute the total relevance.
|
|
54
|
+
total_relevance = ops.sum(
|
|
55
|
+
ops.multiply(sample_weight, relevance), axis=1, keepdims=True
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
per_list_map = ops.divide_no_nan(total_precision, total_relevance)
|
|
59
|
+
|
|
60
|
+
per_list_weights = get_list_weights(sample_weight, relevance)
|
|
61
|
+
|
|
62
|
+
return per_list_map, per_list_weights
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
concept_sentence = (
|
|
66
|
+
"It calculates the average of precision values computed after each "
|
|
67
|
+
"relevant item present in the ranked list"
|
|
68
|
+
)
|
|
69
|
+
relevance_type = "binary indicators (0 or 1) of relevance"
|
|
70
|
+
score_range_interpretation = (
|
|
71
|
+
"Scores range from 0 to 1, with higher values indicating that relevant "
|
|
72
|
+
"items are generally positioned higher in the ranking"
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
formula = """
|
|
76
|
+
The formula for average precision is defined below. MAP is the mean over average
|
|
77
|
+
precision computed for each list.
|
|
78
|
+
|
|
79
|
+
```
|
|
80
|
+
AP(y, s) = sum_j (P@j(y, s) * rel(j)) / sum_i y_i
|
|
81
|
+
rel(j) = y_i if rank(s_i) = j
|
|
82
|
+
```
|
|
83
|
+
|
|
84
|
+
where:
|
|
85
|
+
- `j` represents the rank position (starting from 1).
|
|
86
|
+
- `sum_j` indicates a summation over all ranks `j` from 1 up to the list
|
|
87
|
+
size (or `k`).
|
|
88
|
+
- `P@j(y, s)` denotes the Precision at rank `j`, calculated as the
|
|
89
|
+
number of relevant items found within the top `j` positions divided by
|
|
90
|
+
`j`.
|
|
91
|
+
- `rel(j)` represents the relevance of the item specifically at rank
|
|
92
|
+
`j`. `rel(j)` is 1 if the item at rank `j` is relevant, and 0
|
|
93
|
+
otherwise.
|
|
94
|
+
- `y_i` is the true relevance label of the original item `i` before
|
|
95
|
+
ranking.
|
|
96
|
+
- `rank(s_i)` is the rank position assigned to item `i` based on its
|
|
97
|
+
score `s_i`.
|
|
98
|
+
- `sum_i y_i` calculates the total number of relevant items in the
|
|
99
|
+
original list `y`.
|
|
100
|
+
"""
|
|
101
|
+
extra_args = ""
|
|
102
|
+
|
|
103
|
+
MeanAveragePrecision.__doc__ = format_docstring(
|
|
104
|
+
ranking_metric_subclass_doc_string,
|
|
105
|
+
width=80,
|
|
106
|
+
metric_name="Mean Average Precision",
|
|
107
|
+
metric_abbreviation="MAP",
|
|
108
|
+
concept_sentence=concept_sentence,
|
|
109
|
+
relevance_type=relevance_type,
|
|
110
|
+
score_range_interpretation=score_range_interpretation,
|
|
111
|
+
formula=formula,
|
|
112
|
+
) + ranking_metric_subclass_doc_string_args.format(extra_args=extra_args)
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
from keras import ops
|
|
2
|
+
|
|
3
|
+
from keras_rs.src import types
|
|
4
|
+
from keras_rs.src.api_export import keras_rs_export
|
|
5
|
+
from keras_rs.src.metrics.ranking_metric import RankingMetric
|
|
6
|
+
from keras_rs.src.metrics.ranking_metric import (
|
|
7
|
+
ranking_metric_subclass_doc_string,
|
|
8
|
+
)
|
|
9
|
+
from keras_rs.src.metrics.ranking_metric import (
|
|
10
|
+
ranking_metric_subclass_doc_string_args,
|
|
11
|
+
)
|
|
12
|
+
from keras_rs.src.metrics.ranking_metrics_utils import get_list_weights
|
|
13
|
+
from keras_rs.src.metrics.ranking_metrics_utils import sort_by_scores
|
|
14
|
+
from keras_rs.src.utils.doc_string_utils import format_docstring
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@keras_rs_export("keras_rs.metrics.MeanReciprocalRank")
|
|
18
|
+
class MeanReciprocalRank(RankingMetric):
|
|
19
|
+
def compute_metric(
|
|
20
|
+
self,
|
|
21
|
+
y_true: types.Tensor,
|
|
22
|
+
y_pred: types.Tensor,
|
|
23
|
+
mask: types.Tensor,
|
|
24
|
+
sample_weight: types.Tensor,
|
|
25
|
+
) -> types.Tensor:
|
|
26
|
+
# Assume: `y_true = [0, 0, 1]`, `y_pred = [0.1, 0.9, 0.2]`.
|
|
27
|
+
# `sorted_y_true = [0, 1, 0]` (sorted in descending order).
|
|
28
|
+
(sorted_y_true,) = sort_by_scores(
|
|
29
|
+
tensors_to_sort=[y_true],
|
|
30
|
+
scores=y_pred,
|
|
31
|
+
mask=mask,
|
|
32
|
+
k=self.k,
|
|
33
|
+
shuffle_ties=self.shuffle_ties,
|
|
34
|
+
seed=self.seed_generator,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
# This will depend on `k`, i.e., it will not always be the same as
|
|
38
|
+
# `len(y_true)`.
|
|
39
|
+
list_length = ops.shape(sorted_y_true)[1]
|
|
40
|
+
|
|
41
|
+
# We consider only binary relevance here, anything above 1 is treated
|
|
42
|
+
# as 1. `relevance = [0., 1., 0.]`.
|
|
43
|
+
relevance = ops.cast(
|
|
44
|
+
ops.greater_equal(
|
|
45
|
+
sorted_y_true, ops.cast(1, dtype=sorted_y_true.dtype)
|
|
46
|
+
),
|
|
47
|
+
dtype="float32",
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
# `reciprocal_rank = [1, 0.5, 0.33]`
|
|
51
|
+
reciprocal_rank = ops.divide(
|
|
52
|
+
ops.cast(1, dtype="float32"),
|
|
53
|
+
ops.arange(1, list_length + 1, dtype="float32"),
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
# `mrr` should be of shape `(batch_size, 1)`.
|
|
57
|
+
# `mrr = amax([0., 0.5, 0.]) = 0.5`
|
|
58
|
+
mrr = ops.amax(
|
|
59
|
+
ops.multiply(relevance, reciprocal_rank),
|
|
60
|
+
axis=1,
|
|
61
|
+
keepdims=True,
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
# Get weights.
|
|
65
|
+
overall_relevance = ops.cast(
|
|
66
|
+
ops.greater_equal(y_true, ops.cast(1, dtype=y_true.dtype)),
|
|
67
|
+
dtype="float32",
|
|
68
|
+
)
|
|
69
|
+
per_list_weights = get_list_weights(
|
|
70
|
+
weights=sample_weight, relevance=overall_relevance
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
return mrr, per_list_weights
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
concept_sentence = (
|
|
77
|
+
"It focuses on the rank position of the single highest-scoring relevant "
|
|
78
|
+
"item"
|
|
79
|
+
)
|
|
80
|
+
relevance_type = "binary indicators (0 or 1) of relevance"
|
|
81
|
+
score_range_interpretation = (
|
|
82
|
+
"Scores range from 0 to 1, with 1 indicating the first relevant item was "
|
|
83
|
+
"always ranked first"
|
|
84
|
+
)
|
|
85
|
+
formula = """```
|
|
86
|
+
MRR(y, s) = max_{i} y_{i} / rank(s_{i})
|
|
87
|
+
```"""
|
|
88
|
+
extra_args = ""
|
|
89
|
+
MeanReciprocalRank.__doc__ = format_docstring(
|
|
90
|
+
ranking_metric_subclass_doc_string,
|
|
91
|
+
width=80,
|
|
92
|
+
metric_name="Mean Reciprocal Rank",
|
|
93
|
+
metric_abbreviation="MRR",
|
|
94
|
+
concept_sentence=concept_sentence,
|
|
95
|
+
relevance_type=relevance_type,
|
|
96
|
+
score_range_interpretation=score_range_interpretation,
|
|
97
|
+
formula=formula,
|
|
98
|
+
) + ranking_metric_subclass_doc_string_args.format(extra_args=extra_args)
|