keras-rs-nightly 0.0.1.dev2025021903__py3-none-any.whl → 0.3.1.dev202512130338__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. keras_rs/__init__.py +9 -28
  2. keras_rs/layers/__init__.py +37 -0
  3. keras_rs/losses/__init__.py +19 -0
  4. keras_rs/metrics/__init__.py +16 -0
  5. keras_rs/src/layers/embedding/base_distributed_embedding.py +1151 -0
  6. keras_rs/src/layers/embedding/distributed_embedding.py +33 -0
  7. keras_rs/src/layers/embedding/distributed_embedding_config.py +132 -0
  8. keras_rs/src/layers/embedding/embed_reduce.py +309 -0
  9. keras_rs/src/layers/embedding/jax/__init__.py +0 -0
  10. keras_rs/src/layers/embedding/jax/checkpoint_utils.py +104 -0
  11. keras_rs/src/layers/embedding/jax/config_conversion.py +468 -0
  12. keras_rs/src/layers/embedding/jax/distributed_embedding.py +829 -0
  13. keras_rs/src/layers/embedding/jax/embedding_lookup.py +276 -0
  14. keras_rs/src/layers/embedding/jax/embedding_utils.py +217 -0
  15. keras_rs/src/layers/embedding/tensorflow/__init__.py +0 -0
  16. keras_rs/src/layers/embedding/tensorflow/config_conversion.py +363 -0
  17. keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py +436 -0
  18. keras_rs/src/layers/feature_interaction/__init__.py +0 -0
  19. keras_rs/src/layers/{modeling → feature_interaction}/dot_interaction.py +116 -25
  20. keras_rs/src/layers/{modeling → feature_interaction}/feature_cross.py +40 -22
  21. keras_rs/src/layers/retrieval/brute_force_retrieval.py +16 -65
  22. keras_rs/src/layers/retrieval/hard_negative_mining.py +94 -0
  23. keras_rs/src/layers/retrieval/remove_accidental_hits.py +97 -0
  24. keras_rs/src/layers/retrieval/retrieval.py +127 -0
  25. keras_rs/src/layers/retrieval/sampling_probability_correction.py +63 -0
  26. keras_rs/src/losses/__init__.py +0 -0
  27. keras_rs/src/losses/list_mle_loss.py +212 -0
  28. keras_rs/src/losses/pairwise_hinge_loss.py +90 -0
  29. keras_rs/src/losses/pairwise_logistic_loss.py +99 -0
  30. keras_rs/src/losses/pairwise_loss.py +165 -0
  31. keras_rs/src/losses/pairwise_loss_utils.py +39 -0
  32. keras_rs/src/losses/pairwise_mean_squared_error.py +133 -0
  33. keras_rs/src/losses/pairwise_soft_zero_one_loss.py +98 -0
  34. keras_rs/src/metrics/__init__.py +0 -0
  35. keras_rs/src/metrics/dcg.py +161 -0
  36. keras_rs/src/metrics/mean_average_precision.py +130 -0
  37. keras_rs/src/metrics/mean_reciprocal_rank.py +121 -0
  38. keras_rs/src/metrics/ndcg.py +197 -0
  39. keras_rs/src/metrics/precision_at_k.py +117 -0
  40. keras_rs/src/metrics/ranking_metric.py +260 -0
  41. keras_rs/src/metrics/ranking_metrics_utils.py +257 -0
  42. keras_rs/src/metrics/recall_at_k.py +108 -0
  43. keras_rs/src/metrics/utils.py +70 -0
  44. keras_rs/src/types.py +43 -14
  45. keras_rs/src/utils/doc_string_utils.py +53 -0
  46. keras_rs/src/utils/keras_utils.py +52 -3
  47. keras_rs/src/utils/tpu_test_utils.py +120 -0
  48. keras_rs/src/version.py +1 -1
  49. {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/METADATA +88 -8
  50. keras_rs_nightly-0.3.1.dev202512130338.dist-info/RECORD +58 -0
  51. {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/WHEEL +1 -1
  52. keras_rs/api/__init__.py +0 -9
  53. keras_rs/api/layers/__init__.py +0 -11
  54. keras_rs_nightly-0.0.1.dev2025021903.dist-info/RECORD +0 -19
  55. /keras_rs/src/layers/{modeling → embedding}/__init__.py +0 -0
  56. {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,212 @@
1
+ from typing import Any
2
+
3
+ import keras
4
+ from keras import ops
5
+
6
+ from keras_rs.src import types
7
+ from keras_rs.src.api_export import keras_rs_export
8
+ from keras_rs.src.metrics.ranking_metrics_utils import sort_by_scores
9
+ from keras_rs.src.metrics.utils import standardize_call_inputs_ranks
10
+
11
+
12
+ @keras_rs_export("keras_rs.losses.ListMLELoss")
13
+ class ListMLELoss(keras.losses.Loss):
14
+ """Implements ListMLE (Maximum Likelihood Estimation) loss for ranking.
15
+
16
+ ListMLE loss is a listwise ranking loss that maximizes the likelihood of
17
+ the ground truth ranking. It works by:
18
+ 1. Sorting items by their relevance scores (labels)
19
+ 2. Computing the probability of observing this ranking given the
20
+ predicted scores
21
+ 3. Maximizing this likelihood (minimizing negative log-likelihood)
22
+
23
+ The loss is computed as the negative log-likelihood of the ground truth
24
+ ranking given the predicted scores:
25
+
26
+ ```
27
+ loss = -sum(log(exp(s_i) / sum(exp(s_j) for j >= i)))
28
+ ```
29
+
30
+ where s_i is the predicted score for item i in the sorted order.
31
+
32
+ Args:
33
+ temperature: Temperature parameter for scaling logits. Higher values
34
+ make the probability distribution more uniform. Defaults to 1.0.
35
+ reduction: Type of reduction to apply to the loss. In almost all cases
36
+ this should be `"sum_over_batch_size"`. Supported options are
37
+ `"sum"`, `"sum_over_batch_size"`, `"mean"`,
38
+ `"mean_with_sample_weight"` or `None`. Defaults to
39
+ `"sum_over_batch_size"`.
40
+ name: Optional name for the loss instance.
41
+ dtype: The dtype of the loss's computations. Defaults to `None`.
42
+
43
+ Examples:
44
+ ```python
45
+ # Basic usage
46
+ loss_fn = ListMLELoss()
47
+
48
+ # With temperature scaling
49
+ loss_fn = ListMLELoss(temperature=0.5)
50
+
51
+ # Example with synthetic data
52
+ y_true = [[3, 2, 1, 0]] # Relevance scores
53
+ y_pred = [[0.8, 0.6, 0.4, 0.2]] # Predicted scores
54
+ loss = loss_fn(y_true, y_pred)
55
+ ```
56
+ """
57
+
58
+ def __init__(self, temperature: float = 1.0, **kwargs: Any) -> None:
59
+ super().__init__(**kwargs)
60
+
61
+ if temperature <= 0.0:
62
+ raise ValueError(
63
+ f"`temperature` should be a positive float. Received: "
64
+ f"`temperature` = {temperature}."
65
+ )
66
+
67
+ self.temperature = temperature
68
+ self._epsilon = 1e-10
69
+
70
+ def compute_unreduced_loss(
71
+ self,
72
+ labels: types.Tensor,
73
+ logits: types.Tensor,
74
+ mask: types.Tensor | None = None,
75
+ ) -> tuple[types.Tensor, types.Tensor]:
76
+ """Compute the unreduced ListMLE loss.
77
+
78
+ Args:
79
+ labels: Ground truth relevance scores of
80
+ shape [batch_size,list_size].
81
+ logits: Predicted scores of shape [batch_size, list_size].
82
+ mask: Optional mask of shape [batch_size, list_size].
83
+
84
+ Returns:
85
+ Tuple of (losses, weights) where losses has shape [batch_size, 1]
86
+ and weights has the same shape.
87
+ """
88
+
89
+ valid_mask = ops.greater_equal(labels, ops.cast(0.0, labels.dtype))
90
+
91
+ if mask is not None:
92
+ valid_mask = ops.logical_and(
93
+ valid_mask, ops.cast(mask, dtype="bool")
94
+ )
95
+
96
+ num_valid_items = ops.sum(
97
+ ops.cast(valid_mask, dtype=labels.dtype), axis=1, keepdims=True
98
+ )
99
+
100
+ batch_has_valid_items = ops.greater(num_valid_items, 0.0)
101
+
102
+ labels_for_sorting = ops.where(
103
+ valid_mask, labels, ops.full_like(labels, -1e9)
104
+ )
105
+ logits_masked = ops.where(
106
+ valid_mask, logits, ops.full_like(logits, -1e9)
107
+ )
108
+
109
+ sorted_logits, sorted_valid_mask = sort_by_scores(
110
+ tensors_to_sort=[logits_masked, valid_mask],
111
+ scores=labels_for_sorting,
112
+ mask=None,
113
+ shuffle_ties=False,
114
+ seed=None,
115
+ )
116
+ sorted_logits = ops.divide(
117
+ sorted_logits, ops.cast(self.temperature, dtype=sorted_logits.dtype)
118
+ )
119
+
120
+ valid_logits_for_max = ops.where(
121
+ sorted_valid_mask, sorted_logits, ops.full_like(sorted_logits, -1e9)
122
+ )
123
+ raw_max = ops.max(valid_logits_for_max, axis=1, keepdims=True)
124
+ raw_max = ops.where(
125
+ batch_has_valid_items, raw_max, ops.zeros_like(raw_max)
126
+ )
127
+ sorted_logits = ops.subtract(sorted_logits, raw_max)
128
+
129
+ # Set invalid positions to very negative BEFORE exp
130
+ sorted_logits = ops.where(
131
+ sorted_valid_mask, sorted_logits, ops.full_like(sorted_logits, -1e9)
132
+ )
133
+ exp_logits = ops.exp(sorted_logits)
134
+
135
+ reversed_exp = ops.flip(exp_logits, axis=1)
136
+ reversed_cumsum = ops.cumsum(reversed_exp, axis=1)
137
+ cumsum_from_right = ops.flip(reversed_cumsum, axis=1)
138
+
139
+ log_normalizers = ops.log(cumsum_from_right + self._epsilon)
140
+ log_probs = ops.subtract(sorted_logits, log_normalizers)
141
+
142
+ log_probs = ops.where(
143
+ sorted_valid_mask, log_probs, ops.zeros_like(log_probs)
144
+ )
145
+
146
+ negative_log_likelihood = ops.negative(
147
+ ops.sum(log_probs, axis=1, keepdims=True)
148
+ )
149
+
150
+ negative_log_likelihood = ops.where(
151
+ batch_has_valid_items,
152
+ negative_log_likelihood,
153
+ ops.zeros_like(negative_log_likelihood),
154
+ )
155
+
156
+ weights = ops.ones_like(negative_log_likelihood)
157
+
158
+ return negative_log_likelihood, weights
159
+
160
+ def call(
161
+ self,
162
+ y_true: types.Tensor,
163
+ y_pred: types.Tensor,
164
+ ) -> types.Tensor:
165
+ """Compute the ListMLE loss.
166
+
167
+ Args:
168
+ y_true: tensor or dict. Ground truth values. If tensor, of shape
169
+ `(list_size)` for unbatched inputs or `(batch_size, list_size)`
170
+ for batched inputs. If an item has a label of -1, it is ignored
171
+ in loss computation. If it is a dictionary, it should have two
172
+ keys: `"labels"` and `"mask"`. `"mask"` can be used to ignore
173
+ elements in loss computation.
174
+ y_pred: tensor. The predicted values, of shape `(list_size)` for
175
+ unbatched inputs or `(batch_size, list_size)` for batched
176
+ inputs. Should be of the same shape as `y_true`.
177
+
178
+ Returns:
179
+ The loss tensor of shape [batch_size].
180
+ """
181
+ mask = None
182
+ if isinstance(y_true, dict):
183
+ if "labels" not in y_true:
184
+ raise ValueError(
185
+ '`"labels"` should be present in `y_true`. Received: '
186
+ f"`y_true` = {y_true}"
187
+ )
188
+
189
+ mask = y_true.get("mask", None)
190
+ y_true = y_true["labels"]
191
+
192
+ y_true = ops.convert_to_tensor(y_true)
193
+ y_pred = ops.convert_to_tensor(y_pred)
194
+ if mask is not None:
195
+ mask = ops.convert_to_tensor(mask)
196
+
197
+ y_true, y_pred, mask, _ = standardize_call_inputs_ranks(
198
+ y_true, y_pred, mask
199
+ )
200
+
201
+ losses, weights = self.compute_unreduced_loss(
202
+ labels=y_true, logits=y_pred, mask=mask
203
+ )
204
+ losses = ops.multiply(losses, weights)
205
+ losses = ops.squeeze(losses, axis=-1)
206
+ return losses
207
+
208
+ # getting config
209
+ def get_config(self) -> dict[str, Any]:
210
+ config: dict[str, Any] = super().get_config()
211
+ config.update({"temperature": self.temperature})
212
+ return config
@@ -0,0 +1,90 @@
1
+ from keras import ops
2
+
3
+ from keras_rs.src import types
4
+ from keras_rs.src.api_export import keras_rs_export
5
+ from keras_rs.src.losses.pairwise_loss import PairwiseLoss
6
+ from keras_rs.src.losses.pairwise_loss import pairwise_loss_subclass_doc_string
7
+
8
+
9
+ @keras_rs_export("keras_rs.losses.PairwiseHingeLoss")
10
+ class PairwiseHingeLoss(PairwiseLoss):
11
+ def pairwise_loss(self, pairwise_logits: types.Tensor) -> types.Tensor:
12
+ return ops.relu(ops.subtract(ops.array(1), pairwise_logits))
13
+
14
+
15
+ formula = "loss = sum_{i} sum_{j} I(y_i > y_j) * max(0, 1 - (s_i - s_j))"
16
+ explanation = """
17
+ - `max(0, 1 - (s_i - s_j))` is the hinge loss, which penalizes cases where
18
+ the score difference `s_i - s_j` is not sufficiently large when
19
+ `y_i > y_j`.
20
+ """
21
+ extra_args = ""
22
+ example = """
23
+ With `compile()` API:
24
+
25
+ ```python
26
+ model.compile(
27
+ loss=keras_rs.losses.PairwiseHingeLoss(),
28
+ ...
29
+ )
30
+ ```
31
+
32
+ As a standalone function with unbatched inputs:
33
+
34
+ >>> y_true = np.array([1.0, 0.0, 1.0, 3.0, 2.0])
35
+ >>> y_pred = np.array([1.0, 3.0, 2.0, 4.0, 0.8])
36
+ >>> pairwise_hinge_loss = keras_rs.losses.PairwiseHingeLoss()
37
+ >>> pairwise_hinge_loss(y_true=y_true, y_pred=y_pred)
38
+ 2.32000
39
+
40
+ With batched inputs using default 'auto'/'sum_over_batch_size' reduction:
41
+
42
+ >>> y_true = np.array([[1.0, 0.0, 1.0, 3.0], [0.0, 1.0, 2.0, 3.0]])
43
+ >>> y_pred = np.array([[1.0, 3.0, 2.0, 4.0], [1.0, 1.8, 2.0, 3.0]])
44
+ >>> pairwise_hinge_loss = keras_rs.losses.PairwiseHingeLoss()
45
+ >>> pairwise_hinge_loss(y_true=y_true, y_pred=y_pred)
46
+ 0.75
47
+
48
+ With masked inputs (useful for ragged inputs):
49
+
50
+ >>> y_true = {
51
+ ... "labels": np.array([[1.0, 0.0, 1.0, 3.0], [0.0, 1.0, 2.0, 3.0]]),
52
+ ... "mask": np.array(
53
+ ... [[True, True, True, True], [True, True, False, False]]
54
+ ... ),
55
+ ... }
56
+ >>> y_pred = np.array([[1.0, 3.0, 2.0, 4.0], [1.0, 1.8, 2.0, 3.0]])
57
+ >>> pairwise_hinge_loss(y_true=y_true, y_pred=y_pred)
58
+ 0.64999
59
+
60
+ With `sample_weight`:
61
+
62
+ >>> y_true = np.array([[1.0, 0.0, 1.0, 3.0], [0.0, 1.0, 2.0, 3.0]])
63
+ >>> y_pred = np.array([[1.0, 3.0, 2.0, 4.0], [1.0, 1.8, 2.0, 3.0]])
64
+ >>> sample_weight = np.array(
65
+ ... [[2.0, 3.0, 1.0, 1.0], [2.0, 1.0, 0.0, 0.0]]
66
+ ... )
67
+ >>> pairwise_hinge_loss = keras_rs.losses.PairwiseHingeLoss()
68
+ >>> pairwise_hinge_loss(
69
+ ... y_true=y_true, y_pred=y_pred, sample_weight=sample_weight
70
+ ... )
71
+ 1.02499
72
+
73
+ Using `'none'` reduction:
74
+
75
+ >>> y_true = np.array([[1.0, 0.0, 1.0, 3.0], [0.0, 1.0, 2.0, 3.0]])
76
+ >>> y_pred = np.array([[1.0, 3.0, 2.0, 4.0], [1.0, 1.8, 2.0, 3.0]])
77
+ >>> pairwise_hinge_loss = keras_rs.losses.PairwiseHingeLoss(
78
+ ... reduction="none"
79
+ ... )
80
+ >>> pairwise_hinge_loss(y_true=y_true, y_pred=y_pred)
81
+ [[3. , 0. , 2. , 0.], [0., 0.20000005, 0.79999995, 0.]]
82
+ """
83
+
84
+ PairwiseHingeLoss.__doc__ = pairwise_loss_subclass_doc_string.format(
85
+ loss_name="hinge loss",
86
+ formula=formula,
87
+ explanation=explanation,
88
+ extra_args=extra_args,
89
+ example=example,
90
+ )
@@ -0,0 +1,99 @@
1
+ from keras import ops
2
+
3
+ from keras_rs.src import types
4
+ from keras_rs.src.api_export import keras_rs_export
5
+ from keras_rs.src.losses.pairwise_loss import PairwiseLoss
6
+ from keras_rs.src.losses.pairwise_loss import pairwise_loss_subclass_doc_string
7
+
8
+
9
+ @keras_rs_export("keras_rs.losses.PairwiseLogisticLoss")
10
+ class PairwiseLogisticLoss(PairwiseLoss):
11
+ def pairwise_loss(self, pairwise_logits: types.Tensor) -> types.Tensor:
12
+ return ops.add(
13
+ ops.relu(ops.negative(pairwise_logits)),
14
+ ops.log(
15
+ ops.add(
16
+ ops.array(1),
17
+ ops.exp(ops.negative(ops.abs(pairwise_logits))),
18
+ )
19
+ ),
20
+ )
21
+
22
+
23
+ formula = "loss = sum_{i} sum_{j} I(y_i > y_j) * log(1 + exp(-(s_i - s_j)))"
24
+ explanation = """
25
+ - `log(1 + exp(-(s_i - s_j)))` is the logistic loss, which penalizes
26
+ cases where the score difference `s_i - s_j` is not sufficiently large
27
+ when `y_i > y_j`. This function provides a smooth approximation of the
28
+ ideal step function, making it suitable for gradient-based optimization.
29
+ """
30
+ extra_args = ""
31
+ example = """
32
+ With `compile()` API:
33
+
34
+ ```python
35
+ model.compile(
36
+ loss=keras_rs.losses.PairwiseLogisticLoss(),
37
+ ...
38
+ )
39
+ ```
40
+
41
+ As a standalone function with unbatched inputs:
42
+
43
+ >>> y_true = np.array([1.0, 0.0, 1.0, 3.0, 2.0])
44
+ >>> y_pred = np.array([1.0, 3.0, 2.0, 4.0, 0.8])
45
+ >>> pairwise_logistic_loss = keras_rs.losses.PairwiseLogisticLoss()
46
+ >>> pairwise_logistic_loss(y_true=y_true, y_pred=y_pred)
47
+ >>> 1.70708
48
+
49
+ With batched inputs using default 'auto'/'sum_over_batch_size' reduction:
50
+
51
+ >>> y_true = np.array([[1.0, 0.0, 1.0, 3.0], [0.0, 1.0, 2.0, 3.0]])
52
+ >>> y_pred = np.array([[1.0, 3.0, 2.0, 4.0], [1.0, 1.8, 2.0, 3.0]])
53
+ >>> pairwise_logistic_loss = keras_rs.losses.PairwiseLogisticLoss()
54
+ >>> pairwise_logistic_loss(y_true=y_true, y_pred=y_pred)
55
+ 0.73936
56
+
57
+ With masked inputs (useful for ragged inputs):
58
+
59
+ >>> y_true = {
60
+ ... "labels": np.array([[1.0, 0.0, 1.0, 3.0], [0.0, 1.0, 2.0, 3.0]]),
61
+ ... "mask": np.array(
62
+ ... [[True, True, True, True], [True, True, False, False]]
63
+ ... ),
64
+ ... }
65
+ >>> y_pred = np.array([[1.0, 3.0, 2.0, 4.0], [1.0, 1.8, 2.0, 3.0]])
66
+ >>> pairwise_logistic_loss(y_true=y_true, y_pred=y_pred)
67
+ 0.53751
68
+
69
+ With `sample_weight`:
70
+
71
+ >>> y_true = np.array([[1.0, 0.0, 1.0, 3.0], [0.0, 1.0, 2.0, 3.0]])
72
+ >>> y_pred = np.array([[1.0, 3.0, 2.0, 4.0], [1.0, 1.8, 2.0, 3.0]])
73
+ >>> sample_weight = np.array(
74
+ ... [[2.0, 3.0, 1.0, 1.0], [2.0, 1.0, 0.0, 0.0]]
75
+ ... )
76
+ >>> pairwise_logistic_loss = keras_rs.losses.PairwiseLogisticLoss()
77
+ >>> pairwise_logistic_loss(
78
+ ... y_true=y_true, y_pred=y_pred, sample_weight=sample_weight
79
+ ... )
80
+ >>> 0.80337
81
+
82
+ Using `'none'` reduction:
83
+
84
+ >>> y_true = np.array([[1.0, 0.0, 1.0, 3.0], [0.0, 1.0, 2.0, 3.0]])
85
+ >>> y_pred = np.array([[1.0, 3.0, 2.0, 4.0], [1.0, 1.8, 2.0, 3.0]])
86
+ >>> pairwise_logistic_loss = keras_rs.losses.PairwiseLogisticLoss(
87
+ ... reduction="none"
88
+ ... )
89
+ >>> pairwise_logistic_loss(y_true=y_true, y_pred=y_pred)
90
+ [[2.126928, 0., 1.3132616, 0.48877698], [0., 0.20000005, 0.79999995, 0.]]
91
+ """
92
+
93
+ PairwiseLogisticLoss.__doc__ = pairwise_loss_subclass_doc_string.format(
94
+ loss_name="logistic loss",
95
+ formula=formula,
96
+ explanation=explanation,
97
+ extra_args=extra_args,
98
+ example=example,
99
+ )
@@ -0,0 +1,165 @@
1
+ import abc
2
+ from typing import Any
3
+
4
+ import keras
5
+ from keras import ops
6
+
7
+ from keras_rs.src import types
8
+ from keras_rs.src.losses.pairwise_loss_utils import pairwise_comparison
9
+ from keras_rs.src.metrics.utils import standardize_call_inputs_ranks
10
+
11
+
12
+ class PairwiseLoss(keras.losses.Loss, abc.ABC):
13
+ """Base class for pairwise ranking losses.
14
+
15
+ Pairwise loss functions are designed for ranking tasks, where the goal is to
16
+ correctly order items within each list. Any pairwise loss function computes
17
+ the loss value by comparing pairs of items within each list, penalizing
18
+ cases where an item with a higher true label has a lower predicted score
19
+ than an item with a lower true label.
20
+
21
+ In order to implement any kind of pairwise loss, override the
22
+ `pairwise_loss` method.
23
+ """
24
+
25
+ def __init__(self, temperature: float = 1.0, **kwargs: Any) -> None:
26
+ super().__init__(**kwargs)
27
+
28
+ if temperature <= 0.0:
29
+ raise ValueError(
30
+ f"`temperature` should be a positive float. Received: "
31
+ f"`temperature` = {temperature}."
32
+ )
33
+
34
+ self.temperature = temperature
35
+
36
+ # TODO(abheesht): Add `lambda_weights`.
37
+
38
+ @abc.abstractmethod
39
+ def pairwise_loss(self, pairwise_logits: types.Tensor) -> types.Tensor:
40
+ pass
41
+
42
+ def compute_unreduced_loss(
43
+ self,
44
+ labels: types.Tensor,
45
+ logits: types.Tensor,
46
+ mask: types.Tensor | None = None,
47
+ ) -> tuple[types.Tensor, types.Tensor]:
48
+ # Mask all values less than 0 (since less than 0 implies invalid
49
+ # labels).
50
+ valid_mask = ops.greater_equal(labels, ops.cast(0.0, labels.dtype))
51
+
52
+ if mask is not None:
53
+ valid_mask = ops.logical_and(valid_mask, mask)
54
+
55
+ pairwise_labels, pairwise_logits = pairwise_comparison(
56
+ labels=labels,
57
+ logits=logits,
58
+ mask=valid_mask,
59
+ logits_op=ops.subtract,
60
+ )
61
+ pairwise_logits = ops.divide(
62
+ pairwise_logits,
63
+ ops.cast(self.temperature, dtype=pairwise_logits.dtype),
64
+ )
65
+
66
+ return self.pairwise_loss(pairwise_logits), pairwise_labels
67
+
68
+ def call(
69
+ self,
70
+ y_true: types.Tensor,
71
+ y_pred: types.Tensor,
72
+ ) -> types.Tensor:
73
+ """Compute the pairwise loss.
74
+
75
+ Args:
76
+ y_true: tensor or dict. Ground truth values. If tensor, of shape
77
+ `(list_size)` for unbatched inputs or `(batch_size, list_size)`
78
+ for batched inputs. If an item has a label of -1, it is ignored
79
+ in loss computation. If it is a dictionary, it should have two
80
+ keys: `"labels"` and `"mask"`. `"mask"` can be used to ignore
81
+ elements in loss computation, i.e., pairs will not be formed
82
+ with those items. Note that the final mask is an `and` of the
83
+ passed mask, and `labels >= 0`.
84
+ y_pred: tensor. The predicted values, of shape `(list_size)` for
85
+ unbatched inputs or `(batch_size, list_size)` for batched
86
+ inputs. Should be of the same shape as `y_true`.
87
+
88
+ Returns:
89
+ The loss.
90
+ """
91
+ mask = None
92
+ if isinstance(y_true, dict):
93
+ if "labels" not in y_true:
94
+ raise ValueError(
95
+ '`"labels"` should be present in `y_true`. Received: '
96
+ f"`y_true` = {y_true}"
97
+ )
98
+
99
+ mask = y_true.get("mask", None)
100
+ y_true = y_true["labels"]
101
+
102
+ y_true = ops.convert_to_tensor(y_true)
103
+ y_pred = ops.convert_to_tensor(y_pred)
104
+ if mask is not None:
105
+ mask = ops.convert_to_tensor(mask)
106
+
107
+ y_true, y_pred, mask, _ = standardize_call_inputs_ranks(
108
+ y_true, y_pred, mask
109
+ )
110
+
111
+ losses, weights = self.compute_unreduced_loss(
112
+ labels=y_true, logits=y_pred, mask=mask
113
+ )
114
+ losses = ops.multiply(losses, weights)
115
+ losses = ops.sum(losses, axis=-1)
116
+ return losses
117
+
118
+ def get_config(self) -> dict[str, Any]:
119
+ config: dict[str, Any] = super().get_config()
120
+ config.update({"temperature": self.temperature})
121
+ return config
122
+
123
+
124
+ pairwise_loss_subclass_doc_string = (
125
+ "Computes pairwise {loss_name} between true labels and predicted scores."
126
+ """
127
+ This loss function is designed for ranking tasks, where the goal is to
128
+ correctly order items within each list. It computes the loss by comparing
129
+ pairs of items within each list, penalizing cases where an item with a
130
+ higher true label has a lower predicted score than an item with a lower
131
+ true label.
132
+
133
+ For each list of predicted scores `s` in `y_pred` and the corresponding list
134
+ of true labels `y` in `y_true`, the loss is computed as follows:
135
+
136
+ ```
137
+ {formula}
138
+ ```
139
+
140
+ where:
141
+
142
+ - `y_i` and `y_j` are the true labels of items `i` and `j`, respectively.
143
+ - `s_i` and `s_j` are the predicted scores of items `i` and `j`,
144
+ respectively.
145
+ - `I(y_i > y_j)` is an indicator function that equals 1 if `y_i > y_j`,
146
+ and 0 otherwise.{explanation}
147
+ Args:{extra_args}
148
+ reduction: Type of reduction to apply to the loss. In almost all cases
149
+ this should be `"sum_over_batch_size"`. Supported options are
150
+ `"sum"`, `"sum_over_batch_size"`, `"mean"`,
151
+ `"mean_with_sample_weight"` or `None`. `"sum"` sums the loss,
152
+ `"sum_over_batch_size"` and `"mean"` sum the loss and divide by the
153
+ sample size, and `"mean_with_sample_weight"` sums the loss and
154
+ divides by the sum of the sample weights. `"none"` and `None`
155
+ perform no aggregation. Defaults to `"sum_over_batch_size"`.
156
+ name: Optional name for the loss instance.
157
+ dtype: The dtype of the loss's computations. Defaults to `None`, which
158
+ means using `keras.backend.floatx()`. `keras.backend.floatx()` is a
159
+ `"float32"` unless set to different value
160
+ (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
161
+ provided, then the `compute_dtype` will be utilized.
162
+
163
+ Examples:
164
+ {example}"""
165
+ )
@@ -0,0 +1,39 @@
1
+ from typing import Callable
2
+
3
+ from keras import ops
4
+
5
+ from keras_rs.src import types
6
+
7
+
8
+ def apply_pairwise_op(
9
+ x: types.Tensor, op: Callable[[types.Tensor, types.Tensor], types.Tensor]
10
+ ) -> types.Tensor:
11
+ return op(
12
+ ops.expand_dims(x, axis=-1),
13
+ ops.expand_dims(x, axis=-2),
14
+ )
15
+
16
+
17
+ def pairwise_comparison(
18
+ labels: types.Tensor,
19
+ logits: types.Tensor,
20
+ mask: types.Tensor,
21
+ logits_op: Callable[[types.Tensor, types.Tensor], types.Tensor],
22
+ ) -> tuple[types.Tensor, types.Tensor]:
23
+ # Compute the difference for all pairs in a list. The output is a tensor
24
+ # with shape `(batch_size, list_size, list_size)`, where `[:, i, j]` stores
25
+ # information for pair `(i, j)`.
26
+ pairwise_labels_diff = apply_pairwise_op(labels, ops.subtract)
27
+ pairwise_logits = apply_pairwise_op(logits, logits_op)
28
+
29
+ # Keep only those cases where `l_i < l_j`.
30
+ pairwise_labels = ops.cast(
31
+ ops.greater(pairwise_labels_diff, 0), dtype=labels.dtype
32
+ )
33
+ if mask is not None:
34
+ valid_pairs = apply_pairwise_op(mask, ops.logical_and)
35
+ pairwise_labels = ops.multiply(
36
+ pairwise_labels, ops.cast(valid_pairs, dtype=pairwise_labels.dtype)
37
+ )
38
+
39
+ return pairwise_labels, pairwise_logits