keras-rs-nightly 0.0.1.dev2025041603__tar.gz → 0.0.1.dev2025041803__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of keras-rs-nightly might be problematic. Click here for more details.
- {keras_rs_nightly-0.0.1.dev2025041603 → keras_rs_nightly-0.0.1.dev2025041803}/PKG-INFO +1 -1
- {keras_rs_nightly-0.0.1.dev2025041603 → keras_rs_nightly-0.0.1.dev2025041803}/keras_rs/src/layers/retrieval/hard_negative_mining.py +11 -46
- {keras_rs_nightly-0.0.1.dev2025041603 → keras_rs_nightly-0.0.1.dev2025041803}/keras_rs/src/version.py +1 -1
- {keras_rs_nightly-0.0.1.dev2025041603 → keras_rs_nightly-0.0.1.dev2025041803}/keras_rs_nightly.egg-info/PKG-INFO +1 -1
- {keras_rs_nightly-0.0.1.dev2025041603 → keras_rs_nightly-0.0.1.dev2025041803}/README.md +0 -0
- {keras_rs_nightly-0.0.1.dev2025041603 → keras_rs_nightly-0.0.1.dev2025041803}/keras_rs/__init__.py +0 -0
- {keras_rs_nightly-0.0.1.dev2025041603 → keras_rs_nightly-0.0.1.dev2025041803}/keras_rs/api/__init__.py +0 -0
- {keras_rs_nightly-0.0.1.dev2025041603 → keras_rs_nightly-0.0.1.dev2025041803}/keras_rs/api/layers/__init__.py +0 -0
- {keras_rs_nightly-0.0.1.dev2025041603 → keras_rs_nightly-0.0.1.dev2025041803}/keras_rs/api/losses/__init__.py +0 -0
- {keras_rs_nightly-0.0.1.dev2025041603 → keras_rs_nightly-0.0.1.dev2025041803}/keras_rs/src/__init__.py +0 -0
- {keras_rs_nightly-0.0.1.dev2025041603 → keras_rs_nightly-0.0.1.dev2025041803}/keras_rs/src/api_export.py +0 -0
- {keras_rs_nightly-0.0.1.dev2025041603 → keras_rs_nightly-0.0.1.dev2025041803}/keras_rs/src/layers/__init__.py +0 -0
- {keras_rs_nightly-0.0.1.dev2025041603 → keras_rs_nightly-0.0.1.dev2025041803}/keras_rs/src/layers/feature_interaction/__init__.py +0 -0
- {keras_rs_nightly-0.0.1.dev2025041603 → keras_rs_nightly-0.0.1.dev2025041803}/keras_rs/src/layers/feature_interaction/dot_interaction.py +0 -0
- {keras_rs_nightly-0.0.1.dev2025041603 → keras_rs_nightly-0.0.1.dev2025041803}/keras_rs/src/layers/feature_interaction/feature_cross.py +0 -0
- {keras_rs_nightly-0.0.1.dev2025041603 → keras_rs_nightly-0.0.1.dev2025041803}/keras_rs/src/layers/retrieval/__init__.py +0 -0
- {keras_rs_nightly-0.0.1.dev2025041603 → keras_rs_nightly-0.0.1.dev2025041803}/keras_rs/src/layers/retrieval/brute_force_retrieval.py +0 -0
- {keras_rs_nightly-0.0.1.dev2025041603 → keras_rs_nightly-0.0.1.dev2025041803}/keras_rs/src/layers/retrieval/remove_accidental_hits.py +0 -0
- {keras_rs_nightly-0.0.1.dev2025041603 → keras_rs_nightly-0.0.1.dev2025041803}/keras_rs/src/layers/retrieval/retrieval.py +0 -0
- {keras_rs_nightly-0.0.1.dev2025041603 → keras_rs_nightly-0.0.1.dev2025041803}/keras_rs/src/layers/retrieval/sampling_probability_correction.py +0 -0
- {keras_rs_nightly-0.0.1.dev2025041603 → keras_rs_nightly-0.0.1.dev2025041803}/keras_rs/src/losses/__init__.py +0 -0
- {keras_rs_nightly-0.0.1.dev2025041603 → keras_rs_nightly-0.0.1.dev2025041803}/keras_rs/src/losses/pairwise_hinge_loss.py +0 -0
- {keras_rs_nightly-0.0.1.dev2025041603 → keras_rs_nightly-0.0.1.dev2025041803}/keras_rs/src/losses/pairwise_logistic_loss.py +0 -0
- {keras_rs_nightly-0.0.1.dev2025041603 → keras_rs_nightly-0.0.1.dev2025041803}/keras_rs/src/losses/pairwise_loss.py +0 -0
- {keras_rs_nightly-0.0.1.dev2025041603 → keras_rs_nightly-0.0.1.dev2025041803}/keras_rs/src/losses/pairwise_mean_squared_error.py +0 -0
- {keras_rs_nightly-0.0.1.dev2025041603 → keras_rs_nightly-0.0.1.dev2025041803}/keras_rs/src/losses/pairwise_soft_zero_one_loss.py +0 -0
- {keras_rs_nightly-0.0.1.dev2025041603 → keras_rs_nightly-0.0.1.dev2025041803}/keras_rs/src/types.py +0 -0
- {keras_rs_nightly-0.0.1.dev2025041603 → keras_rs_nightly-0.0.1.dev2025041803}/keras_rs/src/utils/__init__.py +0 -0
- {keras_rs_nightly-0.0.1.dev2025041603 → keras_rs_nightly-0.0.1.dev2025041803}/keras_rs/src/utils/keras_utils.py +0 -0
- {keras_rs_nightly-0.0.1.dev2025041603 → keras_rs_nightly-0.0.1.dev2025041803}/keras_rs/src/utils/pairwise_loss_utils.py +0 -0
- {keras_rs_nightly-0.0.1.dev2025041603 → keras_rs_nightly-0.0.1.dev2025041803}/keras_rs_nightly.egg-info/SOURCES.txt +0 -0
- {keras_rs_nightly-0.0.1.dev2025041603 → keras_rs_nightly-0.0.1.dev2025041803}/keras_rs_nightly.egg-info/dependency_links.txt +0 -0
- {keras_rs_nightly-0.0.1.dev2025041603 → keras_rs_nightly-0.0.1.dev2025041803}/keras_rs_nightly.egg-info/requires.txt +0 -0
- {keras_rs_nightly-0.0.1.dev2025041603 → keras_rs_nightly-0.0.1.dev2025041803}/keras_rs_nightly.egg-info/top_level.txt +0 -0
- {keras_rs_nightly-0.0.1.dev2025041603 → keras_rs_nightly-0.0.1.dev2025041803}/pyproject.toml +0 -0
- {keras_rs_nightly-0.0.1.dev2025041603 → keras_rs_nightly-0.0.1.dev2025041803}/setup.cfg +0 -0
|
@@ -10,41 +10,6 @@ from keras_rs.src.api_export import keras_rs_export
|
|
|
10
10
|
MAX_FLOAT = ml_dtypes.finfo("float32").max / 100.0
|
|
11
11
|
|
|
12
12
|
|
|
13
|
-
def _gather_elements_along_row(
|
|
14
|
-
data: types.Tensor, column_indices: types.Tensor
|
|
15
|
-
) -> types.Tensor:
|
|
16
|
-
"""Gathers elements from a 2D tensor given the column indices of each row.
|
|
17
|
-
|
|
18
|
-
First, gets the flat 1D indices to gather from. Then flattens the data to 1D
|
|
19
|
-
and uses `ops.take()` to generate 1D output and finally reshapes the output
|
|
20
|
-
back to 2D.
|
|
21
|
-
|
|
22
|
-
Args:
|
|
23
|
-
data: A [N, M] 2D `Tensor`.
|
|
24
|
-
column_indices: A [N, K] 2D `Tensor` denoting for each row, the K column
|
|
25
|
-
indices to gather elements from the data `Tensor`.
|
|
26
|
-
|
|
27
|
-
Returns:
|
|
28
|
-
A [N, K] `Tensor` including output elements gathered from data `Tensor`.
|
|
29
|
-
|
|
30
|
-
Raises:
|
|
31
|
-
ValueError: if the first dimensions of data and column_indices don't
|
|
32
|
-
match.
|
|
33
|
-
"""
|
|
34
|
-
num_row, num_column, *_ = ops.shape(data)
|
|
35
|
-
num_gathered = ops.shape(column_indices)[1]
|
|
36
|
-
row_indices = ops.tile(
|
|
37
|
-
ops.expand_dims(ops.arange(num_row), -1), [1, num_gathered]
|
|
38
|
-
)
|
|
39
|
-
flat_data = ops.reshape(data, [-1])
|
|
40
|
-
flat_indices = ops.reshape(
|
|
41
|
-
ops.add(ops.multiply(row_indices, num_column), column_indices), [-1]
|
|
42
|
-
)
|
|
43
|
-
return ops.reshape(
|
|
44
|
-
ops.take(flat_data, flat_indices), [num_row, num_gathered]
|
|
45
|
-
)
|
|
46
|
-
|
|
47
|
-
|
|
48
13
|
@keras_rs_export("keras_rs.layers.HardNegativeMining")
|
|
49
14
|
class HardNegativeMining(keras.layers.Layer):
|
|
50
15
|
"""Transforms logits and labels to return hard negatives.
|
|
@@ -68,21 +33,21 @@ class HardNegativeMining(keras.layers.Layer):
|
|
|
68
33
|
negatives as well as the positive candidate.
|
|
69
34
|
|
|
70
35
|
Args:
|
|
71
|
-
logits: `[batch_size,
|
|
72
|
-
|
|
73
|
-
|
|
36
|
+
logits: logits tensor, typically `[batch_size, num_candidates]` but
|
|
37
|
+
can have more dimensions or be 1D as `[num_candidates]`.
|
|
38
|
+
labels: one-hot labels tensor, must be the same shape as `logits`.
|
|
74
39
|
|
|
75
40
|
Returns:
|
|
76
|
-
tuple containing
|
|
77
|
-
|
|
78
|
-
-
|
|
79
|
-
|
|
41
|
+
tuple containing two tensors with the last dimension of
|
|
42
|
+
`num_candidates` replaced with `num_hard_negatives + 1`.
|
|
43
|
+
- logits: `[..., num_hard_negatives + 1]` tensor of logits.
|
|
44
|
+
- labels: `[..., num_hard_negatives + 1]` one-hot tensor of labels.
|
|
80
45
|
"""
|
|
81
46
|
|
|
82
47
|
# Number of sampled logits, i.e, the number of hard negatives to be
|
|
83
48
|
# sampled (k) + number of true logit (1) per query, capped by batch
|
|
84
49
|
# size.
|
|
85
|
-
num_logits = ops.shape(logits)[1]
|
|
50
|
+
num_logits = ops.shape(logits)[-1]
|
|
86
51
|
if isinstance(num_logits, int):
|
|
87
52
|
num_sampled = min(self._num_hard_negatives + 1, num_logits)
|
|
88
53
|
else:
|
|
@@ -98,14 +63,14 @@ class HardNegativeMining(keras.layers.Layer):
|
|
|
98
63
|
# For each query, get the indices of the logits which have the highest
|
|
99
64
|
# k + 1 logit values, including the highest k negative logits and one
|
|
100
65
|
# true logit.
|
|
101
|
-
_,
|
|
66
|
+
_, indices = ops.top_k(
|
|
102
67
|
ops.add(logits, ops.multiply(labels, MAX_FLOAT)),
|
|
103
68
|
k=num_sampled,
|
|
104
69
|
sorted=False,
|
|
105
70
|
)
|
|
106
71
|
|
|
107
72
|
# Gather sampled logits and corresponding labels.
|
|
108
|
-
logits =
|
|
109
|
-
labels =
|
|
73
|
+
logits = ops.take_along_axis(logits, indices, axis=-1)
|
|
74
|
+
labels = ops.take_along_axis(labels, indices, axis=-1)
|
|
110
75
|
|
|
111
76
|
return logits, labels
|
|
File without changes
|
{keras_rs_nightly-0.0.1.dev2025041603 → keras_rs_nightly-0.0.1.dev2025041803}/keras_rs/__init__.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{keras_rs_nightly-0.0.1.dev2025041603 → keras_rs_nightly-0.0.1.dev2025041803}/keras_rs/src/types.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{keras_rs_nightly-0.0.1.dev2025041603 → keras_rs_nightly-0.0.1.dev2025041803}/pyproject.toml
RENAMED
|
File without changes
|
|
File without changes
|