keras-rs-nightly 0.0.1.dev2025041603__py3-none-any.whl → 0.0.1.dev2025041803__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of keras-rs-nightly might be problematic. Click here for more details.

@@ -10,41 +10,6 @@ from keras_rs.src.api_export import keras_rs_export
10
10
  MAX_FLOAT = ml_dtypes.finfo("float32").max / 100.0
11
11
 
12
12
 
13
- def _gather_elements_along_row(
14
- data: types.Tensor, column_indices: types.Tensor
15
- ) -> types.Tensor:
16
- """Gathers elements from a 2D tensor given the column indices of each row.
17
-
18
- First, gets the flat 1D indices to gather from. Then flattens the data to 1D
19
- and uses `ops.take()` to generate 1D output and finally reshapes the output
20
- back to 2D.
21
-
22
- Args:
23
- data: A [N, M] 2D `Tensor`.
24
- column_indices: A [N, K] 2D `Tensor` denoting for each row, the K column
25
- indices to gather elements from the data `Tensor`.
26
-
27
- Returns:
28
- A [N, K] `Tensor` including output elements gathered from data `Tensor`.
29
-
30
- Raises:
31
- ValueError: if the first dimensions of data and column_indices don't
32
- match.
33
- """
34
- num_row, num_column, *_ = ops.shape(data)
35
- num_gathered = ops.shape(column_indices)[1]
36
- row_indices = ops.tile(
37
- ops.expand_dims(ops.arange(num_row), -1), [1, num_gathered]
38
- )
39
- flat_data = ops.reshape(data, [-1])
40
- flat_indices = ops.reshape(
41
- ops.add(ops.multiply(row_indices, num_column), column_indices), [-1]
42
- )
43
- return ops.reshape(
44
- ops.take(flat_data, flat_indices), [num_row, num_gathered]
45
- )
46
-
47
-
48
13
  @keras_rs_export("keras_rs.layers.HardNegativeMining")
49
14
  class HardNegativeMining(keras.layers.Layer):
50
15
  """Transforms logits and labels to return hard negatives.
@@ -68,21 +33,21 @@ class HardNegativeMining(keras.layers.Layer):
68
33
  negatives as well as the positive candidate.
69
34
 
70
35
  Args:
71
- logits: `[batch_size, number_of_candidates]` tensor of logits.
72
- labels: `[batch_size, number_of_candidates]` one-hot tensor of
73
- labels.
36
+ logits: logits tensor, typically `[batch_size, num_candidates]` but
37
+ can have more dimensions or be 1D as `[num_candidates]`.
38
+ labels: one-hot labels tensor, must be the same shape as `logits`.
74
39
 
75
40
  Returns:
76
- tuple containing:
77
- - logits: `[batch_size, num_hard_negatives + 1]` tensor of logits.
78
- - labels: `[batch_size, num_hard_negatives + 1]` one-hot tensor of
79
- labels.
41
+ tuple containing two tensors with the last dimension of
42
+ `num_candidates` replaced with `num_hard_negatives + 1`.
43
+ - logits: `[..., num_hard_negatives + 1]` tensor of logits.
44
+ - labels: `[..., num_hard_negatives + 1]` one-hot tensor of labels.
80
45
  """
81
46
 
82
47
  # Number of sampled logits, i.e, the number of hard negatives to be
83
48
  # sampled (k) + number of true logit (1) per query, capped by batch
84
49
  # size.
85
- num_logits = ops.shape(logits)[1]
50
+ num_logits = ops.shape(logits)[-1]
86
51
  if isinstance(num_logits, int):
87
52
  num_sampled = min(self._num_hard_negatives + 1, num_logits)
88
53
  else:
@@ -98,14 +63,14 @@ class HardNegativeMining(keras.layers.Layer):
98
63
  # For each query, get the indices of the logits which have the highest
99
64
  # k + 1 logit values, including the highest k negative logits and one
100
65
  # true logit.
