keras-rs-nightly 0.0.1.dev2025021903__py3-none-any.whl → 0.3.1.dev202512130338__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_rs/__init__.py +9 -28
- keras_rs/layers/__init__.py +37 -0
- keras_rs/losses/__init__.py +19 -0
- keras_rs/metrics/__init__.py +16 -0
- keras_rs/src/layers/embedding/base_distributed_embedding.py +1151 -0
- keras_rs/src/layers/embedding/distributed_embedding.py +33 -0
- keras_rs/src/layers/embedding/distributed_embedding_config.py +132 -0
- keras_rs/src/layers/embedding/embed_reduce.py +309 -0
- keras_rs/src/layers/embedding/jax/__init__.py +0 -0
- keras_rs/src/layers/embedding/jax/checkpoint_utils.py +104 -0
- keras_rs/src/layers/embedding/jax/config_conversion.py +468 -0
- keras_rs/src/layers/embedding/jax/distributed_embedding.py +829 -0
- keras_rs/src/layers/embedding/jax/embedding_lookup.py +276 -0
- keras_rs/src/layers/embedding/jax/embedding_utils.py +217 -0
- keras_rs/src/layers/embedding/tensorflow/__init__.py +0 -0
- keras_rs/src/layers/embedding/tensorflow/config_conversion.py +363 -0
- keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py +436 -0
- keras_rs/src/layers/feature_interaction/__init__.py +0 -0
- keras_rs/src/layers/{modeling → feature_interaction}/dot_interaction.py +116 -25
- keras_rs/src/layers/{modeling → feature_interaction}/feature_cross.py +40 -22
- keras_rs/src/layers/retrieval/brute_force_retrieval.py +16 -65
- keras_rs/src/layers/retrieval/hard_negative_mining.py +94 -0
- keras_rs/src/layers/retrieval/remove_accidental_hits.py +97 -0
- keras_rs/src/layers/retrieval/retrieval.py +127 -0
- keras_rs/src/layers/retrieval/sampling_probability_correction.py +63 -0
- keras_rs/src/losses/__init__.py +0 -0
- keras_rs/src/losses/list_mle_loss.py +212 -0
- keras_rs/src/losses/pairwise_hinge_loss.py +90 -0
- keras_rs/src/losses/pairwise_logistic_loss.py +99 -0
- keras_rs/src/losses/pairwise_loss.py +165 -0
- keras_rs/src/losses/pairwise_loss_utils.py +39 -0
- keras_rs/src/losses/pairwise_mean_squared_error.py +133 -0
- keras_rs/src/losses/pairwise_soft_zero_one_loss.py +98 -0
- keras_rs/src/metrics/__init__.py +0 -0
- keras_rs/src/metrics/dcg.py +161 -0
- keras_rs/src/metrics/mean_average_precision.py +130 -0
- keras_rs/src/metrics/mean_reciprocal_rank.py +121 -0
- keras_rs/src/metrics/ndcg.py +197 -0
- keras_rs/src/metrics/precision_at_k.py +117 -0
- keras_rs/src/metrics/ranking_metric.py +260 -0
- keras_rs/src/metrics/ranking_metrics_utils.py +257 -0
- keras_rs/src/metrics/recall_at_k.py +108 -0
- keras_rs/src/metrics/utils.py +70 -0
- keras_rs/src/types.py +43 -14
- keras_rs/src/utils/doc_string_utils.py +53 -0
- keras_rs/src/utils/keras_utils.py +52 -3
- keras_rs/src/utils/tpu_test_utils.py +120 -0
- keras_rs/src/version.py +1 -1
- {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/METADATA +88 -8
- keras_rs_nightly-0.3.1.dev202512130338.dist-info/RECORD +58 -0
- {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/WHEEL +1 -1
- keras_rs/api/__init__.py +0 -9
- keras_rs/api/layers/__init__.py +0 -11
- keras_rs_nightly-0.0.1.dev2025021903.dist-info/RECORD +0 -19
- /keras_rs/src/layers/{modeling → embedding}/__init__.py +0 -0
- {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
import keras
|
|
4
|
+
from keras import ops
|
|
5
|
+
|
|
6
|
+
from keras_rs.src import types
|
|
7
|
+
from keras_rs.src.api_export import keras_rs_export
|
|
8
|
+
from keras_rs.src.metrics.ranking_metrics_utils import sort_by_scores
|
|
9
|
+
from keras_rs.src.metrics.utils import standardize_call_inputs_ranks
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@keras_rs_export("keras_rs.losses.ListMLELoss")
|
|
13
|
+
class ListMLELoss(keras.losses.Loss):
|
|
14
|
+
"""Implements ListMLE (Maximum Likelihood Estimation) loss for ranking.
|
|
15
|
+
|
|
16
|
+
ListMLE loss is a listwise ranking loss that maximizes the likelihood of
|
|
17
|
+
the ground truth ranking. It works by:
|
|
18
|
+
1. Sorting items by their relevance scores (labels)
|
|
19
|
+
2. Computing the probability of observing this ranking given the
|
|
20
|
+
predicted scores
|
|
21
|
+
3. Maximizing this likelihood (minimizing negative log-likelihood)
|
|
22
|
+
|
|
23
|
+
The loss is computed as the negative log-likelihood of the ground truth
|
|
24
|
+
ranking given the predicted scores:
|
|
25
|
+
|
|
26
|
+
```
|
|
27
|
+
loss = -sum(log(exp(s_i) / sum(exp(s_j) for j >= i)))
|
|
28
|
+
```
|
|
29
|
+
|
|
30
|
+
where s_i is the predicted score for item i in the sorted order.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
temperature: Temperature parameter for scaling logits. Higher values
|
|
34
|
+
make the probability distribution more uniform. Defaults to 1.0.
|
|
35
|
+
reduction: Type of reduction to apply to the loss. In almost all cases
|
|
36
|
+
this should be `"sum_over_batch_size"`. Supported options are
|
|
37
|
+
`"sum"`, `"sum_over_batch_size"`, `"mean"`,
|
|
38
|
+
`"mean_with_sample_weight"` or `None`. Defaults to
|
|
39
|
+
`"sum_over_batch_size"`.
|
|
40
|
+
name: Optional name for the loss instance.
|
|
41
|
+
dtype: The dtype of the loss's computations. Defaults to `None`.
|
|
42
|
+
|
|
43
|
+
Examples:
|
|
44
|
+
```python
|
|
45
|
+
# Basic usage
|
|
46
|
+
loss_fn = ListMLELoss()
|
|
47
|
+
|
|
48
|
+
# With temperature scaling
|
|
49
|
+
loss_fn = ListMLELoss(temperature=0.5)
|
|
50
|
+
|
|
51
|
+
# Example with synthetic data
|
|
52
|
+
y_true = [[3, 2, 1, 0]] # Relevance scores
|
|
53
|
+
y_pred = [[0.8, 0.6, 0.4, 0.2]] # Predicted scores
|
|
54
|
+
loss = loss_fn(y_true, y_pred)
|
|
55
|
+
```
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
def __init__(self, temperature: float = 1.0, **kwargs: Any) -> None:
|
|
59
|
+
super().__init__(**kwargs)
|
|
60
|
+
|
|
61
|
+
if temperature <= 0.0:
|
|
62
|
+
raise ValueError(
|
|
63
|
+
f"`temperature` should be a positive float. Received: "
|
|
64
|
+
f"`temperature` = {temperature}."
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
self.temperature = temperature
|
|
68
|
+
self._epsilon = 1e-10
|
|
69
|
+
|
|
70
|
+
def compute_unreduced_loss(
|
|
71
|
+
self,
|
|
72
|
+
labels: types.Tensor,
|
|
73
|
+
logits: types.Tensor,
|
|
74
|
+
mask: types.Tensor | None = None,
|
|
75
|
+
) -> tuple[types.Tensor, types.Tensor]:
|
|
76
|
+
"""Compute the unreduced ListMLE loss.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
labels: Ground truth relevance scores of
|
|
80
|
+
shape [batch_size,list_size].
|
|
81
|
+
logits: Predicted scores of shape [batch_size, list_size].
|
|
82
|
+
mask: Optional mask of shape [batch_size, list_size].
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
Tuple of (losses, weights) where losses has shape [batch_size, 1]
|
|
86
|
+
and weights has the same shape.
|
|
87
|
+
"""
|
|
88
|
+
|
|
89
|
+
valid_mask = ops.greater_equal(labels, ops.cast(0.0, labels.dtype))
|
|
90
|
+
|
|
91
|
+
if mask is not None:
|
|
92
|
+
valid_mask = ops.logical_and(
|
|
93
|
+
valid_mask, ops.cast(mask, dtype="bool")
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
num_valid_items = ops.sum(
|
|
97
|
+
ops.cast(valid_mask, dtype=labels.dtype), axis=1, keepdims=True
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
batch_has_valid_items = ops.greater(num_valid_items, 0.0)
|
|
101
|
+
|
|
102
|
+
labels_for_sorting = ops.where(
|
|
103
|
+
valid_mask, labels, ops.full_like(labels, -1e9)
|
|
104
|
+
)
|
|
105
|
+
logits_masked = ops.where(
|
|
106
|
+
valid_mask, logits, ops.full_like(logits, -1e9)
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
sorted_logits, sorted_valid_mask = sort_by_scores(
|
|
110
|
+
tensors_to_sort=[logits_masked, valid_mask],
|
|
111
|
+
scores=labels_for_sorting,
|
|
112
|
+
mask=None,
|
|
113
|
+
shuffle_ties=False,
|
|
114
|
+
seed=None,
|
|
115
|
+
)
|
|
116
|
+
sorted_logits = ops.divide(
|
|
117
|
+
sorted_logits, ops.cast(self.temperature, dtype=sorted_logits.dtype)
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
valid_logits_for_max = ops.where(
|
|
121
|
+
sorted_valid_mask, sorted_logits, ops.full_like(sorted_logits, -1e9)
|
|
122
|
+
)
|
|
123
|
+
raw_max = ops.max(valid_logits_for_max, axis=1, keepdims=True)
|
|
124
|
+
raw_max = ops.where(
|
|
125
|
+
batch_has_valid_items, raw_max, ops.zeros_like(raw_max)
|
|
126
|
+
)
|
|
127
|
+
sorted_logits = ops.subtract(sorted_logits, raw_max)
|
|
128
|
+
|
|
129
|
+
# Set invalid positions to very negative BEFORE exp
|
|
130
|
+
sorted_logits = ops.where(
|
|
131
|
+
sorted_valid_mask, sorted_logits, ops.full_like(sorted_logits, -1e9)
|
|
132
|
+
)
|
|
133
|
+
exp_logits = ops.exp(sorted_logits)
|
|
134
|
+
|
|
135
|
+
reversed_exp = ops.flip(exp_logits, axis=1)
|
|
136
|
+
reversed_cumsum = ops.cumsum(reversed_exp, axis=1)
|
|
137
|
+
cumsum_from_right = ops.flip(reversed_cumsum, axis=1)
|
|
138
|
+
|
|
139
|
+
log_normalizers = ops.log(cumsum_from_right + self._epsilon)
|
|
140
|
+
log_probs = ops.subtract(sorted_logits, log_normalizers)
|
|
141
|
+
|
|
142
|
+
log_probs = ops.where(
|
|
143
|
+
sorted_valid_mask, log_probs, ops.zeros_like(log_probs)
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
negative_log_likelihood = ops.negative(
|
|
147
|
+
ops.sum(log_probs, axis=1, keepdims=True)
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
negative_log_likelihood = ops.where(
|
|
151
|
+
batch_has_valid_items,
|
|
152
|
+
negative_log_likelihood,
|
|
153
|
+
ops.zeros_like(negative_log_likelihood),
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
weights = ops.ones_like(negative_log_likelihood)
|
|
157
|
+
|
|
158
|
+
return negative_log_likelihood, weights
|
|
159
|
+
|
|
160
|
+
def call(
|
|
161
|
+
self,
|
|
162
|
+
y_true: types.Tensor,
|
|
163
|
+
y_pred: types.Tensor,
|
|
164
|
+
) -> types.Tensor:
|
|
165
|
+
"""Compute the ListMLE loss.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
y_true: tensor or dict. Ground truth values. If tensor, of shape
|
|
169
|
+
`(list_size)` for unbatched inputs or `(batch_size, list_size)`
|
|
170
|
+
for batched inputs. If an item has a label of -1, it is ignored
|
|
171
|
+
in loss computation. If it is a dictionary, it should have two
|
|
172
|
+
keys: `"labels"` and `"mask"`. `"mask"` can be used to ignore
|
|
173
|
+
elements in loss computation.
|
|
174
|
+
y_pred: tensor. The predicted values, of shape `(list_size)` for
|
|
175
|
+
unbatched inputs or `(batch_size, list_size)` for batched
|
|
176
|
+
inputs. Should be of the same shape as `y_true`.
|
|
177
|
+
|
|
178
|
+
Returns:
|
|
179
|
+
The loss tensor of shape [batch_size].
|
|
180
|
+
"""
|
|
181
|
+
mask = None
|
|
182
|
+
if isinstance(y_true, dict):
|
|
183
|
+
if "labels" not in y_true:
|
|
184
|
+
raise ValueError(
|
|
185
|
+
'`"labels"` should be present in `y_true`. Received: '
|
|
186
|
+
f"`y_true` = {y_true}"
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
mask = y_true.get("mask", None)
|
|
190
|
+
y_true = y_true["labels"]
|
|
191
|
+
|
|
192
|
+
y_true = ops.convert_to_tensor(y_true)
|
|
193
|
+
y_pred = ops.convert_to_tensor(y_pred)
|
|
194
|
+
if mask is not None:
|
|
195
|
+
mask = ops.convert_to_tensor(mask)
|
|
196
|
+
|
|
197
|
+
y_true, y_pred, mask, _ = standardize_call_inputs_ranks(
|
|
198
|
+
y_true, y_pred, mask
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
losses, weights = self.compute_unreduced_loss(
|
|
202
|
+
labels=y_true, logits=y_pred, mask=mask
|
|
203
|
+
)
|
|
204
|
+
losses = ops.multiply(losses, weights)
|
|
205
|
+
losses = ops.squeeze(losses, axis=-1)
|
|
206
|
+
return losses
|
|
207
|
+
|
|
208
|
+
# getting config
|
|
209
|
+
def get_config(self) -> dict[str, Any]:
|
|
210
|
+
config: dict[str, Any] = super().get_config()
|
|
211
|
+
config.update({"temperature": self.temperature})
|
|
212
|
+
return config
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
from keras import ops
|
|
2
|
+
|
|
3
|
+
from keras_rs.src import types
|
|
4
|
+
from keras_rs.src.api_export import keras_rs_export
|
|
5
|
+
from keras_rs.src.losses.pairwise_loss import PairwiseLoss
|
|
6
|
+
from keras_rs.src.losses.pairwise_loss import pairwise_loss_subclass_doc_string
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@keras_rs_export("keras_rs.losses.PairwiseHingeLoss")
|
|
10
|
+
class PairwiseHingeLoss(PairwiseLoss):
|
|
11
|
+
def pairwise_loss(self, pairwise_logits: types.Tensor) -> types.Tensor:
|
|
12
|
+
return ops.relu(ops.subtract(ops.array(1), pairwise_logits))
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
formula = "loss = sum_{i} sum_{j} I(y_i > y_j) * max(0, 1 - (s_i - s_j))"
|
|
16
|
+
explanation = """
|
|
17
|
+
- `max(0, 1 - (s_i - s_j))` is the hinge loss, which penalizes cases where
|
|
18
|
+
the score difference `s_i - s_j` is not sufficiently large when
|
|
19
|
+
`y_i > y_j`.
|
|
20
|
+
"""
|
|
21
|
+
extra_args = ""
|
|
22
|
+
example = """
|
|
23
|
+
With `compile()` API:
|
|
24
|
+
|
|
25
|
+
```python
|
|
26
|
+
model.compile(
|
|
27
|
+
loss=keras_rs.losses.PairwiseHingeLoss(),
|
|
28
|
+
...
|
|
29
|
+
)
|
|
30
|
+
```
|
|
31
|
+
|
|
32
|
+
As a standalone function with unbatched inputs:
|
|
33
|
+
|
|
34
|
+
>>> y_true = np.array([1.0, 0.0, 1.0, 3.0, 2.0])
|
|
35
|
+
>>> y_pred = np.array([1.0, 3.0, 2.0, 4.0, 0.8])
|
|
36
|
+
>>> pairwise_hinge_loss = keras_rs.losses.PairwiseHingeLoss()
|
|
37
|
+
>>> pairwise_hinge_loss(y_true=y_true, y_pred=y_pred)
|
|
38
|
+
2.32000
|
|
39
|
+
|
|
40
|
+
With batched inputs using default 'auto'/'sum_over_batch_size' reduction:
|
|
41
|
+
|
|
42
|
+
>>> y_true = np.array([[1.0, 0.0, 1.0, 3.0], [0.0, 1.0, 2.0, 3.0]])
|
|
43
|
+
>>> y_pred = np.array([[1.0, 3.0, 2.0, 4.0], [1.0, 1.8, 2.0, 3.0]])
|
|
44
|
+
>>> pairwise_hinge_loss = keras_rs.losses.PairwiseHingeLoss()
|
|
45
|
+
>>> pairwise_hinge_loss(y_true=y_true, y_pred=y_pred)
|
|
46
|
+
0.75
|
|
47
|
+
|
|
48
|
+
With masked inputs (useful for ragged inputs):
|
|
49
|
+
|
|
50
|
+
>>> y_true = {
|
|
51
|
+
... "labels": np.array([[1.0, 0.0, 1.0, 3.0], [0.0, 1.0, 2.0, 3.0]]),
|
|
52
|
+
... "mask": np.array(
|
|
53
|
+
... [[True, True, True, True], [True, True, False, False]]
|
|
54
|
+
... ),
|
|
55
|
+
... }
|
|
56
|
+
>>> y_pred = np.array([[1.0, 3.0, 2.0, 4.0], [1.0, 1.8, 2.0, 3.0]])
|
|
57
|
+
>>> pairwise_hinge_loss(y_true=y_true, y_pred=y_pred)
|
|
58
|
+
0.64999
|
|
59
|
+
|
|
60
|
+
With `sample_weight`:
|
|
61
|
+
|
|
62
|
+
>>> y_true = np.array([[1.0, 0.0, 1.0, 3.0], [0.0, 1.0, 2.0, 3.0]])
|
|
63
|
+
>>> y_pred = np.array([[1.0, 3.0, 2.0, 4.0], [1.0, 1.8, 2.0, 3.0]])
|
|
64
|
+
>>> sample_weight = np.array(
|
|
65
|
+
... [[2.0, 3.0, 1.0, 1.0], [2.0, 1.0, 0.0, 0.0]]
|
|
66
|
+
... )
|
|
67
|
+
>>> pairwise_hinge_loss = keras_rs.losses.PairwiseHingeLoss()
|
|
68
|
+
>>> pairwise_hinge_loss(
|
|
69
|
+
... y_true=y_true, y_pred=y_pred, sample_weight=sample_weight
|
|
70
|
+
... )
|
|
71
|
+
1.02499
|
|
72
|
+
|
|
73
|
+
Using `'none'` reduction:
|
|
74
|
+
|
|
75
|
+
>>> y_true = np.array([[1.0, 0.0, 1.0, 3.0], [0.0, 1.0, 2.0, 3.0]])
|
|
76
|
+
>>> y_pred = np.array([[1.0, 3.0, 2.0, 4.0], [1.0, 1.8, 2.0, 3.0]])
|
|
77
|
+
>>> pairwise_hinge_loss = keras_rs.losses.PairwiseHingeLoss(
|
|
78
|
+
... reduction="none"
|
|
79
|
+
... )
|
|
80
|
+
>>> pairwise_hinge_loss(y_true=y_true, y_pred=y_pred)
|
|
81
|
+
[[3. , 0. , 2. , 0.], [0., 0.20000005, 0.79999995, 0.]]
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
PairwiseHingeLoss.__doc__ = pairwise_loss_subclass_doc_string.format(
|
|
85
|
+
loss_name="hinge loss",
|
|
86
|
+
formula=formula,
|
|
87
|
+
explanation=explanation,
|
|
88
|
+
extra_args=extra_args,
|
|
89
|
+
example=example,
|
|
90
|
+
)
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
from keras import ops
|
|
2
|
+
|
|
3
|
+
from keras_rs.src import types
|
|
4
|
+
from keras_rs.src.api_export import keras_rs_export
|
|
5
|
+
from keras_rs.src.losses.pairwise_loss import PairwiseLoss
|
|
6
|
+
from keras_rs.src.losses.pairwise_loss import pairwise_loss_subclass_doc_string
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@keras_rs_export("keras_rs.losses.PairwiseLogisticLoss")
|
|
10
|
+
class PairwiseLogisticLoss(PairwiseLoss):
|
|
11
|
+
def pairwise_loss(self, pairwise_logits: types.Tensor) -> types.Tensor:
|
|
12
|
+
return ops.add(
|
|
13
|
+
ops.relu(ops.negative(pairwise_logits)),
|
|
14
|
+
ops.log(
|
|
15
|
+
ops.add(
|
|
16
|
+
ops.array(1),
|
|
17
|
+
ops.exp(ops.negative(ops.abs(pairwise_logits))),
|
|
18
|
+
)
|
|
19
|
+
),
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
formula = "loss = sum_{i} sum_{j} I(y_i > y_j) * log(1 + exp(-(s_i - s_j)))"
|
|
24
|
+
explanation = """
|
|
25
|
+
- `log(1 + exp(-(s_i - s_j)))` is the logistic loss, which penalizes
|
|
26
|
+
cases where the score difference `s_i - s_j` is not sufficiently large
|
|
27
|
+
when `y_i > y_j`. This function provides a smooth approximation of the
|
|
28
|
+
ideal step function, making it suitable for gradient-based optimization.
|
|
29
|
+
"""
|
|
30
|
+
extra_args = ""
|
|
31
|
+
example = """
|
|
32
|
+
With `compile()` API:
|
|
33
|
+
|
|
34
|
+
```python
|
|
35
|
+
model.compile(
|
|
36
|
+
loss=keras_rs.losses.PairwiseLogisticLoss(),
|
|
37
|
+
...
|
|
38
|
+
)
|
|
39
|
+
```
|
|
40
|
+
|
|
41
|
+
As a standalone function with unbatched inputs:
|
|
42
|
+
|
|
43
|
+
>>> y_true = np.array([1.0, 0.0, 1.0, 3.0, 2.0])
|
|
44
|
+
>>> y_pred = np.array([1.0, 3.0, 2.0, 4.0, 0.8])
|
|
45
|
+
>>> pairwise_logistic_loss = keras_rs.losses.PairwiseLogisticLoss()
|
|
46
|
+
>>> pairwise_logistic_loss(y_true=y_true, y_pred=y_pred)
|
|
47
|
+
>>> 1.70708
|
|
48
|
+
|
|
49
|
+
With batched inputs using default 'auto'/'sum_over_batch_size' reduction:
|
|
50
|
+
|
|
51
|
+
>>> y_true = np.array([[1.0, 0.0, 1.0, 3.0], [0.0, 1.0, 2.0, 3.0]])
|
|
52
|
+
>>> y_pred = np.array([[1.0, 3.0, 2.0, 4.0], [1.0, 1.8, 2.0, 3.0]])
|
|
53
|
+
>>> pairwise_logistic_loss = keras_rs.losses.PairwiseLogisticLoss()
|
|
54
|
+
>>> pairwise_logistic_loss(y_true=y_true, y_pred=y_pred)
|
|
55
|
+
0.73936
|
|
56
|
+
|
|
57
|
+
With masked inputs (useful for ragged inputs):
|
|
58
|
+
|
|
59
|
+
>>> y_true = {
|
|
60
|
+
... "labels": np.array([[1.0, 0.0, 1.0, 3.0], [0.0, 1.0, 2.0, 3.0]]),
|
|
61
|
+
... "mask": np.array(
|
|
62
|
+
... [[True, True, True, True], [True, True, False, False]]
|
|
63
|
+
... ),
|
|
64
|
+
... }
|
|
65
|
+
>>> y_pred = np.array([[1.0, 3.0, 2.0, 4.0], [1.0, 1.8, 2.0, 3.0]])
|
|
66
|
+
>>> pairwise_logistic_loss(y_true=y_true, y_pred=y_pred)
|
|
67
|
+
0.53751
|
|
68
|
+
|
|
69
|
+
With `sample_weight`:
|
|
70
|
+
|
|
71
|
+
>>> y_true = np.array([[1.0, 0.0, 1.0, 3.0], [0.0, 1.0, 2.0, 3.0]])
|
|
72
|
+
>>> y_pred = np.array([[1.0, 3.0, 2.0, 4.0], [1.0, 1.8, 2.0, 3.0]])
|
|
73
|
+
>>> sample_weight = np.array(
|
|
74
|
+
... [[2.0, 3.0, 1.0, 1.0], [2.0, 1.0, 0.0, 0.0]]
|
|
75
|
+
... )
|
|
76
|
+
>>> pairwise_logistic_loss = keras_rs.losses.PairwiseLogisticLoss()
|
|
77
|
+
>>> pairwise_logistic_loss(
|
|
78
|
+
... y_true=y_true, y_pred=y_pred, sample_weight=sample_weight
|
|
79
|
+
... )
|
|
80
|
+
>>> 0.80337
|
|
81
|
+
|
|
82
|
+
Using `'none'` reduction:
|
|
83
|
+
|
|
84
|
+
>>> y_true = np.array([[1.0, 0.0, 1.0, 3.0], [0.0, 1.0, 2.0, 3.0]])
|
|
85
|
+
>>> y_pred = np.array([[1.0, 3.0, 2.0, 4.0], [1.0, 1.8, 2.0, 3.0]])
|
|
86
|
+
>>> pairwise_logistic_loss = keras_rs.losses.PairwiseLogisticLoss(
|
|
87
|
+
... reduction="none"
|
|
88
|
+
... )
|
|
89
|
+
>>> pairwise_logistic_loss(y_true=y_true, y_pred=y_pred)
|
|
90
|
+
[[2.126928, 0., 1.3132616, 0.48877698], [0., 0.20000005, 0.79999995, 0.]]
|
|
91
|
+
"""
|
|
92
|
+
|
|
93
|
+
PairwiseLogisticLoss.__doc__ = pairwise_loss_subclass_doc_string.format(
|
|
94
|
+
loss_name="logistic loss",
|
|
95
|
+
formula=formula,
|
|
96
|
+
explanation=explanation,
|
|
97
|
+
extra_args=extra_args,
|
|
98
|
+
example=example,
|
|
99
|
+
)
|
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
import keras
|
|
5
|
+
from keras import ops
|
|
6
|
+
|
|
7
|
+
from keras_rs.src import types
|
|
8
|
+
from keras_rs.src.losses.pairwise_loss_utils import pairwise_comparison
|
|
9
|
+
from keras_rs.src.metrics.utils import standardize_call_inputs_ranks
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class PairwiseLoss(keras.losses.Loss, abc.ABC):
|
|
13
|
+
"""Base class for pairwise ranking losses.
|
|
14
|
+
|
|
15
|
+
Pairwise loss functions are designed for ranking tasks, where the goal is to
|
|
16
|
+
correctly order items within each list. Any pairwise loss function computes
|
|
17
|
+
the loss value by comparing pairs of items within each list, penalizing
|
|
18
|
+
cases where an item with a higher true label has a lower predicted score
|
|
19
|
+
than an item with a lower true label.
|
|
20
|
+
|
|
21
|
+
In order to implement any kind of pairwise loss, override the
|
|
22
|
+
`pairwise_loss` method.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__(self, temperature: float = 1.0, **kwargs: Any) -> None:
|
|
26
|
+
super().__init__(**kwargs)
|
|
27
|
+
|
|
28
|
+
if temperature <= 0.0:
|
|
29
|
+
raise ValueError(
|
|
30
|
+
f"`temperature` should be a positive float. Received: "
|
|
31
|
+
f"`temperature` = {temperature}."
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
self.temperature = temperature
|
|
35
|
+
|
|
36
|
+
# TODO(abheesht): Add `lambda_weights`.
|
|
37
|
+
|
|
38
|
+
@abc.abstractmethod
|
|
39
|
+
def pairwise_loss(self, pairwise_logits: types.Tensor) -> types.Tensor:
|
|
40
|
+
pass
|
|
41
|
+
|
|
42
|
+
def compute_unreduced_loss(
|
|
43
|
+
self,
|
|
44
|
+
labels: types.Tensor,
|
|
45
|
+
logits: types.Tensor,
|
|
46
|
+
mask: types.Tensor | None = None,
|
|
47
|
+
) -> tuple[types.Tensor, types.Tensor]:
|
|
48
|
+
# Mask all values less than 0 (since less than 0 implies invalid
|
|
49
|
+
# labels).
|
|
50
|
+
valid_mask = ops.greater_equal(labels, ops.cast(0.0, labels.dtype))
|
|
51
|
+
|
|
52
|
+
if mask is not None:
|
|
53
|
+
valid_mask = ops.logical_and(valid_mask, mask)
|
|
54
|
+
|
|
55
|
+
pairwise_labels, pairwise_logits = pairwise_comparison(
|
|
56
|
+
labels=labels,
|
|
57
|
+
logits=logits,
|
|
58
|
+
mask=valid_mask,
|
|
59
|
+
logits_op=ops.subtract,
|
|
60
|
+
)
|
|
61
|
+
pairwise_logits = ops.divide(
|
|
62
|
+
pairwise_logits,
|
|
63
|
+
ops.cast(self.temperature, dtype=pairwise_logits.dtype),
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
return self.pairwise_loss(pairwise_logits), pairwise_labels
|
|
67
|
+
|
|
68
|
+
def call(
|
|
69
|
+
self,
|
|
70
|
+
y_true: types.Tensor,
|
|
71
|
+
y_pred: types.Tensor,
|
|
72
|
+
) -> types.Tensor:
|
|
73
|
+
"""Compute the pairwise loss.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
y_true: tensor or dict. Ground truth values. If tensor, of shape
|
|
77
|
+
`(list_size)` for unbatched inputs or `(batch_size, list_size)`
|
|
78
|
+
for batched inputs. If an item has a label of -1, it is ignored
|
|
79
|
+
in loss computation. If it is a dictionary, it should have two
|
|
80
|
+
keys: `"labels"` and `"mask"`. `"mask"` can be used to ignore
|
|
81
|
+
elements in loss computation, i.e., pairs will not be formed
|
|
82
|
+
with those items. Note that the final mask is an `and` of the
|
|
83
|
+
passed mask, and `labels >= 0`.
|
|
84
|
+
y_pred: tensor. The predicted values, of shape `(list_size)` for
|
|
85
|
+
unbatched inputs or `(batch_size, list_size)` for batched
|
|
86
|
+
inputs. Should be of the same shape as `y_true`.
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
The loss.
|
|
90
|
+
"""
|
|
91
|
+
mask = None
|
|
92
|
+
if isinstance(y_true, dict):
|
|
93
|
+
if "labels" not in y_true:
|
|
94
|
+
raise ValueError(
|
|
95
|
+
'`"labels"` should be present in `y_true`. Received: '
|
|
96
|
+
f"`y_true` = {y_true}"
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
mask = y_true.get("mask", None)
|
|
100
|
+
y_true = y_true["labels"]
|
|
101
|
+
|
|
102
|
+
y_true = ops.convert_to_tensor(y_true)
|
|
103
|
+
y_pred = ops.convert_to_tensor(y_pred)
|
|
104
|
+
if mask is not None:
|
|
105
|
+
mask = ops.convert_to_tensor(mask)
|
|
106
|
+
|
|
107
|
+
y_true, y_pred, mask, _ = standardize_call_inputs_ranks(
|
|
108
|
+
y_true, y_pred, mask
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
losses, weights = self.compute_unreduced_loss(
|
|
112
|
+
labels=y_true, logits=y_pred, mask=mask
|
|
113
|
+
)
|
|
114
|
+
losses = ops.multiply(losses, weights)
|
|
115
|
+
losses = ops.sum(losses, axis=-1)
|
|
116
|
+
return losses
|
|
117
|
+
|
|
118
|
+
def get_config(self) -> dict[str, Any]:
|
|
119
|
+
config: dict[str, Any] = super().get_config()
|
|
120
|
+
config.update({"temperature": self.temperature})
|
|
121
|
+
return config
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
pairwise_loss_subclass_doc_string = (
|
|
125
|
+
"Computes pairwise {loss_name} between true labels and predicted scores."
|
|
126
|
+
"""
|
|
127
|
+
This loss function is designed for ranking tasks, where the goal is to
|
|
128
|
+
correctly order items within each list. It computes the loss by comparing
|
|
129
|
+
pairs of items within each list, penalizing cases where an item with a
|
|
130
|
+
higher true label has a lower predicted score than an item with a lower
|
|
131
|
+
true label.
|
|
132
|
+
|
|
133
|
+
For each list of predicted scores `s` in `y_pred` and the corresponding list
|
|
134
|
+
of true labels `y` in `y_true`, the loss is computed as follows:
|
|
135
|
+
|
|
136
|
+
```
|
|
137
|
+
{formula}
|
|
138
|
+
```
|
|
139
|
+
|
|
140
|
+
where:
|
|
141
|
+
|
|
142
|
+
- `y_i` and `y_j` are the true labels of items `i` and `j`, respectively.
|
|
143
|
+
- `s_i` and `s_j` are the predicted scores of items `i` and `j`,
|
|
144
|
+
respectively.
|
|
145
|
+
- `I(y_i > y_j)` is an indicator function that equals 1 if `y_i > y_j`,
|
|
146
|
+
and 0 otherwise.{explanation}
|
|
147
|
+
Args:{extra_args}
|
|
148
|
+
reduction: Type of reduction to apply to the loss. In almost all cases
|
|
149
|
+
this should be `"sum_over_batch_size"`. Supported options are
|
|
150
|
+
`"sum"`, `"sum_over_batch_size"`, `"mean"`,
|
|
151
|
+
`"mean_with_sample_weight"` or `None`. `"sum"` sums the loss,
|
|
152
|
+
`"sum_over_batch_size"` and `"mean"` sum the loss and divide by the
|
|
153
|
+
sample size, and `"mean_with_sample_weight"` sums the loss and
|
|
154
|
+
divides by the sum of the sample weights. `"none"` and `None`
|
|
155
|
+
perform no aggregation. Defaults to `"sum_over_batch_size"`.
|
|
156
|
+
name: Optional name for the loss instance.
|
|
157
|
+
dtype: The dtype of the loss's computations. Defaults to `None`, which
|
|
158
|
+
means using `keras.backend.floatx()`. `keras.backend.floatx()` is a
|
|
159
|
+
`"float32"` unless set to different value
|
|
160
|
+
(via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
|
|
161
|
+
provided, then the `compute_dtype` will be utilized.
|
|
162
|
+
|
|
163
|
+
Examples:
|
|
164
|
+
{example}"""
|
|
165
|
+
)
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
from typing import Callable
|
|
2
|
+
|
|
3
|
+
from keras import ops
|
|
4
|
+
|
|
5
|
+
from keras_rs.src import types
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def apply_pairwise_op(
|
|
9
|
+
x: types.Tensor, op: Callable[[types.Tensor, types.Tensor], types.Tensor]
|
|
10
|
+
) -> types.Tensor:
|
|
11
|
+
return op(
|
|
12
|
+
ops.expand_dims(x, axis=-1),
|
|
13
|
+
ops.expand_dims(x, axis=-2),
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def pairwise_comparison(
|
|
18
|
+
labels: types.Tensor,
|
|
19
|
+
logits: types.Tensor,
|
|
20
|
+
mask: types.Tensor,
|
|
21
|
+
logits_op: Callable[[types.Tensor, types.Tensor], types.Tensor],
|
|
22
|
+
) -> tuple[types.Tensor, types.Tensor]:
|
|
23
|
+
# Compute the difference for all pairs in a list. The output is a tensor
|
|
24
|
+
# with shape `(batch_size, list_size, list_size)`, where `[:, i, j]` stores
|
|
25
|
+
# information for pair `(i, j)`.
|
|
26
|
+
pairwise_labels_diff = apply_pairwise_op(labels, ops.subtract)
|
|
27
|
+
pairwise_logits = apply_pairwise_op(logits, logits_op)
|
|
28
|
+
|
|
29
|
+
# Keep only those cases where `l_i < l_j`.
|
|
30
|
+
pairwise_labels = ops.cast(
|
|
31
|
+
ops.greater(pairwise_labels_diff, 0), dtype=labels.dtype
|
|
32
|
+
)
|
|
33
|
+
if mask is not None:
|
|
34
|
+
valid_pairs = apply_pairwise_op(mask, ops.logical_and)
|
|
35
|
+
pairwise_labels = ops.multiply(
|
|
36
|
+
pairwise_labels, ops.cast(valid_pairs, dtype=pairwise_labels.dtype)
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
return pairwise_labels, pairwise_logits
|