keras-rs-nightly 0.0.1.dev2025042003__py3-none-any.whl → 0.0.1.dev2025042203__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of keras-rs-nightly might be problematic. Click here for more details.

@@ -0,0 +1,238 @@
1
+ from typing import Callable, Optional, Union
2
+
3
+ import keras
4
+ from keras import ops
5
+
6
+ from keras_rs.src import types
7
+
8
+
9
+ def get_shuffled_indices(
10
+ shape: types.TensorShape,
11
+ mask: Optional[types.Tensor] = None,
12
+ shuffle_ties: bool = True,
13
+ seed: Optional[Union[int, keras.random.SeedGenerator]] = None,
14
+ ) -> types.Tensor:
15
+ """Utility function for getting shuffled indices, with masked indices
16
+ pushed to the end.
17
+
18
+ Args:
19
+ shape: tuple. The shape of the tensor for which to generate
20
+ shuffled indices.
21
+ mask: An optional boolean tensor with the same shape as `shape`.
22
+ If provided, elements where `mask` is `False` will be placed
23
+ at the end of the sorted indices. Defaults to `None` (no masking).
24
+ shuffle_ties: Boolean indicating how to handle ties if multiple elements
25
+ have the same sorting value (randomly when `shuffle_ties` is True
26
+ otherwise, order is preserved).
27
+ seed: Optional integer seed for the random number generator used when
28
+ `shuffle_ties` is True. Ensures reproducibility. Defaults to None.
29
+
30
+ Returns:
31
+ A tensor of shape `shape` containing shuffled indices.
32
+ """
33
+ # If `shuffle_ties` is True, generate random values. Otherwise, generate
34
+ # zeros so that we get `[0, 1, 2, ...]` as indices on doing `argsort`.
35
+ if shuffle_ties:
36
+ shuffle_values = keras.random.uniform(shape, seed=seed, dtype="float32")
37
+ else:
38
+ shuffle_values = ops.zeros(shape, dtype="float32")
39
+
40
+ # When `mask = False`, increase value by 1 so that those indices are placed
41
+ # at the end. Note that `shuffle_values` lies in the range `[0, 1)`, so
42
+ # adding by 1 works out.
43
+ if mask is not None:
44
+ shuffle_values = ops.where(
45
+ mask,
46
+ shuffle_values,
47
+ ops.add(shuffle_values, ops.cast(1, dtype="float32")),
48
+ )
49
+
50
+ shuffled_indices = ops.argsort(shuffle_values)
51
+ return shuffled_indices
52
+
53
+
54
+ def sort_by_scores(
55
+ tensors_to_sort: list[types.Tensor],
56
+ scores: types.Tensor,
57
+ mask: Optional[types.Tensor] = None,
58
+ k: Optional[int] = None,
59
+ shuffle_ties: bool = True,
60
+ seed: Optional[Union[int, keras.random.SeedGenerator]] = None,
61
+ ) -> types.Tensor:
62
+ """
63
+ Utility function for sorting tensors by scores.
64
+
65
+ Args:
66
+ tensors_to_sort. list of tensors. All tensors are of shape
67
+ `(batch_size, list_size)`. These tensors are sorted based on
68
+ `scores`.
69
+ scores: tensor. Of shape `(batch_size, list_size)`. The scores to sort
70
+ by.
71
+ k: int. The number of top-ranked items to consider (the 'k' in 'top-k').
72
+ If `None`, `list_size` is used.
73
+ shuffle_ties: bool. Whether to randomly shuffle scores before sorting.
74
+ This is done to break ties.
75
+ seed: int. Seed for shuffling.
76
+
77
+ Returns:
78
+ List of sorted tensors (`tensors_to_sort`), sorted using `scores`.
79
+ """
80
+ max_possible_k = ops.shape(scores)[1]
81
+ if k is None:
82
+ k = max_possible_k
83
+ elif isinstance(max_possible_k, int):
84
+ k = min(k, max_possible_k)
85
+ else:
86
+ k = ops.minimum(k, max_possible_k)
87
+
88
+ # Shuffle ties randomly, and push masked values to the beginning.
89
+ shuffled_indices = None
90
+ if shuffle_ties or mask is not None:
91
+ shuffled_indices = get_shuffled_indices(
92
+ ops.shape(scores),
93
+ mask=mask,
94
+ shuffle_ties=True,
95
+ seed=seed,
96
+ )
97
+ scores = ops.take_along_axis(scores, shuffled_indices, axis=1)
98
+
99
+ # Get top-k indices.
100
+ _, indices = ops.top_k(scores, k=k, sorted=True)
101
+
102
+ # If we shuffled our `scores` tensor, we need to get the correct indices
103
+ # by indexing into `shuffled_indices`.
104
+ if shuffled_indices is not None:
105
+ indices = ops.take_along_axis(shuffled_indices, indices, axis=1)
106
+
107
+ return [
108
+ ops.take_along_axis(tensor_to_sort, indices, axis=1)
109
+ for tensor_to_sort in tensors_to_sort
110
+ ]
111
+
112
+
113
+ def get_list_weights(
114
+ weights: types.Tensor, relevance: types.Tensor
115
+ ) -> types.Tensor:
116
+ """Computes per-list weights from provided sample weights.
117
+
118
+ Per-list weights are calculated as follows:
119
+ ```
120
+ per_list_weights = sum(weights * relevance) / sum(relevance).
121
+ ```
122
+
123
+ For lists where the sum of relevance is 0, a default weight is assigned:
124
+ ```
125
+ sum(per_list_weights) / num(sum(relevance) != 0 AND sum(weights) != 0)
126
+ ```
127
+
128
+ If all lists have a sum of relevance equal to 0, the default weight is 1.0.
129
+
130
+ As a result of the above computation, this function takes care of the
131
+ following cases:
132
+ - **Uniform Weights:** When all input weights are 1.0, all per-list weights
133
+ will be 1.0, even for lists with no relevant examples. This aligns with
134
+ standard ranking metrics.
135
+ - **Non-zero Weights per List:** If every list has at least one non-zero
136
+ weight, the default weight mechanism is not utilized, which is suitable
137
+ for unbiased metrics.
138
+ - **Mixed Scenarios:** For cases with a mix of lists having zero and
139
+ non-zero relevance and weights, the weights for lists with non-zero
140
+ relevance and weights are proportional to:
141
+
142
+ ```
143
+ per_list_weights / sum(per_list_weights) *
144
+ num(sum(relevance) != 0) / num(lists)
145
+ ```
146
+
147
+ The rest have weights `1.0 / num(lists)`.
148
+
149
+ Args:
150
+ weights: tensor. Weights tensor of shape `(batch_size, list_size)`.
151
+ relevance: tensor. The relevance `Tensor` of shape
152
+ `(batch_size, list_size)`.
153
+
154
+ Returns:
155
+ A tensor of shape `(batch_size, 1)`, containing the per-list weights.
156
+ """
157
+ # Calculate if the sum of weights per list is greater than 0.0.
158
+ nonzero_weights = ops.greater(ops.sum(weights, axis=1, keepdims=True), 0.0)
159
+ # Calculate the sum of relevance per list
160
+ per_list_relevance = ops.sum(relevance, axis=1, keepdims=True)
161
+ # Calculate if the sum of relevance per list is greater than 0.0
162
+ nonzero_relevance_condition = ops.greater(per_list_relevance, 0.0)
163
+ # Identify lists where both weights and relevance sums are non-zero.
164
+ nonzero_relevance = ops.cast(
165
+ ops.logical_and(nonzero_weights, nonzero_relevance_condition),
166
+ dtype="float32",
167
+ )
168
+ # Count the number of lists with non-zero relevance and non-zero weights.
169
+ nonzero_relevance_count = ops.sum(nonzero_relevance, axis=0, keepdims=True)
170
+
171
+ # Calculate the per-list weights using the core formula.
172
+ # Numerator: `sum(weights * relevance)` per list
173
+ numerator = ops.sum(ops.multiply(weights, relevance), axis=1, keepdims=True)
174
+ # Denominator: per_list_relevance = sum(relevance) per list
175
+ per_list_weights = ops.divide_no_nan(numerator, per_list_relevance)
176
+
177
+ # Calculate the sum of the computed per-list weights.
178
+ sum_weights = ops.sum(per_list_weights, axis=0, keepdims=True)
179
+
180
+ # Calculate the average weight to use as default for lists with zero
181
+ # relevance but non-zero weights. If no lists have non-zero relevance,
182
+ # default to 1.0.
183
+ avg_weight = ops.where(
184
+ ops.greater(nonzero_relevance_count, 0.0),
185
+ ops.divide(sum_weights, nonzero_relevance_count),
186
+ ops.cast(1, dtype=sum_weights.dtype),
187
+ )
188
+
189
+ # Final assignment of weights based on conditions:
190
+ # 1. If sum(weights) == 0 for a list, the final weight is 0.
191
+ # 2. If sum(weights) > 0 AND sum(relevance) > 0, use the calculated
192
+ # `per_list_weights`.
193
+ # 3. If `sum(weights) > 0` AND `sum(relevance) == 0`, use the calculated
194
+ # `avg_weight`.
195
+ final_weights = ops.where(
196
+ nonzero_weights,
197
+ ops.where(
198
+ nonzero_relevance_condition,
199
+ per_list_weights,
200
+ avg_weight,
201
+ ),
202
+ ops.cast(0, dtype=per_list_weights.dtype),
203
+ )
204
+
205
+ return final_weights
206
+
207
+
208
+ @keras.saving.register_keras_serializable() # type: ignore[misc]
209
+ def default_gain_fn(label: types.Tensor) -> types.Tensor:
210
+ return ops.subtract(ops.power(2.0, label), 1.0)
211
+
212
+
213
+ @keras.saving.register_keras_serializable() # type: ignore[misc]
214
+ def default_rank_discount_fn(rank: types.Tensor) -> types.Tensor:
215
+ return ops.divide(
216
+ ops.cast(1, dtype=rank.dtype),
217
+ ops.log2(ops.add(ops.cast(1, dtype=rank.dtype), rank)),
218
+ )
219
+
220
+
221
+ def compute_dcg(
222
+ y_true: types.Tensor,
223
+ sample_weight: types.Tensor,
224
+ gain_fn: Callable[[types.Tensor], types.Tensor] = default_gain_fn,
225
+ rank_discount_fn: Callable[
226
+ [types.Tensor], types.Tensor
227
+ ] = default_rank_discount_fn,
228
+ ) -> types.Tensor:
229
+ list_size = ops.shape(y_true)[1]
230
+ positions = ops.arange(1, list_size + 1, dtype="float32")
231
+ gain = gain_fn(y_true)
232
+ discount = rank_discount_fn(positions)
233
+
234
+ return ops.sum(
235
+ ops.multiply(sample_weight, ops.multiply(gain, discount)),
236
+ axis=1,
237
+ keepdims=True,
238
+ )
@@ -0,0 +1,85 @@
1
+ from keras import ops
2
+
3
+ from keras_rs.src import types
4
+ from keras_rs.src.api_export import keras_rs_export
5
+ from keras_rs.src.metrics.ranking_metric import RankingMetric
6
+ from keras_rs.src.metrics.ranking_metric import (
7
+ ranking_metric_subclass_doc_string,
8
+ )
9
+ from keras_rs.src.metrics.ranking_metric import (
10
+ ranking_metric_subclass_doc_string_args,
11
+ )
12
+ from keras_rs.src.metrics.ranking_metrics_utils import get_list_weights
13
+ from keras_rs.src.metrics.ranking_metrics_utils import sort_by_scores
14
+ from keras_rs.src.utils.doc_string_utils import format_docstring
15
+
16
+
17
+ @keras_rs_export("keras_rs.metrics.RecallAtK")
18
+ class RecallAtK(RankingMetric):
19
+ def compute_metric(
20
+ self,
21
+ y_true: types.Tensor,
22
+ y_pred: types.Tensor,
23
+ mask: types.Tensor,
24
+ sample_weight: types.Tensor,
25
+ ) -> types.Tensor:
26
+ # Assume: `y_true = [0, 0, 1]`, `y_pred = [0.1, 0.9, 0.2]`.
27
+ # `sorted_y_true = [0, 1, 0]` (sorted in descending order).
28
+ (sorted_y_true,) = sort_by_scores(
29
+ tensors_to_sort=[y_true],
30
+ scores=y_pred,
31
+ mask=mask,
32
+ k=self.k,
33
+ shuffle_ties=self.shuffle_ties,
34
+ seed=self.seed_generator,
35
+ )
36
+
37
+ relevance = ops.cast(
38
+ ops.greater_equal(
39
+ sorted_y_true, ops.cast(1, dtype=sorted_y_true.dtype)
40
+ ),
41
+ dtype="float32",
42
+ )
43
+ overall_relevance = ops.cast(
44
+ ops.greater_equal(y_true, ops.cast(1, dtype=y_true.dtype)),
45
+ dtype="float32",
46
+ )
47
+ per_list_recall = ops.divide_no_nan(
48
+ ops.sum(relevance, axis=1, keepdims=True),
49
+ ops.sum(overall_relevance, axis=1, keepdims=True),
50
+ )
51
+
52
+ # Get weights.
53
+ per_list_weights = get_list_weights(
54
+ weights=sample_weight, relevance=overall_relevance
55
+ )
56
+
57
+ return per_list_recall, per_list_weights
58
+
59
+
60
+ concept_sentence = (
61
+ "It measures the proportion of relevant items found in the top-k "
62
+ "recommendations out of the total number of relevant items for a user"
63
+ )
64
+ relevance_type = "binary indicators (0 or 1) of relevance"
65
+ score_range_interpretation = (
66
+ "Scores range from 0 to 1, with 1 indicating that all relevant items "
67
+ "for the user were found within the top-k recommendations"
68
+ )
69
+ formula = """```
70
+ R@k(y, s) = sum_i I[rank(s_i) < k] y_i / sum_j y_j
71
+ ```
72
+
73
+ where `y_i` is the relevance label (0/1) of the item ranked at position
74
+ `i`, `I[condition]` is 1 if the condition is met, otherwise 0."""
75
+ extra_args = ""
76
+ RecallAtK.__doc__ = format_docstring(
77
+ ranking_metric_subclass_doc_string,
78
+ width=80,
79
+ metric_name="Recall@k",
80
+ metric_abbreviation="R@k",
81
+ concept_sentence=concept_sentence,
82
+ relevance_type=relevance_type,
83
+ score_range_interpretation=score_range_interpretation,
84
+ formula=formula,
85
+ ) + ranking_metric_subclass_doc_string_args.format(extra_args=extra_args)
@@ -0,0 +1,72 @@
1
+ from typing import Optional
2
+
3
+ from keras import ops
4
+
5
+ from keras_rs.src import types
6
+ from keras_rs.src.utils.keras_utils import check_rank
7
+ from keras_rs.src.utils.keras_utils import check_shapes_compatible
8
+
9
+
10
+ def standardize_call_inputs_ranks(
11
+ y_true: types.Tensor,
12
+ y_pred: types.Tensor,
13
+ mask: Optional[types.Tensor] = None,
14
+ check_y_true_rank: bool = True,
15
+ ) -> tuple[types.Tensor, types.Tensor, Optional[types.Tensor], bool]:
16
+ """
17
+ Utility function for processing inputs for losses and metrics.
18
+
19
+ This utility function does three things:
20
+
21
+ - Checks that `y_true`, `y_pred` are of rank 1 or 2;
22
+ - Checks that `y_true`, `y_pred`, `mask` have the same shape;
23
+ - Adds batch dimension if rank = 1.
24
+
25
+ Args:
26
+ y_true: tensor. Ground truth values.
27
+ y_pred: tensor. The predicted values.
28
+ mask: tensor. Boolean mask for `y_true`.
29
+ check_y_true_rank: bool. Whether to check the rank of `y_true`.
30
+
31
+ Returns:
32
+ Tuple of processed `y_true`, `y_pred`, `mask`, and `batched`. `batched`
33
+ is a bool indicating if the inputs are batched.
34
+ """
35
+
36
+ y_true_shape = ops.shape(y_true)
37
+ y_true_rank = len(y_true_shape)
38
+ y_pred_shape = ops.shape(y_pred)
39
+ y_pred_rank = len(y_pred_shape)
40
+ if mask is not None:
41
+ mask_shape = ops.shape(mask)
42
+ mask_rank = len(mask_shape)
43
+
44
+ if check_y_true_rank:
45
+ check_rank(y_true_rank, allowed_ranks=(1, 2), tensor_name="y_true")
46
+ check_rank(y_pred_rank, allowed_ranks=(1, 2), tensor_name="y_pred")
47
+ if mask is not None:
48
+ check_rank(mask_rank, allowed_ranks=(1, 2), tensor_name="mask")
49
+ if not check_shapes_compatible(y_true_shape, y_pred_shape):
50
+ raise ValueError(
51
+ "`y_true` and `y_pred` should have the same shape. Received: "
52
+ f"`y_true.shape` = {y_true_shape}, `y_pred.shape` = {y_pred_shape}."
53
+ )
54
+ if mask is not None and not check_shapes_compatible(
55
+ y_true_shape, mask_shape
56
+ ):
57
+ raise ValueError(
58
+ "`y_true['labels']` and `y_true['mask']` should have the same "
59
+ f"shape. Received: `y_true['labels'].shape` = {y_true_shape}, "
60
+ f"`y_true['mask'].shape` = {mask_shape}."
61
+ )
62
+
63
+ batched = True
64
+ if y_true_rank == 1:
65
+ batched = False
66
+
67
+ y_true = ops.expand_dims(y_true, axis=0)
68
+ y_pred = ops.expand_dims(y_pred, axis=0)
69
+ if mask is not None:
70
+ mask = ops.expand_dims(mask, axis=0)
71
+
72
+ return y_true, y_pred, mask, batched
@@ -0,0 +1,48 @@
1
+ import re
2
+ import textwrap
3
+ from typing import Any
4
+
5
+
6
+ def format_docstring(template: str, width: int = 80, **kwargs: Any) -> str:
7
+ """Formats and wraps a docstring using dedent and fill."""
8
+ base_indent_str = " " * 4
9
+
10
+ # Initial format
11
+ formatted = template.format(**kwargs)
12
+
13
+ # Dedent the whole block
14
+ dedented_all = textwrap.dedent(formatted).strip()
15
+
16
+ # Split into logical paragraphs/blocks.
17
+ blocks = re.split(r"(\n\s*\n)", dedented_all)
18
+
19
+ processed_output = []
20
+
21
+ for block in blocks:
22
+ stripped_block = block.strip()
23
+ if not stripped_block:
24
+ processed_output.append(block)
25
+ continue
26
+
27
+ if "```" in stripped_block:
28
+ formula_dedented = textwrap.dedent(stripped_block)
29
+ processed_output.append(
30
+ textwrap.indent(formula_dedented, base_indent_str)
31
+ )
32
+ elif "where:" in stripped_block:
33
+ processed_output.append(
34
+ textwrap.indent(stripped_block, base_indent_str)
35
+ )
36
+ else:
37
+ processed_output.append(
38
+ textwrap.fill(
39
+ stripped_block,
40
+ width=width - len(base_indent_str),
41
+ initial_indent=base_indent_str,
42
+ subsequent_indent=base_indent_str,
43
+ )
44
+ )
45
+
46
+ final_string = "".join(processed_output).strip()
47
+ final_string = base_indent_str + final_string
48
+ return final_string
@@ -42,3 +42,15 @@ def check_shapes_compatible(
42
42
  return False
43
43
 
44
44
  return True
45
+
46
+
47
+ def check_rank(
48
+ x_rank: int,
49
+ allowed_ranks: tuple[int, ...],
50
+ tensor_name: str,
51
+ ) -> None:
52
+ if x_rank not in allowed_ranks:
53
+ raise ValueError(
54
+ f"`{tensor_name}` should have a rank from `{allowed_ranks}`."
55
+ f"Received: `{x_rank}`."
56
+ )
keras_rs/src/version.py CHANGED
@@ -1,7 +1,7 @@
1
1
  from keras_rs.src.api_export import keras_rs_export
2
2
 
3
3
  # Unique source of truth for the version number.
4
- __version__ = "0.0.1.dev2025042003"
4
+ __version__ = "0.0.1.dev2025042203"
5
5
 
6
6
 
7
7
  @keras_rs_export("keras_rs.version")
@@ -1,10 +1,10 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: keras-rs-nightly
3
- Version: 0.0.1.dev2025042003
3
+ Version: 0.0.1.dev2025042203
4
4
  Summary: Multi-backend recommender systems with Keras 3.
5
5
  Author-email: Keras RS team <keras-rs@google.com>
6
6
  License: Apache License 2.0
7
- Project-URL: Home, https://keras.io/
7
+ Project-URL: Home, https://keras.io/keras_rs
8
8
  Project-URL: Repository, https://github.com/keras-team/keras-rs
9
9
  Classifier: Development Status :: 3 - Alpha
10
10
  Classifier: Programming Language :: Python :: 3
@@ -0,0 +1,43 @@
1
+ keras_rs/__init__.py,sha256=X3VNKb_6VDEs5GqcbEc_l8mAsefWb5UgSu8krnQdFcM,794
2
+ keras_rs/api/__init__.py,sha256=Q3tmPWGmDoqZ_cy_hCFZowdRzAWjWpOvVAuFLHzrmzw,305
3
+ keras_rs/api/layers/__init__.py,sha256=SB7_QOBPizvbbyQAMb8mPl7vAx0gCxJBPm6V7H67SgU,747
4
+ keras_rs/api/losses/__init__.py,sha256=LGW7eHQh8FbQXdMV1s9zJpbloVlz_Zlo51sorWAvFwE,455
5
+ keras_rs/api/metrics/__init__.py,sha256=tKI6Hj8VQIT01xOCsp7hw5_eZ9Tl2HcVfhq7VpsOkOw,472
6
+ keras_rs/src/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
7
+ keras_rs/src/api_export.py,sha256=RsmG-DvO-cdFeAF9W6LRzms0kvtm-Yp9BAA_d-952zI,510
8
+ keras_rs/src/types.py,sha256=UyOdgjqrqg_b58opnY8n6gTiDHKVR8z_bmEruehERBk,514
9
+ keras_rs/src/version.py,sha256=FW541YEknvHzS3EpQiUQw7SBE4hNNFoMmsDLUHSBESU,222
10
+ keras_rs/src/layers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
+ keras_rs/src/layers/feature_interaction/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
+ keras_rs/src/layers/feature_interaction/dot_interaction.py,sha256=jGHcg0EiWxth6LTxG2yWgHcyx_GXrxvA61uQqpPfnDQ,6900
13
+ keras_rs/src/layers/feature_interaction/feature_cross.py,sha256=5OCSI0vFYzJNmgkKcuHIbVv8U2q3UvS80-qZjPimDjM,8155
14
+ keras_rs/src/layers/retrieval/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
+ keras_rs/src/layers/retrieval/brute_force_retrieval.py,sha256=izdppBXxJH0KqYEg7Zsr-SL-SHgAmnFopXMPalEO3uw,5676
16
+ keras_rs/src/layers/retrieval/hard_negative_mining.py,sha256=IWFrbw1h9z3AUw4oUBKf5_Aud4MTHO_AKdHfoyFa5As,3031
17
+ keras_rs/src/layers/retrieval/remove_accidental_hits.py,sha256=Z84z2YgKspKeNdc5id8lf9TAyFsbCCz3acJxiKXYipc,3324
18
+ keras_rs/src/layers/retrieval/retrieval.py,sha256=hVOBF10SF2q_TgJdVUqztbnw5qQF-cxVRGdJbOKoL9M,4191
19
+ keras_rs/src/layers/retrieval/sampling_probability_correction.py,sha256=80vgOPfBiF-PC0dSyqS57IcIxOxi_Q_R7eSXHn1G0yI,1437
20
+ keras_rs/src/losses/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
21
+ keras_rs/src/losses/pairwise_hinge_loss.py,sha256=nrIU0d1IcCAGo7RVxNitkldJhY2ZrXxjTV7Po27FXds,950
22
+ keras_rs/src/losses/pairwise_logistic_loss.py,sha256=2dTtRmrNfvF_lOvHK0UQ518L2d4fkvQZDj30HWB5A2s,1305
23
+ keras_rs/src/losses/pairwise_loss.py,sha256=rmDr_Qc3yA0CR8rUCCGjOgdbjYfC505BLNuITyb1n8k,6132
24
+ keras_rs/src/losses/pairwise_loss_utils.py,sha256=xvdGvdKNkvGvIaWYEQziWTFNa5EJz7rdkVGgrsnDHUk,1246
25
+ keras_rs/src/losses/pairwise_mean_squared_error.py,sha256=KhSRvjg4RpwhASP1Sl7PZoq2488P_uGDr9tZWzZhDVU,2764
26
+ keras_rs/src/losses/pairwise_soft_zero_one_loss.py,sha256=QdWn-lyWQM-U9ID9xGQ7oK10q9XT6qd1gxVAKy8hZH4,1239
27
+ keras_rs/src/metrics/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
28
+ keras_rs/src/metrics/dcg.py,sha256=DzSBc9ZbgNavuHRt3wtVzdx4ouAaaqeYhd9NxQLPq0g,5120
29
+ keras_rs/src/metrics/mean_average_precision.py,sha256=SF5NlhlyVL9L_YVkj_s_135f3-8hILVHRziSGafGyZI,3915
30
+ keras_rs/src/metrics/mean_reciprocal_rank.py,sha256=4stq0MzyWNokMlol6BESDAMuoUFieDrFFc57ue94h4Y,3240
31
+ keras_rs/src/metrics/ndcg.py,sha256=G7WNFoUaOhnf4vMF1jgcI4yGxieUfJv5E0upv4Qs1AQ,6545
32
+ keras_rs/src/metrics/precision_at_k.py,sha256=u-mj49qamt448gxkOI9YIZMMrhgO8QmetRFXGGlWOqY,3247
33
+ keras_rs/src/metrics/ranking_metric.py,sha256=cdFb4Lg2Z8P-02ImMGUAX4XeOUyzEE8TA6nB4fDgq0U,10411
34
+ keras_rs/src/metrics/ranking_metrics_utils.py,sha256=989J8pr6FRsA1HwBeF7SA8uQqjZT2XeCxKfRuMysWnQ,8828
35
+ keras_rs/src/metrics/recall_at_k.py,sha256=hlPnR5AtFjdd5AG0zLkLGVyLO5mWtp2bAu_cSOq9Fws,2919
36
+ keras_rs/src/metrics/utils.py,sha256=6xanTNdwARn4ugzmb7ko2kwAhNhsnR4NhrpS_qW0IKc,2506
37
+ keras_rs/src/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
38
+ keras_rs/src/utils/doc_string_utils.py,sha256=yVyQ8pYdl4gd4tKRhD8dXmQX1EwZeLiV3cCq3A1tUEk,1466
39
+ keras_rs/src/utils/keras_utils.py,sha256=d28OdQP4GrJk4NIQS4n0KPtCbgOCxVU_vDnnI7ODpOw,1562
40
+ keras_rs_nightly-0.0.1.dev2025042203.dist-info/METADATA,sha256=EHmYa_9Q01HrVfvhwfBju0dmxeH25EwW2WCTiKCEKqs,3555
41
+ keras_rs_nightly-0.0.1.dev2025042203.dist-info/WHEEL,sha256=pxyMxgL8-pra_rKaQ4drOZAegBVuX-G_4nRHjjgWbmo,91
42
+ keras_rs_nightly-0.0.1.dev2025042203.dist-info/top_level.txt,sha256=pWs8X78Z0cn6lfcIb9VYOW5UeJ-TpoaO9dByzo7_FFo,9
43
+ keras_rs_nightly-0.0.1.dev2025042203.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (78.1.1)
2
+ Generator: setuptools (79.0.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,102 +0,0 @@
1
- from typing import Callable, Optional
2
-
3
- from keras import ops
4
-
5
- from keras_rs.src import types
6
- from keras_rs.src.utils.keras_utils import check_shapes_compatible
7
-
8
-
9
- def apply_pairwise_op(
10
- x: types.Tensor, op: Callable[[types.Tensor, types.Tensor], types.Tensor]
11
- ) -> types.Tensor:
12
- return op(
13
- ops.expand_dims(x, axis=-1),
14
- ops.expand_dims(x, axis=-2),
15
- )
16
-
17
-
18
- def pairwise_comparison(
19
- labels: types.Tensor,
20
- logits: types.Tensor,
21
- mask: types.Tensor,
22
- logits_op: Callable[[types.Tensor, types.Tensor], types.Tensor],
23
- ) -> tuple[types.Tensor, types.Tensor]:
24
- # Compute the difference for all pairs in a list. The output is a tensor
25
- # with shape `(batch_size, list_size, list_size)`, where `[:, i, j]` stores
26
- # information for pair `(i, j)`.
27
- pairwise_labels_diff = apply_pairwise_op(labels, ops.subtract)
28
- pairwise_logits = apply_pairwise_op(logits, logits_op)
29
-
30
- # Keep only those cases where `l_i < l_j`.
31
- pairwise_labels = ops.cast(
32
- ops.greater(pairwise_labels_diff, 0), dtype=labels.dtype
33
- )
34
- if mask is not None:
35
- valid_pairs = apply_pairwise_op(mask, ops.logical_and)
36
- pairwise_labels = ops.multiply(
37
- pairwise_labels, ops.cast(valid_pairs, dtype=pairwise_labels.dtype)
38
- )
39
-
40
- return pairwise_labels, pairwise_logits
41
-
42
-
43
- def process_loss_call_inputs(
44
- y_true: types.Tensor,
45
- y_pred: types.Tensor,
46
- mask: Optional[types.Tensor] = None,
47
- ) -> tuple[types.Tensor, types.Tensor, Optional[types.Tensor]]:
48
- """
49
- Utility function for processing inputs for pairwise losses.
50
-
51
- This utility function does three things:
52
-
53
- - Checks that `y_true`, `y_pred` are of rank 1 or 2;
54
- - Checks that `y_true`, `y_pred`, `mask` have the same shape;
55
- - Adds batch dimension if rank = 1.
56
- """
57
-
58
- y_true_shape = ops.shape(y_true)
59
- y_true_rank = len(y_true_shape)
60
- y_pred_shape = ops.shape(y_pred)
61
- y_pred_rank = len(y_pred_shape)
62
- if mask is not None:
63
- mask_shape = ops.shape(mask)
64
- mask_rank = len(mask_shape)
65
-
66
- # Check ranks and shapes.
67
- def check_rank(
68
- x_rank: int,
69
- allowed_ranks: tuple[int, ...] = (1, 2),
70
- tensor_name: Optional[str] = None,
71
- ) -> None:
72
- if x_rank not in allowed_ranks:
73
- raise ValueError(
74
- f"`{tensor_name}` should have a rank from `{allowed_ranks}`."
75
- f"Received: `{x_rank}`."
76
- )
77
-
78
- check_rank(y_true_rank, tensor_name="y_true")
79
- check_rank(y_pred_rank, tensor_name="y_pred")
80
- if mask is not None:
81
- check_rank(mask_rank, tensor_name="mask")
82
- if not check_shapes_compatible(y_true_shape, y_pred_shape):
83
- raise ValueError(
84
- "`y_true` and `y_pred` should have the same shape. Received: "
85
- f"`y_true.shape` = {y_true_shape}, `y_pred.shape` = {y_pred_shape}."
86
- )
87
- if mask is not None and not check_shapes_compatible(
88
- y_true_shape, mask_shape
89
- ):
90
- raise ValueError(
91
- "`y_true['labels']` and `y_true['mask']` should have the same "
92
- f"shape. Received: `y_true['labels'].shape` = {y_true_shape}, "
93
- f"`y_true['mask'].shape` = {mask_shape}."
94
- )
95
-
96
- if y_true_rank == 1:
97
- y_true = ops.expand_dims(y_true, axis=0)
98
- y_pred = ops.expand_dims(y_pred, axis=0)
99
- if mask is not None:
100
- mask = ops.expand_dims(mask, axis=0)
101
-
102
- return y_true, y_pred, mask
@@ -1,31 +0,0 @@
1
- keras_rs/__init__.py,sha256=X3VNKb_6VDEs5GqcbEc_l8mAsefWb5UgSu8krnQdFcM,794
2
- keras_rs/api/__init__.py,sha256=9Xf-uH9j_SBaTc5RU0pkxrOEgHWPwSKjf4_maySH_nU,272
3
- keras_rs/api/layers/__init__.py,sha256=SB7_QOBPizvbbyQAMb8mPl7vAx0gCxJBPm6V7H67SgU,747
4
- keras_rs/api/losses/__init__.py,sha256=LGW7eHQh8FbQXdMV1s9zJpbloVlz_Zlo51sorWAvFwE,455
5
- keras_rs/src/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
- keras_rs/src/api_export.py,sha256=RsmG-DvO-cdFeAF9W6LRzms0kvtm-Yp9BAA_d-952zI,510
7
- keras_rs/src/types.py,sha256=UyOdgjqrqg_b58opnY8n6gTiDHKVR8z_bmEruehERBk,514
8
- keras_rs/src/version.py,sha256=hch-ZAcVBDQk6_Mes0xphoDOXCeEDb5Yk_hDiswqfpg,222
9
- keras_rs/src/layers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
- keras_rs/src/layers/feature_interaction/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
- keras_rs/src/layers/feature_interaction/dot_interaction.py,sha256=jGHcg0EiWxth6LTxG2yWgHcyx_GXrxvA61uQqpPfnDQ,6900
12
- keras_rs/src/layers/feature_interaction/feature_cross.py,sha256=5OCSI0vFYzJNmgkKcuHIbVv8U2q3UvS80-qZjPimDjM,8155
13
- keras_rs/src/layers/retrieval/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
- keras_rs/src/layers/retrieval/brute_force_retrieval.py,sha256=izdppBXxJH0KqYEg7Zsr-SL-SHgAmnFopXMPalEO3uw,5676
15
- keras_rs/src/layers/retrieval/hard_negative_mining.py,sha256=IWFrbw1h9z3AUw4oUBKf5_Aud4MTHO_AKdHfoyFa5As,3031
16
- keras_rs/src/layers/retrieval/remove_accidental_hits.py,sha256=Z84z2YgKspKeNdc5id8lf9TAyFsbCCz3acJxiKXYipc,3324
17
- keras_rs/src/layers/retrieval/retrieval.py,sha256=hVOBF10SF2q_TgJdVUqztbnw5qQF-cxVRGdJbOKoL9M,4191
18
- keras_rs/src/layers/retrieval/sampling_probability_correction.py,sha256=80vgOPfBiF-PC0dSyqS57IcIxOxi_Q_R7eSXHn1G0yI,1437
19
- keras_rs/src/losses/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
20
- keras_rs/src/losses/pairwise_hinge_loss.py,sha256=vqDGd-OnZxiqdeE6vuabE8BKDfill3D2GM0lW5JUmsg,922
21
- keras_rs/src/losses/pairwise_logistic_loss.py,sha256=dhq3CVxuLAH17QxkOs3XVLliYoE3zSJME62CU34vX-k,1274
22
- keras_rs/src/losses/pairwise_loss.py,sha256=oQCKSRAbQrajj_fXnno1I8wYxCMezAKXUljag5viqMY,5428
23
- keras_rs/src/losses/pairwise_mean_squared_error.py,sha256=782K4mFji0DB-mTtHc6dvvQW9azzwFZq_BiocCb-gBE,2727
24
- keras_rs/src/losses/pairwise_soft_zero_one_loss.py,sha256=XBej5nybFXEQ-Vp6GLvNmqTqMq-VEDTz83UVQAQsTZ8,1203
25
- keras_rs/src/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
26
- keras_rs/src/utils/keras_utils.py,sha256=IjWSRieBkv7UX12qgUoI1tcOeISstCLRSTqSHpT06yE,1275
27
- keras_rs/src/utils/pairwise_loss_utils.py,sha256=6eF4CTJubCySO8M5nd3_gdTlJsta_YMnwDCcdqWYGHA,3435
28
- keras_rs_nightly-0.0.1.dev2025042003.dist-info/METADATA,sha256=cKc5u76ElpkkcfUd5doiEvR58f0yx3A9yMC-eKeqf1E,3547
29
- keras_rs_nightly-0.0.1.dev2025042003.dist-info/WHEEL,sha256=lTU6B6eIfYoiQJTZNc-fyaR6BpL6ehTzU3xGYxn2n8k,91
30
- keras_rs_nightly-0.0.1.dev2025042003.dist-info/top_level.txt,sha256=pWs8X78Z0cn6lfcIb9VYOW5UeJ-TpoaO9dByzo7_FFo,9
31
- keras_rs_nightly-0.0.1.dev2025042003.dist-info/RECORD,,