keras-rs-nightly 0.0.1.dev2025021903__py3-none-any.whl → 0.3.1.dev202512130338__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_rs/__init__.py +9 -28
- keras_rs/layers/__init__.py +37 -0
- keras_rs/losses/__init__.py +19 -0
- keras_rs/metrics/__init__.py +16 -0
- keras_rs/src/layers/embedding/base_distributed_embedding.py +1151 -0
- keras_rs/src/layers/embedding/distributed_embedding.py +33 -0
- keras_rs/src/layers/embedding/distributed_embedding_config.py +132 -0
- keras_rs/src/layers/embedding/embed_reduce.py +309 -0
- keras_rs/src/layers/embedding/jax/__init__.py +0 -0
- keras_rs/src/layers/embedding/jax/checkpoint_utils.py +104 -0
- keras_rs/src/layers/embedding/jax/config_conversion.py +468 -0
- keras_rs/src/layers/embedding/jax/distributed_embedding.py +829 -0
- keras_rs/src/layers/embedding/jax/embedding_lookup.py +276 -0
- keras_rs/src/layers/embedding/jax/embedding_utils.py +217 -0
- keras_rs/src/layers/embedding/tensorflow/__init__.py +0 -0
- keras_rs/src/layers/embedding/tensorflow/config_conversion.py +363 -0
- keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py +436 -0
- keras_rs/src/layers/feature_interaction/__init__.py +0 -0
- keras_rs/src/layers/{modeling → feature_interaction}/dot_interaction.py +116 -25
- keras_rs/src/layers/{modeling → feature_interaction}/feature_cross.py +40 -22
- keras_rs/src/layers/retrieval/brute_force_retrieval.py +16 -65
- keras_rs/src/layers/retrieval/hard_negative_mining.py +94 -0
- keras_rs/src/layers/retrieval/remove_accidental_hits.py +97 -0
- keras_rs/src/layers/retrieval/retrieval.py +127 -0
- keras_rs/src/layers/retrieval/sampling_probability_correction.py +63 -0
- keras_rs/src/losses/__init__.py +0 -0
- keras_rs/src/losses/list_mle_loss.py +212 -0
- keras_rs/src/losses/pairwise_hinge_loss.py +90 -0
- keras_rs/src/losses/pairwise_logistic_loss.py +99 -0
- keras_rs/src/losses/pairwise_loss.py +165 -0
- keras_rs/src/losses/pairwise_loss_utils.py +39 -0
- keras_rs/src/losses/pairwise_mean_squared_error.py +133 -0
- keras_rs/src/losses/pairwise_soft_zero_one_loss.py +98 -0
- keras_rs/src/metrics/__init__.py +0 -0
- keras_rs/src/metrics/dcg.py +161 -0
- keras_rs/src/metrics/mean_average_precision.py +130 -0
- keras_rs/src/metrics/mean_reciprocal_rank.py +121 -0
- keras_rs/src/metrics/ndcg.py +197 -0
- keras_rs/src/metrics/precision_at_k.py +117 -0
- keras_rs/src/metrics/ranking_metric.py +260 -0
- keras_rs/src/metrics/ranking_metrics_utils.py +257 -0
- keras_rs/src/metrics/recall_at_k.py +108 -0
- keras_rs/src/metrics/utils.py +70 -0
- keras_rs/src/types.py +43 -14
- keras_rs/src/utils/doc_string_utils.py +53 -0
- keras_rs/src/utils/keras_utils.py +52 -3
- keras_rs/src/utils/tpu_test_utils.py +120 -0
- keras_rs/src/version.py +1 -1
- {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/METADATA +88 -8
- keras_rs_nightly-0.3.1.dev202512130338.dist-info/RECORD +58 -0
- {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/WHEEL +1 -1
- keras_rs/api/__init__.py +0 -9
- keras_rs/api/layers/__init__.py +0 -11
- keras_rs_nightly-0.0.1.dev2025021903.dist-info/RECORD +0 -19
- /keras_rs/src/layers/{modeling → embedding}/__init__.py +0 -0
- {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/top_level.txt +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Any
|
|
1
|
+
from typing import Any
|
|
2
2
|
|
|
3
3
|
import keras
|
|
4
4
|
from keras import ops
|
|
@@ -52,17 +52,37 @@ class FeatureCross(keras.layers.Layer):
|
|
|
52
52
|
Regularizer to use for the kernel matrix.
|
|
53
53
|
bias_regularizer: string or `keras.regularizer` regularizer.
|
|
54
54
|
Regularizer to use for the bias vector.
|
|
55
|
+
**kwargs: Args to pass to the base class.
|
|
55
56
|
|
|
56
57
|
Example:
|
|
57
58
|
|
|
58
59
|
```python
|
|
59
|
-
#
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
60
|
+
# 1. Simple forward pass
|
|
61
|
+
batch_size = 2
|
|
62
|
+
embedding_dim = 32
|
|
63
|
+
feature1 = np.random.randn(batch_size, embedding_dim)
|
|
64
|
+
feature2 = np.random.randn(batch_size, embedding_dim)
|
|
65
|
+
crossed_features = keras_rs.layers.FeatureCross()(feature1, feature2)
|
|
66
|
+
|
|
67
|
+
# 2. After embedding layer in a model
|
|
68
|
+
vocabulary_size = 32
|
|
69
|
+
embedding_dim = 6
|
|
70
|
+
|
|
71
|
+
# Create a simple model containing the layer.
|
|
72
|
+
inputs = keras.Input(shape=(), name='indices', dtype="int32")
|
|
73
|
+
x0 = keras.layers.Embedding(
|
|
74
|
+
input_dim=vocabulary_size,
|
|
75
|
+
output_dim=embedding_dim
|
|
76
|
+
)(inputs)
|
|
77
|
+
x1 = keras_rs.layers.FeatureCross()(x0, x0)
|
|
78
|
+
x2 = keras_rs.layers.FeatureCross()(x0, x1)
|
|
64
79
|
logits = keras.layers.Dense(units=10)(x2)
|
|
65
|
-
model = keras.Model(
|
|
80
|
+
model = keras.Model(inputs, logits)
|
|
81
|
+
|
|
82
|
+
# Call the model on the inputs.
|
|
83
|
+
batch_size = 2
|
|
84
|
+
input_data = np.random.randint(0, vocabulary_size, size=(batch_size,))
|
|
85
|
+
outputs = model(input_data)
|
|
66
86
|
```
|
|
67
87
|
|
|
68
88
|
References:
|
|
@@ -72,20 +92,18 @@ class FeatureCross(keras.layers.Layer):
|
|
|
72
92
|
|
|
73
93
|
def __init__(
|
|
74
94
|
self,
|
|
75
|
-
projection_dim:
|
|
76
|
-
diag_scale:
|
|
95
|
+
projection_dim: int | None = None,
|
|
96
|
+
diag_scale: float | None = 0.0,
|
|
77
97
|
use_bias: bool = True,
|
|
78
|
-
pre_activation:
|
|
79
|
-
kernel_initializer:
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
bias_initializer:
|
|
83
|
-
kernel_regularizer:
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
bias_regularizer:
|
|
87
|
-
Text, None, keras.regularizers.Regularizer
|
|
88
|
-
] = None,
|
|
98
|
+
pre_activation: str | keras.layers.Activation | None = None,
|
|
99
|
+
kernel_initializer: (
|
|
100
|
+
str | keras.initializers.Initializer
|
|
101
|
+
) = "glorot_uniform",
|
|
102
|
+
bias_initializer: str | keras.initializers.Initializer = "zeros",
|
|
103
|
+
kernel_regularizer: (
|
|
104
|
+
str | None | keras.regularizers.Regularizer
|
|
105
|
+
) = None,
|
|
106
|
+
bias_regularizer: (str | None | keras.regularizers.Regularizer) = None,
|
|
89
107
|
**kwargs: Any,
|
|
90
108
|
) -> None:
|
|
91
109
|
super().__init__(**kwargs)
|
|
@@ -109,7 +127,7 @@ class FeatureCross(keras.layers.Layer):
|
|
|
109
127
|
f"`diag_scale={self.diag_scale}`"
|
|
110
128
|
)
|
|
111
129
|
|
|
112
|
-
def build(self, input_shape: types.
|
|
130
|
+
def build(self, input_shape: types.Shape) -> None:
|
|
113
131
|
last_dim = input_shape[-1]
|
|
114
132
|
|
|
115
133
|
if self.projection_dim is not None:
|
|
@@ -135,7 +153,7 @@ class FeatureCross(keras.layers.Layer):
|
|
|
135
153
|
self.built = True
|
|
136
154
|
|
|
137
155
|
def call(
|
|
138
|
-
self, x0: types.Tensor, x:
|
|
156
|
+
self, x0: types.Tensor, x: types.Tensor | None = None
|
|
139
157
|
) -> types.Tensor:
|
|
140
158
|
"""Forward pass of the cross layer.
|
|
141
159
|
|
|
@@ -1,13 +1,14 @@
|
|
|
1
|
-
from typing import Any
|
|
1
|
+
from typing import Any
|
|
2
2
|
|
|
3
3
|
import keras
|
|
4
4
|
|
|
5
5
|
from keras_rs.src import types
|
|
6
6
|
from keras_rs.src.api_export import keras_rs_export
|
|
7
|
+
from keras_rs.src.layers.retrieval.retrieval import Retrieval
|
|
7
8
|
|
|
8
9
|
|
|
9
10
|
@keras_rs_export("keras_rs.layers.BruteForceRetrieval")
|
|
10
|
-
class BruteForceRetrieval(
|
|
11
|
+
class BruteForceRetrieval(Retrieval):
|
|
11
12
|
"""Brute force top-k retrieval.
|
|
12
13
|
|
|
13
14
|
This layer maintains a set of candidates and is able to exactly retrieve the
|
|
@@ -54,17 +55,19 @@ class BruteForceRetrieval(keras.layers.Layer):
|
|
|
54
55
|
|
|
55
56
|
def __init__(
|
|
56
57
|
self,
|
|
57
|
-
candidate_embeddings:
|
|
58
|
-
candidate_ids:
|
|
58
|
+
candidate_embeddings: types.Tensor | None = None,
|
|
59
|
+
candidate_ids: types.Tensor | None = None,
|
|
59
60
|
k: int = 10,
|
|
60
61
|
return_scores: bool = True,
|
|
61
62
|
**kwargs: Any,
|
|
62
63
|
) -> None:
|
|
63
|
-
|
|
64
|
+
# Keep `k`, `return_scores` as separately passed args instead of keeping
|
|
65
|
+
# them in `kwargs`. This is to ensure the user does not have to hop
|
|
66
|
+
# to the base class to check which other args can be passed.
|
|
67
|
+
super().__init__(k=k, return_scores=return_scores, **kwargs)
|
|
68
|
+
|
|
64
69
|
self.candidate_embeddings = None
|
|
65
70
|
self.candidate_ids = None
|
|
66
|
-
self.k = k
|
|
67
|
-
self.return_scores = return_scores
|
|
68
71
|
|
|
69
72
|
if candidate_embeddings is None:
|
|
70
73
|
if candidate_ids is not None:
|
|
@@ -78,42 +81,18 @@ class BruteForceRetrieval(keras.layers.Layer):
|
|
|
78
81
|
def update_candidates(
|
|
79
82
|
self,
|
|
80
83
|
candidate_embeddings: types.Tensor,
|
|
81
|
-
candidate_ids:
|
|
84
|
+
candidate_ids: types.Tensor | None = None,
|
|
82
85
|
) -> None:
|
|
83
86
|
"""Update the set of candidates and optionally their candidate IDs.
|
|
84
87
|
|
|
85
88
|
Args:
|
|
86
89
|
candidate_embeddings: The candidate embeddings.
|
|
87
|
-
candidate_ids: The identifiers for the candidates. If `None
|
|
90
|
+
candidate_ids: The identifiers for the candidates. If `None`, the
|
|
88
91
|
indices of the candidates are returned instead.
|
|
89
92
|
"""
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
if len(candidate_embeddings.shape) != 2:
|
|
94
|
-
raise ValueError(
|
|
95
|
-
"`candidate_embeddings` must be a tensor of rank 2 "
|
|
96
|
-
"(num_candidates, embedding_size), received "
|
|
97
|
-
"`candidate_embeddings` with shape "
|
|
98
|
-
f"{candidate_embeddings.shape}"
|
|
99
|
-
)
|
|
100
|
-
|
|
101
|
-
if candidate_embeddings.shape[0] < self.k:
|
|
102
|
-
raise ValueError(
|
|
103
|
-
"The number of candidates provided "
|
|
104
|
-
f"({candidate_embeddings.shape[0]}) is less than the number of "
|
|
105
|
-
f"candidates to retrieve (k={self.k})."
|
|
106
|
-
)
|
|
107
|
-
|
|
108
|
-
if (
|
|
109
|
-
candidate_ids is not None
|
|
110
|
-
and candidate_ids.shape[0] != candidate_embeddings.shape[0]
|
|
111
|
-
):
|
|
112
|
-
raise ValueError(
|
|
113
|
-
"The `candidate_embeddings` and `candidate_is` tensors must "
|
|
114
|
-
"have the same number of rows, got tensors of shape "
|
|
115
|
-
f"{candidate_embeddings.shape} and {candidate_ids.shape}."
|
|
116
|
-
)
|
|
93
|
+
self._validate_candidate_embeddings_and_ids(
|
|
94
|
+
candidate_embeddings, candidate_ids
|
|
95
|
+
)
|
|
117
96
|
|
|
118
97
|
if self.candidate_embeddings is not None:
|
|
119
98
|
# Update of existing variables.
|
|
@@ -146,7 +125,7 @@ class BruteForceRetrieval(keras.layers.Layer):
|
|
|
146
125
|
|
|
147
126
|
def call(
|
|
148
127
|
self, inputs: types.Tensor
|
|
149
|
-
) ->
|
|
128
|
+
) -> types.Tensor | tuple[types.Tensor, types.Tensor]:
|
|
150
129
|
"""Returns the top candidates for the query passed as input.
|
|
151
130
|
|
|
152
131
|
Args:
|
|
@@ -167,31 +146,3 @@ class BruteForceRetrieval(keras.layers.Layer):
|
|
|
167
146
|
return top_scores, top_ids
|
|
168
147
|
else:
|
|
169
148
|
return top_ids
|
|
170
|
-
|
|
171
|
-
def compute_score(
|
|
172
|
-
self, query_embedding: types.Tensor, candidate_embedding: types.Tensor
|
|
173
|
-
) -> types.Tensor:
|
|
174
|
-
"""Computes the standard dot product score from queries and candidates.
|
|
175
|
-
|
|
176
|
-
Args:
|
|
177
|
-
query_embedding: Tensor of query embedding corresponding to the
|
|
178
|
-
queries for which to retrieve top candidates.
|
|
179
|
-
candidate_embedding: Tensor of candidate embeddings.
|
|
180
|
-
|
|
181
|
-
Returns:
|
|
182
|
-
The dot product of queries and candidates.
|
|
183
|
-
"""
|
|
184
|
-
|
|
185
|
-
return keras.ops.matmul(
|
|
186
|
-
query_embedding, keras.ops.transpose(candidate_embedding)
|
|
187
|
-
)
|
|
188
|
-
|
|
189
|
-
def get_config(self) -> dict[str, Any]:
|
|
190
|
-
config: dict[str, Any] = super().get_config()
|
|
191
|
-
config.update(
|
|
192
|
-
{
|
|
193
|
-
"k": self.k,
|
|
194
|
-
"return_scores": self.compute_score,
|
|
195
|
-
}
|
|
196
|
-
)
|
|
197
|
-
return config
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
import keras
|
|
4
|
+
import ml_dtypes
|
|
5
|
+
from keras import ops
|
|
6
|
+
|
|
7
|
+
from keras_rs.src import types
|
|
8
|
+
from keras_rs.src.api_export import keras_rs_export
|
|
9
|
+
|
|
10
|
+
MAX_FLOAT = ml_dtypes.finfo("float32").max / 100.0
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@keras_rs_export("keras_rs.layers.HardNegativeMining")
|
|
14
|
+
class HardNegativeMining(keras.layers.Layer):
|
|
15
|
+
"""Filter logits and labels to return hard negatives.
|
|
16
|
+
|
|
17
|
+
The output will include logits and labels for the requested number of hard
|
|
18
|
+
negatives as well as the positive candidate.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
num_hard_negatives: How many hard negatives to return.
|
|
22
|
+
**kwargs: Args to pass to the base class.
|
|
23
|
+
|
|
24
|
+
Example:
|
|
25
|
+
|
|
26
|
+
```python
|
|
27
|
+
# Create layer with the configured number of hard negatives to mine.
|
|
28
|
+
hard_negative_mining = keras_rs.layers.HardNegativeMining(
|
|
29
|
+
num_hard_negatives=10
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
# This will retrieve the top 10 negative candidates plus the positive
|
|
33
|
+
# candidate from `labels` for each row.
|
|
34
|
+
out_logits, out_labels = hard_negative_mining(in_logits, in_labels)
|
|
35
|
+
```
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
def __init__(self, num_hard_negatives: int, **kwargs: Any) -> None:
|
|
39
|
+
super().__init__(**kwargs)
|
|
40
|
+
self._num_hard_negatives = num_hard_negatives
|
|
41
|
+
self.built = True
|
|
42
|
+
|
|
43
|
+
def call(
|
|
44
|
+
self, logits: types.Tensor, labels: types.Tensor
|
|
45
|
+
) -> tuple[types.Tensor, types.Tensor]:
|
|
46
|
+
"""Filters logits and labels with per-query hard negative mining.
|
|
47
|
+
|
|
48
|
+
The result will include logits and labels for `num_hard_negatives`
|
|
49
|
+
negatives as well as the positive candidate.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
logits: The logits tensor, typically `[batch_size, num_candidates]`
|
|
53
|
+
but can have more dimensions or be 1D as `[num_candidates]`.
|
|
54
|
+
labels: The one-hot labels tensor, must be the same shape as
|
|
55
|
+
`logits`.
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
A tuple containing two tensors with the last dimension of
|
|
59
|
+
`num_candidates` replaced with `num_hard_negatives + 1`.
|
|
60
|
+
|
|
61
|
+
* logits: `[..., num_hard_negatives + 1]` tensor of logits.
|
|
62
|
+
* labels: `[..., num_hard_negatives + 1]` one-hot tensor of labels.
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
# Number of sampled logits, i.e, the number of hard negatives to be
|
|
66
|
+
# sampled (k) + number of true logit (1) per query, capped by batch
|
|
67
|
+
# size.
|
|
68
|
+
num_logits = ops.shape(logits)[-1]
|
|
69
|
+
if isinstance(num_logits, int):
|
|
70
|
+
num_sampled = min(self._num_hard_negatives + 1, num_logits)
|
|
71
|
+
else:
|
|
72
|
+
num_sampled = ops.minimum(self._num_hard_negatives + 1, num_logits)
|
|
73
|
+
# To gather indices of top k negative logits per row (query) in logits,
|
|
74
|
+
# true logits need to be excluded. First replace the true logits
|
|
75
|
+
# (corresponding to positive labels) with a large score value and then
|
|
76
|
+
# select the top k + 1 logits from each row so that selected indices
|
|
77
|
+
# include the indices of true logit + top k negative logits. This
|
|
78
|
+
# approach is to avoid using inefficient masking when excluding true
|
|
79
|
+
# logits.
|
|
80
|
+
|
|
81
|
+
# For each query, get the indices of the logits which have the highest
|
|
82
|
+
# k + 1 logit values, including the highest k negative logits and one
|
|
83
|
+
# true logit.
|
|
84
|
+
_, indices = ops.top_k(
|
|
85
|
+
ops.add(logits, ops.multiply(labels, MAX_FLOAT)),
|
|
86
|
+
k=num_sampled,
|
|
87
|
+
sorted=False,
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
# Gather sampled logits and corresponding labels.
|
|
91
|
+
logits = ops.take_along_axis(logits, indices, axis=-1)
|
|
92
|
+
labels = ops.take_along_axis(labels, indices, axis=-1)
|
|
93
|
+
|
|
94
|
+
return logits, labels
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
import keras
|
|
2
|
+
import ml_dtypes
|
|
3
|
+
from keras import ops
|
|
4
|
+
|
|
5
|
+
from keras_rs.src import types
|
|
6
|
+
from keras_rs.src.api_export import keras_rs_export
|
|
7
|
+
from keras_rs.src.utils import keras_utils
|
|
8
|
+
|
|
9
|
+
SMALLEST_FLOAT = ml_dtypes.finfo("float32").smallest_normal / 100.0
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@keras_rs_export("keras_rs.layers.RemoveAccidentalHits")
|
|
13
|
+
class RemoveAccidentalHits(keras.layers.Layer):
|
|
14
|
+
"""Zeroes the logits of accidental negatives.
|
|
15
|
+
|
|
16
|
+
Zeroes the logits of negative candidates that have the same ID as the
|
|
17
|
+
positive candidate in that row.
|
|
18
|
+
|
|
19
|
+
Example:
|
|
20
|
+
|
|
21
|
+
```python
|
|
22
|
+
# Create layer with the configured number of hard negatives to mine.
|
|
23
|
+
remove_accidental_hits = keras_rs.layers.RemoveAccidentalHits()
|
|
24
|
+
|
|
25
|
+
# This will zero the logits of negative candidates that have the same ID as
|
|
26
|
+
# the positive candidate from `labels` so as to not negatively impact the
|
|
27
|
+
# true positive.
|
|
28
|
+
logits = remove_accidental_hits(logits, labels, candidate_ids)
|
|
29
|
+
```
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def call(
|
|
33
|
+
self,
|
|
34
|
+
logits: types.Tensor,
|
|
35
|
+
labels: types.Tensor,
|
|
36
|
+
candidate_ids: types.Tensor,
|
|
37
|
+
) -> types.Tensor:
|
|
38
|
+
"""Zeroes selected logits.
|
|
39
|
+
|
|
40
|
+
For each row in the batch, zeroes the logits of negative candidates that
|
|
41
|
+
have the same ID as the positive candidate in that row.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
logits: The logits tensor, typically `[batch_size, num_candidates]`
|
|
45
|
+
but can have more dimensions or be 1D as `[num_candidates]`.
|
|
46
|
+
labels: The one-hot labels tensor, must be the same shape as
|
|
47
|
+
`logits`.
|
|
48
|
+
candidate_ids: The candidate identifiers tensor, can be
|
|
49
|
+
`[num_candidates]` or `[batch_size, num_candidates]` or have
|
|
50
|
+
more dimensions as long as they match the last dimensions of
|
|
51
|
+
`labels`.
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
The modified logits with the same shape as the input logits.
|
|
55
|
+
"""
|
|
56
|
+
# A more principled way is to implement
|
|
57
|
+
# `softmax_cross_entropy_with_logits` with a input mask. Here we
|
|
58
|
+
# approximate so by letting accidental hits have extremely small logits
|
|
59
|
+
# (SMALLEST_FLOAT) for ease-of-implementation.
|
|
60
|
+
|
|
61
|
+
labels_shape = ops.shape(labels)
|
|
62
|
+
labels_rank = len(labels_shape)
|
|
63
|
+
logits_shape = ops.shape(logits)
|
|
64
|
+
candidate_ids_shape = ops.shape(candidate_ids)
|
|
65
|
+
candidate_ids_rank = len(candidate_ids_shape)
|
|
66
|
+
|
|
67
|
+
if not keras_utils.check_shapes_compatible(labels_shape, logits_shape):
|
|
68
|
+
raise ValueError(
|
|
69
|
+
"`labels` and `logits` should have the same shape. Received: "
|
|
70
|
+
f"`labels.shape` = {labels_shape}, "
|
|
71
|
+
f"`logits.shape` = {logits_shape}."
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
if not keras_utils.check_shapes_compatible(
|
|
75
|
+
labels_shape[-candidate_ids_rank:], candidate_ids_shape
|
|
76
|
+
):
|
|
77
|
+
raise ValueError(
|
|
78
|
+
"`candidate_ids` should have the same shape as the last "
|
|
79
|
+
"dimensions of `labels`. Received: "
|
|
80
|
+
f"`candidate_ids.shape` = {candidate_ids_shape}, "
|
|
81
|
+
f"`labels.shape` = {labels_shape}."
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
# Add dimensions to `candidate_ids` to have the same rank as `labels`.
|
|
85
|
+
if candidate_ids_rank < labels_rank:
|
|
86
|
+
candidate_ids = ops.expand_dims(
|
|
87
|
+
candidate_ids, list(range(labels_rank - candidate_ids_rank))
|
|
88
|
+
)
|
|
89
|
+
positive_indices = ops.expand_dims(ops.argmax(labels, axis=-1), -1)
|
|
90
|
+
positive_candidate_ids = ops.take(candidate_ids, positive_indices)
|
|
91
|
+
|
|
92
|
+
duplicate = ops.cast(
|
|
93
|
+
ops.equal(positive_candidate_ids, candidate_ids), labels.dtype
|
|
94
|
+
)
|
|
95
|
+
duplicate = ops.subtract(duplicate, labels)
|
|
96
|
+
|
|
97
|
+
return ops.add(logits, ops.multiply(duplicate, SMALLEST_FLOAT))
|
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
import keras
|
|
5
|
+
|
|
6
|
+
from keras_rs.src import types
|
|
7
|
+
from keras_rs.src.api_export import keras_rs_export
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@keras_rs_export("keras_rs.layers.Retrieval")
|
|
11
|
+
class Retrieval(keras.layers.Layer, abc.ABC):
|
|
12
|
+
"""Retrieval base abstract class.
|
|
13
|
+
|
|
14
|
+
This layer provides a common interface for all retrieval layers. In order
|
|
15
|
+
to implement a custom retrieval layer, this abstract class should be
|
|
16
|
+
subclassed.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
k: int. Number of candidates to retrieve.
|
|
20
|
+
return_scores: bool. When `True`, this layer returns a tuple with the
|
|
21
|
+
top scores and the top identifiers. When `False`, this layer returns
|
|
22
|
+
a single tensor with the top identifiers.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
k: int = 10,
|
|
28
|
+
return_scores: bool = True,
|
|
29
|
+
**kwargs: Any,
|
|
30
|
+
) -> None:
|
|
31
|
+
super().__init__(**kwargs)
|
|
32
|
+
self.k = k
|
|
33
|
+
self.return_scores = return_scores
|
|
34
|
+
|
|
35
|
+
def _validate_candidate_embeddings_and_ids(
|
|
36
|
+
self,
|
|
37
|
+
candidate_embeddings: types.Tensor,
|
|
38
|
+
candidate_ids: types.Tensor | None = None,
|
|
39
|
+
) -> None:
|
|
40
|
+
"""Validates inputs to `update_candidates()`."""
|
|
41
|
+
|
|
42
|
+
if candidate_embeddings is None:
|
|
43
|
+
raise ValueError("`candidate_embeddings` is required.")
|
|
44
|
+
|
|
45
|
+
if len(candidate_embeddings.shape) != 2:
|
|
46
|
+
raise ValueError(
|
|
47
|
+
"`candidate_embeddings` must be a tensor of rank 2 "
|
|
48
|
+
"(num_candidates, embedding_size), received "
|
|
49
|
+
"`candidate_embeddings` with shape "
|
|
50
|
+
f"{candidate_embeddings.shape}"
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
if candidate_embeddings.shape[0] < self.k:
|
|
54
|
+
raise ValueError(
|
|
55
|
+
"The number of candidates provided "
|
|
56
|
+
f"({candidate_embeddings.shape[0]}) is less than the number of "
|
|
57
|
+
f"candidates to retrieve (k={self.k})."
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
if (
|
|
61
|
+
candidate_ids is not None
|
|
62
|
+
and candidate_ids.shape[0] != candidate_embeddings.shape[0]
|
|
63
|
+
):
|
|
64
|
+
raise ValueError(
|
|
65
|
+
"The `candidate_embeddings` and `candidate_is` tensors must "
|
|
66
|
+
"have the same number of rows, got tensors of shape "
|
|
67
|
+
f"{candidate_embeddings.shape} and {candidate_ids.shape}."
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
@abc.abstractmethod
|
|
71
|
+
def update_candidates(
|
|
72
|
+
self,
|
|
73
|
+
candidate_embeddings: types.Tensor,
|
|
74
|
+
candidate_ids: types.Tensor | None = None,
|
|
75
|
+
) -> None:
|
|
76
|
+
"""Update the set of candidates and optionally their candidate IDs.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
candidate_embeddings: The candidate embeddings.
|
|
80
|
+
candidate_ids: The identifiers for the candidates. If `None`, the
|
|
81
|
+
indices of the candidates are returned instead.
|
|
82
|
+
"""
|
|
83
|
+
pass
|
|
84
|
+
|
|
85
|
+
@abc.abstractmethod
|
|
86
|
+
def call(
|
|
87
|
+
self, inputs: types.Tensor
|
|
88
|
+
) -> types.Tensor | tuple[types.Tensor, types.Tensor]:
|
|
89
|
+
"""Returns the top candidates for the query passed as input.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
inputs: the query for which to return top candidates.
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
A tuple with the top scores and the top identifiers if
|
|
96
|
+
`returns_scores` is True, otherwise a tensor with the top
|
|
97
|
+
identifiers.
|
|
98
|
+
"""
|
|
99
|
+
pass
|
|
100
|
+
|
|
101
|
+
def compute_score(
|
|
102
|
+
self, query_embedding: types.Tensor, candidate_embedding: types.Tensor
|
|
103
|
+
) -> types.Tensor:
|
|
104
|
+
"""Computes the standard dot product score from queries and candidates.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
query_embedding: Tensor of query embedding corresponding to the
|
|
108
|
+
queries for which to retrieve top candidates.
|
|
109
|
+
candidate_embedding: Tensor of candidate embeddings.
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
The dot product of queries and candidates.
|
|
113
|
+
"""
|
|
114
|
+
|
|
115
|
+
return keras.ops.matmul(
|
|
116
|
+
query_embedding, keras.ops.transpose(candidate_embedding)
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
def get_config(self) -> dict[str, Any]:
|
|
120
|
+
config: dict[str, Any] = super().get_config()
|
|
121
|
+
config.update(
|
|
122
|
+
{
|
|
123
|
+
"k": self.k,
|
|
124
|
+
"return_scores": self.compute_score,
|
|
125
|
+
}
|
|
126
|
+
)
|
|
127
|
+
return config
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
import keras
|
|
4
|
+
from keras import ops
|
|
5
|
+
|
|
6
|
+
from keras_rs.src import types
|
|
7
|
+
from keras_rs.src.api_export import keras_rs_export
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@keras_rs_export("keras_rs.layers.SamplingProbabilityCorrection")
|
|
11
|
+
class SamplingProbabilityCorrection(keras.layers.Layer):
|
|
12
|
+
"""Sampling probability correction.
|
|
13
|
+
|
|
14
|
+
Corrects the logits to reflect the sampling probability of negatives.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
epsilon: float. Small float added to sampling probability to avoid
|
|
18
|
+
taking the log of zero. Defaults to 1e-6.
|
|
19
|
+
**kwargs: Args to pass to the base class.
|
|
20
|
+
|
|
21
|
+
Example:
|
|
22
|
+
|
|
23
|
+
```python
|
|
24
|
+
# Create the layer.
|
|
25
|
+
sampling_probability_correction = (
|
|
26
|
+
keras_rs.layers.SamplingProbabilityCorrection()
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
# Correct the logits based on the provided candidate sampling probability.
|
|
30
|
+
logits = sampling_probability_correction(logits, probabilities)
|
|
31
|
+
```
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(self, epsilon: float = 1e-6, **kwargs: Any) -> None:
|
|
35
|
+
super().__init__(**kwargs)
|
|
36
|
+
self.epsilon = epsilon
|
|
37
|
+
self.built = True
|
|
38
|
+
|
|
39
|
+
def call(
|
|
40
|
+
self,
|
|
41
|
+
logits: types.Tensor,
|
|
42
|
+
candidate_sampling_probability: types.Tensor,
|
|
43
|
+
) -> types.Tensor:
|
|
44
|
+
"""Corrects input logits to account for candidate sampling probability.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
logits: The logits tensor to correct, typically
|
|
48
|
+
`[batch_size, num_candidates]` but can have more dimensions or
|
|
49
|
+
be 1D as `[num_candidates]`.
|
|
50
|
+
candidate_sampling_probability: The sampling probability with the
|
|
51
|
+
same shape as `logits`.
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
The corrected logits with the same shape as the input logits.
|
|
55
|
+
"""
|
|
56
|
+
return logits - ops.log(
|
|
57
|
+
ops.clip(candidate_sampling_probability, self.epsilon, 1.0)
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
def get_config(self) -> dict[str, Any]:
|
|
61
|
+
config: dict[str, Any] = super().get_config()
|
|
62
|
+
config.update({"epsilon": self.epsilon})
|
|
63
|
+
return config
|
|
File without changes
|