keras-rs-nightly 0.0.1.dev2025043003__tar.gz → 0.0.1.dev2025050103__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of keras-rs-nightly might be problematic. Click here for more details.

Files changed (47) hide show
  1. {keras_rs_nightly-0.0.1.dev2025043003 → keras_rs_nightly-0.0.1.dev2025050103}/PKG-INFO +1 -1
  2. {keras_rs_nightly-0.0.1.dev2025043003 → keras_rs_nightly-0.0.1.dev2025050103}/keras_rs/src/metrics/mean_average_precision.py +1 -1
  3. {keras_rs_nightly-0.0.1.dev2025043003 → keras_rs_nightly-0.0.1.dev2025050103}/keras_rs/src/metrics/mean_reciprocal_rank.py +4 -4
  4. {keras_rs_nightly-0.0.1.dev2025043003 → keras_rs_nightly-0.0.1.dev2025050103}/keras_rs/src/metrics/precision_at_k.py +3 -3
  5. {keras_rs_nightly-0.0.1.dev2025043003 → keras_rs_nightly-0.0.1.dev2025050103}/keras_rs/src/metrics/ranking_metric.py +7 -1
  6. {keras_rs_nightly-0.0.1.dev2025043003 → keras_rs_nightly-0.0.1.dev2025050103}/keras_rs/src/metrics/ranking_metrics_utils.py +2 -2
  7. {keras_rs_nightly-0.0.1.dev2025043003 → keras_rs_nightly-0.0.1.dev2025050103}/keras_rs/src/metrics/recall_at_k.py +2 -2
  8. {keras_rs_nightly-0.0.1.dev2025043003 → keras_rs_nightly-0.0.1.dev2025050103}/keras_rs/src/version.py +1 -1
  9. {keras_rs_nightly-0.0.1.dev2025043003 → keras_rs_nightly-0.0.1.dev2025050103}/keras_rs_nightly.egg-info/PKG-INFO +1 -1
  10. {keras_rs_nightly-0.0.1.dev2025043003 → keras_rs_nightly-0.0.1.dev2025050103}/README.md +0 -0
  11. {keras_rs_nightly-0.0.1.dev2025043003 → keras_rs_nightly-0.0.1.dev2025050103}/keras_rs/api/__init__.py +0 -0
  12. {keras_rs_nightly-0.0.1.dev2025043003 → keras_rs_nightly-0.0.1.dev2025050103}/keras_rs/api/layers/__init__.py +0 -0
  13. {keras_rs_nightly-0.0.1.dev2025043003 → keras_rs_nightly-0.0.1.dev2025050103}/keras_rs/api/losses/__init__.py +0 -0
  14. {keras_rs_nightly-0.0.1.dev2025043003 → keras_rs_nightly-0.0.1.dev2025050103}/keras_rs/api/metrics/__init__.py +0 -0
  15. {keras_rs_nightly-0.0.1.dev2025043003 → keras_rs_nightly-0.0.1.dev2025050103}/keras_rs/src/__init__.py +0 -0
  16. {keras_rs_nightly-0.0.1.dev2025043003 → keras_rs_nightly-0.0.1.dev2025050103}/keras_rs/src/api_export.py +0 -0
  17. {keras_rs_nightly-0.0.1.dev2025043003 → keras_rs_nightly-0.0.1.dev2025050103}/keras_rs/src/layers/__init__.py +0 -0
  18. {keras_rs_nightly-0.0.1.dev2025043003 → keras_rs_nightly-0.0.1.dev2025050103}/keras_rs/src/layers/feature_interaction/__init__.py +0 -0
  19. {keras_rs_nightly-0.0.1.dev2025043003 → keras_rs_nightly-0.0.1.dev2025050103}/keras_rs/src/layers/feature_interaction/dot_interaction.py +0 -0
  20. {keras_rs_nightly-0.0.1.dev2025043003 → keras_rs_nightly-0.0.1.dev2025050103}/keras_rs/src/layers/feature_interaction/feature_cross.py +0 -0
  21. {keras_rs_nightly-0.0.1.dev2025043003 → keras_rs_nightly-0.0.1.dev2025050103}/keras_rs/src/layers/retrieval/__init__.py +0 -0
  22. {keras_rs_nightly-0.0.1.dev2025043003 → keras_rs_nightly-0.0.1.dev2025050103}/keras_rs/src/layers/retrieval/brute_force_retrieval.py +0 -0
  23. {keras_rs_nightly-0.0.1.dev2025043003 → keras_rs_nightly-0.0.1.dev2025050103}/keras_rs/src/layers/retrieval/hard_negative_mining.py +0 -0
  24. {keras_rs_nightly-0.0.1.dev2025043003 → keras_rs_nightly-0.0.1.dev2025050103}/keras_rs/src/layers/retrieval/remove_accidental_hits.py +0 -0
  25. {keras_rs_nightly-0.0.1.dev2025043003 → keras_rs_nightly-0.0.1.dev2025050103}/keras_rs/src/layers/retrieval/retrieval.py +0 -0
  26. {keras_rs_nightly-0.0.1.dev2025043003 → keras_rs_nightly-0.0.1.dev2025050103}/keras_rs/src/layers/retrieval/sampling_probability_correction.py +0 -0
  27. {keras_rs_nightly-0.0.1.dev2025043003 → keras_rs_nightly-0.0.1.dev2025050103}/keras_rs/src/losses/__init__.py +0 -0
  28. {keras_rs_nightly-0.0.1.dev2025043003 → keras_rs_nightly-0.0.1.dev2025050103}/keras_rs/src/losses/pairwise_hinge_loss.py +0 -0
  29. {keras_rs_nightly-0.0.1.dev2025043003 → keras_rs_nightly-0.0.1.dev2025050103}/keras_rs/src/losses/pairwise_logistic_loss.py +0 -0
  30. {keras_rs_nightly-0.0.1.dev2025043003 → keras_rs_nightly-0.0.1.dev2025050103}/keras_rs/src/losses/pairwise_loss.py +0 -0
  31. {keras_rs_nightly-0.0.1.dev2025043003 → keras_rs_nightly-0.0.1.dev2025050103}/keras_rs/src/losses/pairwise_loss_utils.py +0 -0
  32. {keras_rs_nightly-0.0.1.dev2025043003 → keras_rs_nightly-0.0.1.dev2025050103}/keras_rs/src/losses/pairwise_mean_squared_error.py +0 -0
  33. {keras_rs_nightly-0.0.1.dev2025043003 → keras_rs_nightly-0.0.1.dev2025050103}/keras_rs/src/losses/pairwise_soft_zero_one_loss.py +0 -0
  34. {keras_rs_nightly-0.0.1.dev2025043003 → keras_rs_nightly-0.0.1.dev2025050103}/keras_rs/src/metrics/__init__.py +0 -0
  35. {keras_rs_nightly-0.0.1.dev2025043003 → keras_rs_nightly-0.0.1.dev2025050103}/keras_rs/src/metrics/dcg.py +0 -0
  36. {keras_rs_nightly-0.0.1.dev2025043003 → keras_rs_nightly-0.0.1.dev2025050103}/keras_rs/src/metrics/ndcg.py +0 -0
  37. {keras_rs_nightly-0.0.1.dev2025043003 → keras_rs_nightly-0.0.1.dev2025050103}/keras_rs/src/metrics/utils.py +0 -0
  38. {keras_rs_nightly-0.0.1.dev2025043003 → keras_rs_nightly-0.0.1.dev2025050103}/keras_rs/src/types.py +0 -0
  39. {keras_rs_nightly-0.0.1.dev2025043003 → keras_rs_nightly-0.0.1.dev2025050103}/keras_rs/src/utils/__init__.py +0 -0
  40. {keras_rs_nightly-0.0.1.dev2025043003 → keras_rs_nightly-0.0.1.dev2025050103}/keras_rs/src/utils/doc_string_utils.py +0 -0
  41. {keras_rs_nightly-0.0.1.dev2025043003 → keras_rs_nightly-0.0.1.dev2025050103}/keras_rs/src/utils/keras_utils.py +0 -0
  42. {keras_rs_nightly-0.0.1.dev2025043003 → keras_rs_nightly-0.0.1.dev2025050103}/keras_rs_nightly.egg-info/SOURCES.txt +0 -0
  43. {keras_rs_nightly-0.0.1.dev2025043003 → keras_rs_nightly-0.0.1.dev2025050103}/keras_rs_nightly.egg-info/dependency_links.txt +0 -0
  44. {keras_rs_nightly-0.0.1.dev2025043003 → keras_rs_nightly-0.0.1.dev2025050103}/keras_rs_nightly.egg-info/requires.txt +0 -0
  45. {keras_rs_nightly-0.0.1.dev2025043003 → keras_rs_nightly-0.0.1.dev2025050103}/keras_rs_nightly.egg-info/top_level.txt +0 -0
  46. {keras_rs_nightly-0.0.1.dev2025043003 → keras_rs_nightly-0.0.1.dev2025050103}/pyproject.toml +0 -0
  47. {keras_rs_nightly-0.0.1.dev2025043003 → keras_rs_nightly-0.0.1.dev2025050103}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: keras-rs-nightly
