keras-rs-nightly 0.3.1.dev202510100326__tar.gz → 0.3.1.dev202511120334__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of keras-rs-nightly might be problematic. Click here for more details.

Files changed (63) hide show
  1. {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/PKG-INFO +1 -1
  2. {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/api/losses/__init__.py +1 -0
  3. {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/embedding/jax/distributed_embedding.py +62 -21
  4. keras_rs_nightly-0.3.1.dev202511120334/keras_rs/src/layers/embedding/jax/embedding_utils.py +244 -0
  5. keras_rs_nightly-0.3.1.dev202511120334/keras_rs/src/losses/list_mle_loss.py +212 -0
  6. {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/metrics/ranking_metrics_utils.py +19 -0
  7. {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/version.py +1 -1
  8. {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs_nightly.egg-info/PKG-INFO +1 -1
  9. {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs_nightly.egg-info/SOURCES.txt +1 -0
  10. keras_rs_nightly-0.3.1.dev202510100326/keras_rs/src/layers/embedding/jax/embedding_utils.py +0 -535
  11. {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/README.md +0 -0
  12. {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/api/__init__.py +0 -0
  13. {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/api/layers/__init__.py +0 -0
  14. {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/api/metrics/__init__.py +0 -0
  15. {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/__init__.py +0 -0
  16. {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/api_export.py +0 -0
  17. {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/__init__.py +0 -0
  18. {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/embedding/__init__.py +0 -0
  19. {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/embedding/base_distributed_embedding.py +0 -0
  20. {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/embedding/distributed_embedding.py +0 -0
  21. {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/embedding/distributed_embedding_config.py +0 -0
  22. {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/embedding/embed_reduce.py +0 -0
  23. {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/embedding/jax/__init__.py +0 -0
  24. {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/embedding/jax/checkpoint_utils.py +0 -0
  25. {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/embedding/jax/config_conversion.py +0 -0
  26. {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/embedding/jax/embedding_lookup.py +0 -0
  27. {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/embedding/tensorflow/__init__.py +0 -0
  28. {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/embedding/tensorflow/config_conversion.py +0 -0
  29. {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py +0 -0
  30. {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/feature_interaction/__init__.py +0 -0
  31. {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/feature_interaction/dot_interaction.py +0 -0
  32. {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/feature_interaction/feature_cross.py +0 -0
  33. {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/retrieval/__init__.py +0 -0
  34. {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/retrieval/brute_force_retrieval.py +0 -0
  35. {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/retrieval/hard_negative_mining.py +0 -0
  36. {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/retrieval/remove_accidental_hits.py +0 -0
  37. {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/retrieval/retrieval.py +0 -0
  38. {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/retrieval/sampling_probability_correction.py +0 -0
  39. {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/losses/__init__.py +0 -0
  40. {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/losses/pairwise_hinge_loss.py +0 -0
  41. {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/losses/pairwise_logistic_loss.py +0 -0
  42. {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/losses/pairwise_loss.py +0 -0
  43. {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/losses/pairwise_loss_utils.py +0 -0
  44. {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/losses/pairwise_mean_squared_error.py +0 -0
  45. {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/losses/pairwise_soft_zero_one_loss.py +0 -0
  46. {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/metrics/__init__.py +0 -0
  47. {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/metrics/dcg.py +0 -0
  48. {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/metrics/mean_average_precision.py +0 -0
  49. {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/metrics/mean_reciprocal_rank.py +0 -0
  50. {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/metrics/ndcg.py +0 -0
  51. {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/metrics/precision_at_k.py +0 -0
  52. {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/metrics/ranking_metric.py +0 -0
  53. {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/metrics/recall_at_k.py +0 -0
  54. {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/metrics/utils.py +0 -0
  55. {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/types.py +0 -0
  56. {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/utils/__init__.py +0 -0
  57. {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/utils/doc_string_utils.py +0 -0
  58. {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/utils/keras_utils.py +0 -0
  59. {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs_nightly.egg-info/dependency_links.txt +0 -0
  60. {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs_nightly.egg-info/requires.txt +0 -0
  61. {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs_nightly.egg-info/top_level.txt +0 -0
  62. {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/pyproject.toml +0 -0
  63. {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: keras-rs-nightly
3
- Version: 0.3.1.dev202510100326
3
+ Version: 0.3.1.dev202511120334
4
4
  Summary: Multi-backend recommender systems with Keras 3.
5
5
  Author-email: Keras team <keras-users@googlegroups.com>
6
6
  License: Apache License 2.0
@@ -4,6 +4,7 @@ This file was autogenerated. Do not edit it by hand,
4
4
  since your modifications would be overwritten.
5
5
  """
6
6
 
7
+ from keras_rs.src.losses.list_mle_loss import ListMLELoss as ListMLELoss
7
8
  from keras_rs.src.losses.pairwise_hinge_loss import (
8
9
  PairwiseHingeLoss as PairwiseHingeLoss,
9
10
  )
@@ -9,6 +9,7 @@ import keras
9
9
  import numpy as np
10
10
  from jax import numpy as jnp
11
11
  from jax.experimental import layout as jax_layout
12
+ from jax.experimental import multihost_utils
12
13
  from jax_tpu_embedding.sparsecore.lib.nn import embedding
13
14
  from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec
14
15
  from jax_tpu_embedding.sparsecore.lib.nn import (
@@ -442,7 +443,50 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
442
443
 
443
444
  # Collect all stacked tables.
444
445
  table_specs = embedding.get_table_specs(feature_specs)
445
- table_stacks = embedding_utils.get_table_stacks(table_specs)
446
+ table_stacks = jte_table_stacking.get_table_stacks(table_specs)
447
+
448
+ # Update stacked table stats to max of values across involved tables.
449
+ max_ids_per_partition = {}
450
+ max_unique_ids_per_partition = {}
451
+ required_buffer_size_per_device = {}
452
+ id_drop_counters = {}
453
+ for stack_name, stack in table_stacks.items():
454
+ max_ids_per_partition[stack_name] = np.max(
455
+ np.asarray(
456
+ [s.max_ids_per_partition for s in stack], dtype=np.int32
457
+ )
458
+ )
459
+ max_unique_ids_per_partition[stack_name] = np.max(
460
+ np.asarray(
461
+ [s.max_unique_ids_per_partition for s in stack],
462
+ dtype=np.int32,
463
+ )
464
+ )
465
+
466
+ # Only set the suggested buffer size if set on any individual table.
467
+ valid_buffer_sizes = [
468
+ s.suggested_coo_buffer_size_per_device
469
+ for s in stack
470
+ if s.suggested_coo_buffer_size_per_device is not None
471
+ ]
472
+ if valid_buffer_sizes:
473
+ required_buffer_size_per_device[stack_name] = np.max(
474
+ np.asarray(valid_buffer_sizes, dtype=np.int32)
475
+ )
476
+
477
+ id_drop_counters[stack_name] = 0
478
+
479
+ aggregated_stats = embedding.SparseDenseMatmulInputStats(
480
+ max_ids_per_partition=max_ids_per_partition,
481
+ max_unique_ids_per_partition=max_unique_ids_per_partition,
482
+ required_buffer_size_per_sc=required_buffer_size_per_device,
483
+ id_drop_counters=id_drop_counters,
484
+ )
485
+ embedding.update_preprocessing_parameters(
486
+ feature_specs,
487
+ aggregated_stats,
488
+ num_sc_per_device,
489
+ )
446
490
 
447
491
  # Create variables for all stacked tables and slot variables.
448
492
  with sparsecore_distribution.scope():
@@ -516,7 +560,7 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
516
560
 
517
561
  # Each stacked-table gets a ShardedCooMatrix.
518
562
  table_specs = embedding.get_table_specs(self._config.feature_specs)
519
- table_stacks = embedding_utils.get_table_stacks(table_specs)
563
+ table_stacks = jte_table_stacking.get_table_stacks(table_specs)
520
564
  stacked_table_specs = {
521
565
  stack_name: stack[0].stacked_table_spec
522
566
  for stack_name, stack in table_stacks.items()
@@ -600,31 +644,26 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
600
644
  # underlying stacked tables specs in the feature specs.
601
645
 
602
646
  # Aggregate stats across all processes/devices via pmax.
603
- num_local_cpu_devices = jax.local_device_count("cpu")
604
-
605
- def pmax_aggregate(x: Any) -> Any:
606
- if not hasattr(x, "ndim"):
607
- x = np.array(x)
608
- tiled_x = np.tile(x, (num_local_cpu_devices, *([1] * x.ndim)))
609
- return jax.pmap(
610
- lambda y: jax.lax.pmax(y, "all_cpus"), # type: ignore[no-untyped-call]
611
- axis_name="all_cpus",
612
- backend="cpu",
613
- )(tiled_x)[0]
614
-
615
- full_stats = jax.tree.map(pmax_aggregate, stats)
647
+ all_stats = multihost_utils.process_allgather(stats)
648
+ aggregated_stats = jax.tree.map(
649
+ lambda x: jnp.max(x, axis=0), all_stats
650
+ )
616
651
 
617
652
  # Check if stats changed enough to warrant action.
618
653
  stacked_table_specs = embedding.get_stacked_table_specs(
619
654
  self._config.feature_specs
620
655
  )
621
656
  changed = any(
622
- np.max(full_stats.max_ids_per_partition[stack_name])
657
+ np.max(aggregated_stats.max_ids_per_partition[stack_name])
623
658
  > spec.max_ids_per_partition
624
- or np.max(full_stats.max_unique_ids_per_partition[stack_name])
659
+ or np.max(
660
+ aggregated_stats.max_unique_ids_per_partition[stack_name]
661
+ )
625
662
  > spec.max_unique_ids_per_partition
626
663
  or (
627
- np.max(full_stats.required_buffer_size_per_sc[stack_name])
664
+ np.max(
665
+ aggregated_stats.required_buffer_size_per_sc[stack_name]
666
+ )
628
667
  * num_sc_per_device
629
668
  )
630
669
  > (spec.suggested_coo_buffer_size_per_device or 0)
@@ -634,7 +673,9 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
634
673
  # Update configuration and repeat preprocessing if stats changed.
635
674
  if changed:
636
675
  embedding.update_preprocessing_parameters(
637
- self._config.feature_specs, full_stats, num_sc_per_device
676
+ self._config.feature_specs,
677
+ aggregated_stats,
678
+ num_sc_per_device,
638
679
  )
639
680
 
640
681
  # Re-execute preprocessing with consistent input statistics.
@@ -720,7 +761,7 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
720
761
  config = self._config
721
762
  num_table_shards = config.mesh.devices.size * config.num_sc_per_device
722
763
  table_specs = embedding.get_table_specs(config.feature_specs)
723
- sharded_tables = embedding_utils.stack_and_shard_tables(
764
+ sharded_tables = jte_table_stacking.stack_and_shard_tables(
724
765
  table_specs,
725
766
  tables,
726
767
  num_table_shards,
@@ -763,7 +804,7 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
763
804
 
764
805
  return typing.cast(
765
806
  dict[str, ArrayLike],
766
- embedding_utils.unshard_and_unstack_tables(
807
+ jte_table_stacking.unshard_and_unstack_tables(
767
808
  table_specs, table_variables, num_table_shards
768
809
  ),
769
810
  )
@@ -0,0 +1,244 @@
1
+ """Utility functions for manipulating JAX embedding tables and inputs."""
2
+
3
+ import collections
4
+ from typing import Any, Mapping, NamedTuple, Sequence, TypeAlias, TypeVar
5
+
6
+ import jax
7
+ import numpy as np
8
+ from jax_tpu_embedding.sparsecore.lib.nn import embedding
9
+ from jax_tpu_embedding.sparsecore.lib.nn import table_stacking
10
+ from jax_tpu_embedding.sparsecore.lib.nn.embedding_spec import FeatureSpec
11
+
12
+ from keras_rs.src.types import Nested
13
+
14
+ T = TypeVar("T")
15
+
16
+ # Any to support tf.Ragged without needing an explicit TF dependency.
17
+ ArrayLike: TypeAlias = jax.Array | np.ndarray | Any # type: ignore
18
+ Shape: TypeAlias = tuple[int, ...]
19
+
20
+
21
+ class FeatureSamples(NamedTuple):
22
+ tokens: ArrayLike
23
+ weights: ArrayLike
24
+
25
+
26
+ class ShardedCooMatrix(NamedTuple):
27
+ shard_starts: ArrayLike
28
+ shard_ends: ArrayLike
29
+ col_ids: ArrayLike
30
+ row_ids: ArrayLike
31
+ values: ArrayLike
32
+
33
+
34
+ def convert_to_numpy(
35
+ ragged_or_dense: np.ndarray[Any, Any] | Sequence[Sequence[Any]] | Any,
36
+ dtype: Any,
37
+ ) -> np.ndarray[Any, Any]:
38
+ """Converts a ragged or dense list of inputs to a ragged/dense numpy array.
39
+
40
+ The output is adjusted to be 2D.
41
+
42
+ Args:
43
+ ragged_or_dense: Input that is either already a numpy array, or nested
44
+ sequence.
45
+ dtype: Numpy dtype of output array.
46
+
47
+ Returns:
48
+ Corresponding numpy array.
49
+ """
50
+ if hasattr(ragged_or_dense, "numpy"):
51
+ # Support tf.RaggedTensor and other TF input dtypes.
52
+ if callable(getattr(ragged_or_dense, "numpy")):
53
+ ragged_or_dense = ragged_or_dense.numpy()
54
+
55
+ if isinstance(ragged_or_dense, jax.Array):
56
+ ragged_or_dense = np.asarray(ragged_or_dense)
57
+
58
+ if isinstance(ragged_or_dense, np.ndarray):
59
+ # Convert 1D to 2D.
60
+ if ragged_or_dense.dtype != np.ndarray and ragged_or_dense.ndim == 1:
61
+ return ragged_or_dense.reshape(-1, 1).astype(dtype)
62
+
63
+ # If dense, return converted dense type.
64
+ if ragged_or_dense.dtype != np.ndarray:
65
+ return ragged_or_dense.astype(dtype)
66
+
67
+ # Ragged numpy array.
68
+ return ragged_or_dense
69
+
70
+ # Handle 1D sequence input.
71
+ if not isinstance(ragged_or_dense[0], collections.abc.Sequence):
72
+ return np.asarray(ragged_or_dense, dtype=dtype).reshape(-1, 1)
73
+
74
+ # Assemble elements into an nd-array.
75
+ counts = [len(vals) for vals in ragged_or_dense]
76
+ if all([count == counts[0] for count in counts]):
77
+ # Dense input.
78
+ return np.asarray(ragged_or_dense, dtype=dtype)
79
+ else:
80
+ # Ragged input, convert to ragged numpy arrays.
81
+ return np.array(
82
+ [np.array(row, dtype=dtype) for row in ragged_or_dense],
83
+ dtype=np.ndarray,
84
+ )
85
+
86
+
87
+ def ones_like(
88
+ ragged_or_dense: np.ndarray[Any, Any], dtype: Any = None
89
+ ) -> np.ndarray[Any, Any]:
90
+ """Creates an array of ones the same as as the input.
91
+
92
+ This differs from traditional numpy in that a ragged input will lead to
93
+ a resulting ragged array of ones, whereas np.ones_like(...) will instead
94
+ only consider the outer array and return a 1D dense array of ones.
95
+
96
+ Args:
97
+ ragged_or_dense: The ragged or dense input whose shape and data-type
98
+ define these same attributes of the returned array.
99
+ dtype: The data-type of the returned array.
100
+
101
+ Returns:
102
+ An array of ones with the same shape as the input, and specified data
103
+ type.
104
+ """
105
+ dtype = dtype or ragged_or_dense.dtype
106
+ if ragged_or_dense.dtype == np.ndarray:
107
+ # Ragged.
108
+ return np.array(
109
+ [np.ones_like(row, dtype=dtype) for row in ragged_or_dense],
110
+ dtype=np.ndarray,
111
+ )
112
+ else:
113
+ # Dense.
114
+ return np.ones_like(ragged_or_dense, dtype=dtype)
115
+
116
+
117
+ def create_feature_samples(
118
+ feature_structure: Nested[T],
119
+ feature_ids: Nested[ArrayLike | Sequence[int] | Sequence[Sequence[int]]],
120
+ feature_weights: None
121
+ | (Nested[ArrayLike | Sequence[float] | Sequence[Sequence[float]]]),
122
+ ) -> Nested[FeatureSamples]:
123
+ """Constructs a collection of sample tuples from provided IDs and weights.
124
+
125
+ Args:
126
+ feature_structure: The nested structure of the inputs (typically
127
+ `FeatureSpec`s).
128
+ feature_ids: The feature IDs to use for the samples.
129
+ feature_weights: The feature weights to use for the samples. Defaults
130
+ to ones if not provided.
131
+
132
+ Returns:
133
+ A nested collection of `FeatureSamples` corresponding to the input IDs
134
+ and weights, for use in embedding lookups.
135
+ """
136
+ # Create numpy arrays from inputs.
137
+ feature_ids = jax.tree.map(
138
+ lambda _, ids: convert_to_numpy(ids, np.int32),
139
+ feature_structure,
140
+ feature_ids,
141
+ )
142
+
143
+ if feature_weights is None:
144
+ # Make ragged or dense ones_like.
145
+ feature_weights = jax.tree.map(
146
+ lambda _, ids: ones_like(ids, np.float32),
147
+ feature_structure,
148
+ feature_ids,
149
+ )
150
+ else:
151
+ feature_weights = jax.tree.map(
152
+ lambda _, wgts: convert_to_numpy(wgts, np.float32),
153
+ feature_structure,
154
+ feature_weights,
155
+ )
156
+
157
+ # Assemble.
158
+ def _create_feature_samples(
159
+ sample_ids: np.ndarray[Any, Any],
160
+ sample_weights: np.ndarray[Any, Any],
161
+ ) -> FeatureSamples:
162
+ return FeatureSamples(sample_ids, sample_weights)
163
+
164
+ output: Nested[FeatureSamples] = jax.tree.map(
165
+ lambda _, sample_ids, sample_weights: _create_feature_samples(
166
+ sample_ids, sample_weights
167
+ ),
168
+ feature_structure,
169
+ feature_ids,
170
+ feature_weights,
171
+ )
172
+ return output
173
+
174
+
175
+ def stack_and_shard_samples(
176
+ feature_specs: Nested[FeatureSpec],
177
+ feature_samples: Nested[FeatureSamples],
178
+ local_device_count: int,
179
+ global_device_count: int,
180
+ num_sc_per_device: int,
181
+ static_buffer_size: int | Mapping[str, int] | None = None,
182
+ ) -> tuple[dict[str, ShardedCooMatrix], embedding.SparseDenseMatmulInputStats]:
183
+ """Prepares input samples for use in embedding lookups.
184
+
185
+ Args:
186
+ feature_specs: Nested collection of feature specifications.
187
+ feature_samples: Nested collection of feature samples.
188
+ local_device_count: Number of local JAX devices.
189
+ global_device_count: Number of global JAX devices.
190
+ num_sc_per_device: Number of sparsecores per device.
191
+ static_buffer_size: The static buffer size to use for the samples.
192
+ Defaults to None, in which case an upper-bound for the buffer size
193
+ will be automatically determined.
194
+
195
+ Returns:
196
+ The preprocessed inputs, and statistics useful for updating FeatureSpecs
197
+ based on the provided input data.
198
+ """
199
+ del static_buffer_size # Currently ignored.
200
+ flat_feature_specs, _ = jax.tree.flatten(feature_specs)
201
+
202
+ feature_tokens = []
203
+ feature_weights = []
204
+
205
+ def collect_tokens_and_weights(
206
+ feature_spec: FeatureSpec, samples: FeatureSamples
207
+ ) -> None:
208
+ del feature_spec
209
+ feature_tokens.append(samples.tokens)
210
+ feature_weights.append(samples.weights)
211
+
212
+ jax.tree.map(collect_tokens_and_weights, feature_specs, feature_samples)
213
+
214
+ preprocessed_inputs, stats = embedding.preprocess_sparse_dense_matmul_input(
215
+ feature_tokens,
216
+ feature_weights,
217
+ flat_feature_specs,
218
+ local_device_count=local_device_count,
219
+ global_device_count=global_device_count,
220
+ num_sc_per_device=num_sc_per_device,
221
+ sharding_strategy="MOD",
222
+ has_leading_dimension=False,
223
+ allow_id_dropping=True,
224
+ )
225
+
226
+ out: dict[str, ShardedCooMatrix] = {}
227
+ tables_names = preprocessed_inputs.lhs_row_pointers.keys()
228
+ for table_name in tables_names:
229
+ shard_ends = preprocessed_inputs.lhs_row_pointers[table_name]
230
+ shard_starts = np.concatenate(
231
+ [
232
+ np.asarray([0]),
233
+ table_stacking._next_largest_multiple(shard_ends[:-1], 8),
234
+ ]
235
+ )
236
+ out[table_name] = ShardedCooMatrix(
237
+ shard_starts=shard_starts,
238
+ shard_ends=shard_ends,
239
+ col_ids=preprocessed_inputs.lhs_embedding_ids[table_name],
240
+ row_ids=preprocessed_inputs.lhs_sample_ids[table_name],
241
+ values=preprocessed_inputs.lhs_gains[table_name],
242
+ )
243
+
244
+ return out, stats
@@ -0,0 +1,212 @@
1
+ from typing import Any
2
+
3
+ import keras
4
+ from keras import ops
5
+
6
+ from keras_rs.src import types
7
+ from keras_rs.src.api_export import keras_rs_export
8
+ from keras_rs.src.metrics.ranking_metrics_utils import sort_by_scores
9
+ from keras_rs.src.metrics.utils import standardize_call_inputs_ranks
10
+
11
+
12
+ @keras_rs_export("keras_rs.losses.ListMLELoss")
13
+ class ListMLELoss(keras.losses.Loss):
14
+ """Implements ListMLE (Maximum Likelihood Estimation) loss for ranking.
15
+
16
+ ListMLE loss is a listwise ranking loss that maximizes the likelihood of
17
+ the ground truth ranking. It works by:
18
+ 1. Sorting items by their relevance scores (labels)
19
+ 2. Computing the probability of observing this ranking given the
20
+ predicted scores
21
+ 3. Maximizing this likelihood (minimizing negative log-likelihood)
22
+
23
+ The loss is computed as the negative log-likelihood of the ground truth
24
+ ranking given the predicted scores:
25
+
26
+ ```
27
+ loss = -sum(log(exp(s_i) / sum(exp(s_j) for j >= i)))
28
+ ```
29
+
30
+ where s_i is the predicted score for item i in the sorted order.
31
+
32
+ Args:
33
+ temperature: Temperature parameter for scaling logits. Higher values
34
+ make the probability distribution more uniform. Defaults to 1.0.
35
+ reduction: Type of reduction to apply to the loss. In almost all cases
36
+ this should be `"sum_over_batch_size"`. Supported options are
37
+ `"sum"`, `"sum_over_batch_size"`, `"mean"`,
38
+ `"mean_with_sample_weight"` or `None`. Defaults to
39
+ `"sum_over_batch_size"`.
40
+ name: Optional name for the loss instance.
41
+ dtype: The dtype of the loss's computations. Defaults to `None`.
42
+
43
+ Examples:
44
+ ```python
45
+ # Basic usage
46
+ loss_fn = ListMLELoss()
47
+
48
+ # With temperature scaling
49
+ loss_fn = ListMLELoss(temperature=0.5)
50
+
51
+ # Example with synthetic data
52
+ y_true = [[3, 2, 1, 0]] # Relevance scores
53
+ y_pred = [[0.8, 0.6, 0.4, 0.2]] # Predicted scores
54
+ loss = loss_fn(y_true, y_pred)
55
+ ```
56
+ """
57
+
58
+ def __init__(self, temperature: float = 1.0, **kwargs: Any) -> None:
59
+ super().__init__(**kwargs)
60
+
61
+ if temperature <= 0.0:
62
+ raise ValueError(
63
+ f"`temperature` should be a positive float. Received: "
64
+ f"`temperature` = {temperature}."
65
+ )
66
+
67
+ self.temperature = temperature
68
+ self._epsilon = 1e-10
69
+
70
+ def compute_unreduced_loss(
71
+ self,
72
+ labels: types.Tensor,
73
+ logits: types.Tensor,
74
+ mask: types.Tensor | None = None,
75
+ ) -> tuple[types.Tensor, types.Tensor]:
76
+ """Compute the unreduced ListMLE loss.
77
+
78
+ Args:
79
+ labels: Ground truth relevance scores of
80
+ shape [batch_size,list_size].
81
+ logits: Predicted scores of shape [batch_size, list_size].
82
+ mask: Optional mask of shape [batch_size, list_size].
83
+
84
+ Returns:
85
+ Tuple of (losses, weights) where losses has shape [batch_size, 1]
86
+ and weights has the same shape.
87
+ """
88
+
89
+ valid_mask = ops.greater_equal(labels, ops.cast(0.0, labels.dtype))
90
+
91
+ if mask is not None:
92
+ valid_mask = ops.logical_and(
93
+ valid_mask, ops.cast(mask, dtype="bool")
94
+ )
95
+
96
+ num_valid_items = ops.sum(
97
+ ops.cast(valid_mask, dtype=labels.dtype), axis=1, keepdims=True
98
+ )
99
+
100
+ batch_has_valid_items = ops.greater(num_valid_items, 0.0)
101
+
102
+ labels_for_sorting = ops.where(
103
+ valid_mask, labels, ops.full_like(labels, -1e9)
104
+ )
105
+ logits_masked = ops.where(
106
+ valid_mask, logits, ops.full_like(logits, -1e9)
107
+ )
108
+
109
+ sorted_logits, sorted_valid_mask = sort_by_scores(
110
+ tensors_to_sort=[logits_masked, valid_mask],
111
+ scores=labels_for_sorting,
112
+ mask=None,
113
+ shuffle_ties=False,
114
+ seed=None,
115
+ )
116
+ sorted_logits = ops.divide(
117
+ sorted_logits, ops.cast(self.temperature, dtype=sorted_logits.dtype)
118
+ )
119
+
120
+ valid_logits_for_max = ops.where(
121
+ sorted_valid_mask, sorted_logits, ops.full_like(sorted_logits, -1e9)
122
+ )
123
+ raw_max = ops.max(valid_logits_for_max, axis=1, keepdims=True)
124
+ raw_max = ops.where(
125
+ batch_has_valid_items, raw_max, ops.zeros_like(raw_max)
126
+ )
127
+ sorted_logits = ops.subtract(sorted_logits, raw_max)
128
+
129
+ # Set invalid positions to very negative BEFORE exp
130
+ sorted_logits = ops.where(
131
+ sorted_valid_mask, sorted_logits, ops.full_like(sorted_logits, -1e9)
132
+ )
133
+ exp_logits = ops.exp(sorted_logits)
134
+
135
+ reversed_exp = ops.flip(exp_logits, axis=1)
136
+ reversed_cumsum = ops.cumsum(reversed_exp, axis=1)
137
+ cumsum_from_right = ops.flip(reversed_cumsum, axis=1)
138
+
139
+ log_normalizers = ops.log(cumsum_from_right + self._epsilon)
140
+ log_probs = ops.subtract(sorted_logits, log_normalizers)
141
+
142
+ log_probs = ops.where(
143
+ sorted_valid_mask, log_probs, ops.zeros_like(log_probs)
144
+ )
145
+
146
+ negative_log_likelihood = ops.negative(
147
+ ops.sum(log_probs, axis=1, keepdims=True)
148
+ )
149
+
150
+ negative_log_likelihood = ops.where(
151
+ batch_has_valid_items,
152
+ negative_log_likelihood,
153
+ ops.zeros_like(negative_log_likelihood),
154
+ )
155
+
156
+ weights = ops.ones_like(negative_log_likelihood)
157
+
158
+ return negative_log_likelihood, weights
159
+
160
+ def call(
161
+ self,
162
+ y_true: types.Tensor,
163
+ y_pred: types.Tensor,
164
+ ) -> types.Tensor:
165
+ """Compute the ListMLE loss.
166
+
167
+ Args:
168
+ y_true: tensor or dict. Ground truth values. If tensor, of shape
169
+ `(list_size)` for unbatched inputs or `(batch_size, list_size)`
170
+ for batched inputs. If an item has a label of -1, it is ignored
171
+ in loss computation. If it is a dictionary, it should have two
172
+ keys: `"labels"` and `"mask"`. `"mask"` can be used to ignore
173
+ elements in loss computation.
174
+ y_pred: tensor. The predicted values, of shape `(list_size)` for
175
+ unbatched inputs or `(batch_size, list_size)` for batched
176
+ inputs. Should be of the same shape as `y_true`.
177
+
178
+ Returns:
179
+ The loss tensor of shape [batch_size].
180
+ """
181
+ mask = None
182
+ if isinstance(y_true, dict):
183
+ if "labels" not in y_true:
184
+ raise ValueError(
185
+ '`"labels"` should be present in `y_true`. Received: '
186
+ f"`y_true` = {y_true}"
187
+ )
188
+
189
+ mask = y_true.get("mask", None)
190
+ y_true = y_true["labels"]
191
+
192
+ y_true = ops.convert_to_tensor(y_true)
193
+ y_pred = ops.convert_to_tensor(y_pred)
194
+ if mask is not None:
195
+ mask = ops.convert_to_tensor(mask)
196
+
197
+ y_true, y_pred, mask, _ = standardize_call_inputs_ranks(
198
+ y_true, y_pred, mask
199
+ )
200
+
201
+ losses, weights = self.compute_unreduced_loss(
202
+ labels=y_true, logits=y_pred, mask=mask
203
+ )
204
+ losses = ops.multiply(losses, weights)
205
+ losses = ops.squeeze(losses, axis=-1)
206
+ return losses
207
+
208
+ # getting config
209
+ def get_config(self) -> dict[str, Any]:
210
+ config: dict[str, Any] = super().get_config()
211
+ config.update({"temperature": self.temperature})
212
+ return config
@@ -85,6 +85,25 @@ def sort_by_scores(
85
85
  else:
86
86
  k = ops.minimum(k, max_possible_k)
87
87
 
88
+ # --- Work around for PyTorch instability ---
89
+ # Torch's `topk` is not stable with `sorted=True`, unlike JAX and TF.
90
+ # See:
91
+ # - https://github.com/pytorch/pytorch/issues/27542
92
+ # - https://github.com/pytorch/pytorch/issues/88227
93
+ #
94
+ # This small "stable offset" ensures deterministic tie-breaking for
95
+ # equal scores. We can remove this workaround once PyTorch adds a
96
+ # `stable=True` flag for topk.
97
+
98
+ if keras.backend.backend() == "torch" and not shuffle_ties:
99
+ list_size = ops.shape(scores)[1]
100
+ indices = ops.arange(list_size)
101
+ indices = ops.expand_dims(indices, axis=0)
102
+ indices = ops.broadcast_to(indices, ops.shape(scores))
103
+ stable_offset = ops.cast(indices, scores.dtype) * 1e-6
104
+ scores = ops.subtract(scores, stable_offset)
105
+ # --- End FIX ---
106
+
88
107
  # Shuffle ties randomly, and push masked values to the beginning.
89
108
  shuffled_indices = None
90
109
  if shuffle_ties or mask is not None:
@@ -1,7 +1,7 @@
1
1
  from keras_rs.src.api_export import keras_rs_export
2
2
 
3
3
  # Unique source of truth for the version number.
4
- __version__ = "0.3.1.dev202510100326"
4
+ __version__ = "0.3.1.dev202511120334"
5
5
 
6
6
 
7
7
  @keras_rs_export("keras_rs.version")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: keras-rs-nightly
3
- Version: 0.3.1.dev202510100326
3
+ Version: 0.3.1.dev202511120334
4
4
  Summary: Multi-backend recommender systems with Keras 3.
5
5
  Author-email: Keras team <keras-users@googlegroups.com>
6
6
  License: Apache License 2.0
@@ -33,6 +33,7 @@ keras_rs/src/layers/retrieval/remove_accidental_hits.py
33
33
  keras_rs/src/layers/retrieval/retrieval.py
34
34
  keras_rs/src/layers/retrieval/sampling_probability_correction.py
35
35
  keras_rs/src/losses/__init__.py
36
+ keras_rs/src/losses/list_mle_loss.py
36
37
  keras_rs/src/losses/pairwise_hinge_loss.py
37
38
  keras_rs/src/losses/pairwise_logistic_loss.py
38
39
  keras_rs/src/losses/pairwise_loss.py