keras-rs-nightly 0.2.2.dev202508190331__py3-none-any.whl → 0.4.1.dev202601250348__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_rs/losses/__init__.py +1 -0
- keras_rs/src/layers/embedding/base_distributed_embedding.py +19 -10
- keras_rs/src/layers/embedding/distributed_embedding_config.py +2 -2
- keras_rs/src/layers/embedding/jax/distributed_embedding.py +133 -201
- keras_rs/src/layers/embedding/jax/embedding_lookup.py +25 -4
- keras_rs/src/layers/embedding/jax/embedding_utils.py +22 -401
- keras_rs/src/layers/embedding/tensorflow/config_conversion.py +26 -19
- keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py +22 -5
- keras_rs/src/losses/list_mle_loss.py +212 -0
- keras_rs/src/metrics/ranking_metrics_utils.py +21 -2
- keras_rs/src/utils/tpu_test_utils.py +120 -0
- keras_rs/src/version.py +1 -1
- {keras_rs_nightly-0.2.2.dev202508190331.dist-info → keras_rs_nightly-0.4.1.dev202601250348.dist-info}/METADATA +4 -3
- {keras_rs_nightly-0.2.2.dev202508190331.dist-info → keras_rs_nightly-0.4.1.dev202601250348.dist-info}/RECORD +16 -14
- {keras_rs_nightly-0.2.2.dev202508190331.dist-info → keras_rs_nightly-0.4.1.dev202601250348.dist-info}/WHEEL +1 -1
- {keras_rs_nightly-0.2.2.dev202508190331.dist-info → keras_rs_nightly-0.4.1.dev202601250348.dist-info}/top_level.txt +0 -0
|
@@ -1,17 +1,13 @@
|
|
|
1
1
|
"""Utility functions for manipulating JAX embedding tables and inputs."""
|
|
2
2
|
|
|
3
3
|
import collections
|
|
4
|
-
import dataclasses
|
|
5
|
-
import typing
|
|
6
4
|
from typing import Any, Mapping, NamedTuple, Sequence, TypeAlias, TypeVar
|
|
7
5
|
|
|
8
6
|
import jax
|
|
9
7
|
import numpy as np
|
|
10
|
-
from jax import numpy as jnp
|
|
11
8
|
from jax_tpu_embedding.sparsecore.lib.nn import embedding
|
|
9
|
+
from jax_tpu_embedding.sparsecore.lib.nn import table_stacking
|
|
12
10
|
from jax_tpu_embedding.sparsecore.lib.nn.embedding_spec import FeatureSpec
|
|
13
|
-
from jax_tpu_embedding.sparsecore.lib.nn.embedding_spec import StackedTableSpec
|
|
14
|
-
from jax_tpu_embedding.sparsecore.lib.nn.embedding_spec import TableSpec
|
|
15
11
|
|
|
16
12
|
from keras_rs.src.types import Nested
|
|
17
13
|
|
|
@@ -24,7 +20,7 @@ Shape: TypeAlias = tuple[int, ...]
|
|
|
24
20
|
|
|
25
21
|
class FeatureSamples(NamedTuple):
|
|
26
22
|
tokens: ArrayLike
|
|
27
|
-
weights: ArrayLike
|
|
23
|
+
weights: ArrayLike | None
|
|
28
24
|
|
|
29
25
|
|
|
30
26
|
class ShardedCooMatrix(NamedTuple):
|
|
@@ -35,357 +31,6 @@ class ShardedCooMatrix(NamedTuple):
|
|
|
35
31
|
values: ArrayLike
|
|
36
32
|
|
|
37
33
|
|
|
38
|
-
def _round_up_to_multiple(value: int, multiple: int) -> int:
|
|
39
|
-
return ((value + multiple - 1) // multiple) * multiple
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
def _default_stacked_table_spec(
|
|
43
|
-
table_spec: TableSpec, num_shards: int, batch_size: int
|
|
44
|
-
) -> StackedTableSpec:
|
|
45
|
-
return StackedTableSpec(
|
|
46
|
-
stack_name=table_spec.name,
|
|
47
|
-
stack_vocab_size=_round_up_to_multiple(
|
|
48
|
-
table_spec.vocabulary_size, 8 * num_shards
|
|
49
|
-
),
|
|
50
|
-
stack_embedding_dim=_round_up_to_multiple(table_spec.embedding_dim, 8),
|
|
51
|
-
optimizer=table_spec.optimizer,
|
|
52
|
-
combiner=table_spec.combiner,
|
|
53
|
-
total_sample_count=batch_size,
|
|
54
|
-
max_ids_per_partition=table_spec.max_ids_per_partition,
|
|
55
|
-
max_unique_ids_per_partition=table_spec.max_unique_ids_per_partition,
|
|
56
|
-
)
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
def _get_stacked_table_spec(
|
|
60
|
-
table_spec: TableSpec, num_shards: int, batch_size: int = 0
|
|
61
|
-
) -> StackedTableSpec:
|
|
62
|
-
return table_spec.stacked_table_spec or _default_stacked_table_spec(
|
|
63
|
-
table_spec, num_shards, batch_size
|
|
64
|
-
)
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
def pad_table(
|
|
68
|
-
table_spec: TableSpec,
|
|
69
|
-
table_values: jax.Array,
|
|
70
|
-
num_shards: int,
|
|
71
|
-
pad_value: jnp.float32 = jnp.nan,
|
|
72
|
-
) -> jax.Array:
|
|
73
|
-
"""Adds appropriate padding to a table to prepare for stacking.
|
|
74
|
-
|
|
75
|
-
Args:
|
|
76
|
-
table_spec: Table specification describing the table to pad.
|
|
77
|
-
table_values: Table values array to pad.
|
|
78
|
-
num_shards: Number of shards in the table (typically
|
|
79
|
-
`global_device_count * num_sc_per_device`).
|
|
80
|
-
pad_value: Value to use for padding.
|
|
81
|
-
|
|
82
|
-
Returns:
|
|
83
|
-
Padded table values.
|
|
84
|
-
"""
|
|
85
|
-
vocabulary_size = table_spec.vocabulary_size
|
|
86
|
-
embedding_dim = table_spec.embedding_dim
|
|
87
|
-
padded_vocabulary_size = _round_up_to_multiple(
|
|
88
|
-
vocabulary_size, 8 * num_shards
|
|
89
|
-
)
|
|
90
|
-
stack_embedding_dim = _get_stacked_table_spec(
|
|
91
|
-
table_spec, num_shards
|
|
92
|
-
).stack_embedding_dim
|
|
93
|
-
return jnp.pad(
|
|
94
|
-
table_values,
|
|
95
|
-
(
|
|
96
|
-
(0, padded_vocabulary_size - vocabulary_size),
|
|
97
|
-
(0, stack_embedding_dim - embedding_dim),
|
|
98
|
-
),
|
|
99
|
-
constant_values=pad_value,
|
|
100
|
-
)
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
def _stack_and_shard_table(
|
|
104
|
-
stacked_table: jax.Array,
|
|
105
|
-
table_spec: TableSpec,
|
|
106
|
-
table: jax.Array,
|
|
107
|
-
num_shards: int,
|
|
108
|
-
pad_value: jnp.float32,
|
|
109
|
-
) -> jax.Array:
|
|
110
|
-
"""Stacks and shards a single table for use in sparsecore lookups."""
|
|
111
|
-
padded_values = pad_table(table_spec, table, num_shards, pad_value)
|
|
112
|
-
sharded_padded_vocabulary_size = padded_values.shape[0] // num_shards
|
|
113
|
-
stack_embedding_dim = stacked_table.shape[-1]
|
|
114
|
-
|
|
115
|
-
# Mod-shard vocabulary across devices.
|
|
116
|
-
sharded_values = jnp.swapaxes(
|
|
117
|
-
padded_values.reshape(-1, num_shards, stack_embedding_dim),
|
|
118
|
-
0,
|
|
119
|
-
1,
|
|
120
|
-
)
|
|
121
|
-
|
|
122
|
-
# Rotate shards.
|
|
123
|
-
setting_in_stack = table_spec.setting_in_stack
|
|
124
|
-
rotated_values = jnp.roll(
|
|
125
|
-
sharded_values, setting_in_stack.shard_rotation, axis=0
|
|
126
|
-
)
|
|
127
|
-
|
|
128
|
-
# Insert table into the stack.
|
|
129
|
-
table_row = setting_in_stack.row_offset_in_shard
|
|
130
|
-
stacked_table = stacked_table.at[
|
|
131
|
-
:, table_row : (table_row + sharded_padded_vocabulary_size), :
|
|
132
|
-
].set(rotated_values)
|
|
133
|
-
|
|
134
|
-
return stacked_table
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
def stack_and_shard_tables(
|
|
138
|
-
table_specs: Nested[TableSpec],
|
|
139
|
-
tables: Nested[ArrayLike],
|
|
140
|
-
num_shards: int,
|
|
141
|
-
pad_value: jnp.float32 = jnp.nan,
|
|
142
|
-
) -> dict[str, Nested[jax.Array]]:
|
|
143
|
-
"""Stacks and shards tables for use in sparsecore lookups.
|
|
144
|
-
|
|
145
|
-
Args:
|
|
146
|
-
table_specs: Nested collection of unstacked table specifications.
|
|
147
|
-
tables: Table values corresponding to the table_specs.
|
|
148
|
-
num_shards: Number of shards in the table (typically
|
|
149
|
-
`global_device_count * num_sc_per_device`).
|
|
150
|
-
pad_value: Value to use for padding.
|
|
151
|
-
|
|
152
|
-
Returns:
|
|
153
|
-
A mapping of stacked table names to stacked table values.
|
|
154
|
-
"""
|
|
155
|
-
|
|
156
|
-
# Gather stacked table information.
|
|
157
|
-
stacked_table_map: dict[
|
|
158
|
-
str,
|
|
159
|
-
tuple[StackedTableSpec, list[TableSpec]],
|
|
160
|
-
] = {}
|
|
161
|
-
|
|
162
|
-
def collect_stacked_tables(table_spec: TableSpec) -> None:
|
|
163
|
-
stacked_table_spec = _get_stacked_table_spec(table_spec, num_shards)
|
|
164
|
-
stacked_table_name = stacked_table_spec.stack_name
|
|
165
|
-
if stacked_table_name not in stacked_table_map:
|
|
166
|
-
stacked_table_map[stacked_table_name] = (stacked_table_spec, [])
|
|
167
|
-
stacked_table_map[stacked_table_name][1].append(table_spec)
|
|
168
|
-
|
|
169
|
-
_ = jax.tree.map(collect_stacked_tables, table_specs)
|
|
170
|
-
|
|
171
|
-
table_map: dict[str, Nested[jax.Array]] = {}
|
|
172
|
-
|
|
173
|
-
def collect_tables(table_spec: TableSpec, table: Nested[jax.Array]) -> None:
|
|
174
|
-
table_map[table_spec.name] = table
|
|
175
|
-
|
|
176
|
-
_ = jax.tree.map(collect_tables, table_specs, tables)
|
|
177
|
-
|
|
178
|
-
stacked_tables: dict[str, Nested[jax.Array]] = {}
|
|
179
|
-
for (
|
|
180
|
-
stacked_table_spec,
|
|
181
|
-
table_specs,
|
|
182
|
-
) in stacked_table_map.values():
|
|
183
|
-
stack_vocab_size = stacked_table_spec.stack_vocab_size
|
|
184
|
-
sharded_vocab_size = stack_vocab_size // num_shards
|
|
185
|
-
stack_embedding_dim = stacked_table_spec.stack_embedding_dim
|
|
186
|
-
|
|
187
|
-
# Allocate initial buffer. The stacked table will be divided among
|
|
188
|
-
# shards by splitting the vocabulary dimension:
|
|
189
|
-
# [ v, e ] -> [s, v/s, e]
|
|
190
|
-
stacked_table_tree = jax.tree.map(
|
|
191
|
-
lambda _: jnp.zeros(
|
|
192
|
-
# pylint: disable-next=cell-var-from-loop, used only in loop body.
|
|
193
|
-
shape=(num_shards, sharded_vocab_size, stack_embedding_dim),
|
|
194
|
-
dtype=jnp.float32,
|
|
195
|
-
),
|
|
196
|
-
table_map[table_specs[0].name],
|
|
197
|
-
)
|
|
198
|
-
|
|
199
|
-
for table_spec in table_specs:
|
|
200
|
-
table_tree = table_map[table_spec.name]
|
|
201
|
-
stacked_table_tree = jax.tree.map(
|
|
202
|
-
lambda stacked_table, table: _stack_and_shard_table(
|
|
203
|
-
# pylint: disable-next=cell-var-from-loop, used only in loop body.
|
|
204
|
-
stacked_table,
|
|
205
|
-
# pylint: disable-next=cell-var-from-loop, used only in loop body.
|
|
206
|
-
table_spec,
|
|
207
|
-
table,
|
|
208
|
-
num_shards,
|
|
209
|
-
pad_value,
|
|
210
|
-
),
|
|
211
|
-
stacked_table_tree,
|
|
212
|
-
table_tree,
|
|
213
|
-
)
|
|
214
|
-
|
|
215
|
-
stacked_tables[stacked_table_spec.stack_name] = stacked_table_tree
|
|
216
|
-
|
|
217
|
-
return stacked_tables
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
def _unshard_and_unstack_table(
|
|
221
|
-
table_spec: TableSpec,
|
|
222
|
-
stacked_table_tree: Nested[jax.Array],
|
|
223
|
-
num_shards: int,
|
|
224
|
-
) -> Nested[jax.Array]:
|
|
225
|
-
"""Unshards and unstacks a single table."""
|
|
226
|
-
vocabulary_size = table_spec.vocabulary_size
|
|
227
|
-
embedding_dim = table_spec.embedding_dim
|
|
228
|
-
|
|
229
|
-
def _unshard_and_unstack_single_table(
|
|
230
|
-
table_spec: TableSpec, stacked_table: jax.Array
|
|
231
|
-
) -> jax.Array:
|
|
232
|
-
stack_embedding_dim = stacked_table.shape[-1]
|
|
233
|
-
|
|
234
|
-
# Maybe re-shape in case it was flattened.
|
|
235
|
-
stacked_table = stacked_table.reshape(
|
|
236
|
-
num_shards, -1, stack_embedding_dim
|
|
237
|
-
)
|
|
238
|
-
sharded_vocabulary_size = (
|
|
239
|
-
_round_up_to_multiple(vocabulary_size, 8 * num_shards) // num_shards
|
|
240
|
-
)
|
|
241
|
-
|
|
242
|
-
# Extract padded values from the stacked table.
|
|
243
|
-
setting_in_stack = table_spec.setting_in_stack
|
|
244
|
-
row = setting_in_stack.row_offset_in_shard
|
|
245
|
-
padded_values = stacked_table[
|
|
246
|
-
:, row : (row + sharded_vocabulary_size), :
|
|
247
|
-
]
|
|
248
|
-
|
|
249
|
-
# Un-rotate shards.
|
|
250
|
-
padded_values = jnp.roll(
|
|
251
|
-
padded_values, -setting_in_stack.shard_rotation, axis=0
|
|
252
|
-
)
|
|
253
|
-
|
|
254
|
-
# Un-mod-shard.
|
|
255
|
-
padded_values = jnp.swapaxes(padded_values, 0, 1).reshape(
|
|
256
|
-
-1, stack_embedding_dim
|
|
257
|
-
)
|
|
258
|
-
|
|
259
|
-
# Un-pad.
|
|
260
|
-
return padded_values[:vocabulary_size, :embedding_dim]
|
|
261
|
-
|
|
262
|
-
output: Nested[jax.Array] = jax.tree.map(
|
|
263
|
-
lambda stacked_table: _unshard_and_unstack_single_table(
|
|
264
|
-
table_spec, stacked_table
|
|
265
|
-
),
|
|
266
|
-
stacked_table_tree,
|
|
267
|
-
)
|
|
268
|
-
return output
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
def unshard_and_unstack_tables(
|
|
272
|
-
table_specs: Nested[TableSpec],
|
|
273
|
-
stacked_tables: Mapping[str, Nested[jax.Array]],
|
|
274
|
-
num_shards: int,
|
|
275
|
-
) -> Nested[jax.Array]:
|
|
276
|
-
"""Unshards and unstacks a collection of tables.
|
|
277
|
-
|
|
278
|
-
Args:
|
|
279
|
-
table_specs: Nested collection of unstacked table specifications.
|
|
280
|
-
stacked_tables: Mapping of stacked table names to stacked table values.
|
|
281
|
-
num_shards: Number of shards in the table (typically
|
|
282
|
-
`global_device_count * num_sc_per_device`).
|
|
283
|
-
|
|
284
|
-
Returns:
|
|
285
|
-
A mapping of table names to unstacked table values.
|
|
286
|
-
"""
|
|
287
|
-
output: Nested[jax.Array] = jax.tree.map(
|
|
288
|
-
lambda table_spec: _unshard_and_unstack_table(
|
|
289
|
-
table_spec,
|
|
290
|
-
stacked_tables[
|
|
291
|
-
_get_stacked_table_spec(table_spec, num_shards=1).stack_name
|
|
292
|
-
],
|
|
293
|
-
num_shards,
|
|
294
|
-
),
|
|
295
|
-
table_specs,
|
|
296
|
-
)
|
|
297
|
-
return output
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
def get_table_specs(feature_specs: Nested[FeatureSpec]) -> dict[str, TableSpec]:
|
|
301
|
-
table_spec_map: dict[str, TableSpec] = {}
|
|
302
|
-
flat_feature_specs, _ = jax.tree.flatten(feature_specs)
|
|
303
|
-
for feature_spec in flat_feature_specs:
|
|
304
|
-
table_spec = feature_spec.table_spec
|
|
305
|
-
table_spec_map[table_spec.name] = table_spec
|
|
306
|
-
return table_spec_map
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
def get_table_stacks(
|
|
310
|
-
table_specs: Nested[TableSpec],
|
|
311
|
-
) -> dict[str, list[TableSpec]]:
|
|
312
|
-
"""Extracts lists of tables that are stacked together.
|
|
313
|
-
|
|
314
|
-
Args:
|
|
315
|
-
table_specs: Nested collection of table specifications.
|
|
316
|
-
|
|
317
|
-
Returns:
|
|
318
|
-
A mapping of stacked table names to lists of table specifications for
|
|
319
|
-
each stack.
|
|
320
|
-
"""
|
|
321
|
-
stacked_table_specs: dict[str, list[TableSpec]] = collections.defaultdict(
|
|
322
|
-
list
|
|
323
|
-
)
|
|
324
|
-
flat_table_specs, _ = jax.tree.flatten(table_specs)
|
|
325
|
-
for table_spec in flat_table_specs:
|
|
326
|
-
table_spec = typing.cast(TableSpec, table_spec)
|
|
327
|
-
stacked_table_spec = table_spec.stacked_table_spec
|
|
328
|
-
if stacked_table_spec is not None:
|
|
329
|
-
stacked_table_specs[stacked_table_spec.stack_name].append(
|
|
330
|
-
table_spec
|
|
331
|
-
)
|
|
332
|
-
else:
|
|
333
|
-
stacked_table_specs[table_spec.name].append(table_spec)
|
|
334
|
-
|
|
335
|
-
return stacked_table_specs
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
def update_stacked_table_specs(
|
|
339
|
-
feature_specs: Nested[FeatureSpec],
|
|
340
|
-
max_ids_per_partition: Mapping[str, int],
|
|
341
|
-
max_unique_ids_per_partition: Mapping[str, int],
|
|
342
|
-
) -> None:
|
|
343
|
-
"""Updates properties in the supplied feature specs.
|
|
344
|
-
|
|
345
|
-
Args:
|
|
346
|
-
feature_specs: Feature specs to update in-place.
|
|
347
|
-
max_ids_per_partition: Mapping of table stack name to
|
|
348
|
-
new `max_ids_per_partition` for the stack.
|
|
349
|
-
max_unique_ids_per_partition: Mapping of table stack name to
|
|
350
|
-
new `max_unique_ids_per_partition` for the stack.
|
|
351
|
-
"""
|
|
352
|
-
# Collect table specs and stacked table specs.
|
|
353
|
-
table_specs: dict[str, TableSpec] = {}
|
|
354
|
-
for feature_spec in jax.tree.flatten(feature_specs)[0]:
|
|
355
|
-
feature_spec = typing.cast(FeatureSpec, feature_spec)
|
|
356
|
-
table_specs[feature_spec.table_spec.name] = feature_spec.table_spec
|
|
357
|
-
|
|
358
|
-
stacked_table_specs: dict[str, StackedTableSpec] = {}
|
|
359
|
-
for table_spec in table_specs.values():
|
|
360
|
-
stacked_table_spec = typing.cast(
|
|
361
|
-
StackedTableSpec, table_spec.stacked_table_spec
|
|
362
|
-
)
|
|
363
|
-
stacked_table_specs[stacked_table_spec.stack_name] = stacked_table_spec
|
|
364
|
-
|
|
365
|
-
# Replace fields in the stacked_table_specs.
|
|
366
|
-
stacked_table_specs = {
|
|
367
|
-
stack_name: dataclasses.replace(
|
|
368
|
-
stacked_table_spec,
|
|
369
|
-
max_ids_per_partition=max_ids_per_partition[
|
|
370
|
-
stacked_table_spec.stack_name
|
|
371
|
-
],
|
|
372
|
-
max_unique_ids_per_partition=max_unique_ids_per_partition[
|
|
373
|
-
stacked_table_spec.stack_name
|
|
374
|
-
],
|
|
375
|
-
)
|
|
376
|
-
for stack_name, stacked_table_spec in stacked_table_specs.items()
|
|
377
|
-
}
|
|
378
|
-
|
|
379
|
-
# Insert new stacked tables into tables.
|
|
380
|
-
for table_spec in table_specs.values():
|
|
381
|
-
stacked_table_spec = typing.cast(
|
|
382
|
-
StackedTableSpec, table_spec.stacked_table_spec
|
|
383
|
-
)
|
|
384
|
-
table_spec.stacked_table_spec = stacked_table_specs[
|
|
385
|
-
stacked_table_spec.stack_name
|
|
386
|
-
]
|
|
387
|
-
|
|
388
|
-
|
|
389
34
|
def convert_to_numpy(
|
|
390
35
|
ragged_or_dense: np.ndarray[Any, Any] | Sequence[Sequence[Any]] | Any,
|
|
391
36
|
dtype: Any,
|
|
@@ -439,36 +84,6 @@ def convert_to_numpy(
|
|
|
439
84
|
)
|
|
440
85
|
|
|
441
86
|
|
|
442
|
-
def ones_like(
|
|
443
|
-
ragged_or_dense: np.ndarray[Any, Any], dtype: Any = None
|
|
444
|
-
) -> np.ndarray[Any, Any]:
|
|
445
|
-
"""Creates an array of ones the same as as the input.
|
|
446
|
-
|
|
447
|
-
This differs from traditional numpy in that a ragged input will lead to
|
|
448
|
-
a resulting ragged array of ones, whereas np.ones_like(...) will instead
|
|
449
|
-
only consider the outer array and return a 1D dense array of ones.
|
|
450
|
-
|
|
451
|
-
Args:
|
|
452
|
-
ragged_or_dense: The ragged or dense input whose shape and data-type
|
|
453
|
-
define these same attributes of the returned array.
|
|
454
|
-
dtype: The data-type of the returned array.
|
|
455
|
-
|
|
456
|
-
Returns:
|
|
457
|
-
An array of ones with the same shape as the input, and specified data
|
|
458
|
-
type.
|
|
459
|
-
"""
|
|
460
|
-
dtype = dtype or ragged_or_dense.dtype
|
|
461
|
-
if ragged_or_dense.dtype == np.ndarray:
|
|
462
|
-
# Ragged.
|
|
463
|
-
return np.array(
|
|
464
|
-
[np.ones_like(row, dtype=dtype) for row in ragged_or_dense],
|
|
465
|
-
dtype=np.ndarray,
|
|
466
|
-
)
|
|
467
|
-
else:
|
|
468
|
-
# Dense.
|
|
469
|
-
return np.ones_like(ragged_or_dense, dtype=dtype)
|
|
470
|
-
|
|
471
|
-
|
|
472
87
|
def create_feature_samples(
|
|
473
88
|
feature_structure: Nested[T],
|
|
474
89
|
feature_ids: Nested[ArrayLike | Sequence[int] | Sequence[Sequence[int]]],
|
|
@@ -496,18 +111,17 @@ def create_feature_samples(
|
|
|
496
111
|
)
|
|
497
112
|
|
|
498
113
|
if feature_weights is None:
|
|
499
|
-
#
|
|
500
|
-
|
|
501
|
-
lambda _, ids: ones_like(ids, np.float32),
|
|
114
|
+
return jax.tree.map( # type: ignore[no-any-return]
|
|
115
|
+
lambda _, ids: FeatureSamples(ids, None),
|
|
502
116
|
feature_structure,
|
|
503
117
|
feature_ids,
|
|
504
118
|
)
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
119
|
+
|
|
120
|
+
feature_weights = jax.tree.map(
|
|
121
|
+
lambda _, wgts: convert_to_numpy(wgts, np.float32),
|
|
122
|
+
feature_structure,
|
|
123
|
+
feature_weights,
|
|
124
|
+
)
|
|
511
125
|
|
|
512
126
|
# Assemble.
|
|
513
127
|
def _create_feature_samples(
|
|
@@ -544,8 +158,8 @@ def stack_and_shard_samples(
|
|
|
544
158
|
global_device_count: Number of global JAX devices.
|
|
545
159
|
num_sc_per_device: Number of sparsecores per device.
|
|
546
160
|
static_buffer_size: The static buffer size to use for the samples.
|
|
547
|
-
|
|
548
|
-
|
|
161
|
+
Defaults to None, in which case an upper-bound for the buffer size
|
|
162
|
+
will be automatically determined.
|
|
549
163
|
|
|
550
164
|
Returns:
|
|
551
165
|
The preprocessed inputs, and statistics useful for updating FeatureSpecs
|
|
@@ -555,17 +169,21 @@ def stack_and_shard_samples(
|
|
|
555
169
|
flat_feature_specs, _ = jax.tree.flatten(feature_specs)
|
|
556
170
|
|
|
557
171
|
feature_tokens = []
|
|
558
|
-
|
|
172
|
+
collected_weights = []
|
|
559
173
|
|
|
560
174
|
def collect_tokens_and_weights(
|
|
561
175
|
feature_spec: FeatureSpec, samples: FeatureSamples
|
|
562
176
|
) -> None:
|
|
563
177
|
del feature_spec
|
|
564
178
|
feature_tokens.append(samples.tokens)
|
|
565
|
-
|
|
179
|
+
collected_weights.append(samples.weights)
|
|
566
180
|
|
|
567
181
|
jax.tree.map(collect_tokens_and_weights, feature_specs, feature_samples)
|
|
568
182
|
|
|
183
|
+
feature_weights = (
|
|
184
|
+
None if all(w is None for w in collected_weights) else collected_weights
|
|
185
|
+
)
|
|
186
|
+
|
|
569
187
|
preprocessed_inputs, stats = embedding.preprocess_sparse_dense_matmul_input(
|
|
570
188
|
feature_tokens,
|
|
571
189
|
feature_weights,
|
|
@@ -583,7 +201,10 @@ def stack_and_shard_samples(
|
|
|
583
201
|
for table_name in tables_names:
|
|
584
202
|
shard_ends = preprocessed_inputs.lhs_row_pointers[table_name]
|
|
585
203
|
shard_starts = np.concatenate(
|
|
586
|
-
[
|
|
204
|
+
[
|
|
205
|
+
np.asarray([0]),
|
|
206
|
+
table_stacking._next_largest_multiple(shard_ends[:-1], 8),
|
|
207
|
+
]
|
|
587
208
|
)
|
|
588
209
|
out[table_name] = ShardedCooMatrix(
|
|
589
210
|
shard_starts=shard_starts,
|
|
@@ -53,7 +53,7 @@ OPTIMIZER_MAPPINGS = {
|
|
|
53
53
|
# KerasRS to TensorFlow
|
|
54
54
|
|
|
55
55
|
|
|
56
|
-
def
|
|
56
|
+
def keras_to_tf_tpu_configuration(
|
|
57
57
|
feature_configs: types.Nested[FeatureConfig],
|
|
58
58
|
table_stacking: str | Sequence[str] | Sequence[Sequence[str]],
|
|
59
59
|
num_replicas_in_sync: int,
|
|
@@ -66,14 +66,15 @@ def translate_keras_rs_configuration(
|
|
|
66
66
|
Args:
|
|
67
67
|
feature_configs: The nested Keras RS feature configs.
|
|
68
68
|
table_stacking: The Keras RS table stacking.
|
|
69
|
+
num_replicas_in_sync: The number of replicas in sync from the strategy.
|
|
69
70
|
|
|
70
71
|
Returns:
|
|
71
72
|
A tuple containing the TensorFlow TPU feature configs and the TensorFlow
|
|
72
73
|
TPU sparse core embedding config.
|
|
73
74
|
"""
|
|
74
|
-
tables: dict[
|
|
75
|
+
tables: dict[int, tf.tpu.experimental.embedding.TableConfig] = {}
|
|
75
76
|
feature_configs = keras.tree.map_structure(
|
|
76
|
-
lambda f:
|
|
77
|
+
lambda f: keras_to_tf_tpu_feature_config(
|
|
77
78
|
f, tables, num_replicas_in_sync
|
|
78
79
|
),
|
|
79
80
|
feature_configs,
|
|
@@ -108,9 +109,9 @@ def translate_keras_rs_configuration(
|
|
|
108
109
|
return feature_configs, sparse_core_embedding_config
|
|
109
110
|
|
|
110
111
|
|
|
111
|
-
def
|
|
112
|
+
def keras_to_tf_tpu_feature_config(
|
|
112
113
|
feature_config: FeatureConfig,
|
|
113
|
-
tables: dict[
|
|
114
|
+
tables: dict[int, tf.tpu.experimental.embedding.TableConfig],
|
|
114
115
|
num_replicas_in_sync: int,
|
|
115
116
|
) -> tf.tpu.experimental.embedding.FeatureConfig:
|
|
116
117
|
"""Translates a Keras RS feature config to a TensorFlow TPU feature config.
|
|
@@ -120,7 +121,8 @@ def translate_keras_rs_feature_config(
|
|
|
120
121
|
|
|
121
122
|
Args:
|
|
122
123
|
feature_config: The Keras RS feature config to translate.
|
|
123
|
-
tables: A mapping of KerasRS table
|
|
124
|
+
tables: A mapping of KerasRS table config ids to TF TPU table configs.
|
|
125
|
+
num_replicas_in_sync: The number of replicas in sync from the strategy.
|
|
124
126
|
|
|
125
127
|
Returns:
|
|
126
128
|
The TensorFlow TPU feature config.
|
|
@@ -131,10 +133,10 @@ def translate_keras_rs_feature_config(
|
|
|
131
133
|
f"but got {num_replicas_in_sync}."
|
|
132
134
|
)
|
|
133
135
|
|
|
134
|
-
table = tables.get(feature_config.table, None)
|
|
136
|
+
table = tables.get(id(feature_config.table), None)
|
|
135
137
|
if table is None:
|
|
136
|
-
table =
|
|
137
|
-
tables[feature_config.table] = table
|
|
138
|
+
table = keras_to_tf_tpu_table_config(feature_config.table)
|
|
139
|
+
tables[id(feature_config.table)] = table
|
|
138
140
|
|
|
139
141
|
if len(feature_config.output_shape) < 2:
|
|
140
142
|
raise ValueError(
|
|
@@ -168,7 +170,7 @@ def translate_keras_rs_feature_config(
|
|
|
168
170
|
)
|
|
169
171
|
|
|
170
172
|
|
|
171
|
-
def
|
|
173
|
+
def keras_to_tf_tpu_table_config(
|
|
172
174
|
table_config: TableConfig,
|
|
173
175
|
) -> tf.tpu.experimental.embedding.TableConfig:
|
|
174
176
|
initializer = table_config.initializer
|
|
@@ -179,13 +181,13 @@ def translate_keras_rs_table_config(
|
|
|
179
181
|
vocabulary_size=table_config.vocabulary_size,
|
|
180
182
|
dim=table_config.embedding_dim,
|
|
181
183
|
initializer=initializer,
|
|
182
|
-
optimizer=
|
|
184
|
+
optimizer=to_tf_tpu_optimizer(table_config.optimizer),
|
|
183
185
|
combiner=table_config.combiner,
|
|
184
186
|
name=table_config.name,
|
|
185
187
|
)
|
|
186
188
|
|
|
187
189
|
|
|
188
|
-
def
|
|
190
|
+
def keras_to_tf_tpu_optimizer(
|
|
189
191
|
optimizer: keras.optimizers.Optimizer,
|
|
190
192
|
) -> TfTpuOptimizer:
|
|
191
193
|
"""Translates a Keras optimizer to a TensorFlow TPU `_Optimizer`.
|
|
@@ -238,7 +240,12 @@ def translate_keras_optimizer(
|
|
|
238
240
|
"Unsupported optimizer option `Optimizer.loss_scale_factor`."
|
|
239
241
|
)
|
|
240
242
|
|
|
241
|
-
optimizer_mapping =
|
|
243
|
+
optimizer_mapping = None
|
|
244
|
+
for optimizer_class, mapping in OPTIMIZER_MAPPINGS.items():
|
|
245
|
+
# Handle subclasses of the main optimizer class.
|
|
246
|
+
if isinstance(optimizer, optimizer_class):
|
|
247
|
+
optimizer_mapping = mapping
|
|
248
|
+
break
|
|
242
249
|
if optimizer_mapping is None:
|
|
243
250
|
raise ValueError(
|
|
244
251
|
f"Unsupported optimizer type {type(optimizer)}. Optimizer must be "
|
|
@@ -258,7 +265,7 @@ def translate_keras_optimizer(
|
|
|
258
265
|
return optimizer_mapping.tpu_optimizer_class(**tpu_optimizer_kwargs)
|
|
259
266
|
|
|
260
267
|
|
|
261
|
-
def
|
|
268
|
+
def to_tf_tpu_optimizer(
|
|
262
269
|
optimizer: str | keras.optimizers.Optimizer | TfTpuOptimizer | None,
|
|
263
270
|
) -> TfTpuOptimizer:
|
|
264
271
|
"""Translates a Keras optimizer into a TensorFlow TPU `_Optimizer`.
|
|
@@ -299,7 +306,7 @@ def translate_optimizer(
|
|
|
299
306
|
"'sgd', 'adagrad', 'adam', or 'ftrl'"
|
|
300
307
|
)
|
|
301
308
|
elif isinstance(optimizer, keras.optimizers.Optimizer):
|
|
302
|
-
return
|
|
309
|
+
return keras_to_tf_tpu_optimizer(optimizer)
|
|
303
310
|
else:
|
|
304
311
|
raise ValueError(
|
|
305
312
|
f"Unknown optimizer type {type(optimizer)}. Please pass an "
|
|
@@ -312,7 +319,7 @@ def translate_optimizer(
|
|
|
312
319
|
# TensorFlow to TensorFlow
|
|
313
320
|
|
|
314
321
|
|
|
315
|
-
def
|
|
322
|
+
def clone_tf_tpu_feature_configs(
|
|
316
323
|
feature_configs: types.Nested[tf.tpu.experimental.embedding.FeatureConfig],
|
|
317
324
|
) -> types.Nested[tf.tpu.experimental.embedding.FeatureConfig]:
|
|
318
325
|
"""Clones and resolves TensorFlow TPU feature configs.
|
|
@@ -327,7 +334,7 @@ def clone_tf_feature_configs(
|
|
|
327
334
|
"""
|
|
328
335
|
table_configs_dict = {}
|
|
329
336
|
|
|
330
|
-
def
|
|
337
|
+
def clone_and_resolve_tf_tpu_feature_config(
|
|
331
338
|
fc: tf.tpu.experimental.embedding.FeatureConfig,
|
|
332
339
|
) -> tf.tpu.experimental.embedding.FeatureConfig:
|
|
333
340
|
if fc.table not in table_configs_dict:
|
|
@@ -336,7 +343,7 @@ def clone_tf_feature_configs(
|
|
|
336
343
|
vocabulary_size=fc.table.vocabulary_size,
|
|
337
344
|
dim=fc.table.dim,
|
|
338
345
|
initializer=fc.table.initializer,
|
|
339
|
-
optimizer=
|
|
346
|
+
optimizer=to_tf_tpu_optimizer(fc.table.optimizer),
|
|
340
347
|
combiner=fc.table.combiner,
|
|
341
348
|
name=fc.table.name,
|
|
342
349
|
quantization_config=fc.table.quantization_config,
|
|
@@ -352,5 +359,5 @@ def clone_tf_feature_configs(
|
|
|
352
359
|
)
|
|
353
360
|
|
|
354
361
|
return keras.tree.map_structure(
|
|
355
|
-
|
|
362
|
+
clone_and_resolve_tf_tpu_feature_config, feature_configs
|
|
356
363
|
)
|
|
@@ -35,8 +35,15 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
|
|
|
35
35
|
table_stacking: (
|
|
36
36
|
str | Sequence[str] | Sequence[Sequence[str]]
|
|
37
37
|
) = "auto",
|
|
38
|
+
update_stats: bool = False,
|
|
38
39
|
**kwargs: Any,
|
|
39
40
|
) -> None:
|
|
41
|
+
# `update_stats` is supported only on JAX.
|
|
42
|
+
if update_stats:
|
|
43
|
+
raise ValueError(
|
|
44
|
+
"`update_stats` cannot be True for the TensorFlow backend."
|
|
45
|
+
)
|
|
46
|
+
|
|
40
47
|
# Intercept arguments that are supported only on TensorFlow.
|
|
41
48
|
self._optimizer = kwargs.pop("optimizer", None)
|
|
42
49
|
self._pipeline_execution_with_tensor_core = kwargs.pop(
|
|
@@ -106,7 +113,7 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
|
|
|
106
113
|
"for the configuration."
|
|
107
114
|
)
|
|
108
115
|
self._tpu_feature_configs, self._sparse_core_embedding_config = (
|
|
109
|
-
config_conversion.
|
|
116
|
+
config_conversion.keras_to_tf_tpu_configuration(
|
|
110
117
|
feature_configs,
|
|
111
118
|
table_stacking,
|
|
112
119
|
strategy.num_replicas_in_sync,
|
|
@@ -135,10 +142,10 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
|
|
|
135
142
|
"supported with this TPU generation."
|
|
136
143
|
)
|
|
137
144
|
self._tpu_feature_configs = (
|
|
138
|
-
config_conversion.
|
|
145
|
+
config_conversion.clone_tf_tpu_feature_configs(feature_configs)
|
|
139
146
|
)
|
|
140
147
|
|
|
141
|
-
self._tpu_optimizer = config_conversion.
|
|
148
|
+
self._tpu_optimizer = config_conversion.to_tf_tpu_optimizer(
|
|
142
149
|
self._optimizer
|
|
143
150
|
)
|
|
144
151
|
|
|
@@ -281,8 +288,18 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
|
|
|
281
288
|
def _sparsecore_get_embedding_tables(self) -> dict[str, types.Tensor]:
|
|
282
289
|
tables: dict[str, types.Tensor] = {}
|
|
283
290
|
strategy = tf.distribute.get_strategy()
|
|
284
|
-
|
|
285
|
-
|
|
291
|
+
if not self._is_tpu_strategy(strategy):
|
|
292
|
+
raise RuntimeError(
|
|
293
|
+
"`DistributedEmbedding.get_embedding_tables` needs to be "
|
|
294
|
+
"called under the TPUStrategy that DistributedEmbedding was "
|
|
295
|
+
f"created with, but is being called under strategy {strategy}. "
|
|
296
|
+
"Please use `with strategy.scope()` when calling "
|
|
297
|
+
"`get_embedding_tables`."
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
tpu_hardware = strategy.extended.tpu_hardware_feature
|
|
301
|
+
num_sc_per_device = tpu_hardware.num_embedding_devices_per_chip
|
|
302
|
+
num_shards = strategy.num_replicas_in_sync * num_sc_per_device
|
|
286
303
|
|
|
287
304
|
def populate_table(
|
|
288
305
|
feature_config: tf.tpu.experimental.embedding.FeatureConfig,
|