keras-rs-nightly 0.0.1.dev2025040703__py3-none-any.whl → 0.0.1.dev2025040803__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of keras-rs-nightly might be problematic. Click here for more details.
- keras_rs/api/layers/__init__.py +3 -0
- keras_rs/src/layers/retrieval/hard_negative_mining.py +2 -2
- keras_rs/src/layers/retrieval/remove_accidental_hits.py +84 -0
- keras_rs/src/version.py +1 -1
- {keras_rs_nightly-0.0.1.dev2025040703.dist-info → keras_rs_nightly-0.0.1.dev2025040803.dist-info}/METADATA +2 -1
- {keras_rs_nightly-0.0.1.dev2025040703.dist-info → keras_rs_nightly-0.0.1.dev2025040803.dist-info}/RECORD +8 -7
- {keras_rs_nightly-0.0.1.dev2025040703.dist-info → keras_rs_nightly-0.0.1.dev2025040803.dist-info}/WHEEL +0 -0
- {keras_rs_nightly-0.0.1.dev2025040703.dist-info → keras_rs_nightly-0.0.1.dev2025040803.dist-info}/top_level.txt +0 -0
keras_rs/api/layers/__init__.py
CHANGED
|
@@ -14,6 +14,9 @@ from keras_rs.src.layers.retrieval.brute_force_retrieval import (
|
|
|
14
14
|
from keras_rs.src.layers.retrieval.hard_negative_mining import (
|
|
15
15
|
HardNegativeMining,
|
|
16
16
|
)
|
|
17
|
+
from keras_rs.src.layers.retrieval.remove_accidental_hits import (
|
|
18
|
+
RemoveAccidentalHits,
|
|
19
|
+
)
|
|
17
20
|
from keras_rs.src.layers.retrieval.sampling_probability_correction import (
|
|
18
21
|
SamplingProbabilityCorrection,
|
|
19
22
|
)
|
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
from typing import Any
|
|
2
2
|
|
|
3
3
|
import keras
|
|
4
|
-
import
|
|
4
|
+
import ml_dtypes
|
|
5
5
|
from keras import ops
|
|
6
6
|
|
|
7
7
|
from keras_rs.src import types
|
|
8
8
|
from keras_rs.src.api_export import keras_rs_export
|
|
9
9
|
|
|
10
|
-
MAX_FLOAT =
|
|
10
|
+
MAX_FLOAT = ml_dtypes.finfo("float32").max / 100.0
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
def _gather_elements_along_row(
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
import keras
|
|
2
|
+
import ml_dtypes
|
|
3
|
+
from keras import ops
|
|
4
|
+
|
|
5
|
+
from keras_rs.src import types
|
|
6
|
+
from keras_rs.src.api_export import keras_rs_export
|
|
7
|
+
from keras_rs.src.utils import keras_utils
|
|
8
|
+
|
|
9
|
+
SMALLEST_FLOAT = ml_dtypes.finfo("float32").smallest_normal / 100.0
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@keras_rs_export("keras_rs.layers.RemoveAccidentalHits")
|
|
13
|
+
class RemoveAccidentalHits(keras.layers.Layer):
|
|
14
|
+
"""Zeroes the logits of accidental negatives.
|
|
15
|
+
|
|
16
|
+
Zeroes the logits of negative candidates that have the same ID as the
|
|
17
|
+
positive candidate in that row.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def call(
|
|
21
|
+
self,
|
|
22
|
+
labels: types.Tensor,
|
|
23
|
+
logits: types.Tensor,
|
|
24
|
+
candidate_ids: types.Tensor,
|
|
25
|
+
) -> types.Tensor:
|
|
26
|
+
"""Zeroes selected logits.
|
|
27
|
+
|
|
28
|
+
For each row in the batch, zeroes the logits of negative candidates that
|
|
29
|
+
have the same ID as the positive candidate in that row.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
labels: one-hot labels tensor, typically
|
|
33
|
+
`[batch_size, num_candidates]` but can have more dimensions or be
|
|
34
|
+
1D as `[num_candidates]`.
|
|
35
|
+
logits: logits tensor. Must have the same shape as `labels`.
|
|
36
|
+
candidate_ids: candidate identifiers tensor, can be `[num_candidates]`
|
|
37
|
+
or `[batch_size, num_candidates]` or have more dimensions as long
|
|
38
|
+
as they match the last dimensions of `labels`.
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
logits: Modified logits.
|
|
42
|
+
"""
|
|
43
|
+
# A more principled way is to implement
|
|
44
|
+
# `softmax_cross_entropy_with_logits` with a input mask. Here we
|
|
45
|
+
# approximate so by letting accidental hits have extremely small logits
|
|
46
|
+
# (SMALLEST_FLOAT) for ease-of-implementation.
|
|
47
|
+
|
|
48
|
+
labels_shape = ops.shape(labels)
|
|
49
|
+
labels_rank = len(labels_shape)
|
|
50
|
+
logits_shape = ops.shape(logits)
|
|
51
|
+
candidate_ids_shape = ops.shape(candidate_ids)
|
|
52
|
+
candidate_ids_rank = len(candidate_ids_shape)
|
|
53
|
+
|
|
54
|
+
if not keras_utils.check_shapes_compatible(labels_shape, logits_shape):
|
|
55
|
+
raise ValueError(
|
|
56
|
+
"`labels` and `logits` should have the same shape. Received: "
|
|
57
|
+
f"`labels.shape` = {labels_shape}, "
|
|
58
|
+
f"`logits.shape` = {logits_shape}."
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
if not keras_utils.check_shapes_compatible(
|
|
62
|
+
labels_shape[-candidate_ids_rank:], candidate_ids_shape
|
|
63
|
+
):
|
|
64
|
+
raise ValueError(
|
|
65
|
+
"`candidate_ids` should have the same shape as the last "
|
|
66
|
+
"dimensions of `labels`. Received: "
|
|
67
|
+
f"`candidate_ids.shape` = {candidate_ids_shape}, "
|
|
68
|
+
f"`labels.shape` = {labels_shape}."
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
# Add dimensions to `candidate_ids` to have the same rank as `labels`.
|
|
72
|
+
if candidate_ids_rank < labels_rank:
|
|
73
|
+
candidate_ids = ops.expand_dims(
|
|
74
|
+
candidate_ids, list(range(labels_rank - candidate_ids_rank))
|
|
75
|
+
)
|
|
76
|
+
positive_indices = ops.expand_dims(ops.argmax(labels, axis=-1), -1)
|
|
77
|
+
positive_candidate_ids = ops.take(candidate_ids, positive_indices)
|
|
78
|
+
|
|
79
|
+
duplicate = ops.cast(
|
|
80
|
+
ops.equal(positive_candidate_ids, candidate_ids), labels.dtype
|
|
81
|
+
)
|
|
82
|
+
duplicate = ops.subtract(duplicate, labels)
|
|
83
|
+
|
|
84
|
+
return ops.add(logits, ops.multiply(duplicate, SMALLEST_FLOAT))
|
keras_rs/src/version.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: keras-rs-nightly
|
|
3
|
-
Version: 0.0.1.
|
|
3
|
+
Version: 0.0.1.dev2025040803
|
|
4
4
|
Summary: Multi-backend recommender systems with Keras 3.
|
|
5
5
|
Author-email: Keras RS team <keras-rs@google.com>
|
|
6
6
|
License: Apache License 2.0
|
|
@@ -20,6 +20,7 @@ Classifier: Topic :: Software Development
|
|
|
20
20
|
Requires-Python: >=3.9
|
|
21
21
|
Description-Content-Type: text/markdown
|
|
22
22
|
Requires-Dist: keras
|
|
23
|
+
Requires-Dist: ml-dtypes
|
|
23
24
|
|
|
24
25
|
# Keras Recommenders
|
|
25
26
|
|
|
@@ -1,18 +1,19 @@
|
|
|
1
1
|
keras_rs/__init__.py,sha256=X3VNKb_6VDEs5GqcbEc_l8mAsefWb5UgSu8krnQdFcM,794
|
|
2
2
|
keras_rs/api/__init__.py,sha256=9Xf-uH9j_SBaTc5RU0pkxrOEgHWPwSKjf4_maySH_nU,272
|
|
3
|
-
keras_rs/api/layers/__init__.py,sha256=
|
|
3
|
+
keras_rs/api/layers/__init__.py,sha256=VvsmrU1T198D8svbjAk7bysRRfDeT71A-9wv7NyPC28,685
|
|
4
4
|
keras_rs/api/losses/__init__.py,sha256=LGW7eHQh8FbQXdMV1s9zJpbloVlz_Zlo51sorWAvFwE,455
|
|
5
5
|
keras_rs/src/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
6
6
|
keras_rs/src/api_export.py,sha256=RsmG-DvO-cdFeAF9W6LRzms0kvtm-Yp9BAA_d-952zI,510
|
|
7
7
|
keras_rs/src/types.py,sha256=UyOdgjqrqg_b58opnY8n6gTiDHKVR8z_bmEruehERBk,514
|
|
8
|
-
keras_rs/src/version.py,sha256=
|
|
8
|
+
keras_rs/src/version.py,sha256=a9BItWgRhEYrbx2zCLxStYnDVn5qAEy8Ovsszh2TS_o,222
|
|
9
9
|
keras_rs/src/layers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
10
10
|
keras_rs/src/layers/feature_interaction/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
11
11
|
keras_rs/src/layers/feature_interaction/dot_interaction.py,sha256=jGHcg0EiWxth6LTxG2yWgHcyx_GXrxvA61uQqpPfnDQ,6900
|
|
12
12
|
keras_rs/src/layers/feature_interaction/feature_cross.py,sha256=5OCSI0vFYzJNmgkKcuHIbVv8U2q3UvS80-qZjPimDjM,8155
|
|
13
13
|
keras_rs/src/layers/retrieval/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
14
14
|
keras_rs/src/layers/retrieval/brute_force_retrieval.py,sha256=mohILOt6PC6jHBztaowDbj3QBnSetuvkq55FmE39PlY,7321
|
|
15
|
-
keras_rs/src/layers/retrieval/hard_negative_mining.py,sha256=
|
|
15
|
+
keras_rs/src/layers/retrieval/hard_negative_mining.py,sha256=CY8-3W52ZBIFcEfvjXJxbFltD6ulXl4-sZCRF6stIEc,4119
|
|
16
|
+
keras_rs/src/layers/retrieval/remove_accidental_hits.py,sha256=fiFQLlkMBXhG8V7a8mv_hKOwlqEJeUiMBYUVQw1woTE,3270
|
|
16
17
|
keras_rs/src/layers/retrieval/sampling_probability_correction.py,sha256=80vgOPfBiF-PC0dSyqS57IcIxOxi_Q_R7eSXHn1G0yI,1437
|
|
17
18
|
keras_rs/src/losses/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
18
19
|
keras_rs/src/losses/pairwise_hinge_loss.py,sha256=vqDGd-OnZxiqdeE6vuabE8BKDfill3D2GM0lW5JUmsg,922
|
|
@@ -23,7 +24,7 @@ keras_rs/src/losses/pairwise_soft_zero_one_loss.py,sha256=XBej5nybFXEQ-Vp6GLvNmq
|
|
|
23
24
|
keras_rs/src/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
24
25
|
keras_rs/src/utils/keras_utils.py,sha256=IjWSRieBkv7UX12qgUoI1tcOeISstCLRSTqSHpT06yE,1275
|
|
25
26
|
keras_rs/src/utils/pairwise_loss_utils.py,sha256=5SAqA3z30A1awzV9l5oVbcno5Z6HXARkNcUFTPL7_jg,3380
|
|
26
|
-
keras_rs_nightly-0.0.1.
|
|
27
|
-
keras_rs_nightly-0.0.1.
|
|
28
|
-
keras_rs_nightly-0.0.1.
|
|
29
|
-
keras_rs_nightly-0.0.1.
|
|
27
|
+
keras_rs_nightly-0.0.1.dev2025040803.dist-info/METADATA,sha256=oC2tgoUja1GWeMPNM6ZFmrBJNnuPof-Eb8LioVF3kYs,3547
|
|
28
|
+
keras_rs_nightly-0.0.1.dev2025040803.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
|
29
|
+
keras_rs_nightly-0.0.1.dev2025040803.dist-info/top_level.txt,sha256=pWs8X78Z0cn6lfcIb9VYOW5UeJ-TpoaO9dByzo7_FFo,9
|
|
30
|
+
keras_rs_nightly-0.0.1.dev2025040803.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|