keras-rs-nightly 0.0.1.dev2025042103__tar.gz → 0.0.1.dev2025042503__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of keras-rs-nightly might be problematic. Click here for more details.

Files changed (51) hide show
  1. {keras_rs_nightly-0.0.1.dev2025042103 → keras_rs_nightly-0.0.1.dev2025042503}/PKG-INFO +4 -3
  2. keras_rs_nightly-0.0.1.dev2025042503/keras_rs/api/__init__.py +11 -0
  3. {keras_rs_nightly-0.0.1.dev2025042103 → keras_rs_nightly-0.0.1.dev2025042503}/keras_rs/api/layers/__init__.py +9 -7
  4. keras_rs_nightly-0.0.1.dev2025042503/keras_rs/api/losses/__init__.py +18 -0
  5. keras_rs_nightly-0.0.1.dev2025042503/keras_rs/api/metrics/__init__.py +16 -0
  6. {keras_rs_nightly-0.0.1.dev2025042103 → keras_rs_nightly-0.0.1.dev2025042503}/keras_rs/src/losses/pairwise_hinge_loss.py +1 -0
  7. {keras_rs_nightly-0.0.1.dev2025042103 → keras_rs_nightly-0.0.1.dev2025042503}/keras_rs/src/losses/pairwise_logistic_loss.py +1 -0
  8. {keras_rs_nightly-0.0.1.dev2025042103 → keras_rs_nightly-0.0.1.dev2025042503}/keras_rs/src/losses/pairwise_loss.py +36 -12
  9. keras_rs_nightly-0.0.1.dev2025042503/keras_rs/src/losses/pairwise_loss_utils.py +39 -0
  10. {keras_rs_nightly-0.0.1.dev2025042103 → keras_rs_nightly-0.0.1.dev2025042503}/keras_rs/src/losses/pairwise_mean_squared_error.py +2 -1
  11. {keras_rs_nightly-0.0.1.dev2025042103 → keras_rs_nightly-0.0.1.dev2025042503}/keras_rs/src/losses/pairwise_soft_zero_one_loss.py +1 -0
  12. keras_rs_nightly-0.0.1.dev2025042503/keras_rs/src/metrics/dcg.py +140 -0
  13. keras_rs_nightly-0.0.1.dev2025042503/keras_rs/src/metrics/mean_average_precision.py +112 -0
  14. keras_rs_nightly-0.0.1.dev2025042503/keras_rs/src/metrics/mean_reciprocal_rank.py +98 -0
  15. keras_rs_nightly-0.0.1.dev2025042503/keras_rs/src/metrics/ndcg.py +184 -0
  16. keras_rs_nightly-0.0.1.dev2025042503/keras_rs/src/metrics/precision_at_k.py +94 -0
  17. keras_rs_nightly-0.0.1.dev2025042503/keras_rs/src/metrics/ranking_metric.py +252 -0
  18. keras_rs_nightly-0.0.1.dev2025042503/keras_rs/src/metrics/ranking_metrics_utils.py +238 -0
  19. keras_rs_nightly-0.0.1.dev2025042503/keras_rs/src/metrics/recall_at_k.py +85 -0
  20. keras_rs_nightly-0.0.1.dev2025042503/keras_rs/src/metrics/utils.py +72 -0
  21. keras_rs_nightly-0.0.1.dev2025042503/keras_rs/src/utils/__init__.py +0 -0
  22. keras_rs_nightly-0.0.1.dev2025042503/keras_rs/src/utils/doc_string_utils.py +48 -0
  23. {keras_rs_nightly-0.0.1.dev2025042103 → keras_rs_nightly-0.0.1.dev2025042503}/keras_rs/src/utils/keras_utils.py +12 -0
  24. {keras_rs_nightly-0.0.1.dev2025042103 → keras_rs_nightly-0.0.1.dev2025042503}/keras_rs/src/version.py +1 -1
  25. {keras_rs_nightly-0.0.1.dev2025042103 → keras_rs_nightly-0.0.1.dev2025042503}/keras_rs_nightly.egg-info/PKG-INFO +4 -3
  26. {keras_rs_nightly-0.0.1.dev2025042103 → keras_rs_nightly-0.0.1.dev2025042503}/keras_rs_nightly.egg-info/SOURCES.txt +13 -2
  27. {keras_rs_nightly-0.0.1.dev2025042103 → keras_rs_nightly-0.0.1.dev2025042503}/pyproject.toml +12 -5
  28. keras_rs_nightly-0.0.1.dev2025042103/keras_rs/__init__.py +0 -30
  29. keras_rs_nightly-0.0.1.dev2025042103/keras_rs/api/__init__.py +0 -10
  30. keras_rs_nightly-0.0.1.dev2025042103/keras_rs/api/losses/__init__.py +0 -14
  31. keras_rs_nightly-0.0.1.dev2025042103/keras_rs/src/utils/pairwise_loss_utils.py +0 -102
  32. {keras_rs_nightly-0.0.1.dev2025042103 → keras_rs_nightly-0.0.1.dev2025042503}/README.md +0 -0
  33. {keras_rs_nightly-0.0.1.dev2025042103 → keras_rs_nightly-0.0.1.dev2025042503}/keras_rs/src/__init__.py +0 -0
  34. {keras_rs_nightly-0.0.1.dev2025042103 → keras_rs_nightly-0.0.1.dev2025042503}/keras_rs/src/api_export.py +0 -0
  35. {keras_rs_nightly-0.0.1.dev2025042103 → keras_rs_nightly-0.0.1.dev2025042503}/keras_rs/src/layers/__init__.py +0 -0
  36. {keras_rs_nightly-0.0.1.dev2025042103 → keras_rs_nightly-0.0.1.dev2025042503}/keras_rs/src/layers/feature_interaction/__init__.py +0 -0
  37. {keras_rs_nightly-0.0.1.dev2025042103 → keras_rs_nightly-0.0.1.dev2025042503}/keras_rs/src/layers/feature_interaction/dot_interaction.py +0 -0
  38. {keras_rs_nightly-0.0.1.dev2025042103 → keras_rs_nightly-0.0.1.dev2025042503}/keras_rs/src/layers/feature_interaction/feature_cross.py +0 -0
  39. {keras_rs_nightly-0.0.1.dev2025042103 → keras_rs_nightly-0.0.1.dev2025042503}/keras_rs/src/layers/retrieval/__init__.py +0 -0
  40. {keras_rs_nightly-0.0.1.dev2025042103 → keras_rs_nightly-0.0.1.dev2025042503}/keras_rs/src/layers/retrieval/brute_force_retrieval.py +0 -0
  41. {keras_rs_nightly-0.0.1.dev2025042103 → keras_rs_nightly-0.0.1.dev2025042503}/keras_rs/src/layers/retrieval/hard_negative_mining.py +0 -0
  42. {keras_rs_nightly-0.0.1.dev2025042103 → keras_rs_nightly-0.0.1.dev2025042503}/keras_rs/src/layers/retrieval/remove_accidental_hits.py +0 -0
  43. {keras_rs_nightly-0.0.1.dev2025042103 → keras_rs_nightly-0.0.1.dev2025042503}/keras_rs/src/layers/retrieval/retrieval.py +0 -0
  44. {keras_rs_nightly-0.0.1.dev2025042103 → keras_rs_nightly-0.0.1.dev2025042503}/keras_rs/src/layers/retrieval/sampling_probability_correction.py +0 -0
  45. {keras_rs_nightly-0.0.1.dev2025042103 → keras_rs_nightly-0.0.1.dev2025042503}/keras_rs/src/losses/__init__.py +0 -0
  46. {keras_rs_nightly-0.0.1.dev2025042103/keras_rs/src/utils → keras_rs_nightly-0.0.1.dev2025042503/keras_rs/src/metrics}/__init__.py +0 -0
  47. {keras_rs_nightly-0.0.1.dev2025042103 → keras_rs_nightly-0.0.1.dev2025042503}/keras_rs/src/types.py +0 -0
  48. {keras_rs_nightly-0.0.1.dev2025042103 → keras_rs_nightly-0.0.1.dev2025042503}/keras_rs_nightly.egg-info/dependency_links.txt +0 -0
  49. {keras_rs_nightly-0.0.1.dev2025042103 → keras_rs_nightly-0.0.1.dev2025042503}/keras_rs_nightly.egg-info/requires.txt +0 -0
  50. {keras_rs_nightly-0.0.1.dev2025042103 → keras_rs_nightly-0.0.1.dev2025042503}/keras_rs_nightly.egg-info/top_level.txt +0 -0
  51. {keras_rs_nightly-0.0.1.dev2025042103 → keras_rs_nightly-0.0.1.dev2025042503}/setup.cfg +0 -0
