keras-rs-nightly 0.3.1.dev202510170329__tar.gz → 0.3.1.dev202510190335__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of keras-rs-nightly might be problematic. Click here for more details.

Files changed (62) hide show
  1. {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510190335}/PKG-INFO +1 -1
  2. {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510190335}/keras_rs/src/layers/embedding/jax/distributed_embedding.py +4 -4
  3. keras_rs_nightly-0.3.1.dev202510190335/keras_rs/src/layers/embedding/jax/embedding_utils.py +244 -0
  4. {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510190335}/keras_rs/src/version.py +1 -1
  5. {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510190335}/keras_rs_nightly.egg-info/PKG-INFO +1 -1
  6. keras_rs_nightly-0.3.1.dev202510170329/keras_rs/src/layers/embedding/jax/embedding_utils.py +0 -535
  7. {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510190335}/README.md +0 -0
  8. {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510190335}/keras_rs/api/__init__.py +0 -0
  9. {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510190335}/keras_rs/api/layers/__init__.py +0 -0
  10. {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510190335}/keras_rs/api/losses/__init__.py +0 -0
  11. {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510190335}/keras_rs/api/metrics/__init__.py +0 -0
  12. {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510190335}/keras_rs/src/__init__.py +0 -0
  13. {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510190335}/keras_rs/src/api_export.py +0 -0
  14. {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510190335}/keras_rs/src/layers/__init__.py +0 -0
  15. {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510190335}/keras_rs/src/layers/embedding/__init__.py +0 -0
  16. {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510190335}/keras_rs/src/layers/embedding/base_distributed_embedding.py +0 -0
  17. {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510190335}/keras_rs/src/layers/embedding/distributed_embedding.py +0 -0
  18. {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510190335}/keras_rs/src/layers/embedding/distributed_embedding_config.py +0 -0
  19. {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510190335}/keras_rs/src/layers/embedding/embed_reduce.py +0 -0
  20. {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510190335}/keras_rs/src/layers/embedding/jax/__init__.py +0 -0
  21. {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510190335}/keras_rs/src/layers/embedding/jax/checkpoint_utils.py +0 -0
  22. {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510190335}/keras_rs/src/layers/embedding/jax/config_conversion.py +0 -0
  23. {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510190335}/keras_rs/src/layers/embedding/jax/embedding_lookup.py +0 -0
  24. {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510190335}/keras_rs/src/layers/embedding/tensorflow/__init__.py +0 -0
  25. {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510190335}/keras_rs/src/layers/embedding/tensorflow/config_conversion.py +0 -0
  26. {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510190335}/keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py +0 -0
  27. {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510190335}/keras_rs/src/layers/feature_interaction/__init__.py +0 -0
  28. {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510190335}/keras_rs/src/layers/feature_interaction/dot_interaction.py +0 -0
  29. {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510190335}/keras_rs/src/layers/feature_interaction/feature_cross.py +0 -0
  30. {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510190335}/keras_rs/src/layers/retrieval/__init__.py +0 -0
  31. {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510190335}/keras_rs/src/layers/retrieval/brute_force_retrieval.py +0 -0
  32. {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510190335}/keras_rs/src/layers/retrieval/hard_negative_mining.py +0 -0
  33. {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510190335}/keras_rs/src/layers/retrieval/remove_accidental_hits.py +0 -0
  34. {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510190335}/keras_rs/src/layers/retrieval/retrieval.py +0 -0
  35. {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510190335}/keras_rs/src/layers/retrieval/sampling_probability_correction.py +0 -0
  36. {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510190335}/keras_rs/src/losses/__init__.py +0 -0
  37. {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510190335}/keras_rs/src/losses/pairwise_hinge_loss.py +0 -0
  38. {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510190335}/keras_rs/src/losses/pairwise_logistic_loss.py +0 -0
  39. {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510190335}/keras_rs/src/losses/pairwise_loss.py +0 -0
  40. {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510190335}/keras_rs/src/losses/pairwise_loss_utils.py +0 -0
  41. {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510190335}/keras_rs/src/losses/pairwise_mean_squared_error.py +0 -0
  42. {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510190335}/keras_rs/src/losses/pairwise_soft_zero_one_loss.py +0 -0
  43. {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510190335}/keras_rs/src/metrics/__init__.py +0 -0
  44. {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510190335}/keras_rs/src/metrics/dcg.py +0 -0
  45. {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510190335}/keras_rs/src/metrics/mean_average_precision.py +0 -0
  46. {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510190335}/keras_rs/src/metrics/mean_reciprocal_rank.py +0 -0
  47. {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510190335}/keras_rs/src/metrics/ndcg.py +0 -0
  48. {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510190335}/keras_rs/src/metrics/precision_at_k.py +0 -0
  49. {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510190335}/keras_rs/src/metrics/ranking_metric.py +0 -0
  50. {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510190335}/keras_rs/src/metrics/ranking_metrics_utils.py +0 -0
  51. {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510190335}/keras_rs/src/metrics/recall_at_k.py +0 -0
  52. {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510190335}/keras_rs/src/metrics/utils.py +0 -0
  53. {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510190335}/keras_rs/src/types.py +0 -0
  54. {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510190335}/keras_rs/src/utils/__init__.py +0 -0
  55. {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510190335}/keras_rs/src/utils/doc_string_utils.py +0 -0
  56. {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510190335}/keras_rs/src/utils/keras_utils.py +0 -0
  57. {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510190335}/keras_rs_nightly.egg-info/SOURCES.txt +0 -0
  58. {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510190335}/keras_rs_nightly.egg-info/dependency_links.txt +0 -0
  59. {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510190335}/keras_rs_nightly.egg-info/requires.txt +0 -0
  60. {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510190335}/keras_rs_nightly.egg-info/top_level.txt +0 -0
  61. {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510190335}/pyproject.toml +0 -0
  62. {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510190335}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: keras-rs-nightly
