keras-rs-nightly 0.0.1.dev2025030803__tar.gz → 0.0.1.dev2025031003__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of keras-rs-nightly might be problematic. Click here for more details.
- {keras_rs_nightly-0.0.1.dev2025030803 → keras_rs_nightly-0.0.1.dev2025031003}/PKG-INFO +1 -1
- {keras_rs_nightly-0.0.1.dev2025030803 → keras_rs_nightly-0.0.1.dev2025031003}/keras_rs/api/__init__.py +1 -0
- keras_rs_nightly-0.0.1.dev2025031003/keras_rs/api/losses/__init__.py +14 -0
- keras_rs_nightly-0.0.1.dev2025031003/keras_rs/src/losses/pairwise_hinge_loss.py +26 -0
- keras_rs_nightly-0.0.1.dev2025031003/keras_rs/src/losses/pairwise_logistic_loss.py +35 -0
- keras_rs_nightly-0.0.1.dev2025031003/keras_rs/src/losses/pairwise_loss.py +134 -0
- keras_rs_nightly-0.0.1.dev2025031003/keras_rs/src/losses/pairwise_mean_squared_error.py +71 -0
- keras_rs_nightly-0.0.1.dev2025031003/keras_rs/src/losses/pairwise_soft_zero_one_loss.py +31 -0
- keras_rs_nightly-0.0.1.dev2025031003/keras_rs/src/utils/__init__.py +0 -0
- keras_rs_nightly-0.0.1.dev2025031003/keras_rs/src/utils/pairwise_loss_utils.py +100 -0
- {keras_rs_nightly-0.0.1.dev2025030803 → keras_rs_nightly-0.0.1.dev2025031003}/keras_rs/src/version.py +1 -1
- {keras_rs_nightly-0.0.1.dev2025030803 → keras_rs_nightly-0.0.1.dev2025031003}/keras_rs_nightly.egg-info/PKG-INFO +1 -1
- {keras_rs_nightly-0.0.1.dev2025030803 → keras_rs_nightly-0.0.1.dev2025031003}/keras_rs_nightly.egg-info/SOURCES.txt +8 -0
- {keras_rs_nightly-0.0.1.dev2025030803 → keras_rs_nightly-0.0.1.dev2025031003}/README.md +0 -0
- {keras_rs_nightly-0.0.1.dev2025030803 → keras_rs_nightly-0.0.1.dev2025031003}/keras_rs/__init__.py +0 -0
- {keras_rs_nightly-0.0.1.dev2025030803 → keras_rs_nightly-0.0.1.dev2025031003}/keras_rs/api/layers/__init__.py +0 -0
- {keras_rs_nightly-0.0.1.dev2025030803 → keras_rs_nightly-0.0.1.dev2025031003}/keras_rs/src/__init__.py +0 -0
- {keras_rs_nightly-0.0.1.dev2025030803 → keras_rs_nightly-0.0.1.dev2025031003}/keras_rs/src/api_export.py +0 -0
- {keras_rs_nightly-0.0.1.dev2025030803 → keras_rs_nightly-0.0.1.dev2025031003}/keras_rs/src/layers/__init__.py +0 -0
- {keras_rs_nightly-0.0.1.dev2025030803 → keras_rs_nightly-0.0.1.dev2025031003}/keras_rs/src/layers/modeling/__init__.py +0 -0
- {keras_rs_nightly-0.0.1.dev2025030803 → keras_rs_nightly-0.0.1.dev2025031003}/keras_rs/src/layers/modeling/dot_interaction.py +0 -0
- {keras_rs_nightly-0.0.1.dev2025030803 → keras_rs_nightly-0.0.1.dev2025031003}/keras_rs/src/layers/modeling/feature_cross.py +0 -0
- {keras_rs_nightly-0.0.1.dev2025030803 → keras_rs_nightly-0.0.1.dev2025031003}/keras_rs/src/layers/retrieval/__init__.py +0 -0
- {keras_rs_nightly-0.0.1.dev2025030803 → keras_rs_nightly-0.0.1.dev2025031003}/keras_rs/src/layers/retrieval/brute_force_retrieval.py +0 -0
- {keras_rs_nightly-0.0.1.dev2025030803 → keras_rs_nightly-0.0.1.dev2025031003}/keras_rs/src/layers/retrieval/hard_negative_mining.py +0 -0
- {keras_rs_nightly-0.0.1.dev2025030803 → keras_rs_nightly-0.0.1.dev2025031003}/keras_rs/src/layers/retrieval/sampling_probability_correction.py +0 -0
- {keras_rs_nightly-0.0.1.dev2025030803/keras_rs/src/utils → keras_rs_nightly-0.0.1.dev2025031003/keras_rs/src/losses}/__init__.py +0 -0
- {keras_rs_nightly-0.0.1.dev2025030803 → keras_rs_nightly-0.0.1.dev2025031003}/keras_rs/src/types.py +0 -0
- {keras_rs_nightly-0.0.1.dev2025030803 → keras_rs_nightly-0.0.1.dev2025031003}/keras_rs/src/utils/keras_utils.py +0 -0
- {keras_rs_nightly-0.0.1.dev2025030803 → keras_rs_nightly-0.0.1.dev2025031003}/keras_rs_nightly.egg-info/dependency_links.txt +0 -0
- {keras_rs_nightly-0.0.1.dev2025030803 → keras_rs_nightly-0.0.1.dev2025031003}/keras_rs_nightly.egg-info/requires.txt +0 -0
- {keras_rs_nightly-0.0.1.dev2025030803 → keras_rs_nightly-0.0.1.dev2025031003}/keras_rs_nightly.egg-info/top_level.txt +0 -0
- {keras_rs_nightly-0.0.1.dev2025030803 → keras_rs_nightly-0.0.1.dev2025031003}/pyproject.toml +0 -0
- {keras_rs_nightly-0.0.1.dev2025030803 → keras_rs_nightly-0.0.1.dev2025031003}/setup.cfg +0 -0
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
"""DO NOT EDIT.
|
|
2
|
+
|
|
3
|
+
This file was autogenerated. Do not edit it by hand,
|
|
4
|
+
since your modifications would be overwritten.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from keras_rs.src.losses.pairwise_hinge_loss import PairwiseHingeLoss
|
|
8
|
+
from keras_rs.src.losses.pairwise_logistic_loss import PairwiseLogisticLoss
|
|
9
|
+
from keras_rs.src.losses.pairwise_mean_squared_error import (
|
|
10
|
+
PairwiseMeanSquaredError,
|
|
11
|
+
)
|
|
12
|
+
from keras_rs.src.losses.pairwise_soft_zero_one_loss import (
|
|
13
|
+
PairwiseSoftZeroOneLoss,
|
|
14
|
+
)
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
from keras import ops
|
|
2
|
+
|
|
3
|
+
from keras_rs.src import types
|
|
4
|
+
from keras_rs.src.api_export import keras_rs_export
|
|
5
|
+
from keras_rs.src.losses.pairwise_loss import PairwiseLoss
|
|
6
|
+
from keras_rs.src.losses.pairwise_loss import pairwise_loss_subclass_doc_string
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@keras_rs_export("keras_rs.losses.PairwiseHingeLoss")
|
|
10
|
+
class PairwiseHingeLoss(PairwiseLoss):
|
|
11
|
+
def pairwise_loss(self, pairwise_logits: types.Tensor) -> types.Tensor:
|
|
12
|
+
return ops.relu(ops.subtract(ops.array(1), pairwise_logits))
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
formula = "loss = sum_{i} sum_{j} I(y_i > y_j) * max(0, 1 - (s_i - s_j))"
|
|
16
|
+
explanation = """
|
|
17
|
+
- `max(0, 1 - (s_i - s_j))` is the hinge loss, which penalizes cases where
|
|
18
|
+
the score difference `s_i - s_j` is not sufficiently large when
|
|
19
|
+
`y_i > y_j`.
|
|
20
|
+
"""
|
|
21
|
+
extra_args = ""
|
|
22
|
+
PairwiseHingeLoss.__doc__ = pairwise_loss_subclass_doc_string.format(
|
|
23
|
+
formula=formula,
|
|
24
|
+
explanation=explanation,
|
|
25
|
+
extra_args=extra_args,
|
|
26
|
+
)
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
from keras import ops
|
|
2
|
+
|
|
3
|
+
from keras_rs.src import types
|
|
4
|
+
from keras_rs.src.api_export import keras_rs_export
|
|
5
|
+
from keras_rs.src.losses.pairwise_loss import PairwiseLoss
|
|
6
|
+
from keras_rs.src.losses.pairwise_loss import pairwise_loss_subclass_doc_string
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@keras_rs_export("keras_rs.losses.PairwiseLogisticLoss")
|
|
10
|
+
class PairwiseLogisticLoss(PairwiseLoss):
|
|
11
|
+
def pairwise_loss(self, pairwise_logits: types.Tensor) -> types.Tensor:
|
|
12
|
+
return ops.add(
|
|
13
|
+
ops.relu(ops.negative(pairwise_logits)),
|
|
14
|
+
ops.log(
|
|
15
|
+
ops.add(
|
|
16
|
+
ops.array(1),
|
|
17
|
+
ops.exp(ops.negative(ops.abs(pairwise_logits))),
|
|
18
|
+
)
|
|
19
|
+
),
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
formula = "loss = sum_{i} sum_{j} I(y_i > y_j) * log(1 + exp(-(s_i - s_j)))"
|
|
24
|
+
explanation = """
|
|
25
|
+
- `log(1 + exp(-(s_i - s_j)))` is the logistic loss, which penalizes
|
|
26
|
+
cases where the score difference `s_i - s_j` is not sufficiently large
|
|
27
|
+
when `y_i > y_j`. This function provides a smooth approximation of the
|
|
28
|
+
ideal step function, making it suitable for gradient-based optimization.
|
|
29
|
+
"""
|
|
30
|
+
extra_args = ""
|
|
31
|
+
PairwiseLogisticLoss.__doc__ = pairwise_loss_subclass_doc_string.format(
|
|
32
|
+
formula=formula,
|
|
33
|
+
explanation=explanation,
|
|
34
|
+
extra_args=extra_args,
|
|
35
|
+
)
|
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
import keras
|
|
5
|
+
from keras import ops
|
|
6
|
+
|
|
7
|
+
from keras_rs.src import types
|
|
8
|
+
from keras_rs.src.utils.pairwise_loss_utils import pairwise_comparison
|
|
9
|
+
from keras_rs.src.utils.pairwise_loss_utils import process_loss_call_inputs
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class PairwiseLoss(keras.losses.Loss, abc.ABC):
|
|
13
|
+
"""Base class for pairwise ranking losses.
|
|
14
|
+
|
|
15
|
+
Pairwise loss functions are designed for ranking tasks, where the goal is to
|
|
16
|
+
correctly order items within each list. Any pairwise loss function computes
|
|
17
|
+
the loss value by comparing pairs of items within each list, penalizing
|
|
18
|
+
cases where an item with a higher true label has a lower predicted score
|
|
19
|
+
than an item with a lower true label.
|
|
20
|
+
|
|
21
|
+
In order to implement any kind of pairwise loss, override the
|
|
22
|
+
`pairwise_loss` method.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
# TODO: Add `temperature`, `lambda_weights`.
|
|
26
|
+
|
|
27
|
+
@abc.abstractmethod
|
|
28
|
+
def pairwise_loss(self, pairwise_logits: types.Tensor) -> types.Tensor:
|
|
29
|
+
raise NotImplementedError(
|
|
30
|
+
"All subclasses of `keras_rs.losses.pairwise_loss.PairwiseLoss`"
|
|
31
|
+
"must implement the `pairwise_loss()` method."
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
def compute_unreduced_loss(
|
|
35
|
+
self,
|
|
36
|
+
labels: types.Tensor,
|
|
37
|
+
logits: types.Tensor,
|
|
38
|
+
mask: Optional[types.Tensor] = None,
|
|
39
|
+
) -> tuple[types.Tensor, types.Tensor]:
|
|
40
|
+
# Mask all values less than 0 (since less than 0 implies invalid
|
|
41
|
+
# labels).
|
|
42
|
+
valid_mask = ops.greater_equal(labels, ops.cast(0.0, labels.dtype))
|
|
43
|
+
|
|
44
|
+
if mask is not None:
|
|
45
|
+
valid_mask = ops.logical_and(valid_mask, mask)
|
|
46
|
+
|
|
47
|
+
pairwise_labels, pairwise_logits = pairwise_comparison(
|
|
48
|
+
labels=labels,
|
|
49
|
+
logits=logits,
|
|
50
|
+
mask=valid_mask,
|
|
51
|
+
logits_op=ops.subtract,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
return self.pairwise_loss(pairwise_logits), pairwise_labels
|
|
55
|
+
|
|
56
|
+
def call(
|
|
57
|
+
self,
|
|
58
|
+
y_true: types.Tensor,
|
|
59
|
+
y_pred: types.Tensor,
|
|
60
|
+
) -> types.Tensor:
|
|
61
|
+
"""
|
|
62
|
+
Args:
|
|
63
|
+
y_true: tensor or dict. Ground truth values. If tensor, of shape
|
|
64
|
+
`(list_size)` for unbatched inputs or `(batch_size, list_size)`
|
|
65
|
+
for batched inputs. If an item has a label of -1, it is ignored
|
|
66
|
+
in loss computation. If it is a dictionary, it should have two
|
|
67
|
+
keys: `"labels"` and `"mask"`. `"mask"` can be used to ignore
|
|
68
|
+
elements in loss computation, i.e., pairs will not be formed
|
|
69
|
+
with those items. Note that the final mask is an and of the
|
|
70
|
+
passed mask, and `labels == -1`.
|
|
71
|
+
y_pred: tensor. The predicted values, of shape `(list_size)` for
|
|
72
|
+
unbatched inputs or `(batch_size, list_size)` for batched
|
|
73
|
+
inputs. Should be of the same shape as `y_true`.
|
|
74
|
+
"""
|
|
75
|
+
mask = None
|
|
76
|
+
if isinstance(y_true, dict):
|
|
77
|
+
if "labels" not in y_true:
|
|
78
|
+
raise ValueError(
|
|
79
|
+
'`"labels"` should be present in `y_true`. Received: '
|
|
80
|
+
f"`y_true` = {y_true}"
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
mask = y_true.get("mask", None)
|
|
84
|
+
y_true = y_true["labels"]
|
|
85
|
+
|
|
86
|
+
y_true, y_pred, mask = process_loss_call_inputs(y_true, y_pred, mask)
|
|
87
|
+
|
|
88
|
+
losses, weights = self.compute_unreduced_loss(
|
|
89
|
+
labels=y_true, logits=y_pred, mask=mask
|
|
90
|
+
)
|
|
91
|
+
losses = ops.multiply(losses, weights)
|
|
92
|
+
losses = ops.sum(losses, axis=-1)
|
|
93
|
+
return losses
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
pairwise_loss_subclass_doc_string = (
|
|
97
|
+
"Computes pairwise hinge loss between true labels and predicted scores."
|
|
98
|
+
"""
|
|
99
|
+
This loss function is designed for ranking tasks, where the goal is to
|
|
100
|
+
correctly order items within each list. It computes the loss by comparing
|
|
101
|
+
pairs of items within each list, penalizing cases where an item with a
|
|
102
|
+
higher true label has a lower predicted score than an item with a lower
|
|
103
|
+
true label.
|
|
104
|
+
|
|
105
|
+
For each list of predicted scores `s` in `y_pred` and the corresponding list
|
|
106
|
+
of true labels `y` in `y_true`, the loss is computed as follows:
|
|
107
|
+
|
|
108
|
+
```
|
|
109
|
+
{formula}
|
|
110
|
+
```
|
|
111
|
+
|
|
112
|
+
where:
|
|
113
|
+
- `y_i` and `y_j` are the true labels of items `i` and `j`, respectively.
|
|
114
|
+
- `s_i` and `s_j` are the predicted scores of items `i` and `j`,
|
|
115
|
+
respectively.
|
|
116
|
+
- `I(y_i > y_j)` is an indicator function that equals 1 if `y_i > y_j`,
|
|
117
|
+
and 0 otherwise.{explanation}
|
|
118
|
+
Args:{extra_args}
|
|
119
|
+
reduction: Type of reduction to apply to the loss. In almost all cases
|
|
120
|
+
this should be `"sum_over_batch_size"`. Supported options are
|
|
121
|
+
`"sum"`, `"sum_over_batch_size"`, `"mean"`,
|
|
122
|
+
`"mean_with_sample_weight"` or `None`. `"sum"` sums the loss,
|
|
123
|
+
`"sum_over_batch_size"` and `"mean"` sum the loss and divide by the
|
|
124
|
+
sample size, and `"mean_with_sample_weight"` sums the loss and
|
|
125
|
+
divides by the sum of the sample weights. `"none"` and `None`
|
|
126
|
+
perform no aggregation. Defaults to `"sum_over_batch_size"`.
|
|
127
|
+
name: Optional name for the loss instance.
|
|
128
|
+
dtype: The dtype of the loss's computations. Defaults to `None`, which
|
|
129
|
+
means using `keras.backend.floatx()`. `keras.backend.floatx()` is a
|
|
130
|
+
`"float32"` unless set to different value
|
|
131
|
+
(via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
|
|
132
|
+
provided, then the `compute_dtype` will be utilized.
|
|
133
|
+
"""
|
|
134
|
+
)
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
from keras import ops
|
|
4
|
+
|
|
5
|
+
from keras_rs.src import types
|
|
6
|
+
from keras_rs.src.api_export import keras_rs_export
|
|
7
|
+
from keras_rs.src.losses.pairwise_loss import PairwiseLoss
|
|
8
|
+
from keras_rs.src.losses.pairwise_loss import pairwise_loss_subclass_doc_string
|
|
9
|
+
from keras_rs.src.utils.pairwise_loss_utils import apply_pairwise_op
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@keras_rs_export("keras_rs.losses.PairwiseMeanSquaredError")
|
|
13
|
+
class PairwiseMeanSquaredError(PairwiseLoss):
|
|
14
|
+
def pairwise_loss(self, pairwise_logits: types.Tensor) -> types.Tensor:
|
|
15
|
+
# Since we override `compute_unreduced_loss`, we do not need to
|
|
16
|
+
# implement this method.
|
|
17
|
+
pass
|
|
18
|
+
|
|
19
|
+
def compute_unreduced_loss(
|
|
20
|
+
self,
|
|
21
|
+
labels: types.Tensor,
|
|
22
|
+
logits: types.Tensor,
|
|
23
|
+
mask: Optional[types.Tensor] = None,
|
|
24
|
+
) -> tuple[types.Tensor, types.Tensor]:
|
|
25
|
+
# Override `PairwiseLoss.compute_unreduced_loss` since pairwise weights
|
|
26
|
+
# for MSE are computed differently.
|
|
27
|
+
|
|
28
|
+
batch_size, list_size = ops.shape(labels)
|
|
29
|
+
|
|
30
|
+
# Mask all values less than 0 (since less than 0 implies invalid
|
|
31
|
+
# labels).
|
|
32
|
+
valid_mask = ops.greater_equal(labels, ops.cast(0.0, labels.dtype))
|
|
33
|
+
|
|
34
|
+
if mask is not None:
|
|
35
|
+
valid_mask = ops.logical_and(valid_mask, mask)
|
|
36
|
+
|
|
37
|
+
# Compute the difference for all pairs in a list. The output is a tensor
|
|
38
|
+
# with shape `(batch_size, list_size, list_size)`, where `[:, i, j]`
|
|
39
|
+
# stores information for pair `(i, j)`.
|
|
40
|
+
pairwise_labels_diff = apply_pairwise_op(labels, ops.subtract)
|
|
41
|
+
pairwise_logits_diff = apply_pairwise_op(logits, ops.subtract)
|
|
42
|
+
valid_pair = apply_pairwise_op(valid_mask, ops.logical_and)
|
|
43
|
+
pairwise_mse = ops.square(pairwise_labels_diff - pairwise_logits_diff)
|
|
44
|
+
|
|
45
|
+
# Compute weights.
|
|
46
|
+
pairwise_weights = ops.ones_like(pairwise_mse)
|
|
47
|
+
# Exclude self pairs.
|
|
48
|
+
pairwise_weights = ops.subtract(
|
|
49
|
+
pairwise_weights,
|
|
50
|
+
ops.tile(ops.eye(list_size, list_size), (batch_size, 1, 1)),
|
|
51
|
+
)
|
|
52
|
+
# Include only valid pairs.
|
|
53
|
+
pairwise_weights = ops.multiply(
|
|
54
|
+
pairwise_weights, ops.cast(valid_pair, dtype=pairwise_weights.dtype)
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
return pairwise_mse, pairwise_weights
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
formula = "loss = sum_{i} sum_{j} I(y_i > y_j) * (s_i - s_j)^2"
|
|
61
|
+
explanation = """
|
|
62
|
+
- `(s_i - s_j)^2` is the squared difference between the predicted scores
|
|
63
|
+
of items `i` and `j`, which penalizes discrepancies between the
|
|
64
|
+
predicted order of items relative to their true order.
|
|
65
|
+
"""
|
|
66
|
+
extra_args = ""
|
|
67
|
+
PairwiseMeanSquaredError.__doc__ = pairwise_loss_subclass_doc_string.format(
|
|
68
|
+
formula=formula,
|
|
69
|
+
explanation=explanation,
|
|
70
|
+
extra_args=extra_args,
|
|
71
|
+
)
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
from keras import ops
|
|
2
|
+
|
|
3
|
+
from keras_rs.src import types
|
|
4
|
+
from keras_rs.src.api_export import keras_rs_export
|
|
5
|
+
from keras_rs.src.losses.pairwise_loss import PairwiseLoss
|
|
6
|
+
from keras_rs.src.losses.pairwise_loss import pairwise_loss_subclass_doc_string
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@keras_rs_export("keras_rs.losses.PairwiseSoftZeroOneLoss")
|
|
10
|
+
class PairwiseSoftZeroOneLoss(PairwiseLoss):
|
|
11
|
+
def pairwise_loss(self, pairwise_logits: types.Tensor) -> types.Tensor:
|
|
12
|
+
return ops.where(
|
|
13
|
+
ops.greater(pairwise_logits, ops.array(0.0)),
|
|
14
|
+
ops.subtract(ops.array(1.0), ops.sigmoid(pairwise_logits)),
|
|
15
|
+
ops.sigmoid(ops.negative(pairwise_logits)),
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
formula = "loss = sum_{i} sum_{j} I(y_i > y_j) * (1 - sigmoid(s_i - s_j))"
|
|
20
|
+
explanation = """
|
|
21
|
+
- `(1 - sigmoid(s_i - s_j))` represents the soft zero-one loss, which
|
|
22
|
+
approximates the ideal zero-one loss (which would be 1 if `s_i < s_j`
|
|
23
|
+
and 0 otherwise) with a smooth, differentiable function. This makes it
|
|
24
|
+
suitable for gradient-based optimization.
|
|
25
|
+
"""
|
|
26
|
+
extra_args = ""
|
|
27
|
+
PairwiseSoftZeroOneLoss.__doc__ = pairwise_loss_subclass_doc_string.format(
|
|
28
|
+
formula=formula,
|
|
29
|
+
explanation=explanation,
|
|
30
|
+
extra_args=extra_args,
|
|
31
|
+
)
|
|
File without changes
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
from typing import Callable, Optional
|
|
2
|
+
|
|
3
|
+
from keras import ops
|
|
4
|
+
|
|
5
|
+
from keras_rs.src import types
|
|
6
|
+
from keras_rs.src.utils.keras_utils import check_shapes_compatible
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def apply_pairwise_op(x: types.Tensor, op: ops) -> types.Tensor:
|
|
10
|
+
return op(
|
|
11
|
+
ops.expand_dims(x, axis=-1),
|
|
12
|
+
ops.expand_dims(x, axis=-2),
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def pairwise_comparison(
|
|
17
|
+
labels: types.Tensor,
|
|
18
|
+
logits: types.Tensor,
|
|
19
|
+
mask: types.Tensor,
|
|
20
|
+
logits_op: Callable[[types.Tensor, types.Tensor], types.Tensor],
|
|
21
|
+
) -> tuple[types.Tensor, types.Tensor]:
|
|
22
|
+
# Compute the difference for all pairs in a list. The output is a tensor
|
|
23
|
+
# with shape `(batch_size, list_size, list_size)`, where `[:, i, j]` stores
|
|
24
|
+
# information for pair `(i, j)`.
|
|
25
|
+
pairwise_labels_diff = apply_pairwise_op(labels, ops.subtract)
|
|
26
|
+
pairwise_logits = apply_pairwise_op(logits, logits_op)
|
|
27
|
+
|
|
28
|
+
# Keep only those cases where `l_i < l_j`.
|
|
29
|
+
pairwise_labels = ops.cast(
|
|
30
|
+
ops.greater(pairwise_labels_diff, 0), dtype=labels.dtype
|
|
31
|
+
)
|
|
32
|
+
if mask is not None:
|
|
33
|
+
valid_pairs = apply_pairwise_op(mask, ops.logical_and)
|
|
34
|
+
pairwise_labels = ops.multiply(
|
|
35
|
+
pairwise_labels, ops.cast(valid_pairs, dtype=pairwise_labels.dtype)
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
return pairwise_labels, pairwise_logits
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def process_loss_call_inputs(
|
|
42
|
+
y_true: types.Tensor,
|
|
43
|
+
y_pred: types.Tensor,
|
|
44
|
+
mask: Optional[types.Tensor] = None,
|
|
45
|
+
) -> tuple[types.Tensor, types.Tensor, Optional[types.Tensor]]:
|
|
46
|
+
"""
|
|
47
|
+
Utility function for processing inputs for pairwise losses.
|
|
48
|
+
|
|
49
|
+
This utility function does three things:
|
|
50
|
+
|
|
51
|
+
- Checks that `y_true`, `y_pred` are of rank 1 or 2;
|
|
52
|
+
- Checks that `y_true`, `y_pred`, `mask` have the same shape;
|
|
53
|
+
- Adds batch dimension if rank = 1.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
y_true_shape = ops.shape(y_true)
|
|
57
|
+
y_true_rank = len(y_true_shape)
|
|
58
|
+
y_pred_shape = ops.shape(y_pred)
|
|
59
|
+
y_pred_rank = len(y_pred_shape)
|
|
60
|
+
if mask is not None:
|
|
61
|
+
mask_shape = ops.shape(mask)
|
|
62
|
+
mask_rank = len(mask_shape)
|
|
63
|
+
|
|
64
|
+
# Check ranks and shapes.
|
|
65
|
+
def check_rank(
|
|
66
|
+
x_rank: int,
|
|
67
|
+
allowed_ranks: tuple[int, ...] = (1, 2),
|
|
68
|
+
tensor_name: Optional[str] = None,
|
|
69
|
+
) -> None:
|
|
70
|
+
if x_rank not in allowed_ranks:
|
|
71
|
+
raise ValueError(
|
|
72
|
+
f"`{tensor_name}` should have a rank from `{allowed_ranks}`."
|
|
73
|
+
f"Received: `{x_rank}`."
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
check_rank(y_true_rank, tensor_name="y_true")
|
|
77
|
+
check_rank(y_pred_rank, tensor_name="y_pred")
|
|
78
|
+
if mask is not None:
|
|
79
|
+
check_rank(mask_rank, tensor_name="mask")
|
|
80
|
+
if not check_shapes_compatible(y_true_shape, y_pred_shape):
|
|
81
|
+
raise ValueError(
|
|
82
|
+
"`y_true` and `y_pred` should have the same shape. Received: "
|
|
83
|
+
f"`y_true.shape` = {y_true_shape}, `y_pred.shape` = {y_pred_shape}."
|
|
84
|
+
)
|
|
85
|
+
if mask is not None and not check_shapes_compatible(
|
|
86
|
+
y_true_shape, mask_shape
|
|
87
|
+
):
|
|
88
|
+
raise ValueError(
|
|
89
|
+
"`y_true['labels']` and `y_true['mask']` should have the same "
|
|
90
|
+
f"shape. Received: `y_true['labels'].shape` = {y_true_shape}, "
|
|
91
|
+
f"`y_true['mask'].shape` = {mask_shape}."
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
if y_true_rank == 1:
|
|
95
|
+
y_true = ops.expand_dims(y_true, axis=0)
|
|
96
|
+
y_pred = ops.expand_dims(y_pred, axis=0)
|
|
97
|
+
if mask is not None:
|
|
98
|
+
mask = ops.expand_dims(mask, axis=0)
|
|
99
|
+
|
|
100
|
+
return y_true, y_pred, mask
|
|
@@ -3,6 +3,7 @@ pyproject.toml
|
|
|
3
3
|
keras_rs/__init__.py
|
|
4
4
|
keras_rs/api/__init__.py
|
|
5
5
|
keras_rs/api/layers/__init__.py
|
|
6
|
+
keras_rs/api/losses/__init__.py
|
|
6
7
|
keras_rs/src/__init__.py
|
|
7
8
|
keras_rs/src/api_export.py
|
|
8
9
|
keras_rs/src/types.py
|
|
@@ -15,8 +16,15 @@ keras_rs/src/layers/retrieval/__init__.py
|
|
|
15
16
|
keras_rs/src/layers/retrieval/brute_force_retrieval.py
|
|
16
17
|
keras_rs/src/layers/retrieval/hard_negative_mining.py
|
|
17
18
|
keras_rs/src/layers/retrieval/sampling_probability_correction.py
|
|
19
|
+
keras_rs/src/losses/__init__.py
|
|
20
|
+
keras_rs/src/losses/pairwise_hinge_loss.py
|
|
21
|
+
keras_rs/src/losses/pairwise_logistic_loss.py
|
|
22
|
+
keras_rs/src/losses/pairwise_loss.py
|
|
23
|
+
keras_rs/src/losses/pairwise_mean_squared_error.py
|
|
24
|
+
keras_rs/src/losses/pairwise_soft_zero_one_loss.py
|
|
18
25
|
keras_rs/src/utils/__init__.py
|
|
19
26
|
keras_rs/src/utils/keras_utils.py
|
|
27
|
+
keras_rs/src/utils/pairwise_loss_utils.py
|
|
20
28
|
keras_rs_nightly.egg-info/PKG-INFO
|
|
21
29
|
keras_rs_nightly.egg-info/SOURCES.txt
|
|
22
30
|
keras_rs_nightly.egg-info/dependency_links.txt
|
|
File without changes
|
{keras_rs_nightly-0.0.1.dev2025030803 → keras_rs_nightly-0.0.1.dev2025031003}/keras_rs/__init__.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{keras_rs_nightly-0.0.1.dev2025030803 → keras_rs_nightly-0.0.1.dev2025031003}/keras_rs/src/types.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{keras_rs_nightly-0.0.1.dev2025030803 → keras_rs_nightly-0.0.1.dev2025031003}/pyproject.toml
RENAMED
|
File without changes
|
|
File without changes
|