keras-rs-nightly 0.0.1.dev2025043003__py3-none-any.whl → 0.0.1.dev2025050103__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of keras-rs-nightly might be problematic. Click here for more details.

@@ -25,7 +25,7 @@ class MeanAveragePrecision(RankingMetric):
25
25
  ) -> types.Tensor:
26
26
  relevance = ops.cast(
27
27
  ops.greater_equal(y_true, ops.cast(1, dtype=y_true.dtype)),
28
- dtype="float32",
28
+ dtype=y_pred.dtype,
29
29
  )
30
30
  sorted_relevance, sorted_weights = sort_by_scores(
31
31
  tensors_to_sort=[relevance, sample_weight],
@@ -44,13 +44,13 @@ class MeanReciprocalRank(RankingMetric):
44
44
  ops.greater_equal(
45
45
  sorted_y_true, ops.cast(1, dtype=sorted_y_true.dtype)
46
46
  ),
47
- dtype="float32",
47
+ dtype=y_pred.dtype,
48
48
  )
49
49
 
50
50
  # `reciprocal_rank = [1, 0.5, 0.33]`
51
51
  reciprocal_rank = ops.divide(
52
- ops.cast(1, dtype="float32"),
53
- ops.arange(1, list_length + 1, dtype="float32"),
52
+ ops.cast(1, dtype=y_pred.dtype),
53
+ ops.arange(1, list_length + 1, dtype=y_pred.dtype),
54
54
  )
55
55
 
56
56
  # `mrr` should be of shape `(batch_size, 1)`.
@@ -64,7 +64,7 @@ class MeanReciprocalRank(RankingMetric):
64
64
  # Get weights.
65
65
  overall_relevance = ops.cast(
66
66
  ops.greater_equal(y_true, ops.cast(1, dtype=y_true.dtype)),
67
- dtype="float32",
67
+ dtype=y_pred.dtype,
68
68
  )
69
69
  per_list_weights = get_list_weights(
70
70
  weights=sample_weight, relevance=overall_relevance
@@ -40,7 +40,7 @@ class PrecisionAtK(RankingMetric):
40
40
  ops.greater_equal(
41
41
  sorted_y_true, ops.cast(1, dtype=sorted_y_true.dtype)
42
42
  ),
43
- dtype="float32",
43
+ dtype=y_pred.dtype,
44
44
  )
45
45
  list_length = ops.shape(sorted_y_true)[1]
46
46
  # TODO: We do not do this for MRR, and the other metrics. Do we need to
@@ -52,13 +52,13 @@ class PrecisionAtK(RankingMetric):
52
52
 
53
53
  per_list_precision = ops.divide_no_nan(
54
54
  ops.sum(relevance, axis=1, keepdims=True),
55
- ops.cast(valid_list_length, dtype="float32"),
55
+ ops.cast(valid_list_length, dtype=y_pred.dtype),
56
56
  )
57
57
 
58
58
  # Get weights.
59
59
  overall_relevance = ops.cast(
60
60
  ops.greater_equal(y_true, ops.cast(1, dtype=y_true.dtype)),
61
- dtype="float32",
61
+ dtype=y_pred.dtype,
62
62
  )
