keras-rs-nightly 0.0.1.dev2025050103__py3-none-any.whl → 0.2.2.dev202506100336__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of keras-rs-nightly might be problematic. Click here for more details.
- keras_rs/layers/__init__.py +12 -0
- keras_rs/src/layers/embedding/__init__.py +0 -0
- keras_rs/src/layers/embedding/base_distributed_embedding.py +1124 -0
- keras_rs/src/layers/embedding/distributed_embedding.py +33 -0
- keras_rs/src/layers/embedding/distributed_embedding_config.py +129 -0
- keras_rs/src/layers/embedding/embed_reduce.py +309 -0
- keras_rs/src/layers/embedding/jax/__init__.py +0 -0
- keras_rs/src/layers/embedding/jax/config_conversion.py +398 -0
- keras_rs/src/layers/embedding/jax/distributed_embedding.py +892 -0
- keras_rs/src/layers/embedding/jax/embedding_lookup.py +255 -0
- keras_rs/src/layers/embedding/jax/embedding_utils.py +596 -0
- keras_rs/src/layers/embedding/tensorflow/__init__.py +0 -0
- keras_rs/src/layers/embedding/tensorflow/config_conversion.py +323 -0
- keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py +424 -0
- keras_rs/src/layers/feature_interaction/dot_interaction.py +2 -2
- keras_rs/src/layers/feature_interaction/feature_cross.py +14 -16
- keras_rs/src/layers/retrieval/brute_force_retrieval.py +5 -5
- keras_rs/src/layers/retrieval/retrieval.py +4 -4
- keras_rs/src/losses/pairwise_loss.py +2 -2
- keras_rs/src/losses/pairwise_mean_squared_error.py +1 -3
- keras_rs/src/metrics/dcg.py +2 -2
- keras_rs/src/metrics/ndcg.py +2 -2
- keras_rs/src/metrics/ranking_metric.py +4 -4
- keras_rs/src/metrics/ranking_metrics_utils.py +8 -8
- keras_rs/src/metrics/utils.py +2 -4
- keras_rs/src/types.py +43 -14
- keras_rs/src/utils/keras_utils.py +26 -6
- keras_rs/src/version.py +1 -1
- {keras_rs_nightly-0.0.1.dev2025050103.dist-info → keras_rs_nightly-0.2.2.dev202506100336.dist-info}/METADATA +6 -3
- keras_rs_nightly-0.2.2.dev202506100336.dist-info/RECORD +55 -0
- {keras_rs_nightly-0.0.1.dev2025050103.dist-info → keras_rs_nightly-0.2.2.dev202506100336.dist-info}/WHEEL +1 -1
- keras_rs_nightly-0.0.1.dev2025050103.dist-info/RECORD +0 -42
- {keras_rs_nightly-0.0.1.dev2025050103.dist-info → keras_rs_nightly-0.2.2.dev202506100336.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,892 @@
|
|
|
1
|
+
"""JAX implementation of the TPU embedding layer."""
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
import typing
|
|
5
|
+
from typing import Any, Mapping, Sequence, Union
|
|
6
|
+
|
|
7
|
+
import jax
|
|
8
|
+
import keras
|
|
9
|
+
import numpy as np
|
|
10
|
+
from jax import numpy as jnp
|
|
11
|
+
from jax.experimental import layout as jax_layout
|
|
12
|
+
from jax_tpu_embedding.sparsecore.lib.nn import embedding
|
|
13
|
+
from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec
|
|
14
|
+
from jax_tpu_embedding.sparsecore.lib.nn import (
|
|
15
|
+
table_stacking as jte_table_stacking,
|
|
16
|
+
)
|
|
17
|
+
from jax_tpu_embedding.sparsecore.utils import utils as jte_utils
|
|
18
|
+
from keras.src import backend
|
|
19
|
+
|
|
20
|
+
from keras_rs.src import types
|
|
21
|
+
from keras_rs.src.layers.embedding import base_distributed_embedding
|
|
22
|
+
from keras_rs.src.layers.embedding import distributed_embedding_config as config
|
|
23
|
+
from keras_rs.src.layers.embedding.jax import config_conversion
|
|
24
|
+
from keras_rs.src.layers.embedding.jax import (
|
|
25
|
+
embedding_lookup as jte_embedding_lookup,
|
|
26
|
+
)
|
|
27
|
+
from keras_rs.src.layers.embedding.jax import embedding_utils
|
|
28
|
+
from keras_rs.src.types import Nested
|
|
29
|
+
from keras_rs.src.utils import keras_utils
|
|
30
|
+
|
|
31
|
+
ArrayLike = Union[np.ndarray[Any, Any], jax.Array]
|
|
32
|
+
FeatureConfig = config.FeatureConfig
|
|
33
|
+
shard_map = jax.experimental.shard_map.shard_map # type: ignore[attr-defined]
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _get_partition_spec(
|
|
37
|
+
layout: (
|
|
38
|
+
keras.distribution.TensorLayout
|
|
39
|
+
| jax_layout.Layout
|
|
40
|
+
| jax.sharding.NamedSharding
|
|
41
|
+
| jax.sharding.PartitionSpec
|
|
42
|
+
),
|
|
43
|
+
) -> Any:
|
|
44
|
+
"""Extracts the partition spec from a layout or sharding."""
|
|
45
|
+
if isinstance(layout, keras.distribution.TensorLayout):
|
|
46
|
+
layout = layout.backend_layout
|
|
47
|
+
|
|
48
|
+
if isinstance(layout, jax_layout.Layout):
|
|
49
|
+
layout = layout.sharding
|
|
50
|
+
|
|
51
|
+
if isinstance(layout, jax.sharding.NamedSharding):
|
|
52
|
+
layout = layout.spec
|
|
53
|
+
|
|
54
|
+
return layout
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class ShardedInitializer(keras.initializers.Initializer):
|
|
58
|
+
"""Wraps an initializer to prepare for use with embedding tables.
|
|
59
|
+
|
|
60
|
+
Jit-compiles the function and applies optimal output sharding to
|
|
61
|
+
allow initialization on device.
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
def __init__(
|
|
65
|
+
self,
|
|
66
|
+
initializer: keras.initializers.Initializer | str,
|
|
67
|
+
layout: keras.distribution.TensorLayout | None,
|
|
68
|
+
):
|
|
69
|
+
if isinstance(initializer, str):
|
|
70
|
+
initializer = keras.initializers.get(initializer)
|
|
71
|
+
|
|
72
|
+
self._initializer = initializer
|
|
73
|
+
self._layout = layout
|
|
74
|
+
|
|
75
|
+
def __call__(
|
|
76
|
+
self, shape: types.Shape, dtype: types.DType | None = None
|
|
77
|
+
) -> jax.Array:
|
|
78
|
+
if self._layout is not None:
|
|
79
|
+
compiled_initializer = jax.jit(
|
|
80
|
+
self._initializer,
|
|
81
|
+
out_shardings=self._layout.backend_layout,
|
|
82
|
+
static_argnames=["shape", "dtype"],
|
|
83
|
+
)
|
|
84
|
+
output: jax.Array = compiled_initializer(shape, dtype)
|
|
85
|
+
return output
|
|
86
|
+
|
|
87
|
+
output = self._initializer(shape, dtype)
|
|
88
|
+
return output
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class StackedTableInitializer(keras.initializers.Initializer):
|
|
92
|
+
"""Initializes a single stacked table from multiple table initializers."""
|
|
93
|
+
|
|
94
|
+
def __init__(
|
|
95
|
+
self,
|
|
96
|
+
table_specs: Nested[embedding_spec.TableSpec],
|
|
97
|
+
num_shards: int,
|
|
98
|
+
layout: keras.distribution.TensorLayout,
|
|
99
|
+
seed: int | keras.random.SeedGenerator | jax.Array = 0,
|
|
100
|
+
):
|
|
101
|
+
# Sort table specs so we can simply concatenate them when assembling the
|
|
102
|
+
# stacked table.
|
|
103
|
+
self._table_specs = sorted(
|
|
104
|
+
keras.tree.flatten(table_specs),
|
|
105
|
+
key=lambda table_spec: (
|
|
106
|
+
table_spec.setting_in_stack.row_offset_in_shard,
|
|
107
|
+
),
|
|
108
|
+
)
|
|
109
|
+
self._num_shards = num_shards
|
|
110
|
+
self._layout = layout
|
|
111
|
+
self._key = keras.src.backend.jax.random.jax_draw_seed(seed)
|
|
112
|
+
|
|
113
|
+
def _initialize_shard(
|
|
114
|
+
self,
|
|
115
|
+
keys: jax.Array,
|
|
116
|
+
shape: tuple[int, int],
|
|
117
|
+
dtype: Any,
|
|
118
|
+
num_shards_per_device: int,
|
|
119
|
+
) -> jax.Array:
|
|
120
|
+
"""Initializes a single shard of a stacked table."""
|
|
121
|
+
del shape # Unused.
|
|
122
|
+
table_shards: list[jax.Array] = []
|
|
123
|
+
# NOTE: the following ignores padding, rotations in shard, and
|
|
124
|
+
# mod-sharding, assuming all initializers are shard-independent.
|
|
125
|
+
for i in range(num_shards_per_device):
|
|
126
|
+
for j, table_spec in enumerate(self._table_specs):
|
|
127
|
+
setting_in_stack = table_spec.setting_in_stack
|
|
128
|
+
table_shard_shape = (
|
|
129
|
+
setting_in_stack.padded_vocab_size // self._num_shards,
|
|
130
|
+
setting_in_stack.padded_embedding_dim,
|
|
131
|
+
)
|
|
132
|
+
initializer = table_spec.initializer
|
|
133
|
+
table_shards.append(
|
|
134
|
+
initializer(keys[i, j], table_shard_shape, dtype)
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
return jnp.concatenate(table_shards, axis=0)
|
|
138
|
+
|
|
139
|
+
def __call__(
|
|
140
|
+
self, shape: types.Shape, dtype: types.DType | None = None
|
|
141
|
+
) -> jax.Array:
|
|
142
|
+
stacked_table_spec = typing.cast(
|
|
143
|
+
embedding_spec.StackedTableSpec,
|
|
144
|
+
self._table_specs[0].stacked_table_spec,
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
# Input shape is governed by the table specs.
|
|
148
|
+
assert shape == (
|
|
149
|
+
stacked_table_spec.stack_vocab_size,
|
|
150
|
+
stacked_table_spec.stack_embedding_dim,
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
layout = self._layout
|
|
154
|
+
backend_layout = layout.backend_layout
|
|
155
|
+
backend_mesh = layout.device_mesh.backend_mesh
|
|
156
|
+
num_devices_along_axis = backend_mesh.shape[layout.axes[0]]
|
|
157
|
+
num_shards_per_device = self._num_shards // num_devices_along_axis
|
|
158
|
+
shard_shape = (
|
|
159
|
+
stacked_table_spec.stack_vocab_size // num_devices_along_axis,
|
|
160
|
+
stacked_table_spec.stack_embedding_dim,
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
sharded_initializer = jax.jit(
|
|
164
|
+
shard_map(
|
|
165
|
+
lambda keys: self._initialize_shard(
|
|
166
|
+
keys, shard_shape, dtype, num_shards_per_device
|
|
167
|
+
),
|
|
168
|
+
mesh=backend_mesh,
|
|
169
|
+
in_specs=_get_partition_spec(backend_layout),
|
|
170
|
+
out_specs=_get_partition_spec(backend_layout),
|
|
171
|
+
),
|
|
172
|
+
out_shardings=backend_layout,
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
keys = jax.random.split(
|
|
176
|
+
self._key, (self._num_shards, len(self._table_specs))
|
|
177
|
+
)
|
|
178
|
+
# Try extracting seeds from the existing table initializers.
|
|
179
|
+
for i, table_spec in enumerate(self._table_specs):
|
|
180
|
+
initializer = table_spec.initializer
|
|
181
|
+
if isinstance(
|
|
182
|
+
initializer, config_conversion.WrappedKerasInitializer
|
|
183
|
+
):
|
|
184
|
+
initializer_key = initializer.key()
|
|
185
|
+
if initializer_key is not None:
|
|
186
|
+
col = jax.random.split(initializer_key, self._num_shards)
|
|
187
|
+
keys = keys.at[:, i].set(col)
|
|
188
|
+
|
|
189
|
+
output: jax.Array = sharded_initializer(keys)
|
|
190
|
+
return output
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
|
|
194
|
+
"""JAX implementation of the TPU embedding layer."""
|
|
195
|
+
|
|
196
|
+
def _create_sparsecore_distribution(
|
|
197
|
+
self, sparsecore_axis_name: str = "sparsecore"
|
|
198
|
+
) -> tuple[
|
|
199
|
+
keras.distribution.ModelParallel, keras.distribution.TensorLayout
|
|
200
|
+
]:
|
|
201
|
+
"""SparseCore requires a specific layout.
|
|
202
|
+
|
|
203
|
+
The mesh must be 1D, must use all TPUs available, and must shard all
|
|
204
|
+
tables across all devices.
|
|
205
|
+
|
|
206
|
+
Args:
|
|
207
|
+
sparsecore_axis_name: The name of the sparsecore axis.
|
|
208
|
+
|
|
209
|
+
Returns:
|
|
210
|
+
A Keras distribution to use for all sparsecore operations.
|
|
211
|
+
"""
|
|
212
|
+
all_devices = jax.devices()
|
|
213
|
+
axes = [sparsecore_axis_name]
|
|
214
|
+
device_mesh = keras.distribution.DeviceMesh(
|
|
215
|
+
(len(all_devices),), axes, all_devices
|
|
216
|
+
)
|
|
217
|
+
sparsecore_layout = keras.distribution.TensorLayout(axes, device_mesh)
|
|
218
|
+
# Custom sparsecore layout with tiling.
|
|
219
|
+
# pylint: disable-next=protected-access
|
|
220
|
+
sparsecore_layout._backend_layout = jax_layout.Layout(
|
|
221
|
+
jax_layout.DeviceLocalLayout(
|
|
222
|
+
major_to_minor=(0, 1),
|
|
223
|
+
_tiling=((8,),),
|
|
224
|
+
),
|
|
225
|
+
jax.sharding.NamedSharding(
|
|
226
|
+
device_mesh.backend_mesh,
|
|
227
|
+
jax.sharding.PartitionSpec(
|
|
228
|
+
axes # type: ignore[no-untyped-call]
|
|
229
|
+
),
|
|
230
|
+
),
|
|
231
|
+
)
|
|
232
|
+
layout_map = keras.distribution.LayoutMap(device_mesh=device_mesh)
|
|
233
|
+
path = self.path
|
|
234
|
+
if path is None:
|
|
235
|
+
# Layer hasn't been properly built yet. Use current layer name.
|
|
236
|
+
path = self.name
|
|
237
|
+
layout_map[path + "/var"] = sparsecore_layout
|
|
238
|
+
sparsecore_distribution = keras.distribution.ModelParallel(
|
|
239
|
+
layout_map=layout_map
|
|
240
|
+
)
|
|
241
|
+
return sparsecore_distribution, sparsecore_layout
|
|
242
|
+
|
|
243
|
+
def _create_cpu_distribution(
|
|
244
|
+
self, cpu_axis_name: str = "cpu"
|
|
245
|
+
) -> tuple[
|
|
246
|
+
keras.distribution.ModelParallel, keras.distribution.TensorLayout
|
|
247
|
+
]:
|
|
248
|
+
"""Share a variable across all CPU processes."""
|
|
249
|
+
cpu_devices = jax.devices("cpu")
|
|
250
|
+
device_mesh = keras.distribution.DeviceMesh(
|
|
251
|
+
(len(cpu_devices),), [cpu_axis_name], cpu_devices
|
|
252
|
+
)
|
|
253
|
+
replicated_layout = keras.distribution.TensorLayout([], device_mesh)
|
|
254
|
+
layout_map = keras.distribution.LayoutMap(device_mesh=device_mesh)
|
|
255
|
+
cpu_distribution = keras.distribution.ModelParallel(
|
|
256
|
+
layout_map=layout_map
|
|
257
|
+
)
|
|
258
|
+
return cpu_distribution, replicated_layout
|
|
259
|
+
|
|
260
|
+
def _add_sparsecore_weight(
|
|
261
|
+
self,
|
|
262
|
+
name: str,
|
|
263
|
+
shape: tuple[int, int],
|
|
264
|
+
initializer: jax.nn.initializers.Initializer,
|
|
265
|
+
dtype: Any,
|
|
266
|
+
overwrite_with_gradient: bool,
|
|
267
|
+
) -> keras.Variable:
|
|
268
|
+
var = self.add_weight(
|
|
269
|
+
name=name, shape=shape, initializer=initializer, dtype=dtype
|
|
270
|
+
)
|
|
271
|
+
var.overwrite_with_gradient = overwrite_with_gradient
|
|
272
|
+
return var
|
|
273
|
+
|
|
274
|
+
def _add_table_variable(
|
|
275
|
+
self,
|
|
276
|
+
table_specs: Sequence[embedding_spec.TableSpec],
|
|
277
|
+
num_shards: int,
|
|
278
|
+
add_slot_variables: bool,
|
|
279
|
+
) -> tuple[keras.Variable, tuple[keras.Variable, ...] | None]:
|
|
280
|
+
stacked_table_spec = typing.cast(
|
|
281
|
+
embedding_spec.StackedTableSpec, table_specs[0].stacked_table_spec
|
|
282
|
+
)
|
|
283
|
+
optimizer = stacked_table_spec.optimizer
|
|
284
|
+
num_slot_variables = optimizer.slot_variables_count()
|
|
285
|
+
table_shape = (
|
|
286
|
+
stacked_table_spec.stack_vocab_size,
|
|
287
|
+
stacked_table_spec.stack_embedding_dim,
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
# Make a stacked embedding table initializer.
|
|
291
|
+
table_initializers = [
|
|
292
|
+
config_conversion.jax_to_keras_initializer(table_spec.initializer)
|
|
293
|
+
for table_spec in table_specs
|
|
294
|
+
]
|
|
295
|
+
# If all initializers are the same, we can use a single sharded
|
|
296
|
+
# initializer. Otherwise, we need to interleave individual stacked table
|
|
297
|
+
# shards.
|
|
298
|
+
sparsecore_layout = self._sparsecore_layout
|
|
299
|
+
stacked_table_initializer = ShardedInitializer(
|
|
300
|
+
table_initializers[0], sparsecore_layout
|
|
301
|
+
)
|
|
302
|
+
if not all(
|
|
303
|
+
initializer == table_initializers[0]
|
|
304
|
+
for initializer in table_initializers
|
|
305
|
+
):
|
|
306
|
+
stacked_table_initializer = StackedTableInitializer(
|
|
307
|
+
table_specs, num_shards, sparsecore_layout
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
variable_name = f"var:{stacked_table_spec.stack_name}:table"
|
|
311
|
+
table_variable = self._add_sparsecore_weight(
|
|
312
|
+
name=variable_name,
|
|
313
|
+
shape=table_shape,
|
|
314
|
+
initializer=stacked_table_initializer,
|
|
315
|
+
dtype="float32",
|
|
316
|
+
overwrite_with_gradient=True,
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
slot_variables = None
|
|
320
|
+
if add_slot_variables:
|
|
321
|
+
# All optimizers for a given stacked table are guaranteed to be the
|
|
322
|
+
# same, so we can use a single sharded initializer for the entire
|
|
323
|
+
# stacked table.
|
|
324
|
+
slot_initializers = optimizer.slot_variables_initializers()
|
|
325
|
+
# Try extracting field names from variables, otherwise just use the
|
|
326
|
+
# count.
|
|
327
|
+
slot_names = range(num_slot_variables)
|
|
328
|
+
if hasattr(slot_initializers, "_fields"):
|
|
329
|
+
slot_names = slot_initializers._fields
|
|
330
|
+
|
|
331
|
+
slot_variables = tuple(
|
|
332
|
+
self._add_sparsecore_weight(
|
|
333
|
+
name=f"{variable_name}:slot:{slot_name}",
|
|
334
|
+
shape=table_shape,
|
|
335
|
+
initializer=ShardedInitializer(
|
|
336
|
+
config_conversion.jax_to_keras_initializer(initializer),
|
|
337
|
+
sparsecore_layout,
|
|
338
|
+
),
|
|
339
|
+
dtype=jnp.float32,
|
|
340
|
+
overwrite_with_gradient=True,
|
|
341
|
+
)
|
|
342
|
+
for slot_name, initializer in zip(slot_names, slot_initializers)
|
|
343
|
+
)
|
|
344
|
+
slot_variables = keras.tree.pack_sequence_as(
|
|
345
|
+
slot_initializers, slot_variables
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
return table_variable, slot_variables
|
|
349
|
+
|
|
350
|
+
@keras_utils.no_automatic_dependency_tracking
|
|
351
|
+
def _sparsecore_init(
|
|
352
|
+
self,
|
|
353
|
+
feature_configs: dict[str, FeatureConfig],
|
|
354
|
+
table_stacking: str | Sequence[str] | Sequence[Sequence[str]],
|
|
355
|
+
) -> None:
|
|
356
|
+
if not self._has_sparsecore():
|
|
357
|
+
raise ValueError(
|
|
358
|
+
"Not sparse cores available, cannot use explicit sparsecore"
|
|
359
|
+
" placement."
|
|
360
|
+
)
|
|
361
|
+
|
|
362
|
+
self._sc_feature_configs = feature_configs
|
|
363
|
+
self._sparsecore_built = False
|
|
364
|
+
# Fill in any empty default settings.
|
|
365
|
+
for feature_config in keras.tree.flatten(self._sc_feature_configs):
|
|
366
|
+
if feature_config.table.initializer is None:
|
|
367
|
+
table = feature_config.table
|
|
368
|
+
table.initializer = keras.initializers.TruncatedNormal(
|
|
369
|
+
mean=0.0, stddev=1.0 / math.sqrt(float(table.embedding_dim))
|
|
370
|
+
)
|
|
371
|
+
|
|
372
|
+
# Actual stacking of tables is done in build() to ensure the
|
|
373
|
+
# distribution is set up correctly.
|
|
374
|
+
self._table_stacking = table_stacking
|
|
375
|
+
|
|
376
|
+
def _sparsecore_build(
|
|
377
|
+
self, input_shapes: Nested[types.Shape] | None = None
|
|
378
|
+
) -> None:
|
|
379
|
+
self.sparsecore_build(input_shapes)
|
|
380
|
+
|
|
381
|
+
@keras_utils.no_automatic_dependency_tracking
|
|
382
|
+
def sparsecore_build(
|
|
383
|
+
self, input_shapes: Nested[types.Shape] | None = None
|
|
384
|
+
) -> None:
|
|
385
|
+
del input_shapes # Unused.
|
|
386
|
+
|
|
387
|
+
if self._sparsecore_built:
|
|
388
|
+
return
|
|
389
|
+
|
|
390
|
+
feature_specs = config_conversion.keras_to_jte_feature_configs(
|
|
391
|
+
self._sc_feature_configs
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
# Distribution for sparsecore operations.
|
|
395
|
+
sparsecore_distribution, sparsecore_layout = (
|
|
396
|
+
self._create_sparsecore_distribution()
|
|
397
|
+
)
|
|
398
|
+
self._sparsecore_layout = sparsecore_layout
|
|
399
|
+
self._sparsecore_distribution = sparsecore_distribution
|
|
400
|
+
|
|
401
|
+
# Distribution for CPU operations.
|
|
402
|
+
cpu_distribution, cpu_layout = self._create_cpu_distribution()
|
|
403
|
+
self._cpu_distribution = cpu_distribution
|
|
404
|
+
self._cpu_layout = cpu_layout
|
|
405
|
+
|
|
406
|
+
mesh = sparsecore_distribution.device_mesh.backend_mesh
|
|
407
|
+
global_device_count = mesh.devices.size
|
|
408
|
+
num_sc_per_device = jte_utils.num_sparsecores_per_device(
|
|
409
|
+
mesh.devices.item(0)
|
|
410
|
+
)
|
|
411
|
+
# One table shard per global sparsecore.
|
|
412
|
+
num_variable_shards = global_device_count * num_sc_per_device
|
|
413
|
+
|
|
414
|
+
# Maybe stack tables.
|
|
415
|
+
table_stacking = self._table_stacking
|
|
416
|
+
if table_stacking is not None:
|
|
417
|
+
if isinstance(table_stacking, str):
|
|
418
|
+
if table_stacking == "auto":
|
|
419
|
+
jte_table_stacking.auto_stack_tables(
|
|
420
|
+
feature_specs, global_device_count, num_sc_per_device
|
|
421
|
+
)
|
|
422
|
+
else:
|
|
423
|
+
raise ValueError(
|
|
424
|
+
f"Unsupported table stacking {table_stacking}, must be"
|
|
425
|
+
"None, 'auto', or sequences of table names to stack."
|
|
426
|
+
)
|
|
427
|
+
else:
|
|
428
|
+
if isinstance(table_stacking, list) and len(table_stacking) > 0:
|
|
429
|
+
elem = table_stacking[0]
|
|
430
|
+
# List of lists of table names.
|
|
431
|
+
if isinstance(elem, list):
|
|
432
|
+
for table_names in table_stacking:
|
|
433
|
+
jte_table_stacking.stack_tables(
|
|
434
|
+
feature_specs,
|
|
435
|
+
table_names,
|
|
436
|
+
global_device_count,
|
|
437
|
+
num_sc_per_device,
|
|
438
|
+
)
|
|
439
|
+
# Single list of table names.
|
|
440
|
+
elif isinstance(elem, str):
|
|
441
|
+
jte_table_stacking.stack_tables(
|
|
442
|
+
feature_specs,
|
|
443
|
+
table_stacking,
|
|
444
|
+
global_device_count,
|
|
445
|
+
num_sc_per_device,
|
|
446
|
+
)
|
|
447
|
+
else:
|
|
448
|
+
raise ValueError(
|
|
449
|
+
f"Unsupported table stacking {table_stacking}, "
|
|
450
|
+
"must be None, 'auto', or sequences of table names "
|
|
451
|
+
"to stack."
|
|
452
|
+
)
|
|
453
|
+
|
|
454
|
+
# Adjust any non-stacked tables to prepare for training.
|
|
455
|
+
embedding.prepare_feature_specs_for_training(
|
|
456
|
+
feature_specs, global_device_count, num_sc_per_device
|
|
457
|
+
)
|
|
458
|
+
|
|
459
|
+
# Collect all stacked tables.
|
|
460
|
+
table_specs = embedding_utils.get_table_specs(feature_specs)
|
|
461
|
+
table_stacks = embedding_utils.get_table_stacks(table_specs)
|
|
462
|
+
stacked_table_specs = {
|
|
463
|
+
stack_name: stack[0].stacked_table_spec
|
|
464
|
+
for stack_name, stack in table_stacks.items()
|
|
465
|
+
}
|
|
466
|
+
|
|
467
|
+
# Create variables for all stacked tables and slot variables.
|
|
468
|
+
with sparsecore_distribution.scope():
|
|
469
|
+
self._table_and_slot_variables = {
|
|
470
|
+
table_name: self._add_table_variable(
|
|
471
|
+
table_stack,
|
|
472
|
+
add_slot_variables=self.trainable,
|
|
473
|
+
num_shards=num_variable_shards,
|
|
474
|
+
)
|
|
475
|
+
for table_name, table_stack in table_stacks.items()
|
|
476
|
+
}
|
|
477
|
+
|
|
478
|
+
# Create a step-counter variable for use in custom table gradients.
|
|
479
|
+
# This must be a floating-point type so we can get a real gradient
|
|
480
|
+
# for it. It will automatically be updated with each application of
|
|
481
|
+
# the optimizer, since the next iteration is returned in the
|
|
482
|
+
# gradient.
|
|
483
|
+
sharded_zero_initializer = ShardedInitializer(
|
|
484
|
+
"zeros",
|
|
485
|
+
keras.distribution.TensorLayout(
|
|
486
|
+
[], sparsecore_layout.device_mesh
|
|
487
|
+
),
|
|
488
|
+
)
|
|
489
|
+
self._iterations = self.add_weight(
|
|
490
|
+
shape=(),
|
|
491
|
+
name="iteration",
|
|
492
|
+
initializer=sharded_zero_initializer,
|
|
493
|
+
dtype="float32",
|
|
494
|
+
trainable=True,
|
|
495
|
+
)
|
|
496
|
+
self._iterations.overwrite_with_gradient = True
|
|
497
|
+
|
|
498
|
+
with cpu_distribution.scope():
|
|
499
|
+
# Create variables to track static buffer size and max IDs for each
|
|
500
|
+
# table during preprocessing. These variables are shared across all
|
|
501
|
+
# processes on CPU. We don't add these via `add_weight` because we
|
|
502
|
+
# can't have them passed to the training function.
|
|
503
|
+
replicated_zeros_initializer = ShardedInitializer(
|
|
504
|
+
"zeros", cpu_layout
|
|
505
|
+
)
|
|
506
|
+
|
|
507
|
+
with backend.name_scope(self.name, caller=self):
|
|
508
|
+
self._preprocessing_buffer_size = {
|
|
509
|
+
table_name: backend.Variable(
|
|
510
|
+
initializer=replicated_zeros_initializer,
|
|
511
|
+
shape=(),
|
|
512
|
+
dtype=backend.standardize_dtype("int32"),
|
|
513
|
+
trainable=False,
|
|
514
|
+
name=table_name + ":preprocessing:buffer_size",
|
|
515
|
+
)
|
|
516
|
+
for table_name in stacked_table_specs.keys()
|
|
517
|
+
}
|
|
518
|
+
self._preprocessing_max_unique_ids_per_partition = {
|
|
519
|
+
table_name: backend.Variable(
|
|
520
|
+
shape=(),
|
|
521
|
+
name=table_name
|
|
522
|
+
+ ":preprocessing:max_unique_ids_per_partition",
|
|
523
|
+
initializer=replicated_zeros_initializer,
|
|
524
|
+
dtype=backend.standardize_dtype("int32"),
|
|
525
|
+
trainable=False,
|
|
526
|
+
)
|
|
527
|
+
for table_name in stacked_table_specs.keys()
|
|
528
|
+
}
|
|
529
|
+
|
|
530
|
+
self._preprocessing_max_ids_per_partition = {
|
|
531
|
+
table_name: backend.Variable(
|
|
532
|
+
shape=(),
|
|
533
|
+
name=table_name
|
|
534
|
+
+ ":preprocessing:max_ids_per_partition",
|
|
535
|
+
initializer=replicated_zeros_initializer,
|
|
536
|
+
dtype=backend.standardize_dtype("int32"),
|
|
537
|
+
trainable=False,
|
|
538
|
+
)
|
|
539
|
+
for table_name in stacked_table_specs.keys()
|
|
540
|
+
}
|
|
541
|
+
|
|
542
|
+
self._config = jte_embedding_lookup.EmbeddingLookupConfiguration(
|
|
543
|
+
feature_specs,
|
|
544
|
+
mesh=mesh,
|
|
545
|
+
table_partition=_get_partition_spec(sparsecore_layout),
|
|
546
|
+
samples_partition=_get_partition_spec(sparsecore_layout),
|
|
547
|
+
table_layout=sparsecore_layout.backend_layout,
|
|
548
|
+
)
|
|
549
|
+
|
|
550
|
+
self._sparsecore_built = True
|
|
551
|
+
|
|
552
|
+
def _sparsecore_symbolic_preprocess(
|
|
553
|
+
self,
|
|
554
|
+
inputs: dict[str, types.Tensor],
|
|
555
|
+
weights: dict[str, types.Tensor] | None,
|
|
556
|
+
training: bool = False,
|
|
557
|
+
) -> dict[str, dict[str, embedding_utils.ShardedCooMatrix]]:
|
|
558
|
+
"""Allow preprocess(...) with `keras.Input`s.
|
|
559
|
+
|
|
560
|
+
This is to support creating functional models via:
|
|
561
|
+
```python
|
|
562
|
+
inputs = keras.Input(shape=(None), dtype="int32")
|
|
563
|
+
weights = keras.Input(shape=(None), dtype="float32")
|
|
564
|
+
preprocessed_inputs = distributed_embedding.preprocess(inputs, weights)
|
|
565
|
+
outputs = distributed_embedding(preprocessed_inputs)
|
|
566
|
+
model = keras.Model(inputs=preprocessed_inputs, outputs=outputs)
|
|
567
|
+
```
|
|
568
|
+
|
|
569
|
+
Args:
|
|
570
|
+
inputs: SparseCore path->tensor input ID's tensors.
|
|
571
|
+
weights: Optional Sparsecore path->tensor input weights tensors.
|
|
572
|
+
training: Whether the layer is training or not.
|
|
573
|
+
|
|
574
|
+
Returns:
|
|
575
|
+
Symbolic preprocessed input tensors to the layer/model.
|
|
576
|
+
"""
|
|
577
|
+
# Arguments are currently ignored since the input shape is governed
|
|
578
|
+
# by the stacked table configuration.
|
|
579
|
+
del inputs, weights, training
|
|
580
|
+
|
|
581
|
+
# Each stacked-table gets a ShardedCooMatrix.
|
|
582
|
+
table_specs = embedding_utils.get_table_specs(
|
|
583
|
+
self._config.feature_specs
|
|
584
|
+
)
|
|
585
|
+
table_stacks = embedding_utils.get_table_stacks(table_specs)
|
|
586
|
+
stacked_table_specs = {
|
|
587
|
+
stack_name: stack[0].stacked_table_spec
|
|
588
|
+
for stack_name, stack in table_stacks.items()
|
|
589
|
+
}
|
|
590
|
+
|
|
591
|
+
def _compute_table_output_spec(
|
|
592
|
+
stacked_table_spec: embedding_spec.StackedTableSpec,
|
|
593
|
+
) -> embedding_utils.ShardedCooMatrix:
|
|
594
|
+
# The true shape of the components in the ShardedCooMatrix depends
|
|
595
|
+
# on the hardware configuration (# devices, sparsecores),
|
|
596
|
+
# properties of the input data (# max IDs, unique IDs), and other
|
|
597
|
+
# hints like a suggested internal buffer size. Some of the
|
|
598
|
+
# calculations are currently a bit in flux as we experiment with
|
|
599
|
+
# memory trade-offs. For the purposes of input/output sizes,
|
|
600
|
+
# however, the size could be viewed as dynamic 1D without affecting
|
|
601
|
+
# the output spec sizes.
|
|
602
|
+
del stacked_table_spec
|
|
603
|
+
return embedding_utils.ShardedCooMatrix(
|
|
604
|
+
# Mark these as `Input`s since that's how they will be used when
|
|
605
|
+
# constructing a functional Keras model.
|
|
606
|
+
shard_starts=keras.Input(shape=tuple(), dtype="int32"),
|
|
607
|
+
shard_ends=keras.Input(shape=tuple(), dtype="int32"),
|
|
608
|
+
col_ids=keras.Input(shape=tuple(), dtype="int32"),
|
|
609
|
+
row_ids=keras.Input(shape=tuple(), dtype="int32"),
|
|
610
|
+
values=keras.Input(shape=tuple(), dtype="float32"),
|
|
611
|
+
)
|
|
612
|
+
|
|
613
|
+
preprocessed = keras.tree.map_structure(
|
|
614
|
+
_compute_table_output_spec, stacked_table_specs
|
|
615
|
+
)
|
|
616
|
+
|
|
617
|
+
return {"inputs": preprocessed}
|
|
618
|
+
|
|
619
|
+
def _sparsecore_preprocess(
|
|
620
|
+
self,
|
|
621
|
+
inputs: dict[str, types.Tensor],
|
|
622
|
+
weights: dict[str, types.Tensor] | None,
|
|
623
|
+
training: bool = False,
|
|
624
|
+
) -> dict[str, dict[str, embedding_utils.ShardedCooMatrix]]:
|
|
625
|
+
if any(
|
|
626
|
+
isinstance(x, jax.core.Tracer) for x in keras.tree.flatten(inputs)
|
|
627
|
+
):
|
|
628
|
+
raise ValueError(
|
|
629
|
+
"DistributedEmbedding.preprocess(...) does not support"
|
|
630
|
+
" jit-compilation"
|
|
631
|
+
)
|
|
632
|
+
|
|
633
|
+
if not self._sparsecore_built:
|
|
634
|
+
self._sparsecore_build()
|
|
635
|
+
|
|
636
|
+
# Support symbolic KerasTensors (i.e. keras.Input).
|
|
637
|
+
if any(
|
|
638
|
+
isinstance(x, keras.KerasTensor) for x in keras.tree.flatten(inputs)
|
|
639
|
+
):
|
|
640
|
+
return self._sparsecore_symbolic_preprocess(
|
|
641
|
+
inputs, weights, training
|
|
642
|
+
)
|
|
643
|
+
|
|
644
|
+
samples = embedding_utils.create_feature_samples(
|
|
645
|
+
self._config.feature_specs, inputs, weights
|
|
646
|
+
)
|
|
647
|
+
|
|
648
|
+
layout = self._sparsecore_layout
|
|
649
|
+
mesh = layout.device_mesh.backend_mesh
|
|
650
|
+
global_device_count = mesh.devices.size
|
|
651
|
+
local_device_count = mesh.local_mesh.devices.size
|
|
652
|
+
num_sc_per_device = jte_utils.num_sparsecores_per_device(
|
|
653
|
+
mesh.devices.item(0)
|
|
654
|
+
)
|
|
655
|
+
|
|
656
|
+
# Get current buffer size/max_ids.
|
|
657
|
+
previous_max_ids_per_partition = keras.tree.map_structure(
|
|
658
|
+
lambda max_ids_per_partition: max_ids_per_partition.value.item(),
|
|
659
|
+
self._preprocessing_max_ids_per_partition,
|
|
660
|
+
)
|
|
661
|
+
previous_max_unique_ids_per_partition = keras.tree.map_structure(
|
|
662
|
+
lambda max_unique_ids_per_partition: (
|
|
663
|
+
max_unique_ids_per_partition.value.item()
|
|
664
|
+
),
|
|
665
|
+
self._preprocessing_max_unique_ids_per_partition,
|
|
666
|
+
)
|
|
667
|
+
previous_buffer_size = keras.tree.map_structure(
|
|
668
|
+
lambda buffer_size: buffer_size.value.item(),
|
|
669
|
+
self._preprocessing_buffer_size,
|
|
670
|
+
)
|
|
671
|
+
|
|
672
|
+
preprocessed, stats = embedding_utils.stack_and_shard_samples(
|
|
673
|
+
self._config.feature_specs,
|
|
674
|
+
samples,
|
|
675
|
+
local_device_count,
|
|
676
|
+
global_device_count,
|
|
677
|
+
num_sc_per_device,
|
|
678
|
+
static_buffer_size=previous_buffer_size,
|
|
679
|
+
)
|
|
680
|
+
|
|
681
|
+
# Extract max unique IDs and buffer sizes.
|
|
682
|
+
# We need to replicate this value across all local CPU devices.
|
|
683
|
+
if training:
|
|
684
|
+
num_local_cpu_devices = jax.local_device_count("cpu")
|
|
685
|
+
local_max_ids_per_partition = {
|
|
686
|
+
table_name: np.repeat(
|
|
687
|
+
# Maximum across all partitions and previous max.
|
|
688
|
+
np.maximum(
|
|
689
|
+
np.max(elems),
|
|
690
|
+
previous_max_ids_per_partition[table_name],
|
|
691
|
+
),
|
|
692
|
+
num_local_cpu_devices,
|
|
693
|
+
)
|
|
694
|
+
for table_name, elems in stats.max_ids_per_partition.items()
|
|
695
|
+
}
|
|
696
|
+
local_max_unique_ids_per_partition = {
|
|
697
|
+
name: np.repeat(
|
|
698
|
+
# Maximum across all partitions and previous max.
|
|
699
|
+
np.maximum(
|
|
700
|
+
np.max(elems),
|
|
701
|
+
previous_max_unique_ids_per_partition[name],
|
|
702
|
+
),
|
|
703
|
+
num_local_cpu_devices,
|
|
704
|
+
)
|
|
705
|
+
for name, elems in stats.max_unique_ids_per_partition.items()
|
|
706
|
+
}
|
|
707
|
+
local_buffer_size = {
|
|
708
|
+
table_name: np.repeat(
|
|
709
|
+
np.maximum(
|
|
710
|
+
np.max(
|
|
711
|
+
# Round values up to the next multiple of 8.
|
|
712
|
+
# Currently using this as a proxy for the actual
|
|
713
|
+
# required buffer size.
|
|
714
|
+
((elems + 7) // 8) * 8
|
|
715
|
+
)
|
|
716
|
+
* global_device_count
|
|
717
|
+
* num_sc_per_device
|
|
718
|
+
* local_device_count
|
|
719
|
+
* num_sc_per_device,
|
|
720
|
+
previous_buffer_size[table_name],
|
|
721
|
+
),
|
|
722
|
+
num_local_cpu_devices,
|
|
723
|
+
)
|
|
724
|
+
for table_name, elems in stats.max_ids_per_partition.items()
|
|
725
|
+
}
|
|
726
|
+
|
|
727
|
+
# Aggregate variables across all processes/devices.
|
|
728
|
+
max_across_cpus = jax.pmap(
|
|
729
|
+
lambda x: jax.lax.pmax( # type: ignore[no-untyped-call]
|
|
730
|
+
x, "all_cpus"
|
|
731
|
+
),
|
|
732
|
+
axis_name="all_cpus",
|
|
733
|
+
devices=self._cpu_layout.device_mesh.backend_mesh.devices,
|
|
734
|
+
)
|
|
735
|
+
new_max_ids_per_partition = max_across_cpus(
|
|
736
|
+
local_max_ids_per_partition
|
|
737
|
+
)
|
|
738
|
+
new_max_unique_ids_per_partition = max_across_cpus(
|
|
739
|
+
local_max_unique_ids_per_partition
|
|
740
|
+
)
|
|
741
|
+
new_buffer_size = max_across_cpus(local_buffer_size)
|
|
742
|
+
|
|
743
|
+
# Assign new preprocessing parameters.
|
|
744
|
+
with self._cpu_distribution.scope():
|
|
745
|
+
# For each process, all max ids/buffer sizes are replicated
|
|
746
|
+
# across all local devices. Take the value from the first
|
|
747
|
+
# device.
|
|
748
|
+
keras.tree.map_structure(
|
|
749
|
+
lambda var, values: var.assign(values[0]),
|
|
750
|
+
self._preprocessing_max_ids_per_partition,
|
|
751
|
+
new_max_ids_per_partition,
|
|
752
|
+
)
|
|
753
|
+
keras.tree.map_structure(
|
|
754
|
+
lambda var, values: var.assign(values[0]),
|
|
755
|
+
self._preprocessing_max_unique_ids_per_partition,
|
|
756
|
+
new_max_unique_ids_per_partition,
|
|
757
|
+
)
|
|
758
|
+
keras.tree.map_structure(
|
|
759
|
+
lambda var, values: var.assign(values[0]),
|
|
760
|
+
self._preprocessing_buffer_size,
|
|
761
|
+
new_buffer_size,
|
|
762
|
+
)
|
|
763
|
+
# Update parameters in the underlying feature specs.
|
|
764
|
+
int_max_ids_per_partition = keras.tree.map_structure(
|
|
765
|
+
lambda varray: varray.item(), new_max_ids_per_partition
|
|
766
|
+
)
|
|
767
|
+
int_max_unique_ids_per_partition = keras.tree.map_structure(
|
|
768
|
+
lambda varray: varray.item(),
|
|
769
|
+
new_max_unique_ids_per_partition,
|
|
770
|
+
)
|
|
771
|
+
embedding_utils.update_stacked_table_specs(
|
|
772
|
+
self._config.feature_specs,
|
|
773
|
+
int_max_ids_per_partition,
|
|
774
|
+
int_max_unique_ids_per_partition,
|
|
775
|
+
)
|
|
776
|
+
|
|
777
|
+
return {"inputs": preprocessed}
|
|
778
|
+
|
|
779
|
+
def _sparsecore_call(
|
|
780
|
+
self,
|
|
781
|
+
inputs: dict[str, types.Tensor],
|
|
782
|
+
weights: dict[str, types.Tensor] | None = None,
|
|
783
|
+
training: bool = False,
|
|
784
|
+
**kwargs: Any,
|
|
785
|
+
) -> dict[str, types.Tensor]:
|
|
786
|
+
assert weights is None
|
|
787
|
+
|
|
788
|
+
if not self._sparsecore_built:
|
|
789
|
+
self._sparsecore_build()
|
|
790
|
+
|
|
791
|
+
table_and_slots = keras.tree.map_structure(
|
|
792
|
+
lambda var: var.value, self._table_and_slot_variables
|
|
793
|
+
)
|
|
794
|
+
with self._sparsecore_distribution.scope():
|
|
795
|
+
lookup_func = jax.jit(
|
|
796
|
+
jte_embedding_lookup.embedding_lookup, static_argnames="config"
|
|
797
|
+
)
|
|
798
|
+
out: dict[str, types.Tensor] = lookup_func(
|
|
799
|
+
self._config, inputs, table_and_slots, self._iterations.value
|
|
800
|
+
)
|
|
801
|
+
return out
|
|
802
|
+
|
|
803
|
+
def set_embedding_tables(self, tables: Mapping[str, ArrayLike]) -> None:
|
|
804
|
+
"""Sets the embedding tables to specific (unsharded) values.
|
|
805
|
+
|
|
806
|
+
Args:
|
|
807
|
+
tables: Mapping of table name -> table values.
|
|
808
|
+
"""
|
|
809
|
+
if "default_device" in self._placement_to_path_to_feature_config:
|
|
810
|
+
self._default_device_set_tables(tables)
|
|
811
|
+
|
|
812
|
+
if "sparsecore" in self._placement_to_path_to_feature_config:
|
|
813
|
+
self._sparsecore_set_tables(tables)
|
|
814
|
+
|
|
815
|
+
def _default_device_set_tables(
|
|
816
|
+
self, tables: Mapping[str, ArrayLike]
|
|
817
|
+
) -> None:
|
|
818
|
+
if not self.built:
|
|
819
|
+
raise ValueError("Layer must first be built before setting tables.")
|
|
820
|
+
|
|
821
|
+
if "default_device" in self._placement_to_path_to_feature_config:
|
|
822
|
+
table_to_embedding_layer = {}
|
|
823
|
+
for (
|
|
824
|
+
path,
|
|
825
|
+
feature_config,
|
|
826
|
+
) in self._placement_to_path_to_feature_config[
|
|
827
|
+
"default_device"
|
|
828
|
+
].items():
|
|
829
|
+
table_to_embedding_layer[feature_config.table] = (
|
|
830
|
+
self._default_device_embedding_layers[path]
|
|
831
|
+
)
|
|
832
|
+
|
|
833
|
+
for table, embedding_layer in table_to_embedding_layer.items():
|
|
834
|
+
table_values = tables.get(table.name, None)
|
|
835
|
+
if table_values is not None:
|
|
836
|
+
if embedding_layer.lora_enabled:
|
|
837
|
+
raise ValueError("Cannot set table if LoRA is enabled.")
|
|
838
|
+
# pylint: disable-next=protected-access
|
|
839
|
+
embedding_layer._embeddings.assign(table_values)
|
|
840
|
+
|
|
841
|
+
def _sparsecore_set_tables(self, tables: Mapping[str, ArrayLike]) -> None:
|
|
842
|
+
if not self._sparsecore_built:
|
|
843
|
+
self._sparsecore_build()
|
|
844
|
+
|
|
845
|
+
config = self._config
|
|
846
|
+
num_table_shards = config.mesh.devices.size * config.num_sc_per_device
|
|
847
|
+
table_specs = embedding_utils.get_table_specs(config.feature_specs)
|
|
848
|
+
sharded_tables = embedding_utils.stack_and_shard_tables(
|
|
849
|
+
table_specs,
|
|
850
|
+
tables,
|
|
851
|
+
num_table_shards,
|
|
852
|
+
)
|
|
853
|
+
|
|
854
|
+
device_tables = jax.device_put(
|
|
855
|
+
jax.tree.map(
|
|
856
|
+
# Flatten shard dimension to allow auto-sharding to split the
|
|
857
|
+
# array.
|
|
858
|
+
lambda table: table.reshape((-1, table.shape[-1])),
|
|
859
|
+
sharded_tables,
|
|
860
|
+
),
|
|
861
|
+
self._sparsecore_layout.backend_layout,
|
|
862
|
+
)
|
|
863
|
+
|
|
864
|
+
# Assign stacked table variables to the device values.
|
|
865
|
+
keras.tree.map_structure_up_to(
|
|
866
|
+
device_tables,
|
|
867
|
+
lambda table_and_slot_variables,
|
|
868
|
+
table_value: table_and_slot_variables[0].assign(table_value),
|
|
869
|
+
self._table_and_slot_variables,
|
|
870
|
+
device_tables,
|
|
871
|
+
)
|
|
872
|
+
|
|
873
|
+
def _sparsecore_get_embedding_tables(self) -> dict[str, ArrayLike]:
|
|
874
|
+
if not self._sparsecore_built:
|
|
875
|
+
self.sparsecore_build()
|
|
876
|
+
|
|
877
|
+
config = self._config
|
|
878
|
+
num_table_shards = config.mesh.devices.size * config.num_sc_per_device
|
|
879
|
+
table_specs = embedding_utils.get_table_specs(config.feature_specs)
|
|
880
|
+
|
|
881
|
+
# Extract only the table variables, not the gradient slot variables.
|
|
882
|
+
table_variables = {
|
|
883
|
+
name: jax.device_get(table_and_slots[0].value)
|
|
884
|
+
for name, table_and_slots in self._table_and_slot_variables.items()
|
|
885
|
+
}
|
|
886
|
+
|
|
887
|
+
return typing.cast(
|
|
888
|
+
dict[str, ArrayLike],
|
|
889
|
+
embedding_utils.unshard_and_unstack_tables(
|
|
890
|
+
table_specs, table_variables, num_table_shards
|
|
891
|
+
),
|
|
892
|
+
)
|