keras-rs-nightly 0.0.1.dev2025042003__py3-none-any.whl → 0.0.1.dev2025042203__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of keras-rs-nightly might be problematic. Click here for more details.

keras_rs/api/__init__.py CHANGED
@@ -6,5 +6,6 @@ since your modifications would be overwritten.
6
6
 
7
7
  from keras_rs.api import layers
8
8
  from keras_rs.api import losses
9
+ from keras_rs.api import metrics
9
10
  from keras_rs.src.version import __version__
10
11
  from keras_rs.src.version import version
@@ -0,0 +1,12 @@
1
+ """DO NOT EDIT.
2
+
3
+ This file was autogenerated. Do not edit it by hand,
4
+ since your modifications would be overwritten.
5
+ """
6
+
7
+ from keras_rs.src.metrics.dcg import DCG
8
+ from keras_rs.src.metrics.mean_average_precision import MeanAveragePrecision
9
+ from keras_rs.src.metrics.mean_reciprocal_rank import MeanReciprocalRank
10
+ from keras_rs.src.metrics.ndcg import NDCG
11
+ from keras_rs.src.metrics.precision_at_k import PrecisionAtK
12
+ from keras_rs.src.metrics.recall_at_k import RecallAtK
@@ -20,6 +20,7 @@ explanation = """
20
20
  """
21
21
  extra_args = ""