63
63
  per_list_weights = get_list_weights(
64
64
  weights=sample_weight, relevance=overall_relevance
@@ -116,6 +116,12 @@ class RankingMetric(keras.metrics.Mean, abc.ABC):
116
116
  if passed_mask is not None:
117
117
  passed_mask = ops.convert_to_tensor(passed_mask)
118
118
 
119
+ # Cast to the correct dtype.
120
+ y_true = ops.cast(y_true, dtype=self.dtype)
121
+ y_pred = ops.cast(y_pred, dtype=self.dtype)
122
+ if sample_weight is not None:
123
+ sample_weight = ops.cast(sample_weight, dtype=self.dtype)
124
+
119
125
  # === Process `sample_weight` ===
120
126
  if sample_weight is None:
121
127
  sample_weight = ops.cast(1, dtype=y_pred.dtype)
@@ -152,7 +158,7 @@ class RankingMetric(keras.metrics.Mean, abc.ABC):
152
158
 
153
159
  # Mask all values less than 0 (since less than 0 implies invalid
154
160
  # labels).
155
- valid_mask = ops.greater_equal(y_true, ops.cast(0.0, y_true.dtype))
161
+ valid_mask = ops.greater_equal(y_true, ops.cast(0, y_true.dtype))
156
162
  if passed_mask is not None:
157
163
  valid_mask = ops.logical_and(valid_mask, passed_mask)
158
164
 
@@ -163,7 +163,7 @@ def get_list_weights(
163
163
  # Identify lists where both weights and relevance sums are non-zero.
164
164
  nonzero_relevance = ops.cast(
165
165
  ops.logical_and(nonzero_weights, nonzero_relevance_condition),
166
- dtype="float32",
166
+ dtype=weights.dtype,
167
167
  )
168
168
  # Count the number of lists with non-zero relevance and non-zero weights.
169
169
  nonzero_relevance_count = ops.sum(nonzero_relevance, axis=0, keepdims=True)
@@ -227,7 +227,7 @@ def compute_dcg(
227
227
  ] = default_rank_discount_fn,
228
228
  ) -> types.Tensor:
229
229
  list_size = ops.shape(y_true)[1]
230
- positions = ops.arange(1, list_size + 1, dtype="float32")
230
+ positions = ops.arange(1, list_size + 1, dtype=y_true.dtype)
231
231
  gain = gain_fn(y_true)
232
232
  discount = rank_discount_fn(positions)
233
233
 
@@ -38,11 +38,11 @@ class RecallAtK(RankingMetric):
38
38
  ops.greater_equal(
39
39
  sorted_y_true, ops.cast(1, dtype=sorted_y_true.dtype)
40
40
  ),
41
- dtype="float32",
41
+ dtype=y_pred.dtype,
42
42
  )
43
43
  overall_relevance = ops.cast(
44
44
  ops.greater_equal(y_true, ops.cast(1, dtype=y_true.dtype)),
45
- dtype="float32",
45
+ dtype=y_pred.dtype,
46
46
  )