@@ -1,10 +1,10 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: keras-rs-nightly
3
- Version: 0.0.1.dev2025042103
3
+ Version: 0.0.1.dev2025042503
4
4
  Summary: Multi-backend recommender systems with Keras 3.
5
- Author-email: Keras RS team <keras-rs@google.com>
5
+ Author-email: Keras team <keras-users@googlegroups.com>
6
6
  License: Apache License 2.0
7
- Project-URL: Home, https://keras.io/
7
+ Project-URL: Home, https://keras.io/keras_rs
8
8
  Project-URL: Repository, https://github.com/keras-team/keras-rs
9
9
  Classifier: Development Status :: 3 - Alpha
10
10
  Classifier: Programming Language :: Python :: 3
@@ -13,6 +13,7 @@ Classifier: Programming Language :: Python :: 3.10
13
13
  Classifier: Programming Language :: Python :: 3.11
14
14
  Classifier: Programming Language :: Python :: 3 :: Only
15
15
  Classifier: Operating System :: Unix
16
+ Classifier: Operating System :: Microsoft :: Windows
16
17
  Classifier: Operating System :: MacOS
17
18
  Classifier: Intended Audience :: Science/Research
18
19
  Classifier: Topic :: Scientific/Engineering
@@ -0,0 +1,11 @@
1
+ """DO NOT EDIT.
2
+
3
+ This file was autogenerated. Do not edit it by hand,
4
+ since your modifications would be overwritten.
5
+ """
6
+
7
+ from keras_rs import layers as layers
8
+ from keras_rs import losses as losses
9
+ from keras_rs import metrics as metrics
10
+ from keras_rs.src.version import __version__ as __version__
11
+ from keras_rs.src.version import version as version
@@ -5,19 +5,21 @@ since your modifications would be overwritten.
5
5
  """
6
6
 
7
7
  from keras_rs.src.layers.feature_interaction.dot_interaction import (
8
- DotInteraction,
8
+ DotInteraction as DotInteraction,
9
+ )
10
+ from keras_rs.src.layers.feature_interaction.feature_cross import (
11
+ FeatureCross as FeatureCross,
9
12
  )
10
- from keras_rs.src.layers.feature_interaction.feature_cross import FeatureCross
11
13
  from keras_rs.src.layers.retrieval.brute_force_retrieval import (
12
- BruteForceRetrieval,
14
+ BruteForceRetrieval as BruteForceRetrieval,
13
15
  )
14
16
  from keras_rs.src.layers.retrieval.hard_negative_mining import (
15
- HardNegativeMining,
17
+ HardNegativeMining as HardNegativeMining,
16
18
  )
17
19
  from keras_rs.src.layers.retrieval.remove_accidental_hits import (
18
- RemoveAccidentalHits,
20
+ RemoveAccidentalHits as RemoveAccidentalHits,
19
21
  )
20
- from keras_rs.src.layers.retrieval.retrieval import Retrieval
22
+ from keras_rs.src.layers.retrieval.retrieval import Retrieval as Retrieval
21
23
  from keras_rs.src.layers.retrieval.sampling_probability_correction import (
22
- SamplingProbabilityCorrection,
24
+ SamplingProbabilityCorrection as SamplingProbabilityCorrection,
23
25
  )
@@ -0,0 +1,18 @@
1
+ """DO NOT EDIT.
2
+
3
+ This file was autogenerated. Do not edit it by hand,
4
+ since your modifications would be overwritten.
5
+ """
6
+
7
+ from keras_rs.src.losses.pairwise_hinge_loss import (
8
+ PairwiseHingeLoss as PairwiseHingeLoss,
9
+ )
10
+ from keras_rs.src.losses.pairwise_logistic_loss import (
11
+ PairwiseLogisticLoss as PairwiseLogisticLoss,
12
+ )
13
+ from keras_rs.src.losses.pairwise_mean_squared_error import (
14
+ PairwiseMeanSquaredError as PairwiseMeanSquaredError,
15
+ )
16
+ from keras_rs.src.losses.pairwise_soft_zero_one_loss import (
17
+ PairwiseSoftZeroOneLoss as PairwiseSoftZeroOneLoss,
18
+ )
@@ -0,0 +1,16 @@
1
+ """DO NOT EDIT.
2
+
3
+ This file was autogenerated. Do not edit it by hand,
4
+ since your modifications would be overwritten.
5
+ """
6
+
7
+ from keras_rs.src.metrics.dcg import DCG as DCG
8
+ from keras_rs.src.metrics.mean_average_precision import (
9
+ MeanAveragePrecision as MeanAveragePrecision,
10
+ )
11
+ from keras_rs.src.metrics.mean_reciprocal_rank import (
12
+ MeanReciprocalRank as MeanReciprocalRank,
13
+ )
14
+ from keras_rs.src.metrics.ndcg import NDCG as NDCG
15
+ from keras_rs.src.metrics.precision_at_k import PrecisionAtK as PrecisionAtK
16
+ from keras_rs.src.metrics.recall_at_k import RecallAtK as RecallAtK
@@ -20,6 +20,7 @@ explanation = """
20
20
  """
21
21
  extra_args = ""
22
22
  PairwiseHingeLoss.__doc__ = pairwise_loss_subclass_doc_string.format(
23
+ loss_name="hinge loss",
23
24
  formula=formula,
24
25
  explanation=explanation,
25
26
  extra_args=extra_args,
@@ -29,6 +29,7 @@ explanation = """
29
29
  """
30
30
  extra_args = ""
31
31
  PairwiseLogisticLoss.__doc__ = pairwise_loss_subclass_doc_string.format(
32
+ loss_name="logistic loss",
32
33
  formula=formula,
33
34
  explanation=explanation,
34
35
  extra_args=extra_args,
@@ -1,12 +1,12 @@
1
1
  import abc
2
- from typing import Optional
2
+ from typing import Any, Optional
3
3
 
4
4
  import keras
5
5
  from keras import ops
6
6
 
7
7
  from keras_rs.src import types
8
- from keras_rs.src.utils.pairwise_loss_utils import pairwise_comparison
9
- from keras_rs.src.utils.pairwise_loss_utils import process_loss_call_inputs
8
+ from keras_rs.src.losses.pairwise_loss_utils import pairwise_comparison
9
+ from keras_rs.src.metrics.utils import standardize_call_inputs_ranks
10
10
 
11
11
 
12
12
  class PairwiseLoss(keras.losses.Loss, abc.ABC):
@@ -22,14 +22,22 @@ class PairwiseLoss(keras.losses.Loss, abc.ABC):
22
22
  `pairwise_loss` method.
23
23
  """
