keras-rs-nightly 0.3.1.dev202510100326__tar.gz → 0.3.1.dev202511120334__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of keras-rs-nightly might be problematic. Click here for more details.
- {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/PKG-INFO +1 -1
- {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/api/losses/__init__.py +1 -0
- {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/embedding/jax/distributed_embedding.py +62 -21
- keras_rs_nightly-0.3.1.dev202511120334/keras_rs/src/layers/embedding/jax/embedding_utils.py +244 -0
- keras_rs_nightly-0.3.1.dev202511120334/keras_rs/src/losses/list_mle_loss.py +212 -0
- {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/metrics/ranking_metrics_utils.py +19 -0
- {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/version.py +1 -1
- {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs_nightly.egg-info/PKG-INFO +1 -1
- {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs_nightly.egg-info/SOURCES.txt +1 -0
- keras_rs_nightly-0.3.1.dev202510100326/keras_rs/src/layers/embedding/jax/embedding_utils.py +0 -535
- {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/README.md +0 -0
- {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/api/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/api/layers/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/api/metrics/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/api_export.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/embedding/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/embedding/base_distributed_embedding.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/embedding/distributed_embedding.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/embedding/distributed_embedding_config.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/embedding/embed_reduce.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/embedding/jax/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/embedding/jax/checkpoint_utils.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/embedding/jax/config_conversion.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/embedding/jax/embedding_lookup.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/embedding/tensorflow/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/embedding/tensorflow/config_conversion.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/feature_interaction/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/feature_interaction/dot_interaction.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/feature_interaction/feature_cross.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/retrieval/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/retrieval/brute_force_retrieval.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/retrieval/hard_negative_mining.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/retrieval/remove_accidental_hits.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/retrieval/retrieval.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/retrieval/sampling_probability_correction.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/losses/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/losses/pairwise_hinge_loss.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/losses/pairwise_logistic_loss.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/losses/pairwise_loss.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/losses/pairwise_loss_utils.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/losses/pairwise_mean_squared_error.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/losses/pairwise_soft_zero_one_loss.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/metrics/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/metrics/dcg.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/metrics/mean_average_precision.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/metrics/mean_reciprocal_rank.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/metrics/ndcg.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/metrics/precision_at_k.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/metrics/ranking_metric.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/metrics/recall_at_k.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/metrics/utils.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/types.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/utils/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/utils/doc_string_utils.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/utils/keras_utils.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs_nightly.egg-info/dependency_links.txt +0 -0
- {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs_nightly.egg-info/requires.txt +0 -0
- {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs_nightly.egg-info/top_level.txt +0 -0
- {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/pyproject.toml +0 -0
- {keras_rs_nightly-0.3.1.dev202510100326 → keras_rs_nightly-0.3.1.dev202511120334}/setup.cfg +0 -0
|
@@ -4,6 +4,7 @@ This file was autogenerated. Do not edit it by hand,
|
|
|
4
4
|
since your modifications would be overwritten.
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
|
+
from keras_rs.src.losses.list_mle_loss import ListMLELoss as ListMLELoss
|
|
7
8
|
from keras_rs.src.losses.pairwise_hinge_loss import (
|
|
8
9
|
PairwiseHingeLoss as PairwiseHingeLoss,
|
|
9
10
|
)
|
|
@@ -9,6 +9,7 @@ import keras
|
|
|
9
9
|
import numpy as np
|
|
10
10
|
from jax import numpy as jnp
|
|
11
11
|
from jax.experimental import layout as jax_layout
|
|
12
|
+
from jax.experimental import multihost_utils
|
|
12
13
|
from jax_tpu_embedding.sparsecore.lib.nn import embedding
|
|
13
14
|
from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec
|
|
14
15
|
from jax_tpu_embedding.sparsecore.lib.nn import (
|
|
@@ -442,7 +443,50 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
|
|
|
442
443
|
|
|
443
444
|
# Collect all stacked tables.
|
|
444
445
|
table_specs = embedding.get_table_specs(feature_specs)
|
|
445
|
-
table_stacks =
|
|
446
|
+
table_stacks = jte_table_stacking.get_table_stacks(table_specs)
|
|
447
|
+
|
|
448
|
+
# Update stacked table stats to max of values across involved tables.
|
|
449
|
+
max_ids_per_partition = {}
|
|
450
|
+
max_unique_ids_per_partition = {}
|
|
451
|
+
required_buffer_size_per_device = {}
|
|
452
|
+
id_drop_counters = {}
|
|
453
|
+
for stack_name, stack in table_stacks.items():
|
|
454
|
+
max_ids_per_partition[stack_name] = np.max(
|
|
455
|
+
np.asarray(
|
|
456
|
+
[s.max_ids_per_partition for s in stack], dtype=np.int32
|
|
457
|
+
)
|
|
458
|
+
)
|
|
459
|
+
max_unique_ids_per_partition[stack_name] = np.max(
|
|
460
|
+
np.asarray(
|
|
461
|
+
[s.max_unique_ids_per_partition for s in stack],
|
|
462
|
+
dtype=np.int32,
|
|
463
|
+
)
|
|
464
|
+
)
|
|
465
|
+
|
|
466
|
+
# Only set the suggested buffer size if set on any individual table.
|
|
467
|
+
valid_buffer_sizes = [
|
|
468
|
+
s.suggested_coo_buffer_size_per_device
|
|
469
|
+
for s in stack
|
|
470
|
+
if s.suggested_coo_buffer_size_per_device is not None
|
|
471
|
+
]
|
|
472
|
+
if valid_buffer_sizes:
|
|
473
|
+
required_buffer_size_per_device[stack_name] = np.max(
|
|
474
|
+
np.asarray(valid_buffer_sizes, dtype=np.int32)
|
|
475
|
+
)
|
|
476
|
+
|
|
477
|
+
id_drop_counters[stack_name] = 0
|
|
478
|
+
|
|
479
|
+
aggregated_stats = embedding.SparseDenseMatmulInputStats(
|
|
480
|
+
max_ids_per_partition=max_ids_per_partition,
|
|
481
|
+
max_unique_ids_per_partition=max_unique_ids_per_partition,
|
|
482
|
+
required_buffer_size_per_sc=required_buffer_size_per_device,
|
|
483
|
+
id_drop_counters=id_drop_counters,
|
|
484
|
+
)
|
|
485
|
+
embedding.update_preprocessing_parameters(
|
|
486
|
+
feature_specs,
|
|
487
|
+
aggregated_stats,
|
|
488
|
+
num_sc_per_device,
|
|
489
|
+
)
|
|
446
490
|
|
|
447
491
|
# Create variables for all stacked tables and slot variables.
|
|
448
492
|
with sparsecore_distribution.scope():
|
|
@@ -516,7 +560,7 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
|
|
|
516
560
|
|
|
517
561
|
# Each stacked-table gets a ShardedCooMatrix.
|
|
518
562
|
table_specs = embedding.get_table_specs(self._config.feature_specs)
|
|
519
|
-
table_stacks =
|
|
563
|
+
table_stacks = jte_table_stacking.get_table_stacks(table_specs)
|
|
520
564
|
stacked_table_specs = {
|
|
521
565
|
stack_name: stack[0].stacked_table_spec
|
|
522
566
|
for stack_name, stack in table_stacks.items()
|
|
@@ -600,31 +644,26 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
|
|
|
600
644
|
# underlying stacked tables specs in the feature specs.
|
|
601
645
|
|
|
602
646
|
# Aggregate stats across all processes/devices via pmax.
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
x = np.array(x)
|
|
608
|
-
tiled_x = np.tile(x, (num_local_cpu_devices, *([1] * x.ndim)))
|
|
609
|
-
return jax.pmap(
|
|
610
|
-
lambda y: jax.lax.pmax(y, "all_cpus"), # type: ignore[no-untyped-call]
|
|
611
|
-
axis_name="all_cpus",
|
|
612
|
-
backend="cpu",
|
|
613
|
-
)(tiled_x)[0]
|
|
614
|
-
|
|
615
|
-
full_stats = jax.tree.map(pmax_aggregate, stats)
|
|
647
|
+
all_stats = multihost_utils.process_allgather(stats)
|
|
648
|
+
aggregated_stats = jax.tree.map(
|
|
649
|
+
lambda x: jnp.max(x, axis=0), all_stats
|
|
650
|
+
)
|
|
616
651
|
|
|
617
652
|
# Check if stats changed enough to warrant action.
|
|
618
653
|
stacked_table_specs = embedding.get_stacked_table_specs(
|
|
619
654
|
self._config.feature_specs
|
|
620
655
|
)
|
|
621
656
|
changed = any(
|
|
622
|
-
np.max(
|
|
657
|
+
np.max(aggregated_stats.max_ids_per_partition[stack_name])
|
|
623
658
|
> spec.max_ids_per_partition
|
|
624
|
-
or np.max(
|
|
659
|
+
or np.max(
|
|
660
|
+
aggregated_stats.max_unique_ids_per_partition[stack_name]
|
|
661
|
+
)
|
|
625
662
|
> spec.max_unique_ids_per_partition
|
|
626
663
|
or (
|
|
627
|
-
np.max(
|
|
664
|
+
np.max(
|
|
665
|
+
aggregated_stats.required_buffer_size_per_sc[stack_name]
|
|
666
|
+
)
|
|
628
667
|
* num_sc_per_device
|
|
629
668
|
)
|
|
630
669
|
> (spec.suggested_coo_buffer_size_per_device or 0)
|
|
@@ -634,7 +673,9 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
|
|
|
634
673
|
# Update configuration and repeat preprocessing if stats changed.
|
|
635
674
|
if changed:
|
|
636
675
|
embedding.update_preprocessing_parameters(
|
|
637
|
-
self._config.feature_specs,
|
|
676
|
+
self._config.feature_specs,
|
|
677
|
+
aggregated_stats,
|
|
678
|
+
num_sc_per_device,
|
|
638
679
|
)
|
|
639
680
|
|
|
640
681
|
# Re-execute preprocessing with consistent input statistics.
|
|
@@ -720,7 +761,7 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
|
|
|
720
761
|
config = self._config
|
|
721
762
|
num_table_shards = config.mesh.devices.size * config.num_sc_per_device
|
|
722
763
|
table_specs = embedding.get_table_specs(config.feature_specs)
|
|
723
|
-
sharded_tables =
|
|
764
|
+
sharded_tables = jte_table_stacking.stack_and_shard_tables(
|
|
724
765
|
table_specs,
|
|
725
766
|
tables,
|
|
726
767
|
num_table_shards,
|
|
@@ -763,7 +804,7 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
|
|
|
763
804
|
|
|
764
805
|
return typing.cast(
|
|
765
806
|
dict[str, ArrayLike],
|
|
766
|
-
|
|
807
|
+
jte_table_stacking.unshard_and_unstack_tables(
|
|
767
808
|
table_specs, table_variables, num_table_shards
|
|
768
809
|
),
|
|
769
810
|
)
|
|
@@ -0,0 +1,244 @@
|
|
|
1
|
+
"""Utility functions for manipulating JAX embedding tables and inputs."""
|
|
2
|
+
|
|
3
|
+
import collections
|
|
4
|
+
from typing import Any, Mapping, NamedTuple, Sequence, TypeAlias, TypeVar
|
|
5
|
+
|
|
6
|
+
import jax
|
|
7
|
+
import numpy as np
|
|
8
|
+
from jax_tpu_embedding.sparsecore.lib.nn import embedding
|
|
9
|
+
from jax_tpu_embedding.sparsecore.lib.nn import table_stacking
|
|
10
|
+
from jax_tpu_embedding.sparsecore.lib.nn.embedding_spec import FeatureSpec
|
|
11
|
+
|
|
12
|
+
from keras_rs.src.types import Nested
|
|
13
|
+
|
|
14
|
+
T = TypeVar("T")
|
|
15
|
+
|
|
16
|
+
# Any to support tf.Ragged without needing an explicit TF dependency.
|
|
17
|
+
ArrayLike: TypeAlias = jax.Array | np.ndarray | Any # type: ignore
|
|
18
|
+
Shape: TypeAlias = tuple[int, ...]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class FeatureSamples(NamedTuple):
|
|
22
|
+
tokens: ArrayLike
|
|
23
|
+
weights: ArrayLike
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class ShardedCooMatrix(NamedTuple):
|
|
27
|
+
shard_starts: ArrayLike
|
|
28
|
+
shard_ends: ArrayLike
|
|
29
|
+
col_ids: ArrayLike
|
|
30
|
+
row_ids: ArrayLike
|
|
31
|
+
values: ArrayLike
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def convert_to_numpy(
|
|
35
|
+
ragged_or_dense: np.ndarray[Any, Any] | Sequence[Sequence[Any]] | Any,
|
|
36
|
+
dtype: Any,
|
|
37
|
+
) -> np.ndarray[Any, Any]:
|
|
38
|
+
"""Converts a ragged or dense list of inputs to a ragged/dense numpy array.
|
|
39
|
+
|
|
40
|
+
The output is adjusted to be 2D.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
ragged_or_dense: Input that is either already a numpy array, or nested
|
|
44
|
+
sequence.
|
|
45
|
+
dtype: Numpy dtype of output array.
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
Corresponding numpy array.
|
|
49
|
+
"""
|
|
50
|
+
if hasattr(ragged_or_dense, "numpy"):
|
|
51
|
+
# Support tf.RaggedTensor and other TF input dtypes.
|
|
52
|
+
if callable(getattr(ragged_or_dense, "numpy")):
|
|
53
|
+
ragged_or_dense = ragged_or_dense.numpy()
|
|
54
|
+
|
|
55
|
+
if isinstance(ragged_or_dense, jax.Array):
|
|
56
|
+
ragged_or_dense = np.asarray(ragged_or_dense)
|
|
57
|
+
|
|
58
|
+
if isinstance(ragged_or_dense, np.ndarray):
|
|
59
|
+
# Convert 1D to 2D.
|
|
60
|
+
if ragged_or_dense.dtype != np.ndarray and ragged_or_dense.ndim == 1:
|
|
61
|
+
return ragged_or_dense.reshape(-1, 1).astype(dtype)
|
|
62
|
+
|
|
63
|
+
# If dense, return converted dense type.
|
|
64
|
+
if ragged_or_dense.dtype != np.ndarray:
|
|
65
|
+
return ragged_or_dense.astype(dtype)
|
|
66
|
+
|
|
67
|
+
# Ragged numpy array.
|
|
68
|
+
return ragged_or_dense
|
|
69
|
+
|
|
70
|
+
# Handle 1D sequence input.
|
|
71
|
+
if not isinstance(ragged_or_dense[0], collections.abc.Sequence):
|
|
72
|
+
return np.asarray(ragged_or_dense, dtype=dtype).reshape(-1, 1)
|
|
73
|
+
|
|
74
|
+
# Assemble elements into an nd-array.
|
|
75
|
+
counts = [len(vals) for vals in ragged_or_dense]
|
|
76
|
+
if all([count == counts[0] for count in counts]):
|
|
77
|
+
# Dense input.
|
|
78
|
+
return np.asarray(ragged_or_dense, dtype=dtype)
|
|
79
|
+
else:
|
|
80
|
+
# Ragged input, convert to ragged numpy arrays.
|
|
81
|
+
return np.array(
|
|
82
|
+
[np.array(row, dtype=dtype) for row in ragged_or_dense],
|
|
83
|
+
dtype=np.ndarray,
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def ones_like(
|
|
88
|
+
ragged_or_dense: np.ndarray[Any, Any], dtype: Any = None
|
|
89
|
+
) -> np.ndarray[Any, Any]:
|
|
90
|
+
"""Creates an array of ones the same as as the input.
|
|
91
|
+
|
|
92
|
+
This differs from traditional numpy in that a ragged input will lead to
|
|
93
|
+
a resulting ragged array of ones, whereas np.ones_like(...) will instead
|
|
94
|
+
only consider the outer array and return a 1D dense array of ones.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
ragged_or_dense: The ragged or dense input whose shape and data-type
|
|
98
|
+
define these same attributes of the returned array.
|
|
99
|
+
dtype: The data-type of the returned array.
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
An array of ones with the same shape as the input, and specified data
|
|
103
|
+
type.
|
|
104
|
+
"""
|
|
105
|
+
dtype = dtype or ragged_or_dense.dtype
|
|
106
|
+
if ragged_or_dense.dtype == np.ndarray:
|
|
107
|
+
# Ragged.
|
|
108
|
+
return np.array(
|
|
109
|
+
[np.ones_like(row, dtype=dtype) for row in ragged_or_dense],
|
|
110
|
+
dtype=np.ndarray,
|
|
111
|
+
)
|
|
112
|
+
else:
|
|
113
|
+
# Dense.
|
|
114
|
+
return np.ones_like(ragged_or_dense, dtype=dtype)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def create_feature_samples(
|
|
118
|
+
feature_structure: Nested[T],
|
|
119
|
+
feature_ids: Nested[ArrayLike | Sequence[int] | Sequence[Sequence[int]]],
|
|
120
|
+
feature_weights: None
|
|
121
|
+
| (Nested[ArrayLike | Sequence[float] | Sequence[Sequence[float]]]),
|
|
122
|
+
) -> Nested[FeatureSamples]:
|
|
123
|
+
"""Constructs a collection of sample tuples from provided IDs and weights.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
feature_structure: The nested structure of the inputs (typically
|
|
127
|
+
`FeatureSpec`s).
|
|
128
|
+
feature_ids: The feature IDs to use for the samples.
|
|
129
|
+
feature_weights: The feature weights to use for the samples. Defaults
|
|
130
|
+
to ones if not provided.
|
|
131
|
+
|
|
132
|
+
Returns:
|
|
133
|
+
A nested collection of `FeatureSamples` corresponding to the input IDs
|
|
134
|
+
and weights, for use in embedding lookups.
|
|
135
|
+
"""
|
|
136
|
+
# Create numpy arrays from inputs.
|
|
137
|
+
feature_ids = jax.tree.map(
|
|
138
|
+
lambda _, ids: convert_to_numpy(ids, np.int32),
|
|
139
|
+
feature_structure,
|
|
140
|
+
feature_ids,
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
if feature_weights is None:
|
|
144
|
+
# Make ragged or dense ones_like.
|
|
145
|
+
feature_weights = jax.tree.map(
|
|
146
|
+
lambda _, ids: ones_like(ids, np.float32),
|
|
147
|
+
feature_structure,
|
|
148
|
+
feature_ids,
|
|
149
|
+
)
|
|
150
|
+
else:
|
|
151
|
+
feature_weights = jax.tree.map(
|
|
152
|
+
lambda _, wgts: convert_to_numpy(wgts, np.float32),
|
|
153
|
+
feature_structure,
|
|
154
|
+
feature_weights,
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
# Assemble.
|
|
158
|
+
def _create_feature_samples(
|
|
159
|
+
sample_ids: np.ndarray[Any, Any],
|
|
160
|
+
sample_weights: np.ndarray[Any, Any],
|
|
161
|
+
) -> FeatureSamples:
|
|
162
|
+
return FeatureSamples(sample_ids, sample_weights)
|
|
163
|
+
|
|
164
|
+
output: Nested[FeatureSamples] = jax.tree.map(
|
|
165
|
+
lambda _, sample_ids, sample_weights: _create_feature_samples(
|
|
166
|
+
sample_ids, sample_weights
|
|
167
|
+
),
|
|
168
|
+
feature_structure,
|
|
169
|
+
feature_ids,
|
|
170
|
+
feature_weights,
|
|
171
|
+
)
|
|
172
|
+
return output
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def stack_and_shard_samples(
|
|
176
|
+
feature_specs: Nested[FeatureSpec],
|
|
177
|
+
feature_samples: Nested[FeatureSamples],
|
|
178
|
+
local_device_count: int,
|
|
179
|
+
global_device_count: int,
|
|
180
|
+
num_sc_per_device: int,
|
|
181
|
+
static_buffer_size: int | Mapping[str, int] | None = None,
|
|
182
|
+
) -> tuple[dict[str, ShardedCooMatrix], embedding.SparseDenseMatmulInputStats]:
|
|
183
|
+
"""Prepares input samples for use in embedding lookups.
|
|
184
|
+
|
|
185
|
+
Args:
|
|
186
|
+
feature_specs: Nested collection of feature specifications.
|
|
187
|
+
feature_samples: Nested collection of feature samples.
|
|
188
|
+
local_device_count: Number of local JAX devices.
|
|
189
|
+
global_device_count: Number of global JAX devices.
|
|
190
|
+
num_sc_per_device: Number of sparsecores per device.
|
|
191
|
+
static_buffer_size: The static buffer size to use for the samples.
|
|
192
|
+
Defaults to None, in which case an upper-bound for the buffer size
|
|
193
|
+
will be automatically determined.
|
|
194
|
+
|
|
195
|
+
Returns:
|
|
196
|
+
The preprocessed inputs, and statistics useful for updating FeatureSpecs
|
|
197
|
+
based on the provided input data.
|
|
198
|
+
"""
|
|
199
|
+
del static_buffer_size # Currently ignored.
|
|
200
|
+
flat_feature_specs, _ = jax.tree.flatten(feature_specs)
|
|
201
|
+
|
|
202
|
+
feature_tokens = []
|
|
203
|
+
feature_weights = []
|
|
204
|
+
|
|
205
|
+
def collect_tokens_and_weights(
|
|
206
|
+
feature_spec: FeatureSpec, samples: FeatureSamples
|
|
207
|
+
) -> None:
|
|
208
|
+
del feature_spec
|
|
209
|
+
feature_tokens.append(samples.tokens)
|
|
210
|
+
feature_weights.append(samples.weights)
|
|
211
|
+
|
|
212
|
+
jax.tree.map(collect_tokens_and_weights, feature_specs, feature_samples)
|
|
213
|
+
|
|
214
|
+
preprocessed_inputs, stats = embedding.preprocess_sparse_dense_matmul_input(
|
|
215
|
+
feature_tokens,
|
|
216
|
+
feature_weights,
|
|
217
|
+
flat_feature_specs,
|
|
218
|
+
local_device_count=local_device_count,
|
|
219
|
+
global_device_count=global_device_count,
|
|
220
|
+
num_sc_per_device=num_sc_per_device,
|
|
221
|
+
sharding_strategy="MOD",
|
|
222
|
+
has_leading_dimension=False,
|
|
223
|
+
allow_id_dropping=True,
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
out: dict[str, ShardedCooMatrix] = {}
|
|
227
|
+
tables_names = preprocessed_inputs.lhs_row_pointers.keys()
|
|
228
|
+
for table_name in tables_names:
|
|
229
|
+
shard_ends = preprocessed_inputs.lhs_row_pointers[table_name]
|
|
230
|
+
shard_starts = np.concatenate(
|
|
231
|
+
[
|
|
232
|
+
np.asarray([0]),
|
|
233
|
+
table_stacking._next_largest_multiple(shard_ends[:-1], 8),
|
|
234
|
+
]
|
|
235
|
+
)
|
|
236
|
+
out[table_name] = ShardedCooMatrix(
|
|
237
|
+
shard_starts=shard_starts,
|
|
238
|
+
shard_ends=shard_ends,
|
|
239
|
+
col_ids=preprocessed_inputs.lhs_embedding_ids[table_name],
|
|
240
|
+
row_ids=preprocessed_inputs.lhs_sample_ids[table_name],
|
|
241
|
+
values=preprocessed_inputs.lhs_gains[table_name],
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
return out, stats
|
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
import keras
|
|
4
|
+
from keras import ops
|
|
5
|
+
|
|
6
|
+
from keras_rs.src import types
|
|
7
|
+
from keras_rs.src.api_export import keras_rs_export
|
|
8
|
+
from keras_rs.src.metrics.ranking_metrics_utils import sort_by_scores
|
|
9
|
+
from keras_rs.src.metrics.utils import standardize_call_inputs_ranks
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@keras_rs_export("keras_rs.losses.ListMLELoss")
|
|
13
|
+
class ListMLELoss(keras.losses.Loss):
|
|
14
|
+
"""Implements ListMLE (Maximum Likelihood Estimation) loss for ranking.
|
|
15
|
+
|
|
16
|
+
ListMLE loss is a listwise ranking loss that maximizes the likelihood of
|
|
17
|
+
the ground truth ranking. It works by:
|
|
18
|
+
1. Sorting items by their relevance scores (labels)
|
|
19
|
+
2. Computing the probability of observing this ranking given the
|
|
20
|
+
predicted scores
|
|
21
|
+
3. Maximizing this likelihood (minimizing negative log-likelihood)
|
|
22
|
+
|
|
23
|
+
The loss is computed as the negative log-likelihood of the ground truth
|
|
24
|
+
ranking given the predicted scores:
|
|
25
|
+
|
|
26
|
+
```
|
|
27
|
+
loss = -sum(log(exp(s_i) / sum(exp(s_j) for j >= i)))
|
|
28
|
+
```
|
|
29
|
+
|
|
30
|
+
where s_i is the predicted score for item i in the sorted order.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
temperature: Temperature parameter for scaling logits. Higher values
|
|
34
|
+
make the probability distribution more uniform. Defaults to 1.0.
|
|
35
|
+
reduction: Type of reduction to apply to the loss. In almost all cases
|
|
36
|
+
this should be `"sum_over_batch_size"`. Supported options are
|
|
37
|
+
`"sum"`, `"sum_over_batch_size"`, `"mean"`,
|
|
38
|
+
`"mean_with_sample_weight"` or `None`. Defaults to
|
|
39
|
+
`"sum_over_batch_size"`.
|
|
40
|
+
name: Optional name for the loss instance.
|
|
41
|
+
dtype: The dtype of the loss's computations. Defaults to `None`.
|
|
42
|
+
|
|
43
|
+
Examples:
|
|
44
|
+
```python
|
|
45
|
+
# Basic usage
|
|
46
|
+
loss_fn = ListMLELoss()
|
|
47
|
+
|
|
48
|
+
# With temperature scaling
|
|
49
|
+
loss_fn = ListMLELoss(temperature=0.5)
|
|
50
|
+
|
|
51
|
+
# Example with synthetic data
|
|
52
|
+
y_true = [[3, 2, 1, 0]] # Relevance scores
|
|
53
|
+
y_pred = [[0.8, 0.6, 0.4, 0.2]] # Predicted scores
|
|
54
|
+
loss = loss_fn(y_true, y_pred)
|
|
55
|
+
```
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
def __init__(self, temperature: float = 1.0, **kwargs: Any) -> None:
|
|
59
|
+
super().__init__(**kwargs)
|
|
60
|
+
|
|
61
|
+
if temperature <= 0.0:
|
|
62
|
+
raise ValueError(
|
|
63
|
+
f"`temperature` should be a positive float. Received: "
|
|
64
|
+
f"`temperature` = {temperature}."
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
self.temperature = temperature
|
|
68
|
+
self._epsilon = 1e-10
|
|
69
|
+
|
|
70
|
+
def compute_unreduced_loss(
|
|
71
|
+
self,
|
|
72
|
+
labels: types.Tensor,
|
|
73
|
+
logits: types.Tensor,
|
|
74
|
+
mask: types.Tensor | None = None,
|
|
75
|
+
) -> tuple[types.Tensor, types.Tensor]:
|
|
76
|
+
"""Compute the unreduced ListMLE loss.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
labels: Ground truth relevance scores of
|
|
80
|
+
shape [batch_size,list_size].
|
|
81
|
+
logits: Predicted scores of shape [batch_size, list_size].
|
|
82
|
+
mask: Optional mask of shape [batch_size, list_size].
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
Tuple of (losses, weights) where losses has shape [batch_size, 1]
|
|
86
|
+
and weights has the same shape.
|
|
87
|
+
"""
|
|
88
|
+
|
|
89
|
+
valid_mask = ops.greater_equal(labels, ops.cast(0.0, labels.dtype))
|
|
90
|
+
|
|
91
|
+
if mask is not None:
|
|
92
|
+
valid_mask = ops.logical_and(
|
|
93
|
+
valid_mask, ops.cast(mask, dtype="bool")
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
num_valid_items = ops.sum(
|
|
97
|
+
ops.cast(valid_mask, dtype=labels.dtype), axis=1, keepdims=True
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
batch_has_valid_items = ops.greater(num_valid_items, 0.0)
|
|
101
|
+
|
|
102
|
+
labels_for_sorting = ops.where(
|
|
103
|
+
valid_mask, labels, ops.full_like(labels, -1e9)
|
|
104
|
+
)
|
|
105
|
+
logits_masked = ops.where(
|
|
106
|
+
valid_mask, logits, ops.full_like(logits, -1e9)
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
sorted_logits, sorted_valid_mask = sort_by_scores(
|
|
110
|
+
tensors_to_sort=[logits_masked, valid_mask],
|
|
111
|
+
scores=labels_for_sorting,
|
|
112
|
+
mask=None,
|
|
113
|
+
shuffle_ties=False,
|
|
114
|
+
seed=None,
|
|
115
|
+
)
|
|
116
|
+
sorted_logits = ops.divide(
|
|
117
|
+
sorted_logits, ops.cast(self.temperature, dtype=sorted_logits.dtype)
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
valid_logits_for_max = ops.where(
|
|
121
|
+
sorted_valid_mask, sorted_logits, ops.full_like(sorted_logits, -1e9)
|
|
122
|
+
)
|
|
123
|
+
raw_max = ops.max(valid_logits_for_max, axis=1, keepdims=True)
|
|
124
|
+
raw_max = ops.where(
|
|
125
|
+
batch_has_valid_items, raw_max, ops.zeros_like(raw_max)
|
|
126
|
+
)
|
|
127
|
+
sorted_logits = ops.subtract(sorted_logits, raw_max)
|
|
128
|
+
|
|
129
|
+
# Set invalid positions to very negative BEFORE exp
|
|
130
|
+
sorted_logits = ops.where(
|
|
131
|
+
sorted_valid_mask, sorted_logits, ops.full_like(sorted_logits, -1e9)
|
|
132
|
+
)
|
|
133
|
+
exp_logits = ops.exp(sorted_logits)
|
|
134
|
+
|
|
135
|
+
reversed_exp = ops.flip(exp_logits, axis=1)
|
|
136
|
+
reversed_cumsum = ops.cumsum(reversed_exp, axis=1)
|
|
137
|
+
cumsum_from_right = ops.flip(reversed_cumsum, axis=1)
|
|
138
|
+
|
|
139
|
+
log_normalizers = ops.log(cumsum_from_right + self._epsilon)
|
|
140
|
+
log_probs = ops.subtract(sorted_logits, log_normalizers)
|
|
141
|
+
|
|
142
|
+
log_probs = ops.where(
|
|
143
|
+
sorted_valid_mask, log_probs, ops.zeros_like(log_probs)
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
negative_log_likelihood = ops.negative(
|
|
147
|
+
ops.sum(log_probs, axis=1, keepdims=True)
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
negative_log_likelihood = ops.where(
|
|
151
|
+
batch_has_valid_items,
|
|
152
|
+
negative_log_likelihood,
|
|
153
|
+
ops.zeros_like(negative_log_likelihood),
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
weights = ops.ones_like(negative_log_likelihood)
|
|
157
|
+
|
|
158
|
+
return negative_log_likelihood, weights
|
|
159
|
+
|
|
160
|
+
def call(
|
|
161
|
+
self,
|
|
162
|
+
y_true: types.Tensor,
|
|
163
|
+
y_pred: types.Tensor,
|
|
164
|
+
) -> types.Tensor:
|
|
165
|
+
"""Compute the ListMLE loss.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
y_true: tensor or dict. Ground truth values. If tensor, of shape
|
|
169
|
+
`(list_size)` for unbatched inputs or `(batch_size, list_size)`
|
|
170
|
+
for batched inputs. If an item has a label of -1, it is ignored
|
|
171
|
+
in loss computation. If it is a dictionary, it should have two
|
|
172
|
+
keys: `"labels"` and `"mask"`. `"mask"` can be used to ignore
|
|
173
|
+
elements in loss computation.
|
|
174
|
+
y_pred: tensor. The predicted values, of shape `(list_size)` for
|
|
175
|
+
unbatched inputs or `(batch_size, list_size)` for batched
|
|
176
|
+
inputs. Should be of the same shape as `y_true`.
|
|
177
|
+
|
|
178
|
+
Returns:
|
|
179
|
+
The loss tensor of shape [batch_size].
|
|
180
|
+
"""
|
|
181
|
+
mask = None
|
|
182
|
+
if isinstance(y_true, dict):
|
|
183
|
+
if "labels" not in y_true:
|
|
184
|
+
raise ValueError(
|
|
185
|
+
'`"labels"` should be present in `y_true`. Received: '
|
|
186
|
+
f"`y_true` = {y_true}"
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
mask = y_true.get("mask", None)
|
|
190
|
+
y_true = y_true["labels"]
|
|
191
|
+
|
|
192
|
+
y_true = ops.convert_to_tensor(y_true)
|
|
193
|
+
y_pred = ops.convert_to_tensor(y_pred)
|
|
194
|
+
if mask is not None:
|
|
195
|
+
mask = ops.convert_to_tensor(mask)
|
|
196
|
+
|
|
197
|
+
y_true, y_pred, mask, _ = standardize_call_inputs_ranks(
|
|
198
|
+
y_true, y_pred, mask
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
losses, weights = self.compute_unreduced_loss(
|
|
202
|
+
labels=y_true, logits=y_pred, mask=mask
|
|
203
|
+
)
|
|
204
|
+
losses = ops.multiply(losses, weights)
|
|
205
|
+
losses = ops.squeeze(losses, axis=-1)
|
|
206
|
+
return losses
|
|
207
|
+
|
|
208
|
+
# getting config
|
|
209
|
+
def get_config(self) -> dict[str, Any]:
|
|
210
|
+
config: dict[str, Any] = super().get_config()
|
|
211
|
+
config.update({"temperature": self.temperature})
|
|
212
|
+
return config
|
|
@@ -85,6 +85,25 @@ def sort_by_scores(
|
|
|
85
85
|
else:
|
|
86
86
|
k = ops.minimum(k, max_possible_k)
|
|
87
87
|
|
|
88
|
+
# --- Work around for PyTorch instability ---
|
|
89
|
+
# Torch's `topk` is not stable with `sorted=True`, unlike JAX and TF.
|
|
90
|
+
# See:
|
|
91
|
+
# - https://github.com/pytorch/pytorch/issues/27542
|
|
92
|
+
# - https://github.com/pytorch/pytorch/issues/88227
|
|
93
|
+
#
|
|
94
|
+
# This small "stable offset" ensures deterministic tie-breaking for
|
|
95
|
+
# equal scores. We can remove this workaround once PyTorch adds a
|
|
96
|
+
# `stable=True` flag for topk.
|
|
97
|
+
|
|
98
|
+
if keras.backend.backend() == "torch" and not shuffle_ties:
|
|
99
|
+
list_size = ops.shape(scores)[1]
|
|
100
|
+
indices = ops.arange(list_size)
|
|
101
|
+
indices = ops.expand_dims(indices, axis=0)
|
|
102
|
+
indices = ops.broadcast_to(indices, ops.shape(scores))
|
|
103
|
+
stable_offset = ops.cast(indices, scores.dtype) * 1e-6
|
|
104
|
+
scores = ops.subtract(scores, stable_offset)
|
|
105
|
+
# --- End FIX ---
|
|
106
|
+
|
|
88
107
|
# Shuffle ties randomly, and push masked values to the beginning.
|
|
89
108
|
shuffled_indices = None
|
|
90
109
|
if shuffle_ties or mask is not None:
|
|
@@ -33,6 +33,7 @@ keras_rs/src/layers/retrieval/remove_accidental_hits.py
|
|
|
33
33
|
keras_rs/src/layers/retrieval/retrieval.py
|
|
34
34
|
keras_rs/src/layers/retrieval/sampling_probability_correction.py
|
|
35
35
|
keras_rs/src/losses/__init__.py
|
|
36
|
+
keras_rs/src/losses/list_mle_loss.py
|
|
36
37
|
keras_rs/src/losses/pairwise_hinge_loss.py
|
|
37
38
|
keras_rs/src/losses/pairwise_logistic_loss.py
|
|
38
39
|
keras_rs/src/losses/pairwise_loss.py
|