keras-rs-nightly 0.2.2.dev202508190331__py3-none-any.whl → 0.4.1.dev202601250348__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_rs/losses/__init__.py +1 -0
- keras_rs/src/layers/embedding/base_distributed_embedding.py +19 -10
- keras_rs/src/layers/embedding/distributed_embedding_config.py +2 -2
- keras_rs/src/layers/embedding/jax/distributed_embedding.py +133 -201
- keras_rs/src/layers/embedding/jax/embedding_lookup.py +25 -4
- keras_rs/src/layers/embedding/jax/embedding_utils.py +22 -401
- keras_rs/src/layers/embedding/tensorflow/config_conversion.py +26 -19
- keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py +22 -5
- keras_rs/src/losses/list_mle_loss.py +212 -0
- keras_rs/src/metrics/ranking_metrics_utils.py +21 -2
- keras_rs/src/utils/tpu_test_utils.py +120 -0
- keras_rs/src/version.py +1 -1
- {keras_rs_nightly-0.2.2.dev202508190331.dist-info → keras_rs_nightly-0.4.1.dev202601250348.dist-info}/METADATA +4 -3
- {keras_rs_nightly-0.2.2.dev202508190331.dist-info → keras_rs_nightly-0.4.1.dev202601250348.dist-info}/RECORD +16 -14
- {keras_rs_nightly-0.2.2.dev202508190331.dist-info → keras_rs_nightly-0.4.1.dev202601250348.dist-info}/WHEEL +1 -1
- {keras_rs_nightly-0.2.2.dev202508190331.dist-info → keras_rs_nightly-0.4.1.dev202601250348.dist-info}/top_level.txt +0 -0
keras_rs/losses/__init__.py
CHANGED
|
@@ -4,6 +4,7 @@ This file was autogenerated. Do not edit it by hand,
|
|
|
4
4
|
since your modifications would be overwritten.
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
|
+
from keras_rs.src.losses.list_mle_loss import ListMLELoss as ListMLELoss
|
|
7
8
|
from keras_rs.src.losses.pairwise_hinge_loss import (
|
|
8
9
|
PairwiseHingeLoss as PairwiseHingeLoss,
|
|
9
10
|
)
|
|
@@ -457,6 +457,10 @@ class DistributedEmbedding(keras.layers.Layer):
|
|
|
457
457
|
tables in the inner lists together. Note that table stacking is not
|
|
458
458
|
supported on older TPUs, in which case the default value of `"auto"`
|
|
459
459
|
will be interpreted as no table stacking.
|
|
460
|
+
update_stats: If True, `'max_ids_per_partition'`,
|
|
461
|
+
`'max_unique_ids_per_partition'` and
|
|
462
|
+
`'suggested_coo_buffer_size_per_device'` are updated during
|
|
463
|
+
training. This argument can be set to True only for the JAX backend.
|
|
460
464
|
**kwargs: Additional arguments to pass to the layer base class.
|
|
461
465
|
"""
|
|
462
466
|
|
|
@@ -467,6 +471,7 @@ class DistributedEmbedding(keras.layers.Layer):
|
|
|
467
471
|
table_stacking: (
|
|
468
472
|
str | Sequence[str] | Sequence[Sequence[str]]
|
|
469
473
|
) = "auto",
|
|
474
|
+
update_stats: bool = False,
|
|
470
475
|
**kwargs: Any,
|
|
471
476
|
) -> None:
|
|
472
477
|
super().__init__(**kwargs)
|
|
@@ -486,6 +491,8 @@ class DistributedEmbedding(keras.layers.Layer):
|
|
|
486
491
|
table_stacking,
|
|
487
492
|
)
|
|
488
493
|
|
|
494
|
+
self.update_stats = update_stats
|
|
495
|
+
|
|
489
496
|
@keras_utils.no_automatic_dependency_tracking
|
|
490
497
|
def _init_feature_configs_structures(
|
|
491
498
|
self,
|
|
@@ -822,13 +829,13 @@ class DistributedEmbedding(keras.layers.Layer):
|
|
|
822
829
|
table_stacking: str | Sequence[Sequence[str]],
|
|
823
830
|
) -> None:
|
|
824
831
|
del table_stacking
|
|
825
|
-
|
|
832
|
+
table_config_id_to_embedding_layer: dict[int, EmbedReduce] = {}
|
|
826
833
|
self._default_device_embedding_layers: dict[str, EmbedReduce] = {}
|
|
827
834
|
|
|
828
835
|
for path, feature_config in feature_configs.items():
|
|
829
|
-
if feature_config.table in
|
|
836
|
+
if id(feature_config.table) in table_config_id_to_embedding_layer:
|
|
830
837
|
self._default_device_embedding_layers[path] = (
|
|
831
|
-
|
|
838
|
+
table_config_id_to_embedding_layer[id(feature_config.table)]
|
|
832
839
|
)
|
|
833
840
|
else:
|
|
834
841
|
embedding_layer = EmbedReduce(
|
|
@@ -838,7 +845,9 @@ class DistributedEmbedding(keras.layers.Layer):
|
|
|
838
845
|
embeddings_initializer=feature_config.table.initializer,
|
|
839
846
|
combiner=feature_config.table.combiner,
|
|
840
847
|
)
|
|
841
|
-
|
|
848
|
+
table_config_id_to_embedding_layer[id(feature_config.table)] = (
|
|
849
|
+
embedding_layer
|
|
850
|
+
)
|
|
842
851
|
self._default_device_embedding_layers[path] = embedding_layer
|
|
843
852
|
|
|
844
853
|
def _default_device_build(
|
|
@@ -1013,8 +1022,8 @@ class DistributedEmbedding(keras.layers.Layer):
|
|
|
1013
1022
|
|
|
1014
1023
|
# The serialized `TableConfig` objects.
|
|
1015
1024
|
table_config_dicts: list[dict[str, Any]] = []
|
|
1016
|
-
# Mapping from `TableConfig` to index in `table_config_dicts`.
|
|
1017
|
-
|
|
1025
|
+
# Mapping from `TableConfig` id to index in `table_config_dicts`.
|
|
1026
|
+
table_config_id_to_index: dict[int, int] = {}
|
|
1018
1027
|
|
|
1019
1028
|
def serialize_feature_config(
|
|
1020
1029
|
feature_config: FeatureConfig,
|
|
@@ -1024,17 +1033,17 @@ class DistributedEmbedding(keras.layers.Layer):
|
|
|
1024
1033
|
# key.
|
|
1025
1034
|
feature_config_dict = feature_config.get_config()
|
|
1026
1035
|
|
|
1027
|
-
if feature_config.table not in
|
|
1036
|
+
if id(feature_config.table) not in table_config_id_to_index:
|
|
1028
1037
|
# Save the serialized `TableConfig` the first time we see it and
|
|
1029
1038
|
# remember its index.
|
|
1030
|
-
|
|
1039
|
+
table_config_id_to_index[id(feature_config.table)] = len(
|
|
1031
1040
|
table_config_dicts
|
|
1032
1041
|
)
|
|
1033
1042
|
table_config_dicts.append(feature_config_dict["table"])
|
|
1034
1043
|
|
|
1035
1044
|
# Replace the serialized `TableConfig` with its index.
|
|
1036
|
-
feature_config_dict["table"] =
|
|
1037
|
-
feature_config.table
|
|
1045
|
+
feature_config_dict["table"] = table_config_id_to_index[
|
|
1046
|
+
id(feature_config.table)
|
|
1038
1047
|
]
|
|
1039
1048
|
return feature_config_dict
|
|
1040
1049
|
|
|
@@ -10,7 +10,7 @@ from keras_rs.src.api_export import keras_rs_export
|
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
@keras_rs_export("keras_rs.layers.TableConfig")
|
|
13
|
-
@dataclasses.dataclass(
|
|
13
|
+
@dataclasses.dataclass(order=True)
|
|
14
14
|
class TableConfig:
|
|
15
15
|
"""Configuration for one embedding table.
|
|
16
16
|
|
|
@@ -88,7 +88,7 @@ class TableConfig:
|
|
|
88
88
|
|
|
89
89
|
|
|
90
90
|
@keras_rs_export("keras_rs.layers.FeatureConfig")
|
|
91
|
-
@dataclasses.dataclass(
|
|
91
|
+
@dataclasses.dataclass(order=True)
|
|
92
92
|
class FeatureConfig:
|
|
93
93
|
"""Configuration for one embedding feature.
|
|
94
94
|
|
|
@@ -9,13 +9,13 @@ import keras
|
|
|
9
9
|
import numpy as np
|
|
10
10
|
from jax import numpy as jnp
|
|
11
11
|
from jax.experimental import layout as jax_layout
|
|
12
|
+
from jax.experimental import multihost_utils
|
|
12
13
|
from jax_tpu_embedding.sparsecore.lib.nn import embedding
|
|
13
14
|
from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec
|
|
14
15
|
from jax_tpu_embedding.sparsecore.lib.nn import (
|
|
15
16
|
table_stacking as jte_table_stacking,
|
|
16
17
|
)
|
|
17
18
|
from jax_tpu_embedding.sparsecore.utils import utils as jte_utils
|
|
18
|
-
from keras.src import backend
|
|
19
19
|
|
|
20
20
|
from keras_rs.src import types
|
|
21
21
|
from keras_rs.src.layers.embedding import base_distributed_embedding
|
|
@@ -28,9 +28,14 @@ from keras_rs.src.layers.embedding.jax import embedding_utils
|
|
|
28
28
|
from keras_rs.src.types import Nested
|
|
29
29
|
from keras_rs.src.utils import keras_utils
|
|
30
30
|
|
|
31
|
+
if jax.__version_info__ >= (0, 8, 0):
|
|
32
|
+
from jax import shard_map
|
|
33
|
+
else:
|
|
34
|
+
from jax.experimental.shard_map import shard_map # type: ignore[assignment]
|
|
35
|
+
|
|
36
|
+
|
|
31
37
|
ArrayLike = Union[np.ndarray[Any, Any], jax.Array]
|
|
32
38
|
FeatureConfig = config.FeatureConfig
|
|
33
|
-
shard_map = jax.experimental.shard_map.shard_map # type: ignore[attr-defined]
|
|
34
39
|
|
|
35
40
|
|
|
36
41
|
def _get_partition_spec(
|
|
@@ -247,23 +252,6 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
|
|
|
247
252
|
)
|
|
248
253
|
return sparsecore_distribution, sparsecore_layout
|
|
249
254
|
|
|
250
|
-
def _create_cpu_distribution(
|
|
251
|
-
self, cpu_axis_name: str = "cpu"
|
|
252
|
-
) -> tuple[
|
|
253
|
-
keras.distribution.ModelParallel, keras.distribution.TensorLayout
|
|
254
|
-
]:
|
|
255
|
-
"""Share a variable across all CPU processes."""
|
|
256
|
-
cpu_devices = jax.devices("cpu")
|
|
257
|
-
device_mesh = keras.distribution.DeviceMesh(
|
|
258
|
-
(len(cpu_devices),), [cpu_axis_name], cpu_devices
|
|
259
|
-
)
|
|
260
|
-
replicated_layout = keras.distribution.TensorLayout([], device_mesh)
|
|
261
|
-
layout_map = keras.distribution.LayoutMap(device_mesh=device_mesh)
|
|
262
|
-
cpu_distribution = keras.distribution.ModelParallel(
|
|
263
|
-
layout_map=layout_map
|
|
264
|
-
)
|
|
265
|
-
return cpu_distribution, replicated_layout
|
|
266
|
-
|
|
267
255
|
def _add_sparsecore_weight(
|
|
268
256
|
self,
|
|
269
257
|
name: str,
|
|
@@ -283,7 +271,7 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
|
|
|
283
271
|
table_specs: Sequence[embedding_spec.TableSpec],
|
|
284
272
|
num_shards: int,
|
|
285
273
|
add_slot_variables: bool,
|
|
286
|
-
) ->
|
|
274
|
+
) -> embedding.EmbeddingVariables:
|
|
287
275
|
stacked_table_spec = typing.cast(
|
|
288
276
|
embedding_spec.StackedTableSpec, table_specs[0].stacked_table_spec
|
|
289
277
|
)
|
|
@@ -352,7 +340,7 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
|
|
|
352
340
|
slot_initializers, slot_variables
|
|
353
341
|
)
|
|
354
342
|
|
|
355
|
-
return table_variable, slot_variables
|
|
343
|
+
return embedding.EmbeddingVariables(table_variable, slot_variables)
|
|
356
344
|
|
|
357
345
|
@keras_utils.no_automatic_dependency_tracking
|
|
358
346
|
def _sparsecore_init(
|
|
@@ -405,11 +393,6 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
|
|
|
405
393
|
self._sparsecore_layout = sparsecore_layout
|
|
406
394
|
self._sparsecore_distribution = sparsecore_distribution
|
|
407
395
|
|
|
408
|
-
# Distribution for CPU operations.
|
|
409
|
-
cpu_distribution, cpu_layout = self._create_cpu_distribution()
|
|
410
|
-
self._cpu_distribution = cpu_distribution
|
|
411
|
-
self._cpu_layout = cpu_layout
|
|
412
|
-
|
|
413
396
|
mesh = sparsecore_distribution.device_mesh.backend_mesh
|
|
414
397
|
global_device_count = mesh.devices.size
|
|
415
398
|
num_sc_per_device = jte_utils.num_sparsecores_per_device(
|
|
@@ -424,7 +407,9 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
|
|
|
424
407
|
if isinstance(table_stacking, str):
|
|
425
408
|
if table_stacking == "auto":
|
|
426
409
|
jte_table_stacking.auto_stack_tables(
|
|
427
|
-
feature_specs,
|
|
410
|
+
feature_specs,
|
|
411
|
+
global_device_count,
|
|
412
|
+
num_sc_per_device,
|
|
428
413
|
)
|
|
429
414
|
else:
|
|
430
415
|
raise ValueError(
|
|
@@ -464,12 +449,51 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
|
|
|
464
449
|
)
|
|
465
450
|
|
|
466
451
|
# Collect all stacked tables.
|
|
467
|
-
table_specs =
|
|
468
|
-
table_stacks =
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
}
|
|
452
|
+
table_specs = embedding.get_table_specs(feature_specs)
|
|
453
|
+
table_stacks = jte_table_stacking.get_table_stacks(table_specs)
|
|
454
|
+
|
|
455
|
+
# Update stacked table stats to max of values across involved tables.
|
|
456
|
+
max_ids_per_partition = {}
|
|
457
|
+
max_unique_ids_per_partition = {}
|
|
458
|
+
required_buffer_size_per_device = {}
|
|
459
|
+
id_drop_counters = {}
|
|
460
|
+
for stack_name, stack in table_stacks.items():
|
|
461
|
+
max_ids_per_partition[stack_name] = np.max(
|
|
462
|
+
np.asarray(
|
|
463
|
+
[s.max_ids_per_partition for s in stack], dtype=np.int32
|
|
464
|
+
)
|
|
465
|
+
)
|
|
466
|
+
max_unique_ids_per_partition[stack_name] = np.max(
|
|
467
|
+
np.asarray(
|
|
468
|
+
[s.max_unique_ids_per_partition for s in stack],
|
|
469
|
+
dtype=np.int32,
|
|
470
|
+
)
|
|
471
|
+
)
|
|
472
|
+
|
|
473
|
+
# Only set the suggested buffer size if set on any individual table.
|
|
474
|
+
valid_buffer_sizes = [
|
|
475
|
+
s.suggested_coo_buffer_size_per_device
|
|
476
|
+
for s in stack
|
|
477
|
+
if s.suggested_coo_buffer_size_per_device is not None
|
|
478
|
+
]
|
|
479
|
+
if valid_buffer_sizes:
|
|
480
|
+
required_buffer_size_per_device[stack_name] = np.max(
|
|
481
|
+
np.asarray(valid_buffer_sizes, dtype=np.int32)
|
|
482
|
+
)
|
|
483
|
+
|
|
484
|
+
id_drop_counters[stack_name] = 0
|
|
485
|
+
|
|
486
|
+
aggregated_stats = embedding.SparseDenseMatmulInputStats(
|
|
487
|
+
max_ids_per_partition=max_ids_per_partition,
|
|
488
|
+
max_unique_ids_per_partition=max_unique_ids_per_partition,
|
|
489
|
+
required_buffer_size_per_sc=required_buffer_size_per_device,
|
|
490
|
+
id_drop_counters=id_drop_counters,
|
|
491
|
+
)
|
|
492
|
+
embedding.update_preprocessing_parameters(
|
|
493
|
+
feature_specs,
|
|
494
|
+
aggregated_stats,
|
|
495
|
+
num_sc_per_device,
|
|
496
|
+
)
|
|
473
497
|
|
|
474
498
|
# Create variables for all stacked tables and slot variables.
|
|
475
499
|
with sparsecore_distribution.scope():
|
|
@@ -502,50 +526,6 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
|
|
|
502
526
|
)
|
|
503
527
|
self._iterations.overwrite_with_gradient = True
|
|
504
528
|
|
|
505
|
-
with cpu_distribution.scope():
|
|
506
|
-
# Create variables to track static buffer size and max IDs for each
|
|
507
|
-
# table during preprocessing. These variables are shared across all
|
|
508
|
-
# processes on CPU. We don't add these via `add_weight` because we
|
|
509
|
-
# can't have them passed to the training function.
|
|
510
|
-
replicated_zeros_initializer = ShardedInitializer(
|
|
511
|
-
"zeros", cpu_layout
|
|
512
|
-
)
|
|
513
|
-
|
|
514
|
-
with backend.name_scope(self.name, caller=self):
|
|
515
|
-
self._preprocessing_buffer_size = {
|
|
516
|
-
table_name: backend.Variable(
|
|
517
|
-
initializer=replicated_zeros_initializer,
|
|
518
|
-
shape=(),
|
|
519
|
-
dtype=backend.standardize_dtype("int32"),
|
|
520
|
-
trainable=False,
|
|
521
|
-
name=table_name + ":preprocessing:buffer_size",
|
|
522
|
-
)
|
|
523
|
-
for table_name in stacked_table_specs.keys()
|
|
524
|
-
}
|
|
525
|
-
self._preprocessing_max_unique_ids_per_partition = {
|
|
526
|
-
table_name: backend.Variable(
|
|
527
|
-
shape=(),
|
|
528
|
-
name=table_name
|
|
529
|
-
+ ":preprocessing:max_unique_ids_per_partition",
|
|
530
|
-
initializer=replicated_zeros_initializer,
|
|
531
|
-
dtype=backend.standardize_dtype("int32"),
|
|
532
|
-
trainable=False,
|
|
533
|
-
)
|
|
534
|
-
for table_name in stacked_table_specs.keys()
|
|
535
|
-
}
|
|
536
|
-
|
|
537
|
-
self._preprocessing_max_ids_per_partition = {
|
|
538
|
-
table_name: backend.Variable(
|
|
539
|
-
shape=(),
|
|
540
|
-
name=table_name
|
|
541
|
-
+ ":preprocessing:max_ids_per_partition",
|
|
542
|
-
initializer=replicated_zeros_initializer,
|
|
543
|
-
dtype=backend.standardize_dtype("int32"),
|
|
544
|
-
trainable=False,
|
|
545
|
-
)
|
|
546
|
-
for table_name in stacked_table_specs.keys()
|
|
547
|
-
}
|
|
548
|
-
|
|
549
529
|
self._config = jte_embedding_lookup.EmbeddingLookupConfiguration(
|
|
550
530
|
feature_specs,
|
|
551
531
|
mesh=mesh,
|
|
@@ -586,10 +566,8 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
|
|
|
586
566
|
del inputs, weights, training
|
|
587
567
|
|
|
588
568
|
# Each stacked-table gets a ShardedCooMatrix.
|
|
589
|
-
table_specs =
|
|
590
|
-
|
|
591
|
-
)
|
|
592
|
-
table_stacks = embedding_utils.get_table_stacks(table_specs)
|
|
569
|
+
table_specs = embedding.get_table_specs(self._config.feature_specs)
|
|
570
|
+
table_stacks = jte_table_stacking.get_table_stacks(table_specs)
|
|
593
571
|
stacked_table_specs = {
|
|
594
572
|
stack_name: stack[0].stacked_table_spec
|
|
595
573
|
for stack_name, stack in table_stacks.items()
|
|
@@ -660,125 +638,74 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
|
|
|
660
638
|
mesh.devices.item(0)
|
|
661
639
|
)
|
|
662
640
|
|
|
663
|
-
# Get current buffer size/max_ids.
|
|
664
|
-
previous_max_ids_per_partition = keras.tree.map_structure(
|
|
665
|
-
lambda max_ids_per_partition: max_ids_per_partition.value.item(),
|
|
666
|
-
self._preprocessing_max_ids_per_partition,
|
|
667
|
-
)
|
|
668
|
-
previous_max_unique_ids_per_partition = keras.tree.map_structure(
|
|
669
|
-
lambda max_unique_ids_per_partition: (
|
|
670
|
-
max_unique_ids_per_partition.value.item()
|
|
671
|
-
),
|
|
672
|
-
self._preprocessing_max_unique_ids_per_partition,
|
|
673
|
-
)
|
|
674
|
-
previous_buffer_size = keras.tree.map_structure(
|
|
675
|
-
lambda buffer_size: buffer_size.value.item(),
|
|
676
|
-
self._preprocessing_buffer_size,
|
|
677
|
-
)
|
|
678
|
-
|
|
679
641
|
preprocessed, stats = embedding_utils.stack_and_shard_samples(
|
|
680
642
|
self._config.feature_specs,
|
|
681
643
|
samples,
|
|
682
644
|
local_device_count,
|
|
683
645
|
global_device_count,
|
|
684
646
|
num_sc_per_device,
|
|
685
|
-
|
|
686
|
-
)
|
|
687
|
-
|
|
688
|
-
# Extract max unique IDs and buffer sizes.
|
|
689
|
-
# We need to replicate this value across all local CPU devices.
|
|
690
|
-
if training:
|
|
691
|
-
num_local_cpu_devices = jax.local_device_count("cpu")
|
|
692
|
-
local_max_ids_per_partition = {
|
|
693
|
-
table_name: np.repeat(
|
|
694
|
-
# Maximum across all partitions and previous max.
|
|
695
|
-
np.maximum(
|
|
696
|
-
np.max(elems),
|
|
697
|
-
previous_max_ids_per_partition[table_name],
|
|
698
|
-
),
|
|
699
|
-
num_local_cpu_devices,
|
|
700
|
-
)
|
|
701
|
-
for table_name, elems in stats.max_ids_per_partition.items()
|
|
702
|
-
}
|
|
703
|
-
local_max_unique_ids_per_partition = {
|
|
704
|
-
name: np.repeat(
|
|
705
|
-
# Maximum across all partitions and previous max.
|
|
706
|
-
np.maximum(
|
|
707
|
-
np.max(elems),
|
|
708
|
-
previous_max_unique_ids_per_partition[name],
|
|
709
|
-
),
|
|
710
|
-
num_local_cpu_devices,
|
|
711
|
-
)
|
|
712
|
-
for name, elems in stats.max_unique_ids_per_partition.items()
|
|
713
|
-
}
|
|
714
|
-
local_buffer_size = {
|
|
715
|
-
table_name: np.repeat(
|
|
716
|
-
np.maximum(
|
|
717
|
-
np.max(
|
|
718
|
-
# Round values up to the next multiple of 8.
|
|
719
|
-
# Currently using this as a proxy for the actual
|
|
720
|
-
# required buffer size.
|
|
721
|
-
((elems + 7) // 8) * 8
|
|
722
|
-
)
|
|
723
|
-
* global_device_count
|
|
724
|
-
* num_sc_per_device
|
|
725
|
-
* local_device_count
|
|
726
|
-
* num_sc_per_device,
|
|
727
|
-
previous_buffer_size[table_name],
|
|
728
|
-
),
|
|
729
|
-
num_local_cpu_devices,
|
|
730
|
-
)
|
|
731
|
-
for table_name, elems in stats.max_ids_per_partition.items()
|
|
732
|
-
}
|
|
647
|
+
)
|
|
733
648
|
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
649
|
+
if training and self.update_stats:
|
|
650
|
+
# Synchronize input statistics across all devices and update the
|
|
651
|
+
# underlying stacked tables specs in the feature specs.
|
|
652
|
+
|
|
653
|
+
# Gather stats across all processes/devices via process_allgather.
|
|
654
|
+
all_stats = multihost_utils.process_allgather(stats)
|
|
655
|
+
all_stats = jax.tree.map(np.max, all_stats)
|
|
656
|
+
|
|
657
|
+
# Check if stats changed enough to warrant action.
|
|
658
|
+
stacked_table_specs = embedding.get_stacked_table_specs(
|
|
659
|
+
self._config.feature_specs
|
|
744
660
|
)
|
|
745
|
-
|
|
746
|
-
|
|
661
|
+
changed = any(
|
|
662
|
+
all_stats.max_ids_per_partition[stack_name]
|
|
663
|
+
> spec.max_ids_per_partition
|
|
664
|
+
or all_stats.max_unique_ids_per_partition[stack_name]
|
|
665
|
+
> spec.max_unique_ids_per_partition
|
|
666
|
+
or all_stats.required_buffer_size_per_sc[stack_name]
|
|
667
|
+
* num_sc_per_device
|
|
668
|
+
> (spec.suggested_coo_buffer_size_per_device or 0)
|
|
669
|
+
for stack_name, spec in stacked_table_specs.items()
|
|
747
670
|
)
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
671
|
+
|
|
672
|
+
# Update configuration and repeat preprocessing if stats changed.
|
|
673
|
+
if changed:
|
|
674
|
+
for stack_name, spec in stacked_table_specs.items():
|
|
675
|
+
all_stats.max_ids_per_partition[stack_name] = np.max(
|
|
676
|
+
[
|
|
677
|
+
all_stats.max_ids_per_partition[stack_name],
|
|
678
|
+
spec.max_ids_per_partition,
|
|
679
|
+
]
|
|
680
|
+
)
|
|
681
|
+
all_stats.max_unique_ids_per_partition[stack_name] = np.max(
|
|
682
|
+
[
|
|
683
|
+
all_stats.max_unique_ids_per_partition[stack_name],
|
|
684
|
+
spec.max_unique_ids_per_partition,
|
|
685
|
+
]
|
|
686
|
+
)
|
|
687
|
+
all_stats.required_buffer_size_per_sc[stack_name] = np.max(
|
|
688
|
+
[
|
|
689
|
+
all_stats.required_buffer_size_per_sc[stack_name],
|
|
690
|
+
(
|
|
691
|
+
(spec.suggested_coo_buffer_size_per_device or 0)
|
|
692
|
+
+ (num_sc_per_device - 1)
|
|
693
|
+
)
|
|
694
|
+
// num_sc_per_device,
|
|
695
|
+
]
|
|
696
|
+
)
|
|
697
|
+
|
|
698
|
+
embedding.update_preprocessing_parameters(
|
|
699
|
+
self._config.feature_specs, all_stats, num_sc_per_device
|
|
777
700
|
)
|
|
778
|
-
|
|
701
|
+
|
|
702
|
+
# Re-execute preprocessing with consistent input statistics.
|
|
703
|
+
preprocessed, _ = embedding_utils.stack_and_shard_samples(
|
|
779
704
|
self._config.feature_specs,
|
|
780
|
-
|
|
781
|
-
|
|
705
|
+
samples,
|
|
706
|
+
local_device_count,
|
|
707
|
+
global_device_count,
|
|
708
|
+
num_sc_per_device,
|
|
782
709
|
)
|
|
783
710
|
|
|
784
711
|
return {"inputs": preprocessed}
|
|
@@ -826,19 +753,22 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
|
|
|
826
753
|
raise ValueError("Layer must first be built before setting tables.")
|
|
827
754
|
|
|
828
755
|
if "default_device" in self._placement_to_path_to_feature_config:
|
|
829
|
-
|
|
756
|
+
table_name_to_embedding_layer = {}
|
|
830
757
|
for (
|
|
831
758
|
path,
|
|
832
759
|
feature_config,
|
|
833
760
|
) in self._placement_to_path_to_feature_config[
|
|
834
761
|
"default_device"
|
|
835
762
|
].items():
|
|
836
|
-
|
|
763
|
+
table_name_to_embedding_layer[feature_config.table.name] = (
|
|
837
764
|
self._default_device_embedding_layers[path]
|
|
838
765
|
)
|
|
839
766
|
|
|
840
|
-
for
|
|
841
|
-
|
|
767
|
+
for (
|
|
768
|
+
table_name,
|
|
769
|
+
embedding_layer,
|
|
770
|
+
) in table_name_to_embedding_layer.items():
|
|
771
|
+
table_values = tables.get(table_name, None)
|
|
842
772
|
if table_values is not None:
|
|
843
773
|
if embedding_layer.lora_enabled:
|
|
844
774
|
raise ValueError("Cannot set table if LoRA is enabled.")
|
|
@@ -851,8 +781,8 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
|
|
|
851
781
|
|
|
852
782
|
config = self._config
|
|
853
783
|
num_table_shards = config.mesh.devices.size * config.num_sc_per_device
|
|
854
|
-
table_specs =
|
|
855
|
-
sharded_tables =
|
|
784
|
+
table_specs = embedding.get_table_specs(config.feature_specs)
|
|
785
|
+
sharded_tables = jte_table_stacking.stack_and_shard_tables(
|
|
856
786
|
table_specs,
|
|
857
787
|
tables,
|
|
858
788
|
num_table_shards,
|
|
@@ -871,8 +801,8 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
|
|
|
871
801
|
# Assign stacked table variables to the device values.
|
|
872
802
|
keras.tree.map_structure_up_to(
|
|
873
803
|
device_tables,
|
|
874
|
-
lambda
|
|
875
|
-
table_value:
|
|
804
|
+
lambda embedding_variables,
|
|
805
|
+
table_value: embedding_variables.table.assign(table_value),
|
|
876
806
|
self._table_and_slot_variables,
|
|
877
807
|
device_tables,
|
|
878
808
|
)
|
|
@@ -883,17 +813,19 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
|
|
|
883
813
|
|
|
884
814
|
config = self._config
|
|
885
815
|
num_table_shards = config.mesh.devices.size * config.num_sc_per_device
|
|
886
|
-
table_specs =
|
|
816
|
+
table_specs = embedding.get_table_specs(config.feature_specs)
|
|
887
817
|
|
|
888
818
|
# Extract only the table variables, not the gradient slot variables.
|
|
889
819
|
table_variables = {
|
|
890
|
-
name: jax.device_get(
|
|
891
|
-
for name,
|
|
820
|
+
name: jax.device_get(embedding_variables.table.value)
|
|
821
|
+
for name, embedding_variables in (
|
|
822
|
+
self._table_and_slot_variables.items()
|
|
823
|
+
)
|
|
892
824
|
}
|
|
893
825
|
|
|
894
826
|
return typing.cast(
|
|
895
827
|
dict[str, ArrayLike],
|
|
896
|
-
|
|
828
|
+
jte_table_stacking.unshard_and_unstack_tables(
|
|
897
829
|
table_specs, table_variables, num_table_shards
|
|
898
830
|
),
|
|
899
831
|
)
|
|
@@ -16,9 +16,30 @@ from jax_tpu_embedding.sparsecore.utils import utils as jte_utils
|
|
|
16
16
|
from keras_rs.src.layers.embedding.jax import embedding_utils
|
|
17
17
|
from keras_rs.src.types import Nested
|
|
18
18
|
|
|
19
|
-
|
|
20
|
-
|
|
19
|
+
if jax.__version_info__ >= (0, 8, 0):
|
|
20
|
+
from jax import shard_map
|
|
21
|
+
else:
|
|
22
|
+
from jax.experimental.shard_map import shard_map as exp_shard_map
|
|
23
|
+
|
|
24
|
+
def shard_map( # type: ignore[misc]
|
|
25
|
+
f: Any = None,
|
|
26
|
+
/,
|
|
27
|
+
*,
|
|
28
|
+
out_specs: Any,
|
|
29
|
+
in_specs: Any,
|
|
30
|
+
mesh: Any = None,
|
|
31
|
+
check_vma: bool = True,
|
|
32
|
+
) -> Any:
|
|
33
|
+
return exp_shard_map(
|
|
34
|
+
f,
|
|
35
|
+
mesh=mesh,
|
|
36
|
+
in_specs=in_specs,
|
|
37
|
+
out_specs=out_specs,
|
|
38
|
+
check_rep=check_vma,
|
|
39
|
+
) # type: ignore[no-untyped-call]
|
|
40
|
+
|
|
21
41
|
|
|
42
|
+
ShardedCooMatrix = embedding_utils.ShardedCooMatrix
|
|
22
43
|
ArrayLike: TypeAlias = jax.Array | np.ndarray[Any, Any]
|
|
23
44
|
JaxLayout: TypeAlias = jax.sharding.NamedSharding | jax_layout.Format
|
|
24
45
|
|
|
@@ -121,7 +142,7 @@ def embedding_lookup(
|
|
|
121
142
|
mesh=config.mesh,
|
|
122
143
|
in_specs=(pd, pt),
|
|
123
144
|
out_specs=pd,
|
|
124
|
-
|
|
145
|
+
check_vma=False,
|
|
125
146
|
),
|
|
126
147
|
)
|
|
127
148
|
|
|
@@ -220,7 +241,7 @@ def embedding_lookup_bwd(
|
|
|
220
241
|
mesh=config.mesh,
|
|
221
242
|
in_specs=(pd, pd, pt, preplicate),
|
|
222
243
|
out_specs=pt,
|
|
223
|
-
|
|
244
|
+
check_vma=False,
|
|
224
245
|
),
|
|
225
246
|
# in_shardings=(
|
|
226
247
|
# activation_layout,
|