24
24
 
25
- # TODO: Add `temperature`, `lambda_weights`.
25
+ def __init__(self, temperature: float = 1.0, **kwargs: Any) -> None:
26
+ super().__init__(**kwargs)
27
+
28
+ if temperature <= 0.0:
29
+ raise ValueError(
30
+ f"`temperature` should be a positive float. Received: "
31
+ f"`temperature` = {temperature}."
32
+ )
33
+
34
+ self.temperature = temperature
35
+
36
+ # TODO(abheesht): Add `lambda_weights`.
26
37
 
27
38
  @abc.abstractmethod
28
39
  def pairwise_loss(self, pairwise_logits: types.Tensor) -> types.Tensor:
29
- raise NotImplementedError(
30
- "All subclasses of `keras_rs.losses.pairwise_loss.PairwiseLoss`"
31
- "must implement the `pairwise_loss()` method."
32
- )
40
+ pass
33
41
 
34
42
  def compute_unreduced_loss(
35
43
  self,
@@ -50,6 +58,10 @@ class PairwiseLoss(keras.losses.Loss, abc.ABC):
50
58
  mask=valid_mask,
51
59
  logits_op=ops.subtract,
52
60
  )
61
+ pairwise_logits = ops.divide(
62
+ pairwise_logits,
63
+ ops.cast(self.temperature, dtype=pairwise_logits.dtype),
64
+ )
53
65
 
