keras-rs-nightly 0.0.1.dev2025042103__py3-none-any.whl → 0.0.1.dev2025042503__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of keras-rs-nightly might be problematic. Click here for more details.

Files changed (31) hide show
  1. keras_rs/__init__.py +9 -28
  2. keras_rs/{api/layers → layers}/__init__.py +9 -7
  3. keras_rs/losses/__init__.py +18 -0
  4. keras_rs/metrics/__init__.py +16 -0
  5. keras_rs/src/losses/pairwise_hinge_loss.py +1 -0
  6. keras_rs/src/losses/pairwise_logistic_loss.py +1 -0
  7. keras_rs/src/losses/pairwise_loss.py +36 -12
  8. keras_rs/src/losses/pairwise_loss_utils.py +39 -0
  9. keras_rs/src/losses/pairwise_mean_squared_error.py +2 -1
  10. keras_rs/src/losses/pairwise_soft_zero_one_loss.py +1 -0
  11. keras_rs/src/metrics/__init__.py +0 -0
  12. keras_rs/src/metrics/dcg.py +140 -0
  13. keras_rs/src/metrics/mean_average_precision.py +112 -0
  14. keras_rs/src/metrics/mean_reciprocal_rank.py +98 -0
  15. keras_rs/src/metrics/ndcg.py +184 -0
  16. keras_rs/src/metrics/precision_at_k.py +94 -0
  17. keras_rs/src/metrics/ranking_metric.py +252 -0
  18. keras_rs/src/metrics/ranking_metrics_utils.py +238 -0
  19. keras_rs/src/metrics/recall_at_k.py +85 -0
  20. keras_rs/src/metrics/utils.py +72 -0
  21. keras_rs/src/utils/doc_string_utils.py +48 -0
  22. keras_rs/src/utils/keras_utils.py +12 -0
  23. keras_rs/src/version.py +1 -1
  24. {keras_rs_nightly-0.0.1.dev2025042103.dist-info → keras_rs_nightly-0.0.1.dev2025042503.dist-info}/METADATA +4 -3
  25. keras_rs_nightly-0.0.1.dev2025042503.dist-info/RECORD +42 -0
  26. {keras_rs_nightly-0.0.1.dev2025042103.dist-info → keras_rs_nightly-0.0.1.dev2025042503.dist-info}/WHEEL +1 -1
  27. keras_rs/api/__init__.py +0 -10
  28. keras_rs/api/losses/__init__.py +0 -14
  29. keras_rs/src/utils/pairwise_loss_utils.py +0 -102
  30. keras_rs_nightly-0.0.1.dev2025042103.dist-info/RECORD +0 -31
  31. {keras_rs_nightly-0.0.1.dev2025042103.dist-info → keras_rs_nightly-0.0.1.dev2025042503.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,238 @@
1
+ from typing import Callable, Optional, Union
2
+
3
+ import keras
4
+ from keras import ops
5
+
6
+ from keras_rs.src import types
7
+
8
+
9
+ def get_shuffled_indices(
10
+ shape: types.TensorShape,
11
+ mask: Optional[types.Tensor] = None,
12
+ shuffle_ties: bool = True,
13
+ seed: Optional[Union[int, keras.random.SeedGenerator]] = None,
14
+ ) -> types.Tensor:
15
+ """Utility function for getting shuffled indices, with masked indices
16
+ pushed to the end.
17
+
18
+ Args:
19
+ shape: tuple. The shape of the tensor for which to generate
20
+ shuffled indices.
21
+ mask: An optional boolean tensor with the same shape as `shape`.
22
+ If provided, elements where `mask` is `False` will be placed
23
+ at the end of the sorted indices. Defaults to `None` (no masking).
24
+ shuffle_ties: Boolean indicating how to handle ties if multiple elements
25
+ have the same sorting value (randomly when `shuffle_ties` is True
26
+ otherwise, order is preserved).
27
+ seed: Optional integer seed for the random number generator used when
28
+ `shuffle_ties` is True. Ensures reproducibility. Defaults to None.
29
+
30
+ Returns:
31
+ A tensor of shape `shape` containing shuffled indices.
32
+ """
33
+ # If `shuffle_ties` is True, generate random values. Otherwise, generate
34
+ # zeros so that we get `[0, 1, 2, ...]` as indices on doing `argsort`.
35
+ if shuffle_ties:
36
+ shuffle_values = keras.random.uniform(shape, seed=seed, dtype="float32")
37
+ else:
38
+ shuffle_values = ops.zeros(shape, dtype="float32")
39
+
40
+ # When `mask = False`, increase value by 1 so that those indices are placed
41
+ # at the end. Note that `shuffle_values` lies in the range `[0, 1)`, so
42
+ # adding by 1 works out.
43
+ if mask is not None:
44
+ shuffle_values = ops.where(
45
+ mask,
46
+ shuffle_values,
47
+ ops.add(shuffle_values, ops.cast(1, dtype="float32")),
48
+ )
49
+
50
+ shuffled_indices = ops.argsort(shuffle_values)
51
+ return shuffled_indices
52
+
53
+
54
+ def sort_by_scores(
55
+ tensors_to_sort: list[types.Tensor],
56
+ scores: types.Tensor,
57
+ mask: Optional[types.Tensor] = None,
58
+ k: Optional[int] = None,
59
+ shuffle_ties: bool = True,
60
+ seed: Optional[Union[int, keras.random.SeedGenerator]] = None,
61
+ ) -> types.Tensor:
62
+ """
63
+ Utility function for sorting tensors by scores.
64
+
65
+ Args:
66
+ tensors_to_sort. list of tensors. All tensors are of shape
67
+ `(batch_size, list_size)`. These tensors are sorted based on
68
+ `scores`.
69
+ scores: tensor. Of shape `(batch_size, list_size)`. The scores to sort
70
+ by.
71
+ k: int. The number of top-ranked items to consider (the 'k' in 'top-k').
72
+ If `None`, `list_size` is used.
73
+ shuffle_ties: bool. Whether to randomly shuffle scores before sorting.
74
+ This is done to break ties.
75
+ seed: int. Seed for shuffling.
76
+
77
+ Returns:
78
+ List of sorted tensors (`tensors_to_sort`), sorted using `scores`.
79
+ """
80
+ max_possible_k = ops.shape(scores)[1]
81
+ if k is None:
82
+ k = max_possible_k
83
+ elif isinstance(max_possible_k, int):
84
+ k = min(k, max_possible_k)
85
+ else:
86
+ k = ops.minimum(k, max_possible_k)
87
+
88
+ # Shuffle ties randomly, and push masked values to the beginning.
89
+ shuffled_indices = None
90
+ if shuffle_ties or mask is not None:
91
+ shuffled_indices = get_shuffled_indices(
92
+ ops.shape(scores),
93
+ mask=mask,
94
+ shuffle_ties=True,
95
+ seed=seed,
96
+ )
97
+ scores = ops.take_along_axis(scores, shuffled_indices, axis=1)
98
+
99
+ # Get top-k indices.
100
+ _, indices = ops.top_k(scores, k=k, sorted=True)
101
+
102
+ # If we shuffled our `scores` tensor, we need to get the correct indices
103
+ # by indexing into `shuffled_indices`.
104
+ if shuffled_indices is not None:
105
+ indices = ops.take_along_axis(shuffled_indices, indices, axis=1)
106
+
107
+ return [
108
+ ops.take_along_axis(tensor_to_sort, indices, axis=1)
109
+ for tensor_to_sort in tensors_to_sort
110
+ ]
111
+
112
+
113
+ def get_list_weights(
114
+ weights: types.Tensor, relevance: types.Tensor
115
+ ) -> types.Tensor:
116
+ """Computes per-list weights from provided sample weights.
117
+
118
+ Per-list weights are calculated as follows:
119
+ ```
120
+ per_list_weights = sum(weights * relevance) / sum(relevance).
121
+ ```
122
+
123
+ For lists where the sum of relevance is 0, a default weight is assigned:
124
+ ```
125
+ sum(per_list_weights) / num(sum(relevance) != 0 AND sum(weights) != 0)
126
+ ```
127
+
128
+ If all lists have a sum of relevance equal to 0, the default weight is 1.0.
129
+
130
+ As a result of the above computation, this function takes care of the
131
+ following cases:
132
+ - **Uniform Weights:** When all input weights are 1.0, all per-list weights
133
+ will be 1.0, even for lists with no relevant examples. This aligns with
134
+ standard ranking metrics.
135
+ - **Non-zero Weights per List:** If every list has at least one non-zero
136
+ weight, the default weight mechanism is not utilized, which is suitable
137
+ for unbiased metrics.
138
+ - **Mixed Scenarios:** For cases with a mix of lists having zero and
139
+ non-zero relevance and weights, the weights for lists with non-zero
140
+ relevance and weights are proportional to:
141
+
142
+ ```
143
+ per_list_weights / sum(per_list_weights) *
144
+ num(sum(relevance) != 0) / num(lists)
145
+ ```
146
+
147
+ The rest have weights `1.0 / num(lists)`.
148
+
149
+ Args:
150
+ weights: tensor. Weights tensor of shape `(batch_size, list_size)`.
151
+ relevance: tensor. The relevance `Tensor` of shape
152
+ `(batch_size, list_size)`.
153
+
154
+ Returns:
155
+ A tensor of shape `(batch_size, 1)`, containing the per-list weights.
156
+ """
157
+ # Calculate if the sum of weights per list is greater than 0.0.
158
+ nonzero_weights = ops.greater(ops.sum(weights, axis=1, keepdims=True), 0.0)
159
+ # Calculate the sum of relevance per list
160
+ per_list_relevance = ops.sum(relevance, axis=1, keepdims=True)
161
+ # Calculate if the sum of relevance per list is greater than 0.0
162
+ nonzero_relevance_condition = ops.greater(per_list_relevance, 0.0)
163
+ # Identify lists where both weights and relevance sums are non-zero.
164
+ nonzero_relevance = ops.cast(
165
+ ops.logical_and(nonzero_weights, nonzero_relevance_condition),
166
+ dtype="float32",
167
+ )
168
+ # Count the number of lists with non-zero relevance and non-zero weights.
169
+ nonzero_relevance_count = ops.sum(nonzero_relevance, axis=0, keepdims=True)
170
+
171
+ # Calculate the per-list weights using the core formula.
172
+ # Numerator: `sum(weights * relevance)` per list
173
+ numerator = ops.sum(ops.multiply(weights, relevance), axis=1, keepdims=True)
174
+ # Denominator: per_list_relevance = sum(relevance) per list
175
+ per_list_weights = ops.divide_no_nan(numerator, per_list_relevance)
176
+
177
+ # Calculate the sum of the computed per-list weights.
178
+ sum_weights = ops.sum(per_list_weights, axis=0, keepdims=True)
179
+
180
+ # Calculate the average weight to use as default for lists with zero
181
+ # relevance but non-zero weights. If no lists have non-zero relevance,
182
+ # default to 1.0.
183
+ avg_weight = ops.where(
184
+ ops.greater(nonzero_relevance_count, 0.0),
185
+ ops.divide(sum_weights, nonzero_relevance_count),
186
+ ops.cast(1, dtype=sum_weights.dtype),
187
+ )
188
+
189
+ # Final assignment of weights based on conditions:
190
+ # 1. If sum(weights) == 0 for a list, the final weight is 0.
191
+ # 2. If sum(weights) > 0 AND sum(relevance) > 0, use the calculated
192
+ # `per_list_weights`.
193
+ # 3. If `sum(weights) > 0` AND `sum(relevance) == 0`, use the calculated
194
+ # `avg_weight`.
195
+ final_weights = ops.where(
196
+ nonzero_weights,
197
+ ops.where(
198
+ nonzero_relevance_condition,
199
+ per_list_weights,
200
+ avg_weight,
201
+ ),
202
+ ops.cast(0, dtype=per_list_weights.dtype),
203
+ )
204
+
205
+ return final_weights
206
+
207
+
208
+ @keras.saving.register_keras_serializable() # type: ignore[misc]
209
+ def default_gain_fn(label: types.Tensor) -> types.Tensor:
210
+ return ops.subtract(ops.power(2.0, label), 1.0)
211
+
212
+
213
+ @keras.saving.register_keras_serializable() # type: ignore[misc]
214
+ def default_rank_discount_fn(rank: types.Tensor) -> types.Tensor:
215
+ return ops.divide(
216
+ ops.cast(1, dtype=rank.dtype),
217
+ ops.log2(ops.add(ops.cast(1, dtype=rank.dtype), rank)),
218
+ )
219
+
220
+
221
+ def compute_dcg(
222
+ y_true: types.Tensor,
223
+ sample_weight: types.Tensor,
224
+ gain_fn: Callable[[types.Tensor], types.Tensor] = default_gain_fn,
225
+ rank_discount_fn: Callable[
226
+ [types.Tensor], types.Tensor
227
+ ] = default_rank_discount_fn,
228
+ ) -> types.Tensor:
229
+ list_size = ops.shape(y_true)[1]
230
+ positions = ops.arange(1, list_size + 1, dtype="float32")
231
+ gain = gain_fn(y_true)
232
+ discount = rank_discount_fn(positions)
233
+
234
+ return ops.sum(
235
+ ops.multiply(sample_weight, ops.multiply(gain, discount)),
236
+ axis=1,
237
+ keepdims=True,
238
+ )
@@ -0,0 +1,85 @@
1
+ from keras import ops
2
+
3
+ from keras_rs.src import types
4
+ from keras_rs.src.api_export import keras_rs_export
5
+ from keras_rs.src.metrics.ranking_metric import RankingMetric
6
+ from keras_rs.src.metrics.ranking_metric import (
7
+ ranking_metric_subclass_doc_string,
8
+ )
9
+ from keras_rs.src.metrics.ranking_metric import (
10
+ ranking_metric_subclass_doc_string_args,
11
+ )
12
+ from keras_rs.src.metrics.ranking_metrics_utils import get_list_weights
13
+ from keras_rs.src.metrics.ranking_metrics_utils import sort_by_scores
14
+ from keras_rs.src.utils.doc_string_utils import format_docstring
15
+
16
+
17
+ @keras_rs_export("keras_rs.metrics.RecallAtK")
18
+ class RecallAtK(RankingMetric):
19
+ def compute_metric(
20
+ self,
21
+ y_true: types.Tensor,
22
+ y_pred: types.Tensor,
23
+ mask: types.Tensor,
24
+ sample_weight: types.Tensor,
25
+ ) -> types.Tensor:
26
+ # Assume: `y_true = [0, 0, 1]`, `y_pred = [0.1, 0.9, 0.2]`.
27
+ # `sorted_y_true = [0, 1, 0]` (sorted in descending order).
28
+ (sorted_y_true,) = sort_by_scores(
29
+ tensors_to_sort=[y_true],
30
+ scores=y_pred,
31
+ mask=mask,
32
+ k=self.k,
33
+ shuffle_ties=self.shuffle_ties,
34
+ seed=self.seed_generator,
35
+ )
36
+
37
+ relevance = ops.cast(
38
+ ops.greater_equal(
39
+ sorted_y_true, ops.cast(1, dtype=sorted_y_true.dtype)
40
+ ),
41
+ dtype="float32",
42
+ )
43
+ overall_relevance = ops.cast(
44
+ ops.greater_equal(y_true, ops.cast(1, dtype=y_true.dtype)),
45
+ dtype="float32",
46
+ )
47
+ per_list_recall = ops.divide_no_nan(
48
+ ops.sum(relevance, axis=1, keepdims=True),
49
+ ops.sum(overall_relevance, axis=1, keepdims=True),
50
+ )
51
+
52
+ # Get weights.
53
+ per_list_weights = get_list_weights(
54
+ weights=sample_weight, relevance=overall_relevance
55
+ )
56
+
57
+ return per_list_recall, per_list_weights
58
+
59
+
60
+ concept_sentence = (
61
+ "It measures the proportion of relevant items found in the top-k "
62
+ "recommendations out of the total number of relevant items for a user"
63
+ )
64
+ relevance_type = "binary indicators (0 or 1) of relevance"
65
+ score_range_interpretation = (
66
+ "Scores range from 0 to 1, with 1 indicating that all relevant items "
67
+ "for the user were found within the top-k recommendations"
68
+ )
69
+ formula = """```
70
+ R@k(y, s) = sum_i I[rank(s_i) < k] y_i / sum_j y_j
71
+ ```
72
+
73
+ where `y_i` is the relevance label (0/1) of the item ranked at position
74
+ `i`, `I[condition]` is 1 if the condition is met, otherwise 0."""
75
+ extra_args = ""
76
+ RecallAtK.__doc__ = format_docstring(
77
+ ranking_metric_subclass_doc_string,
78
+ width=80,
79
+ metric_name="Recall@k",
80
+ metric_abbreviation="R@k",
81
+ concept_sentence=concept_sentence,
82
+ relevance_type=relevance_type,
83
+ score_range_interpretation=score_range_interpretation,
84
+ formula=formula,
85
+ ) + ranking_metric_subclass_doc_string_args.format(extra_args=extra_args)
@@ -0,0 +1,72 @@
1
+ from typing import Optional
2
+
3
+ from keras import ops
4
+
5
+ from keras_rs.src import types
6
+ from keras_rs.src.utils.keras_utils import check_rank
7
+ from keras_rs.src.utils.keras_utils import check_shapes_compatible
8
+
9
+
10
+ def standardize_call_inputs_ranks(
11
+ y_true: types.Tensor,
12
+ y_pred: types.Tensor,
13
+ mask: Optional[types.Tensor] = None,
14
+ check_y_true_rank: bool = True,
15
+ ) -> tuple[types.Tensor, types.Tensor, Optional[types.Tensor], bool]:
16
+ """
17
+ Utility function for processing inputs for losses and metrics.
18
+
19
+ This utility function does three things:
20
+
21
+ - Checks that `y_true`, `y_pred` are of rank 1 or 2;
22
+ - Checks that `y_true`, `y_pred`, `mask` have the same shape;
23
+ - Adds batch dimension if rank = 1.
24
+
25
+ Args:
26
+ y_true: tensor. Ground truth values.
27
+ y_pred: tensor. The predicted values.
28
+ mask: tensor. Boolean mask for `y_true`.
29
+ check_y_true_rank: bool. Whether to check the rank of `y_true`.
30
+
31
+ Returns:
32
+ Tuple of processed `y_true`, `y_pred`, `mask`, and `batched`. `batched`
33
+ is a bool indicating if the inputs are batched.
34
+ """
35
+
36
+ y_true_shape = ops.shape(y_true)
37
+ y_true_rank = len(y_true_shape)
38
+ y_pred_shape = ops.shape(y_pred)
39
+ y_pred_rank = len(y_pred_shape)
40
+ if mask is not None:
41
+ mask_shape = ops.shape(mask)
42
+ mask_rank = len(mask_shape)
43
+
44
+ if check_y_true_rank:
45
+ check_rank(y_true_rank, allowed_ranks=(1, 2), tensor_name="y_true")
46
+ check_rank(y_pred_rank, allowed_ranks=(1, 2), tensor_name="y_pred")
47
+ if mask is not None:
48
+ check_rank(mask_rank, allowed_ranks=(1, 2), tensor_name="mask")
49
+ if not check_shapes_compatible(y_true_shape, y_pred_shape):
50
+ raise ValueError(
51
+ "`y_true` and `y_pred` should have the same shape. Received: "
52
+ f"`y_true.shape` = {y_true_shape}, `y_pred.shape` = {y_pred_shape}."
53
+ )
54
+ if mask is not None and not check_shapes_compatible(
55
+ y_true_shape, mask_shape
56
+ ):
57
+ raise ValueError(
58
+ "`y_true['labels']` and `y_true['mask']` should have the same "
59
+ f"shape. Received: `y_true['labels'].shape` = {y_true_shape}, "
60
+ f"`y_true['mask'].shape` = {mask_shape}."
61
+ )
62
+
63
+ batched = True
64
+ if y_true_rank == 1:
65
+ batched = False
66
+
67
+ y_true = ops.expand_dims(y_true, axis=0)
68
+ y_pred = ops.expand_dims(y_pred, axis=0)
69
+ if mask is not None:
70
+ mask = ops.expand_dims(mask, axis=0)
71
+
72
+ return y_true, y_pred, mask, batched
@@ -0,0 +1,48 @@
1
+ import re
2
+ import textwrap
3
+ from typing import Any
4
+
5
+
6
+ def format_docstring(template: str, width: int = 80, **kwargs: Any) -> str:
7
+ """Formats and wraps a docstring using dedent and fill."""
8
+ base_indent_str = " " * 4
9
+
10
+ # Initial format
11
+ formatted = template.format(**kwargs)
12
+
13
+ # Dedent the whole block
14
+ dedented_all = textwrap.dedent(formatted).strip()
15
+
16
+ # Split into logical paragraphs/blocks.
17
+ blocks = re.split(r"(\n\s*\n)", dedented_all)
18
+
19
+ processed_output = []
20
+
21
+ for block in blocks:
22
+ stripped_block = block.strip()
23
+ if not stripped_block:
24
+ processed_output.append(block)
25
+ continue
26
+
27
+ if "```" in stripped_block:
28
+ formula_dedented = textwrap.dedent(stripped_block)
29
+ processed_output.append(
30
+ textwrap.indent(formula_dedented, base_indent_str)
31
+ )
32
+ elif "where:" in stripped_block:
33
+ processed_output.append(
34
+ textwrap.indent(stripped_block, base_indent_str)
35
+ )
36
+ else:
37
+ processed_output.append(
38
+ textwrap.fill(
39
+ stripped_block,
40
+ width=width - len(base_indent_str),
41
+ initial_indent=base_indent_str,
42
+ subsequent_indent=base_indent_str,
43
+ )
44
+ )
45
+
46
+ final_string = "".join(processed_output).strip()
47
+ final_string = base_indent_str + final_string
48
+ return final_string
@@ -42,3 +42,15 @@ def check_shapes_compatible(
42
42
  return False
43
43
 
44
44
  return True
45
+
46
+
47
+ def check_rank(
48
+ x_rank: int,
49
+ allowed_ranks: tuple[int, ...],
50
+ tensor_name: str,
51
+ ) -> None:
52
+ if x_rank not in allowed_ranks:
53
+ raise ValueError(
54
+ f"`{tensor_name}` should have a rank from `{allowed_ranks}`."
55
+ f"Received: `{x_rank}`."
56
+ )
keras_rs/src/version.py CHANGED
@@ -1,7 +1,7 @@
1
1
  from keras_rs.src.api_export import keras_rs_export
2
2
 
3
3
  # Unique source of truth for the version number.
4
- __version__ = "0.0.1.dev2025042103"
4
+ __version__ = "0.0.1.dev2025042503"
5
5
 
6
6
 
7
7
  @keras_rs_export("keras_rs.version")
@@ -1,10 +1,10 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: keras-rs-nightly
3
- Version: 0.0.1.dev2025042103
3
+ Version: 0.0.1.dev2025042503
4
4
  Summary: Multi-backend recommender systems with Keras 3.
5
- Author-email: Keras RS team <keras-rs@google.com>
5
+ Author-email: Keras team <keras-users@googlegroups.com>
6
6
  License: Apache License 2.0
7
- Project-URL: Home, https://keras.io/
7
+ Project-URL: Home, https://keras.io/keras_rs
8
8
  Project-URL: Repository, https://github.com/keras-team/keras-rs
9
9
  Classifier: Development Status :: 3 - Alpha
10
10
  Classifier: Programming Language :: Python :: 3
@@ -13,6 +13,7 @@ Classifier: Programming Language :: Python :: 3.10
13
13
  Classifier: Programming Language :: Python :: 3.11
14
14
  Classifier: Programming Language :: Python :: 3 :: Only
15
15
  Classifier: Operating System :: Unix
16
+ Classifier: Operating System :: Microsoft :: Windows
16
17
  Classifier: Operating System :: MacOS
17
18
  Classifier: Intended Audience :: Science/Research
18
19
  Classifier: Topic :: Scientific/Engineering
@@ -0,0 +1,42 @@
1
+ keras_rs/__init__.py,sha256=8sjHiPN2GhUqAq4V7Vh4FLLqYw20-jgdI26ZKX5sg6M,350
2
+ keras_rs/layers/__init__.py,sha256=cvrFgFWg0RjI0ExUZOKZRdcN-FwTIkqhT33Vx8wGtjQ,905
3
+ keras_rs/losses/__init__.py,sha256=m04QOgxIUfJ2MvCUKLgEof-UbSNKgUYLPnY-D9NAclI,573
4
+ keras_rs/metrics/__init__.py,sha256=Qxpf6OFooIL9TIn2l3WgOea3HFRG0hq02glPAxtMZ9c,580
5
+ keras_rs/src/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
+ keras_rs/src/api_export.py,sha256=RsmG-DvO-cdFeAF9W6LRzms0kvtm-Yp9BAA_d-952zI,510
7
+ keras_rs/src/types.py,sha256=UyOdgjqrqg_b58opnY8n6gTiDHKVR8z_bmEruehERBk,514
8
+ keras_rs/src/version.py,sha256=xK7d2N2GcNnMxYdTlwcMzFhxwnexq30m_nMs_uRDSWA,222
9
+ keras_rs/src/layers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
+ keras_rs/src/layers/feature_interaction/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
+ keras_rs/src/layers/feature_interaction/dot_interaction.py,sha256=jGHcg0EiWxth6LTxG2yWgHcyx_GXrxvA61uQqpPfnDQ,6900
12
+ keras_rs/src/layers/feature_interaction/feature_cross.py,sha256=5OCSI0vFYzJNmgkKcuHIbVv8U2q3UvS80-qZjPimDjM,8155
13
+ keras_rs/src/layers/retrieval/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
+ keras_rs/src/layers/retrieval/brute_force_retrieval.py,sha256=izdppBXxJH0KqYEg7Zsr-SL-SHgAmnFopXMPalEO3uw,5676
15
+ keras_rs/src/layers/retrieval/hard_negative_mining.py,sha256=IWFrbw1h9z3AUw4oUBKf5_Aud4MTHO_AKdHfoyFa5As,3031
16
+ keras_rs/src/layers/retrieval/remove_accidental_hits.py,sha256=Z84z2YgKspKeNdc5id8lf9TAyFsbCCz3acJxiKXYipc,3324
17
+ keras_rs/src/layers/retrieval/retrieval.py,sha256=hVOBF10SF2q_TgJdVUqztbnw5qQF-cxVRGdJbOKoL9M,4191
18
+ keras_rs/src/layers/retrieval/sampling_probability_correction.py,sha256=80vgOPfBiF-PC0dSyqS57IcIxOxi_Q_R7eSXHn1G0yI,1437
19
+ keras_rs/src/losses/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
20
+ keras_rs/src/losses/pairwise_hinge_loss.py,sha256=nrIU0d1IcCAGo7RVxNitkldJhY2ZrXxjTV7Po27FXds,950
21
+ keras_rs/src/losses/pairwise_logistic_loss.py,sha256=2dTtRmrNfvF_lOvHK0UQ518L2d4fkvQZDj30HWB5A2s,1305
22
+ keras_rs/src/losses/pairwise_loss.py,sha256=rmDr_Qc3yA0CR8rUCCGjOgdbjYfC505BLNuITyb1n8k,6132
23
+ keras_rs/src/losses/pairwise_loss_utils.py,sha256=xvdGvdKNkvGvIaWYEQziWTFNa5EJz7rdkVGgrsnDHUk,1246
24
+ keras_rs/src/losses/pairwise_mean_squared_error.py,sha256=KhSRvjg4RpwhASP1Sl7PZoq2488P_uGDr9tZWzZhDVU,2764
25
+ keras_rs/src/losses/pairwise_soft_zero_one_loss.py,sha256=QdWn-lyWQM-U9ID9xGQ7oK10q9XT6qd1gxVAKy8hZH4,1239
26
+ keras_rs/src/metrics/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
27
+ keras_rs/src/metrics/dcg.py,sha256=DzSBc9ZbgNavuHRt3wtVzdx4ouAaaqeYhd9NxQLPq0g,5120
28
+ keras_rs/src/metrics/mean_average_precision.py,sha256=SF5NlhlyVL9L_YVkj_s_135f3-8hILVHRziSGafGyZI,3915
29
+ keras_rs/src/metrics/mean_reciprocal_rank.py,sha256=4stq0MzyWNokMlol6BESDAMuoUFieDrFFc57ue94h4Y,3240
30
+ keras_rs/src/metrics/ndcg.py,sha256=G7WNFoUaOhnf4vMF1jgcI4yGxieUfJv5E0upv4Qs1AQ,6545
31
+ keras_rs/src/metrics/precision_at_k.py,sha256=u-mj49qamt448gxkOI9YIZMMrhgO8QmetRFXGGlWOqY,3247
32
+ keras_rs/src/metrics/ranking_metric.py,sha256=cdFb4Lg2Z8P-02ImMGUAX4XeOUyzEE8TA6nB4fDgq0U,10411
33
+ keras_rs/src/metrics/ranking_metrics_utils.py,sha256=989J8pr6FRsA1HwBeF7SA8uQqjZT2XeCxKfRuMysWnQ,8828
34
+ keras_rs/src/metrics/recall_at_k.py,sha256=hlPnR5AtFjdd5AG0zLkLGVyLO5mWtp2bAu_cSOq9Fws,2919
35
+ keras_rs/src/metrics/utils.py,sha256=6xanTNdwARn4ugzmb7ko2kwAhNhsnR4NhrpS_qW0IKc,2506
36
+ keras_rs/src/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
37
+ keras_rs/src/utils/doc_string_utils.py,sha256=yVyQ8pYdl4gd4tKRhD8dXmQX1EwZeLiV3cCq3A1tUEk,1466
38
+ keras_rs/src/utils/keras_utils.py,sha256=d28OdQP4GrJk4NIQS4n0KPtCbgOCxVU_vDnnI7ODpOw,1562
39
+ keras_rs_nightly-0.0.1.dev2025042503.dist-info/METADATA,sha256=mJ89IDGmATYXeG9OiNQWwvnm1Z6TaXJ3Aedj7nMiixc,3614
40
+ keras_rs_nightly-0.0.1.dev2025042503.dist-info/WHEEL,sha256=SmOxYU7pzNKBqASvQJ7DjX3XGUF92lrGhMb3R6_iiqI,91
41
+ keras_rs_nightly-0.0.1.dev2025042503.dist-info/top_level.txt,sha256=pWs8X78Z0cn6lfcIb9VYOW5UeJ-TpoaO9dByzo7_FFo,9
42
+ keras_rs_nightly-0.0.1.dev2025042503.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (79.0.0)
2
+ Generator: setuptools (79.0.1)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
keras_rs/api/__init__.py DELETED
@@ -1,10 +0,0 @@
1
- """DO NOT EDIT.
2
-
3
- This file was autogenerated. Do not edit it by hand,
4
- since your modifications would be overwritten.
5
- """
6
-
7
- from keras_rs.api import layers
8
- from keras_rs.api import losses
9
- from keras_rs.src.version import __version__
10
- from keras_rs.src.version import version
@@ -1,14 +0,0 @@
1
- """DO NOT EDIT.
2
-
3
- This file was autogenerated. Do not edit it by hand,
4
- since your modifications would be overwritten.
5
- """
6
-
7
- from keras_rs.src.losses.pairwise_hinge_loss import PairwiseHingeLoss
8
- from keras_rs.src.losses.pairwise_logistic_loss import PairwiseLogisticLoss
9
- from keras_rs.src.losses.pairwise_mean_squared_error import (
10
- PairwiseMeanSquaredError,
11
- )
12
- from keras_rs.src.losses.pairwise_soft_zero_one_loss import (
13
- PairwiseSoftZeroOneLoss,
14
- )
@@ -1,102 +0,0 @@
1
- from typing import Callable, Optional
2
-
3
- from keras import ops
4
-
5
- from keras_rs.src import types
6
- from keras_rs.src.utils.keras_utils import check_shapes_compatible
7
-
8
-
9
- def apply_pairwise_op(
10
- x: types.Tensor, op: Callable[[types.Tensor, types.Tensor], types.Tensor]
11
- ) -> types.Tensor:
12
- return op(
13
- ops.expand_dims(x, axis=-1),
14
- ops.expand_dims(x, axis=-2),
15
- )
16
-
17
-
18
- def pairwise_comparison(
19
- labels: types.Tensor,
20
- logits: types.Tensor,
21
- mask: types.Tensor,
22
- logits_op: Callable[[types.Tensor, types.Tensor], types.Tensor],
23
- ) -> tuple[types.Tensor, types.Tensor]:
24
- # Compute the difference for all pairs in a list. The output is a tensor
25
- # with shape `(batch_size, list_size, list_size)`, where `[:, i, j]` stores
26
- # information for pair `(i, j)`.
27
- pairwise_labels_diff = apply_pairwise_op(labels, ops.subtract)
28
- pairwise_logits = apply_pairwise_op(logits, logits_op)
29
-
30
- # Keep only those cases where `l_i < l_j`.
31
- pairwise_labels = ops.cast(
32
- ops.greater(pairwise_labels_diff, 0), dtype=labels.dtype
33
- )
34
- if mask is not None:
35
- valid_pairs = apply_pairwise_op(mask, ops.logical_and)
36
- pairwise_labels = ops.multiply(
37
- pairwise_labels, ops.cast(valid_pairs, dtype=pairwise_labels.dtype)
38
- )
39
-
40
- return pairwise_labels, pairwise_logits
41
-
42
-
43
- def process_loss_call_inputs(
44
- y_true: types.Tensor,
45
- y_pred: types.Tensor,
46
- mask: Optional[types.Tensor] = None,
47
- ) -> tuple[types.Tensor, types.Tensor, Optional[types.Tensor]]:
48
- """
49
- Utility function for processing inputs for pairwise losses.
50
-
51
- This utility function does three things:
52
-
53
- - Checks that `y_true`, `y_pred` are of rank 1 or 2;
54
- - Checks that `y_true`, `y_pred`, `mask` have the same shape;
55
- - Adds batch dimension if rank = 1.
56
- """
57
-
58
- y_true_shape = ops.shape(y_true)
59
- y_true_rank = len(y_true_shape)
60
- y_pred_shape = ops.shape(y_pred)
61
- y_pred_rank = len(y_pred_shape)
62
- if mask is not None:
63
- mask_shape = ops.shape(mask)
64
- mask_rank = len(mask_shape)
65
-
66
- # Check ranks and shapes.
67
- def check_rank(
68
- x_rank: int,
69
- allowed_ranks: tuple[int, ...] = (1, 2),
70
- tensor_name: Optional[str] = None,
71
- ) -> None:
72
- if x_rank not in allowed_ranks:
73
- raise ValueError(
74
- f"`{tensor_name}` should have a rank from `{allowed_ranks}`."
75
- f"Received: `{x_rank}`."
76
- )
77
-
78
- check_rank(y_true_rank, tensor_name="y_true")
79
- check_rank(y_pred_rank, tensor_name="y_pred")
80
- if mask is not None:
81
- check_rank(mask_rank, tensor_name="mask")
82
- if not check_shapes_compatible(y_true_shape, y_pred_shape):
83
- raise ValueError(
84
- "`y_true` and `y_pred` should have the same shape. Received: "
85
- f"`y_true.shape` = {y_true_shape}, `y_pred.shape` = {y_pred_shape}."
86
- )
87
- if mask is not None and not check_shapes_compatible(
88
- y_true_shape, mask_shape
89
- ):
90
- raise ValueError(
91
- "`y_true['labels']` and `y_true['mask']` should have the same "
92
- f"shape. Received: `y_true['labels'].shape` = {y_true_shape}, "
93
- f"`y_true['mask'].shape` = {mask_shape}."
94
- )
95
-
96
- if y_true_rank == 1:
97
- y_true = ops.expand_dims(y_true, axis=0)
98
- y_pred = ops.expand_dims(y_pred, axis=0)
99
- if mask is not None:
100
- mask = ops.expand_dims(mask, axis=0)
101
-
102
- return y_true, y_pred, mask