22
22
  PairwiseHingeLoss.__doc__ = pairwise_loss_subclass_doc_string.format(
23
+ loss_name="hinge loss",
23
24
  formula=formula,
24
25
  explanation=explanation,
25
26
  extra_args=extra_args,
@@ -29,6 +29,7 @@ explanation = """
29
29
  """
30
30
  extra_args = ""
31
31
  PairwiseLogisticLoss.__doc__ = pairwise_loss_subclass_doc_string.format(
32
+ loss_name="logistic loss",
32
33
  formula=formula,
33
34
  explanation=explanation,
34
35
  extra_args=extra_args,
@@ -1,12 +1,12 @@
1
1
  import abc
2
- from typing import Optional
2
+ from typing import Any, Optional
3
3
 
4
4
  import keras
5
5
  from keras import ops
6
6
 
7
7
  from keras_rs.src import types
8
- from keras_rs.src.utils.pairwise_loss_utils import pairwise_comparison
9
- from keras_rs.src.utils.pairwise_loss_utils import process_loss_call_inputs
8
+ from keras_rs.src.losses.pairwise_loss_utils import pairwise_comparison
9
+ from keras_rs.src.metrics.utils import standardize_call_inputs_ranks
10
10
 
11
11
 
12
12
  class PairwiseLoss(keras.losses.Loss, abc.ABC):
@@ -22,14 +22,22 @@ class PairwiseLoss(keras.losses.Loss, abc.ABC):
22
22
  `pairwise_loss` method.
23
23
  """
24
24
 
25
- # TODO: Add `temperature`, `lambda_weights`.
25
+ def __init__(self, temperature: float = 1.0, **kwargs: Any) -> None:
26
+ super().__init__(**kwargs)
27
+
28
+ if temperature <= 0.0:
29
+ raise ValueError(
30
+ f"`temperature` should be a positive float. Received: "
31
+ f"`temperature` = {temperature}."
32
+ )
33
+
34
+ self.temperature = temperature
35
+
36
+ # TODO(abheesht): Add `lambda_weights`.
26
37
 
27
38
  @abc.abstractmethod
28
39
  def pairwise_loss(self, pairwise_logits: types.Tensor) -> types.Tensor:
29
- raise NotImplementedError(
30
- "All subclasses of `keras_rs.losses.pairwise_loss.PairwiseLoss`"
31
- "must implement the `pairwise_loss()` method."
32
- )
40
+ pass
33
41
 
34
42
  def compute_unreduced_loss(
35
43
  self,
@@ -50,6 +58,10 @@ class PairwiseLoss(keras.losses.Loss, abc.ABC):
50
58
  mask=valid_mask,
51
59
  logits_op=ops.subtract,
52
60
  )
61
+ pairwise_logits = ops.divide(
62
+ pairwise_logits,
63
+ ops.cast(self.temperature, dtype=pairwise_logits.dtype),
64
+ )
53
65
 
54
66
  return self.pairwise_loss(pairwise_logits), pairwise_labels
55
67
 
@@ -66,8 +78,8 @@ class PairwiseLoss(keras.losses.Loss, abc.ABC):
66
78
  in loss computation. If it is a dictionary, it should have two
67
79
  keys: `"labels"` and `"mask"`. `"mask"` can be used to ignore
68
80
  elements in loss computation, i.e., pairs will not be formed
69
- with those items. Note that the final mask is an and of the
70
- passed mask, and `labels == -1`.
81
+ with those items. Note that the final mask is an `and` of the
82
+ passed mask, and `labels >= 0`.
71
83
  y_pred: tensor. The predicted values, of shape `(list_size)` for
72
84
  unbatched inputs or `(batch_size, list_size)` for batched
73
85
  inputs. Should be of the same shape as `y_true`.
@@ -83,7 +95,14 @@ class PairwiseLoss(keras.losses.Loss, abc.ABC):
83
95
  mask = y_true.get("mask", None)
84
96
  y_true = y_true["labels"]
85
97
 
86
- y_true, y_pred, mask = process_loss_call_inputs(y_true, y_pred, mask)
98
+ y_true = ops.convert_to_tensor(y_true)
99
+ y_pred = ops.convert_to_tensor(y_pred)
100
+ if mask is not None:
101
+ mask = ops.convert_to_tensor(mask)
102
+
103
+ y_true, y_pred, mask, _ = standardize_call_inputs_ranks(
104
+ y_true, y_pred, mask
105
+ )
87
106
 
88
107
  losses, weights = self.compute_unreduced_loss(
89
108
  labels=y_true, logits=y_pred, mask=mask
@@ -92,9 +111,14 @@ class PairwiseLoss(keras.losses.Loss, abc.ABC):
92
111
  losses = ops.sum(losses, axis=-1)
93
112
  return losses
94
113
 
114
+ def get_config(self) -> dict[str, Any]:
115
+ config: dict[str, Any] = super().get_config()
116
+ config.update({"temperature": self.temperature})
117
+ return config
118
+
95
119
 
96
120
  pairwise_loss_subclass_doc_string = (
97
- "Computes pairwise hinge loss between true labels and predicted scores."
121
+ "Computes pairwise {loss_name} between true labels and predicted scores."
98
122
  """
99
123
  This loss function is designed for ranking tasks, where the goal is to
100
124
  correctly order items within each list. It computes the loss by comparing
@@ -0,0 +1,39 @@
1
+ from typing import Callable
2
+
3
+ from keras import ops
4
+
5
+ from keras_rs.src import types
6
+
7
+
8
+ def apply_pairwise_op(
9
+ x: types.Tensor, op: Callable[[types.Tensor, types.Tensor], types.Tensor]
10
+ ) -> types.Tensor:
11
+ return op(
12
+ ops.expand_dims(x, axis=-1),
13
+ ops.expand_dims(x, axis=-2),
14
+ )
15
+
16
+
17
+ def pairwise_comparison(
18
+ labels: types.Tensor,
19
+ logits: types.Tensor,
20
+ mask: types.Tensor,
21
+ logits_op: Callable[[types.Tensor, types.Tensor], types.Tensor],
22
+ ) -> tuple[types.Tensor, types.Tensor]:
23
+ # Compute the difference for all pairs in a list. The output is a tensor
24
+ # with shape `(batch_size, list_size, list_size)`, where `[:, i, j]` stores
25
+ # information for pair `(i, j)`.
26
+ pairwise_labels_diff = apply_pairwise_op(labels, ops.subtract)
27
+ pairwise_logits = apply_pairwise_op(logits, logits_op)
28
+
29
+ # Keep only those cases where `l_i < l_j`.
30
+ pairwise_labels = ops.cast(
31
+ ops.greater(pairwise_labels_diff, 0), dtype=labels.dtype
32
+ )
33
+ if mask is not None:
34
+ valid_pairs = apply_pairwise_op(mask, ops.logical_and)
35
+ pairwise_labels = ops.multiply(
36
+ pairwise_labels, ops.cast(valid_pairs, dtype=pairwise_labels.dtype)
37
+ )
38
+
39
+ return pairwise_labels, pairwise_logits
@@ -6,7 +6,7 @@ from keras_rs.src import types
6
6
  from keras_rs.src.api_export import keras_rs_export
7
7
  from keras_rs.src.losses.pairwise_loss import PairwiseLoss
8
8
  from keras_rs.src.losses.pairwise_loss import pairwise_loss_subclass_doc_string
9
- from keras_rs.src.utils.pairwise_loss_utils import apply_pairwise_op
9
+ from keras_rs.src.losses.pairwise_loss_utils import apply_pairwise_op
10
10
 
11
11
 
12
12
  @keras_rs_export("keras_rs.losses.PairwiseMeanSquaredError")
@@ -65,6 +65,7 @@ explanation = """
65
65
  """
66
66
  extra_args = ""
67
67
  PairwiseMeanSquaredError.__doc__ = pairwise_loss_subclass_doc_string.format(
68
+ loss_name="mean squared error",
68
69
  formula=formula,
69
70
  explanation=explanation,
70
71
  extra_args=extra_args,
@@ -25,6 +25,7 @@ explanation = """
25
25
  """
26
26
  extra_args = ""
27
27
  PairwiseSoftZeroOneLoss.__doc__ = pairwise_loss_subclass_doc_string.format(
28
+ loss_name="soft zero-one loss",
28
29
  formula=formula,
29
30
  explanation=explanation,
30
31
  extra_args=extra_args,
File without changes
@@ -0,0 +1,140 @@
1
+ from typing import Any, Callable, Optional
2
+
3
+ from keras import ops
4
+ from keras.saving import deserialize_keras_object
5
+ from keras.saving import serialize_keras_object
6
+
7
+ from keras_rs.src import types
8
+ from keras_rs.src.api_export import keras_rs_export
9
+ from keras_rs.src.metrics.ranking_metric import RankingMetric
10
+ from keras_rs.src.metrics.ranking_metric import (
11
+ ranking_metric_subclass_doc_string,
12
+ )
13
+ from keras_rs.src.metrics.ranking_metric import (
14
+ ranking_metric_subclass_doc_string_args,
15
+ )
16
+ from keras_rs.src.metrics.ranking_metrics_utils import compute_dcg
17
+ from keras_rs.src.metrics.ranking_metrics_utils import default_gain_fn
18
+ from keras_rs.src.metrics.ranking_metrics_utils import default_rank_discount_fn
19
+ from keras_rs.src.metrics.ranking_metrics_utils import get_list_weights
20
+ from keras_rs.src.metrics.ranking_metrics_utils import sort_by_scores
21
+ from keras_rs.src.utils.doc_string_utils import format_docstring
22
+
23
+
24
+ @keras_rs_export("keras_rs.metrics.DCG")
25
+ class DCG(RankingMetric):
26
+ def __init__(
27
+ self,
28
+ k: Optional[int] = None,
29
+ gain_fn: Callable[[types.Tensor], types.Tensor] = default_gain_fn,
30
+ rank_discount_fn: Callable[
31
+ [types.Tensor], types.Tensor
32
+ ] = default_rank_discount_fn,
33
+ **kwargs: Any,
34
+ ) -> None:
35
+ super().__init__(k=k, **kwargs)
36
+
37
+ self.gain_fn = gain_fn
38
+ self.rank_discount_fn = rank_discount_fn
39
+
40
+ def compute_metric(
41
+ self,
42
+ y_true: types.Tensor,
43
+ y_pred: types.Tensor,
44
+ mask: types.Tensor,
45
+ sample_weight: types.Tensor,
46
+ ) -> types.Tensor:
47
+ sorted_y_true, sorted_weights = sort_by_scores(
48
+ tensors_to_sort=[y_true, sample_weight],
49
+ scores=y_pred,
50
+ k=self.k,
51
+ mask=mask,
52
+ shuffle_ties=self.shuffle_ties,
53
+ seed=self.seed_generator,
54
+ )
55
+
56
+ dcg = compute_dcg(
57
+ y_true=sorted_y_true,
58
+ sample_weight=sorted_weights,
59
+ gain_fn=self.gain_fn,
60
+ rank_discount_fn=self.rank_discount_fn,
61
+ )
62
+
63
+ per_list_weights = get_list_weights(
64
+ weights=sample_weight, relevance=self.gain_fn(y_true)
65
+ )
66
+ # Since we have already multiplied with `sample_weight`, we need to
67
+ # divide by `per_list_weights` so as to nullify the multiplication
68
+ # which `keras.metrics.Mean` will do.
69
+ per_list_dcg = ops.divide_no_nan(dcg, per_list_weights)
70
+
71
+ return per_list_dcg, per_list_weights
72
+
73
+ def get_config(self) -> dict[str, Any]:
74
+ config: dict[str, Any] = super().get_config()
75
+ config.update(
76
+ {
77
+ "gain_fn": serialize_keras_object(self.gain_fn),
78
+ "rank_discount_fn": serialize_keras_object(
79
+ self.rank_discount_fn
80
+ ),
81
+ }
82
+ )
83
+ return config
84
+
85
+ @classmethod
86
+ def from_config(cls, config: dict[str, Any]) -> "DCG":
87
+ config["gain_fn"] = deserialize_keras_object(config["gain_fn"])
88
+ config["rank_discount_fn"] = deserialize_keras_object(
89
+ config["rank_discount_fn"]
90
+ )
91
+ return cls(**config)
92
+
93
+
94
+ concept_sentence = (
95
+ "It computes the sum of the graded relevance scores of items, applying a "
96
+ "configurable discount based on position"
97
+ )
98
+ relevance_type = (
99
+ "graded relevance scores (non-negative numbers where higher values "
100
+ "indicate greater relevance)"
101
+ )
102
+ score_range_interpretation = (
103
+ "Scores are non-negative, with higher values indicating better ranking "
104
+ "quality (highly relevant items are ranked higher). The score for a single "
105
+ "list is not bounded or normalized, i.e., it does not lie in a range"
106
+ )
107
+
108
+ formula = """
109
+ ```
110
+ DCG@k(y', w') = sum_{i=1}^{k} (gain_fn(y'_i) / rank_discount_fn(i))
111
+ ```
112
+
113
+ where:
114
+ - `y'_i` is the true relevance score of the item ranked at position `i`
115
+ (obtained by sorting `y_true` according to `y_pred`).
116
+ - `gain_fn` is the user-provided function mapping relevance `y'_i` to a
117
+ gain value. The default function (`default_gain_fn`) is typically
118
+ equivalent to `lambda y: 2**y - 1`.
119
+ - `rank_discount_fn` is the user-provided function mapping rank `i`
120
+ to a discount value. The default function (`default_rank_discount_fn`)
121
+ is typically equivalent to `lambda rank: 1 / log2(rank + 1)`.
122
+ - The final result aggregates these per-list scores.
123
+ """
124
+ extra_args = """
125
+ gain_fn: callable. Maps relevance scores (`y_true`) to gain values. The
126
+ default implements `2**y - 1`.
127
+ rank_discount_fn: function. Maps rank positions to discount
128
+ values. The default (`default_rank_discount_fn`) implements
129
+ `1 / log2(rank + 1)`."""
130
+
131
+ DCG.__doc__ = format_docstring(
132
+ ranking_metric_subclass_doc_string,
133
+ width=80,
134
+ metric_name="Discounted Cumulative Gain",
135
+ metric_abbreviation="DCG",
136
+ concept_sentence=concept_sentence,
137
+ relevance_type=relevance_type,
138
+ score_range_interpretation=score_range_interpretation,
139
+ formula=formula,
140
+ ) + ranking_metric_subclass_doc_string_args.format(extra_args=extra_args)
@@ -0,0 +1,112 @@
1
+ from keras import ops
2
+
3
+ from keras_rs.src import types
4
+ from keras_rs.src.api_export import keras_rs_export
5
+ from keras_rs.src.metrics.ranking_metric import RankingMetric
6
+ from keras_rs.src.metrics.ranking_metric import (
7
+ ranking_metric_subclass_doc_string,
8
+ )
9
+ from keras_rs.src.metrics.ranking_metric import (
10
+ ranking_metric_subclass_doc_string_args,
11
+ )
12
+ from keras_rs.src.metrics.ranking_metrics_utils import get_list_weights
13
+ from keras_rs.src.metrics.ranking_metrics_utils import sort_by_scores
14
+ from keras_rs.src.utils.doc_string_utils import format_docstring
15
+
16
+
17
+ @keras_rs_export("keras_rs.metrics.MeanAveragePrecision")
18
+ class MeanAveragePrecision(RankingMetric):
19
+ def compute_metric(
20
+ self,
21
+ y_true: types.Tensor,
22
+ y_pred: types.Tensor,
23
+ mask: types.Tensor,
24
+ sample_weight: types.Tensor,
25
+ ) -> types.Tensor:
26
+ relevance = ops.cast(
27
+ ops.greater_equal(y_true, ops.cast(1, dtype=y_true.dtype)),
28
+ dtype="float32",
29
+ )
30
+ sorted_relevance, sorted_weights = sort_by_scores(
31
+ tensors_to_sort=[relevance, sample_weight],
32
+ scores=y_pred,
33
+ mask=mask,
34
+ k=self.k,
35
+ shuffle_ties=self.shuffle_ties,
36
+ seed=self.seed_generator,
37
+ )
38
+ per_list_relevant_counts = ops.cumsum(sorted_relevance, axis=1)
39
+ per_list_cutoffs = ops.cumsum(ops.ones_like(sorted_relevance), axis=1)
40
+ per_list_precisions = ops.divide_no_nan(
41
+ per_list_relevant_counts, per_list_cutoffs
42
+ )
43
+
44
+ total_precision = ops.sum(
45
+ ops.multiply(
46
+ per_list_precisions,
47
+ ops.multiply(sorted_weights, sorted_relevance),
48
+ ),
49
+ axis=1,
50
+ keepdims=True,
51
+ )
52
+
53
+ # Compute the total relevance.
54
+ total_relevance = ops.sum(
55
+ ops.multiply(sample_weight, relevance), axis=1, keepdims=True
56
+ )
57
+
58
+ per_list_map = ops.divide_no_nan(total_precision, total_relevance)
59
+
60
+ per_list_weights = get_list_weights(sample_weight, relevance)
61
+
62
+ return per_list_map, per_list_weights
63
+
64
+
65
+ concept_sentence = (
66
+ "It calculates the average of precision values computed after each "
67
+ "relevant item present in the ranked list"
68
+ )
69
+ relevance_type = "binary indicators (0 or 1) of relevance"
70
+ score_range_interpretation = (
71
+ "Scores range from 0 to 1, with higher values indicating that relevant "
72
+ "items are generally positioned higher in the ranking"
73
+ )
74
+
75
+ formula = """
76
+ The formula for average precision is defined below. MAP is the mean over average
77
+ precision computed for each list.
78
+
79
+ ```
80
+ AP(y, s) = sum_j (P@j(y, s) * rel(j)) / sum_i y_i
81
+ rel(j) = y_i if rank(s_i) = j
82
+ ```
83
+
84
+ where:
85
+ - `j` represents the rank position (starting from 1).
86
+ - `sum_j` indicates a summation over all ranks `j` from 1 up to the list
87
+ size (or `k`).
88
+ - `P@j(y, s)` denotes the Precision at rank `j`, calculated as the
89
+ number of relevant items found within the top `j` positions divided by
90
+ `j`.
91
+ - `rel(j)` represents the relevance of the item specifically at rank
92
+ `j`. `rel(j)` is 1 if the item at rank `j` is relevant, and 0
93
+ otherwise.
94
+ - `y_i` is the true relevance label of the original item `i` before
95
+ ranking.
96
+ - `rank(s_i)` is the rank position assigned to item `i` based on its
97
+ score `s_i`.
98
+ - `sum_i y_i` calculates the total number of relevant items in the
99
+ original list `y`.
100
+ """
101
+ extra_args = ""
102
+
103
+ MeanAveragePrecision.__doc__ = format_docstring(
104
+ ranking_metric_subclass_doc_string,
105
+ width=80,
106
+ metric_name="Mean Average Precision",
107
+ metric_abbreviation="MAP",
108
+ concept_sentence=concept_sentence,
109
+ relevance_type=relevance_type,
110
+ score_range_interpretation=score_range_interpretation,
111
+ formula=formula,
112
+ ) + ranking_metric_subclass_doc_string_args.format(extra_args=extra_args)
@@ -0,0 +1,98 @@
1
+ from keras import ops
2
+
3
+ from keras_rs.src import types
4
+ from keras_rs.src.api_export import keras_rs_export
5
+ from keras_rs.src.metrics.ranking_metric import RankingMetric
6
+ from keras_rs.src.metrics.ranking_metric import (
7
+ ranking_metric_subclass_doc_string,
8
+ )
9
+ from keras_rs.src.metrics.ranking_metric import (
10
+ ranking_metric_subclass_doc_string_args,
11
+ )
12
+ from keras_rs.src.metrics.ranking_metrics_utils import get_list_weights
13
+ from keras_rs.src.metrics.ranking_metrics_utils import sort_by_scores
14
+ from keras_rs.src.utils.doc_string_utils import format_docstring
15
+
16
+
17
+ @keras_rs_export("keras_rs.metrics.MeanReciprocalRank")
18
+ class MeanReciprocalRank(RankingMetric):
19
+ def compute_metric(
20
+ self,
21
+ y_true: types.Tensor,
22
+ y_pred: types.Tensor,
23
+ mask: types.Tensor,
24
+ sample_weight: types.Tensor,
25
+ ) -> types.Tensor:
26
+ # Assume: `y_true = [0, 0, 1]`, `y_pred = [0.1, 0.9, 0.2]`.
27
+ # `sorted_y_true = [0, 1, 0]` (sorted in descending order).
28
+ (sorted_y_true,) = sort_by_scores(
29
+ tensors_to_sort=[y_true],
30
+ scores=y_pred,
31
+ mask=mask,
32
+ k=self.k,
33
+ shuffle_ties=self.shuffle_ties,
34
+ seed=self.seed_generator,
35
+ )
36
+
37
+ # This will depend on `k`, i.e., it will not always be the same as
38
+ # `len(y_true)`.
39
+ list_length = ops.shape(sorted_y_true)[1]
40
+
41
+ # We consider only binary relevance here, anything above 1 is treated
42
+ # as 1. `relevance = [0., 1., 0.]`.
43
+ relevance = ops.cast(
44
+ ops.greater_equal(
45
+ sorted_y_true, ops.cast(1, dtype=sorted_y_true.dtype)
46
+ ),
47
+ dtype="float32",
48
+ )
49
+
50
+ # `reciprocal_rank = [1, 0.5, 0.33]`
51
+ reciprocal_rank = ops.divide(
52
+ ops.cast(1, dtype="float32"),
53
+ ops.arange(1, list_length + 1, dtype="float32"),
54
+ )
55
+
56
+ # `mrr` should be of shape `(batch_size, 1)`.
57
+ # `mrr = amax([0., 0.5, 0.]) = 0.5`
58
+ mrr = ops.amax(
59
+ ops.multiply(relevance, reciprocal_rank),
60
+ axis=1,
61
+ keepdims=True,
62
+ )
63
+
64
+ # Get weights.
65
+ overall_relevance = ops.cast(
66
+ ops.greater_equal(y_true, ops.cast(1, dtype=y_true.dtype)),
67
+ dtype="float32",
68
+ )
69
+ per_list_weights = get_list_weights(
70
+ weights=sample_weight, relevance=overall_relevance
71
+ )
72
+
73
+ return mrr, per_list_weights
74
+
75
+
76
+ concept_sentence = (
77
+ "It focuses on the rank position of the single highest-scoring relevant "
78
+ "item"
79
+ )
80
+ relevance_type = "binary indicators (0 or 1) of relevance"
81
+ score_range_interpretation = (
82
+ "Scores range from 0 to 1, with 1 indicating the first relevant item was "
83
+ "always ranked first"
84
+ )
85
+ formula = """```
86
+ MRR(y, s) = max_{i} y_{i} / rank(s_{i})
87
+ ```"""
88
+ extra_args = ""
89
+ MeanReciprocalRank.__doc__ = format_docstring(
90
+ ranking_metric_subclass_doc_string,
91
+ width=80,
92
+ metric_name="Mean Reciprocal Rank",
93
+ metric_abbreviation="MRR",
94
+ concept_sentence=concept_sentence,
95
+ relevance_type=relevance_type,
96
+ score_range_interpretation=score_range_interpretation,
97
+ formula=formula,
98
+ ) + ranking_metric_subclass_doc_string_args.format(extra_args=extra_args)