keras-rs-nightly 0.0.1.dev2025021903__py3-none-any.whl → 0.3.1.dev202512130338__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_rs/__init__.py +9 -28
- keras_rs/layers/__init__.py +37 -0
- keras_rs/losses/__init__.py +19 -0
- keras_rs/metrics/__init__.py +16 -0
- keras_rs/src/layers/embedding/base_distributed_embedding.py +1151 -0
- keras_rs/src/layers/embedding/distributed_embedding.py +33 -0
- keras_rs/src/layers/embedding/distributed_embedding_config.py +132 -0
- keras_rs/src/layers/embedding/embed_reduce.py +309 -0
- keras_rs/src/layers/embedding/jax/__init__.py +0 -0
- keras_rs/src/layers/embedding/jax/checkpoint_utils.py +104 -0
- keras_rs/src/layers/embedding/jax/config_conversion.py +468 -0
- keras_rs/src/layers/embedding/jax/distributed_embedding.py +829 -0
- keras_rs/src/layers/embedding/jax/embedding_lookup.py +276 -0
- keras_rs/src/layers/embedding/jax/embedding_utils.py +217 -0
- keras_rs/src/layers/embedding/tensorflow/__init__.py +0 -0
- keras_rs/src/layers/embedding/tensorflow/config_conversion.py +363 -0
- keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py +436 -0
- keras_rs/src/layers/feature_interaction/__init__.py +0 -0
- keras_rs/src/layers/{modeling → feature_interaction}/dot_interaction.py +116 -25
- keras_rs/src/layers/{modeling → feature_interaction}/feature_cross.py +40 -22
- keras_rs/src/layers/retrieval/brute_force_retrieval.py +16 -65
- keras_rs/src/layers/retrieval/hard_negative_mining.py +94 -0
- keras_rs/src/layers/retrieval/remove_accidental_hits.py +97 -0
- keras_rs/src/layers/retrieval/retrieval.py +127 -0
- keras_rs/src/layers/retrieval/sampling_probability_correction.py +63 -0
- keras_rs/src/losses/__init__.py +0 -0
- keras_rs/src/losses/list_mle_loss.py +212 -0
- keras_rs/src/losses/pairwise_hinge_loss.py +90 -0
- keras_rs/src/losses/pairwise_logistic_loss.py +99 -0
- keras_rs/src/losses/pairwise_loss.py +165 -0
- keras_rs/src/losses/pairwise_loss_utils.py +39 -0
- keras_rs/src/losses/pairwise_mean_squared_error.py +133 -0
- keras_rs/src/losses/pairwise_soft_zero_one_loss.py +98 -0
- keras_rs/src/metrics/__init__.py +0 -0
- keras_rs/src/metrics/dcg.py +161 -0
- keras_rs/src/metrics/mean_average_precision.py +130 -0
- keras_rs/src/metrics/mean_reciprocal_rank.py +121 -0
- keras_rs/src/metrics/ndcg.py +197 -0
- keras_rs/src/metrics/precision_at_k.py +117 -0
- keras_rs/src/metrics/ranking_metric.py +260 -0
- keras_rs/src/metrics/ranking_metrics_utils.py +257 -0
- keras_rs/src/metrics/recall_at_k.py +108 -0
- keras_rs/src/metrics/utils.py +70 -0
- keras_rs/src/types.py +43 -14
- keras_rs/src/utils/doc_string_utils.py +53 -0
- keras_rs/src/utils/keras_utils.py +52 -3
- keras_rs/src/utils/tpu_test_utils.py +120 -0
- keras_rs/src/version.py +1 -1
- {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/METADATA +88 -8
- keras_rs_nightly-0.3.1.dev202512130338.dist-info/RECORD +58 -0
- {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/WHEEL +1 -1
- keras_rs/api/__init__.py +0 -9
- keras_rs/api/layers/__init__.py +0 -11
- keras_rs_nightly-0.0.1.dev2025021903.dist-info/RECORD +0 -19
- /keras_rs/src/layers/{modeling → embedding}/__init__.py +0 -0
- {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
from keras import ops
|
|
2
|
+
|
|
3
|
+
from keras_rs.src import types
|
|
4
|
+
from keras_rs.src.api_export import keras_rs_export
|
|
5
|
+
from keras_rs.src.losses.pairwise_loss import PairwiseLoss
|
|
6
|
+
from keras_rs.src.losses.pairwise_loss import pairwise_loss_subclass_doc_string
|
|
7
|
+
from keras_rs.src.losses.pairwise_loss_utils import apply_pairwise_op
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@keras_rs_export("keras_rs.losses.PairwiseMeanSquaredError")
|
|
11
|
+
class PairwiseMeanSquaredError(PairwiseLoss):
|
|
12
|
+
def pairwise_loss(self, pairwise_logits: types.Tensor) -> types.Tensor:
|
|
13
|
+
# Since we override `compute_unreduced_loss`, we do not need to
|
|
14
|
+
# implement this method.
|
|
15
|
+
pass
|
|
16
|
+
|
|
17
|
+
def compute_unreduced_loss(
|
|
18
|
+
self,
|
|
19
|
+
labels: types.Tensor,
|
|
20
|
+
logits: types.Tensor,
|
|
21
|
+
mask: types.Tensor | None = None,
|
|
22
|
+
) -> tuple[types.Tensor, types.Tensor]:
|
|
23
|
+
# Override `PairwiseLoss.compute_unreduced_loss` since pairwise weights
|
|
24
|
+
# for MSE are computed differently.
|
|
25
|
+
|
|
26
|
+
batch_size, list_size = ops.shape(labels)
|
|
27
|
+
|
|
28
|
+
# Mask all values less than 0 (since less than 0 implies invalid
|
|
29
|
+
# labels).
|
|
30
|
+
valid_mask = ops.greater_equal(labels, ops.cast(0.0, labels.dtype))
|
|
31
|
+
|
|
32
|
+
if mask is not None:
|
|
33
|
+
valid_mask = ops.logical_and(valid_mask, mask)
|
|
34
|
+
|
|
35
|
+
# Compute the difference for all pairs in a list. The output is a tensor
|
|
36
|
+
# with shape `(batch_size, list_size, list_size)`, where `[:, i, j]`
|
|
37
|
+
# stores information for pair `(i, j)`.
|
|
38
|
+
pairwise_labels_diff = apply_pairwise_op(labels, ops.subtract)
|
|
39
|
+
pairwise_logits_diff = apply_pairwise_op(logits, ops.subtract)
|
|
40
|
+
valid_pair = apply_pairwise_op(valid_mask, ops.logical_and)
|
|
41
|
+
pairwise_mse = ops.square(pairwise_labels_diff - pairwise_logits_diff)
|
|
42
|
+
|
|
43
|
+
# Compute weights.
|
|
44
|
+
pairwise_weights = ops.ones_like(pairwise_mse)
|
|
45
|
+
# Exclude self pairs.
|
|
46
|
+
pairwise_weights = ops.subtract(
|
|
47
|
+
pairwise_weights,
|
|
48
|
+
ops.tile(ops.eye(list_size, list_size), (batch_size, 1, 1)),
|
|
49
|
+
)
|
|
50
|
+
# Include only valid pairs.
|
|
51
|
+
pairwise_weights = ops.multiply(
|
|
52
|
+
pairwise_weights, ops.cast(valid_pair, dtype=pairwise_weights.dtype)
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
return pairwise_mse, pairwise_weights
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
formula = "loss = sum_{i} sum_{j} I(y_i > y_j) * (s_i - s_j)^2"
|
|
59
|
+
explanation = """
|
|
60
|
+
- `(s_i - s_j)^2` is the squared difference between the predicted scores
|
|
61
|
+
of items `i` and `j`, which penalizes discrepancies between the predicted
|
|
62
|
+
order of items relative to their true order.
|
|
63
|
+
"""
|
|
64
|
+
extra_args = ""
|
|
65
|
+
example = """
|
|
66
|
+
With `compile()` API:
|
|
67
|
+
|
|
68
|
+
```python
|
|
69
|
+
model.compile(
|
|
70
|
+
loss=keras_rs.losses.PairwiseMeanSquaredError(),
|
|
71
|
+
...
|
|
72
|
+
)
|
|
73
|
+
```
|
|
74
|
+
|
|
75
|
+
As a standalone function with unbatched inputs:
|
|
76
|
+
|
|
77
|
+
>>> y_true = np.array([1.0, 0.0, 1.0, 3.0, 2.0])
|
|
78
|
+
>>> y_pred = np.array([1.0, 3.0, 2.0, 4.0, 0.8])
|
|
79
|
+
>>> pairwise_mse = keras_rs.losses.PairwiseMeanSquaredError()
|
|
80
|
+
>>> pairwise_mse(y_true=y_true, y_pred=y_pred)
|
|
81
|
+
>>> 19.10400
|
|
82
|
+
|
|
83
|
+
With batched inputs using default 'auto'/'sum_over_batch_size' reduction:
|
|
84
|
+
|
|
85
|
+
>>> y_true = np.array([[1.0, 0.0, 1.0, 3.0], [0.0, 1.0, 2.0, 3.0]])
|
|
86
|
+
>>> y_pred = np.array([[1.0, 3.0, 2.0, 4.0], [1.0, 1.8, 2.0, 3.0]])
|
|
87
|
+
>>> pairwise_mse = keras_rs.losses.PairwiseMeanSquaredError()
|
|
88
|
+
>>> pairwise_mse(y_true=y_true, y_pred=y_pred)
|
|
89
|
+
5.57999
|
|
90
|
+
|
|
91
|
+
With masked inputs (useful for ragged inputs):
|
|
92
|
+
|
|
93
|
+
>>> y_true = {
|
|
94
|
+
... "labels": np.array([[1.0, 0.0, 1.0, 3.0], [0.0, 1.0, 2.0, 3.0]]),
|
|
95
|
+
... "mask": np.array(
|
|
96
|
+
... [[True, True, True, True], [True, True, False, False]]
|
|
97
|
+
... ),
|
|
98
|
+
... }
|
|
99
|
+
>>> y_pred = np.array([[1.0, 3.0, 2.0, 4.0], [1.0, 1.8, 2.0, 3.0]])
|
|
100
|
+
>>> pairwise_mse(y_true=y_true, y_pred=y_pred)
|
|
101
|
+
4.76000
|
|
102
|
+
|
|
103
|
+
With `sample_weight`:
|
|
104
|
+
|
|
105
|
+
>>> y_true = np.array([[1.0, 0.0, 1.0, 3.0], [0.0, 1.0, 2.0, 3.0]])
|
|
106
|
+
>>> y_pred = np.array([[1.0, 3.0, 2.0, 4.0], [1.0, 1.8, 2.0, 3.0]])
|
|
107
|
+
>>> sample_weight = np.array(
|
|
108
|
+
... [[2.0, 3.0, 1.0, 1.0], [2.0, 1.0, 0.0, 0.0]]
|
|
109
|
+
... )
|
|
110
|
+
>>> pairwise_mse = keras_rs.losses.PairwiseMeanSquaredError()
|
|
111
|
+
>>> pairwise_mse(
|
|
112
|
+
... y_true=y_true, y_pred=y_pred, sample_weight=sample_weight
|
|
113
|
+
... )
|
|
114
|
+
11.0500
|
|
115
|
+
|
|
116
|
+
Using `'none'` reduction:
|
|
117
|
+
|
|
118
|
+
>>> y_true = np.array([[1.0, 0.0, 1.0, 3.0], [0.0, 1.0, 2.0, 3.0]])
|
|
119
|
+
>>> y_pred = np.array([[1.0, 3.0, 2.0, 4.0], [1.0, 1.8, 2.0, 3.0]])
|
|
120
|
+
>>> pairwise_mse = keras_rs.losses.PairwiseMeanSquaredError(
|
|
121
|
+
... reduction="none"
|
|
122
|
+
... )
|
|
123
|
+
>>> pairwise_mse(y_true=y_true, y_pred=y_pred)
|
|
124
|
+
[[11., 17., 5., 5.], [2.04, 1.3199998, 1.6399999, 1.6399999]]
|
|
125
|
+
"""
|
|
126
|
+
|
|
127
|
+
PairwiseMeanSquaredError.__doc__ = pairwise_loss_subclass_doc_string.format(
|
|
128
|
+
loss_name="mean squared error",
|
|
129
|
+
formula=formula,
|
|
130
|
+
explanation=explanation,
|
|
131
|
+
extra_args=extra_args,
|
|
132
|
+
example=example,
|
|
133
|
+
)
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
from keras import ops
|
|
2
|
+
|
|
3
|
+
from keras_rs.src import types
|
|
4
|
+
from keras_rs.src.api_export import keras_rs_export
|
|
5
|
+
from keras_rs.src.losses.pairwise_loss import PairwiseLoss
|
|
6
|
+
from keras_rs.src.losses.pairwise_loss import pairwise_loss_subclass_doc_string
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@keras_rs_export("keras_rs.losses.PairwiseSoftZeroOneLoss")
|
|
10
|
+
class PairwiseSoftZeroOneLoss(PairwiseLoss):
|
|
11
|
+
def pairwise_loss(self, pairwise_logits: types.Tensor) -> types.Tensor:
|
|
12
|
+
return ops.where(
|
|
13
|
+
ops.greater(pairwise_logits, ops.array(0.0)),
|
|
14
|
+
ops.subtract(ops.array(1.0), ops.sigmoid(pairwise_logits)),
|
|
15
|
+
ops.sigmoid(ops.negative(pairwise_logits)),
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
formula = "loss = sum_{i} sum_{j} I(y_i > y_j) * (1 - sigmoid(s_i - s_j))"
|
|
20
|
+
explanation = """
|
|
21
|
+
- `(1 - sigmoid(s_i - s_j))` represents the soft zero-one loss, which
|
|
22
|
+
approximates the ideal zero-one loss (which would be 1 if `s_i < s_j`
|
|
23
|
+
and 0 otherwise) with a smooth, differentiable function. This makes it
|
|
24
|
+
suitable for gradient-based optimization.
|
|
25
|
+
"""
|
|
26
|
+
extra_args = ""
|
|
27
|
+
example = """
|
|
28
|
+
With `compile()` API:
|
|
29
|
+
|
|
30
|
+
```python
|
|
31
|
+
model.compile(
|
|
32
|
+
loss=keras_rs.losses.PairwiseSoftZeroOneLoss(),
|
|
33
|
+
...
|
|
34
|
+
)
|
|
35
|
+
```
|
|
36
|
+
|
|
37
|
+
As a standalone function with unbatched inputs:
|
|
38
|
+
|
|
39
|
+
>>> y_true = np.array([1.0, 0.0, 1.0, 3.0, 2.0])
|
|
40
|
+
>>> y_pred = np.array([1.0, 3.0, 2.0, 4.0, 0.8])
|
|
41
|
+
>>> pairwise_soft_zero_one_loss = keras_rs.losses.PairwiseSoftZeroOneLoss()
|
|
42
|
+
>>> pairwise_soft_zero_one_loss(y_true=y_true, y_pred=y_pred)
|
|
43
|
+
0.86103
|
|
44
|
+
|
|
45
|
+
With batched inputs using default 'auto'/'sum_over_batch_size' reduction:
|
|
46
|
+
|
|
47
|
+
>>> y_true = np.array([[1.0, 0.0, 1.0, 3.0], [0.0, 1.0, 2.0, 3.0]])
|
|
48
|
+
>>> y_pred = np.array([[1.0, 3.0, 2.0, 4.0], [1.0, 1.8, 2.0, 3.0]])
|
|
49
|
+
>>> pairwise_soft_zero_one_loss = keras_rs.losses.PairwiseSoftZeroOneLoss()
|
|
50
|
+
>>> pairwise_soft_zero_one_loss(y_true=y_true, y_pred=y_pred)
|
|
51
|
+
0.46202
|
|
52
|
+
|
|
53
|
+
With masked inputs (useful for ragged inputs):
|
|
54
|
+
|
|
55
|
+
>>> y_true = {
|
|
56
|
+
... "labels": np.array([[1.0, 0.0, 1.0, 3.0], [0.0, 1.0, 2.0, 3.0]]),
|
|
57
|
+
... "mask": np.array(
|
|
58
|
+
... [[True, True, True, True], [True, True, False, False]]
|
|
59
|
+
... ),
|
|
60
|
+
... }
|
|
61
|
+
>>> y_pred = np.array([[1.0, 3.0, 2.0, 4.0], [1.0, 1.8, 2.0, 3.0]])
|
|
62
|
+
>>> pairwise_soft_zero_one_loss(y_true=y_true, y_pred=y_pred)
|
|
63
|
+
0.29468
|
|
64
|
+
|
|
65
|
+
With `sample_weight`:
|
|
66
|
+
|
|
67
|
+
>>> y_true = np.array([[1.0, 0.0, 1.0, 3.0], [0.0, 1.0, 2.0, 3.0]])
|
|
68
|
+
>>> y_pred = np.array([[1.0, 3.0, 2.0, 4.0], [1.0, 1.8, 2.0, 3.0]])
|
|
69
|
+
>>> sample_weight = np.array(
|
|
70
|
+
... [[2.0, 3.0, 1.0, 1.0], [2.0, 1.0, 0.0, 0.0]]
|
|
71
|
+
... )
|
|
72
|
+
>>> pairwise_soft_zero_one_loss = keras_rs.losses.PairwiseSoftZeroOneLoss()
|
|
73
|
+
>>> pairwise_soft_zero_one_loss(
|
|
74
|
+
... y_true=y_true, y_pred=y_pred, sample_weight=sample_weight
|
|
75
|
+
... )
|
|
76
|
+
0.40478
|
|
77
|
+
|
|
78
|
+
Using `'none'` reduction:
|
|
79
|
+
|
|
80
|
+
>>> y_true = np.array([[1.0, 0.0, 1.0, 3.0], [0.0, 1.0, 2.0, 3.0]])
|
|
81
|
+
>>> y_pred = np.array([[1.0, 3.0, 2.0, 4.0], [1.0, 1.8, 2.0, 3.0]])
|
|
82
|
+
>>> pairwise_soft_zero_one_loss = keras_rs.losses.PairwiseSoftZeroOneLoss(
|
|
83
|
+
... reduction="none"
|
|
84
|
+
... )
|
|
85
|
+
>>> pairwise_soft_zero_one_loss(y_true=y_true, y_pred=y_pred)
|
|
86
|
+
[
|
|
87
|
+
[0.8807971 , 0., 0.73105854, 0.43557024],
|
|
88
|
+
[0., 0.31002545, 0.7191075 , 0.61961967]
|
|
89
|
+
]
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
PairwiseSoftZeroOneLoss.__doc__ = pairwise_loss_subclass_doc_string.format(
|
|
93
|
+
loss_name="soft zero-one loss",
|
|
94
|
+
formula=formula,
|
|
95
|
+
explanation=explanation,
|
|
96
|
+
extra_args=extra_args,
|
|
97
|
+
example=example,
|
|
98
|
+
)
|
|
File without changes
|
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
from typing import Any, Callable
|
|
2
|
+
|
|
3
|
+
from keras import ops
|
|
4
|
+
from keras.saving import deserialize_keras_object
|
|
5
|
+
from keras.saving import serialize_keras_object
|
|
6
|
+
|
|
7
|
+
from keras_rs.src import types
|
|
8
|
+
from keras_rs.src.api_export import keras_rs_export
|
|
9
|
+
from keras_rs.src.metrics.ranking_metric import RankingMetric
|
|
10
|
+
from keras_rs.src.metrics.ranking_metric import (
|
|
11
|
+
ranking_metric_subclass_doc_string,
|
|
12
|
+
)
|
|
13
|
+
from keras_rs.src.metrics.ranking_metric import (
|
|
14
|
+
ranking_metric_subclass_doc_string_post_desc,
|
|
15
|
+
)
|
|
16
|
+
from keras_rs.src.metrics.ranking_metrics_utils import compute_dcg
|
|
17
|
+
from keras_rs.src.metrics.ranking_metrics_utils import default_gain_fn
|
|
18
|
+
from keras_rs.src.metrics.ranking_metrics_utils import default_rank_discount_fn
|
|
19
|
+
from keras_rs.src.metrics.ranking_metrics_utils import get_list_weights
|
|
20
|
+
from keras_rs.src.metrics.ranking_metrics_utils import sort_by_scores
|
|
21
|
+
from keras_rs.src.utils.doc_string_utils import format_docstring
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@keras_rs_export("keras_rs.metrics.DCG")
|
|
25
|
+
class DCG(RankingMetric):
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
k: int | None = None,
|
|
29
|
+
gain_fn: Callable[[types.Tensor], types.Tensor] = default_gain_fn,
|
|
30
|
+
rank_discount_fn: Callable[
|
|
31
|
+
[types.Tensor], types.Tensor
|
|
32
|
+
] = default_rank_discount_fn,
|
|
33
|
+
**kwargs: Any,
|
|
34
|
+
) -> None:
|
|
35
|
+
super().__init__(k=k, **kwargs)
|
|
36
|
+
|
|
37
|
+
self.gain_fn = gain_fn
|
|
38
|
+
self.rank_discount_fn = rank_discount_fn
|
|
39
|
+
|
|
40
|
+
def compute_metric(
|
|
41
|
+
self,
|
|
42
|
+
y_true: types.Tensor,
|
|
43
|
+
y_pred: types.Tensor,
|
|
44
|
+
mask: types.Tensor,
|
|
45
|
+
sample_weight: types.Tensor,
|
|
46
|
+
) -> types.Tensor:
|
|
47
|
+
sorted_y_true, sorted_weights = sort_by_scores(
|
|
48
|
+
tensors_to_sort=[y_true, sample_weight],
|
|
49
|
+
scores=y_pred,
|
|
50
|
+
k=self.k,
|
|
51
|
+
mask=mask,
|
|
52
|
+
shuffle_ties=self.shuffle_ties,
|
|
53
|
+
seed=self.seed_generator,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
dcg = compute_dcg(
|
|
57
|
+
y_true=sorted_y_true,
|
|
58
|
+
sample_weight=sorted_weights,
|
|
59
|
+
gain_fn=self.gain_fn,
|
|
60
|
+
rank_discount_fn=self.rank_discount_fn,
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
per_list_weights = get_list_weights(
|
|
64
|
+
weights=sample_weight, relevance=self.gain_fn(y_true)
|
|
65
|
+
)
|
|
66
|
+
# Since we have already multiplied with `sample_weight`, we need to
|
|
67
|
+
# divide by `per_list_weights` so as to nullify the multiplication
|
|
68
|
+
# which `keras.metrics.Mean` will do.
|
|
69
|
+
per_list_dcg = ops.divide_no_nan(dcg, per_list_weights)
|
|
70
|
+
|
|
71
|
+
return per_list_dcg, per_list_weights
|
|
72
|
+
|
|
73
|
+
def get_config(self) -> dict[str, Any]:
|
|
74
|
+
config: dict[str, Any] = super().get_config()
|
|
75
|
+
config.update(
|
|
76
|
+
{
|
|
77
|
+
"gain_fn": serialize_keras_object(self.gain_fn),
|
|
78
|
+
"rank_discount_fn": serialize_keras_object(
|
|
79
|
+
self.rank_discount_fn
|
|
80
|
+
),
|
|
81
|
+
}
|
|
82
|
+
)
|
|
83
|
+
return config
|
|
84
|
+
|
|
85
|
+
@classmethod
|
|
86
|
+
def from_config(cls, config: dict[str, Any]) -> "DCG":
|
|
87
|
+
config["gain_fn"] = deserialize_keras_object(config["gain_fn"])
|
|
88
|
+
config["rank_discount_fn"] = deserialize_keras_object(
|
|
89
|
+
config["rank_discount_fn"]
|
|
90
|
+
)
|
|
91
|
+
return cls(**config)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
concept_sentence = (
|
|
95
|
+
"It computes the sum of the graded relevance scores of items, applying a "
|
|
96
|
+
"configurable discount based on position"
|
|
97
|
+
)
|
|
98
|
+
relevance_type = (
|
|
99
|
+
"graded relevance scores (non-negative numbers where higher values "
|
|
100
|
+
"indicate greater relevance)"
|
|
101
|
+
)
|
|
102
|
+
score_range_interpretation = (
|
|
103
|
+
"Scores are non-negative, with higher values indicating better ranking "
|
|
104
|
+
"quality (highly relevant items are ranked higher). The score for a single "
|
|
105
|
+
"list is not bounded or normalized, i.e., it does not lie in a range"
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
formula = """
|
|
109
|
+
```
|
|
110
|
+
DCG@k(y', w') = sum_{i=1}^{k} (gain_fn(y'_i) / rank_discount_fn(i))
|
|
111
|
+
```
|
|
112
|
+
|
|
113
|
+
where:
|
|
114
|
+
- `y'_i` is the true relevance score of the item ranked at position `i`
|
|
115
|
+
(obtained by sorting `y_true` according to `y_pred`).
|
|
116
|
+
- `gain_fn` is the user-provided function mapping relevance `y'_i` to a
|
|
117
|
+
gain value. The default function (`default_gain_fn`) is typically
|
|
118
|
+
equivalent to `lambda y: 2**y - 1`.
|
|
119
|
+
- `rank_discount_fn` is the user-provided function mapping rank `i`
|
|
120
|
+
to a discount value. The default function (`default_rank_discount_fn`)
|
|
121
|
+
is typically equivalent to `lambda rank: 1 / log2(rank + 1)`.
|
|
122
|
+
- The final result aggregates these per-list scores."""
|
|
123
|
+
extra_args = """
|
|
124
|
+
gain_fn: callable. Maps relevance scores (`y_true`) to gain values. The
|
|
125
|
+
default implements `2**y - 1`.
|
|
126
|
+
rank_discount_fn: function. Maps rank positions to discount
|
|
127
|
+
values. The default (`default_rank_discount_fn`) implements
|
|
128
|
+
`1 / log2(rank + 1)`."""
|
|
129
|
+
example = """
|
|
130
|
+
>>> batch_size = 2
|
|
131
|
+
>>> list_size = 5
|
|
132
|
+
>>> labels = np.random.randint(0, 3, size=(batch_size, list_size))
|
|
133
|
+
>>> scores = np.random.random(size=(batch_size, list_size))
|
|
134
|
+
>>> metric = keras_rs.metrics.DCG()(
|
|
135
|
+
... y_true=labels, y_pred=scores
|
|
136
|
+
... )
|
|
137
|
+
|
|
138
|
+
Mask certain elements (can be used for uneven inputs):
|
|
139
|
+
|
|
140
|
+
>>> batch_size = 2
|
|
141
|
+
>>> list_size = 5
|
|
142
|
+
>>> labels = np.random.randint(0, 3, size=(batch_size, list_size))
|
|
143
|
+
>>> scores = np.random.random(size=(batch_size, list_size))
|
|
144
|
+
>>> mask = np.random.randint(0, 2, size=(batch_size, list_size), dtype=bool)
|
|
145
|
+
>>> metric = keras_rs.metrics.DCG()(
|
|
146
|
+
... y_true={"labels": labels, "mask": mask}, y_pred=scores
|
|
147
|
+
... )
|
|
148
|
+
"""
|
|
149
|
+
|
|
150
|
+
DCG.__doc__ = format_docstring(
|
|
151
|
+
ranking_metric_subclass_doc_string,
|
|
152
|
+
width=80,
|
|
153
|
+
metric_name="Discounted Cumulative Gain",
|
|
154
|
+
metric_abbreviation="DCG",
|
|
155
|
+
concept_sentence=concept_sentence,
|
|
156
|
+
relevance_type=relevance_type,
|
|
157
|
+
score_range_interpretation=score_range_interpretation,
|
|
158
|
+
formula=formula,
|
|
159
|
+
) + ranking_metric_subclass_doc_string_post_desc.format(
|
|
160
|
+
extra_args=extra_args, example=example
|
|
161
|
+
)
|
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
from keras import ops
|
|
2
|
+
|
|
3
|
+
from keras_rs.src import types
|
|
4
|
+
from keras_rs.src.api_export import keras_rs_export
|
|
5
|
+
from keras_rs.src.metrics.ranking_metric import RankingMetric
|
|
6
|
+
from keras_rs.src.metrics.ranking_metric import (
|
|
7
|
+
ranking_metric_subclass_doc_string,
|
|
8
|
+
)
|
|
9
|
+
from keras_rs.src.metrics.ranking_metric import (
|
|
10
|
+
ranking_metric_subclass_doc_string_post_desc,
|
|
11
|
+
)
|
|
12
|
+
from keras_rs.src.metrics.ranking_metrics_utils import get_list_weights
|
|
13
|
+
from keras_rs.src.metrics.ranking_metrics_utils import sort_by_scores
|
|
14
|
+
from keras_rs.src.utils.doc_string_utils import format_docstring
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@keras_rs_export("keras_rs.metrics.MeanAveragePrecision")
|
|
18
|
+
class MeanAveragePrecision(RankingMetric):
|
|
19
|
+
def compute_metric(
|
|
20
|
+
self,
|
|
21
|
+
y_true: types.Tensor,
|
|
22
|
+
y_pred: types.Tensor,
|
|
23
|
+
mask: types.Tensor,
|
|
24
|
+
sample_weight: types.Tensor,
|
|
25
|
+
) -> types.Tensor:
|
|
26
|
+
relevance = ops.cast(
|
|
27
|
+
ops.greater_equal(y_true, ops.cast(1, dtype=y_true.dtype)),
|
|
28
|
+
dtype=y_pred.dtype,
|
|
29
|
+
)
|
|
30
|
+
sorted_relevance, sorted_weights = sort_by_scores(
|
|
31
|
+
tensors_to_sort=[relevance, sample_weight],
|
|
32
|
+
scores=y_pred,
|
|
33
|
+
mask=mask,
|
|
34
|
+
k=self.k,
|
|
35
|
+
shuffle_ties=self.shuffle_ties,
|
|
36
|
+
seed=self.seed_generator,
|
|
37
|
+
)
|
|
38
|
+
per_list_relevant_counts = ops.cumsum(sorted_relevance, axis=1)
|
|
39
|
+
per_list_cutoffs = ops.cumsum(ops.ones_like(sorted_relevance), axis=1)
|
|
40
|
+
per_list_precisions = ops.divide_no_nan(
|
|
41
|
+
per_list_relevant_counts, per_list_cutoffs
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
total_precision = ops.sum(
|
|
45
|
+
ops.multiply(
|
|
46
|
+
per_list_precisions,
|
|
47
|
+
ops.multiply(sorted_weights, sorted_relevance),
|
|
48
|
+
),
|
|
49
|
+
axis=1,
|
|
50
|
+
keepdims=True,
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
# Compute the total relevance.
|
|
54
|
+
total_relevance = ops.sum(
|
|
55
|
+
ops.multiply(sample_weight, relevance), axis=1, keepdims=True
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
per_list_map = ops.divide_no_nan(total_precision, total_relevance)
|
|
59
|
+
|
|
60
|
+
per_list_weights = get_list_weights(sample_weight, relevance)
|
|
61
|
+
|
|
62
|
+
return per_list_map, per_list_weights
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
concept_sentence = (
|
|
66
|
+
"It calculates the average of precision values computed after each "
|
|
67
|
+
"relevant item present in the ranked list"
|
|
68
|
+
)
|
|
69
|
+
relevance_type = "binary indicators (0 or 1) of relevance"
|
|
70
|
+
score_range_interpretation = (
|
|
71
|
+
"Scores range from 0 to 1, with higher values indicating that relevant "
|
|
72
|
+
"items are generally positioned higher in the ranking"
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
formula = """
|
|
76
|
+
The formula for average precision is defined below. MAP is the mean over average
|
|
77
|
+
precision computed for each list.
|
|
78
|
+
|
|
79
|
+
```
|
|
80
|
+
AP(y, s) = sum_j (P@j(y, s) * rel(j)) / sum_i y_i
|
|
81
|
+
rel(j) = y_i if rank(s_i) = j
|
|
82
|
+
```
|
|
83
|
+
|
|
84
|
+
where:
|
|
85
|
+
- `j` represents the rank position (starting from 1).
|
|
86
|
+
- `sum_j` indicates a summation over all ranks `j` from 1 up to the list
|
|
87
|
+
size (or `k`).
|
|
88
|
+
- `P@j(y, s)` denotes the Precision at rank `j`, calculated as the
|
|
89
|
+
number of relevant items found within the top `j` positions divided by `j`.
|
|
90
|
+
- `rel(j)` represents the relevance of the item specifically at rank
|
|
91
|
+
`j`. `rel(j)` is 1 if the item at rank `j` is relevant, and 0 otherwise.
|
|
92
|
+
- `y_i` is the true relevance label of the original item `i` before ranking.
|
|
93
|
+
- `rank(s_i)` is the rank position assigned to item `i` based on its score
|
|
94
|
+
`s_i`.
|
|
95
|
+
- `sum_i y_i` calculates the total number of relevant items in the original
|
|
96
|
+
list `y`."""
|
|
97
|
+
extra_args = ""
|
|
98
|
+
example = """
|
|
99
|
+
>>> batch_size = 2
|
|
100
|
+
>>> list_size = 5
|
|
101
|
+
>>> labels = np.random.randint(0, 2, size=(batch_size, list_size))
|
|
102
|
+
>>> scores = np.random.random(size=(batch_size, list_size))
|
|
103
|
+
>>> metric = keras_rs.metrics.MeanAveragePrecision()(
|
|
104
|
+
... y_true=labels, y_pred=scores
|
|
105
|
+
... )
|
|
106
|
+
|
|
107
|
+
Mask certain elements (can be used for uneven inputs):
|
|
108
|
+
|
|
109
|
+
>>> batch_size = 2
|
|
110
|
+
>>> list_size = 5
|
|
111
|
+
>>> labels = np.random.randint(0, 2, size=(batch_size, list_size))
|
|
112
|
+
>>> scores = np.random.random(size=(batch_size, list_size))
|
|
113
|
+
>>> mask = np.random.randint(0, 2, size=(batch_size, list_size), dtype=bool)
|
|
114
|
+
>>> metric = keras_rs.metrics.MeanAveragePrecision()(
|
|
115
|
+
... y_true={"labels": labels, "mask": mask}, y_pred=scores
|
|
116
|
+
... )
|
|
117
|
+
"""
|
|
118
|
+
|
|
119
|
+
MeanAveragePrecision.__doc__ = format_docstring(
|
|
120
|
+
ranking_metric_subclass_doc_string,
|
|
121
|
+
width=80,
|
|
122
|
+
metric_name="Mean Average Precision",
|
|
123
|
+
metric_abbreviation="MAP",
|
|
124
|
+
concept_sentence=concept_sentence,
|
|
125
|
+
relevance_type=relevance_type,
|
|
126
|
+
score_range_interpretation=score_range_interpretation,
|
|
127
|
+
formula=formula,
|
|
128
|
+
) + ranking_metric_subclass_doc_string_post_desc.format(
|
|
129
|
+
extra_args=extra_args, example=example
|
|
130
|
+
)
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
from keras import ops
|
|
2
|
+
|
|
3
|
+
from keras_rs.src import types
|
|
4
|
+
from keras_rs.src.api_export import keras_rs_export
|
|
5
|
+
from keras_rs.src.metrics.ranking_metric import RankingMetric
|
|
6
|
+
from keras_rs.src.metrics.ranking_metric import (
|
|
7
|
+
ranking_metric_subclass_doc_string,
|
|
8
|
+
)
|
|
9
|
+
from keras_rs.src.metrics.ranking_metric import (
|
|
10
|
+
ranking_metric_subclass_doc_string_post_desc,
|
|
11
|
+
)
|
|
12
|
+
from keras_rs.src.metrics.ranking_metrics_utils import get_list_weights
|
|
13
|
+
from keras_rs.src.metrics.ranking_metrics_utils import sort_by_scores
|
|
14
|
+
from keras_rs.src.utils.doc_string_utils import format_docstring
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@keras_rs_export("keras_rs.metrics.MeanReciprocalRank")
|
|
18
|
+
class MeanReciprocalRank(RankingMetric):
|
|
19
|
+
def compute_metric(
|
|
20
|
+
self,
|
|
21
|
+
y_true: types.Tensor,
|
|
22
|
+
y_pred: types.Tensor,
|
|
23
|
+
mask: types.Tensor,
|
|
24
|
+
sample_weight: types.Tensor,
|
|
25
|
+
) -> types.Tensor:
|
|
26
|
+
# Assume: `y_true = [0, 0, 1]`, `y_pred = [0.1, 0.9, 0.2]`.
|
|
27
|
+
# `sorted_y_true = [0, 1, 0]` (sorted in descending order).
|
|
28
|
+
(sorted_y_true,) = sort_by_scores(
|
|
29
|
+
tensors_to_sort=[y_true],
|
|
30
|
+
scores=y_pred,
|
|
31
|
+
mask=mask,
|
|
32
|
+
k=self.k,
|
|
33
|
+
shuffle_ties=self.shuffle_ties,
|
|
34
|
+
seed=self.seed_generator,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
# This will depend on `k`, i.e., it will not always be the same as
|
|
38
|
+
# `len(y_true)`.
|
|
39
|
+
list_length = ops.shape(sorted_y_true)[1]
|
|
40
|
+
|
|
41
|
+
# We consider only binary relevance here, anything above 1 is treated
|
|
42
|
+
# as 1. `relevance = [0., 1., 0.]`.
|
|
43
|
+
relevance = ops.cast(
|
|
44
|
+
ops.greater_equal(
|
|
45
|
+
sorted_y_true, ops.cast(1, dtype=sorted_y_true.dtype)
|
|
46
|
+
),
|
|
47
|
+
dtype=y_pred.dtype,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
# `reciprocal_rank = [1, 0.5, 0.33]`
|
|
51
|
+
reciprocal_rank = ops.divide(
|
|
52
|
+
ops.cast(1, dtype=y_pred.dtype),
|
|
53
|
+
ops.arange(1, list_length + 1, dtype=y_pred.dtype),
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
# `mrr` should be of shape `(batch_size, 1)`.
|
|
57
|
+
# `mrr = amax([0., 0.5, 0.]) = 0.5`
|
|
58
|
+
mrr = ops.amax(
|
|
59
|
+
ops.multiply(relevance, reciprocal_rank),
|
|
60
|
+
axis=1,
|
|
61
|
+
keepdims=True,
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
# Get weights.
|
|
65
|
+
overall_relevance = ops.cast(
|
|
66
|
+
ops.greater_equal(y_true, ops.cast(1, dtype=y_true.dtype)),
|
|
67
|
+
dtype=y_pred.dtype,
|
|
68
|
+
)
|
|
69
|
+
per_list_weights = get_list_weights(
|
|
70
|
+
weights=sample_weight, relevance=overall_relevance
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
return mrr, per_list_weights
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
concept_sentence = (
|
|
77
|
+
"It focuses on the rank position of the single highest-scoring relevant "
|
|
78
|
+
"item"
|
|
79
|
+
)
|
|
80
|
+
relevance_type = "binary indicators (0 or 1) of relevance"
|
|
81
|
+
score_range_interpretation = (
|
|
82
|
+
"Scores range from 0 to 1, with 1 indicating the first relevant item was "
|
|
83
|
+
"always ranked first"
|
|
84
|
+
)
|
|
85
|
+
formula = """```
|
|
86
|
+
MRR(y, s) = max_{i} y_{i} / rank(s_{i})
|
|
87
|
+
```"""
|
|
88
|
+
extra_args = ""
|
|
89
|
+
example = """
|
|
90
|
+
>>> batch_size = 2
|
|
91
|
+
>>> list_size = 5
|
|
92
|
+
>>> labels = np.random.randint(0, 2, size=(batch_size, list_size))
|
|
93
|
+
>>> scores = np.random.random(size=(batch_size, list_size))
|
|
94
|
+
>>> metric = keras_rs.metrics.MeanReciprocalRank()(
|
|
95
|
+
... y_true=labels, y_pred=scores
|
|
96
|
+
... )
|
|
97
|
+
|
|
98
|
+
Mask certain elements (can be used for uneven inputs):
|
|
99
|
+
|
|
100
|
+
>>> batch_size = 2
|
|
101
|
+
>>> list_size = 5
|
|
102
|
+
>>> labels = np.random.randint(0, 2, size=(batch_size, list_size))
|
|
103
|
+
>>> scores = np.random.random(size=(batch_size, list_size))
|
|
104
|
+
>>> mask = np.random.randint(0, 2, size=(batch_size, list_size), dtype=bool)
|
|
105
|
+
>>> metric = keras_rs.metrics.MeanReciprocalRank()(
|
|
106
|
+
... y_true={"labels": labels, "mask": mask}, y_pred=scores
|
|
107
|
+
... )
|
|
108
|
+
"""
|
|
109
|
+
|
|
110
|
+
MeanReciprocalRank.__doc__ = format_docstring(
|
|
111
|
+
ranking_metric_subclass_doc_string,
|
|
112
|
+
width=80,
|
|
113
|
+
metric_name="Mean Reciprocal Rank",
|
|
114
|
+
metric_abbreviation="MRR",
|
|
115
|
+
concept_sentence=concept_sentence,
|
|
116
|
+
relevance_type=relevance_type,
|
|
117
|
+
score_range_interpretation=score_range_interpretation,
|
|
118
|
+
formula=formula,
|
|
119
|
+
) + ranking_metric_subclass_doc_string_post_desc.format(
|
|
120
|
+
extra_args=extra_args, example=example
|
|
121
|
+
)
|