keras-rs-nightly 0.0.1.dev2025041303__tar.gz → 0.0.1.dev2025041503__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of keras-rs-nightly might be problematic. Click here for more details.

Files changed (36) hide show
  1. {keras_rs_nightly-0.0.1.dev2025041303 → keras_rs_nightly-0.0.1.dev2025041503}/PKG-INFO +1 -1
  2. {keras_rs_nightly-0.0.1.dev2025041303 → keras_rs_nightly-0.0.1.dev2025041503}/keras_rs/api/layers/__init__.py +1 -0
  3. {keras_rs_nightly-0.0.1.dev2025041303 → keras_rs_nightly-0.0.1.dev2025041503}/keras_rs/src/layers/retrieval/brute_force_retrieval.py +11 -60
  4. keras_rs_nightly-0.0.1.dev2025041503/keras_rs/src/layers/retrieval/retrieval.py +127 -0
  5. {keras_rs_nightly-0.0.1.dev2025041303 → keras_rs_nightly-0.0.1.dev2025041503}/keras_rs/src/version.py +1 -1
  6. {keras_rs_nightly-0.0.1.dev2025041303 → keras_rs_nightly-0.0.1.dev2025041503}/keras_rs_nightly.egg-info/PKG-INFO +1 -1
  7. {keras_rs_nightly-0.0.1.dev2025041303 → keras_rs_nightly-0.0.1.dev2025041503}/keras_rs_nightly.egg-info/SOURCES.txt +1 -0
  8. {keras_rs_nightly-0.0.1.dev2025041303 → keras_rs_nightly-0.0.1.dev2025041503}/README.md +0 -0
  9. {keras_rs_nightly-0.0.1.dev2025041303 → keras_rs_nightly-0.0.1.dev2025041503}/keras_rs/__init__.py +0 -0
  10. {keras_rs_nightly-0.0.1.dev2025041303 → keras_rs_nightly-0.0.1.dev2025041503}/keras_rs/api/__init__.py +0 -0
  11. {keras_rs_nightly-0.0.1.dev2025041303 → keras_rs_nightly-0.0.1.dev2025041503}/keras_rs/api/losses/__init__.py +0 -0
  12. {keras_rs_nightly-0.0.1.dev2025041303 → keras_rs_nightly-0.0.1.dev2025041503}/keras_rs/src/__init__.py +0 -0
  13. {keras_rs_nightly-0.0.1.dev2025041303 → keras_rs_nightly-0.0.1.dev2025041503}/keras_rs/src/api_export.py +0 -0
  14. {keras_rs_nightly-0.0.1.dev2025041303 → keras_rs_nightly-0.0.1.dev2025041503}/keras_rs/src/layers/__init__.py +0 -0
  15. {keras_rs_nightly-0.0.1.dev2025041303 → keras_rs_nightly-0.0.1.dev2025041503}/keras_rs/src/layers/feature_interaction/__init__.py +0 -0
  16. {keras_rs_nightly-0.0.1.dev2025041303 → keras_rs_nightly-0.0.1.dev2025041503}/keras_rs/src/layers/feature_interaction/dot_interaction.py +0 -0
  17. {keras_rs_nightly-0.0.1.dev2025041303 → keras_rs_nightly-0.0.1.dev2025041503}/keras_rs/src/layers/feature_interaction/feature_cross.py +0 -0
  18. {keras_rs_nightly-0.0.1.dev2025041303 → keras_rs_nightly-0.0.1.dev2025041503}/keras_rs/src/layers/retrieval/__init__.py +0 -0
  19. {keras_rs_nightly-0.0.1.dev2025041303 → keras_rs_nightly-0.0.1.dev2025041503}/keras_rs/src/layers/retrieval/hard_negative_mining.py +0 -0
  20. {keras_rs_nightly-0.0.1.dev2025041303 → keras_rs_nightly-0.0.1.dev2025041503}/keras_rs/src/layers/retrieval/remove_accidental_hits.py +0 -0
  21. {keras_rs_nightly-0.0.1.dev2025041303 → keras_rs_nightly-0.0.1.dev2025041503}/keras_rs/src/layers/retrieval/sampling_probability_correction.py +0 -0
  22. {keras_rs_nightly-0.0.1.dev2025041303 → keras_rs_nightly-0.0.1.dev2025041503}/keras_rs/src/losses/__init__.py +0 -0
  23. {keras_rs_nightly-0.0.1.dev2025041303 → keras_rs_nightly-0.0.1.dev2025041503}/keras_rs/src/losses/pairwise_hinge_loss.py +0 -0
  24. {keras_rs_nightly-0.0.1.dev2025041303 → keras_rs_nightly-0.0.1.dev2025041503}/keras_rs/src/losses/pairwise_logistic_loss.py +0 -0
  25. {keras_rs_nightly-0.0.1.dev2025041303 → keras_rs_nightly-0.0.1.dev2025041503}/keras_rs/src/losses/pairwise_loss.py +0 -0
  26. {keras_rs_nightly-0.0.1.dev2025041303 → keras_rs_nightly-0.0.1.dev2025041503}/keras_rs/src/losses/pairwise_mean_squared_error.py +0 -0
  27. {keras_rs_nightly-0.0.1.dev2025041303 → keras_rs_nightly-0.0.1.dev2025041503}/keras_rs/src/losses/pairwise_soft_zero_one_loss.py +0 -0
  28. {keras_rs_nightly-0.0.1.dev2025041303 → keras_rs_nightly-0.0.1.dev2025041503}/keras_rs/src/types.py +0 -0
  29. {keras_rs_nightly-0.0.1.dev2025041303 → keras_rs_nightly-0.0.1.dev2025041503}/keras_rs/src/utils/__init__.py +0 -0
  30. {keras_rs_nightly-0.0.1.dev2025041303 → keras_rs_nightly-0.0.1.dev2025041503}/keras_rs/src/utils/keras_utils.py +0 -0
  31. {keras_rs_nightly-0.0.1.dev2025041303 → keras_rs_nightly-0.0.1.dev2025041503}/keras_rs/src/utils/pairwise_loss_utils.py +0 -0
  32. {keras_rs_nightly-0.0.1.dev2025041303 → keras_rs_nightly-0.0.1.dev2025041503}/keras_rs_nightly.egg-info/dependency_links.txt +0 -0
  33. {keras_rs_nightly-0.0.1.dev2025041303 → keras_rs_nightly-0.0.1.dev2025041503}/keras_rs_nightly.egg-info/requires.txt +0 -0
  34. {keras_rs_nightly-0.0.1.dev2025041303 → keras_rs_nightly-0.0.1.dev2025041503}/keras_rs_nightly.egg-info/top_level.txt +0 -0
  35. {keras_rs_nightly-0.0.1.dev2025041303 → keras_rs_nightly-0.0.1.dev2025041503}/pyproject.toml +0 -0
  36. {keras_rs_nightly-0.0.1.dev2025041303 → keras_rs_nightly-0.0.1.dev2025041503}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: keras-rs-nightly
