keras-rs-nightly 0.0.1.dev2025021903__py3-none-any.whl → 0.3.1.dev202512130338__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_rs/__init__.py +9 -28
- keras_rs/layers/__init__.py +37 -0
- keras_rs/losses/__init__.py +19 -0
- keras_rs/metrics/__init__.py +16 -0
- keras_rs/src/layers/embedding/base_distributed_embedding.py +1151 -0
- keras_rs/src/layers/embedding/distributed_embedding.py +33 -0
- keras_rs/src/layers/embedding/distributed_embedding_config.py +132 -0
- keras_rs/src/layers/embedding/embed_reduce.py +309 -0
- keras_rs/src/layers/embedding/jax/__init__.py +0 -0
- keras_rs/src/layers/embedding/jax/checkpoint_utils.py +104 -0
- keras_rs/src/layers/embedding/jax/config_conversion.py +468 -0
- keras_rs/src/layers/embedding/jax/distributed_embedding.py +829 -0
- keras_rs/src/layers/embedding/jax/embedding_lookup.py +276 -0
- keras_rs/src/layers/embedding/jax/embedding_utils.py +217 -0
- keras_rs/src/layers/embedding/tensorflow/__init__.py +0 -0
- keras_rs/src/layers/embedding/tensorflow/config_conversion.py +363 -0
- keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py +436 -0
- keras_rs/src/layers/feature_interaction/__init__.py +0 -0
- keras_rs/src/layers/{modeling → feature_interaction}/dot_interaction.py +116 -25
- keras_rs/src/layers/{modeling → feature_interaction}/feature_cross.py +40 -22
- keras_rs/src/layers/retrieval/brute_force_retrieval.py +16 -65
- keras_rs/src/layers/retrieval/hard_negative_mining.py +94 -0
- keras_rs/src/layers/retrieval/remove_accidental_hits.py +97 -0
- keras_rs/src/layers/retrieval/retrieval.py +127 -0
- keras_rs/src/layers/retrieval/sampling_probability_correction.py +63 -0
- keras_rs/src/losses/__init__.py +0 -0
- keras_rs/src/losses/list_mle_loss.py +212 -0
- keras_rs/src/losses/pairwise_hinge_loss.py +90 -0
- keras_rs/src/losses/pairwise_logistic_loss.py +99 -0
- keras_rs/src/losses/pairwise_loss.py +165 -0
- keras_rs/src/losses/pairwise_loss_utils.py +39 -0
- keras_rs/src/losses/pairwise_mean_squared_error.py +133 -0
- keras_rs/src/losses/pairwise_soft_zero_one_loss.py +98 -0
- keras_rs/src/metrics/__init__.py +0 -0
- keras_rs/src/metrics/dcg.py +161 -0
- keras_rs/src/metrics/mean_average_precision.py +130 -0
- keras_rs/src/metrics/mean_reciprocal_rank.py +121 -0
- keras_rs/src/metrics/ndcg.py +197 -0
- keras_rs/src/metrics/precision_at_k.py +117 -0
- keras_rs/src/metrics/ranking_metric.py +260 -0
- keras_rs/src/metrics/ranking_metrics_utils.py +257 -0
- keras_rs/src/metrics/recall_at_k.py +108 -0
- keras_rs/src/metrics/utils.py +70 -0
- keras_rs/src/types.py +43 -14
- keras_rs/src/utils/doc_string_utils.py +53 -0
- keras_rs/src/utils/keras_utils.py +52 -3
- keras_rs/src/utils/tpu_test_utils.py +120 -0
- keras_rs/src/version.py +1 -1
- {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/METADATA +88 -8
- keras_rs_nightly-0.3.1.dev202512130338.dist-info/RECORD +58 -0
- {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/WHEEL +1 -1
- keras_rs/api/__init__.py +0 -9
- keras_rs/api/layers/__init__.py +0 -11
- keras_rs_nightly-0.0.1.dev2025021903.dist-info/RECORD +0 -19
- /keras_rs/src/layers/{modeling → embedding}/__init__.py +0 -0
- {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,257 @@
|
|
|
1
|
+
from typing import Callable
|
|
2
|
+
|
|
3
|
+
import keras
|
|
4
|
+
from keras import ops
|
|
5
|
+
|
|
6
|
+
from keras_rs.src import types
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def get_shuffled_indices(
|
|
10
|
+
shape: types.Shape,
|
|
11
|
+
mask: types.Tensor | None = None,
|
|
12
|
+
shuffle_ties: bool = True,
|
|
13
|
+
seed: int | keras.random.SeedGenerator | None = None,
|
|
14
|
+
) -> types.Tensor:
|
|
15
|
+
"""Utility function for getting shuffled indices, with masked indices
|
|
16
|
+
pushed to the end.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
shape: tuple. The shape of the tensor for which to generate
|
|
20
|
+
shuffled indices.
|
|
21
|
+
mask: An optional boolean tensor with the same shape as `shape`.
|
|
22
|
+
If provided, elements where `mask` is `False` will be placed
|
|
23
|
+
at the end of the sorted indices. Defaults to `None` (no masking).
|
|
24
|
+
shuffle_ties: Boolean indicating how to handle ties if multiple elements
|
|
25
|
+
have the same sorting value (randomly when `shuffle_ties` is True
|
|
26
|
+
otherwise, order is preserved).
|
|
27
|
+
seed: Optional integer seed for the random number generator used when
|
|
28
|
+
`shuffle_ties` is True. Ensures reproducibility. Defaults to None.
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
A tensor of shape `shape` containing shuffled indices.
|
|
32
|
+
"""
|
|
33
|
+
# If `shuffle_ties` is True, generate random values. Otherwise, generate
|
|
34
|
+
# zeros so that we get `[0, 1, 2, ...]` as indices on doing `argsort`.
|
|
35
|
+
if shuffle_ties:
|
|
36
|
+
shuffle_values = keras.random.uniform(shape, seed=seed, dtype="float32")
|
|
37
|
+
else:
|
|
38
|
+
shuffle_values = ops.zeros(shape, dtype="float32")
|
|
39
|
+
|
|
40
|
+
# When `mask = False`, increase value by 1 so that those indices are placed
|
|
41
|
+
# at the end. Note that `shuffle_values` lies in the range `[0, 1)`, so
|
|
42
|
+
# adding by 1 works out.
|
|
43
|
+
if mask is not None:
|
|
44
|
+
shuffle_values = ops.where(
|
|
45
|
+
mask,
|
|
46
|
+
shuffle_values,
|
|
47
|
+
ops.add(shuffle_values, ops.cast(1, dtype="float32")),
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
shuffled_indices = ops.argsort(shuffle_values)
|
|
51
|
+
return shuffled_indices
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def sort_by_scores(
|
|
55
|
+
tensors_to_sort: list[types.Tensor],
|
|
56
|
+
scores: types.Tensor,
|
|
57
|
+
mask: types.Tensor | None = None,
|
|
58
|
+
k: int | None = None,
|
|
59
|
+
shuffle_ties: bool = True,
|
|
60
|
+
seed: int | keras.random.SeedGenerator | None = None,
|
|
61
|
+
) -> types.Tensor:
|
|
62
|
+
"""
|
|
63
|
+
Utility function for sorting tensors by scores.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
tensors_to_sort: list of tensors. All tensors are of shape
|
|
67
|
+
`(batch_size, list_size)`. These tensors are sorted based on
|
|
68
|
+
`scores`.
|
|
69
|
+
scores: tensor. Of shape `(batch_size, list_size)`. The scores to sort
|
|
70
|
+
by.
|
|
71
|
+
k: int. The number of top-ranked items to consider (the 'k' in 'top-k').
|
|
72
|
+
If `None`, `list_size` is used.
|
|
73
|
+
shuffle_ties: bool. Whether to randomly shuffle scores before sorting.
|
|
74
|
+
This is done to break ties.
|
|
75
|
+
seed: int. Seed for shuffling.
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
List of sorted tensors (`tensors_to_sort`), sorted using `scores`.
|
|
79
|
+
"""
|
|
80
|
+
max_possible_k = ops.shape(scores)[1]
|
|
81
|
+
if k is None:
|
|
82
|
+
k = max_possible_k
|
|
83
|
+
elif isinstance(max_possible_k, int):
|
|
84
|
+
k = min(k, max_possible_k)
|
|
85
|
+
else:
|
|
86
|
+
k = ops.minimum(k, max_possible_k)
|
|
87
|
+
|
|
88
|
+
# --- Work around for PyTorch instability ---
|
|
89
|
+
# Torch's `topk` is not stable with `sorted=True`, unlike JAX and TF.
|
|
90
|
+
# See:
|
|
91
|
+
# - https://github.com/pytorch/pytorch/issues/27542
|
|
92
|
+
# - https://github.com/pytorch/pytorch/issues/88227
|
|
93
|
+
#
|
|
94
|
+
# This small "stable offset" ensures deterministic tie-breaking for
|
|
95
|
+
# equal scores. We can remove this workaround once PyTorch adds a
|
|
96
|
+
# `stable=True` flag for topk.
|
|
97
|
+
|
|
98
|
+
if keras.backend.backend() == "torch" and not shuffle_ties:
|
|
99
|
+
list_size = ops.shape(scores)[1]
|
|
100
|
+
indices = ops.arange(list_size)
|
|
101
|
+
indices = ops.expand_dims(indices, axis=0)
|
|
102
|
+
indices = ops.broadcast_to(indices, ops.shape(scores))
|
|
103
|
+
stable_offset = ops.cast(indices, scores.dtype) * 1e-6
|
|
104
|
+
scores = ops.subtract(scores, stable_offset)
|
|
105
|
+
# --- End FIX ---
|
|
106
|
+
|
|
107
|
+
# Shuffle ties randomly, and push masked values to the beginning.
|
|
108
|
+
shuffled_indices = None
|
|
109
|
+
if shuffle_ties or mask is not None:
|
|
110
|
+
shuffled_indices = get_shuffled_indices(
|
|
111
|
+
ops.shape(scores),
|
|
112
|
+
mask=mask,
|
|
113
|
+
shuffle_ties=True,
|
|
114
|
+
seed=seed,
|
|
115
|
+
)
|
|
116
|
+
scores = ops.take_along_axis(scores, shuffled_indices, axis=1)
|
|
117
|
+
|
|
118
|
+
# Get top-k indices.
|
|
119
|
+
_, indices = ops.top_k(scores, k=k, sorted=True)
|
|
120
|
+
|
|
121
|
+
# If we shuffled our `scores` tensor, we need to get the correct indices
|
|
122
|
+
# by indexing into `shuffled_indices`.
|
|
123
|
+
if shuffled_indices is not None:
|
|
124
|
+
indices = ops.take_along_axis(shuffled_indices, indices, axis=1)
|
|
125
|
+
|
|
126
|
+
return [
|
|
127
|
+
ops.take_along_axis(tensor_to_sort, indices, axis=1)
|
|
128
|
+
for tensor_to_sort in tensors_to_sort
|
|
129
|
+
]
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def get_list_weights(
|
|
133
|
+
weights: types.Tensor, relevance: types.Tensor
|
|
134
|
+
) -> types.Tensor:
|
|
135
|
+
"""Computes per-list weights from provided sample weights.
|
|
136
|
+
|
|
137
|
+
Per-list weights are calculated as follows:
|
|
138
|
+
```
|
|
139
|
+
per_list_weights = sum(weights * relevance) / sum(relevance).
|
|
140
|
+
```
|
|
141
|
+
|
|
142
|
+
For lists where the sum of relevance is 0, a default weight is assigned:
|
|
143
|
+
```
|
|
144
|
+
sum(per_list_weights) / num(sum(relevance) != 0 AND sum(weights) != 0)
|
|
145
|
+
```
|
|
146
|
+
|
|
147
|
+
If all lists have a sum of relevance equal to 0, the default weight is 1.0.
|
|
148
|
+
|
|
149
|
+
As a result of the above computation, this function takes care of the
|
|
150
|
+
following cases:
|
|
151
|
+
- **Uniform Weights:** When all input weights are 1.0, all per-list weights
|
|
152
|
+
will be 1.0, even for lists with no relevant examples. This aligns with
|
|
153
|
+
standard ranking metrics.
|
|
154
|
+
- **Non-zero Weights per List:** If every list has at least one non-zero
|
|
155
|
+
weight, the default weight mechanism is not utilized, which is suitable
|
|
156
|
+
for unbiased metrics.
|
|
157
|
+
- **Mixed Scenarios:** For cases with a mix of lists having zero and
|
|
158
|
+
non-zero relevance and weights, the weights for lists with non-zero
|
|
159
|
+
relevance and weights are proportional to:
|
|
160
|
+
|
|
161
|
+
```
|
|
162
|
+
per_list_weights / sum(per_list_weights) *
|
|
163
|
+
num(sum(relevance) != 0) / num(lists)
|
|
164
|
+
```
|
|
165
|
+
|
|
166
|
+
The rest have weights `1.0 / num(lists)`.
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
weights: tensor. Weights tensor of shape `(batch_size, list_size)`.
|
|
170
|
+
relevance: tensor. The relevance `Tensor` of shape
|
|
171
|
+
`(batch_size, list_size)`.
|
|
172
|
+
|
|
173
|
+
Returns:
|
|
174
|
+
A tensor of shape `(batch_size, 1)`, containing the per-list weights.
|
|
175
|
+
"""
|
|
176
|
+
# Calculate if the sum of weights per list is greater than 0.0.
|
|
177
|
+
nonzero_weights = ops.greater(ops.sum(weights, axis=1, keepdims=True), 0.0)
|
|
178
|
+
# Calculate the sum of relevance per list
|
|
179
|
+
per_list_relevance = ops.sum(relevance, axis=1, keepdims=True)
|
|
180
|
+
# Calculate if the sum of relevance per list is greater than 0.0
|
|
181
|
+
nonzero_relevance_condition = ops.greater(per_list_relevance, 0.0)
|
|
182
|
+
# Identify lists where both weights and relevance sums are non-zero.
|
|
183
|
+
nonzero_relevance = ops.cast(
|
|
184
|
+
ops.logical_and(nonzero_weights, nonzero_relevance_condition),
|
|
185
|
+
dtype=weights.dtype,
|
|
186
|
+
)
|
|
187
|
+
# Count the number of lists with non-zero relevance and non-zero weights.
|
|
188
|
+
nonzero_relevance_count = ops.sum(nonzero_relevance, axis=0, keepdims=True)
|
|
189
|
+
|
|
190
|
+
# Calculate the per-list weights using the core formula.
|
|
191
|
+
# Numerator: `sum(weights * relevance)` per list
|
|
192
|
+
numerator = ops.sum(ops.multiply(weights, relevance), axis=1, keepdims=True)
|
|
193
|
+
# Denominator: per_list_relevance = sum(relevance) per list
|
|
194
|
+
per_list_weights = ops.divide_no_nan(numerator, per_list_relevance)
|
|
195
|
+
|
|
196
|
+
# Calculate the sum of the computed per-list weights.
|
|
197
|
+
sum_weights = ops.sum(per_list_weights, axis=0, keepdims=True)
|
|
198
|
+
|
|
199
|
+
# Calculate the average weight to use as default for lists with zero
|
|
200
|
+
# relevance but non-zero weights. If no lists have non-zero relevance,
|
|
201
|
+
# default to 1.0.
|
|
202
|
+
avg_weight = ops.where(
|
|
203
|
+
ops.greater(nonzero_relevance_count, 0.0),
|
|
204
|
+
ops.divide(sum_weights, nonzero_relevance_count),
|
|
205
|
+
ops.cast(1, dtype=sum_weights.dtype),
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
# Final assignment of weights based on conditions:
|
|
209
|
+
# 1. If sum(weights) == 0 for a list, the final weight is 0.
|
|
210
|
+
# 2. If sum(weights) > 0 AND sum(relevance) > 0, use the calculated
|
|
211
|
+
# `per_list_weights`.
|
|
212
|
+
# 3. If `sum(weights) > 0` AND `sum(relevance) == 0`, use the calculated
|
|
213
|
+
# `avg_weight`.
|
|
214
|
+
final_weights = ops.where(
|
|
215
|
+
nonzero_weights,
|
|
216
|
+
ops.where(
|
|
217
|
+
nonzero_relevance_condition,
|
|
218
|
+
per_list_weights,
|
|
219
|
+
avg_weight,
|
|
220
|
+
),
|
|
221
|
+
ops.cast(0, dtype=per_list_weights.dtype),
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
return final_weights
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
@keras.saving.register_keras_serializable() # type: ignore[untyped-decorator]
|
|
228
|
+
def default_gain_fn(label: types.Tensor) -> types.Tensor:
|
|
229
|
+
return ops.subtract(ops.power(2.0, label), 1.0)
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
@keras.saving.register_keras_serializable() # type: ignore[untyped-decorator]
|
|
233
|
+
def default_rank_discount_fn(rank: types.Tensor) -> types.Tensor:
|
|
234
|
+
return ops.divide(
|
|
235
|
+
ops.cast(1, dtype=rank.dtype),
|
|
236
|
+
ops.log2(ops.add(ops.cast(1, dtype=rank.dtype), rank)),
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
def compute_dcg(
|
|
241
|
+
y_true: types.Tensor,
|
|
242
|
+
sample_weight: types.Tensor,
|
|
243
|
+
gain_fn: Callable[[types.Tensor], types.Tensor] = default_gain_fn,
|
|
244
|
+
rank_discount_fn: Callable[
|
|
245
|
+
[types.Tensor], types.Tensor
|
|
246
|
+
] = default_rank_discount_fn,
|
|
247
|
+
) -> types.Tensor:
|
|
248
|
+
list_size = ops.shape(y_true)[1]
|
|
249
|
+
positions = ops.arange(1, list_size + 1, dtype=y_true.dtype)
|
|
250
|
+
gain = gain_fn(y_true)
|
|
251
|
+
discount = rank_discount_fn(positions)
|
|
252
|
+
|
|
253
|
+
return ops.sum(
|
|
254
|
+
ops.multiply(sample_weight, ops.multiply(gain, discount)),
|
|
255
|
+
axis=1,
|
|
256
|
+
keepdims=True,
|
|
257
|
+
)
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
from keras import ops
|
|
2
|
+
|
|
3
|
+
from keras_rs.src import types
|
|
4
|
+
from keras_rs.src.api_export import keras_rs_export
|
|
5
|
+
from keras_rs.src.metrics.ranking_metric import RankingMetric
|
|
6
|
+
from keras_rs.src.metrics.ranking_metric import (
|
|
7
|
+
ranking_metric_subclass_doc_string,
|
|
8
|
+
)
|
|
9
|
+
from keras_rs.src.metrics.ranking_metric import (
|
|
10
|
+
ranking_metric_subclass_doc_string_post_desc,
|
|
11
|
+
)
|
|
12
|
+
from keras_rs.src.metrics.ranking_metrics_utils import get_list_weights
|
|
13
|
+
from keras_rs.src.metrics.ranking_metrics_utils import sort_by_scores
|
|
14
|
+
from keras_rs.src.utils.doc_string_utils import format_docstring
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@keras_rs_export("keras_rs.metrics.RecallAtK")
|
|
18
|
+
class RecallAtK(RankingMetric):
|
|
19
|
+
def compute_metric(
|
|
20
|
+
self,
|
|
21
|
+
y_true: types.Tensor,
|
|
22
|
+
y_pred: types.Tensor,
|
|
23
|
+
mask: types.Tensor,
|
|
24
|
+
sample_weight: types.Tensor,
|
|
25
|
+
) -> types.Tensor:
|
|
26
|
+
# Assume: `y_true = [0, 0, 1]`, `y_pred = [0.1, 0.9, 0.2]`.
|
|
27
|
+
# `sorted_y_true = [0, 1, 0]` (sorted in descending order).
|
|
28
|
+
(sorted_y_true,) = sort_by_scores(
|
|
29
|
+
tensors_to_sort=[y_true],
|
|
30
|
+
scores=y_pred,
|
|
31
|
+
mask=mask,
|
|
32
|
+
k=self.k,
|
|
33
|
+
shuffle_ties=self.shuffle_ties,
|
|
34
|
+
seed=self.seed_generator,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
relevance = ops.cast(
|
|
38
|
+
ops.greater_equal(
|
|
39
|
+
sorted_y_true, ops.cast(1, dtype=sorted_y_true.dtype)
|
|
40
|
+
),
|
|
41
|
+
dtype=y_pred.dtype,
|
|
42
|
+
)
|
|
43
|
+
overall_relevance = ops.cast(
|
|
44
|
+
ops.greater_equal(y_true, ops.cast(1, dtype=y_true.dtype)),
|
|
45
|
+
dtype=y_pred.dtype,
|
|
46
|
+
)
|
|
47
|
+
per_list_recall = ops.divide_no_nan(
|
|
48
|
+
ops.sum(relevance, axis=1, keepdims=True),
|
|
49
|
+
ops.sum(overall_relevance, axis=1, keepdims=True),
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
# Get weights.
|
|
53
|
+
per_list_weights = get_list_weights(
|
|
54
|
+
weights=sample_weight, relevance=overall_relevance
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
return per_list_recall, per_list_weights
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
concept_sentence = (
|
|
61
|
+
"It measures the proportion of relevant items found in the top-k "
|
|
62
|
+
"recommendations out of the total number of relevant items for a user"
|
|
63
|
+
)
|
|
64
|
+
relevance_type = "binary indicators (0 or 1) of relevance"
|
|
65
|
+
score_range_interpretation = (
|
|
66
|
+
"Scores range from 0 to 1, with 1 indicating that all relevant items "
|
|
67
|
+
"for the user were found within the top-k recommendations"
|
|
68
|
+
)
|
|
69
|
+
formula = """```
|
|
70
|
+
R@k(y, s) = sum_i I[rank(s_i) < k] y_i / sum_j y_j
|
|
71
|
+
```
|
|
72
|
+
|
|
73
|
+
where `y_i` is the relevance label (0/1) of the item ranked at position
|
|
74
|
+
`i`, `I[condition]` is 1 if the condition is met, otherwise 0."""
|
|
75
|
+
extra_args = ""
|
|
76
|
+
example = """
|
|
77
|
+
>>> batch_size = 2
|
|
78
|
+
>>> list_size = 5
|
|
79
|
+
>>> labels = np.random.randint(0, 2, size=(batch_size, list_size))
|
|
80
|
+
>>> scores = np.random.random(size=(batch_size, list_size))
|
|
81
|
+
>>> metric = keras_rs.metrics.RecallAtK()(
|
|
82
|
+
... y_true=labels, y_pred=scores
|
|
83
|
+
... )
|
|
84
|
+
|
|
85
|
+
Mask certain elements (can be used for uneven inputs):
|
|
86
|
+
|
|
87
|
+
>>> batch_size = 2
|
|
88
|
+
>>> list_size = 5
|
|
89
|
+
>>> labels = np.random.randint(0, 2, size=(batch_size, list_size))
|
|
90
|
+
>>> scores = np.random.random(size=(batch_size, list_size))
|
|
91
|
+
>>> mask = np.random.randint(0, 2, size=(batch_size, list_size), dtype=bool)
|
|
92
|
+
>>> metric = keras_rs.metrics.RecallAtK()(
|
|
93
|
+
... y_true={"labels": labels, "mask": mask}, y_pred=scores
|
|
94
|
+
... )
|
|
95
|
+
"""
|
|
96
|
+
|
|
97
|
+
RecallAtK.__doc__ = format_docstring(
|
|
98
|
+
ranking_metric_subclass_doc_string,
|
|
99
|
+
width=80,
|
|
100
|
+
metric_name="Recall@k",
|
|
101
|
+
metric_abbreviation="R@k",
|
|
102
|
+
concept_sentence=concept_sentence,
|
|
103
|
+
relevance_type=relevance_type,
|
|
104
|
+
score_range_interpretation=score_range_interpretation,
|
|
105
|
+
formula=formula,
|
|
106
|
+
) + ranking_metric_subclass_doc_string_post_desc.format(
|
|
107
|
+
extra_args=extra_args, example=example
|
|
108
|
+
)
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
from keras import ops
|
|
2
|
+
|
|
3
|
+
from keras_rs.src import types
|
|
4
|
+
from keras_rs.src.utils.keras_utils import check_rank
|
|
5
|
+
from keras_rs.src.utils.keras_utils import check_shapes_compatible
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def standardize_call_inputs_ranks(
|
|
9
|
+
y_true: types.Tensor,
|
|
10
|
+
y_pred: types.Tensor,
|
|
11
|
+
mask: types.Tensor | None = None,
|
|
12
|
+
check_y_true_rank: bool = True,
|
|
13
|
+
) -> tuple[types.Tensor, types.Tensor, types.Tensor | None, bool]:
|
|
14
|
+
"""
|
|
15
|
+
Utility function for processing inputs for losses and metrics.
|
|
16
|
+
|
|
17
|
+
This utility function does three things:
|
|
18
|
+
|
|
19
|
+
- Checks that `y_true`, `y_pred` are of rank 1 or 2;
|
|
20
|
+
- Checks that `y_true`, `y_pred`, `mask` have the same shape;
|
|
21
|
+
- Adds batch dimension if rank = 1.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
y_true: tensor. Ground truth values.
|
|
25
|
+
y_pred: tensor. The predicted values.
|
|
26
|
+
mask: tensor. Boolean mask for `y_true`.
|
|
27
|
+
check_y_true_rank: bool. Whether to check the rank of `y_true`.
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
Tuple of processed `y_true`, `y_pred`, `mask`, and `batched`. `batched`
|
|
31
|
+
is a bool indicating if the inputs are batched.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
y_true_shape = ops.shape(y_true)
|
|
35
|
+
y_true_rank = len(y_true_shape)
|
|
36
|
+
y_pred_shape = ops.shape(y_pred)
|
|
37
|
+
y_pred_rank = len(y_pred_shape)
|
|
38
|
+
if mask is not None:
|
|
39
|
+
mask_shape = ops.shape(mask)
|
|
40
|
+
mask_rank = len(mask_shape)
|
|
41
|
+
|
|
42
|
+
if check_y_true_rank:
|
|
43
|
+
check_rank(y_true_rank, allowed_ranks=(1, 2), tensor_name="y_true")
|
|
44
|
+
check_rank(y_pred_rank, allowed_ranks=(1, 2), tensor_name="y_pred")
|
|
45
|
+
if mask is not None:
|
|
46
|
+
check_rank(mask_rank, allowed_ranks=(1, 2), tensor_name="mask")
|
|
47
|
+
if not check_shapes_compatible(y_true_shape, y_pred_shape):
|
|
48
|
+
raise ValueError(
|
|
49
|
+
"`y_true` and `y_pred` should have the same shape. Received: "
|
|
50
|
+
f"`y_true.shape` = {y_true_shape}, `y_pred.shape` = {y_pred_shape}."
|
|
51
|
+
)
|
|
52
|
+
if mask is not None and not check_shapes_compatible(
|
|
53
|
+
y_true_shape, mask_shape
|
|
54
|
+
):
|
|
55
|
+
raise ValueError(
|
|
56
|
+
"`y_true['labels']` and `y_true['mask']` should have the same "
|
|
57
|
+
f"shape. Received: `y_true['labels'].shape` = {y_true_shape}, "
|
|
58
|
+
f"`y_true['mask'].shape` = {mask_shape}."
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
batched = True
|
|
62
|
+
if y_true_rank == 1:
|
|
63
|
+
batched = False
|
|
64
|
+
|
|
65
|
+
y_true = ops.expand_dims(y_true, axis=0)
|
|
66
|
+
y_pred = ops.expand_dims(y_pred, axis=0)
|
|
67
|
+
if mask is not None:
|
|
68
|
+
mask = ops.expand_dims(mask, axis=0)
|
|
69
|
+
|
|
70
|
+
return y_true, y_pred, mask, batched
|
keras_rs/src/types.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
"""Type definitions."""
|
|
2
2
|
|
|
3
|
-
from typing import Any,
|
|
3
|
+
from typing import Any, Callable, Mapping, Sequence, TypeAlias, TypeVar, Union
|
|
4
|
+
|
|
5
|
+
import keras
|
|
4
6
|
|
|
5
7
|
"""
|
|
6
8
|
A tensor in any of the backends.
|
|
@@ -8,19 +10,46 @@ A tensor in any of the backends.
|
|
|
8
10
|
We do not define it explicitly to not require all the backends to be installed
|
|
9
11
|
and imported. The explicit definition would be:
|
|
10
12
|
```
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
keras.KerasTensor,
|
|
21
|
-
]
|
|
13
|
+
numpy.ndarray,
|
|
14
|
+
| tensorflow.Tensor,
|
|
15
|
+
| tensorflow.RaggedTensor,
|
|
16
|
+
| tensorflow.SparseTensor,
|
|
17
|
+
| tensorflow.IndexedSlices,
|
|
18
|
+
| jax.Array,
|
|
19
|
+
| jax.experimental.sparse.JAXSparse,
|
|
20
|
+
| torch.Tensor,
|
|
21
|
+
| keras.KerasTensor,
|
|
22
22
|
```
|
|
23
23
|
"""
|
|
24
|
-
Tensor = Any
|
|
24
|
+
Tensor: TypeAlias = Any
|
|
25
|
+
|
|
26
|
+
Shape: TypeAlias = Sequence[int | None]
|
|
27
|
+
|
|
28
|
+
DType: TypeAlias = str
|
|
29
|
+
|
|
30
|
+
ConstraintLike: TypeAlias = (
|
|
31
|
+
str
|
|
32
|
+
| keras.constraints.Constraint
|
|
33
|
+
| type[keras.constraints.Constraint]
|
|
34
|
+
| Callable[[Tensor], Tensor]
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
InitializerLike: TypeAlias = (
|
|
38
|
+
str
|
|
39
|
+
| keras.initializers.Initializer
|
|
40
|
+
| type[keras.initializers.Initializer]
|
|
41
|
+
| Callable[[Shape, DType], Tensor]
|
|
42
|
+
| Tensor
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
RegularizerLike: TypeAlias = (
|
|
46
|
+
str
|
|
47
|
+
| keras.regularizers.Regularizer
|
|
48
|
+
| type[keras.regularizers.Regularizer]
|
|
49
|
+
| Callable[[Tensor], Tensor]
|
|
50
|
+
)
|
|
25
51
|
|
|
26
|
-
|
|
52
|
+
T = TypeVar("T")
|
|
53
|
+
Nested: TypeAlias = (
|
|
54
|
+
T | Sequence[Union[T, "Nested[T]"]] | Mapping[str, Union[T, "Nested[T]"]]
|
|
55
|
+
)
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
import re
|
|
2
|
+
import textwrap
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def format_docstring(template: str, width: int = 80, **kwargs: Any) -> str:
|
|
7
|
+
"""Formats and wraps a docstring using dedent and fill."""
|
|
8
|
+
base_indent_str = " " * 4
|
|
9
|
+
|
|
10
|
+
# Initial format
|
|
11
|
+
formatted = template.format(**kwargs)
|
|
12
|
+
|
|
13
|
+
# Dedent the whole block
|
|
14
|
+
dedented_all = textwrap.dedent(formatted).strip()
|
|
15
|
+
|
|
16
|
+
# Split into logical paragraphs/blocks.
|
|
17
|
+
blocks = re.split(r"(\n\s*\n)", dedented_all)
|
|
18
|
+
|
|
19
|
+
processed_output = []
|
|
20
|
+
|
|
21
|
+
for block in blocks:
|
|
22
|
+
stripped_block = block.strip()
|
|
23
|
+
if not stripped_block:
|
|
24
|
+
processed_output.append(block)
|
|
25
|
+
continue
|
|
26
|
+
|
|
27
|
+
if "```" in stripped_block:
|
|
28
|
+
formula_dedented = textwrap.dedent(stripped_block)
|
|
29
|
+
processed_output.append(
|
|
30
|
+
textwrap.indent(formula_dedented, base_indent_str)
|
|
31
|
+
)
|
|
32
|
+
elif "where:" in stripped_block:
|
|
33
|
+
# Expect this to be already indented.
|
|
34
|
+
splitted_block = stripped_block.split("\n")
|
|
35
|
+
processed_output.append(
|
|
36
|
+
textwrap.indent(
|
|
37
|
+
splitted_block[0] + "\n\n" + "\n".join(splitted_block[1:]),
|
|
38
|
+
base_indent_str,
|
|
39
|
+
)
|
|
40
|
+
)
|
|
41
|
+
else:
|
|
42
|
+
processed_output.append(
|
|
43
|
+
textwrap.fill(
|
|
44
|
+
stripped_block,
|
|
45
|
+
width=width - len(base_indent_str),
|
|
46
|
+
initial_indent=base_indent_str,
|
|
47
|
+
subsequent_indent=base_indent_str,
|
|
48
|
+
)
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
final_string = "".join(processed_output).strip()
|
|
52
|
+
final_string = base_indent_str + final_string
|
|
53
|
+
return final_string
|
|
@@ -1,11 +1,35 @@
|
|
|
1
|
-
from typing import
|
|
1
|
+
from typing import Any, Callable
|
|
2
2
|
|
|
3
3
|
import keras
|
|
4
4
|
|
|
5
|
+
from keras_rs.src import types
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def no_automatic_dependency_tracking(
|
|
9
|
+
fn: Callable[..., Any],
|
|
10
|
+
) -> Callable[..., Any]:
|
|
11
|
+
"""Decorator to disable automatic dependency tracking in Keras and TF.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
fn: the function to disable automatic dependency tracking for.
|
|
15
|
+
|
|
16
|
+
Returns:
|
|
17
|
+
a wrapped version of `fn`.
|
|
18
|
+
"""
|
|
19
|
+
if keras.backend.backend() == "tensorflow":
|
|
20
|
+
import tensorflow as tf
|
|
21
|
+
|
|
22
|
+
fn = tf.__internal__.tracking.no_automatic_dependency_tracking(fn)
|
|
23
|
+
|
|
24
|
+
wrapped_fn: Callable[..., Any] = (
|
|
25
|
+
keras.src.utils.tracking.no_automatic_dependency_tracking(fn)
|
|
26
|
+
)
|
|
27
|
+
return wrapped_fn
|
|
28
|
+
|
|
5
29
|
|
|
6
30
|
def clone_initializer(
|
|
7
|
-
initializer:
|
|
8
|
-
) ->
|
|
31
|
+
initializer: types.InitializerLike,
|
|
32
|
+
) -> types.InitializerLike:
|
|
9
33
|
"""Clones an initializer to ensure a new seed.
|
|
10
34
|
|
|
11
35
|
Args:
|
|
@@ -25,3 +49,28 @@ def clone_initializer(
|
|
|
25
49
|
return initializer_class.from_config(config)
|
|
26
50
|
# If we get a string or dict, just return as we cannot and should not clone.
|
|
27
51
|
return initializer
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def check_shapes_compatible(shape1: types.Shape, shape2: types.Shape) -> bool:
|
|
55
|
+
# Check rank first.
|
|
56
|
+
if len(shape1) != len(shape2):
|
|
57
|
+
return False
|
|
58
|
+
|
|
59
|
+
for d1, d2 in zip(shape1, shape2):
|
|
60
|
+
if isinstance(d1, int) and isinstance(d2, int):
|
|
61
|
+
if d1 != d2:
|
|
62
|
+
return False
|
|
63
|
+
|
|
64
|
+
return True
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def check_rank(
|
|
68
|
+
x_rank: int,
|
|
69
|
+
allowed_ranks: tuple[int, ...],
|
|
70
|
+
tensor_name: str,
|
|
71
|
+
) -> None:
|
|
72
|
+
if x_rank not in allowed_ranks:
|
|
73
|
+
raise ValueError(
|
|
74
|
+
f"`{tensor_name}` should have a rank from `{allowed_ranks}`."
|
|
75
|
+
f"Received: `{x_rank}`."
|
|
76
|
+
)
|