keras-rs-nightly 0.0.1.dev2025030403__py3-none-any.whl → 0.0.1.dev2025030503__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of keras-rs-nightly might be problematic. Click here for more details.
- keras_rs/api/layers/__init__.py +3 -0
- keras_rs/src/layers/retrieval/sampling_probability_correction.py +48 -0
- keras_rs/src/version.py +1 -1
- {keras_rs_nightly-0.0.1.dev2025030403.dist-info → keras_rs_nightly-0.0.1.dev2025030503.dist-info}/METADATA +1 -1
- {keras_rs_nightly-0.0.1.dev2025030403.dist-info → keras_rs_nightly-0.0.1.dev2025030503.dist-info}/RECORD +7 -6
- {keras_rs_nightly-0.0.1.dev2025030403.dist-info → keras_rs_nightly-0.0.1.dev2025030503.dist-info}/WHEEL +0 -0
- {keras_rs_nightly-0.0.1.dev2025030403.dist-info → keras_rs_nightly-0.0.1.dev2025030503.dist-info}/top_level.txt +0 -0
keras_rs/api/layers/__init__.py
CHANGED
|
@@ -12,3 +12,6 @@ from keras_rs.src.layers.retrieval.brute_force_retrieval import (
|
|
|
12
12
|
from keras_rs.src.layers.retrieval.hard_negative_mining import (
|
|
13
13
|
HardNegativeMining,
|
|
14
14
|
)
|
|
15
|
+
from keras_rs.src.layers.retrieval.sampling_probability_correction import (
|
|
16
|
+
SamplingProbabilityCorrection,
|
|
17
|
+
)
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
import keras
|
|
4
|
+
from keras import ops
|
|
5
|
+
|
|
6
|
+
from keras_rs.src import types
|
|
7
|
+
from keras_rs.src.api_export import keras_rs_export
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@keras_rs_export("keras_rs.layers.SamplingProbabilityCorrection")
|
|
11
|
+
class SamplingProbabilityCorrection(keras.layers.Layer):
|
|
12
|
+
"""Sampling probability correction.
|
|
13
|
+
|
|
14
|
+
Corrects the logits to reflect the sampling probability of negatives.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
epsilon: float. Small float added to sampling probability to avoid
|
|
18
|
+
taking the log of zero. Defaults to 1e-6.
|
|
19
|
+
**kwargs: Args to pass to the base class.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(self, epsilon: float = 1e-6, **kwargs: Any) -> None:
|
|
23
|
+
super().__init__(**kwargs)
|
|
24
|
+
self.epsilon = epsilon
|
|
25
|
+
self.built = True
|
|
26
|
+
|
|
27
|
+
def call(
|
|
28
|
+
self,
|
|
29
|
+
logits: types.Tensor,
|
|
30
|
+
candidate_sampling_probability: types.Tensor,
|
|
31
|
+
) -> types.Tensor:
|
|
32
|
+
"""Corrects input logits to account for candidate sampling probability.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
logits: The logits to correct.
|
|
36
|
+
candidate_sampling_probability: The sampling probability.
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
The corrected logits.
|
|
40
|
+
"""
|
|
41
|
+
return logits - ops.log(
|
|
42
|
+
ops.clip(candidate_sampling_probability, self.epsilon, 1.0)
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
def get_config(self) -> dict[str, Any]:
|
|
46
|
+
config: dict[str, Any] = super().get_config()
|
|
47
|
+
config.update({"epsilon": self.epsilon})
|
|
48
|
+
return config
|
keras_rs/src/version.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
keras_rs/__init__.py,sha256=X3VNKb_6VDEs5GqcbEc_l8mAsefWb5UgSu8krnQdFcM,794
|
|
2
2
|
keras_rs/api/__init__.py,sha256=BU34SgAwrjZh49ppXGxJPxreTOYW0C7vy3_x8nmvHUk,240
|
|
3
|
-
keras_rs/api/layers/__init__.py,sha256=
|
|
3
|
+
keras_rs/api/layers/__init__.py,sha256=MOWXxHFJQKriffbD-k8giN87CXdzrbRwayzmDipoBIo,559
|
|
4
4
|
keras_rs/src/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
5
5
|
keras_rs/src/api_export.py,sha256=RsmG-DvO-cdFeAF9W6LRzms0kvtm-Yp9BAA_d-952zI,510
|
|
6
6
|
keras_rs/src/types.py,sha256=UyOdgjqrqg_b58opnY8n6gTiDHKVR8z_bmEruehERBk,514
|
|
7
|
-
keras_rs/src/version.py,sha256=
|
|
7
|
+
keras_rs/src/version.py,sha256=aGxhmch--1ygcaY54yqn6fBeP77KQEG3cXmhlwv_c-4,222
|
|
8
8
|
keras_rs/src/layers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
9
9
|
keras_rs/src/layers/modeling/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
10
10
|
keras_rs/src/layers/modeling/dot_interaction.py,sha256=jGHcg0EiWxth6LTxG2yWgHcyx_GXrxvA61uQqpPfnDQ,6900
|
|
@@ -12,9 +12,10 @@ keras_rs/src/layers/modeling/feature_cross.py,sha256=5OCSI0vFYzJNmgkKcuHIbVv8U2q
|
|
|
12
12
|
keras_rs/src/layers/retrieval/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
13
13
|
keras_rs/src/layers/retrieval/brute_force_retrieval.py,sha256=mohILOt6PC6jHBztaowDbj3QBnSetuvkq55FmE39PlY,7321
|
|
14
14
|
keras_rs/src/layers/retrieval/hard_negative_mining.py,sha256=8c44iEUmK_SwfHPdrVtg96ycBZzYf62ee66a49mqCbU,4115
|
|
15
|
+
keras_rs/src/layers/retrieval/sampling_probability_correction.py,sha256=80vgOPfBiF-PC0dSyqS57IcIxOxi_Q_R7eSXHn1G0yI,1437
|
|
15
16
|
keras_rs/src/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
16
17
|
keras_rs/src/utils/keras_utils.py,sha256=IjWSRieBkv7UX12qgUoI1tcOeISstCLRSTqSHpT06yE,1275
|
|
17
|
-
keras_rs_nightly-0.0.1.
|
|
18
|
-
keras_rs_nightly-0.0.1.
|
|
19
|
-
keras_rs_nightly-0.0.1.
|
|
20
|
-
keras_rs_nightly-0.0.1.
|
|
18
|
+
keras_rs_nightly-0.0.1.dev2025030503.dist-info/METADATA,sha256=AA3mZmkM2S2joC0BI5EcuM_kwFk61MPMjqYEXYoguuE,3522
|
|
19
|
+
keras_rs_nightly-0.0.1.dev2025030503.dist-info/WHEEL,sha256=jB7zZ3N9hIM9adW7qlTAyycLYW9npaWKLRzaoVcLKcM,91
|
|
20
|
+
keras_rs_nightly-0.0.1.dev2025030503.dist-info/top_level.txt,sha256=pWs8X78Z0cn6lfcIb9VYOW5UeJ-TpoaO9dByzo7_FFo,9
|
|
21
|
+
keras_rs_nightly-0.0.1.dev2025030503.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|