keras-rs-nightly 0.0.1.dev2025040703__py3-none-any.whl → 0.0.1.dev2025040803__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of keras-rs-nightly might be problematic. Click here for more details.

@@ -14,6 +14,9 @@ from keras_rs.src.layers.retrieval.brute_force_retrieval import (
14
14
  from keras_rs.src.layers.retrieval.hard_negative_mining import (
15
15
  HardNegativeMining,
16
16
  )
17
+ from keras_rs.src.layers.retrieval.remove_accidental_hits import (
18
+ RemoveAccidentalHits,
19
+ )
17
20
  from keras_rs.src.layers.retrieval.sampling_probability_correction import (
18
21
  SamplingProbabilityCorrection,
19
22
  )
@@ -1,13 +1,13 @@
1
1
  from typing import Any
2
2
 
3
3
  import keras
4
- import numpy as np
4
+ import ml_dtypes
5
5
  from keras import ops
6
6
 
7
7
  from keras_rs.src import types
8
8
  from keras_rs.src.api_export import keras_rs_export
9
9
 
10
- MAX_FLOAT = np.finfo(np.float32).max / 100.0
10
+ MAX_FLOAT = ml_dtypes.finfo("float32").max / 100.0
11
11
 
12
12
 
13
13
  def _gather_elements_along_row(
@@ -0,0 +1,84 @@
1
+ import keras
2
+ import ml_dtypes
3
+ from keras import ops
4
+
5
+ from keras_rs.src import types
6
+ from keras_rs.src.api_export import keras_rs_export
7
+ from keras_rs.src.utils import keras_utils
8
+
9
+ SMALLEST_FLOAT = ml_dtypes.finfo("float32").smallest_normal / 100.0
10
+
11
+
12
+ @keras_rs_export("keras_rs.layers.RemoveAccidentalHits")
13
+ class RemoveAccidentalHits(keras.layers.Layer):
14
+ """Zeroes the logits of accidental negatives.
15
+
16
+ Zeroes the logits of negative candidates that have the same ID as the
17
+ positive candidate in that row.
18
+ """
19
+
20
+ def call(
21
+ self,
22
+ labels: types.Tensor,
23
+ logits: types.Tensor,
24
+ candidate_ids: types.Tensor,
25
+ ) -> types.Tensor:
26
+ """Zeroes selected logits.
27
+
28
+ For each row in the batch, zeroes the logits of negative candidates that
29
+ have the same ID as the positive candidate in that row.
30
+
31
+ Args:
32
+ labels: one-hot labels tensor, typically
33
+ `[batch_size, num_candidates]` but can have more dimensions or be
34
+ 1D as `[num_candidates]`.
35
+ logits: logits tensor. Must have the same shape as `labels`.
36
+ candidate_ids: candidate identifiers tensor, can be `[num_candidates]`
37
+ or `[batch_size, num_candidates]` or have more dimensions as long
38
+ as they match the last dimensions of `labels`.
39
+
40
+ Returns:
41
+ logits: Modified logits.
42
+ """
43
+ # A more principled way is to implement
44
+ # `softmax_cross_entropy_with_logits` with a input mask. Here we
45
+ # approximate so by letting accidental hits have extremely small logits
46
+ # (SMALLEST_FLOAT) for ease-of-implementation.
47
+
48
+ labels_shape = ops.shape(labels)
49
+ labels_rank = len(labels_shape)
50
+ logits_shape = ops.shape(logits)
51
+ candidate_ids_shape = ops.shape(candidate_ids)
52
+ candidate_ids_rank = len(candidate_ids_shape)
53
+
54
+ if not keras_utils.check_shapes_compatible(labels_shape, logits_shape):
55
+ raise ValueError(
56
+ "`labels` and `logits` should have the same shape. Received: "
57
+ f"`labels.shape` = {labels_shape}, "
58
+ f"`logits.shape` = {logits_shape}."
59
+ )
60
+
61
+ if not keras_utils.check_shapes_compatible(
62
+ labels_shape[-candidate_ids_rank:], candidate_ids_shape
63
+ ):
64
+ raise ValueError(
65
+ "`candidate_ids` should have the same shape as the last "
66
+ "dimensions of `labels`. Received: "
67
+ f"`candidate_ids.shape` = {candidate_ids_shape}, "
68
+ f"`labels.shape` = {labels_shape}."
69
+ )
70
+
71
+ # Add dimensions to `candidate_ids` to have the same rank as `labels`.
72
+ if candidate_ids_rank < labels_rank:
73
+ candidate_ids = ops.expand_dims(
74
+ candidate_ids, list(range(labels_rank - candidate_ids_rank))
75
+ )
76
+ positive_indices = ops.expand_dims(ops.argmax(labels, axis=-1), -1)
77
+ positive_candidate_ids = ops.take(candidate_ids, positive_indices)
78
+
79
+ duplicate = ops.cast(
80
+ ops.equal(positive_candidate_ids, candidate_ids), labels.dtype
81
+ )
82
+ duplicate = ops.subtract(duplicate, labels)
83
+
84
+ return ops.add(logits, ops.multiply(duplicate, SMALLEST_FLOAT))
keras_rs/src/version.py CHANGED
@@ -1,7 +1,7 @@
1
1
  from keras_rs.src.api_export import keras_rs_export
2
2
 
3
3
  # Unique source of truth for the version number.
4
- __version__ = "0.0.1.dev2025040703"
4
+ __version__ = "0.0.1.dev2025040803"
5
5
 
6
6
 
7
7
  @keras_rs_export("keras_rs.version")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: keras-rs-nightly
3
- Version: 0.0.1.dev2025040703
3
+ Version: 0.0.1.dev2025040803
4
4
  Summary: Multi-backend recommender systems with Keras 3.
5
5
  Author-email: Keras RS team <keras-rs@google.com>
6
6
  License: Apache License 2.0
@@ -20,6 +20,7 @@ Classifier: Topic :: Software Development
20
20
  Requires-Python: >=3.9
21
21
  Description-Content-Type: text/markdown
22
22
  Requires-Dist: keras
23
+ Requires-Dist: ml-dtypes
23
24
 
24
25
  # Keras Recommenders
25
26
 
@@ -1,18 +1,19 @@
1
1
  keras_rs/__init__.py,sha256=X3VNKb_6VDEs5GqcbEc_l8mAsefWb5UgSu8krnQdFcM,794
2
2
  keras_rs/api/__init__.py,sha256=9Xf-uH9j_SBaTc5RU0pkxrOEgHWPwSKjf4_maySH_nU,272
3
- keras_rs/api/layers/__init__.py,sha256=uZgnycSMTUalc2WfFaynmV3PpvZ4XeTttkwjcjsNMvk,590
3
+ keras_rs/api/layers/__init__.py,sha256=VvsmrU1T198D8svbjAk7bysRRfDeT71A-9wv7NyPC28,685
4
4
  keras_rs/api/losses/__init__.py,sha256=LGW7eHQh8FbQXdMV1s9zJpbloVlz_Zlo51sorWAvFwE,455
5
5
  keras_rs/src/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
6
  keras_rs/src/api_export.py,sha256=RsmG-DvO-cdFeAF9W6LRzms0kvtm-Yp9BAA_d-952zI,510
7
7
  keras_rs/src/types.py,sha256=UyOdgjqrqg_b58opnY8n6gTiDHKVR8z_bmEruehERBk,514
8
- keras_rs/src/version.py,sha256=BT4v6Amhk3aylUK7NppTKzinXu04ers7bM8r-DlIpGs,222
8
+ keras_rs/src/version.py,sha256=a9BItWgRhEYrbx2zCLxStYnDVn5qAEy8Ovsszh2TS_o,222
9
9
  keras_rs/src/layers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
10
  keras_rs/src/layers/feature_interaction/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
11
  keras_rs/src/layers/feature_interaction/dot_interaction.py,sha256=jGHcg0EiWxth6LTxG2yWgHcyx_GXrxvA61uQqpPfnDQ,6900
12
12
  keras_rs/src/layers/feature_interaction/feature_cross.py,sha256=5OCSI0vFYzJNmgkKcuHIbVv8U2q3UvS80-qZjPimDjM,8155
13
13
  keras_rs/src/layers/retrieval/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
14
  keras_rs/src/layers/retrieval/brute_force_retrieval.py,sha256=mohILOt6PC6jHBztaowDbj3QBnSetuvkq55FmE39PlY,7321
15
- keras_rs/src/layers/retrieval/hard_negative_mining.py,sha256=8c44iEUmK_SwfHPdrVtg96ycBZzYf62ee66a49mqCbU,4115
15
+ keras_rs/src/layers/retrieval/hard_negative_mining.py,sha256=CY8-3W52ZBIFcEfvjXJxbFltD6ulXl4-sZCRF6stIEc,4119
16
+ keras_rs/src/layers/retrieval/remove_accidental_hits.py,sha256=fiFQLlkMBXhG8V7a8mv_hKOwlqEJeUiMBYUVQw1woTE,3270
16
17
  keras_rs/src/layers/retrieval/sampling_probability_correction.py,sha256=80vgOPfBiF-PC0dSyqS57IcIxOxi_Q_R7eSXHn1G0yI,1437
17
18
  keras_rs/src/losses/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
18
19
  keras_rs/src/losses/pairwise_hinge_loss.py,sha256=vqDGd-OnZxiqdeE6vuabE8BKDfill3D2GM0lW5JUmsg,922
@@ -23,7 +24,7 @@ keras_rs/src/losses/pairwise_soft_zero_one_loss.py,sha256=XBej5nybFXEQ-Vp6GLvNmq
23
24
  keras_rs/src/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
24
25
  keras_rs/src/utils/keras_utils.py,sha256=IjWSRieBkv7UX12qgUoI1tcOeISstCLRSTqSHpT06yE,1275
25
26
  keras_rs/src/utils/pairwise_loss_utils.py,sha256=5SAqA3z30A1awzV9l5oVbcno5Z6HXARkNcUFTPL7_jg,3380
26
- keras_rs_nightly-0.0.1.dev2025040703.dist-info/METADATA,sha256=W0bJDXs5KnTBjcilC9QnfkBXT-clmCaxg79UOmU83OE,3522
27
- keras_rs_nightly-0.0.1.dev2025040703.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
28
- keras_rs_nightly-0.0.1.dev2025040703.dist-info/top_level.txt,sha256=pWs8X78Z0cn6lfcIb9VYOW5UeJ-TpoaO9dByzo7_FFo,9
29
- keras_rs_nightly-0.0.1.dev2025040703.dist-info/RECORD,,
27
+ keras_rs_nightly-0.0.1.dev2025040803.dist-info/METADATA,sha256=oC2tgoUja1GWeMPNM6ZFmrBJNnuPof-Eb8LioVF3kYs,3547
28
+ keras_rs_nightly-0.0.1.dev2025040803.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
29
+ keras_rs_nightly-0.0.1.dev2025040803.dist-info/top_level.txt,sha256=pWs8X78Z0cn6lfcIb9VYOW5UeJ-TpoaO9dByzo7_FFo,9
30
+ keras_rs_nightly-0.0.1.dev2025040803.dist-info/RECORD,,