keras-rs-nightly 0.0.1.dev2025021903__py3-none-any.whl → 0.3.1.dev202512130338__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_rs/__init__.py +9 -28
- keras_rs/layers/__init__.py +37 -0
- keras_rs/losses/__init__.py +19 -0
- keras_rs/metrics/__init__.py +16 -0
- keras_rs/src/layers/embedding/base_distributed_embedding.py +1151 -0
- keras_rs/src/layers/embedding/distributed_embedding.py +33 -0
- keras_rs/src/layers/embedding/distributed_embedding_config.py +132 -0
- keras_rs/src/layers/embedding/embed_reduce.py +309 -0
- keras_rs/src/layers/embedding/jax/__init__.py +0 -0
- keras_rs/src/layers/embedding/jax/checkpoint_utils.py +104 -0
- keras_rs/src/layers/embedding/jax/config_conversion.py +468 -0
- keras_rs/src/layers/embedding/jax/distributed_embedding.py +829 -0
- keras_rs/src/layers/embedding/jax/embedding_lookup.py +276 -0
- keras_rs/src/layers/embedding/jax/embedding_utils.py +217 -0
- keras_rs/src/layers/embedding/tensorflow/__init__.py +0 -0
- keras_rs/src/layers/embedding/tensorflow/config_conversion.py +363 -0
- keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py +436 -0
- keras_rs/src/layers/feature_interaction/__init__.py +0 -0
- keras_rs/src/layers/{modeling → feature_interaction}/dot_interaction.py +116 -25
- keras_rs/src/layers/{modeling → feature_interaction}/feature_cross.py +40 -22
- keras_rs/src/layers/retrieval/brute_force_retrieval.py +16 -65
- keras_rs/src/layers/retrieval/hard_negative_mining.py +94 -0
- keras_rs/src/layers/retrieval/remove_accidental_hits.py +97 -0
- keras_rs/src/layers/retrieval/retrieval.py +127 -0
- keras_rs/src/layers/retrieval/sampling_probability_correction.py +63 -0
- keras_rs/src/losses/__init__.py +0 -0
- keras_rs/src/losses/list_mle_loss.py +212 -0
- keras_rs/src/losses/pairwise_hinge_loss.py +90 -0
- keras_rs/src/losses/pairwise_logistic_loss.py +99 -0
- keras_rs/src/losses/pairwise_loss.py +165 -0
- keras_rs/src/losses/pairwise_loss_utils.py +39 -0
- keras_rs/src/losses/pairwise_mean_squared_error.py +133 -0
- keras_rs/src/losses/pairwise_soft_zero_one_loss.py +98 -0
- keras_rs/src/metrics/__init__.py +0 -0
- keras_rs/src/metrics/dcg.py +161 -0
- keras_rs/src/metrics/mean_average_precision.py +130 -0
- keras_rs/src/metrics/mean_reciprocal_rank.py +121 -0
- keras_rs/src/metrics/ndcg.py +197 -0
- keras_rs/src/metrics/precision_at_k.py +117 -0
- keras_rs/src/metrics/ranking_metric.py +260 -0
- keras_rs/src/metrics/ranking_metrics_utils.py +257 -0
- keras_rs/src/metrics/recall_at_k.py +108 -0
- keras_rs/src/metrics/utils.py +70 -0
- keras_rs/src/types.py +43 -14
- keras_rs/src/utils/doc_string_utils.py +53 -0
- keras_rs/src/utils/keras_utils.py +52 -3
- keras_rs/src/utils/tpu_test_utils.py +120 -0
- keras_rs/src/version.py +1 -1
- {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/METADATA +88 -8
- keras_rs_nightly-0.3.1.dev202512130338.dist-info/RECORD +58 -0
- {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/WHEEL +1 -1
- keras_rs/api/__init__.py +0 -9
- keras_rs/api/layers/__init__.py +0 -11
- keras_rs_nightly-0.0.1.dev2025021903.dist-info/RECORD +0 -19
- /keras_rs/src/layers/{modeling → embedding}/__init__.py +0 -0
- {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,197 @@
|
|
|
1
|
+
from typing import Any, Callable
|
|
2
|
+
|
|
3
|
+
from keras import ops
|
|
4
|
+
from keras.saving import deserialize_keras_object
|
|
5
|
+
from keras.saving import serialize_keras_object
|
|
6
|
+
|
|
7
|
+
from keras_rs.src import types
|
|
8
|
+
from keras_rs.src.api_export import keras_rs_export
|
|
9
|
+
from keras_rs.src.metrics.ranking_metric import RankingMetric
|
|
10
|
+
from keras_rs.src.metrics.ranking_metric import (
|
|
11
|
+
ranking_metric_subclass_doc_string,
|
|
12
|
+
)
|
|
13
|
+
from keras_rs.src.metrics.ranking_metric import (
|
|
14
|
+
ranking_metric_subclass_doc_string_post_desc,
|
|
15
|
+
)
|
|
16
|
+
from keras_rs.src.metrics.ranking_metrics_utils import compute_dcg
|
|
17
|
+
from keras_rs.src.metrics.ranking_metrics_utils import default_gain_fn
|
|
18
|
+
from keras_rs.src.metrics.ranking_metrics_utils import default_rank_discount_fn
|
|
19
|
+
from keras_rs.src.metrics.ranking_metrics_utils import get_list_weights
|
|
20
|
+
from keras_rs.src.metrics.ranking_metrics_utils import sort_by_scores
|
|
21
|
+
from keras_rs.src.utils.doc_string_utils import format_docstring
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@keras_rs_export("keras_rs.metrics.NDCG")
|
|
25
|
+
class NDCG(RankingMetric):
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
k: int | None = None,
|
|
29
|
+
gain_fn: Callable[[types.Tensor], types.Tensor] = default_gain_fn,
|
|
30
|
+
rank_discount_fn: Callable[
|
|
31
|
+
[types.Tensor], types.Tensor
|
|
32
|
+
] = default_rank_discount_fn,
|
|
33
|
+
**kwargs: Any,
|
|
34
|
+
) -> None:
|
|
35
|
+
super().__init__(k=k, **kwargs)
|
|
36
|
+
|
|
37
|
+
self.gain_fn = gain_fn
|
|
38
|
+
self.rank_discount_fn = rank_discount_fn
|
|
39
|
+
|
|
40
|
+
def compute_metric(
|
|
41
|
+
self,
|
|
42
|
+
y_true: types.Tensor,
|
|
43
|
+
y_pred: types.Tensor,
|
|
44
|
+
mask: types.Tensor,
|
|
45
|
+
sample_weight: types.Tensor,
|
|
46
|
+
) -> types.Tensor:
|
|
47
|
+
sorted_y_true, sorted_weights = sort_by_scores(
|
|
48
|
+
tensors_to_sort=[y_true, sample_weight],
|
|
49
|
+
scores=y_pred,
|
|
50
|
+
k=self.k,
|
|
51
|
+
mask=mask,
|
|
52
|
+
shuffle_ties=self.shuffle_ties,
|
|
53
|
+
seed=self.seed_generator,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
dcg = compute_dcg(
|
|
57
|
+
y_true=sorted_y_true,
|
|
58
|
+
sample_weight=sorted_weights,
|
|
59
|
+
gain_fn=self.gain_fn,
|
|
60
|
+
rank_discount_fn=self.rank_discount_fn,
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
weighted_gains = ops.multiply(
|
|
64
|
+
sample_weight,
|
|
65
|
+
self.gain_fn(y_true),
|
|
66
|
+
)
|
|
67
|
+
ideal_sorted_y_true, ideal_sorted_weights = sort_by_scores(
|
|
68
|
+
tensors_to_sort=[y_true, sample_weight],
|
|
69
|
+
scores=weighted_gains,
|
|
70
|
+
k=self.k,
|
|
71
|
+
mask=mask,
|
|
72
|
+
shuffle_ties=self.shuffle_ties,
|
|
73
|
+
seed=self.seed_generator,
|
|
74
|
+
)
|
|
75
|
+
ideal_dcg = compute_dcg(
|
|
76
|
+
y_true=ideal_sorted_y_true,
|
|
77
|
+
sample_weight=ideal_sorted_weights,
|
|
78
|
+
gain_fn=self.gain_fn,
|
|
79
|
+
rank_discount_fn=self.rank_discount_fn,
|
|
80
|
+
)
|
|
81
|
+
per_list_ndcg = ops.divide_no_nan(dcg, ideal_dcg)
|
|
82
|
+
|
|
83
|
+
per_list_weights = get_list_weights(
|
|
84
|
+
weights=sample_weight, relevance=self.gain_fn(y_true)
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
return per_list_ndcg, per_list_weights
|
|
88
|
+
|
|
89
|
+
def get_config(self) -> dict[str, Any]:
|
|
90
|
+
config: dict[str, Any] = super().get_config()
|
|
91
|
+
config.update(
|
|
92
|
+
{
|
|
93
|
+
"gain_fn": serialize_keras_object(self.gain_fn),
|
|
94
|
+
"rank_discount_fn": serialize_keras_object(
|
|
95
|
+
self.rank_discount_fn
|
|
96
|
+
),
|
|
97
|
+
}
|
|
98
|
+
)
|
|
99
|
+
return config
|
|
100
|
+
|
|
101
|
+
@classmethod
|
|
102
|
+
def from_config(cls, config: dict[str, Any]) -> "NDCG":
|
|
103
|
+
config["gain_fn"] = deserialize_keras_object(config["gain_fn"])
|
|
104
|
+
config["rank_discount_fn"] = deserialize_keras_object(
|
|
105
|
+
config["rank_discount_fn"]
|
|
106
|
+
)
|
|
107
|
+
return cls(**config)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
concept_sentence = (
|
|
111
|
+
"It normalizes the Discounted Cumulative Gain (DCG) with the Ideal "
|
|
112
|
+
"Discounted Cumulative Gain (IDCG) for each list"
|
|
113
|
+
)
|
|
114
|
+
relevance_type = (
|
|
115
|
+
"graded relevance scores (non-negative numbers where higher values "
|
|
116
|
+
"indicate greater relevance)"
|
|
117
|
+
)
|
|
118
|
+
score_range_interpretation = (
|
|
119
|
+
"A normalized score (between 0 and 1) is returned. A score of 1 "
|
|
120
|
+
"represents the perfect ranking according to true relevance (within the "
|
|
121
|
+
"top-k), while 0 typically represents a ranking with no relevant items. "
|
|
122
|
+
"Higher scores indicate better ranking quality relative to the best "
|
|
123
|
+
"possible ranking"
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
formula = """
|
|
127
|
+
```
|
|
128
|
+
nDCG@k = DCG@k / IDCG@k
|
|
129
|
+
```
|
|
130
|
+
|
|
131
|
+
where DCG@k is calculated based on the predicted ranking (`y_pred`):
|
|
132
|
+
|
|
133
|
+
```
|
|
134
|
+
DCG@k(y') = sum_{i=1}^{k} (gain_fn(y'_i) / rank_discount_fn(i))
|
|
135
|
+
```
|
|
136
|
+
|
|
137
|
+
And IDCG@k is the Ideal DCG, calculated using the same formula but on items
|
|
138
|
+
sorted perfectly by their *true relevance* (`y_true`):
|
|
139
|
+
|
|
140
|
+
```
|
|
141
|
+
IDCG@k(y'') = sum_{i=1}^{k} (gain_fn(y''_i) / rank_discount_fn(i))
|
|
142
|
+
```
|
|
143
|
+
|
|
144
|
+
where:
|
|
145
|
+
- `y'_i`: True relevance of the item at rank `i` in the ranking induced by
|
|
146
|
+
`y_pred`.
|
|
147
|
+
- `y''_i` True relevance of the item at rank `i` in the *ideal* ranking (sorted
|
|
148
|
+
by `y_true` descending).
|
|
149
|
+
- `gain_fn` is the user-provided function mapping relevance to gain. The default
|
|
150
|
+
function (`default_gain_fn`) is typically equivalent to `lambda y: 2**y - 1`.
|
|
151
|
+
- `rank_discount_fn` is the user-provided function mapping rank `i` (1-based) to
|
|
152
|
+
a discount value. The default function (`default_rank_discount_fn`) is
|
|
153
|
+
typically equivalent to `lambda rank: 1 / log2(rank + 1)`.
|
|
154
|
+
- If IDCG@k is 0 (e.g., no relevant items), nDCG@k is defined as 0.
|
|
155
|
+
- The final result often aggregates these per-list nDCG scores, potentially
|
|
156
|
+
involving normalization by list-specific weights, to produce a weighted
|
|
157
|
+
average."""
|
|
158
|
+
extra_args = """
|
|
159
|
+
gain_fn: callable. Maps relevance scores (`y_true`) to gain values. The
|
|
160
|
+
default implements `2**y - 1`.
|
|
161
|
+
rank_discount_fn: function. Maps rank positions to discount
|
|
162
|
+
values. The default (`default_rank_discount_fn`) implements
|
|
163
|
+
`1 / log2(rank + 1)`."""
|
|
164
|
+
example = """
|
|
165
|
+
>>> batch_size = 2
|
|
166
|
+
>>> list_size = 5
|
|
167
|
+
>>> labels = np.random.randint(0, 3, size=(batch_size, list_size))
|
|
168
|
+
>>> scores = np.random.random(size=(batch_size, list_size))
|
|
169
|
+
>>> metric = keras_rs.metrics.NDCG()(
|
|
170
|
+
... y_true=labels, y_pred=scores
|
|
171
|
+
... )
|
|
172
|
+
|
|
173
|
+
Mask certain elements (can be used for uneven inputs):
|
|
174
|
+
|
|
175
|
+
>>> batch_size = 2
|
|
176
|
+
>>> list_size = 5
|
|
177
|
+
>>> labels = np.random.randint(0, 3, size=(batch_size, list_size))
|
|
178
|
+
>>> scores = np.random.random(size=(batch_size, list_size))
|
|
179
|
+
>>> mask = np.random.randint(0, 2, size=(batch_size, list_size), dtype=bool)
|
|
180
|
+
>>> metric = keras_rs.metrics.NDCG()(
|
|
181
|
+
... y_true={"labels": labels, "mask": mask}, y_pred=scores
|
|
182
|
+
... )
|
|
183
|
+
"""
|
|
184
|
+
|
|
185
|
+
NDCG.__doc__ = format_docstring(
|
|
186
|
+
ranking_metric_subclass_doc_string,
|
|
187
|
+
width=80,
|
|
188
|
+
metric_name="Normalised Discounted Cumulative Gain",
|
|
189
|
+
metric_abbreviation="nDCG",
|
|
190
|
+
concept_sentence=concept_sentence,
|
|
191
|
+
relevance_type=relevance_type,
|
|
192
|
+
score_range_interpretation=score_range_interpretation,
|
|
193
|
+
formula=formula,
|
|
194
|
+
extra_args=extra_args,
|
|
195
|
+
) + ranking_metric_subclass_doc_string_post_desc.format(
|
|
196
|
+
extra_args=extra_args, example=example
|
|
197
|
+
)
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
from keras import ops
|
|
2
|
+
|
|
3
|
+
from keras_rs.src import types
|
|
4
|
+
from keras_rs.src.api_export import keras_rs_export
|
|
5
|
+
from keras_rs.src.metrics.ranking_metric import RankingMetric
|
|
6
|
+
from keras_rs.src.metrics.ranking_metric import (
|
|
7
|
+
ranking_metric_subclass_doc_string,
|
|
8
|
+
)
|
|
9
|
+
from keras_rs.src.metrics.ranking_metric import (
|
|
10
|
+
ranking_metric_subclass_doc_string_post_desc,
|
|
11
|
+
)
|
|
12
|
+
from keras_rs.src.metrics.ranking_metrics_utils import get_list_weights
|
|
13
|
+
from keras_rs.src.metrics.ranking_metrics_utils import sort_by_scores
|
|
14
|
+
from keras_rs.src.utils.doc_string_utils import format_docstring
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@keras_rs_export("keras_rs.metrics.PrecisionAtK")
|
|
18
|
+
class PrecisionAtK(RankingMetric):
|
|
19
|
+
def compute_metric(
|
|
20
|
+
self,
|
|
21
|
+
y_true: types.Tensor,
|
|
22
|
+
y_pred: types.Tensor,
|
|
23
|
+
mask: types.Tensor,
|
|
24
|
+
sample_weight: types.Tensor,
|
|
25
|
+
) -> types.Tensor:
|
|
26
|
+
# Assume: `y_true = [0, 0, 1]`, `y_pred = [0.1, 0.9, 0.2]`.
|
|
27
|
+
# `sorted_y_true = [0, 1, 0]` (sorted in descending order).
|
|
28
|
+
(sorted_y_true,) = sort_by_scores(
|
|
29
|
+
tensors_to_sort=[y_true],
|
|
30
|
+
scores=y_pred,
|
|
31
|
+
mask=mask,
|
|
32
|
+
k=self.k,
|
|
33
|
+
shuffle_ties=self.shuffle_ties,
|
|
34
|
+
seed=self.seed_generator,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
# We consider only binary relevance here, anything above 1 is treated
|
|
38
|
+
# as 1. `relevance = [0., 1., 0.]`.
|
|
39
|
+
relevance = ops.cast(
|
|
40
|
+
ops.greater_equal(
|
|
41
|
+
sorted_y_true, ops.cast(1, dtype=sorted_y_true.dtype)
|
|
42
|
+
),
|
|
43
|
+
dtype=y_pred.dtype,
|
|
44
|
+
)
|
|
45
|
+
list_length = ops.shape(sorted_y_true)[1]
|
|
46
|
+
# TODO: We do not do this for MRR, and the other metrics. Do we need to
|
|
47
|
+
# do this there too?
|
|
48
|
+
valid_list_length = ops.minimum(
|
|
49
|
+
list_length,
|
|
50
|
+
ops.sum(ops.cast(mask, dtype="int32"), axis=1, keepdims=True),
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
per_list_precision = ops.divide_no_nan(
|
|
54
|
+
ops.sum(relevance, axis=1, keepdims=True),
|
|
55
|
+
ops.cast(valid_list_length, dtype=y_pred.dtype),
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
# Get weights.
|
|
59
|
+
overall_relevance = ops.cast(
|
|
60
|
+
ops.greater_equal(y_true, ops.cast(1, dtype=y_true.dtype)),
|
|
61
|
+
dtype=y_pred.dtype,
|
|
62
|
+
)
|
|
63
|
+
per_list_weights = get_list_weights(
|
|
64
|
+
weights=sample_weight, relevance=overall_relevance
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
return per_list_precision, per_list_weights
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
concept_sentence = (
|
|
71
|
+
"It measures the proportion of relevant items among the top-k "
|
|
72
|
+
"recommendations"
|
|
73
|
+
)
|
|
74
|
+
relevance_type = "binary indicators (0 or 1) of relevance"
|
|
75
|
+
score_range_interpretation = (
|
|
76
|
+
"Scores range from 0 to 1, with 1 indicating all top-k items were relevant"
|
|
77
|
+
)
|
|
78
|
+
formula = """```
|
|
79
|
+
P@k(y, s) = 1/k sum_i I[rank(s_i) < k] y_i
|
|
80
|
+
```
|
|
81
|
+
|
|
82
|
+
where `y_i` is the relevance label (0/1) of the item ranked at position
|
|
83
|
+
`i`, and `I[condition]` is 1 if the condition is met, otherwise 0."""
|
|
84
|
+
extra_args = ""
|
|
85
|
+
example = """
|
|
86
|
+
>>> batch_size = 2
|
|
87
|
+
>>> list_size = 5
|
|
88
|
+
>>> labels = np.random.randint(0, 2, size=(batch_size, list_size))
|
|
89
|
+
>>> scores = np.random.random(size=(batch_size, list_size))
|
|
90
|
+
>>> metric = keras_rs.metrics.PrecisionAtK()(
|
|
91
|
+
... y_true=labels, y_pred=scores
|
|
92
|
+
... )
|
|
93
|
+
|
|
94
|
+
Mask certain elements (can be used for uneven inputs):
|
|
95
|
+
|
|
96
|
+
>>> batch_size = 2
|
|
97
|
+
>>> list_size = 5
|
|
98
|
+
>>> labels = np.random.randint(0, 2, size=(batch_size, list_size))
|
|
99
|
+
>>> scores = np.random.random(size=(batch_size, list_size))
|
|
100
|
+
>>> mask = np.random.randint(0, 2, size=(batch_size, list_size), dtype=bool)
|
|
101
|
+
>>> metric = keras_rs.metrics.PrecisionAtK()(
|
|
102
|
+
... y_true={"labels": labels, "mask": mask}, y_pred=scores
|
|
103
|
+
... )
|
|
104
|
+
"""
|
|
105
|
+
|
|
106
|
+
PrecisionAtK.__doc__ = format_docstring(
|
|
107
|
+
ranking_metric_subclass_doc_string,
|
|
108
|
+
width=80,
|
|
109
|
+
metric_name="Precision@k",
|
|
110
|
+
metric_abbreviation="P@k",
|
|
111
|
+
concept_sentence=concept_sentence,
|
|
112
|
+
relevance_type=relevance_type,
|
|
113
|
+
score_range_interpretation=score_range_interpretation,
|
|
114
|
+
formula=formula,
|
|
115
|
+
) + ranking_metric_subclass_doc_string_post_desc.format(
|
|
116
|
+
extra_args=extra_args, example=example
|
|
117
|
+
)
|
|
@@ -0,0 +1,260 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
import keras
|
|
5
|
+
from keras import ops
|
|
6
|
+
|
|
7
|
+
from keras_rs.src import types
|
|
8
|
+
from keras_rs.src.metrics.utils import standardize_call_inputs_ranks
|
|
9
|
+
from keras_rs.src.utils.keras_utils import check_rank
|
|
10
|
+
from keras_rs.src.utils.keras_utils import check_shapes_compatible
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class RankingMetric(keras.metrics.Mean, abc.ABC):
|
|
14
|
+
"""Base class for ranking evaluation metrics (e.g., MAP, MRR, DCG, nDCG).
|
|
15
|
+
|
|
16
|
+
Ranking metrics are used to evaluate the quality of ranked lists produced
|
|
17
|
+
by a ranking model. The primary goal in ranking tasks is to order items
|
|
18
|
+
according to their relevance or utility for a given query or context.
|
|
19
|
+
These metrics provide quantitative measures of how well a model achieves
|
|
20
|
+
this goal, typically by comparing the predicted order of items against the
|
|
21
|
+
ground truth relevance judgments for each list.
|
|
22
|
+
|
|
23
|
+
To define your own ranking metric, subclass this class and implement the
|
|
24
|
+
`compute_metric` method.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
k: int. The number of top-ranked items to consider (the 'k' in 'top-k').
|
|
28
|
+
Must be a positive integer.
|
|
29
|
+
shuffle_ties: bool. Whether to randomly shuffle scores before sorting.
|
|
30
|
+
This is done to break ties. Defaults to `True`.
|
|
31
|
+
seed: int. Random seed used for shuffling.
|
|
32
|
+
name: Optional name for the loss instance.
|
|
33
|
+
dtype: The dtype of the metric's computations. Defaults to `None`, which
|
|
34
|
+
means using `keras.backend.floatx()`. `keras.backend.floatx()` is a
|
|
35
|
+
`"float32"` unless set to different value
|
|
36
|
+
(via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
|
|
37
|
+
provided, then the `compute_dtype` will be utilized.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def __init__(
|
|
41
|
+
self,
|
|
42
|
+
k: int | None = None,
|
|
43
|
+
shuffle_ties: bool = True,
|
|
44
|
+
seed: int | keras.random.SeedGenerator | None = None,
|
|
45
|
+
**kwargs: Any,
|
|
46
|
+
) -> None:
|
|
47
|
+
super().__init__(**kwargs)
|
|
48
|
+
|
|
49
|
+
if k is not None and (not isinstance(k, int) or k < 1):
|
|
50
|
+
raise ValueError(
|
|
51
|
+
f"`k` should be a positive integer. Received: `k` = {k}."
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
self.k = k
|
|
55
|
+
self.shuffle_ties = shuffle_ties
|
|
56
|
+
self.seed = seed
|
|
57
|
+
|
|
58
|
+
# Define `SeedGenerator`. JAX doesn't work, otherwise.
|
|
59
|
+
self.seed_generator = keras.random.SeedGenerator(seed)
|
|
60
|
+
|
|
61
|
+
@abc.abstractmethod
|
|
62
|
+
def compute_metric(
|
|
63
|
+
self,
|
|
64
|
+
y_true: types.Tensor,
|
|
65
|
+
y_pred: types.Tensor,
|
|
66
|
+
mask: types.Tensor,
|
|
67
|
+
sample_weight: types.Tensor,
|
|
68
|
+
) -> types.Tensor:
|
|
69
|
+
"""Abstract method, should be implemented by subclasses."""
|
|
70
|
+
pass
|
|
71
|
+
|
|
72
|
+
def update_state(
|
|
73
|
+
self,
|
|
74
|
+
y_true: types.Tensor,
|
|
75
|
+
y_pred: types.Tensor,
|
|
76
|
+
sample_weight: types.Tensor | None = None,
|
|
77
|
+
) -> None:
|
|
78
|
+
"""
|
|
79
|
+
Accumulates statistics for the ranking metric.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
y_true: tensor or dict. Ground truth values. If tensor, of shape
|
|
83
|
+
`(list_size)` for unbatched inputs or `(batch_size, list_size)`
|
|
84
|
+
for batched inputs. If an item has a label of -1, it is ignored
|
|
85
|
+
in loss computation. If it is a dictionary, it should have two
|
|
86
|
+
keys: `"labels"` and `"mask"`. `"mask"` can be used to ignore
|
|
87
|
+
elements in metric computation, i.e., pairs will not be formed
|
|
88
|
+
with those items. Note that the final mask is an `and` of the
|
|
89
|
+
passed mask, `labels >= 0`, and `sample_weight > 0`.
|
|
90
|
+
y_pred: tensor. The predicted values, of shape `(list_size)` for
|
|
91
|
+
unbatched inputs or `(batch_size, list_size)` for batched
|
|
92
|
+
inputs. Should be of the same shape as `y_true`.
|
|
93
|
+
sample_weight: float/tensor. Can be float value, or tensor of
|
|
94
|
+
shape `(list_size)` or `(batch_size, list_size)`. Defaults to
|
|
95
|
+
`None`.
|
|
96
|
+
"""
|
|
97
|
+
# === Process `y_true`, if dict ===
|
|
98
|
+
passed_mask = None
|
|
99
|
+
if isinstance(y_true, dict):
|
|
100
|
+
if "labels" not in y_true:
|
|
101
|
+
raise ValueError(
|
|
102
|
+
'`"labels"` should be present in `y_true`. Received: '
|
|
103
|
+
f"`y_true` = {y_true}"
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
passed_mask = y_true.get("mask", None)
|
|
107
|
+
y_true = y_true["labels"]
|
|
108
|
+
|
|
109
|
+
# === Convert to tensors, if list ===
|
|
110
|
+
# TODO (abheesht): Figure out if we need to cast tensors to
|
|
111
|
+
# `self.dtype`.
|
|
112
|
+
y_true = ops.convert_to_tensor(y_true)
|
|
113
|
+
y_pred = ops.convert_to_tensor(y_pred)
|
|
114
|
+
if sample_weight is not None:
|
|
115
|
+
sample_weight = ops.convert_to_tensor(sample_weight)
|
|
116
|
+
if passed_mask is not None:
|
|
117
|
+
passed_mask = ops.convert_to_tensor(passed_mask)
|
|
118
|
+
|
|
119
|
+
# Cast to the correct dtype.
|
|
120
|
+
y_true = ops.cast(y_true, dtype=self.dtype)
|
|
121
|
+
y_pred = ops.cast(y_pred, dtype=self.dtype)
|
|
122
|
+
if sample_weight is not None:
|
|
123
|
+
sample_weight = ops.cast(sample_weight, dtype=self.dtype)
|
|
124
|
+
|
|
125
|
+
# === Process `sample_weight` ===
|
|
126
|
+
if sample_weight is None:
|
|
127
|
+
sample_weight = ops.cast(1, dtype=y_pred.dtype)
|
|
128
|
+
|
|
129
|
+
y_true_shape = ops.shape(y_true)
|
|
130
|
+
y_true_rank = len(y_true_shape)
|
|
131
|
+
sample_weight_shape = ops.shape(sample_weight)
|
|
132
|
+
sample_weight_rank = len(sample_weight_shape)
|
|
133
|
+
|
|
134
|
+
# Check `y_true_rank` first. Can be 1 for unbatched inputs, 2 for
|
|
135
|
+
# batched.
|
|
136
|
+
check_rank(y_true_rank, allowed_ranks=(1, 2), tensor_name="y_true")
|
|
137
|
+
|
|
138
|
+
# Check `sample_weight` rank. Should be between 0 and `y_true_rank`.
|
|
139
|
+
check_rank(
|
|
140
|
+
sample_weight_rank,
|
|
141
|
+
allowed_ranks=tuple(range(y_true_rank + 1)),
|
|
142
|
+
tensor_name="sample_weight",
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
if y_true_rank == 2:
|
|
146
|
+
# If `sample_weight` rank is 1, it should be of shape
|
|
147
|
+
# `(batch_size,)`. Otherwise, it should be of shape
|
|
148
|
+
# `(batch_size, list_size)`.
|
|
149
|
+
if sample_weight_rank == 1:
|
|
150
|
+
check_shapes_compatible(sample_weight_shape, (y_true_shape[0],))
|
|
151
|
+
# Uprank this, so that we get per-list weights here.
|
|
152
|
+
sample_weight = ops.expand_dims(sample_weight, axis=1)
|
|
153
|
+
elif sample_weight_rank == 2:
|
|
154
|
+
check_shapes_compatible(sample_weight_shape, y_true_shape)
|
|
155
|
+
|
|
156
|
+
# Reshape `sample_weight` to the shape of `y_true`.
|
|
157
|
+
sample_weight = ops.multiply(ops.ones_like(y_true), sample_weight)
|
|
158
|
+
|
|
159
|
+
# Mask all values less than 0 (since less than 0 implies invalid
|
|
160
|
+
# labels).
|
|
161
|
+
valid_mask = ops.greater_equal(y_true, ops.cast(0, y_true.dtype))
|
|
162
|
+
if passed_mask is not None:
|
|
163
|
+
valid_mask = ops.logical_and(valid_mask, passed_mask)
|
|
164
|
+
|
|
165
|
+
# === Process inputs - shape checking, upranking, etc. ===
|
|
166
|
+
y_true, y_pred, valid_mask, batched = standardize_call_inputs_ranks(
|
|
167
|
+
y_true=y_true,
|
|
168
|
+
y_pred=y_pred,
|
|
169
|
+
mask=valid_mask,
|
|
170
|
+
check_y_true_rank=False,
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
# Uprank sample_weight if unbatched.
|
|
174
|
+
if not batched:
|
|
175
|
+
sample_weight = ops.expand_dims(sample_weight, axis=0)
|
|
176
|
+
|
|
177
|
+
# Get `mask` from `sample_weight`.
|
|
178
|
+
sample_weight_mask = ops.greater(
|
|
179
|
+
sample_weight, ops.cast(0, dtype=sample_weight.dtype)
|
|
180
|
+
)
|
|
181
|
+
mask = ops.logical_and(valid_mask, sample_weight_mask)
|
|
182
|
+
|
|
183
|
+
# === Update "invalid" `y_true`, `y_pred` entries based on mask ===
|
|
184
|
+
|
|
185
|
+
# `y_true`: assign 0 for invalid entries
|
|
186
|
+
y_true = ops.where(mask, y_true, ops.zeros_like(y_true))
|
|
187
|
+
# `y_pred`: assign a value slightly smaller than the smallest value
|
|
188
|
+
# so that it gets sorted last.
|
|
189
|
+
y_pred = ops.where(
|
|
190
|
+
mask,
|
|
191
|
+
y_pred,
|
|
192
|
+
-keras.config.epsilon() * ops.ones_like(y_pred)
|
|
193
|
+
+ ops.amin(y_pred, axis=1, keepdims=True),
|
|
194
|
+
)
|
|
195
|
+
sample_weight = ops.where(
|
|
196
|
+
mask, sample_weight, ops.cast(0, sample_weight.dtype)
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
# === Actual computation ===
|
|
200
|
+
per_list_metric_values, per_list_metric_weights = self.compute_metric(
|
|
201
|
+
y_true=y_true, y_pred=y_pred, mask=mask, sample_weight=sample_weight
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
# Chain to `super()` to get mean metric.
|
|
205
|
+
# TODO(abheesht): Figure out if we want to return unaggregated metric
|
|
206
|
+
# values too of shape `(batch_size,)` from `result()`.
|
|
207
|
+
super().update_state(
|
|
208
|
+
per_list_metric_values, sample_weight=per_list_metric_weights
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
def get_config(self) -> dict[str, Any]:
|
|
212
|
+
config: dict[str, Any] = super().get_config()
|
|
213
|
+
config.update(
|
|
214
|
+
{"k": self.k, "shuffle_ties": self.shuffle_ties, "seed": self.seed}
|
|
215
|
+
)
|
|
216
|
+
return config
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
ranking_metric_subclass_doc_string = """
|
|
220
|
+
Computes {metric_name} ({metric_abbreviation}).
|
|
221
|
+
|
|
222
|
+
This metric evaluates ranking quality. {concept_sentence}. The metric processes
|
|
223
|
+
true relevance labels in `y_true` ({relevance_type}) against predicted scores in
|
|
224
|
+
`y_pred`. The scores in `y_pred` are used to determine the rank order of items,
|
|
225
|
+
by sorting in descending order. {score_range_interpretation}.
|
|
226
|
+
|
|
227
|
+
For each list of predicted scores `s` in `y_pred` and the corresponding list
|
|
228
|
+
of true labels `y` in `y_true`, the per-query {metric_abbreviation} score is
|
|
229
|
+
calculated as follows:
|
|
230
|
+
{formula}
|
|
231
|
+
|
|
232
|
+
The final {metric_abbreviation} score reported is typically the weighted
|
|
233
|
+
average of these per-query scores across all queries/lists in the dataset.
|
|
234
|
+
|
|
235
|
+
Note: `sample_weight` is handled differently for ranking metrics. For
|
|
236
|
+
batched inputs, `sample_weight` can be scalar, 1D, 2D. The scalar case and
|
|
237
|
+
1D case (list-wise weights) are straightforward. The 2D case (item-wise
|
|
238
|
+
weights) is different, in the sense that the sample weights are aggregated
|
|
239
|
+
to get 1D weights. For more details, refer to
|
|
240
|
+
`keras_rs.src.metrics.ranking_metrics_utils.get_list_weights`.
|
|
241
|
+
"""
|
|
242
|
+
|
|
243
|
+
ranking_metric_subclass_doc_string_post_desc = """
|
|
244
|
+
|
|
245
|
+
Args:{extra_args}
|
|
246
|
+
k: int. The number of top-ranked items to consider (the 'k' in 'top-k').
|
|
247
|
+
Must be a positive integer.
|
|
248
|
+
shuffle_ties: bool. Whether to randomly shuffle scores before sorting.
|
|
249
|
+
This is done to break ties. Defaults to `True`.
|
|
250
|
+
seed: int. Random seed used for shuffling.
|
|
251
|
+
name: Optional name for the loss instance.
|
|
252
|
+
dtype: The dtype of the metric's computations. Defaults to `None`, which
|
|
253
|
+
means using `keras.backend.floatx()`. `keras.backend.floatx()` is a
|
|
254
|
+
`"float32"` unless set to different value
|
|
255
|
+
(via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
|
|
256
|
+
provided, then the `compute_dtype` will be utilized.
|
|
257
|
+
|
|
258
|
+
Example:
|
|
259
|
+
{example}
|
|
260
|
+
"""
|