3
- Version: 0.0.1.dev2025043003
3
+ Version: 0.0.1.dev2025050103
4
4
  Summary: Multi-backend recommender systems with Keras 3.
5
5
  Author-email: Keras team <keras-users@googlegroups.com>
6
6
  License: Apache License 2.0
@@ -25,7 +25,7 @@ class MeanAveragePrecision(RankingMetric):
25
25
  ) -> types.Tensor:
26
26
  relevance = ops.cast(
27
27
  ops.greater_equal(y_true, ops.cast(1, dtype=y_true.dtype)),
28
- dtype="float32",
28
+ dtype=y_pred.dtype,
29
29
  )
30
30
  sorted_relevance, sorted_weights = sort_by_scores(
31
31
  tensors_to_sort=[relevance, sample_weight],
@@ -44,13 +44,13 @@ class MeanReciprocalRank(RankingMetric):
44
44
  ops.greater_equal(
45
45
  sorted_y_true, ops.cast(1, dtype=sorted_y_true.dtype)
46
46
  ),
47
- dtype="float32",
47
+ dtype=y_pred.dtype,
48
48
  )
49
49
 
50
50
  # `reciprocal_rank = [1, 0.5, 0.33]`
51
51
  reciprocal_rank = ops.divide(
52
- ops.cast(1, dtype="float32"),
53
- ops.arange(1, list_length + 1, dtype="float32"),
52
+ ops.cast(1, dtype=y_pred.dtype),
53
+ ops.arange(1, list_length + 1, dtype=y_pred.dtype),
54
54
  )