54
66
  return self.pairwise_loss(pairwise_logits), pairwise_labels
55
67
 
@@ -66,8 +78,8 @@ class PairwiseLoss(keras.losses.Loss, abc.ABC):
66
78
  in loss computation. If it is a dictionary, it should have two
67
79
  keys: `"labels"` and `"mask"`. `"mask"` can be used to ignore
68
80
  elements in loss computation, i.e., pairs will not be formed
69
- with those items. Note that the final mask is an and of the
70
- passed mask, and `labels == -1`.
81
+ with those items. Note that the final mask is an `and` of the
82
+ passed mask, and `labels >= 0`.
71
83
  y_pred: tensor. The predicted values, of shape `(list_size)` for
72
84
  unbatched inputs or `(batch_size, list_size)` for batched
73
85
  inputs. Should be of the same shape as `y_true`.
@@ -83,7 +95,14 @@ class PairwiseLoss(keras.losses.Loss, abc.ABC):
83
95
  mask = y_true.get("mask", None)
84
96
  y_true = y_true["labels"]
85
97
 
86
- y_true, y_pred, mask = process_loss_call_inputs(y_true, y_pred, mask)
98
+ y_true = ops.convert_to_tensor(y_true)
99
+ y_pred = ops.convert_to_tensor(y_pred)
100
+ if mask is not None:
101
+ mask = ops.convert_to_tensor(mask)
102
+
103
+ y_true, y_pred, mask, _ = standardize_call_inputs_ranks(
104
+ y_true, y_pred, mask
105
+ )
87
106
 
