keras-rs-nightly 0.3.1.dev202510280332__py3-none-any.whl → 0.3.1.dev202511090334__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,6 +4,7 @@ This file was autogenerated. Do not edit it by hand,
4
4
  since your modifications would be overwritten.
5
5
  """
6
6
 
7
+ from keras_rs.src.losses.list_mle_loss import ListMLELoss as ListMLELoss
7
8
  from keras_rs.src.losses.pairwise_hinge_loss import (
8
9
  PairwiseHingeLoss as PairwiseHingeLoss,
9
10
  )
@@ -445,6 +445,49 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
445
445
  table_specs = embedding.get_table_specs(feature_specs)
446
446
  table_stacks = jte_table_stacking.get_table_stacks(table_specs)
447
447
 
448
+ # Update stacked table stats to max of values across involved tables.
449
+ max_ids_per_partition = {}
450
+ max_unique_ids_per_partition = {}
451
+ required_buffer_size_per_device = {}
452
+ id_drop_counters = {}
453
+ for stack_name, stack in table_stacks.items():
454
+ max_ids_per_partition[stack_name] = np.max(
455
+ np.asarray(
456
+ [s.max_ids_per_partition for s in stack], dtype=np.int32
457
+ )
458
+ )
459
+ max_unique_ids_per_partition[stack_name] = np.max(
460
+ np.asarray(
461
+ [s.max_unique_ids_per_partition for s in stack],
462
+ dtype=np.int32,
463
+ )
464
+ )
465
+
466
+ # Only set the suggested buffer size if set on any individual table.
467
+ valid_buffer_sizes = [
468
+ s.suggested_coo_buffer_size_per_device
469
+ for s in stack
470
+ if s.suggested_coo_buffer_size_per_device is not None
471
+ ]
472
+ if valid_buffer_sizes:
473
+ required_buffer_size_per_device[stack_name] = np.max(
474
+ np.asarray(valid_buffer_sizes, dtype=np.int32)
475
+ )
476
+
477
+ id_drop_counters[stack_name] = 0
478
+
479
+ aggregated_stats = embedding.SparseDenseMatmulInputStats(
480
+ max_ids_per_partition=max_ids_per_partition,
481
+ max_unique_ids_per_partition=max_unique_ids_per_partition,
482
+ required_buffer_size_per_sc=required_buffer_size_per_device,
483
+ id_drop_counters=id_drop_counters,
484
+ )
485
+ embedding.update_preprocessing_parameters(
486
+ feature_specs,
487
+ aggregated_stats,
488
+ num_sc_per_device,
489
+ )
490
+
448
491
  # Create variables for all stacked tables and slot variables.
449
492
  with sparsecore_distribution.scope():
450
493
  self._table_and_slot_variables = {
@@ -0,0 +1,212 @@
1
+ from typing import Any
2
+
3
+ import keras
4
+ from keras import ops
5
+
6
+ from keras_rs.src import types
7
+ from keras_rs.src.api_export import keras_rs_export
8
+ from keras_rs.src.metrics.ranking_metrics_utils import sort_by_scores
9
+ from keras_rs.src.metrics.utils import standardize_call_inputs_ranks
10
+
11
+
12
+ @keras_rs_export("keras_rs.losses.ListMLELoss")
13
+ class ListMLELoss(keras.losses.Loss):
14
+ """Implements ListMLE (Maximum Likelihood Estimation) loss for ranking.
15
+
16
+ ListMLE loss is a listwise ranking loss that maximizes the likelihood of
17
+ the ground truth ranking. It works by:
18
+ 1. Sorting items by their relevance scores (labels)
19
+ 2. Computing the probability of observing this ranking given the
20
+ predicted scores
21
+ 3. Maximizing this likelihood (minimizing negative log-likelihood)
22
+
23
+ The loss is computed as the negative log-likelihood of the ground truth
24
+ ranking given the predicted scores:
25
+
26
+ ```
27
+ loss = -sum(log(exp(s_i) / sum(exp(s_j) for j >= i)))
28
+ ```
29
+
30
+ where s_i is the predicted score for item i in the sorted order.
31
+
32
+ Args:
33
+ temperature: Temperature parameter for scaling logits. Higher values
34
+ make the probability distribution more uniform. Defaults to 1.0.
35
+ reduction: Type of reduction to apply to the loss. In almost all cases
36
+ this should be `"sum_over_batch_size"`. Supported options are
37
+ `"sum"`, `"sum_over_batch_size"`, `"mean"`,
38
+ `"mean_with_sample_weight"` or `None`. Defaults to
39
+ `"sum_over_batch_size"`.
40
+ name: Optional name for the loss instance.
41
+ dtype: The dtype of the loss's computations. Defaults to `None`.
42
+
43
+ Examples:
44
+ ```python
45
+ # Basic usage
46
+ loss_fn = ListMLELoss()
47
+
48
+ # With temperature scaling
49
+ loss_fn = ListMLELoss(temperature=0.5)
50
+
51
+ # Example with synthetic data
52
+ y_true = [[3, 2, 1, 0]] # Relevance scores
53
+ y_pred = [[0.8, 0.6, 0.4, 0.2]] # Predicted scores
54
+ loss = loss_fn(y_true, y_pred)
55
+ ```
56
+ """
57
+
58
+ def __init__(self, temperature: float = 1.0, **kwargs: Any) -> None:
59
+ super().__init__(**kwargs)
60
+
61
+ if temperature <= 0.0:
62
+ raise ValueError(
63
+ f"`temperature` should be a positive float. Received: "
64
+ f"`temperature` = {temperature}."
65
+ )
66
+
67
+ self.temperature = temperature
68
+ self._epsilon = 1e-10
69
+
70
+ def compute_unreduced_loss(
71
+ self,
72
+ labels: types.Tensor,
73
+ logits: types.Tensor,
74
+ mask: types.Tensor | None = None,
75
+ ) -> tuple[types.Tensor, types.Tensor]:
76
+ """Compute the unreduced ListMLE loss.
77
+
78
+ Args:
79
+ labels: Ground truth relevance scores of
80
+ shape [batch_size,list_size].
81
+ logits: Predicted scores of shape [batch_size, list_size].
82
+ mask: Optional mask of shape [batch_size, list_size].
83
+
84
+ Returns:
85
+ Tuple of (losses, weights) where losses has shape [batch_size, 1]
86
+ and weights has the same shape.
87
+ """
88
+
89
+ valid_mask = ops.greater_equal(labels, ops.cast(0.0, labels.dtype))
90
+
91
+ if mask is not None:
92
+ valid_mask = ops.logical_and(
93
+ valid_mask, ops.cast(mask, dtype="bool")
94
+ )
95
+
96
+ num_valid_items = ops.sum(
97
+ ops.cast(valid_mask, dtype=labels.dtype), axis=1, keepdims=True
98
+ )
99
+
100
+ batch_has_valid_items = ops.greater(num_valid_items, 0.0)
101
+
102
+ labels_for_sorting = ops.where(
103
+ valid_mask, labels, ops.full_like(labels, -1e9)
104
+ )
105
+ logits_masked = ops.where(
106
+ valid_mask, logits, ops.full_like(logits, -1e9)
107
+ )
108
+
109
+ sorted_logits, sorted_valid_mask = sort_by_scores(
110
+ tensors_to_sort=[logits_masked, valid_mask],
111
+ scores=labels_for_sorting,
112
+ mask=None,
113
+ shuffle_ties=False,
114
+ seed=None,
115
+ )
116
+ sorted_logits = ops.divide(
117
+ sorted_logits, ops.cast(self.temperature, dtype=sorted_logits.dtype)
118
+ )
119
+
120
+ valid_logits_for_max = ops.where(
121
+ sorted_valid_mask, sorted_logits, ops.full_like(sorted_logits, -1e9)
122
+ )
123
+ raw_max = ops.max(valid_logits_for_max, axis=1, keepdims=True)
124
+ raw_max = ops.where(
125
+ batch_has_valid_items, raw_max, ops.zeros_like(raw_max)
126
+ )
127
+ sorted_logits = ops.subtract(sorted_logits, raw_max)
128
+
129
+ # Set invalid positions to very negative BEFORE exp
130
+ sorted_logits = ops.where(
131
+ sorted_valid_mask, sorted_logits, ops.full_like(sorted_logits, -1e9)
132
+ )
133
+ exp_logits = ops.exp(sorted_logits)
134
+
135
+ reversed_exp = ops.flip(exp_logits, axis=1)
136
+ reversed_cumsum = ops.cumsum(reversed_exp, axis=1)
137
+ cumsum_from_right = ops.flip(reversed_cumsum, axis=1)
138
+
139
+ log_normalizers = ops.log(cumsum_from_right + self._epsilon)
140
+ log_probs = ops.subtract(sorted_logits, log_normalizers)
141
+
142
+ log_probs = ops.where(
143
+ sorted_valid_mask, log_probs, ops.zeros_like(log_probs)
144
+ )
145
+
146
+ negative_log_likelihood = ops.negative(
147
+ ops.sum(log_probs, axis=1, keepdims=True)
148
+ )
149
+
150
+ negative_log_likelihood = ops.where(
151
+ batch_has_valid_items,
152
+ negative_log_likelihood,
153
+ ops.zeros_like(negative_log_likelihood),
154
+ )
155
+
156
+ weights = ops.ones_like(negative_log_likelihood)
157
+
158
+ return negative_log_likelihood, weights
159
+
160
+ def call(
161
+ self,
162
+ y_true: types.Tensor,
163
+ y_pred: types.Tensor,
164
+ ) -> types.Tensor:
165
+ """Compute the ListMLE loss.
166
+
167
+ Args:
168
+ y_true: tensor or dict. Ground truth values. If tensor, of shape
169
+ `(list_size)` for unbatched inputs or `(batch_size, list_size)`
170
+ for batched inputs. If an item has a label of -1, it is ignored
171
+ in loss computation. If it is a dictionary, it should have two
172
+ keys: `"labels"` and `"mask"`. `"mask"` can be used to ignore
173
+ elements in loss computation.
174
+ y_pred: tensor. The predicted values, of shape `(list_size)` for
175
+ unbatched inputs or `(batch_size, list_size)` for batched
176
+ inputs. Should be of the same shape as `y_true`.
177
+
178
+ Returns:
179
+ The loss tensor of shape [batch_size].
180
+ """
181
+ mask = None
182
+ if isinstance(y_true, dict):
183
+ if "labels" not in y_true:
184
+ raise ValueError(
185
+ '`"labels"` should be present in `y_true`. Received: '
186
+ f"`y_true` = {y_true}"
187
+ )
188
+
189
+ mask = y_true.get("mask", None)
190
+ y_true = y_true["labels"]
191
+
192
+ y_true = ops.convert_to_tensor(y_true)
193
+ y_pred = ops.convert_to_tensor(y_pred)
194
+ if mask is not None:
195
+ mask = ops.convert_to_tensor(mask)
196
+
197
+ y_true, y_pred, mask, _ = standardize_call_inputs_ranks(
198
+ y_true, y_pred, mask
199
+ )
200
+
201
+ losses, weights = self.compute_unreduced_loss(
202
+ labels=y_true, logits=y_pred, mask=mask
203
+ )
204
+ losses = ops.multiply(losses, weights)
205
+ losses = ops.squeeze(losses, axis=-1)
206
+ return losses
207
+
208
+ # getting config
209
+ def get_config(self) -> dict[str, Any]:
210
+ config: dict[str, Any] = super().get_config()
211
+ config.update({"temperature": self.temperature})
212
+ return config
@@ -85,6 +85,25 @@ def sort_by_scores(
85
85
  else:
86
86
  k = ops.minimum(k, max_possible_k)
87
87
 
88
+ # --- Work around for PyTorch instability ---
89
+ # Torch's `topk` is not stable with `sorted=True`, unlike JAX and TF.
90
+ # See:
91
+ # - https://github.com/pytorch/pytorch/issues/27542
92
+ # - https://github.com/pytorch/pytorch/issues/88227
93
+ #
94
+ # This small "stable offset" ensures deterministic tie-breaking for
95
+ # equal scores. We can remove this workaround once PyTorch adds a
96
+ # `stable=True` flag for topk.
97
+
98
+ if keras.backend.backend() == "torch" and not shuffle_ties:
99
+ list_size = ops.shape(scores)[1]
100
+ indices = ops.arange(list_size)
101
+ indices = ops.expand_dims(indices, axis=0)
102
+ indices = ops.broadcast_to(indices, ops.shape(scores))
103
+ stable_offset = ops.cast(indices, scores.dtype) * 1e-6
104
+ scores = ops.subtract(scores, stable_offset)
105
+ # --- End FIX ---
106
+
88
107
  # Shuffle ties randomly, and push masked values to the beginning.
89
108
  shuffled_indices = None
90
109
  if shuffle_ties or mask is not None:
keras_rs/src/version.py CHANGED
@@ -1,7 +1,7 @@
1
1
  from keras_rs.src.api_export import keras_rs_export
2
2
 
3
3
  # Unique source of truth for the version number.
4
- __version__ = "0.3.1.dev202510280332"
4
+ __version__ = "0.3.1.dev202511090334"
5
5
 
6
6
 
7
7
  @keras_rs_export("keras_rs.version")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: keras-rs-nightly
3
- Version: 0.3.1.dev202510280332
3
+ Version: 0.3.1.dev202511090334
4
4
  Summary: Multi-backend recommender systems with Keras 3.
5
5
  Author-email: Keras team <keras-users@googlegroups.com>
6
6
  License: Apache License 2.0
@@ -1,11 +1,11 @@
1
1
  keras_rs/__init__.py,sha256=8sjHiPN2GhUqAq4V7Vh4FLLqYw20-jgdI26ZKX5sg6M,350
2
2
  keras_rs/layers/__init__.py,sha256=ERqFu1R8FgeES5rO5QwauArbNCm8auj-AiCURtsG6Ro,1332
3
- keras_rs/losses/__init__.py,sha256=m04QOgxIUfJ2MvCUKLgEof-UbSNKgUYLPnY-D9NAclI,573
3
+ keras_rs/losses/__init__.py,sha256=WyyrxhWrayt-Hm6gSmZ5CPZifbPx0egDIothGi0Dpjk,646
4
4
  keras_rs/metrics/__init__.py,sha256=Qxpf6OFooIL9TIn2l3WgOea3HFRG0hq02glPAxtMZ9c,580
5
5
  keras_rs/src/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
6
  keras_rs/src/api_export.py,sha256=RsmG-DvO-cdFeAF9W6LRzms0kvtm-Yp9BAA_d-952zI,510
7
7
  keras_rs/src/types.py,sha256=1A-oLRdX1-f2DsVZBcNl8qNsaH8pM-gnleLT9FWZWBw,1189
8
- keras_rs/src/version.py,sha256=LBNXhlFa6P1nQhY9SUb1spImnuIwMrfVQVmJfnUfGGM,224
8
+ keras_rs/src/version.py,sha256=u35xmyEk_om8XN89zlRKP9LtB-7xy301fL3uSADKnRE,224
9
9
  keras_rs/src/layers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
10
  keras_rs/src/layers/embedding/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
11
  keras_rs/src/layers/embedding/base_distributed_embedding.py,sha256=RkXZ6notj3Cq6ryR9w30Wb8UlaWjLcUK2Os9ZUQvuhY,45568
@@ -15,7 +15,7 @@ keras_rs/src/layers/embedding/embed_reduce.py,sha256=c-MnEw1-KWs0jTf0JJ_ZBOY-9hR
15
15
  keras_rs/src/layers/embedding/jax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
16
16
  keras_rs/src/layers/embedding/jax/checkpoint_utils.py,sha256=wZ4I5WZVNg5WnrD2j7nhAXgLzDc7xMrUEkSAOx5Sz5c,3495
17
17
  keras_rs/src/layers/embedding/jax/config_conversion.py,sha256=Di1UzRwLgGHd7RuWYJMj2mCOr1u9MseFEWaYKnwD9Bs,16742
18
- keras_rs/src/layers/embedding/jax/distributed_embedding.py,sha256=O3G0AFRzukYdXPRyx7ZDqDvNgJrcbFwTCYTHigfdiKw,29628
18
+ keras_rs/src/layers/embedding/jax/distributed_embedding.py,sha256=Pe2EOeB0xn8WefGtRo2ZqpM_gDktxNm_Qcqb8ANSAFM,31332
19
19
  keras_rs/src/layers/embedding/jax/embedding_lookup.py,sha256=8LigXjPr7uQaUOdZM6yoLGoPYdRcbkXkFeL_sJoQ6uQ,8223
20
20
  keras_rs/src/layers/embedding/jax/embedding_utils.py,sha256=slJ0XwkI1z4vTAnRXQwm39LFnK9AL3CODuGRn5BufgE,8292
21
21
  keras_rs/src/layers/embedding/tensorflow/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -31,6 +31,7 @@ keras_rs/src/layers/retrieval/remove_accidental_hits.py,sha256=WKoIhUSc6SvbgLXcS
31
31
  keras_rs/src/layers/retrieval/retrieval.py,sha256=SFxMdooUhZy854SLZbpoyZR1Md4bHnpf7P077oVjjtU,4162
32
32
  keras_rs/src/layers/retrieval/sampling_probability_correction.py,sha256=3zD6LInxhyIvyujMleGqiuoPKsna2oaTN6JU6xMnW_M,1977
33
33
  keras_rs/src/losses/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
34
+ keras_rs/src/losses/list_mle_loss.py,sha256=NKRV_ZJUXFI1qG9_ugqxRyafRHreuUIokA7mbnsogBo,7433
34
35
  keras_rs/src/losses/pairwise_hinge_loss.py,sha256=tONOJpcwCw1mybwvyx8dAy5t6dDmlIn00enzWfQLXpQ,3049
35
36
  keras_rs/src/losses/pairwise_logistic_loss.py,sha256=40PFdCHDM7CLunT_PE3RbgxROVImw13dgVL3o3nzeNg,3473
36
37
  keras_rs/src/losses/pairwise_loss.py,sha256=Oydk8e7AGU0Mc9yvm6ccr_XDDfUe8EZlS4JJgyxKUm4,6197
@@ -44,13 +45,13 @@ keras_rs/src/metrics/mean_reciprocal_rank.py,sha256=vr3ZZjpGYy2N-N7stcIm5elfHe9A
44
45
  keras_rs/src/metrics/ndcg.py,sha256=ZBaKqV57K7jlto6ZVMxFNNRLdhzbLhdAR8TgDexjSjg,6922
45
46
  keras_rs/src/metrics/precision_at_k.py,sha256=Dj5R-rT_Yd5hAsk4f-BlNMujfgIdPXnFVGOw9u7BIZQ,4038
46
47
  keras_rs/src/metrics/ranking_metric.py,sha256=Lcl-Tt6HlI0f2wQpvAJ2M4mm5qCTZm-IgnLjjSEeNXE,10655
47
- keras_rs/src/metrics/ranking_metrics_utils.py,sha256=0b03wiO9SjaHthtUYO4qezBFB8yLhFSwIRJhsL2fAJg,8785
48
+ keras_rs/src/metrics/ranking_metrics_utils.py,sha256=PrndeM3vJojrDbbgmcBK_YUEKfIeb0riGgtyo0SdAcc,9618
48
49
  keras_rs/src/metrics/recall_at_k.py,sha256=ssnQJC42KLN28cGrmzM-qR4M4iPqiQzWM2MfwYMq4ZE,3701
49
50
  keras_rs/src/metrics/utils.py,sha256=fGTo8j0ykVE5Y3yQCS2orSFcHY20Uxt0NazyPsybUsw,2471
50
51
  keras_rs/src/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
51
52
  keras_rs/src/utils/doc_string_utils.py,sha256=CmqomepmaYcvpACpXEXkrJb8DMnvIgmYK-lJ53lYarY,1675
52
53
  keras_rs/src/utils/keras_utils.py,sha256=dc-NFzs3a-qmRw0vBDiMslPLfrm9yymGduLWesXPhuY,2123
53
- keras_rs_nightly-0.3.1.dev202510280332.dist-info/METADATA,sha256=VVWxJMaJj1ItnuQmyfTAFAAdcBJrQ-sxUAg1SYs6c8Q,5324
54
- keras_rs_nightly-0.3.1.dev202510280332.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
55
- keras_rs_nightly-0.3.1.dev202510280332.dist-info/top_level.txt,sha256=pWs8X78Z0cn6lfcIb9VYOW5UeJ-TpoaO9dByzo7_FFo,9
56
- keras_rs_nightly-0.3.1.dev202510280332.dist-info/RECORD,,
54
+ keras_rs_nightly-0.3.1.dev202511090334.dist-info/METADATA,sha256=93jj01rtPsgRM4_tvafEXEDLEmO5wos5eX_zL2yLpn0,5324
55
+ keras_rs_nightly-0.3.1.dev202511090334.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
56
+ keras_rs_nightly-0.3.1.dev202511090334.dist-info/top_level.txt,sha256=pWs8X78Z0cn6lfcIb9VYOW5UeJ-TpoaO9dByzo7_FFo,9
57
+ keras_rs_nightly-0.3.1.dev202511090334.dist-info/RECORD,,