55
55
 
56
56
  # `mrr` should be of shape `(batch_size, 1)`.
@@ -64,7 +64,7 @@ class MeanReciprocalRank(RankingMetric):
64
64
  # Get weights.
65
65
  overall_relevance = ops.cast(
66
66
  ops.greater_equal(y_true, ops.cast(1, dtype=y_true.dtype)),
67
- dtype="float32",
67
+ dtype=y_pred.dtype,
68
68
  )
69
69
  per_list_weights = get_list_weights(
70
70
  weights=sample_weight, relevance=overall_relevance
@@ -40,7 +40,7 @@ class PrecisionAtK(RankingMetric):
40
40
  ops.greater_equal(
41
41
  sorted_y_true, ops.cast(1, dtype=sorted_y_true.dtype)
42
42
  ),
43
- dtype="float32",
43
+ dtype=y_pred.dtype,
44
44
  )
45
45
  list_length = ops.shape(sorted_y_true)[1]
46
46
  # TODO: We do not do this for MRR, and the other metrics. Do we need to
@@ -52,13 +52,13 @@ class PrecisionAtK(RankingMetric):
52
52
 
53
53
  per_list_precision = ops.divide_no_nan(
54
54
  ops.sum(relevance, axis=1, keepdims=True),
55
- ops.cast(valid_list_length, dtype="float32"),
55
+ ops.cast(valid_list_length, dtype=y_pred.dtype),
56
56
  )
