keras-rs-nightly 0.0.1.dev2025021903__py3-none-any.whl → 0.3.1.dev202512130338__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. keras_rs/__init__.py +9 -28
  2. keras_rs/layers/__init__.py +37 -0
  3. keras_rs/losses/__init__.py +19 -0
  4. keras_rs/metrics/__init__.py +16 -0
  5. keras_rs/src/layers/embedding/base_distributed_embedding.py +1151 -0
  6. keras_rs/src/layers/embedding/distributed_embedding.py +33 -0
  7. keras_rs/src/layers/embedding/distributed_embedding_config.py +132 -0
  8. keras_rs/src/layers/embedding/embed_reduce.py +309 -0
  9. keras_rs/src/layers/embedding/jax/__init__.py +0 -0
  10. keras_rs/src/layers/embedding/jax/checkpoint_utils.py +104 -0
  11. keras_rs/src/layers/embedding/jax/config_conversion.py +468 -0
  12. keras_rs/src/layers/embedding/jax/distributed_embedding.py +829 -0
  13. keras_rs/src/layers/embedding/jax/embedding_lookup.py +276 -0
  14. keras_rs/src/layers/embedding/jax/embedding_utils.py +217 -0
  15. keras_rs/src/layers/embedding/tensorflow/__init__.py +0 -0
  16. keras_rs/src/layers/embedding/tensorflow/config_conversion.py +363 -0
  17. keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py +436 -0
  18. keras_rs/src/layers/feature_interaction/__init__.py +0 -0
  19. keras_rs/src/layers/{modeling → feature_interaction}/dot_interaction.py +116 -25
  20. keras_rs/src/layers/{modeling → feature_interaction}/feature_cross.py +40 -22
  21. keras_rs/src/layers/retrieval/brute_force_retrieval.py +16 -65
  22. keras_rs/src/layers/retrieval/hard_negative_mining.py +94 -0
  23. keras_rs/src/layers/retrieval/remove_accidental_hits.py +97 -0
  24. keras_rs/src/layers/retrieval/retrieval.py +127 -0
  25. keras_rs/src/layers/retrieval/sampling_probability_correction.py +63 -0
  26. keras_rs/src/losses/__init__.py +0 -0
  27. keras_rs/src/losses/list_mle_loss.py +212 -0
  28. keras_rs/src/losses/pairwise_hinge_loss.py +90 -0
  29. keras_rs/src/losses/pairwise_logistic_loss.py +99 -0
  30. keras_rs/src/losses/pairwise_loss.py +165 -0
  31. keras_rs/src/losses/pairwise_loss_utils.py +39 -0
  32. keras_rs/src/losses/pairwise_mean_squared_error.py +133 -0
  33. keras_rs/src/losses/pairwise_soft_zero_one_loss.py +98 -0
  34. keras_rs/src/metrics/__init__.py +0 -0
  35. keras_rs/src/metrics/dcg.py +161 -0
  36. keras_rs/src/metrics/mean_average_precision.py +130 -0
  37. keras_rs/src/metrics/mean_reciprocal_rank.py +121 -0
  38. keras_rs/src/metrics/ndcg.py +197 -0
  39. keras_rs/src/metrics/precision_at_k.py +117 -0
  40. keras_rs/src/metrics/ranking_metric.py +260 -0
  41. keras_rs/src/metrics/ranking_metrics_utils.py +257 -0
  42. keras_rs/src/metrics/recall_at_k.py +108 -0
  43. keras_rs/src/metrics/utils.py +70 -0
  44. keras_rs/src/types.py +43 -14
  45. keras_rs/src/utils/doc_string_utils.py +53 -0
  46. keras_rs/src/utils/keras_utils.py +52 -3
  47. keras_rs/src/utils/tpu_test_utils.py +120 -0
  48. keras_rs/src/version.py +1 -1
  49. {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/METADATA +88 -8
  50. keras_rs_nightly-0.3.1.dev202512130338.dist-info/RECORD +58 -0
  51. {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/WHEEL +1 -1
  52. keras_rs/api/__init__.py +0 -9
  53. keras_rs/api/layers/__init__.py +0 -11
  54. keras_rs_nightly-0.0.1.dev2025021903.dist-info/RECORD +0 -19
  55. /keras_rs/src/layers/{modeling → embedding}/__init__.py +0 -0
  56. {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/top_level.txt +0 -0
keras_rs/__init__.py CHANGED
@@ -1,30 +1,11 @@
1
- import os
1
+ """DO NOT EDIT.
2
2
 
3
- # Import everything from /api/ into keras_rs.
4
- from keras_rs.api import * # noqa: F403
3
+ This file was autogenerated. Do not edit it by hand,
4
+ since your modifications would be overwritten.
5
+ """
5
6
 
6
- # Import * ignores names starting with "_", and `__version__` comes from
7
- # `version` anyway.
8
- from keras_rs.src.version import __version__
9
-
10
- # Add everything in /api/ to the module search path.
11
- __path__.append(os.path.join(os.path.dirname(__file__), "api")) # noqa: F405
12
-
13
- # Don't pollute namespace.
14
- del os
15
-
16
-
17
- # Never autocomplete `.src` or `.api` on an imported keras_rs object.
18
- def __dir__() -> list[str]:
19
- keys = dict.fromkeys((globals().keys()))
20
- keys.pop("src")
21
- keys.pop("api")
22
- return list(keys)
23
-
24
-
25
- # Don't import `.src` or `.api` during `from keras_rs import *`.
26
- __all__ = [
27
- name
28
- for name in globals().keys()
29
- if not (name.startswith("_") or name in ("src", "api"))
30
- ]
7
+ from keras_rs import layers as layers
8
+ from keras_rs import losses as losses
9
+ from keras_rs import metrics as metrics
10
+ from keras_rs.src.version import __version__ as __version__
11
+ from keras_rs.src.version import version as version
@@ -0,0 +1,37 @@
1
+ """DO NOT EDIT.
2
+
3
+ This file was autogenerated. Do not edit it by hand,
4
+ since your modifications would be overwritten.
5
+ """
6
+
7
+ from keras_rs.src.layers.embedding.distributed_embedding import (
8
+ DistributedEmbedding as DistributedEmbedding,
9
+ )
10
+ from keras_rs.src.layers.embedding.distributed_embedding_config import (
11
+ FeatureConfig as FeatureConfig,
12
+ )
13
+ from keras_rs.src.layers.embedding.distributed_embedding_config import (
14
+ TableConfig as TableConfig,
15
+ )
16
+ from keras_rs.src.layers.embedding.embed_reduce import (
17
+ EmbedReduce as EmbedReduce,
18
+ )
19
+ from keras_rs.src.layers.feature_interaction.dot_interaction import (
20
+ DotInteraction as DotInteraction,
21
+ )
22
+ from keras_rs.src.layers.feature_interaction.feature_cross import (
23
+ FeatureCross as FeatureCross,
24
+ )
25
+ from keras_rs.src.layers.retrieval.brute_force_retrieval import (
26
+ BruteForceRetrieval as BruteForceRetrieval,
27
+ )
28
+ from keras_rs.src.layers.retrieval.hard_negative_mining import (
29
+ HardNegativeMining as HardNegativeMining,
30
+ )
31
+ from keras_rs.src.layers.retrieval.remove_accidental_hits import (
32
+ RemoveAccidentalHits as RemoveAccidentalHits,
33
+ )
34
+ from keras_rs.src.layers.retrieval.retrieval import Retrieval as Retrieval
35
+ from keras_rs.src.layers.retrieval.sampling_probability_correction import (
36
+ SamplingProbabilityCorrection as SamplingProbabilityCorrection,
37
+ )
@@ -0,0 +1,19 @@
1
+ """DO NOT EDIT.
2
+
3
+ This file was autogenerated. Do not edit it by hand,
4
+ since your modifications would be overwritten.
5
+ """
6
+
7
+ from keras_rs.src.losses.list_mle_loss import ListMLELoss as ListMLELoss
8
+ from keras_rs.src.losses.pairwise_hinge_loss import (
9
+ PairwiseHingeLoss as PairwiseHingeLoss,
10
+ )
11
+ from keras_rs.src.losses.pairwise_logistic_loss import (
12
+ PairwiseLogisticLoss as PairwiseLogisticLoss,
13
+ )
14
+ from keras_rs.src.losses.pairwise_mean_squared_error import (
15
+ PairwiseMeanSquaredError as PairwiseMeanSquaredError,
16
+ )
17
+ from keras_rs.src.losses.pairwise_soft_zero_one_loss import (
18
+ PairwiseSoftZeroOneLoss as PairwiseSoftZeroOneLoss,
19
+ )
@@ -0,0 +1,16 @@
1
+ """DO NOT EDIT.
2
+
3
+ This file was autogenerated. Do not edit it by hand,
4
+ since your modifications would be overwritten.
5
+ """
6
+
7
+ from keras_rs.src.metrics.dcg import DCG as DCG
8
+ from keras_rs.src.metrics.mean_average_precision import (
9
+ MeanAveragePrecision as MeanAveragePrecision,
10
+ )
11
+ from keras_rs.src.metrics.mean_reciprocal_rank import (
12
+ MeanReciprocalRank as MeanReciprocalRank,
13
+ )
14
+ from keras_rs.src.metrics.ndcg import NDCG as NDCG
15
+ from keras_rs.src.metrics.precision_at_k import PrecisionAtK as PrecisionAtK
16
+ from keras_rs.src.metrics.recall_at_k import RecallAtK as RecallAtK