keras-rs-nightly 0.0.1.dev2025022603__tar.gz → 0.0.1.dev2025022803__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of keras-rs-nightly might be problematic. Click here for more details.
- {keras_rs_nightly-0.0.1.dev2025022603 → keras_rs_nightly-0.0.1.dev2025022803}/PKG-INFO +1 -1
- {keras_rs_nightly-0.0.1.dev2025022603 → keras_rs_nightly-0.0.1.dev2025022803}/keras_rs/api/layers/__init__.py +3 -0
- {keras_rs_nightly-0.0.1.dev2025022603 → keras_rs_nightly-0.0.1.dev2025022803}/keras_rs/src/layers/modeling/dot_interaction.py +1 -0
- {keras_rs_nightly-0.0.1.dev2025022603 → keras_rs_nightly-0.0.1.dev2025022803}/keras_rs/src/layers/modeling/feature_cross.py +1 -0
- keras_rs_nightly-0.0.1.dev2025022803/keras_rs/src/layers/retrieval/hard_negative_mining.py +111 -0
- {keras_rs_nightly-0.0.1.dev2025022603 → keras_rs_nightly-0.0.1.dev2025022803}/keras_rs/src/version.py +1 -1
- {keras_rs_nightly-0.0.1.dev2025022603 → keras_rs_nightly-0.0.1.dev2025022803}/keras_rs_nightly.egg-info/PKG-INFO +1 -1
- {keras_rs_nightly-0.0.1.dev2025022603 → keras_rs_nightly-0.0.1.dev2025022803}/keras_rs_nightly.egg-info/SOURCES.txt +1 -0
- {keras_rs_nightly-0.0.1.dev2025022603 → keras_rs_nightly-0.0.1.dev2025022803}/README.md +0 -0
- {keras_rs_nightly-0.0.1.dev2025022603 → keras_rs_nightly-0.0.1.dev2025022803}/keras_rs/__init__.py +0 -0
- {keras_rs_nightly-0.0.1.dev2025022603 → keras_rs_nightly-0.0.1.dev2025022803}/keras_rs/api/__init__.py +0 -0
- {keras_rs_nightly-0.0.1.dev2025022603 → keras_rs_nightly-0.0.1.dev2025022803}/keras_rs/src/__init__.py +0 -0
- {keras_rs_nightly-0.0.1.dev2025022603 → keras_rs_nightly-0.0.1.dev2025022803}/keras_rs/src/api_export.py +0 -0
- {keras_rs_nightly-0.0.1.dev2025022603 → keras_rs_nightly-0.0.1.dev2025022803}/keras_rs/src/layers/__init__.py +0 -0
- {keras_rs_nightly-0.0.1.dev2025022603 → keras_rs_nightly-0.0.1.dev2025022803}/keras_rs/src/layers/modeling/__init__.py +0 -0
- {keras_rs_nightly-0.0.1.dev2025022603 → keras_rs_nightly-0.0.1.dev2025022803}/keras_rs/src/layers/retrieval/__init__.py +0 -0
- {keras_rs_nightly-0.0.1.dev2025022603 → keras_rs_nightly-0.0.1.dev2025022803}/keras_rs/src/layers/retrieval/brute_force_retrieval.py +0 -0
- {keras_rs_nightly-0.0.1.dev2025022603 → keras_rs_nightly-0.0.1.dev2025022803}/keras_rs/src/types.py +0 -0
- {keras_rs_nightly-0.0.1.dev2025022603 → keras_rs_nightly-0.0.1.dev2025022803}/keras_rs/src/utils/__init__.py +0 -0
- {keras_rs_nightly-0.0.1.dev2025022603 → keras_rs_nightly-0.0.1.dev2025022803}/keras_rs/src/utils/keras_utils.py +0 -0
- {keras_rs_nightly-0.0.1.dev2025022603 → keras_rs_nightly-0.0.1.dev2025022803}/keras_rs_nightly.egg-info/dependency_links.txt +0 -0
- {keras_rs_nightly-0.0.1.dev2025022603 → keras_rs_nightly-0.0.1.dev2025022803}/keras_rs_nightly.egg-info/requires.txt +0 -0
- {keras_rs_nightly-0.0.1.dev2025022603 → keras_rs_nightly-0.0.1.dev2025022803}/keras_rs_nightly.egg-info/top_level.txt +0 -0
- {keras_rs_nightly-0.0.1.dev2025022603 → keras_rs_nightly-0.0.1.dev2025022803}/pyproject.toml +0 -0
- {keras_rs_nightly-0.0.1.dev2025022603 → keras_rs_nightly-0.0.1.dev2025022803}/setup.cfg +0 -0
|
@@ -28,6 +28,7 @@ class DotInteraction(keras.layers.Layer):
|
|
|
28
28
|
entries will be zeros. Otherwise, the output will be only the lower
|
|
29
29
|
triangular part of the interaction matrix. The latter saves space
|
|
30
30
|
but is much slower.
|
|
31
|
+
**kwargs: Args to pass to the base class.
|
|
31
32
|
|
|
32
33
|
References:
|
|
33
34
|
- [M. Naumov et al.](https://arxiv.org/abs/1906.00091)
|
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
import keras
|
|
4
|
+
import numpy as np
|
|
5
|
+
from keras import ops
|
|
6
|
+
|
|
7
|
+
from keras_rs.src import types
|
|
8
|
+
from keras_rs.src.api_export import keras_rs_export
|
|
9
|
+
|
|
10
|
+
MAX_FLOAT = np.finfo(np.float32).max / 100.0
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _gather_elements_along_row(
|
|
14
|
+
data: types.Tensor, column_indices: types.Tensor
|
|
15
|
+
) -> types.Tensor:
|
|
16
|
+
"""Gathers elements from a 2D tensor given the column indices of each row.
|
|
17
|
+
|
|
18
|
+
First, gets the flat 1D indices to gather from. Then flattens the data to 1D
|
|
19
|
+
and uses `ops.take()` to generate 1D output and finally reshapes the output
|
|
20
|
+
back to 2D.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
data: A [N, M] 2D `Tensor`.
|
|
24
|
+
column_indices: A [N, K] 2D `Tensor` denoting for each row, the K column
|
|
25
|
+
indices to gather elements from the data `Tensor`.
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
A [N, K] `Tensor` including output elements gathered from data `Tensor`.
|
|
29
|
+
|
|
30
|
+
Raises:
|
|
31
|
+
ValueError: if the first dimensions of data and column_indices don't
|
|
32
|
+
match.
|
|
33
|
+
"""
|
|
34
|
+
num_row, num_column, *_ = ops.shape(data)
|
|
35
|
+
num_gathered = ops.shape(column_indices)[1]
|
|
36
|
+
row_indices = ops.tile(
|
|
37
|
+
ops.expand_dims(ops.arange(num_row), -1), [1, num_gathered]
|
|
38
|
+
)
|
|
39
|
+
flat_data = ops.reshape(data, [-1])
|
|
40
|
+
flat_indices = ops.reshape(
|
|
41
|
+
ops.add(ops.multiply(row_indices, num_column), column_indices), [-1]
|
|
42
|
+
)
|
|
43
|
+
return ops.reshape(
|
|
44
|
+
ops.take(flat_data, flat_indices), [num_row, num_gathered]
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@keras_rs_export("keras_rs.layers.HardNegativeMining")
|
|
49
|
+
class HardNegativeMining(keras.layers.Layer):
|
|
50
|
+
"""Transforms logits and labels to return hard negatives.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
num_hard_negatives: How many hard negatives to return.
|
|
54
|
+
**kwargs: Args to pass to the base class.
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
def __init__(self, num_hard_negatives: int, **kwargs: Any) -> None:
|
|
58
|
+
super().__init__(**kwargs)
|
|
59
|
+
self._num_hard_negatives = num_hard_negatives
|
|
60
|
+
self.built = True
|
|
61
|
+
|
|
62
|
+
def call(
|
|
63
|
+
self, logits: types.Tensor, labels: types.Tensor
|
|
64
|
+
) -> tuple[types.Tensor, types.Tensor]:
|
|
65
|
+
"""Filters logits and labels with per-query hard negative mining.
|
|
66
|
+
|
|
67
|
+
The result will include logits and labels for `num_hard_negatives`
|
|
68
|
+
negatives as well as the positive candidate.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
logits: `[batch_size, number_of_candidates]` tensor of logits.
|
|
72
|
+
labels: `[batch_size, number_of_candidates]` one-hot tensor of
|
|
73
|
+
labels.
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
tuple containing:
|
|
77
|
+
- logits: `[batch_size, num_hard_negatives + 1]` tensor of logits.
|
|
78
|
+
- labels: `[batch_size, num_hard_negatives + 1]` one-hot tensor of
|
|
79
|
+
labels.
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
# Number of sampled logits, i.e, the number of hard negatives to be
|
|
83
|
+
# sampled (k) + number of true logit (1) per query, capped by batch
|
|
84
|
+
# size.
|
|
85
|
+
num_logits = ops.shape(logits)[1]
|
|
86
|
+
if isinstance(num_logits, int):
|
|
87
|
+
num_sampled = min(self._num_hard_negatives + 1, num_logits)
|
|
88
|
+
else:
|
|
89
|
+
num_sampled = ops.minimum(self._num_hard_negatives + 1, num_logits)
|
|
90
|
+
# To gather indices of top k negative logits per row (query) in logits,
|
|
91
|
+
# true logits need to be excluded. First replace the true logits
|
|
92
|
+
# (corresponding to positive labels) with a large score value and then
|
|
93
|
+
# select the top k + 1 logits from each row so that selected indices
|
|
94
|
+
# include the indices of true logit + top k negative logits. This
|
|
95
|
+
# approach is to avoid using inefficient masking when excluding true
|
|
96
|
+
# logits.
|
|
97
|
+
|
|
98
|
+
# For each query, get the indices of the logits which have the highest
|
|
99
|
+
# k + 1 logit values, including the highest k negative logits and one
|
|
100
|
+
# true logit.
|
|
101
|
+
_, col_indices = ops.top_k(
|
|
102
|
+
ops.add(logits, ops.multiply(labels, MAX_FLOAT)),
|
|
103
|
+
k=num_sampled,
|
|
104
|
+
sorted=False,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
# Gather sampled logits and corresponding labels.
|
|
108
|
+
logits = _gather_elements_along_row(logits, col_indices)
|
|
109
|
+
labels = _gather_elements_along_row(labels, col_indices)
|
|
110
|
+
|
|
111
|
+
return logits, labels
|
|
@@ -13,6 +13,7 @@ keras_rs/src/layers/modeling/dot_interaction.py
|
|
|
13
13
|
keras_rs/src/layers/modeling/feature_cross.py
|
|
14
14
|
keras_rs/src/layers/retrieval/__init__.py
|
|
15
15
|
keras_rs/src/layers/retrieval/brute_force_retrieval.py
|
|
16
|
+
keras_rs/src/layers/retrieval/hard_negative_mining.py
|
|
16
17
|
keras_rs/src/utils/__init__.py
|
|
17
18
|
keras_rs/src/utils/keras_utils.py
|
|
18
19
|
keras_rs_nightly.egg-info/PKG-INFO
|
|
File without changes
|
{keras_rs_nightly-0.0.1.dev2025022603 → keras_rs_nightly-0.0.1.dev2025022803}/keras_rs/__init__.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{keras_rs_nightly-0.0.1.dev2025022603 → keras_rs_nightly-0.0.1.dev2025022803}/keras_rs/src/types.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{keras_rs_nightly-0.0.1.dev2025022603 → keras_rs_nightly-0.0.1.dev2025022803}/pyproject.toml
RENAMED
|
File without changes
|
|
File without changes
|