57
57
 
58
58
  # Get weights.
59
59
  overall_relevance = ops.cast(
60
60
  ops.greater_equal(y_true, ops.cast(1, dtype=y_true.dtype)),
61
- dtype="float32",
61
+ dtype=y_pred.dtype,
62
62
  )
63
63
  per_list_weights = get_list_weights(
64
64
  weights=sample_weight, relevance=overall_relevance
@@ -116,6 +116,12 @@ class RankingMetric(keras.metrics.Mean, abc.ABC):
116
116
  if passed_mask is not None:
117
117
  passed_mask = ops.convert_to_tensor(passed_mask)
118
118
 
119
+ # Cast to the correct dtype.
120
+ y_true = ops.cast(y_true, dtype=self.dtype)
121
+ y_pred = ops.cast(y_pred, dtype=self.dtype)
122
+ if sample_weight is not None:
123
+ sample_weight = ops.cast(sample_weight, dtype=self.dtype)
124
+
119
125
  # === Process `sample_weight` ===
120
126
  if sample_weight is None:
121
127
  sample_weight = ops.cast(1, dtype=y_pred.dtype)
@@ -152,7 +158,7 @@ class RankingMetric(keras.metrics.Mean, abc.ABC):
152
158
 
153
159
  # Mask all values less than 0 (since less than 0 implies invalid
154
160
  # labels).
155
- valid_mask = ops.greater_equal(y_true, ops.cast(0.0, y_true.dtype))
161
+ valid_mask = ops.greater_equal(y_true, ops.cast(0, y_true.dtype))
156
162
  if passed_mask is not None:
157
163
  valid_mask = ops.logical_and(valid_mask, passed_mask)
158
164
 
@@ -163,7 +163,7 @@ def get_list_weights(
163
163
  # Identify lists where both weights and relevance sums are non-zero.
164
164
  nonzero_relevance = ops.cast(
165
165
  ops.logical_and(nonzero_weights, nonzero_relevance_condition),
166
- dtype="float32",
166
+ dtype=weights.dtype,
167
167
  )
168
168
  # Count the number of lists with non-zero relevance and non-zero weights.
169
169
  nonzero_relevance_count = ops.sum(nonzero_relevance, axis=0, keepdims=True)
@@ -227,7 +227,7 @@ def compute_dcg(
227
227
  ] = default_rank_discount_fn,
228
228
  ) -> types.Tensor:
229
229
  list_size = ops.shape(y_true)[1]
230
- positions = ops.arange(1, list_size + 1, dtype="float32")
230
+ positions = ops.arange(1, list_size + 1, dtype=y_true.dtype)
231
231
  gain = gain_fn(y_true)
232
232
  discount = rank_discount_fn(positions)
233
233
 
@@ -38,11 +38,11 @@ class RecallAtK(RankingMetric):
38
38
  ops.greater_equal(
39
39
  sorted_y_true, ops.cast(1, dtype=sorted_y_true.dtype)
40
40
  ),
41
- dtype="float32",
41
+ dtype=y_pred.dtype,
42
42
  )
43
43
  overall_relevance = ops.cast(
44
44
  ops.greater_equal(y_true, ops.cast(1, dtype=y_true.dtype)),
45
- dtype="float32",
45
+ dtype=y_pred.dtype,
46
46
  )
47
47
  per_list_recall = ops.divide_no_nan(
48
48
  ops.sum(relevance, axis=1, keepdims=True),
@@ -1,7 +1,7 @@
1
1
  from keras_rs.src.api_export import keras_rs_export
2
2
 
3
3
  # Unique source of truth for the version number.
4
- __version__ = "0.0.1.dev2025043003"
4
+ __version__ = "0.0.1.dev2025050103"
5
5
 
6
6
 
7
7
  @keras_rs_export("keras_rs.version")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: keras-rs-nightly
3
- Version: 0.0.1.dev2025043003
3
+ Version: 0.0.1.dev2025050103
4
4
  Summary: Multi-backend recommender systems with Keras 3.
5
5
  Author-email: Keras team <keras-users@googlegroups.com>
6
6
  License: Apache License 2.0