3
- Version: 0.3.1.dev202510170329
3
+ Version: 0.3.1.dev202510190335
4
4
  Summary: Multi-backend recommender systems with Keras 3.
5
5
  Author-email: Keras team <keras-users@googlegroups.com>
6
6
  License: Apache License 2.0
@@ -442,7 +442,7 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
442
442
 
443
443
  # Collect all stacked tables.
444
444
  table_specs = embedding.get_table_specs(feature_specs)
445
- table_stacks = embedding_utils.get_table_stacks(table_specs)
445
+ table_stacks = jte_table_stacking.get_table_stacks(table_specs)
446
446
 
447
447
  # Create variables for all stacked tables and slot variables.
448
448
  with sparsecore_distribution.scope():
@@ -516,7 +516,7 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
516
516
 
517
517
  # Each stacked-table gets a ShardedCooMatrix.
518
518
  table_specs = embedding.get_table_specs(self._config.feature_specs)
519
- table_stacks = embedding_utils.get_table_stacks(table_specs)
519
+ table_stacks = jte_table_stacking.get_table_stacks(table_specs)
520
520
  stacked_table_specs = {
521
521
  stack_name: stack[0].stacked_table_spec
522
522
  for stack_name, stack in table_stacks.items()
@@ -720,7 +720,7 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
720
720
  config = self._config
721
721
  num_table_shards = config.mesh.devices.size * config.num_sc_per_device
722
722
  table_specs = embedding.get_table_specs(config.feature_specs)
723
- sharded_tables = embedding_utils.stack_and_shard_tables(
723
+ sharded_tables = jte_table_stacking.stack_and_shard_tables(
724
724
  table_specs,
725
725
  tables,
726
726
  num_table_shards,
@@ -763,7 +763,7 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
763
763
 
764
764
  return typing.cast(
765
765
  dict[str, ArrayLike],
766
- embedding_utils.unshard_and_unstack_tables(
766
+ jte_table_stacking.unshard_and_unstack_tables(
767
767
  table_specs, table_variables, num_table_shards
768
768
  ),
769
769
  )
@@ -0,0 +1,244 @@
1
+ """Utility functions for manipulating JAX embedding tables and inputs."""
2
+
3
+ import collections
4
+ from typing import Any, Mapping, NamedTuple, Sequence, TypeAlias, TypeVar
5
+
6
+ import jax
7
+ import numpy as np
8
+ from jax_tpu_embedding.sparsecore.lib.nn import embedding
9
+ from jax_tpu_embedding.sparsecore.lib.nn import table_stacking
10
+ from jax_tpu_embedding.sparsecore.lib.nn.embedding_spec import FeatureSpec
11
+
12
+ from keras_rs.src.types import Nested
13
+
14
+ T = TypeVar("T")
15
+
16
+ # Any to support tf.Ragged without needing an explicit TF dependency.
17
+ ArrayLike: TypeAlias = jax.Array | np.ndarray | Any # type: ignore
18
+ Shape: TypeAlias = tuple[int, ...]
19
+
20
+
21
+ class FeatureSamples(NamedTuple):
22
+ tokens: ArrayLike
23
+ weights: ArrayLike
24
+
25
+
26
+ class ShardedCooMatrix(NamedTuple):
27
+ shard_starts: ArrayLike
28
+ shard_ends: ArrayLike
29
+ col_ids: ArrayLike
30
+ row_ids: ArrayLike
31
+ values: ArrayLike
32
+
33
+
34
+ def convert_to_numpy(
35
+ ragged_or_dense: np.ndarray[Any, Any] | Sequence[Sequence[Any]] | Any,
36
+ dtype: Any,
37
+ ) -> np.ndarray[Any, Any]:
38
+ """Converts a ragged or dense list of inputs to a ragged/dense numpy array.
39
+
40
+ The output is adjusted to be 2D.
41
+
42
+ Args:
43
+ ragged_or_dense: Input that is either already a numpy array, or nested
44
+ sequence.
45
+ dtype: Numpy dtype of output array.
46
+
47
+ Returns:
48
+ Corresponding numpy array.
49
+ """
50
+ if hasattr(ragged_or_dense, "numpy"):
51
+ # Support tf.RaggedTensor and other TF input dtypes.
52
+ if callable(getattr(ragged_or_dense, "numpy")):
53
+ ragged_or_dense = ragged_or_dense.numpy()
54
+
55
+ if isinstance(ragged_or_dense, jax.Array):
56
+ ragged_or_dense = np.asarray(ragged_or_dense)
57
+
58
+ if isinstance(ragged_or_dense, np.ndarray):
59
+ # Convert 1D to 2D.
60
+ if ragged_or_dense.dtype != np.ndarray and ragged_or_dense.ndim == 1:
61
+ return ragged_or_dense.reshape(-1, 1).astype(dtype)
62
+
63
+ # If dense, return converted dense type.
64
+ if ragged_or_dense.dtype != np.ndarray:
65
+ return ragged_or_dense.astype(dtype)
66
+
67
+ # Ragged numpy array.
68
+ return ragged_or_dense
69
+
70
+ # Handle 1D sequence input.
71
+ if not isinstance(ragged_or_dense[0], collections.abc.Sequence):
72
+ return np.asarray(ragged_or_dense, dtype=dtype).reshape(-1, 1)
73
+
74
+ # Assemble elements into an nd-array.
75
+ counts = [len(vals) for vals in ragged_or_dense]
76
+ if all([count == counts[0] for count in counts]):
77
+ # Dense input.
78
+ return np.asarray(ragged_or_dense, dtype=dtype)
79
+ else:
80
+ # Ragged input, convert to ragged numpy arrays.
81
+ return np.array(
82
+ [np.array(row, dtype=dtype) for row in ragged_or_dense],
83
+ dtype=np.ndarray,
84
+ )
85
+
86
+
87
+ def ones_like(
88
+ ragged_or_dense: np.ndarray[Any, Any], dtype: Any = None
89
+ ) -> np.ndarray[Any, Any]:
90
+ """Creates an array of ones the same as as the input.
91
+
92
+ This differs from traditional numpy in that a ragged input will lead to
93
+ a resulting ragged array of ones, whereas np.ones_like(...) will instead
94
+ only consider the outer array and return a 1D dense array of ones.
95
+
96
+ Args:
97
+ ragged_or_dense: The ragged or dense input whose shape and data-type
98
+ define these same attributes of the returned array.
99
+ dtype: The data-type of the returned array.
100
+
101
+ Returns:
102
+ An array of ones with the same shape as the input, and specified data
103
+ type.
104
+ """
105
+ dtype = dtype or ragged_or_dense.dtype
106
+ if ragged_or_dense.dtype == np.ndarray:
107
+ # Ragged.
108
+ return np.array(
109
+ [np.ones_like(row, dtype=dtype) for row in ragged_or_dense],
110
+ dtype=np.ndarray,
111
+ )
112
+ else:
113
+ # Dense.
114
+ return np.ones_like(ragged_or_dense, dtype=dtype)
115
+
116
+
117
+ def create_feature_samples(
118
+ feature_structure: Nested[T],
119
+ feature_ids: Nested[ArrayLike | Sequence[int] | Sequence[Sequence[int]]],
120
+ feature_weights: None
121
+ | (Nested[ArrayLike | Sequence[float] | Sequence[Sequence[float]]]),
122
+ ) -> Nested[FeatureSamples]:
123
+ """Constructs a collection of sample tuples from provided IDs and weights.
124
+
125
+ Args:
126
+ feature_structure: The nested structure of the inputs (typically
127
+ `FeatureSpec`s).
128
+ feature_ids: The feature IDs to use for the samples.
129
+ feature_weights: The feature weights to use for the samples. Defaults
130
+ to ones if not provided.
131
+
132
+ Returns:
133
+ A nested collection of `FeatureSamples` corresponding to the input IDs
134
+ and weights, for use in embedding lookups.
135
+ """
136
+ # Create numpy arrays from inputs.
137
+ feature_ids = jax.tree.map(
138
+ lambda _, ids: convert_to_numpy(ids, np.int32),
139
+ feature_structure,
140
+ feature_ids,
141
+ )
142
+
143
+ if feature_weights is None:
144
+ # Make ragged or dense ones_like.
145
+ feature_weights = jax.tree.map(
146
+ lambda _, ids: ones_like(ids, np.float32),
147
+ feature_structure,
148
+ feature_ids,
149
+ )
150
+ else:
151
+ feature_weights = jax.tree.map(
152
+ lambda _, wgts: convert_to_numpy(wgts, np.float32),
153
+ feature_structure,
154
+ feature_weights,
155
+ )
156
+
157
+ # Assemble.
158
+ def _create_feature_samples(
159
+ sample_ids: np.ndarray[Any, Any],
160
+ sample_weights: np.ndarray[Any, Any],
161
+ ) -> FeatureSamples:
162
+ return FeatureSamples(sample_ids, sample_weights)
163
+
164
+ output: Nested[FeatureSamples] = jax.tree.map(
165
+ lambda _, sample_ids, sample_weights: _create_feature_samples(
166
+ sample_ids, sample_weights
167
+ ),
168
+ feature_structure,
169
+ feature_ids,
170
+ feature_weights,
171
+ )
172
+ return output
173
+
174
+
175
+ def stack_and_shard_samples(
176
+ feature_specs: Nested[FeatureSpec],
177
+ feature_samples: Nested[FeatureSamples],
178
+ local_device_count: int,
179
+ global_device_count: int,
180
+ num_sc_per_device: int,
181
+ static_buffer_size: int | Mapping[str, int] | None = None,
182
+ ) -> tuple[dict[str, ShardedCooMatrix], embedding.SparseDenseMatmulInputStats]:
183
+ """Prepares input samples for use in embedding lookups.
184
+
185
+ Args:
186
+ feature_specs: Nested collection of feature specifications.
187
+ feature_samples: Nested collection of feature samples.
188
+ local_device_count: Number of local JAX devices.
189
+ global_device_count: Number of global JAX devices.
190
+ num_sc_per_device: Number of sparsecores per device.
191
+ static_buffer_size: The static buffer size to use for the samples.
192
+ Defaults to None, in which case an upper-bound for the buffer size
193
+ will be automatically determined.
194
+
195
+ Returns:
196
+ The preprocessed inputs, and statistics useful for updating FeatureSpecs
197
+ based on the provided input data.
198
+ """
199
+ del static_buffer_size # Currently ignored.
200
+ flat_feature_specs, _ = jax.tree.flatten(feature_specs)
201
+
202
+ feature_tokens = []
203
+ feature_weights = []
204
+
205
+ def collect_tokens_and_weights(
206
+ feature_spec: FeatureSpec, samples: FeatureSamples
207
+ ) -> None:
208
+ del feature_spec
209
+ feature_tokens.append(samples.tokens)
210
+ feature_weights.append(samples.weights)
211
+
212
+ jax.tree.map(collect_tokens_and_weights, feature_specs, feature_samples)
213
+
214
+ preprocessed_inputs, stats = embedding.preprocess_sparse_dense_matmul_input(
215
+ feature_tokens,
216
+ feature_weights,
217
+ flat_feature_specs,
218
+ local_device_count=local_device_count,
219
+ global_device_count=global_device_count,
220
+ num_sc_per_device=num_sc_per_device,
221
+ sharding_strategy="MOD",
222
+ has_leading_dimension=False,
223
+ allow_id_dropping=True,
224
+ )
225
+
226
+ out: dict[str, ShardedCooMatrix] = {}
227
+ tables_names = preprocessed_inputs.lhs_row_pointers.keys()
228
+ for table_name in tables_names:
229
+ shard_ends = preprocessed_inputs.lhs_row_pointers[table_name]
230
+ shard_starts = np.concatenate(
231
+ [
232
+ np.asarray([0]),
233
+ table_stacking._next_largest_multiple(shard_ends[:-1], 8),
234
+ ]
235
+ )
236
+ out[table_name] = ShardedCooMatrix(
237
+ shard_starts=shard_starts,
238
+ shard_ends=shard_ends,
239
+ col_ids=preprocessed_inputs.lhs_embedding_ids[table_name],
240
+ row_ids=preprocessed_inputs.lhs_sample_ids[table_name],
241
+ values=preprocessed_inputs.lhs_gains[table_name],
242
+ )
243
+
244
+ return out, stats
@@ -1,7 +1,7 @@
1
1
  from keras_rs.src.api_export import keras_rs_export
2
2
 
3
3
  # Unique source of truth for the version number.
4
- __version__ = "0.3.1.dev202510170329"
4
+ __version__ = "0.3.1.dev202510190335"
5
5
 
6
6
 
7
7
  @keras_rs_export("keras_rs.version")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: keras-rs-nightly
3
- Version: 0.3.1.dev202510170329
3
+ Version: 0.3.1.dev202510190335
4
4
  Summary: Multi-backend recommender systems with Keras 3.
5
5
  Author-email: Keras team <keras-users@googlegroups.com>
6
6
  License: Apache License 2.0
@@ -1,535 +0,0 @@
1
- """Utility functions for manipulating JAX embedding tables and inputs."""
2
-
3
- import collections
4
- import typing
5
- from typing import Any, Mapping, NamedTuple, Sequence, TypeAlias, TypeVar
6
-
7
- import jax
8
- import numpy as np
9
- from jax import numpy as jnp
10
- from jax_tpu_embedding.sparsecore.lib.nn import embedding
11
- from jax_tpu_embedding.sparsecore.lib.nn.embedding_spec import FeatureSpec
12
- from jax_tpu_embedding.sparsecore.lib.nn.embedding_spec import StackedTableSpec
13
- from jax_tpu_embedding.sparsecore.lib.nn.embedding_spec import TableSpec
14
-
15
- from keras_rs.src.types import Nested
16
-
17
- T = TypeVar("T")
18
-
19
- # Any to support tf.Ragged without needing an explicit TF dependency.
20
- ArrayLike: TypeAlias = jax.Array | np.ndarray | Any # type: ignore
21
- Shape: TypeAlias = tuple[int, ...]
22
-
23
-
24
- class FeatureSamples(NamedTuple):
25
- tokens: ArrayLike
26
- weights: ArrayLike
27
-
28
-
29
- class ShardedCooMatrix(NamedTuple):
30
- shard_starts: ArrayLike
31
- shard_ends: ArrayLike
32
- col_ids: ArrayLike
33
- row_ids: ArrayLike
34
- values: ArrayLike
35
-
36
-
37
- def _round_up_to_multiple(value: int, multiple: int) -> int:
38
- return ((value + multiple - 1) // multiple) * multiple
39
-
40
-
41
- def _default_stacked_table_spec(
42
- table_spec: TableSpec, num_shards: int, batch_size: int
43
- ) -> StackedTableSpec:
44
- return StackedTableSpec(
45
- stack_name=table_spec.name,
46
- stack_vocab_size=_round_up_to_multiple(
47
- table_spec.vocabulary_size, 8 * num_shards
48
- ),
49
- stack_embedding_dim=_round_up_to_multiple(table_spec.embedding_dim, 8),
50
- optimizer=table_spec.optimizer,
51
- combiner=table_spec.combiner,
52
- total_sample_count=batch_size,
53
- max_ids_per_partition=table_spec.max_ids_per_partition,
54
- max_unique_ids_per_partition=table_spec.max_unique_ids_per_partition,
55
- )
56
-
57
-
58
- def _get_stacked_table_spec(
59
- table_spec: TableSpec, num_shards: int, batch_size: int = 0
60
- ) -> StackedTableSpec:
61
- return table_spec.stacked_table_spec or _default_stacked_table_spec(
62
- table_spec, num_shards, batch_size
63
- )
64
-
65
-
66
- def pad_table(
67
- table_spec: TableSpec,
68
- table_values: jax.Array,
69
- num_shards: int,
70
- pad_value: jnp.float32 = jnp.nan,
71
- ) -> jax.Array:
72
- """Adds appropriate padding to a table to prepare for stacking.
73
-
74
- Args:
75
- table_spec: Table specification describing the table to pad.
76
- table_values: Table values array to pad.
77
- num_shards: Number of shards in the table (typically
78
- `global_device_count * num_sc_per_device`).
79
- pad_value: Value to use for padding.
80
-
81
- Returns:
82
- Padded table values.
83
- """
84
- vocabulary_size = table_spec.vocabulary_size
85
- embedding_dim = table_spec.embedding_dim
86
- padded_vocabulary_size = _round_up_to_multiple(
87
- vocabulary_size, 8 * num_shards
88
- )
89
- stack_embedding_dim = _get_stacked_table_spec(
90
- table_spec, num_shards
91
- ).stack_embedding_dim
92
- return jnp.pad(
93
- table_values,
94
- (
95
- (0, padded_vocabulary_size - vocabulary_size),
96
- (0, stack_embedding_dim - embedding_dim),
97
- ),
98
- constant_values=pad_value,
99
- )
100
-
101
-
102
- def _stack_and_shard_table(
103
- stacked_table: jax.Array,
104
- table_spec: TableSpec,
105
- table: jax.Array,
106
- num_shards: int,
107
- pad_value: jnp.float32,
108
- ) -> jax.Array:
109
- """Stacks and shards a single table for use in sparsecore lookups."""
110
- padded_values = pad_table(table_spec, table, num_shards, pad_value)
111
- sharded_padded_vocabulary_size = padded_values.shape[0] // num_shards
112
- stack_embedding_dim = stacked_table.shape[-1]
113
-
114
- # Mod-shard vocabulary across devices.
115
- sharded_values = jnp.swapaxes(
116
- padded_values.reshape(-1, num_shards, stack_embedding_dim),
117
- 0,
118
- 1,
119
- )
120
-
121
- # Rotate shards.
122
- setting_in_stack = table_spec.setting_in_stack
123
- rotated_values = jnp.roll(
124
- sharded_values, setting_in_stack.shard_rotation, axis=0
125
- )
126
-
127
- # Insert table into the stack.
128
- table_row = setting_in_stack.row_offset_in_shard
129
- stacked_table = stacked_table.at[
130
- :, table_row : (table_row + sharded_padded_vocabulary_size), :
131
- ].set(rotated_values)
132
-
133
- return stacked_table
134
-
135
-
136
- def stack_and_shard_tables(
137
- table_specs: Nested[TableSpec],
138
- tables: Nested[ArrayLike],
139
- num_shards: int,
140
- pad_value: jnp.float32 = jnp.nan,
141
- ) -> dict[str, Nested[jax.Array]]:
142
- """Stacks and shards tables for use in sparsecore lookups.
143
-
144
- Args:
145
- table_specs: Nested collection of unstacked table specifications.
146
- tables: Table values corresponding to the table_specs.
147
- num_shards: Number of shards in the table (typically
148
- `global_device_count * num_sc_per_device`).
149
- pad_value: Value to use for padding.
150
-
151
- Returns:
152
- A mapping of stacked table names to stacked table values.
153
- """
154
-
155
- # Gather stacked table information.
156
- stacked_table_map: dict[
157
- str,
158
- tuple[StackedTableSpec, list[TableSpec]],
159
- ] = {}
160
-
161
- def collect_stacked_tables(table_spec: TableSpec) -> None:
162
- stacked_table_spec = _get_stacked_table_spec(table_spec, num_shards)
163
- stacked_table_name = stacked_table_spec.stack_name
164
- if stacked_table_name not in stacked_table_map:
165
- stacked_table_map[stacked_table_name] = (stacked_table_spec, [])
166
- stacked_table_map[stacked_table_name][1].append(table_spec)
167
-
168
- _ = jax.tree.map(collect_stacked_tables, table_specs)
169
-
170
- table_map: dict[str, Nested[jax.Array]] = {}
171
-
172
- def collect_tables(table_spec: TableSpec, table: Nested[jax.Array]) -> None:
173
- table_map[table_spec.name] = table
174
-
175
- _ = jax.tree.map(collect_tables, table_specs, tables)
176
-
177
- stacked_tables: dict[str, Nested[jax.Array]] = {}
178
- for (
179
- stacked_table_spec,
180
- table_specs,
181
- ) in stacked_table_map.values():
182
- stack_vocab_size = stacked_table_spec.stack_vocab_size
183
- sharded_vocab_size = stack_vocab_size // num_shards
184
- stack_embedding_dim = stacked_table_spec.stack_embedding_dim
185
-
186
- # Allocate initial buffer. The stacked table will be divided among
187
- # shards by splitting the vocabulary dimension:
188
- # [ v, e ] -> [s, v/s, e]
189
- stacked_table_tree = jax.tree.map(
190
- lambda _: jnp.zeros(
191
- # pylint: disable-next=cell-var-from-loop, used only in loop body.
192
- shape=(num_shards, sharded_vocab_size, stack_embedding_dim),
193
- dtype=jnp.float32,
194
- ),
195
- table_map[table_specs[0].name],
196
- )
197
-
198
- for table_spec in table_specs:
199
- table_tree = table_map[table_spec.name]
200
- stacked_table_tree = jax.tree.map(
201
- lambda stacked_table, table: _stack_and_shard_table(
202
- # pylint: disable-next=cell-var-from-loop, used only in loop body.
203
- stacked_table,
204
- # pylint: disable-next=cell-var-from-loop, used only in loop body.
205
- table_spec,
206
- table,
207
- num_shards,
208
- pad_value,
209
- ),
210
- stacked_table_tree,
211
- table_tree,
212
- )
213
-
214
- stacked_tables[stacked_table_spec.stack_name] = stacked_table_tree
215
-
216
- return stacked_tables
217
-
218
-
219
- def _unshard_and_unstack_table(
220
- table_spec: TableSpec,
221
- stacked_table_tree: Nested[jax.Array],
222
- num_shards: int,
223
- ) -> Nested[jax.Array]:
224
- """Unshards and unstacks a single table."""
225
- vocabulary_size = table_spec.vocabulary_size
226
- embedding_dim = table_spec.embedding_dim
227
-
228
- def _unshard_and_unstack_single_table(
229
- table_spec: TableSpec, stacked_table: jax.Array
230
- ) -> jax.Array:
231
- stack_embedding_dim = stacked_table.shape[-1]
232
-
233
- # Maybe re-shape in case it was flattened.
234
- stacked_table = stacked_table.reshape(
235
- num_shards, -1, stack_embedding_dim
236
- )
237
- sharded_vocabulary_size = (
238
- _round_up_to_multiple(vocabulary_size, 8 * num_shards) // num_shards
239
- )
240
-
241
- # Extract padded values from the stacked table.
242
- setting_in_stack = table_spec.setting_in_stack
243
- row = setting_in_stack.row_offset_in_shard
244
- padded_values = stacked_table[
245
- :, row : (row + sharded_vocabulary_size), :
246
- ]
247
-
248
- # Un-rotate shards.
249
- padded_values = jnp.roll(
250
- padded_values, -setting_in_stack.shard_rotation, axis=0
251
- )
252
-
253
- # Un-mod-shard.
254
- padded_values = jnp.swapaxes(padded_values, 0, 1).reshape(
255
- -1, stack_embedding_dim
256
- )
257
-
258
- # Un-pad.
259
- return padded_values[:vocabulary_size, :embedding_dim]
260
-
261
- output: Nested[jax.Array] = jax.tree.map(
262
- lambda stacked_table: _unshard_and_unstack_single_table(
263
- table_spec, stacked_table
264
- ),
265
- stacked_table_tree,
266
- )
267
- return output
268
-
269
-
270
- def unshard_and_unstack_tables(
271
- table_specs: Nested[TableSpec],
272
- stacked_tables: Mapping[str, Nested[jax.Array]],
273
- num_shards: int,
274
- ) -> Nested[jax.Array]:
275
- """Unshards and unstacks a collection of tables.
276
-
277
- Args:
278
- table_specs: Nested collection of unstacked table specifications.
279
- stacked_tables: Mapping of stacked table names to stacked table values.
280
- num_shards: Number of shards in the table (typically
281
- `global_device_count * num_sc_per_device`).
282
-
283
- Returns:
284
- A mapping of table names to unstacked table values.
285
- """
286
- output: Nested[jax.Array] = jax.tree.map(
287
- lambda table_spec: _unshard_and_unstack_table(
288
- table_spec,
289
- stacked_tables[
290
- _get_stacked_table_spec(table_spec, num_shards=1).stack_name
291
- ],
292
- num_shards,
293
- ),
294
- table_specs,
295
- )
296
- return output
297
-
298
-
299
- def get_table_stacks(
300
- table_specs: Nested[TableSpec],
301
- ) -> dict[str, list[TableSpec]]:
302
- """Extracts lists of tables that are stacked together.
303
-
304
- Args:
305
- table_specs: Nested collection of table specifications.
306
-
307
- Returns:
308
- A mapping of stacked table names to lists of table specifications for
309
- each stack.
310
- """
311
- stacked_table_specs: dict[str, list[TableSpec]] = collections.defaultdict(
312
- list
313
- )
314
- flat_table_specs, _ = jax.tree.flatten(table_specs)
315
- for table_spec in flat_table_specs:
316
- table_spec = typing.cast(TableSpec, table_spec)
317
- stacked_table_spec = table_spec.stacked_table_spec
318
- if stacked_table_spec is not None:
319
- stacked_table_specs[stacked_table_spec.stack_name].append(
320
- table_spec
321
- )
322
- else:
323
- stacked_table_specs[table_spec.name].append(table_spec)
324
-
325
- return stacked_table_specs
326
-
327
-
328
- def convert_to_numpy(
329
- ragged_or_dense: np.ndarray[Any, Any] | Sequence[Sequence[Any]] | Any,
330
- dtype: Any,
331
- ) -> np.ndarray[Any, Any]:
332
- """Converts a ragged or dense list of inputs to a ragged/dense numpy array.
333
-
334
- The output is adjusted to be 2D.
335
-
336
- Args:
337
- ragged_or_dense: Input that is either already a numpy array, or nested
338
- sequence.
339
- dtype: Numpy dtype of output array.
340
-
341
- Returns:
342
- Corresponding numpy array.
343
- """
344
- if hasattr(ragged_or_dense, "numpy"):
345
- # Support tf.RaggedTensor and other TF input dtypes.
346
- if callable(getattr(ragged_or_dense, "numpy")):
347
- ragged_or_dense = ragged_or_dense.numpy()
348
-
349
- if isinstance(ragged_or_dense, jax.Array):
350
- ragged_or_dense = np.asarray(ragged_or_dense)
351
-
352
- if isinstance(ragged_or_dense, np.ndarray):
353
- # Convert 1D to 2D.
354
- if ragged_or_dense.dtype != np.ndarray and ragged_or_dense.ndim == 1:
355
- return ragged_or_dense.reshape(-1, 1).astype(dtype)
356
-
357
- # If dense, return converted dense type.
358
- if ragged_or_dense.dtype != np.ndarray:
359
- return ragged_or_dense.astype(dtype)
360
-
361
- # Ragged numpy array.
362
- return ragged_or_dense
363
-
364
- # Handle 1D sequence input.
365
- if not isinstance(ragged_or_dense[0], collections.abc.Sequence):
366
- return np.asarray(ragged_or_dense, dtype=dtype).reshape(-1, 1)
367
-
368
- # Assemble elements into an nd-array.
369
- counts = [len(vals) for vals in ragged_or_dense]
370
- if all([count == counts[0] for count in counts]):
371
- # Dense input.
372
- return np.asarray(ragged_or_dense, dtype=dtype)
373
- else:
374
- # Ragged input, convert to ragged numpy arrays.
375
- return np.array(
376
- [np.array(row, dtype=dtype) for row in ragged_or_dense],
377
- dtype=np.ndarray,
378
- )
379
-
380
-
381
- def ones_like(
382
- ragged_or_dense: np.ndarray[Any, Any], dtype: Any = None
383
- ) -> np.ndarray[Any, Any]:
384
- """Creates an array of ones the same as as the input.
385
-
386
- This differs from traditional numpy in that a ragged input will lead to
387
- a resulting ragged array of ones, whereas np.ones_like(...) will instead
388
- only consider the outer array and return a 1D dense array of ones.
389
-
390
- Args:
391
- ragged_or_dense: The ragged or dense input whose shape and data-type
392
- define these same attributes of the returned array.
393
- dtype: The data-type of the returned array.
394
-
395
- Returns:
396
- An array of ones with the same shape as the input, and specified data
397
- type.
398
- """
399
- dtype = dtype or ragged_or_dense.dtype
400
- if ragged_or_dense.dtype == np.ndarray:
401
- # Ragged.
402
- return np.array(
403
- [np.ones_like(row, dtype=dtype) for row in ragged_or_dense],
404
- dtype=np.ndarray,
405
- )
406
- else:
407
- # Dense.
408
- return np.ones_like(ragged_or_dense, dtype=dtype)
409
-
410
-
411
- def create_feature_samples(
412
- feature_structure: Nested[T],
413
- feature_ids: Nested[ArrayLike | Sequence[int] | Sequence[Sequence[int]]],
414
- feature_weights: None
415
- | (Nested[ArrayLike | Sequence[float] | Sequence[Sequence[float]]]),
416
- ) -> Nested[FeatureSamples]:
417
- """Constructs a collection of sample tuples from provided IDs and weights.
418
-
419
- Args:
420
- feature_structure: The nested structure of the inputs (typically
421
- `FeatureSpec`s).
422
- feature_ids: The feature IDs to use for the samples.
423
- feature_weights: The feature weights to use for the samples. Defaults
424
- to ones if not provided.
425
-
426
- Returns:
427
- A nested collection of `FeatureSamples` corresponding to the input IDs
428
- and weights, for use in embedding lookups.
429
- """
430
- # Create numpy arrays from inputs.
431
- feature_ids = jax.tree.map(
432
- lambda _, ids: convert_to_numpy(ids, np.int32),
433
- feature_structure,
434
- feature_ids,
435
- )
436
-
437
- if feature_weights is None:
438
- # Make ragged or dense ones_like.
439
- feature_weights = jax.tree.map(
440
- lambda _, ids: ones_like(ids, np.float32),
441
- feature_structure,
442
- feature_ids,
443
- )
444
- else:
445
- feature_weights = jax.tree.map(
446
- lambda _, wgts: convert_to_numpy(wgts, np.float32),
447
- feature_structure,
448
- feature_weights,
449
- )
450
-
451
- # Assemble.
452
- def _create_feature_samples(
453
- sample_ids: np.ndarray[Any, Any],
454
- sample_weights: np.ndarray[Any, Any],
455
- ) -> FeatureSamples:
456
- return FeatureSamples(sample_ids, sample_weights)
457
-
458
- output: Nested[FeatureSamples] = jax.tree.map(
459
- lambda _, sample_ids, sample_weights: _create_feature_samples(
460
- sample_ids, sample_weights
461
- ),
462
- feature_structure,
463
- feature_ids,
464
- feature_weights,
465
- )
466
- return output
467
-
468
-
469
- def stack_and_shard_samples(
470
- feature_specs: Nested[FeatureSpec],
471
- feature_samples: Nested[FeatureSamples],
472
- local_device_count: int,
473
- global_device_count: int,
474
- num_sc_per_device: int,
475
- static_buffer_size: int | Mapping[str, int] | None = None,
476
- ) -> tuple[dict[str, ShardedCooMatrix], embedding.SparseDenseMatmulInputStats]:
477
- """Prepares input samples for use in embedding lookups.
478
-
479
- Args:
480
- feature_specs: Nested collection of feature specifications.
481
- feature_samples: Nested collection of feature samples.
482
- local_device_count: Number of local JAX devices.
483
- global_device_count: Number of global JAX devices.
484
- num_sc_per_device: Number of sparsecores per device.
485
- static_buffer_size: The static buffer size to use for the samples.
486
- Defaults to None, in which case an upper-bound for the buffer size
487
- will be automatically determined.
488
-
489
- Returns:
490
- The preprocessed inputs, and statistics useful for updating FeatureSpecs
491
- based on the provided input data.
492
- """
493
- del static_buffer_size # Currently ignored.
494
- flat_feature_specs, _ = jax.tree.flatten(feature_specs)
495
-
496
- feature_tokens = []
497
- feature_weights = []
498
-
499
- def collect_tokens_and_weights(
500
- feature_spec: FeatureSpec, samples: FeatureSamples
501
- ) -> None:
502
- del feature_spec
503
- feature_tokens.append(samples.tokens)
504
- feature_weights.append(samples.weights)
505
-
506
- jax.tree.map(collect_tokens_and_weights, feature_specs, feature_samples)
507
-
508
- preprocessed_inputs, stats = embedding.preprocess_sparse_dense_matmul_input(
509
- feature_tokens,
510
- feature_weights,
511
- flat_feature_specs,
512
- local_device_count=local_device_count,
513
- global_device_count=global_device_count,
514
- num_sc_per_device=num_sc_per_device,
515
- sharding_strategy="MOD",
516
- has_leading_dimension=False,
517
- allow_id_dropping=True,
518
- )
519
-
520
- out: dict[str, ShardedCooMatrix] = {}
521
- tables_names = preprocessed_inputs.lhs_row_pointers.keys()
522
- for table_name in tables_names:
523
- shard_ends = preprocessed_inputs.lhs_row_pointers[table_name]
524
- shard_starts = np.concatenate(
525
- [np.asarray([0]), _round_up_to_multiple(shard_ends[:-1], 8)]
526
- )
527
- out[table_name] = ShardedCooMatrix(
528
- shard_starts=shard_starts,
529
- shard_ends=shard_ends,
530
- col_ids=preprocessed_inputs.lhs_embedding_ids[table_name],
531
- row_ids=preprocessed_inputs.lhs_sample_ids[table_name],
532
- values=preprocessed_inputs.lhs_gains[table_name],
533
- )
534
-
535
- return out, stats