keras-rs-nightly 0.0.1.dev2025021903__py3-none-any.whl → 0.3.1.dev202512130338__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_rs/__init__.py +9 -28
- keras_rs/layers/__init__.py +37 -0
- keras_rs/losses/__init__.py +19 -0
- keras_rs/metrics/__init__.py +16 -0
- keras_rs/src/layers/embedding/base_distributed_embedding.py +1151 -0
- keras_rs/src/layers/embedding/distributed_embedding.py +33 -0
- keras_rs/src/layers/embedding/distributed_embedding_config.py +132 -0
- keras_rs/src/layers/embedding/embed_reduce.py +309 -0
- keras_rs/src/layers/embedding/jax/__init__.py +0 -0
- keras_rs/src/layers/embedding/jax/checkpoint_utils.py +104 -0
- keras_rs/src/layers/embedding/jax/config_conversion.py +468 -0
- keras_rs/src/layers/embedding/jax/distributed_embedding.py +829 -0
- keras_rs/src/layers/embedding/jax/embedding_lookup.py +276 -0
- keras_rs/src/layers/embedding/jax/embedding_utils.py +217 -0
- keras_rs/src/layers/embedding/tensorflow/__init__.py +0 -0
- keras_rs/src/layers/embedding/tensorflow/config_conversion.py +363 -0
- keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py +436 -0
- keras_rs/src/layers/feature_interaction/__init__.py +0 -0
- keras_rs/src/layers/{modeling → feature_interaction}/dot_interaction.py +116 -25
- keras_rs/src/layers/{modeling → feature_interaction}/feature_cross.py +40 -22
- keras_rs/src/layers/retrieval/brute_force_retrieval.py +16 -65
- keras_rs/src/layers/retrieval/hard_negative_mining.py +94 -0
- keras_rs/src/layers/retrieval/remove_accidental_hits.py +97 -0
- keras_rs/src/layers/retrieval/retrieval.py +127 -0
- keras_rs/src/layers/retrieval/sampling_probability_correction.py +63 -0
- keras_rs/src/losses/__init__.py +0 -0
- keras_rs/src/losses/list_mle_loss.py +212 -0
- keras_rs/src/losses/pairwise_hinge_loss.py +90 -0
- keras_rs/src/losses/pairwise_logistic_loss.py +99 -0
- keras_rs/src/losses/pairwise_loss.py +165 -0
- keras_rs/src/losses/pairwise_loss_utils.py +39 -0
- keras_rs/src/losses/pairwise_mean_squared_error.py +133 -0
- keras_rs/src/losses/pairwise_soft_zero_one_loss.py +98 -0
- keras_rs/src/metrics/__init__.py +0 -0
- keras_rs/src/metrics/dcg.py +161 -0
- keras_rs/src/metrics/mean_average_precision.py +130 -0
- keras_rs/src/metrics/mean_reciprocal_rank.py +121 -0
- keras_rs/src/metrics/ndcg.py +197 -0
- keras_rs/src/metrics/precision_at_k.py +117 -0
- keras_rs/src/metrics/ranking_metric.py +260 -0
- keras_rs/src/metrics/ranking_metrics_utils.py +257 -0
- keras_rs/src/metrics/recall_at_k.py +108 -0
- keras_rs/src/metrics/utils.py +70 -0
- keras_rs/src/types.py +43 -14
- keras_rs/src/utils/doc_string_utils.py +53 -0
- keras_rs/src/utils/keras_utils.py +52 -3
- keras_rs/src/utils/tpu_test_utils.py +120 -0
- keras_rs/src/version.py +1 -1
- {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/METADATA +88 -8
- keras_rs_nightly-0.3.1.dev202512130338.dist-info/RECORD +58 -0
- {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/WHEEL +1 -1
- keras_rs/api/__init__.py +0 -9
- keras_rs/api/layers/__init__.py +0 -11
- keras_rs_nightly-0.0.1.dev2025021903.dist-info/RECORD +0 -19
- /keras_rs/src/layers/{modeling → embedding}/__init__.py +0 -0
- {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,276 @@
|
|
|
1
|
+
"""Defines a differentiable `embedding_lookup` function.
|
|
2
|
+
|
|
3
|
+
Implementation details for use in JAX models.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import functools
|
|
7
|
+
from typing import Any, Mapping, TypeAlias
|
|
8
|
+
|
|
9
|
+
import jax
|
|
10
|
+
import numpy as np
|
|
11
|
+
from jax.experimental import layout as jax_layout
|
|
12
|
+
from jax_tpu_embedding.sparsecore.lib.nn import embedding
|
|
13
|
+
from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec
|
|
14
|
+
from jax_tpu_embedding.sparsecore.utils import utils as jte_utils
|
|
15
|
+
|
|
16
|
+
from keras_rs.src.layers.embedding.jax import embedding_utils
|
|
17
|
+
from keras_rs.src.types import Nested
|
|
18
|
+
|
|
19
|
+
if jax.__version_info__ >= (0, 8, 0):
|
|
20
|
+
from jax import shard_map
|
|
21
|
+
else:
|
|
22
|
+
from jax.experimental.shard_map import shard_map as exp_shard_map
|
|
23
|
+
|
|
24
|
+
def shard_map( # type: ignore[misc]
|
|
25
|
+
f: Any = None,
|
|
26
|
+
/,
|
|
27
|
+
*,
|
|
28
|
+
out_specs: Any,
|
|
29
|
+
in_specs: Any,
|
|
30
|
+
mesh: Any = None,
|
|
31
|
+
check_vma: bool = True,
|
|
32
|
+
) -> Any:
|
|
33
|
+
return exp_shard_map(
|
|
34
|
+
f,
|
|
35
|
+
mesh=mesh,
|
|
36
|
+
in_specs=in_specs,
|
|
37
|
+
out_specs=out_specs,
|
|
38
|
+
check_rep=check_vma,
|
|
39
|
+
) # type: ignore[no-untyped-call]
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
ShardedCooMatrix = embedding_utils.ShardedCooMatrix
|
|
43
|
+
ArrayLike: TypeAlias = jax.Array | np.ndarray[Any, Any]
|
|
44
|
+
JaxLayout: TypeAlias = jax.sharding.NamedSharding | jax_layout.Format
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class EmbeddingLookupConfiguration:
|
|
48
|
+
"""Feature, mesh and sharding configrmation for lookups."""
|
|
49
|
+
|
|
50
|
+
mesh: jax.sharding.Mesh
|
|
51
|
+
feature_specs: embedding.Nested[embedding_spec.FeatureSpec]
|
|
52
|
+
table_sharding_strategy: str
|
|
53
|
+
num_sc_per_device: int
|
|
54
|
+
samples_partition: jax.sharding.PartitionSpec
|
|
55
|
+
table_partition: jax.sharding.PartitionSpec
|
|
56
|
+
table_layout: JaxLayout
|
|
57
|
+
|
|
58
|
+
def __init__(
|
|
59
|
+
self,
|
|
60
|
+
feature_specs: embedding.Nested[embedding_spec.FeatureSpec],
|
|
61
|
+
mesh: jax.sharding.Mesh | None = None,
|
|
62
|
+
table_sharding_strategy: str = "MOD",
|
|
63
|
+
num_sc_per_device: int | None = None,
|
|
64
|
+
sharding_axis: str = "sparsecore_sharding",
|
|
65
|
+
samples_partition: jax.sharding.PartitionSpec | None = None,
|
|
66
|
+
samples_layout: JaxLayout | None = None,
|
|
67
|
+
table_partition: jax.sharding.PartitionSpec | None = None,
|
|
68
|
+
table_layout: JaxLayout | None = None,
|
|
69
|
+
):
|
|
70
|
+
self.mesh = mesh or jax.sharding.Mesh(jax.devices(), sharding_axis)
|
|
71
|
+
self.feature_specs = feature_specs
|
|
72
|
+
self.table_sharding_strategy = table_sharding_strategy
|
|
73
|
+
self.num_sc_per_device = (
|
|
74
|
+
num_sc_per_device
|
|
75
|
+
if num_sc_per_device is not None
|
|
76
|
+
else jte_utils.num_sparsecores_per_device()
|
|
77
|
+
)
|
|
78
|
+
self.samples_partition = (
|
|
79
|
+
samples_partition
|
|
80
|
+
or jax.sharding.PartitionSpec(
|
|
81
|
+
sharding_axis # type: ignore[no-untyped-call]
|
|
82
|
+
)
|
|
83
|
+
)
|
|
84
|
+
self.samples_layout = samples_layout or jax.sharding.NamedSharding(
|
|
85
|
+
self.mesh, self.samples_partition
|
|
86
|
+
)
|
|
87
|
+
self.table_partition = table_partition or jax.sharding.PartitionSpec(
|
|
88
|
+
sharding_axis,
|
|
89
|
+
None, # type: ignore[no-untyped-call]
|
|
90
|
+
)
|
|
91
|
+
self.table_layout = table_layout or jax.sharding.NamedSharding(
|
|
92
|
+
self.mesh, self.table_partition
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
# Embedding lookup function with custom gradient.
|
|
97
|
+
@functools.partial(jax.custom_vjp, nondiff_argnums=(0,))
|
|
98
|
+
def embedding_lookup(
|
|
99
|
+
config: EmbeddingLookupConfiguration,
|
|
100
|
+
lookups: Mapping[str, ShardedCooMatrix],
|
|
101
|
+
tables: Nested[jax.Array],
|
|
102
|
+
step: jax.Array | None = None,
|
|
103
|
+
) -> Nested[jax.Array]:
|
|
104
|
+
"""Embedding lookup function with custom gradient.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
config: Embedding lookup configuration.
|
|
108
|
+
lookups: Embedding lookup stacked/sharded inputs.
|
|
109
|
+
tables: Embedding lookup stacked/sharded tables.
|
|
110
|
+
step: Current training step number.
|
|
111
|
+
"""
|
|
112
|
+
del step # Only used in backward pass.
|
|
113
|
+
|
|
114
|
+
# Decompose COO matrices.
|
|
115
|
+
row_pointers_raw = {}
|
|
116
|
+
embedding_ids_raw = {}
|
|
117
|
+
sample_ids_raw = {}
|
|
118
|
+
gains_raw = {}
|
|
119
|
+
for table_name, coo in lookups.items():
|
|
120
|
+
row_pointers_raw[table_name] = coo.shard_ends
|
|
121
|
+
embedding_ids_raw[table_name] = coo.col_ids
|
|
122
|
+
sample_ids_raw[table_name] = coo.row_ids
|
|
123
|
+
gains_raw[table_name] = coo.values
|
|
124
|
+
|
|
125
|
+
sparse_dense_matmul_input = embedding.SparseDenseMatmulInput(
|
|
126
|
+
row_pointers_raw,
|
|
127
|
+
embedding_ids_raw,
|
|
128
|
+
sample_ids_raw,
|
|
129
|
+
gains_raw,
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
pd = config.samples_partition
|
|
133
|
+
pt = config.table_partition
|
|
134
|
+
sharded_matmul = jax.jit(
|
|
135
|
+
shard_map(
|
|
136
|
+
functools.partial(
|
|
137
|
+
embedding.tpu_sparse_dense_matmul,
|
|
138
|
+
global_device_count=config.mesh.shape[pd[0]],
|
|
139
|
+
feature_specs=config.feature_specs,
|
|
140
|
+
sharding_strategy=config.table_sharding_strategy,
|
|
141
|
+
),
|
|
142
|
+
mesh=config.mesh,
|
|
143
|
+
in_specs=(pd, pt),
|
|
144
|
+
out_specs=pd,
|
|
145
|
+
check_vma=False,
|
|
146
|
+
),
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
activations: Nested[jax.Array] = sharded_matmul(
|
|
150
|
+
sparse_dense_matmul_input,
|
|
151
|
+
tables,
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
return activations
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def embedding_lookup_fwd(
|
|
158
|
+
config: EmbeddingLookupConfiguration,
|
|
159
|
+
lookups: Mapping[str, ShardedCooMatrix],
|
|
160
|
+
table: Nested[jax.Array],
|
|
161
|
+
step: jax.Array | None = None,
|
|
162
|
+
) -> tuple[
|
|
163
|
+
Nested[jax.Array],
|
|
164
|
+
tuple[Nested[ShardedCooMatrix], Nested[jax.Array], jax.Array | None],
|
|
165
|
+
]:
|
|
166
|
+
"""Forward pass for embedding lookup."""
|
|
167
|
+
return embedding_lookup(config, lookups, table, step), (
|
|
168
|
+
lookups,
|
|
169
|
+
table,
|
|
170
|
+
step,
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def embedding_lookup_bwd(
|
|
175
|
+
config: EmbeddingLookupConfiguration,
|
|
176
|
+
res: tuple[
|
|
177
|
+
Mapping[str, ShardedCooMatrix], # Lookups.
|
|
178
|
+
Mapping[str, Nested[jax.Array]], # Tables.
|
|
179
|
+
jax.Array | None, # Step.
|
|
180
|
+
],
|
|
181
|
+
gradients: Nested[jax.Array],
|
|
182
|
+
) -> tuple[None, Nested[jax.Array], jax.Array | None]:
|
|
183
|
+
"""Backward pass for embedding lookup.
|
|
184
|
+
|
|
185
|
+
Args:
|
|
186
|
+
config: Embedding lookup configuration.
|
|
187
|
+
res: Tuple of embedding lookup (inputs, tables, step).
|
|
188
|
+
gradients: Embedding lookup gradients.
|
|
189
|
+
|
|
190
|
+
Returns:
|
|
191
|
+
A tuple of gradients (None, table_grads, step + 1).
|
|
192
|
+
"""
|
|
193
|
+
lookups, tables, step = res
|
|
194
|
+
|
|
195
|
+
# Decompose COO matrices.
|
|
196
|
+
row_pointers_raw = {}
|
|
197
|
+
embedding_ids_raw = {}
|
|
198
|
+
sample_ids_raw = {}
|
|
199
|
+
gains_raw = {}
|
|
200
|
+
for table_name, coo in lookups.items():
|
|
201
|
+
row_pointers_raw[table_name] = coo.shard_ends
|
|
202
|
+
embedding_ids_raw[table_name] = coo.col_ids
|
|
203
|
+
sample_ids_raw[table_name] = coo.row_ids
|
|
204
|
+
gains_raw[table_name] = coo.values
|
|
205
|
+
|
|
206
|
+
sparse_dense_matmul_input = embedding.SparseDenseMatmulInput(
|
|
207
|
+
row_pointers_raw,
|
|
208
|
+
embedding_ids_raw,
|
|
209
|
+
sample_ids_raw,
|
|
210
|
+
gains_raw,
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
pt = config.table_partition
|
|
214
|
+
pd = config.samples_partition
|
|
215
|
+
# Replicate step count.
|
|
216
|
+
preplicate = jax.sharding.PartitionSpec() # type: ignore[no-untyped-call]
|
|
217
|
+
|
|
218
|
+
def grad_func(
|
|
219
|
+
gradients: Nested[jax.Array],
|
|
220
|
+
sparse_input: embedding.SparseDenseMatmulInput,
|
|
221
|
+
tables: Mapping[str, embedding.EmbeddingVariables],
|
|
222
|
+
step: jax.Array | None,
|
|
223
|
+
) -> Mapping[str, embedding.EmbeddingVariables]:
|
|
224
|
+
output: Mapping[str, embedding.EmbeddingVariables] = (
|
|
225
|
+
embedding.tpu_sparse_dense_matmul_grad(
|
|
226
|
+
gradients,
|
|
227
|
+
sparse_input,
|
|
228
|
+
tables,
|
|
229
|
+
feature_specs=config.feature_specs,
|
|
230
|
+
sharding_strategy=config.table_sharding_strategy,
|
|
231
|
+
step=step,
|
|
232
|
+
)
|
|
233
|
+
)
|
|
234
|
+
return output
|
|
235
|
+
|
|
236
|
+
# activation_layout = jax.sharding.NamedSharding(config.mesh, pd)
|
|
237
|
+
# step_layout = jax.sharding.NamedSharding(config.mesh, preplicate)
|
|
238
|
+
sharded_matmul_grad = jax.jit(
|
|
239
|
+
shard_map(
|
|
240
|
+
grad_func,
|
|
241
|
+
mesh=config.mesh,
|
|
242
|
+
in_specs=(pd, pd, pt, preplicate),
|
|
243
|
+
out_specs=pt,
|
|
244
|
+
check_vma=False,
|
|
245
|
+
),
|
|
246
|
+
# in_shardings=(
|
|
247
|
+
# activation_layout,
|
|
248
|
+
# config.samples_layout,
|
|
249
|
+
# config.table_layout,
|
|
250
|
+
# step_layout,
|
|
251
|
+
# ),
|
|
252
|
+
# out_shardings=(config.table_layout),
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
table_grads = sharded_matmul_grad(
|
|
256
|
+
gradients,
|
|
257
|
+
sparse_dense_matmul_input,
|
|
258
|
+
tables,
|
|
259
|
+
step,
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
# tpu_sparse_dense_matmul_grad returns a general Mapping (usually a dict).
|
|
263
|
+
# It may not be the same type as the embedding table (e.g. FrozenDict).
|
|
264
|
+
# Here we use flatten / unflatten to ensure the types are the same.
|
|
265
|
+
table_grads = jax.tree.unflatten(
|
|
266
|
+
jax.tree.structure(tables), jax.tree.leaves(table_grads)
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
return (
|
|
270
|
+
None,
|
|
271
|
+
table_grads,
|
|
272
|
+
step + 1 if step is not None else None, # Incremented step count.
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
embedding_lookup.defvjp(embedding_lookup_fwd, embedding_lookup_bwd)
|
|
@@ -0,0 +1,217 @@
|
|
|
1
|
+
"""Utility functions for manipulating JAX embedding tables and inputs."""
|
|
2
|
+
|
|
3
|
+
import collections
|
|
4
|
+
from typing import Any, Mapping, NamedTuple, Sequence, TypeAlias, TypeVar
|
|
5
|
+
|
|
6
|
+
import jax
|
|
7
|
+
import numpy as np
|
|
8
|
+
from jax_tpu_embedding.sparsecore.lib.nn import embedding
|
|
9
|
+
from jax_tpu_embedding.sparsecore.lib.nn import table_stacking
|
|
10
|
+
from jax_tpu_embedding.sparsecore.lib.nn.embedding_spec import FeatureSpec
|
|
11
|
+
|
|
12
|
+
from keras_rs.src.types import Nested
|
|
13
|
+
|
|
14
|
+
T = TypeVar("T")
|
|
15
|
+
|
|
16
|
+
# Any to support tf.Ragged without needing an explicit TF dependency.
|
|
17
|
+
ArrayLike: TypeAlias = jax.Array | np.ndarray | Any # type: ignore
|
|
18
|
+
Shape: TypeAlias = tuple[int, ...]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class FeatureSamples(NamedTuple):
|
|
22
|
+
tokens: ArrayLike
|
|
23
|
+
weights: ArrayLike | None
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class ShardedCooMatrix(NamedTuple):
|
|
27
|
+
shard_starts: ArrayLike
|
|
28
|
+
shard_ends: ArrayLike
|
|
29
|
+
col_ids: ArrayLike
|
|
30
|
+
row_ids: ArrayLike
|
|
31
|
+
values: ArrayLike
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def convert_to_numpy(
|
|
35
|
+
ragged_or_dense: np.ndarray[Any, Any] | Sequence[Sequence[Any]] | Any,
|
|
36
|
+
dtype: Any,
|
|
37
|
+
) -> np.ndarray[Any, Any]:
|
|
38
|
+
"""Converts a ragged or dense list of inputs to a ragged/dense numpy array.
|
|
39
|
+
|
|
40
|
+
The output is adjusted to be 2D.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
ragged_or_dense: Input that is either already a numpy array, or nested
|
|
44
|
+
sequence.
|
|
45
|
+
dtype: Numpy dtype of output array.
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
Corresponding numpy array.
|
|
49
|
+
"""
|
|
50
|
+
if hasattr(ragged_or_dense, "numpy"):
|
|
51
|
+
# Support tf.RaggedTensor and other TF input dtypes.
|
|
52
|
+
if callable(getattr(ragged_or_dense, "numpy")):
|
|
53
|
+
ragged_or_dense = ragged_or_dense.numpy()
|
|
54
|
+
|
|
55
|
+
if isinstance(ragged_or_dense, jax.Array):
|
|
56
|
+
ragged_or_dense = np.asarray(ragged_or_dense)
|
|
57
|
+
|
|
58
|
+
if isinstance(ragged_or_dense, np.ndarray):
|
|
59
|
+
# Convert 1D to 2D.
|
|
60
|
+
if ragged_or_dense.dtype != np.ndarray and ragged_or_dense.ndim == 1:
|
|
61
|
+
return ragged_or_dense.reshape(-1, 1).astype(dtype)
|
|
62
|
+
|
|
63
|
+
# If dense, return converted dense type.
|
|
64
|
+
if ragged_or_dense.dtype != np.ndarray:
|
|
65
|
+
return ragged_or_dense.astype(dtype)
|
|
66
|
+
|
|
67
|
+
# Ragged numpy array.
|
|
68
|
+
return ragged_or_dense
|
|
69
|
+
|
|
70
|
+
# Handle 1D sequence input.
|
|
71
|
+
if not isinstance(ragged_or_dense[0], collections.abc.Sequence):
|
|
72
|
+
return np.asarray(ragged_or_dense, dtype=dtype).reshape(-1, 1)
|
|
73
|
+
|
|
74
|
+
# Assemble elements into an nd-array.
|
|
75
|
+
counts = [len(vals) for vals in ragged_or_dense]
|
|
76
|
+
if all([count == counts[0] for count in counts]):
|
|
77
|
+
# Dense input.
|
|
78
|
+
return np.asarray(ragged_or_dense, dtype=dtype)
|
|
79
|
+
else:
|
|
80
|
+
# Ragged input, convert to ragged numpy arrays.
|
|
81
|
+
return np.array(
|
|
82
|
+
[np.array(row, dtype=dtype) for row in ragged_or_dense],
|
|
83
|
+
dtype=np.ndarray,
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def create_feature_samples(
|
|
88
|
+
feature_structure: Nested[T],
|
|
89
|
+
feature_ids: Nested[ArrayLike | Sequence[int] | Sequence[Sequence[int]]],
|
|
90
|
+
feature_weights: None
|
|
91
|
+
| (Nested[ArrayLike | Sequence[float] | Sequence[Sequence[float]]]),
|
|
92
|
+
) -> Nested[FeatureSamples]:
|
|
93
|
+
"""Constructs a collection of sample tuples from provided IDs and weights.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
feature_structure: The nested structure of the inputs (typically
|
|
97
|
+
`FeatureSpec`s).
|
|
98
|
+
feature_ids: The feature IDs to use for the samples.
|
|
99
|
+
feature_weights: The feature weights to use for the samples. Defaults
|
|
100
|
+
to ones if not provided.
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
A nested collection of `FeatureSamples` corresponding to the input IDs
|
|
104
|
+
and weights, for use in embedding lookups.
|
|
105
|
+
"""
|
|
106
|
+
# Create numpy arrays from inputs.
|
|
107
|
+
feature_ids = jax.tree.map(
|
|
108
|
+
lambda _, ids: convert_to_numpy(ids, np.int32),
|
|
109
|
+
feature_structure,
|
|
110
|
+
feature_ids,
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
if feature_weights is None:
|
|
114
|
+
return jax.tree.map( # type: ignore[no-any-return]
|
|
115
|
+
lambda _, ids: FeatureSamples(ids, None),
|
|
116
|
+
feature_structure,
|
|
117
|
+
feature_ids,
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
feature_weights = jax.tree.map(
|
|
121
|
+
lambda _, wgts: convert_to_numpy(wgts, np.float32),
|
|
122
|
+
feature_structure,
|
|
123
|
+
feature_weights,
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
# Assemble.
|
|
127
|
+
def _create_feature_samples(
|
|
128
|
+
sample_ids: np.ndarray[Any, Any],
|
|
129
|
+
sample_weights: np.ndarray[Any, Any],
|
|
130
|
+
) -> FeatureSamples:
|
|
131
|
+
return FeatureSamples(sample_ids, sample_weights)
|
|
132
|
+
|
|
133
|
+
output: Nested[FeatureSamples] = jax.tree.map(
|
|
134
|
+
lambda _, sample_ids, sample_weights: _create_feature_samples(
|
|
135
|
+
sample_ids, sample_weights
|
|
136
|
+
),
|
|
137
|
+
feature_structure,
|
|
138
|
+
feature_ids,
|
|
139
|
+
feature_weights,
|
|
140
|
+
)
|
|
141
|
+
return output
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def stack_and_shard_samples(
|
|
145
|
+
feature_specs: Nested[FeatureSpec],
|
|
146
|
+
feature_samples: Nested[FeatureSamples],
|
|
147
|
+
local_device_count: int,
|
|
148
|
+
global_device_count: int,
|
|
149
|
+
num_sc_per_device: int,
|
|
150
|
+
static_buffer_size: int | Mapping[str, int] | None = None,
|
|
151
|
+
) -> tuple[dict[str, ShardedCooMatrix], embedding.SparseDenseMatmulInputStats]:
|
|
152
|
+
"""Prepares input samples for use in embedding lookups.
|
|
153
|
+
|
|
154
|
+
Args:
|
|
155
|
+
feature_specs: Nested collection of feature specifications.
|
|
156
|
+
feature_samples: Nested collection of feature samples.
|
|
157
|
+
local_device_count: Number of local JAX devices.
|
|
158
|
+
global_device_count: Number of global JAX devices.
|
|
159
|
+
num_sc_per_device: Number of sparsecores per device.
|
|
160
|
+
static_buffer_size: The static buffer size to use for the samples.
|
|
161
|
+
Defaults to None, in which case an upper-bound for the buffer size
|
|
162
|
+
will be automatically determined.
|
|
163
|
+
|
|
164
|
+
Returns:
|
|
165
|
+
The preprocessed inputs, and statistics useful for updating FeatureSpecs
|
|
166
|
+
based on the provided input data.
|
|
167
|
+
"""
|
|
168
|
+
del static_buffer_size # Currently ignored.
|
|
169
|
+
flat_feature_specs, _ = jax.tree.flatten(feature_specs)
|
|
170
|
+
|
|
171
|
+
feature_tokens = []
|
|
172
|
+
collected_weights = []
|
|
173
|
+
|
|
174
|
+
def collect_tokens_and_weights(
|
|
175
|
+
feature_spec: FeatureSpec, samples: FeatureSamples
|
|
176
|
+
) -> None:
|
|
177
|
+
del feature_spec
|
|
178
|
+
feature_tokens.append(samples.tokens)
|
|
179
|
+
collected_weights.append(samples.weights)
|
|
180
|
+
|
|
181
|
+
jax.tree.map(collect_tokens_and_weights, feature_specs, feature_samples)
|
|
182
|
+
|
|
183
|
+
feature_weights = (
|
|
184
|
+
None if all(w is None for w in collected_weights) else collected_weights
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
preprocessed_inputs, stats = embedding.preprocess_sparse_dense_matmul_input(
|
|
188
|
+
feature_tokens,
|
|
189
|
+
feature_weights,
|
|
190
|
+
flat_feature_specs,
|
|
191
|
+
local_device_count=local_device_count,
|
|
192
|
+
global_device_count=global_device_count,
|
|
193
|
+
num_sc_per_device=num_sc_per_device,
|
|
194
|
+
sharding_strategy="MOD",
|
|
195
|
+
has_leading_dimension=False,
|
|
196
|
+
allow_id_dropping=True,
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
out: dict[str, ShardedCooMatrix] = {}
|
|
200
|
+
tables_names = preprocessed_inputs.lhs_row_pointers.keys()
|
|
201
|
+
for table_name in tables_names:
|
|
202
|
+
shard_ends = preprocessed_inputs.lhs_row_pointers[table_name]
|
|
203
|
+
shard_starts = np.concatenate(
|
|
204
|
+
[
|
|
205
|
+
np.asarray([0]),
|
|
206
|
+
table_stacking._next_largest_multiple(shard_ends[:-1], 8),
|
|
207
|
+
]
|
|
208
|
+
)
|
|
209
|
+
out[table_name] = ShardedCooMatrix(
|
|
210
|
+
shard_starts=shard_starts,
|
|
211
|
+
shard_ends=shard_ends,
|
|
212
|
+
col_ids=preprocessed_inputs.lhs_embedding_ids[table_name],
|
|
213
|
+
row_ids=preprocessed_inputs.lhs_sample_ids[table_name],
|
|
214
|
+
values=preprocessed_inputs.lhs_gains[table_name],
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
return out, stats
|
|
File without changes
|