47
47
  per_list_recall = ops.divide_no_nan(
48
48
  ops.sum(relevance, axis=1, keepdims=True),
keras_rs/src/version.py CHANGED
@@ -1,7 +1,7 @@
1
1
  from keras_rs.src.api_export import keras_rs_export
2
2
 
3
3
  # Unique source of truth for the version number.
4
- __version__ = "0.0.1.dev2025043003"
4
+ __version__ = "0.0.1.dev2025050103"
5
5
 
6
6
 
7
7
  @keras_rs_export("keras_rs.version")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: keras-rs-nightly
3
- Version: 0.0.1.dev2025043003
3
+ Version: 0.0.1.dev2025050103
4
4
  Summary: Multi-backend recommender systems with Keras 3.
5
5
  Author-email: Keras team <keras-users@googlegroups.com>
6
6
  License: Apache License 2.0
@@ -5,7 +5,7 @@ keras_rs/metrics/__init__.py,sha256=Qxpf6OFooIL9TIn2l3WgOea3HFRG0hq02glPAxtMZ9c,
5
5
  keras_rs/src/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
6
  keras_rs/src/api_export.py,sha256=RsmG-DvO-cdFeAF9W6LRzms0kvtm-Yp9BAA_d-952zI,510
7
7
  keras_rs/src/types.py,sha256=UyOdgjqrqg_b58opnY8n6gTiDHKVR8z_bmEruehERBk,514
8
- keras_rs/src/version.py,sha256=6DQicfo43WsR2bsg-BdUHiGbBwGhNMF6hKd7NXYSW70,222
8
+ keras_rs/src/version.py,sha256=RINCn1p_Brmx7af_3Abw9rz_hengEGgMGQKOtLcQDrM,222
9
9
  keras_rs/src/layers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
10
  keras_rs/src/layers/feature_interaction/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
11
  keras_rs/src/layers/feature_interaction/dot_interaction.py,sha256=bRLz03_8VaYLNG4gbIKCzsSc26shKMmzmwCs8SujezE,8542
@@ -25,18 +25,18 @@ keras_rs/src/losses/pairwise_mean_squared_error.py,sha256=zFiSr2TNyJysgULxj9R_tr
25
25
  keras_rs/src/losses/pairwise_soft_zero_one_loss.py,sha256=YddVtJS8tKEeb0YrqGzEsr-6IDxH4uRjFrYkZDMWpkk,3492
26
26
  keras_rs/src/metrics/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
27
27
  keras_rs/src/metrics/dcg.py,sha256=UT5EyStuMeF7kpVguF34u7__Hr0bWfSFqEoyX1F4dtA,5836
28
- keras_rs/src/metrics/mean_average_precision.py,sha256=fRptyVvhCtzg0rXhBaTfLmqo7dKIG7vS75HK0xuDvpg,4629
29
- keras_rs/src/metrics/mean_reciprocal_rank.py,sha256=R_LDAuKLK9buSD6hh3_nm0PksMhISbpuI6fR1MTsFWM,4034
28
+ keras_rs/src/metrics/mean_average_precision.py,sha256=yUub4jGnqwTmxf694Z2ymRjMG_vO2HdyqqvDbcEhdSQ,4632
29
+ keras_rs/src/metrics/mean_reciprocal_rank.py,sha256=vr3ZZjpGYy2N-N7stcIm5elfHe9A4W8uY4HADP8icMw,4046
30
30
  keras_rs/src/metrics/ndcg.py,sha256=OX8vqO5JoBm8I7NDOce0bXwtoGNEK0hGEQT8hYfqJDA,6935
31
- keras_rs/src/metrics/precision_at_k.py,sha256=A1pL5-Yo_DfDzUqAfqbF8TY39yqFgf_Fe1cxz0AsCfE,4029
32
- keras_rs/src/metrics/ranking_metric.py,sha256=GFtOszaDmP4Q1ky3KnyMNXR6OBu09Uk4aEOJyn5-JO4,10439
33
- keras_rs/src/metrics/ranking_metrics_utils.py,sha256=989J8pr6FRsA1HwBeF7SA8uQqjZT2XeCxKfRuMysWnQ,8828
34
- keras_rs/src/metrics/recall_at_k.py,sha256=allUQA6JvPcWXxtGUHXmZ_nOWHAOmuUrIy5s5Nxse-4,3695
31
+ keras_rs/src/metrics/precision_at_k.py,sha256=Dj5R-rT_Yd5hAsk4f-BlNMujfgIdPXnFVGOw9u7BIZQ,4038
32
+ keras_rs/src/metrics/ranking_metric.py,sha256=JYj64q1_W3JWyKYTn4V3emKndC3BOcUz5vfQqPIx-S8,10687
33
+ keras_rs/src/metrics/ranking_metrics_utils.py,sha256=voUgDu3Zd-8FP0DaB1PLbInDSzkV8Zfz_6OZlsVG4VQ,8835
34
+ keras_rs/src/metrics/recall_at_k.py,sha256=ssnQJC42KLN28cGrmzM-qR4M4iPqiQzWM2MfwYMq4ZE,3701
35
35
  keras_rs/src/metrics/utils.py,sha256=6xanTNdwARn4ugzmb7ko2kwAhNhsnR4NhrpS_qW0IKc,2506
36
36
  keras_rs/src/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
37
37
  keras_rs/src/utils/doc_string_utils.py,sha256=CmqomepmaYcvpACpXEXkrJb8DMnvIgmYK-lJ53lYarY,1675
38
38
  keras_rs/src/utils/keras_utils.py,sha256=d28OdQP4GrJk4NIQS4n0KPtCbgOCxVU_vDnnI7ODpOw,1562
39
- keras_rs_nightly-0.0.1.dev2025043003.dist-info/METADATA,sha256=9RvG8sYrJD060w9nUrJ_vIVKwx_M3CzH_f0dquulVjg,5199
40
- keras_rs_nightly-0.0.1.dev2025043003.dist-info/WHEEL,sha256=ooBFpIzZCPdw3uqIQsOo4qqbA4ZRPxHnOH7peeONza0,91
41
- keras_rs_nightly-0.0.1.dev2025043003.dist-info/top_level.txt,sha256=pWs8X78Z0cn6lfcIb9VYOW5UeJ-TpoaO9dByzo7_FFo,9
42
- keras_rs_nightly-0.0.1.dev2025043003.dist-info/RECORD,,
39
+ keras_rs_nightly-0.0.1.dev2025050103.dist-info/METADATA,sha256=3RWdJnM3dY6aK29kbLwiq3dqQMwtCY3YQ7EGx-pajhE,5199
40
+ keras_rs_nightly-0.0.1.dev2025050103.dist-info/WHEEL,sha256=wXxTzcEDnjrTwFYjLPcsW_7_XihufBwmpiBeiXNBGEA,91
41
+ keras_rs_nightly-0.0.1.dev2025050103.dist-info/top_level.txt,sha256=pWs8X78Z0cn6lfcIb9VYOW5UeJ-TpoaO9dByzo7_FFo,9
42
+ keras_rs_nightly-0.0.1.dev2025050103.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.0.1)
2
+ Generator: setuptools (80.1.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5