keras-rs-nightly 0.0.1.dev2025041403__py3-none-any.whl → 0.0.1.dev2025041603__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of keras-rs-nightly might be problematic. Click here for more details.
- keras_rs/api/layers/__init__.py +1 -0
- keras_rs/src/layers/retrieval/brute_force_retrieval.py +11 -60
- keras_rs/src/layers/retrieval/retrieval.py +127 -0
- keras_rs/src/version.py +1 -1
- {keras_rs_nightly-0.0.1.dev2025041403.dist-info → keras_rs_nightly-0.0.1.dev2025041603.dist-info}/METADATA +1 -1
- {keras_rs_nightly-0.0.1.dev2025041403.dist-info → keras_rs_nightly-0.0.1.dev2025041603.dist-info}/RECORD +8 -7
- {keras_rs_nightly-0.0.1.dev2025041403.dist-info → keras_rs_nightly-0.0.1.dev2025041603.dist-info}/WHEEL +0 -0
- {keras_rs_nightly-0.0.1.dev2025041403.dist-info → keras_rs_nightly-0.0.1.dev2025041603.dist-info}/top_level.txt +0 -0
keras_rs/api/layers/__init__.py
CHANGED
|
@@ -17,6 +17,7 @@ from keras_rs.src.layers.retrieval.hard_negative_mining import (
|
|
|
17
17
|
from keras_rs.src.layers.retrieval.remove_accidental_hits import (
|
|
18
18
|
RemoveAccidentalHits,
|
|
19
19
|
)
|
|
20
|
+
from keras_rs.src.layers.retrieval.retrieval import Retrieval
|
|
20
21
|
from keras_rs.src.layers.retrieval.sampling_probability_correction import (
|
|
21
22
|
SamplingProbabilityCorrection,
|
|
22
23
|
)
|
|
@@ -4,10 +4,11 @@ import keras
|
|
|
4
4
|
|
|
5
5
|
from keras_rs.src import types
|
|
6
6
|
from keras_rs.src.api_export import keras_rs_export
|
|
7
|
+
from keras_rs.src.layers.retrieval.retrieval import Retrieval
|
|
7
8
|
|
|
8
9
|
|
|
9
10
|
@keras_rs_export("keras_rs.layers.BruteForceRetrieval")
|
|
10
|
-
class BruteForceRetrieval(
|
|
11
|
+
class BruteForceRetrieval(Retrieval):
|
|
11
12
|
"""Brute force top-k retrieval.
|
|
12
13
|
|
|
13
14
|
This layer maintains a set of candidates and is able to exactly retrieve the
|
|
@@ -60,11 +61,13 @@ class BruteForceRetrieval(keras.layers.Layer):
|
|
|
60
61
|
return_scores: bool = True,
|
|
61
62
|
**kwargs: Any,
|
|
62
63
|
) -> None:
|
|
63
|
-
|
|
64
|
+
# Keep `k`, `return_scores` as separately passed args instead of keeping
|
|
65
|
+
# them in `kwargs`. This is to ensure the user does not have to hop
|
|
66
|
+
# to the base class to check which other args can be passed.
|
|
67
|
+
super().__init__(k=k, return_scores=return_scores, **kwargs)
|
|
68
|
+
|
|
64
69
|
self.candidate_embeddings = None
|
|
65
70
|
self.candidate_ids = None
|
|
66
|
-
self.k = k
|
|
67
|
-
self.return_scores = return_scores
|
|
68
71
|
|
|
69
72
|
if candidate_embeddings is None:
|
|
70
73
|
if candidate_ids is not None:
|
|
@@ -84,36 +87,12 @@ class BruteForceRetrieval(keras.layers.Layer):
|
|
|
84
87
|
|
|
85
88
|
Args:
|
|
86
89
|
candidate_embeddings: The candidate embeddings.
|
|
87
|
-
candidate_ids: The identifiers for the candidates. If `None
|
|
90
|
+
candidate_ids: The identifiers for the candidates. If `None`, the
|
|
88
91
|
indices of the candidates are returned instead.
|
|
89
92
|
"""
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
if len(candidate_embeddings.shape) != 2:
|
|
94
|
-
raise ValueError(
|
|
95
|
-
"`candidate_embeddings` must be a tensor of rank 2 "
|
|
96
|
-
"(num_candidates, embedding_size), received "
|
|
97
|
-
"`candidate_embeddings` with shape "
|
|
98
|
-
f"{candidate_embeddings.shape}"
|
|
99
|
-
)
|
|
100
|
-
|
|
101
|
-
if candidate_embeddings.shape[0] < self.k:
|
|
102
|
-
raise ValueError(
|
|
103
|
-
"The number of candidates provided "
|
|
104
|
-
f"({candidate_embeddings.shape[0]}) is less than the number of "
|
|
105
|
-
f"candidates to retrieve (k={self.k})."
|
|
106
|
-
)
|
|
107
|
-
|
|
108
|
-
if (
|
|
109
|
-
candidate_ids is not None
|
|
110
|
-
and candidate_ids.shape[0] != candidate_embeddings.shape[0]
|
|
111
|
-
):
|
|
112
|
-
raise ValueError(
|
|
113
|
-
"The `candidate_embeddings` and `candidate_is` tensors must "
|
|
114
|
-
"have the same number of rows, got tensors of shape "
|
|
115
|
-
f"{candidate_embeddings.shape} and {candidate_ids.shape}."
|
|
116
|
-
)
|
|
93
|
+
self._validate_candidate_embeddings_and_ids(
|
|
94
|
+
candidate_embeddings, candidate_ids
|
|
95
|
+
)
|
|
117
96
|
|
|
118
97
|
if self.candidate_embeddings is not None:
|
|
119
98
|
# Update of existing variables.
|
|
@@ -167,31 +146,3 @@ class BruteForceRetrieval(keras.layers.Layer):
|
|
|
167
146
|
return top_scores, top_ids
|
|
168
147
|
else:
|
|
169
148
|
return top_ids
|
|
170
|
-
|
|
171
|
-
def compute_score(
|
|
172
|
-
self, query_embedding: types.Tensor, candidate_embedding: types.Tensor
|
|
173
|
-
) -> types.Tensor:
|
|
174
|
-
"""Computes the standard dot product score from queries and candidates.
|
|
175
|
-
|
|
176
|
-
Args:
|
|
177
|
-
query_embedding: Tensor of query embedding corresponding to the
|
|
178
|
-
queries for which to retrieve top candidates.
|
|
179
|
-
candidate_embedding: Tensor of candidate embeddings.
|
|
180
|
-
|
|
181
|
-
Returns:
|
|
182
|
-
The dot product of queries and candidates.
|
|
183
|
-
"""
|
|
184
|
-
|
|
185
|
-
return keras.ops.matmul(
|
|
186
|
-
query_embedding, keras.ops.transpose(candidate_embedding)
|
|
187
|
-
)
|
|
188
|
-
|
|
189
|
-
def get_config(self) -> dict[str, Any]:
|
|
190
|
-
config: dict[str, Any] = super().get_config()
|
|
191
|
-
config.update(
|
|
192
|
-
{
|
|
193
|
-
"k": self.k,
|
|
194
|
-
"return_scores": self.compute_score,
|
|
195
|
-
}
|
|
196
|
-
)
|
|
197
|
-
return config
|
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
from typing import Any, Optional, Union
|
|
3
|
+
|
|
4
|
+
import keras
|
|
5
|
+
|
|
6
|
+
from keras_rs.src import types
|
|
7
|
+
from keras_rs.src.api_export import keras_rs_export
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@keras_rs_export("keras_rs.layers.Retrieval")
|
|
11
|
+
class Retrieval(keras.layers.Layer, abc.ABC):
|
|
12
|
+
"""Retrieval base abstract class.
|
|
13
|
+
|
|
14
|
+
This layer provides a common interface for all retrieval layers. In order
|
|
15
|
+
to implement a custom retrieval layer, this abstract class should be
|
|
16
|
+
subclassed.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
k: int. Number of candidates to retrieve.
|
|
20
|
+
return_scores: bool. When `True`, this layer returns a tuple with the
|
|
21
|
+
top scores and the top identifiers. When `False`, this layer returns
|
|
22
|
+
a single tensor with the top identifiers.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
k: int = 10,
|
|
28
|
+
return_scores: bool = True,
|
|
29
|
+
**kwargs: Any,
|
|
30
|
+
) -> None:
|
|
31
|
+
super().__init__(**kwargs)
|
|
32
|
+
self.k = k
|
|
33
|
+
self.return_scores = return_scores
|
|
34
|
+
|
|
35
|
+
def _validate_candidate_embeddings_and_ids(
|
|
36
|
+
self,
|
|
37
|
+
candidate_embeddings: types.Tensor,
|
|
38
|
+
candidate_ids: Optional[types.Tensor] = None,
|
|
39
|
+
) -> None:
|
|
40
|
+
"""Validates inputs to `update_candidates()`."""
|
|
41
|
+
|
|
42
|
+
if candidate_embeddings is None:
|
|
43
|
+
raise ValueError("`candidate_embeddings` is required.")
|
|
44
|
+
|
|
45
|
+
if len(candidate_embeddings.shape) != 2:
|
|
46
|
+
raise ValueError(
|
|
47
|
+
"`candidate_embeddings` must be a tensor of rank 2 "
|
|
48
|
+
"(num_candidates, embedding_size), received "
|
|
49
|
+
"`candidate_embeddings` with shape "
|
|
50
|
+
f"{candidate_embeddings.shape}"
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
if candidate_embeddings.shape[0] < self.k:
|
|
54
|
+
raise ValueError(
|
|
55
|
+
"The number of candidates provided "
|
|
56
|
+
f"({candidate_embeddings.shape[0]}) is less than the number of "
|
|
57
|
+
f"candidates to retrieve (k={self.k})."
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
if (
|
|
61
|
+
candidate_ids is not None
|
|
62
|
+
and candidate_ids.shape[0] != candidate_embeddings.shape[0]
|
|
63
|
+
):
|
|
64
|
+
raise ValueError(
|
|
65
|
+
"The `candidate_embeddings` and `candidate_is` tensors must "
|
|
66
|
+
"have the same number of rows, got tensors of shape "
|
|
67
|
+
f"{candidate_embeddings.shape} and {candidate_ids.shape}."
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
@abc.abstractmethod
|
|
71
|
+
def update_candidates(
|
|
72
|
+
self,
|
|
73
|
+
candidate_embeddings: types.Tensor,
|
|
74
|
+
candidate_ids: Optional[types.Tensor] = None,
|
|
75
|
+
) -> None:
|
|
76
|
+
"""Update the set of candidates and optionally their candidate IDs.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
candidate_embeddings: The candidate embeddings.
|
|
80
|
+
candidate_ids: The identifiers for the candidates. If `None`, the
|
|
81
|
+
indices of the candidates are returned instead.
|
|
82
|
+
"""
|
|
83
|
+
pass
|
|
84
|
+
|
|
85
|
+
@abc.abstractmethod
|
|
86
|
+
def call(
|
|
87
|
+
self, inputs: types.Tensor
|
|
88
|
+
) -> Union[types.Tensor, tuple[types.Tensor, types.Tensor]]:
|
|
89
|
+
"""Returns the top candidates for the query passed as input.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
inputs: the query for which to return top candidates.
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
A tuple with the top scores and the top identifiers if
|
|
96
|
+
`returns_scores` is True, otherwise a tensor with the top
|
|
97
|
+
identifiers.
|
|
98
|
+
"""
|
|
99
|
+
pass
|
|
100
|
+
|
|
101
|
+
def compute_score(
|
|
102
|
+
self, query_embedding: types.Tensor, candidate_embedding: types.Tensor
|
|
103
|
+
) -> types.Tensor:
|
|
104
|
+
"""Computes the standard dot product score from queries and candidates.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
query_embedding: Tensor of query embedding corresponding to the
|
|
108
|
+
queries for which to retrieve top candidates.
|
|
109
|
+
candidate_embedding: Tensor of candidate embeddings.
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
The dot product of queries and candidates.
|
|
113
|
+
"""
|
|
114
|
+
|
|
115
|
+
return keras.ops.matmul(
|
|
116
|
+
query_embedding, keras.ops.transpose(candidate_embedding)
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
def get_config(self) -> dict[str, Any]:
|
|
120
|
+
config: dict[str, Any] = super().get_config()
|
|
121
|
+
config.update(
|
|
122
|
+
{
|
|
123
|
+
"k": self.k,
|
|
124
|
+
"return_scores": self.compute_score,
|
|
125
|
+
}
|
|
126
|
+
)
|
|
127
|
+
return config
|
keras_rs/src/version.py
CHANGED
|
@@ -1,19 +1,20 @@
|
|
|
1
1
|
keras_rs/__init__.py,sha256=X3VNKb_6VDEs5GqcbEc_l8mAsefWb5UgSu8krnQdFcM,794
|
|
2
2
|
keras_rs/api/__init__.py,sha256=9Xf-uH9j_SBaTc5RU0pkxrOEgHWPwSKjf4_maySH_nU,272
|
|
3
|
-
keras_rs/api/layers/__init__.py,sha256=
|
|
3
|
+
keras_rs/api/layers/__init__.py,sha256=SB7_QOBPizvbbyQAMb8mPl7vAx0gCxJBPm6V7H67SgU,747
|
|
4
4
|
keras_rs/api/losses/__init__.py,sha256=LGW7eHQh8FbQXdMV1s9zJpbloVlz_Zlo51sorWAvFwE,455
|
|
5
5
|
keras_rs/src/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
6
6
|
keras_rs/src/api_export.py,sha256=RsmG-DvO-cdFeAF9W6LRzms0kvtm-Yp9BAA_d-952zI,510
|
|
7
7
|
keras_rs/src/types.py,sha256=UyOdgjqrqg_b58opnY8n6gTiDHKVR8z_bmEruehERBk,514
|
|
8
|
-
keras_rs/src/version.py,sha256=
|
|
8
|
+
keras_rs/src/version.py,sha256=wgkFHQtzZaQah52nHJja4pYng_h75BXqZEskD1h29LI,222
|
|
9
9
|
keras_rs/src/layers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
10
10
|
keras_rs/src/layers/feature_interaction/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
11
11
|
keras_rs/src/layers/feature_interaction/dot_interaction.py,sha256=jGHcg0EiWxth6LTxG2yWgHcyx_GXrxvA61uQqpPfnDQ,6900
|
|
12
12
|
keras_rs/src/layers/feature_interaction/feature_cross.py,sha256=5OCSI0vFYzJNmgkKcuHIbVv8U2q3UvS80-qZjPimDjM,8155
|
|
13
13
|
keras_rs/src/layers/retrieval/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
14
|
-
keras_rs/src/layers/retrieval/brute_force_retrieval.py,sha256=
|
|
14
|
+
keras_rs/src/layers/retrieval/brute_force_retrieval.py,sha256=izdppBXxJH0KqYEg7Zsr-SL-SHgAmnFopXMPalEO3uw,5676
|
|
15
15
|
keras_rs/src/layers/retrieval/hard_negative_mining.py,sha256=CY8-3W52ZBIFcEfvjXJxbFltD6ulXl4-sZCRF6stIEc,4119
|
|
16
16
|
keras_rs/src/layers/retrieval/remove_accidental_hits.py,sha256=Z84z2YgKspKeNdc5id8lf9TAyFsbCCz3acJxiKXYipc,3324
|
|
17
|
+
keras_rs/src/layers/retrieval/retrieval.py,sha256=hVOBF10SF2q_TgJdVUqztbnw5qQF-cxVRGdJbOKoL9M,4191
|
|
17
18
|
keras_rs/src/layers/retrieval/sampling_probability_correction.py,sha256=80vgOPfBiF-PC0dSyqS57IcIxOxi_Q_R7eSXHn1G0yI,1437
|
|
18
19
|
keras_rs/src/losses/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
19
20
|
keras_rs/src/losses/pairwise_hinge_loss.py,sha256=vqDGd-OnZxiqdeE6vuabE8BKDfill3D2GM0lW5JUmsg,922
|
|
@@ -24,7 +25,7 @@ keras_rs/src/losses/pairwise_soft_zero_one_loss.py,sha256=XBej5nybFXEQ-Vp6GLvNmq
|
|
|
24
25
|
keras_rs/src/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
25
26
|
keras_rs/src/utils/keras_utils.py,sha256=IjWSRieBkv7UX12qgUoI1tcOeISstCLRSTqSHpT06yE,1275
|
|
26
27
|
keras_rs/src/utils/pairwise_loss_utils.py,sha256=6eF4CTJubCySO8M5nd3_gdTlJsta_YMnwDCcdqWYGHA,3435
|
|
27
|
-
keras_rs_nightly-0.0.1.
|
|
28
|
-
keras_rs_nightly-0.0.1.
|
|
29
|
-
keras_rs_nightly-0.0.1.
|
|
30
|
-
keras_rs_nightly-0.0.1.
|
|
28
|
+
keras_rs_nightly-0.0.1.dev2025041603.dist-info/METADATA,sha256=9EV3mNVpTEuZyIu5Ihha5KhjRwRJnstN4vm1iXMVvQA,3547
|
|
29
|
+
keras_rs_nightly-0.0.1.dev2025041603.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
|
30
|
+
keras_rs_nightly-0.0.1.dev2025041603.dist-info/top_level.txt,sha256=pWs8X78Z0cn6lfcIb9VYOW5UeJ-TpoaO9dByzo7_FFo,9
|
|
31
|
+
keras_rs_nightly-0.0.1.dev2025041603.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|