3
- Version: 0.0.1.dev2025041303
3
+ Version: 0.0.1.dev2025041503
4
4
  Summary: Multi-backend recommender systems with Keras 3.
5
5
  Author-email: Keras RS team <keras-rs@google.com>
6
6
  License: Apache License 2.0
@@ -17,6 +17,7 @@ from keras_rs.src.layers.retrieval.hard_negative_mining import (
17
17
  from keras_rs.src.layers.retrieval.remove_accidental_hits import (
18
18
  RemoveAccidentalHits,
19
19
  )
20
+ from keras_rs.src.layers.retrieval.retrieval import Retrieval
20
21
  from keras_rs.src.layers.retrieval.sampling_probability_correction import (
21
22
  SamplingProbabilityCorrection,
22
23
  )
@@ -4,10 +4,11 @@ import keras
4
4
 
5
5
  from keras_rs.src import types
6
6
  from keras_rs.src.api_export import keras_rs_export
7
+ from keras_rs.src.layers.retrieval.retrieval import Retrieval
7
8
 
8
9
 
9
10
  @keras_rs_export("keras_rs.layers.BruteForceRetrieval")
10
- class BruteForceRetrieval(keras.layers.Layer):
11
+ class BruteForceRetrieval(Retrieval):
11
12
  """Brute force top-k retrieval.
12
13
 
13
14
  This layer maintains a set of candidates and is able to exactly retrieve the
@@ -60,11 +61,13 @@ class BruteForceRetrieval(keras.layers.Layer):
60
61
  return_scores: bool = True,
61
62
  **kwargs: Any,
62
63
  ) -> None:
63
- super().__init__(**kwargs)
64
+ # Keep `k`, `return_scores` as separately passed args instead of keeping
65
+ # them in `kwargs`. This is to ensure the user does not have to hop
66
+ # to the base class to check which other args can be passed.
67
+ super().__init__(k=k, return_scores=return_scores, **kwargs)
68
+
64
69
  self.candidate_embeddings = None
65
70
  self.candidate_ids = None
66
- self.k = k
67
- self.return_scores = return_scores
68
71
 
69
72
  if candidate_embeddings is None:
70
73
  if candidate_ids is not None:
@@ -84,36 +87,12 @@ class BruteForceRetrieval(keras.layers.Layer):
84
87
 
85
88
  Args:
86
89
  candidate_embeddings: The candidate embeddings.
87
- candidate_ids: The identifiers for the candidates. If `None` the
90
+ candidate_ids: The identifiers for the candidates. If `None`, the
88
91
  indices of the candidates are returned instead.
89
92
  """
90
- if candidate_embeddings is None:
91
- raise ValueError("`candidate_embeddings` is required")
92
-
93
- if len(candidate_embeddings.shape) != 2:
94
- raise ValueError(
95
- "`candidate_embeddings` must be a tensor of rank 2 "
96
- "(num_candidates, embedding_size), received "
97
- "`candidate_embeddings` with shape "
98
- f"{candidate_embeddings.shape}"
99
- )
100
-
101
- if candidate_embeddings.shape[0] < self.k:
102
- raise ValueError(
103
- "The number of candidates provided "
104
- f"({candidate_embeddings.shape[0]}) is less than the number of "
105
- f"candidates to retrieve (k={self.k})."
106
- )
107
-
108
- if (
109
- candidate_ids is not None
110
- and candidate_ids.shape[0] != candidate_embeddings.shape[0]
111
- ):
112
- raise ValueError(
113
- "The `candidate_embeddings` and `candidate_is` tensors must "
114
- "have the same number of rows, got tensors of shape "
115
- f"{candidate_embeddings.shape} and {candidate_ids.shape}."
116
- )
93
+ self._validate_candidate_embeddings_and_ids(
94
+ candidate_embeddings, candidate_ids
95
+ )
117
96
 
118
97
  if self.candidate_embeddings is not None:
119
98
  # Update of existing variables.
@@ -167,31 +146,3 @@ class BruteForceRetrieval(keras.layers.Layer):
167
146
  return top_scores, top_ids
168
147
  else:
169
148
  return top_ids
170
-
171
- def compute_score(
172
- self, query_embedding: types.Tensor, candidate_embedding: types.Tensor
173
- ) -> types.Tensor:
174
- """Computes the standard dot product score from queries and candidates.
175
-
176
- Args:
177
- query_embedding: Tensor of query embedding corresponding to the
178
- queries for which to retrieve top candidates.
179
- candidate_embedding: Tensor of candidate embeddings.
180
-
181
- Returns:
182
- The dot product of queries and candidates.
183
- """
184
-
185
- return keras.ops.matmul(
186
- query_embedding, keras.ops.transpose(candidate_embedding)
187
- )
188
-
189
- def get_config(self) -> dict[str, Any]:
190
- config: dict[str, Any] = super().get_config()
191
- config.update(
192
- {
193
- "k": self.k,
194
- "return_scores": self.compute_score,
195
- }
196
- )
197
- return config
@@ -0,0 +1,127 @@
1
+ import abc
2
+ from typing import Any, Optional, Union
3
+
4
+ import keras
5
+
6
+ from keras_rs.src import types
7
+ from keras_rs.src.api_export import keras_rs_export
8
+
9
+
10
+ @keras_rs_export("keras_rs.layers.Retrieval")
11
+ class Retrieval(keras.layers.Layer, abc.ABC):
12
+ """Retrieval base abstract class.
13
+
14
+ This layer provides a common interface for all retrieval layers. In order
15
+ to implement a custom retrieval layer, this abstract class should be
16
+ subclassed.
17
+
18
+ Args:
19
+ k: int. Number of candidates to retrieve.
20
+ return_scores: bool. When `True`, this layer returns a tuple with the
21
+ top scores and the top identifiers. When `False`, this layer returns
22
+ a single tensor with the top identifiers.
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ k: int = 10,
28
+ return_scores: bool = True,
29
+ **kwargs: Any,
30
+ ) -> None:
31
+ super().__init__(**kwargs)
32
+ self.k = k
33
+ self.return_scores = return_scores
34
+
35
+ def _validate_candidate_embeddings_and_ids(
36
+ self,
37
+ candidate_embeddings: types.Tensor,
38
+ candidate_ids: Optional[types.Tensor] = None,
39
+ ) -> None:
40
+ """Validates inputs to `update_candidates()`."""
41
+
42
+ if candidate_embeddings is None:
43
+ raise ValueError("`candidate_embeddings` is required.")
44
+
45
+ if len(candidate_embeddings.shape) != 2:
46
+ raise ValueError(
47
+ "`candidate_embeddings` must be a tensor of rank 2 "
48
+ "(num_candidates, embedding_size), received "
49
+ "`candidate_embeddings` with shape "
50
+ f"{candidate_embeddings.shape}"
51
+ )
52
+
53
+ if candidate_embeddings.shape[0] < self.k:
54
+ raise ValueError(
55
+ "The number of candidates provided "
56
+ f"({candidate_embeddings.shape[0]}) is less than the number of "
57
+ f"candidates to retrieve (k={self.k})."
58
+ )
59
+
60
+ if (
61
+ candidate_ids is not None
62
+ and candidate_ids.shape[0] != candidate_embeddings.shape[0]
63
+ ):
64
+ raise ValueError(
65
+ "The `candidate_embeddings` and `candidate_is` tensors must "
66
+ "have the same number of rows, got tensors of shape "
67
+ f"{candidate_embeddings.shape} and {candidate_ids.shape}."
68
+ )
69
+
70
+ @abc.abstractmethod
71
+ def update_candidates(
72
+ self,
73
+ candidate_embeddings: types.Tensor,
74
+ candidate_ids: Optional[types.Tensor] = None,
75
+ ) -> None:
76
+ """Update the set of candidates and optionally their candidate IDs.
77
+
78
+ Args:
79
+ candidate_embeddings: The candidate embeddings.
80
+ candidate_ids: The identifiers for the candidates. If `None`, the
81
+ indices of the candidates are returned instead.
82
+ """
83
+ pass
84
+
85
+ @abc.abstractmethod
86
+ def call(
87
+ self, inputs: types.Tensor
88
+ ) -> Union[types.Tensor, tuple[types.Tensor, types.Tensor]]:
89
+ """Returns the top candidates for the query passed as input.
90
+
91
+ Args:
92
+ inputs: the query for which to return top candidates.
93
+
94
+ Returns:
95
+ A tuple with the top scores and the top identifiers if
96
+ `returns_scores` is True, otherwise a tensor with the top
97
+ identifiers.
98
+ """
99
+ pass
100
+
101
+ def compute_score(
102
+ self, query_embedding: types.Tensor, candidate_embedding: types.Tensor
103
+ ) -> types.Tensor:
104
+ """Computes the standard dot product score from queries and candidates.
105
+
106
+ Args:
107
+ query_embedding: Tensor of query embedding corresponding to the
108
+ queries for which to retrieve top candidates.
109
+ candidate_embedding: Tensor of candidate embeddings.
110
+
111
+ Returns:
112
+ The dot product of queries and candidates.
113
+ """
114
+
115
+ return keras.ops.matmul(
116
+ query_embedding, keras.ops.transpose(candidate_embedding)
117
+ )
118
+
119
+ def get_config(self) -> dict[str, Any]:
120
+ config: dict[str, Any] = super().get_config()
121
+ config.update(
122
+ {
123
+ "k": self.k,
124
+ "return_scores": self.compute_score,
125
+ }
126
+ )
127
+ return config
@@ -1,7 +1,7 @@
1
1
  from keras_rs.src.api_export import keras_rs_export
2
2
 
3
3
  # Unique source of truth for the version number.
4
- __version__ = "0.0.1.dev2025041303"
4
+ __version__ = "0.0.1.dev2025041503"
5
5
 
6
6
 
7
7
  @keras_rs_export("keras_rs.version")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: keras-rs-nightly
3
- Version: 0.0.1.dev2025041303
3
+ Version: 0.0.1.dev2025041503
4
4
  Summary: Multi-backend recommender systems with Keras 3.
5
5
  Author-email: Keras RS team <keras-rs@google.com>
6
6
  License: Apache License 2.0
@@ -16,6 +16,7 @@ keras_rs/src/layers/retrieval/__init__.py
16
16
  keras_rs/src/layers/retrieval/brute_force_retrieval.py
17
17
  keras_rs/src/layers/retrieval/hard_negative_mining.py
18
18
  keras_rs/src/layers/retrieval/remove_accidental_hits.py
19
+ keras_rs/src/layers/retrieval/retrieval.py
19
20
  keras_rs/src/layers/retrieval/sampling_probability_correction.py
20
21
  keras_rs/src/losses/__init__.py
21
22
  keras_rs/src/losses/pairwise_hinge_loss.py