keras-rs-nightly 0.2.2.dev202508190331__py3-none-any.whl → 0.4.1.dev202601250348__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_rs/losses/__init__.py +1 -0
- keras_rs/src/layers/embedding/base_distributed_embedding.py +19 -10
- keras_rs/src/layers/embedding/distributed_embedding_config.py +2 -2
- keras_rs/src/layers/embedding/jax/distributed_embedding.py +133 -201
- keras_rs/src/layers/embedding/jax/embedding_lookup.py +25 -4
- keras_rs/src/layers/embedding/jax/embedding_utils.py +22 -401
- keras_rs/src/layers/embedding/tensorflow/config_conversion.py +26 -19
- keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py +22 -5
- keras_rs/src/losses/list_mle_loss.py +212 -0
- keras_rs/src/metrics/ranking_metrics_utils.py +21 -2
- keras_rs/src/utils/tpu_test_utils.py +120 -0
- keras_rs/src/version.py +1 -1
- {keras_rs_nightly-0.2.2.dev202508190331.dist-info → keras_rs_nightly-0.4.1.dev202601250348.dist-info}/METADATA +4 -3
- {keras_rs_nightly-0.2.2.dev202508190331.dist-info → keras_rs_nightly-0.4.1.dev202601250348.dist-info}/RECORD +16 -14
- {keras_rs_nightly-0.2.2.dev202508190331.dist-info → keras_rs_nightly-0.4.1.dev202601250348.dist-info}/WHEEL +1 -1
- {keras_rs_nightly-0.2.2.dev202508190331.dist-info → keras_rs_nightly-0.4.1.dev202601250348.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
import keras
|
|
4
|
+
from keras import ops
|
|
5
|
+
|
|
6
|
+
from keras_rs.src import types
|
|
7
|
+
from keras_rs.src.api_export import keras_rs_export
|
|
8
|
+
from keras_rs.src.metrics.ranking_metrics_utils import sort_by_scores
|
|
9
|
+
from keras_rs.src.metrics.utils import standardize_call_inputs_ranks
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@keras_rs_export("keras_rs.losses.ListMLELoss")
|
|
13
|
+
class ListMLELoss(keras.losses.Loss):
|
|
14
|
+
"""Implements ListMLE (Maximum Likelihood Estimation) loss for ranking.
|
|
15
|
+
|
|
16
|
+
ListMLE loss is a listwise ranking loss that maximizes the likelihood of
|
|
17
|
+
the ground truth ranking. It works by:
|
|
18
|
+
1. Sorting items by their relevance scores (labels)
|
|
19
|
+
2. Computing the probability of observing this ranking given the
|
|
20
|
+
predicted scores
|
|
21
|
+
3. Maximizing this likelihood (minimizing negative log-likelihood)
|
|
22
|
+
|
|
23
|
+
The loss is computed as the negative log-likelihood of the ground truth
|
|
24
|
+
ranking given the predicted scores:
|
|
25
|
+
|
|
26
|
+
```
|
|
27
|
+
loss = -sum(log(exp(s_i) / sum(exp(s_j) for j >= i)))
|
|
28
|
+
```
|
|
29
|
+
|
|
30
|
+
where s_i is the predicted score for item i in the sorted order.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
temperature: Temperature parameter for scaling logits. Higher values
|
|
34
|
+
make the probability distribution more uniform. Defaults to 1.0.
|
|
35
|
+
reduction: Type of reduction to apply to the loss. In almost all cases
|
|
36
|
+
this should be `"sum_over_batch_size"`. Supported options are
|
|
37
|
+
`"sum"`, `"sum_over_batch_size"`, `"mean"`,
|
|
38
|
+
`"mean_with_sample_weight"` or `None`. Defaults to
|
|
39
|
+
`"sum_over_batch_size"`.
|
|
40
|
+
name: Optional name for the loss instance.
|
|
41
|
+
dtype: The dtype of the loss's computations. Defaults to `None`.
|
|
42
|
+
|
|
43
|
+
Examples:
|
|
44
|
+
```python
|
|
45
|
+
# Basic usage
|
|
46
|
+
loss_fn = ListMLELoss()
|
|
47
|
+
|
|
48
|
+
# With temperature scaling
|
|
49
|
+
loss_fn = ListMLELoss(temperature=0.5)
|
|
50
|
+
|
|
51
|
+
# Example with synthetic data
|
|
52
|
+
y_true = [[3, 2, 1, 0]] # Relevance scores
|
|
53
|
+
y_pred = [[0.8, 0.6, 0.4, 0.2]] # Predicted scores
|
|
54
|
+
loss = loss_fn(y_true, y_pred)
|
|
55
|
+
```
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
def __init__(self, temperature: float = 1.0, **kwargs: Any) -> None:
|
|
59
|
+
super().__init__(**kwargs)
|
|
60
|
+
|
|
61
|
+
if temperature <= 0.0:
|
|
62
|
+
raise ValueError(
|
|
63
|
+
f"`temperature` should be a positive float. Received: "
|
|
64
|
+
f"`temperature` = {temperature}."
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
self.temperature = temperature
|
|
68
|
+
self._epsilon = 1e-10
|
|
69
|
+
|
|
70
|
+
def compute_unreduced_loss(
|
|
71
|
+
self,
|
|
72
|
+
labels: types.Tensor,
|
|
73
|
+
logits: types.Tensor,
|
|
74
|
+
mask: types.Tensor | None = None,
|
|
75
|
+
) -> tuple[types.Tensor, types.Tensor]:
|
|
76
|
+
"""Compute the unreduced ListMLE loss.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
labels: Ground truth relevance scores of
|
|
80
|
+
shape [batch_size,list_size].
|
|
81
|
+
logits: Predicted scores of shape [batch_size, list_size].
|
|
82
|
+
mask: Optional mask of shape [batch_size, list_size].
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
Tuple of (losses, weights) where losses has shape [batch_size, 1]
|
|
86
|
+
and weights has the same shape.
|
|
87
|
+
"""
|
|
88
|
+
|
|
89
|
+
valid_mask = ops.greater_equal(labels, ops.cast(0.0, labels.dtype))
|
|
90
|
+
|
|
91
|
+
if mask is not None:
|
|
92
|
+
valid_mask = ops.logical_and(
|
|
93
|
+
valid_mask, ops.cast(mask, dtype="bool")
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
num_valid_items = ops.sum(
|
|
97
|
+
ops.cast(valid_mask, dtype=labels.dtype), axis=1, keepdims=True
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
batch_has_valid_items = ops.greater(num_valid_items, 0.0)
|
|
101
|
+
|
|
102
|
+
labels_for_sorting = ops.where(
|
|
103
|
+
valid_mask, labels, ops.full_like(labels, -1e9)
|
|
104
|
+
)
|
|
105
|
+
logits_masked = ops.where(
|
|
106
|
+
valid_mask, logits, ops.full_like(logits, -1e9)
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
sorted_logits, sorted_valid_mask = sort_by_scores(
|
|
110
|
+
tensors_to_sort=[logits_masked, valid_mask],
|
|
111
|
+
scores=labels_for_sorting,
|
|
112
|
+
mask=None,
|
|
113
|
+
shuffle_ties=False,
|
|
114
|
+
seed=None,
|
|
115
|
+
)
|
|
116
|
+
sorted_logits = ops.divide(
|
|
117
|
+
sorted_logits, ops.cast(self.temperature, dtype=sorted_logits.dtype)
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
valid_logits_for_max = ops.where(
|
|
121
|
+
sorted_valid_mask, sorted_logits, ops.full_like(sorted_logits, -1e9)
|
|
122
|
+
)
|
|
123
|
+
raw_max = ops.max(valid_logits_for_max, axis=1, keepdims=True)
|
|
124
|
+
raw_max = ops.where(
|
|
125
|
+
batch_has_valid_items, raw_max, ops.zeros_like(raw_max)
|
|
126
|
+
)
|
|
127
|
+
sorted_logits = ops.subtract(sorted_logits, raw_max)
|
|
128
|
+
|
|
129
|
+
# Set invalid positions to very negative BEFORE exp
|
|
130
|
+
sorted_logits = ops.where(
|
|
131
|
+
sorted_valid_mask, sorted_logits, ops.full_like(sorted_logits, -1e9)
|
|
132
|
+
)
|
|
133
|
+
exp_logits = ops.exp(sorted_logits)
|
|
134
|
+
|
|
135
|
+
reversed_exp = ops.flip(exp_logits, axis=1)
|
|
136
|
+
reversed_cumsum = ops.cumsum(reversed_exp, axis=1)
|
|
137
|
+
cumsum_from_right = ops.flip(reversed_cumsum, axis=1)
|
|
138
|
+
|
|
139
|
+
log_normalizers = ops.log(cumsum_from_right + self._epsilon)
|
|
140
|
+
log_probs = ops.subtract(sorted_logits, log_normalizers)
|
|
141
|
+
|
|
142
|
+
log_probs = ops.where(
|
|
143
|
+
sorted_valid_mask, log_probs, ops.zeros_like(log_probs)
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
negative_log_likelihood = ops.negative(
|
|
147
|
+
ops.sum(log_probs, axis=1, keepdims=True)
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
negative_log_likelihood = ops.where(
|
|
151
|
+
batch_has_valid_items,
|
|
152
|
+
negative_log_likelihood,
|
|
153
|
+
ops.zeros_like(negative_log_likelihood),
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
weights = ops.ones_like(negative_log_likelihood)
|
|
157
|
+
|
|
158
|
+
return negative_log_likelihood, weights
|
|
159
|
+
|
|
160
|
+
def call(
|
|
161
|
+
self,
|
|
162
|
+
y_true: types.Tensor,
|
|
163
|
+
y_pred: types.Tensor,
|
|
164
|
+
) -> types.Tensor:
|
|
165
|
+
"""Compute the ListMLE loss.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
y_true: tensor or dict. Ground truth values. If tensor, of shape
|
|
169
|
+
`(list_size)` for unbatched inputs or `(batch_size, list_size)`
|
|
170
|
+
for batched inputs. If an item has a label of -1, it is ignored
|
|
171
|
+
in loss computation. If it is a dictionary, it should have two
|
|
172
|
+
keys: `"labels"` and `"mask"`. `"mask"` can be used to ignore
|
|
173
|
+
elements in loss computation.
|
|
174
|
+
y_pred: tensor. The predicted values, of shape `(list_size)` for
|
|
175
|
+
unbatched inputs or `(batch_size, list_size)` for batched
|
|
176
|
+
inputs. Should be of the same shape as `y_true`.
|
|
177
|
+
|
|
178
|
+
Returns:
|
|
179
|
+
The loss tensor of shape [batch_size].
|
|
180
|
+
"""
|
|
181
|
+
mask = None
|
|
182
|
+
if isinstance(y_true, dict):
|
|
183
|
+
if "labels" not in y_true:
|
|
184
|
+
raise ValueError(
|
|
185
|
+
'`"labels"` should be present in `y_true`. Received: '
|
|
186
|
+
f"`y_true` = {y_true}"
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
mask = y_true.get("mask", None)
|
|
190
|
+
y_true = y_true["labels"]
|
|
191
|
+
|
|
192
|
+
y_true = ops.convert_to_tensor(y_true)
|
|
193
|
+
y_pred = ops.convert_to_tensor(y_pred)
|
|
194
|
+
if mask is not None:
|
|
195
|
+
mask = ops.convert_to_tensor(mask)
|
|
196
|
+
|
|
197
|
+
y_true, y_pred, mask, _ = standardize_call_inputs_ranks(
|
|
198
|
+
y_true, y_pred, mask
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
losses, weights = self.compute_unreduced_loss(
|
|
202
|
+
labels=y_true, logits=y_pred, mask=mask
|
|
203
|
+
)
|
|
204
|
+
losses = ops.multiply(losses, weights)
|
|
205
|
+
losses = ops.squeeze(losses, axis=-1)
|
|
206
|
+
return losses
|
|
207
|
+
|
|
208
|
+
# getting config
|
|
209
|
+
def get_config(self) -> dict[str, Any]:
|
|
210
|
+
config: dict[str, Any] = super().get_config()
|
|
211
|
+
config.update({"temperature": self.temperature})
|
|
212
|
+
return config
|
|
@@ -85,6 +85,25 @@ def sort_by_scores(
|
|
|
85
85
|
else:
|
|
86
86
|
k = ops.minimum(k, max_possible_k)
|
|
87
87
|
|
|
88
|
+
# --- Work around for PyTorch instability ---
|
|
89
|
+
# Torch's `topk` is not stable with `sorted=True`, unlike JAX and TF.
|
|
90
|
+
# See:
|
|
91
|
+
# - https://github.com/pytorch/pytorch/issues/27542
|
|
92
|
+
# - https://github.com/pytorch/pytorch/issues/88227
|
|
93
|
+
#
|
|
94
|
+
# This small "stable offset" ensures deterministic tie-breaking for
|
|
95
|
+
# equal scores. We can remove this workaround once PyTorch adds a
|
|
96
|
+
# `stable=True` flag for topk.
|
|
97
|
+
|
|
98
|
+
if keras.backend.backend() == "torch" and not shuffle_ties:
|
|
99
|
+
list_size = ops.shape(scores)[1]
|
|
100
|
+
indices = ops.arange(list_size)
|
|
101
|
+
indices = ops.expand_dims(indices, axis=0)
|
|
102
|
+
indices = ops.broadcast_to(indices, ops.shape(scores))
|
|
103
|
+
stable_offset = ops.cast(indices, scores.dtype) * 1e-6
|
|
104
|
+
scores = ops.subtract(scores, stable_offset)
|
|
105
|
+
# --- End FIX ---
|
|
106
|
+
|
|
88
107
|
# Shuffle ties randomly, and push masked values to the beginning.
|
|
89
108
|
shuffled_indices = None
|
|
90
109
|
if shuffle_ties or mask is not None:
|
|
@@ -205,12 +224,12 @@ def get_list_weights(
|
|
|
205
224
|
return final_weights
|
|
206
225
|
|
|
207
226
|
|
|
208
|
-
@keras.saving.register_keras_serializable() # type: ignore[
|
|
227
|
+
@keras.saving.register_keras_serializable() # type: ignore[untyped-decorator]
|
|
209
228
|
def default_gain_fn(label: types.Tensor) -> types.Tensor:
|
|
210
229
|
return ops.subtract(ops.power(2.0, label), 1.0)
|
|
211
230
|
|
|
212
231
|
|
|
213
|
-
@keras.saving.register_keras_serializable() # type: ignore[
|
|
232
|
+
@keras.saving.register_keras_serializable() # type: ignore[untyped-decorator]
|
|
214
233
|
def default_rank_discount_fn(rank: types.Tensor) -> types.Tensor:
|
|
215
234
|
return ops.divide(
|
|
216
235
|
ops.cast(1, dtype=rank.dtype),
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
import contextlib
|
|
2
|
+
import os
|
|
3
|
+
import threading
|
|
4
|
+
from typing import Any, Callable, ContextManager, Optional, Tuple, Union
|
|
5
|
+
|
|
6
|
+
import keras
|
|
7
|
+
import tensorflow as tf
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class DummyStrategy:
|
|
11
|
+
def scope(self) -> ContextManager[None]:
|
|
12
|
+
return contextlib.nullcontext()
|
|
13
|
+
|
|
14
|
+
@property
|
|
15
|
+
def num_replicas_in_sync(self) -> int:
|
|
16
|
+
return 1
|
|
17
|
+
|
|
18
|
+
def run(self, fn: Callable[..., Any], args: Tuple[Any, ...]) -> Any:
|
|
19
|
+
return fn(*args)
|
|
20
|
+
|
|
21
|
+
def experimental_distribute_dataset(
|
|
22
|
+
self, dataset: Any, options: Optional[Any] = None
|
|
23
|
+
) -> Any:
|
|
24
|
+
del options
|
|
25
|
+
return dataset
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class JaxDummyStrategy(DummyStrategy):
|
|
29
|
+
@property
|
|
30
|
+
def num_replicas_in_sync(self) -> Any:
|
|
31
|
+
import jax
|
|
32
|
+
|
|
33
|
+
return jax.device_count("tpu")
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
StrategyType = Union[tf.distribute.Strategy, DummyStrategy]
|
|
37
|
+
|
|
38
|
+
_shared_strategy: Optional[StrategyType] = None
|
|
39
|
+
_lock = threading.Lock()
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def create_tpu_strategy() -> Optional[StrategyType]:
|
|
43
|
+
"""Initializes the TPU system and returns a TPUStrategy."""
|
|
44
|
+
print("Attempting to create TPUStrategy...")
|
|
45
|
+
try:
|
|
46
|
+
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="")
|
|
47
|
+
tf.config.experimental_connect_to_cluster(resolver)
|
|
48
|
+
tf.tpu.experimental.initialize_tpu_system(resolver)
|
|
49
|
+
strategy = tf.distribute.TPUStrategy(resolver)
|
|
50
|
+
print(
|
|
51
|
+
"TPUStrategy created successfully."
|
|
52
|
+
"Devices: {strategy.extended.num_replicas_in_sync}"
|
|
53
|
+
)
|
|
54
|
+
return strategy
|
|
55
|
+
except Exception as e:
|
|
56
|
+
print(f"Error creating TPUStrategy: {e}")
|
|
57
|
+
return None
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def get_shared_tpu_strategy() -> Optional[StrategyType]:
|
|
61
|
+
"""
|
|
62
|
+
Returns a session-wide shared TPUStrategy instance.
|
|
63
|
+
Creates the instance on the first call.
|
|
64
|
+
Returns None if not in a TPU environment or if creation fails.
|
|
65
|
+
"""
|
|
66
|
+
global _shared_strategy
|
|
67
|
+
if _shared_strategy is not None:
|
|
68
|
+
return _shared_strategy
|
|
69
|
+
|
|
70
|
+
with _lock:
|
|
71
|
+
if _shared_strategy is None:
|
|
72
|
+
if "TPU_NAME" not in os.environ:
|
|
73
|
+
_shared_strategy = DummyStrategy()
|
|
74
|
+
return _shared_strategy
|
|
75
|
+
if keras.backend.backend() == "tensorflow":
|
|
76
|
+
resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
|
|
77
|
+
tf.config.experimental_connect_to_cluster(resolver)
|
|
78
|
+
topology = tf.tpu.experimental.initialize_tpu_system(resolver)
|
|
79
|
+
tpu_metadata = resolver.get_tpu_system_metadata()
|
|
80
|
+
device_assignment = tf.tpu.experimental.DeviceAssignment.build(
|
|
81
|
+
topology, num_replicas=tpu_metadata.num_hosts
|
|
82
|
+
)
|
|
83
|
+
_shared_strategy = tf.distribute.TPUStrategy(
|
|
84
|
+
resolver, experimental_device_assignment=device_assignment
|
|
85
|
+
)
|
|
86
|
+
print("### num_replicas", _shared_strategy.num_replicas_in_sync)
|
|
87
|
+
elif keras.backend.backend() == "jax":
|
|
88
|
+
_shared_strategy = JaxDummyStrategy()
|
|
89
|
+
print("### num_replicas", _shared_strategy.num_replicas_in_sync)
|
|
90
|
+
else:
|
|
91
|
+
_shared_strategy = DummyStrategy()
|
|
92
|
+
return _shared_strategy
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def run_with_strategy(
|
|
96
|
+
strategy: Any,
|
|
97
|
+
fn: Callable[..., Any],
|
|
98
|
+
*args: Any,
|
|
99
|
+
jit_compile: bool = False,
|
|
100
|
+
**kwargs: Any,
|
|
101
|
+
) -> Any:
|
|
102
|
+
"""
|
|
103
|
+
Final wrapper fix: Flattens allowed kwargs into positional args before
|
|
104
|
+
entering tf.function to guarantee a fixed graph signature.
|
|
105
|
+
"""
|
|
106
|
+
if keras.backend.backend() == "tensorflow":
|
|
107
|
+
all_inputs = (args, kwargs)
|
|
108
|
+
|
|
109
|
+
@tf.function(jit_compile=jit_compile) # type: ignore[untyped-decorator]
|
|
110
|
+
def tf_function_wrapper(input_tuple: Tuple[Any, Any]) -> Any:
|
|
111
|
+
core_args, core_kwargs = input_tuple
|
|
112
|
+
if core_kwargs:
|
|
113
|
+
return strategy.run(fn, args=core_args, kwargs=core_kwargs)
|
|
114
|
+
else:
|
|
115
|
+
return strategy.run(fn, args=core_args)
|
|
116
|
+
|
|
117
|
+
return tf_function_wrapper(all_inputs)
|
|
118
|
+
else:
|
|
119
|
+
assert not jit_compile
|
|
120
|
+
return fn(*args, **kwargs)
|
keras_rs/src/version.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: keras-rs-nightly
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.4.1.dev202601250348
|
|
4
4
|
Summary: Multi-backend recommender systems with Keras 3.
|
|
5
5
|
Author-email: Keras team <keras-users@googlegroups.com>
|
|
6
6
|
License: Apache License 2.0
|
|
@@ -8,8 +8,9 @@ Project-URL: Home, https://keras.io/keras_rs
|
|
|
8
8
|
Project-URL: Repository, https://github.com/keras-team/keras-rs
|
|
9
9
|
Classifier: Development Status :: 3 - Alpha
|
|
10
10
|
Classifier: Programming Language :: Python :: 3
|
|
11
|
-
Classifier: Programming Language :: Python :: 3.10
|
|
12
11
|
Classifier: Programming Language :: Python :: 3.11
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
13
14
|
Classifier: Programming Language :: Python :: 3 :: Only
|
|
14
15
|
Classifier: Operating System :: Unix
|
|
15
16
|
Classifier: Operating System :: Microsoft :: Windows
|
|
@@ -17,7 +18,7 @@ Classifier: Operating System :: MacOS
|
|
|
17
18
|
Classifier: Intended Audience :: Science/Research
|
|
18
19
|
Classifier: Topic :: Scientific/Engineering
|
|
19
20
|
Classifier: Topic :: Software Development
|
|
20
|
-
Requires-Python: >=3.
|
|
21
|
+
Requires-Python: >=3.11
|
|
21
22
|
Description-Content-Type: text/markdown
|
|
22
23
|
Requires-Dist: keras
|
|
23
24
|
Requires-Dist: ml-dtypes
|
|
@@ -1,26 +1,26 @@
|
|
|
1
1
|
keras_rs/__init__.py,sha256=8sjHiPN2GhUqAq4V7Vh4FLLqYw20-jgdI26ZKX5sg6M,350
|
|
2
2
|
keras_rs/layers/__init__.py,sha256=ERqFu1R8FgeES5rO5QwauArbNCm8auj-AiCURtsG6Ro,1332
|
|
3
|
-
keras_rs/losses/__init__.py,sha256=
|
|
3
|
+
keras_rs/losses/__init__.py,sha256=WyyrxhWrayt-Hm6gSmZ5CPZifbPx0egDIothGi0Dpjk,646
|
|
4
4
|
keras_rs/metrics/__init__.py,sha256=Qxpf6OFooIL9TIn2l3WgOea3HFRG0hq02glPAxtMZ9c,580
|
|
5
5
|
keras_rs/src/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
6
6
|
keras_rs/src/api_export.py,sha256=RsmG-DvO-cdFeAF9W6LRzms0kvtm-Yp9BAA_d-952zI,510
|
|
7
7
|
keras_rs/src/types.py,sha256=1A-oLRdX1-f2DsVZBcNl8qNsaH8pM-gnleLT9FWZWBw,1189
|
|
8
|
-
keras_rs/src/version.py,sha256=
|
|
8
|
+
keras_rs/src/version.py,sha256=7F19b6JBtXTYKOxUk6K2Y_nS6JGidlVq305CdH3935o,224
|
|
9
9
|
keras_rs/src/layers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
10
10
|
keras_rs/src/layers/embedding/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
11
|
-
keras_rs/src/layers/embedding/base_distributed_embedding.py,sha256=
|
|
11
|
+
keras_rs/src/layers/embedding/base_distributed_embedding.py,sha256=REpNKTvkS4eoAi9t1DohnDMfWwMgfxDN4ByWji3aALM,45906
|
|
12
12
|
keras_rs/src/layers/embedding/distributed_embedding.py,sha256=94jxUHoGK3Gs9yfV0KxFTuqPo7XFnhgCNlO2FEeiSgM,1072
|
|
13
|
-
keras_rs/src/layers/embedding/distributed_embedding_config.py,sha256=
|
|
13
|
+
keras_rs/src/layers/embedding/distributed_embedding_config.py,sha256=L41x6W1xcXI-3m94nOh_OsHn6OYpoynakKJvNboJnvE,5762
|
|
14
14
|
keras_rs/src/layers/embedding/embed_reduce.py,sha256=c-MnEw1-KWs0jTf0JJ_ZBOY-9hRkiFyu989Dof3DnS8,12343
|
|
15
15
|
keras_rs/src/layers/embedding/jax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
16
16
|
keras_rs/src/layers/embedding/jax/checkpoint_utils.py,sha256=wZ4I5WZVNg5WnrD2j7nhAXgLzDc7xMrUEkSAOx5Sz5c,3495
|
|
17
17
|
keras_rs/src/layers/embedding/jax/config_conversion.py,sha256=Di1UzRwLgGHd7RuWYJMj2mCOr1u9MseFEWaYKnwD9Bs,16742
|
|
18
|
-
keras_rs/src/layers/embedding/jax/distributed_embedding.py,sha256=
|
|
19
|
-
keras_rs/src/layers/embedding/jax/embedding_lookup.py,sha256=
|
|
20
|
-
keras_rs/src/layers/embedding/jax/embedding_utils.py,sha256=
|
|
18
|
+
keras_rs/src/layers/embedding/jax/distributed_embedding.py,sha256=D2tJehcy10w30KzYKnnV30WwCUvgLUm2YQM3Twwge9M,32338
|
|
19
|
+
keras_rs/src/layers/embedding/jax/embedding_lookup.py,sha256=a90tWTbU9tkFdESG3xir9PTtcvb1cmYR8vl5dDK9PSY,8703
|
|
20
|
+
keras_rs/src/layers/embedding/jax/embedding_utils.py,sha256=5rQGli4Qflg0BU2-j_-4xbBxSqopqbtjkY2KKYWq64Y,7329
|
|
21
21
|
keras_rs/src/layers/embedding/tensorflow/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
22
|
-
keras_rs/src/layers/embedding/tensorflow/config_conversion.py,sha256=
|
|
23
|
-
keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py,sha256
|
|
22
|
+
keras_rs/src/layers/embedding/tensorflow/config_conversion.py,sha256=HpuDthRQQ3X0EO8dW6OAdcgTODkujZlx_swgreVwXyk,13220
|
|
23
|
+
keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py,sha256=rkxnPzMHmq82FEzrLrO13NhDHPiX-3PxRM3AUE6Rv10,18050
|
|
24
24
|
keras_rs/src/layers/feature_interaction/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
25
25
|
keras_rs/src/layers/feature_interaction/dot_interaction.py,sha256=Rs8xIHXNWQNiwjp_xzvQRmTSV1AyhJjDgVc3K5pTmrQ,8530
|
|
26
26
|
keras_rs/src/layers/feature_interaction/feature_cross.py,sha256=Wq_eQvO0WTRlep69QbKi8TwY8bnFoF9vreP_j6ZHNFE,8666
|
|
@@ -31,6 +31,7 @@ keras_rs/src/layers/retrieval/remove_accidental_hits.py,sha256=WKoIhUSc6SvbgLXcS
|
|
|
31
31
|
keras_rs/src/layers/retrieval/retrieval.py,sha256=SFxMdooUhZy854SLZbpoyZR1Md4bHnpf7P077oVjjtU,4162
|
|
32
32
|
keras_rs/src/layers/retrieval/sampling_probability_correction.py,sha256=3zD6LInxhyIvyujMleGqiuoPKsna2oaTN6JU6xMnW_M,1977
|
|
33
33
|
keras_rs/src/losses/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
34
|
+
keras_rs/src/losses/list_mle_loss.py,sha256=NKRV_ZJUXFI1qG9_ugqxRyafRHreuUIokA7mbnsogBo,7433
|
|
34
35
|
keras_rs/src/losses/pairwise_hinge_loss.py,sha256=tONOJpcwCw1mybwvyx8dAy5t6dDmlIn00enzWfQLXpQ,3049
|
|
35
36
|
keras_rs/src/losses/pairwise_logistic_loss.py,sha256=40PFdCHDM7CLunT_PE3RbgxROVImw13dgVL3o3nzeNg,3473
|
|
36
37
|
keras_rs/src/losses/pairwise_loss.py,sha256=Oydk8e7AGU0Mc9yvm6ccr_XDDfUe8EZlS4JJgyxKUm4,6197
|
|
@@ -44,13 +45,14 @@ keras_rs/src/metrics/mean_reciprocal_rank.py,sha256=vr3ZZjpGYy2N-N7stcIm5elfHe9A
|
|
|
44
45
|
keras_rs/src/metrics/ndcg.py,sha256=ZBaKqV57K7jlto6ZVMxFNNRLdhzbLhdAR8TgDexjSjg,6922
|
|
45
46
|
keras_rs/src/metrics/precision_at_k.py,sha256=Dj5R-rT_Yd5hAsk4f-BlNMujfgIdPXnFVGOw9u7BIZQ,4038
|
|
46
47
|
keras_rs/src/metrics/ranking_metric.py,sha256=Lcl-Tt6HlI0f2wQpvAJ2M4mm5qCTZm-IgnLjjSEeNXE,10655
|
|
47
|
-
keras_rs/src/metrics/ranking_metrics_utils.py,sha256=
|
|
48
|
+
keras_rs/src/metrics/ranking_metrics_utils.py,sha256=rtCy_8T5CbQ4UKr3ykFRaM_wZrQ-IqEJa_VnVMTH8nQ,9644
|
|
48
49
|
keras_rs/src/metrics/recall_at_k.py,sha256=ssnQJC42KLN28cGrmzM-qR4M4iPqiQzWM2MfwYMq4ZE,3701
|
|
49
50
|
keras_rs/src/metrics/utils.py,sha256=fGTo8j0ykVE5Y3yQCS2orSFcHY20Uxt0NazyPsybUsw,2471
|
|
50
51
|
keras_rs/src/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
51
52
|
keras_rs/src/utils/doc_string_utils.py,sha256=CmqomepmaYcvpACpXEXkrJb8DMnvIgmYK-lJ53lYarY,1675
|
|
52
53
|
keras_rs/src/utils/keras_utils.py,sha256=dc-NFzs3a-qmRw0vBDiMslPLfrm9yymGduLWesXPhuY,2123
|
|
53
|
-
|
|
54
|
-
keras_rs_nightly-0.
|
|
55
|
-
keras_rs_nightly-0.
|
|
56
|
-
keras_rs_nightly-0.
|
|
54
|
+
keras_rs/src/utils/tpu_test_utils.py,sha256=mQVBrI-CCBbXwQxBq1yDKGMwYhm4g4k3_AaSy_sCs0U,4028
|
|
55
|
+
keras_rs_nightly-0.4.1.dev202601250348.dist-info/METADATA,sha256=APNEvzS76AMD7Km5fXAWMnmc7VJspHKopOCOAd2xM1s,5324
|
|
56
|
+
keras_rs_nightly-0.4.1.dev202601250348.dist-info/WHEEL,sha256=qELbo2s1Yzl39ZmrAibXA2jjPLUYfnVhUNTlyF1rq0Y,92
|
|
57
|
+
keras_rs_nightly-0.4.1.dev202601250348.dist-info/top_level.txt,sha256=pWs8X78Z0cn6lfcIb9VYOW5UeJ-TpoaO9dByzo7_FFo,9
|
|
58
|
+
keras_rs_nightly-0.4.1.dev202601250348.dist-info/RECORD,,
|
|
File without changes
|