101
- _, col_indices = ops.top_k(
66
+ _, indices = ops.top_k(
102
67
  ops.add(logits, ops.multiply(labels, MAX_FLOAT)),
103
68
  k=num_sampled,
104
69
  sorted=False,
105
70
  )
106
71
 
107
72
  # Gather sampled logits and corresponding labels.
108
- logits = _gather_elements_along_row(logits, col_indices)
109
- labels = _gather_elements_along_row(labels, col_indices)
73
+ logits = ops.take_along_axis(logits, indices, axis=-1)
74
+ labels = ops.take_along_axis(labels, indices, axis=-1)
110
75
 
111
76
  return logits, labels
keras_rs/src/version.py CHANGED
@@ -1,7 +1,7 @@
1
1
  from keras_rs.src.api_export import keras_rs_export
2
2
 
3
3
  # Unique source of truth for the version number.
4
- __version__ = "0.0.1.dev2025041603"
4
+ __version__ = "0.0.1.dev2025041803"
5
5
 
6
6
 
7
7
  @keras_rs_export("keras_rs.version")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: keras-rs-nightly
3
- Version: 0.0.1.dev2025041603
3
+ Version: 0.0.1.dev2025041803
4
4
  Summary: Multi-backend recommender systems with Keras 3.
5
5
  Author-email: Keras RS team <keras-rs@google.com>
6
6
  License: Apache License 2.0
@@ -5,14 +5,14 @@ keras_rs/api/losses/__init__.py,sha256=LGW7eHQh8FbQXdMV1s9zJpbloVlz_Zlo51sorWAvF
5
5
  keras_rs/src/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
6
  keras_rs/src/api_export.py,sha256=RsmG-DvO-cdFeAF9W6LRzms0kvtm-Yp9BAA_d-952zI,510
7
7
  keras_rs/src/types.py,sha256=UyOdgjqrqg_b58opnY8n6gTiDHKVR8z_bmEruehERBk,514
8
- keras_rs/src/version.py,sha256=wgkFHQtzZaQah52nHJja4pYng_h75BXqZEskD1h29LI,222
8
+ keras_rs/src/version.py,sha256=a6XGs2YLi6kf7exH9ycMMJESRdEYW2cDXyjyETsQNGQ,222
9
9
  keras_rs/src/layers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
10
  keras_rs/src/layers/feature_interaction/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
11
  keras_rs/src/layers/feature_interaction/dot_interaction.py,sha256=jGHcg0EiWxth6LTxG2yWgHcyx_GXrxvA61uQqpPfnDQ,6900
12
12
  keras_rs/src/layers/feature_interaction/feature_cross.py,sha256=5OCSI0vFYzJNmgkKcuHIbVv8U2q3UvS80-qZjPimDjM,8155
13
13
  keras_rs/src/layers/retrieval/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
14
  keras_rs/src/layers/retrieval/brute_force_retrieval.py,sha256=izdppBXxJH0KqYEg7Zsr-SL-SHgAmnFopXMPalEO3uw,5676
15
- keras_rs/src/layers/retrieval/hard_negative_mining.py,sha256=CY8-3W52ZBIFcEfvjXJxbFltD6ulXl4-sZCRF6stIEc,4119
15
+ keras_rs/src/layers/retrieval/hard_negative_mining.py,sha256=IWFrbw1h9z3AUw4oUBKf5_Aud4MTHO_AKdHfoyFa5As,3031
16
16
  keras_rs/src/layers/retrieval/remove_accidental_hits.py,sha256=Z84z2YgKspKeNdc5id8lf9TAyFsbCCz3acJxiKXYipc,3324
17
17
  keras_rs/src/layers/retrieval/retrieval.py,sha256=hVOBF10SF2q_TgJdVUqztbnw5qQF-cxVRGdJbOKoL9M,4191
18
18
  keras_rs/src/layers/retrieval/sampling_probability_correction.py,sha256=80vgOPfBiF-PC0dSyqS57IcIxOxi_Q_R7eSXHn1G0yI,1437
@@ -25,7 +25,7 @@ keras_rs/src/losses/pairwise_soft_zero_one_loss.py,sha256=XBej5nybFXEQ-Vp6GLvNmq
25
25
  keras_rs/src/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
26
26
  keras_rs/src/utils/keras_utils.py,sha256=IjWSRieBkv7UX12qgUoI1tcOeISstCLRSTqSHpT06yE,1275
27
27
  keras_rs/src/utils/pairwise_loss_utils.py,sha256=6eF4CTJubCySO8M5nd3_gdTlJsta_YMnwDCcdqWYGHA,3435
28
- keras_rs_nightly-0.0.1.dev2025041603.dist-info/METADATA,sha256=9EV3mNVpTEuZyIu5Ihha5KhjRwRJnstN4vm1iXMVvQA,3547
29
- keras_rs_nightly-0.0.1.dev2025041603.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
30
- keras_rs_nightly-0.0.1.dev2025041603.dist-info/top_level.txt,sha256=pWs8X78Z0cn6lfcIb9VYOW5UeJ-TpoaO9dByzo7_FFo,9
31
- keras_rs_nightly-0.0.1.dev2025041603.dist-info/RECORD,,
28
+ keras_rs_nightly-0.0.1.dev2025041803.dist-info/METADATA,sha256=MQ8Edtr90TUCgLCF7VcS-D3FkP33iQmPB0R_yU5BNws,3547
29
+ keras_rs_nightly-0.0.1.dev2025041803.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
30
+ keras_rs_nightly-0.0.1.dev2025041803.dist-info/top_level.txt,sha256=pWs8X78Z0cn6lfcIb9VYOW5UeJ-TpoaO9dByzo7_FFo,9
31
+ keras_rs_nightly-0.0.1.dev2025041803.dist-info/RECORD,,