88
107
  losses, weights = self.compute_unreduced_loss(
89
108
  labels=y_true, logits=y_pred, mask=mask
@@ -92,9 +111,14 @@ class PairwiseLoss(keras.losses.Loss, abc.ABC):
92
111
  losses = ops.sum(losses, axis=-1)
93
112
  return losses
94
113
 
114
+ def get_config(self) -> dict[str, Any]:
115
+ config: dict[str, Any] = super().get_config()
116
+ config.update({"temperature": self.temperature})
117
+ return config
118
+
95
119
 
96
120
  pairwise_loss_subclass_doc_string = (
97
- "Computes pairwise hinge loss between true labels and predicted scores."
121
+ "Computes pairwise {loss_name} between true labels and predicted scores."
98
122
  """
99
123
  This loss function is designed for ranking tasks, where the goal is to
100
124
  correctly order items within each list. It computes the loss by comparing
@@ -0,0 +1,39 @@
1
+ from typing import Callable
2
+
3
+ from keras import ops
4
+
5
+ from keras_rs.src import types
6
+
7
+
8
+ def apply_pairwise_op(
9
+ x: types.Tensor, op: Callable[[types.Tensor, types.Tensor], types.Tensor]
10
+ ) -> types.Tensor:
11
+ return op(
12
+ ops.expand_dims(x, axis=-1),
13
+ ops.expand_dims(x, axis=-2),
14
+ )
15
+
16
+
17
+ def pairwise_comparison(
18
+ labels: types.Tensor,
19
+ logits: types.Tensor,
20
+ mask: types.Tensor,
21
+ logits_op: Callable[[types.Tensor, types.Tensor], types.Tensor],
22
+ ) -> tuple[types.Tensor, types.Tensor]:
23
+ # Compute the difference for all pairs in a list. The output is a tensor
24
+ # with shape `(batch_size, list_size, list_size)`, where `[:, i, j]` stores
25
+ # information for pair `(i, j)`.
26
+ pairwise_labels_diff = apply_pairwise_op(labels, ops.subtract)
27
+ pairwise_logits = apply_pairwise_op(logits, logits_op)
28
+
29
+ # Keep only those cases where `l_i < l_j`.
30
+ pairwise_labels = ops.cast(
31
+ ops.greater(pairwise_labels_diff, 0), dtype=labels.dtype
32
+ )
33
+ if mask is not None:
34
+ valid_pairs = apply_pairwise_op(mask, ops.logical_and)
35
+ pairwise_labels = ops.multiply(
36
+ pairwise_labels, ops.cast(valid_pairs, dtype=pairwise_labels.dtype)
37
+ )
38
+
39
+ return pairwise_labels, pairwise_logits
@@ -6,7 +6,7 @@ from keras_rs.src import types
6
6
  from keras_rs.src.api_export import keras_rs_export
7
7
  from keras_rs.src.losses.pairwise_loss import PairwiseLoss
8
8
  from keras_rs.src.losses.pairwise_loss import pairwise_loss_subclass_doc_string
9
- from keras_rs.src.utils.pairwise_loss_utils import apply_pairwise_op
9
+ from keras_rs.src.losses.pairwise_loss_utils import apply_pairwise_op
10
10
 
11
11
 
12
12
  @keras_rs_export("keras_rs.losses.PairwiseMeanSquaredError")
@@ -65,6 +65,7 @@ explanation = """
65
65
  """
66
66
  extra_args = ""
67
67
  PairwiseMeanSquaredError.__doc__ = pairwise_loss_subclass_doc_string.format(
68
+ loss_name="mean squared error",
68
69
  formula=formula,
69
70
  explanation=explanation,
70
71
  extra_args=extra_args,
@@ -25,6 +25,7 @@ explanation = """
25
25
  """
26
26
  extra_args = ""
27
27
  PairwiseSoftZeroOneLoss.__doc__ = pairwise_loss_subclass_doc_string.format(
28
+ loss_name="soft zero-one loss",
28
29
  formula=formula,
29
30
  explanation=explanation,
30
31
  extra_args=extra_args,
@@ -0,0 +1,140 @@
1
+ from typing import Any, Callable, Optional
2
+
3
+ from keras import ops
4
+ from keras.saving import deserialize_keras_object
5
+ from keras.saving import serialize_keras_object
6
+
7
+ from keras_rs.src import types
8
+ from keras_rs.src.api_export import keras_rs_export
9
+ from keras_rs.src.metrics.ranking_metric import RankingMetric
10
+ from keras_rs.src.metrics.ranking_metric import (
11
+ ranking_metric_subclass_doc_string,
12
+ )
13
+ from keras_rs.src.metrics.ranking_metric import (
14
+ ranking_metric_subclass_doc_string_args,
15
+ )
16
+ from keras_rs.src.metrics.ranking_metrics_utils import compute_dcg
17
+ from keras_rs.src.metrics.ranking_metrics_utils import default_gain_fn
18
+ from keras_rs.src.metrics.ranking_metrics_utils import default_rank_discount_fn
19
+ from keras_rs.src.metrics.ranking_metrics_utils import get_list_weights
20
+ from keras_rs.src.metrics.ranking_metrics_utils import sort_by_scores
21
+ from keras_rs.src.utils.doc_string_utils import format_docstring
22
+
23
+
24
+ @keras_rs_export("keras_rs.metrics.DCG")
25
+ class DCG(RankingMetric):
26
+ def __init__(
27
+ self,
28
+ k: Optional[int] = None,
29
+ gain_fn: Callable[[types.Tensor], types.Tensor] = default_gain_fn,
30
+ rank_discount_fn: Callable[
31
+ [types.Tensor], types.Tensor
32
+ ] = default_rank_discount_fn,
33
+ **kwargs: Any,
34
+ ) -> None:
35
+ super().__init__(k=k, **kwargs)
36
+
37
+ self.gain_fn = gain_fn
38
+ self.rank_discount_fn = rank_discount_fn
39
+
40
+ def compute_metric(
41
+ self,
42
+ y_true: types.Tensor,
43
+ y_pred: types.Tensor,
44
+ mask: types.Tensor,
45
+ sample_weight: types.Tensor,
46
+ ) -> types.Tensor:
47
+ sorted_y_true, sorted_weights = sort_by_scores(
48
+ tensors_to_sort=[y_true, sample_weight],
49
+ scores=y_pred,
50
+ k=self.k,
51
+ mask=mask,
52
+ shuffle_ties=self.shuffle_ties,
53
+ seed=self.seed_generator,
54
+ )
55
+
56
+ dcg = compute_dcg(
57
+ y_true=sorted_y_true,
58
+ sample_weight=sorted_weights,
59
+ gain_fn=self.gain_fn,
60
+ rank_discount_fn=self.rank_discount_fn,
61
+ )
62
+
63
+ per_list_weights = get_list_weights(
64
+ weights=sample_weight, relevance=self.gain_fn(y_true)
65
+ )
66
+ # Since we have already multiplied with `sample_weight`, we need to
67
+ # divide by `per_list_weights` so as to nullify the multiplication
68
+ # which `keras.metrics.Mean` will do.
69
+ per_list_dcg = ops.divide_no_nan(dcg, per_list_weights)
70
+
71
+ return per_list_dcg, per_list_weights
72
+
73
+ def get_config(self) -> dict[str, Any]:
74
+ config: dict[str, Any] = super().get_config()
75
+ config.update(
76
+ {
77
+ "gain_fn": serialize_keras_object(self.gain_fn),
78
+ "rank_discount_fn": serialize_keras_object(
79
+ self.rank_discount_fn
80
+ ),
81
+ }
82
+ )
83
+ return config
84
+
85
+ @classmethod
86
+ def from_config(cls, config: dict[str, Any]) -> "DCG":
87
+ config["gain_fn"] = deserialize_keras_object(config["gain_fn"])
88
+ config["rank_discount_fn"] = deserialize_keras_object(
89
+ config["rank_discount_fn"]
90
+ )
91
+ return cls(**config)
92
+
93
+
94
+ concept_sentence = (
95
+ "It computes the sum of the graded relevance scores of items, applying a "
96
+ "configurable discount based on position"
97
+ )
98
+ relevance_type = (
99
+ "graded relevance scores (non-negative numbers where higher values "
100
+ "indicate greater relevance)"
101
+ )
102
+ score_range_interpretation = (
103
+ "Scores are non-negative, with higher values indicating better ranking "
104
+ "quality (highly relevant items are ranked higher). The score for a single "
105
+ "list is not bounded or normalized, i.e., it does not lie in a range"
106
+ )
107
+
108
+ formula = """
109
+ ```
110
+ DCG@k(y', w') = sum_{i=1}^{k} (gain_fn(y'_i) / rank_discount_fn(i))
111
+ ```
112
+
113
+ where:
114
+ - `y'_i` is the true relevance score of the item ranked at position `i`
115
+ (obtained by sorting `y_true` according to `y_pred`).
116
+ - `gain_fn` is the user-provided function mapping relevance `y'_i` to a
117
+ gain value. The default function (`default_gain_fn`) is typically
118
+ equivalent to `lambda y: 2**y - 1`.
119
+ - `rank_discount_fn` is the user-provided function mapping rank `i`
120
+ to a discount value. The default function (`default_rank_discount_fn`)
121
+ is typically equivalent to `lambda rank: 1 / log2(rank + 1)`.
122
+ - The final result aggregates these per-list scores.
123
+ """
124
+ extra_args = """
125
+ gain_fn: callable. Maps relevance scores (`y_true`) to gain values. The
126
+ default implements `2**y - 1`.
127
+ rank_discount_fn: function. Maps rank positions to discount
128
+ values. The default (`default_rank_discount_fn`) implements
129
+ `1 / log2(rank + 1)`."""
130
+
131
+ DCG.__doc__ = format_docstring(
132
+ ranking_metric_subclass_doc_string,
133
+ width=80,
134
+ metric_name="Discounted Cumulative Gain",
135
+ metric_abbreviation="DCG",
136
+ concept_sentence=concept_sentence,
137
+ relevance_type=relevance_type,
138
+ score_range_interpretation=score_range_interpretation,
139
+ formula=formula,
140
+ ) + ranking_metric_subclass_doc_string_args.format(extra_args=extra_args)
@@ -0,0 +1,112 @@
1
+ from keras import ops
2
+
3
+ from keras_rs.src import types
4
+ from keras_rs.src.api_export import keras_rs_export
5
+ from keras_rs.src.metrics.ranking_metric import RankingMetric
6
+ from keras_rs.src.metrics.ranking_metric import (
7
+ ranking_metric_subclass_doc_string,
8
+ )
9
+ from keras_rs.src.metrics.ranking_metric import (
10
+ ranking_metric_subclass_doc_string_args,
11
+ )
12
+ from keras_rs.src.metrics.ranking_metrics_utils import get_list_weights
13
+ from keras_rs.src.metrics.ranking_metrics_utils import sort_by_scores
14
+ from keras_rs.src.utils.doc_string_utils import format_docstring
15
+
16
+
17
+ @keras_rs_export("keras_rs.metrics.MeanAveragePrecision")
18
+ class MeanAveragePrecision(RankingMetric):
19
+ def compute_metric(
20
+ self,
21
+ y_true: types.Tensor,
22
+ y_pred: types.Tensor,
23
+ mask: types.Tensor,
24
+ sample_weight: types.Tensor,
25
+ ) -> types.Tensor:
26
+ relevance = ops.cast(
27
+ ops.greater_equal(y_true, ops.cast(1, dtype=y_true.dtype)),
28
+ dtype="float32",
29
+ )
30
+ sorted_relevance, sorted_weights = sort_by_scores(
31
+ tensors_to_sort=[relevance, sample_weight],
32
+ scores=y_pred,
33
+ mask=mask,
34
+ k=self.k,
35
+ shuffle_ties=self.shuffle_ties,
36
+ seed=self.seed_generator,
37
+ )
38
+ per_list_relevant_counts = ops.cumsum(sorted_relevance, axis=1)
39
+ per_list_cutoffs = ops.cumsum(ops.ones_like(sorted_relevance), axis=1)
40
+ per_list_precisions = ops.divide_no_nan(
41
+ per_list_relevant_counts, per_list_cutoffs
42
+ )
43
+
44
+ total_precision = ops.sum(
45
+ ops.multiply(
46
+ per_list_precisions,
47
+ ops.multiply(sorted_weights, sorted_relevance),
48
+ ),
49
+ axis=1,
50
+ keepdims=True,
51
+ )
52
+
53
+ # Compute the total relevance.
54
+ total_relevance = ops.sum(
55
+ ops.multiply(sample_weight, relevance), axis=1, keepdims=True
56
+ )
57
+
58
+ per_list_map = ops.divide_no_nan(total_precision, total_relevance)
59
+
60
+ per_list_weights = get_list_weights(sample_weight, relevance)
61
+
62
+ return per_list_map, per_list_weights
63
+
64
+
65
+ concept_sentence = (
66
+ "It calculates the average of precision values computed after each "
67
+ "relevant item present in the ranked list"
68
+ )
69
+ relevance_type = "binary indicators (0 or 1) of relevance"
70
+ score_range_interpretation = (
71
+ "Scores range from 0 to 1, with higher values indicating that relevant "
72
+ "items are generally positioned higher in the ranking"
73
+ )
74
+
75
+ formula = """
76
+ The formula for average precision is defined below. MAP is the mean over average
77
+ precision computed for each list.
78
+
79
+ ```
80
+ AP(y, s) = sum_j (P@j(y, s) * rel(j)) / sum_i y_i
81
+ rel(j) = y_i if rank(s_i) = j
82
+ ```
83
+
84
+ where:
85
+ - `j` represents the rank position (starting from 1).
86
+ - `sum_j` indicates a summation over all ranks `j` from 1 up to the list
87
+ size (or `k`).
88
+ - `P@j(y, s)` denotes the Precision at rank `j`, calculated as the
89
+ number of relevant items found within the top `j` positions divided by
90
+ `j`.
91
+ - `rel(j)` represents the relevance of the item specifically at rank
92
+ `j`. `rel(j)` is 1 if the item at rank `j` is relevant, and 0
93
+ otherwise.
94
+ - `y_i` is the true relevance label of the original item `i` before
95
+ ranking.
96
+ - `rank(s_i)` is the rank position assigned to item `i` based on its
97
+ score `s_i`.
98
+ - `sum_i y_i` calculates the total number of relevant items in the
99
+ original list `y`.
100
+ """
101
+ extra_args = ""
102
+
103
+ MeanAveragePrecision.__doc__ = format_docstring(
104
+ ranking_metric_subclass_doc_string,
105
+ width=80,
106
+ metric_name="Mean Average Precision",
107
+ metric_abbreviation="MAP",
108
+ concept_sentence=concept_sentence,
109
+ relevance_type=relevance_type,
110
+ score_range_interpretation=score_range_interpretation,
111
+ formula=formula,
112
+ ) + ranking_metric_subclass_doc_string_args.format(extra_args=extra_args)
@@ -0,0 +1,98 @@
1
+ from keras import ops
2
+
3
+ from keras_rs.src import types
4
+ from keras_rs.src.api_export import keras_rs_export
5
+ from keras_rs.src.metrics.ranking_metric import RankingMetric
6
+ from keras_rs.src.metrics.ranking_metric import (
7
+ ranking_metric_subclass_doc_string,
8
+ )
9
+ from keras_rs.src.metrics.ranking_metric import (
10
+ ranking_metric_subclass_doc_string_args,
11
+ )
12
+ from keras_rs.src.metrics.ranking_metrics_utils import get_list_weights
13
+ from keras_rs.src.metrics.ranking_metrics_utils import sort_by_scores
14
+ from keras_rs.src.utils.doc_string_utils import format_docstring
15
+
16
+
17
+ @keras_rs_export("keras_rs.metrics.MeanReciprocalRank")
18
+ class MeanReciprocalRank(RankingMetric):
19
+ def compute_metric(
20
+ self,
21
+ y_true: types.Tensor,
22
+ y_pred: types.Tensor,
23
+ mask: types.Tensor,
24
+ sample_weight: types.Tensor,
25
+ ) -> types.Tensor:
26
+ # Assume: `y_true = [0, 0, 1]`, `y_pred = [0.1, 0.9, 0.2]`.
27
+ # `sorted_y_true = [0, 1, 0]` (sorted in descending order).
28
+ (sorted_y_true,) = sort_by_scores(
29
+ tensors_to_sort=[y_true],
30
+ scores=y_pred,
31
+ mask=mask,
32
+ k=self.k,
33
+ shuffle_ties=self.shuffle_ties,
34
+ seed=self.seed_generator,
35
+ )
36
+
37
+ # This will depend on `k`, i.e., it will not always be the same as
38
+ # `len(y_true)`.
39
+ list_length = ops.shape(sorted_y_true)[1]
40
+
41
+ # We consider only binary relevance here, anything above 1 is treated
42
+ # as 1. `relevance = [0., 1., 0.]`.
43
+ relevance = ops.cast(
44
+ ops.greater_equal(
45
+ sorted_y_true, ops.cast(1, dtype=sorted_y_true.dtype)
46
+ ),
47
+ dtype="float32",
48
+ )
49
+
50
+ # `reciprocal_rank = [1, 0.5, 0.33]`
51
+ reciprocal_rank = ops.divide(
52
+ ops.cast(1, dtype="float32"),
53
+ ops.arange(1, list_length + 1, dtype="float32"),
54
+ )
55
+
56
+ # `mrr` should be of shape `(batch_size, 1)`.
57
+ # `mrr = amax([0., 0.5, 0.]) = 0.5`
58
+ mrr = ops.amax(
59
+ ops.multiply(relevance, reciprocal_rank),
60
+ axis=1,
61
+ keepdims=True,
62
+ )
63
+
64
+ # Get weights.
65
+ overall_relevance = ops.cast(
66
+ ops.greater_equal(y_true, ops.cast(1, dtype=y_true.dtype)),
67
+ dtype="float32",
68
+ )
69
+ per_list_weights = get_list_weights(
70
+ weights=sample_weight, relevance=overall_relevance
71
+ )
72
+
73
+ return mrr, per_list_weights
74
+
75
+
76
+ concept_sentence = (
77
+ "It focuses on the rank position of the single highest-scoring relevant "
78
+ "item"
79
+ )
80
+ relevance_type = "binary indicators (0 or 1) of relevance"
81
+ score_range_interpretation = (
82
+ "Scores range from 0 to 1, with 1 indicating the first relevant item was "
83
+ "always ranked first"
84
+ )
85
+ formula = """```
86
+ MRR(y, s) = max_{i} y_{i} / rank(s_{i})
87
+ ```"""
88
+ extra_args = ""
89
+ MeanReciprocalRank.__doc__ = format_docstring(
90
+ ranking_metric_subclass_doc_string,
91
+ width=80,
92
+ metric_name="Mean Reciprocal Rank",
93
+ metric_abbreviation="MRR",
94
+ concept_sentence=concept_sentence,
95
+ relevance_type=relevance_type,
96
+ score_range_interpretation=score_range_interpretation,
97
+ formula=formula,
98
+ ) + ranking_metric_subclass_doc_string_args.format(extra_args=extra_args)