keras-rs-nightly 0.2.2.dev202508190331__py3-none-any.whl → 0.3.1.dev202512130338__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_rs/losses/__init__.py +1 -0
- keras_rs/src/layers/embedding/base_distributed_embedding.py +12 -10
- keras_rs/src/layers/embedding/distributed_embedding_config.py +2 -2
- keras_rs/src/layers/embedding/jax/distributed_embedding.py +127 -197
- keras_rs/src/layers/embedding/jax/embedding_lookup.py +25 -4
- keras_rs/src/layers/embedding/jax/embedding_utils.py +22 -401
- keras_rs/src/layers/embedding/tensorflow/config_conversion.py +26 -19
- keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py +15 -5
- keras_rs/src/losses/list_mle_loss.py +212 -0
- keras_rs/src/metrics/ranking_metrics_utils.py +21 -2
- keras_rs/src/utils/tpu_test_utils.py +120 -0
- keras_rs/src/version.py +1 -1
- {keras_rs_nightly-0.2.2.dev202508190331.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/METADATA +4 -3
- {keras_rs_nightly-0.2.2.dev202508190331.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/RECORD +16 -14
- {keras_rs_nightly-0.2.2.dev202508190331.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/WHEEL +0 -0
- {keras_rs_nightly-0.2.2.dev202508190331.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/top_level.txt +0 -0
keras_rs/losses/__init__.py
CHANGED
|
@@ -4,6 +4,7 @@ This file was autogenerated. Do not edit it by hand,
|
|
|
4
4
|
since your modifications would be overwritten.
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
|
+
from keras_rs.src.losses.list_mle_loss import ListMLELoss as ListMLELoss
|
|
7
8
|
from keras_rs.src.losses.pairwise_hinge_loss import (
|
|
8
9
|
PairwiseHingeLoss as PairwiseHingeLoss,
|
|
9
10
|
)
|
|
@@ -822,13 +822,13 @@ class DistributedEmbedding(keras.layers.Layer):
|
|
|
822
822
|
table_stacking: str | Sequence[Sequence[str]],
|
|
823
823
|
) -> None:
|
|
824
824
|
del table_stacking
|
|
825
|
-
|
|
825
|
+
table_config_id_to_embedding_layer: dict[int, EmbedReduce] = {}
|
|
826
826
|
self._default_device_embedding_layers: dict[str, EmbedReduce] = {}
|
|
827
827
|
|
|
828
828
|
for path, feature_config in feature_configs.items():
|
|
829
|
-
if feature_config.table in
|
|
829
|
+
if id(feature_config.table) in table_config_id_to_embedding_layer:
|
|
830
830
|
self._default_device_embedding_layers[path] = (
|
|
831
|
-
|
|
831
|
+
table_config_id_to_embedding_layer[id(feature_config.table)]
|
|
832
832
|
)
|
|
833
833
|
else:
|
|
834
834
|
embedding_layer = EmbedReduce(
|
|
@@ -838,7 +838,9 @@ class DistributedEmbedding(keras.layers.Layer):
|
|
|
838
838
|
embeddings_initializer=feature_config.table.initializer,
|
|
839
839
|
combiner=feature_config.table.combiner,
|
|
840
840
|
)
|
|
841
|
-
|
|
841
|
+
table_config_id_to_embedding_layer[id(feature_config.table)] = (
|
|
842
|
+
embedding_layer
|
|
843
|
+
)
|
|
842
844
|
self._default_device_embedding_layers[path] = embedding_layer
|
|
843
845
|
|
|
844
846
|
def _default_device_build(
|
|
@@ -1013,8 +1015,8 @@ class DistributedEmbedding(keras.layers.Layer):
|
|
|
1013
1015
|
|
|
1014
1016
|
# The serialized `TableConfig` objects.
|
|
1015
1017
|
table_config_dicts: list[dict[str, Any]] = []
|
|
1016
|
-
# Mapping from `TableConfig` to index in `table_config_dicts`.
|
|
1017
|
-
|
|
1018
|
+
# Mapping from `TableConfig` id to index in `table_config_dicts`.
|
|
1019
|
+
table_config_id_to_index: dict[int, int] = {}
|
|
1018
1020
|
|
|
1019
1021
|
def serialize_feature_config(
|
|
1020
1022
|
feature_config: FeatureConfig,
|
|
@@ -1024,17 +1026,17 @@ class DistributedEmbedding(keras.layers.Layer):
|
|
|
1024
1026
|
# key.
|
|
1025
1027
|
feature_config_dict = feature_config.get_config()
|
|
1026
1028
|
|
|
1027
|
-
if feature_config.table not in
|
|
1029
|
+
if id(feature_config.table) not in table_config_id_to_index:
|
|
1028
1030
|
# Save the serialized `TableConfig` the first time we see it and
|
|
1029
1031
|
# remember its index.
|
|
1030
|
-
|
|
1032
|
+
table_config_id_to_index[id(feature_config.table)] = len(
|
|
1031
1033
|
table_config_dicts
|
|
1032
1034
|
)
|
|
1033
1035
|
table_config_dicts.append(feature_config_dict["table"])
|
|
1034
1036
|
|
|
1035
1037
|
# Replace the serialized `TableConfig` with its index.
|
|
1036
|
-
feature_config_dict["table"] =
|
|
1037
|
-
feature_config.table
|
|
1038
|
+
feature_config_dict["table"] = table_config_id_to_index[
|
|
1039
|
+
id(feature_config.table)
|
|
1038
1040
|
]
|
|
1039
1041
|
return feature_config_dict
|
|
1040
1042
|
|
|
@@ -10,7 +10,7 @@ from keras_rs.src.api_export import keras_rs_export
|
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
@keras_rs_export("keras_rs.layers.TableConfig")
|
|
13
|
-
@dataclasses.dataclass(
|
|
13
|
+
@dataclasses.dataclass(order=True)
|
|
14
14
|
class TableConfig:
|
|
15
15
|
"""Configuration for one embedding table.
|
|
16
16
|
|
|
@@ -88,7 +88,7 @@ class TableConfig:
|
|
|
88
88
|
|
|
89
89
|
|
|
90
90
|
@keras_rs_export("keras_rs.layers.FeatureConfig")
|
|
91
|
-
@dataclasses.dataclass(
|
|
91
|
+
@dataclasses.dataclass(order=True)
|
|
92
92
|
class FeatureConfig:
|
|
93
93
|
"""Configuration for one embedding feature.
|
|
94
94
|
|
|
@@ -9,13 +9,13 @@ import keras
|
|
|
9
9
|
import numpy as np
|
|
10
10
|
from jax import numpy as jnp
|
|
11
11
|
from jax.experimental import layout as jax_layout
|
|
12
|
+
from jax.experimental import multihost_utils
|
|
12
13
|
from jax_tpu_embedding.sparsecore.lib.nn import embedding
|
|
13
14
|
from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec
|
|
14
15
|
from jax_tpu_embedding.sparsecore.lib.nn import (
|
|
15
16
|
table_stacking as jte_table_stacking,
|
|
16
17
|
)
|
|
17
18
|
from jax_tpu_embedding.sparsecore.utils import utils as jte_utils
|
|
18
|
-
from keras.src import backend
|
|
19
19
|
|
|
20
20
|
from keras_rs.src import types
|
|
21
21
|
from keras_rs.src.layers.embedding import base_distributed_embedding
|
|
@@ -28,9 +28,14 @@ from keras_rs.src.layers.embedding.jax import embedding_utils
|
|
|
28
28
|
from keras_rs.src.types import Nested
|
|
29
29
|
from keras_rs.src.utils import keras_utils
|
|
30
30
|
|
|
31
|
+
if jax.__version_info__ >= (0, 8, 0):
|
|
32
|
+
from jax import shard_map
|
|
33
|
+
else:
|
|
34
|
+
from jax.experimental.shard_map import shard_map # type: ignore[assignment]
|
|
35
|
+
|
|
36
|
+
|
|
31
37
|
ArrayLike = Union[np.ndarray[Any, Any], jax.Array]
|
|
32
38
|
FeatureConfig = config.FeatureConfig
|
|
33
|
-
shard_map = jax.experimental.shard_map.shard_map # type: ignore[attr-defined]
|
|
34
39
|
|
|
35
40
|
|
|
36
41
|
def _get_partition_spec(
|
|
@@ -247,23 +252,6 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
|
|
|
247
252
|
)
|
|
248
253
|
return sparsecore_distribution, sparsecore_layout
|
|
249
254
|
|
|
250
|
-
def _create_cpu_distribution(
|
|
251
|
-
self, cpu_axis_name: str = "cpu"
|
|
252
|
-
) -> tuple[
|
|
253
|
-
keras.distribution.ModelParallel, keras.distribution.TensorLayout
|
|
254
|
-
]:
|
|
255
|
-
"""Share a variable across all CPU processes."""
|
|
256
|
-
cpu_devices = jax.devices("cpu")
|
|
257
|
-
device_mesh = keras.distribution.DeviceMesh(
|
|
258
|
-
(len(cpu_devices),), [cpu_axis_name], cpu_devices
|
|
259
|
-
)
|
|
260
|
-
replicated_layout = keras.distribution.TensorLayout([], device_mesh)
|
|
261
|
-
layout_map = keras.distribution.LayoutMap(device_mesh=device_mesh)
|
|
262
|
-
cpu_distribution = keras.distribution.ModelParallel(
|
|
263
|
-
layout_map=layout_map
|
|
264
|
-
)
|
|
265
|
-
return cpu_distribution, replicated_layout
|
|
266
|
-
|
|
267
255
|
def _add_sparsecore_weight(
|
|
268
256
|
self,
|
|
269
257
|
name: str,
|
|
@@ -283,7 +271,7 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
|
|
|
283
271
|
table_specs: Sequence[embedding_spec.TableSpec],
|
|
284
272
|
num_shards: int,
|
|
285
273
|
add_slot_variables: bool,
|
|
286
|
-
) ->
|
|
274
|
+
) -> embedding.EmbeddingVariables:
|
|
287
275
|
stacked_table_spec = typing.cast(
|
|
288
276
|
embedding_spec.StackedTableSpec, table_specs[0].stacked_table_spec
|
|
289
277
|
)
|
|
@@ -352,7 +340,7 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
|
|
|
352
340
|
slot_initializers, slot_variables
|
|
353
341
|
)
|
|
354
342
|
|
|
355
|
-
return table_variable, slot_variables
|
|
343
|
+
return embedding.EmbeddingVariables(table_variable, slot_variables)
|
|
356
344
|
|
|
357
345
|
@keras_utils.no_automatic_dependency_tracking
|
|
358
346
|
def _sparsecore_init(
|
|
@@ -405,11 +393,6 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
|
|
|
405
393
|
self._sparsecore_layout = sparsecore_layout
|
|
406
394
|
self._sparsecore_distribution = sparsecore_distribution
|
|
407
395
|
|
|
408
|
-
# Distribution for CPU operations.
|
|
409
|
-
cpu_distribution, cpu_layout = self._create_cpu_distribution()
|
|
410
|
-
self._cpu_distribution = cpu_distribution
|
|
411
|
-
self._cpu_layout = cpu_layout
|
|
412
|
-
|
|
413
396
|
mesh = sparsecore_distribution.device_mesh.backend_mesh
|
|
414
397
|
global_device_count = mesh.devices.size
|
|
415
398
|
num_sc_per_device = jte_utils.num_sparsecores_per_device(
|
|
@@ -464,12 +447,51 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
|
|
|
464
447
|
)
|
|
465
448
|
|
|
466
449
|
# Collect all stacked tables.
|
|
467
|
-
table_specs =
|
|
468
|
-
table_stacks =
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
}
|
|
450
|
+
table_specs = embedding.get_table_specs(feature_specs)
|
|
451
|
+
table_stacks = jte_table_stacking.get_table_stacks(table_specs)
|
|
452
|
+
|
|
453
|
+
# Update stacked table stats to max of values across involved tables.
|
|
454
|
+
max_ids_per_partition = {}
|
|
455
|
+
max_unique_ids_per_partition = {}
|
|
456
|
+
required_buffer_size_per_device = {}
|
|
457
|
+
id_drop_counters = {}
|
|
458
|
+
for stack_name, stack in table_stacks.items():
|
|
459
|
+
max_ids_per_partition[stack_name] = np.max(
|
|
460
|
+
np.asarray(
|
|
461
|
+
[s.max_ids_per_partition for s in stack], dtype=np.int32
|
|
462
|
+
)
|
|
463
|
+
)
|
|
464
|
+
max_unique_ids_per_partition[stack_name] = np.max(
|
|
465
|
+
np.asarray(
|
|
466
|
+
[s.max_unique_ids_per_partition for s in stack],
|
|
467
|
+
dtype=np.int32,
|
|
468
|
+
)
|
|
469
|
+
)
|
|
470
|
+
|
|
471
|
+
# Only set the suggested buffer size if set on any individual table.
|
|
472
|
+
valid_buffer_sizes = [
|
|
473
|
+
s.suggested_coo_buffer_size_per_device
|
|
474
|
+
for s in stack
|
|
475
|
+
if s.suggested_coo_buffer_size_per_device is not None
|
|
476
|
+
]
|
|
477
|
+
if valid_buffer_sizes:
|
|
478
|
+
required_buffer_size_per_device[stack_name] = np.max(
|
|
479
|
+
np.asarray(valid_buffer_sizes, dtype=np.int32)
|
|
480
|
+
)
|
|
481
|
+
|
|
482
|
+
id_drop_counters[stack_name] = 0
|
|
483
|
+
|
|
484
|
+
aggregated_stats = embedding.SparseDenseMatmulInputStats(
|
|
485
|
+
max_ids_per_partition=max_ids_per_partition,
|
|
486
|
+
max_unique_ids_per_partition=max_unique_ids_per_partition,
|
|
487
|
+
required_buffer_size_per_sc=required_buffer_size_per_device,
|
|
488
|
+
id_drop_counters=id_drop_counters,
|
|
489
|
+
)
|
|
490
|
+
embedding.update_preprocessing_parameters(
|
|
491
|
+
feature_specs,
|
|
492
|
+
aggregated_stats,
|
|
493
|
+
num_sc_per_device,
|
|
494
|
+
)
|
|
473
495
|
|
|
474
496
|
# Create variables for all stacked tables and slot variables.
|
|
475
497
|
with sparsecore_distribution.scope():
|
|
@@ -502,50 +524,6 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
|
|
|
502
524
|
)
|
|
503
525
|
self._iterations.overwrite_with_gradient = True
|
|
504
526
|
|
|
505
|
-
with cpu_distribution.scope():
|
|
506
|
-
# Create variables to track static buffer size and max IDs for each
|
|
507
|
-
# table during preprocessing. These variables are shared across all
|
|
508
|
-
# processes on CPU. We don't add these via `add_weight` because we
|
|
509
|
-
# can't have them passed to the training function.
|
|
510
|
-
replicated_zeros_initializer = ShardedInitializer(
|
|
511
|
-
"zeros", cpu_layout
|
|
512
|
-
)
|
|
513
|
-
|
|
514
|
-
with backend.name_scope(self.name, caller=self):
|
|
515
|
-
self._preprocessing_buffer_size = {
|
|
516
|
-
table_name: backend.Variable(
|
|
517
|
-
initializer=replicated_zeros_initializer,
|
|
518
|
-
shape=(),
|
|
519
|
-
dtype=backend.standardize_dtype("int32"),
|
|
520
|
-
trainable=False,
|
|
521
|
-
name=table_name + ":preprocessing:buffer_size",
|
|
522
|
-
)
|
|
523
|
-
for table_name in stacked_table_specs.keys()
|
|
524
|
-
}
|
|
525
|
-
self._preprocessing_max_unique_ids_per_partition = {
|
|
526
|
-
table_name: backend.Variable(
|
|
527
|
-
shape=(),
|
|
528
|
-
name=table_name
|
|
529
|
-
+ ":preprocessing:max_unique_ids_per_partition",
|
|
530
|
-
initializer=replicated_zeros_initializer,
|
|
531
|
-
dtype=backend.standardize_dtype("int32"),
|
|
532
|
-
trainable=False,
|
|
533
|
-
)
|
|
534
|
-
for table_name in stacked_table_specs.keys()
|
|
535
|
-
}
|
|
536
|
-
|
|
537
|
-
self._preprocessing_max_ids_per_partition = {
|
|
538
|
-
table_name: backend.Variable(
|
|
539
|
-
shape=(),
|
|
540
|
-
name=table_name
|
|
541
|
-
+ ":preprocessing:max_ids_per_partition",
|
|
542
|
-
initializer=replicated_zeros_initializer,
|
|
543
|
-
dtype=backend.standardize_dtype("int32"),
|
|
544
|
-
trainable=False,
|
|
545
|
-
)
|
|
546
|
-
for table_name in stacked_table_specs.keys()
|
|
547
|
-
}
|
|
548
|
-
|
|
549
527
|
self._config = jte_embedding_lookup.EmbeddingLookupConfiguration(
|
|
550
528
|
feature_specs,
|
|
551
529
|
mesh=mesh,
|
|
@@ -586,10 +564,8 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
|
|
|
586
564
|
del inputs, weights, training
|
|
587
565
|
|
|
588
566
|
# Each stacked-table gets a ShardedCooMatrix.
|
|
589
|
-
table_specs =
|
|
590
|
-
|
|
591
|
-
)
|
|
592
|
-
table_stacks = embedding_utils.get_table_stacks(table_specs)
|
|
567
|
+
table_specs = embedding.get_table_specs(self._config.feature_specs)
|
|
568
|
+
table_stacks = jte_table_stacking.get_table_stacks(table_specs)
|
|
593
569
|
stacked_table_specs = {
|
|
594
570
|
stack_name: stack[0].stacked_table_spec
|
|
595
571
|
for stack_name, stack in table_stacks.items()
|
|
@@ -660,125 +636,74 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
|
|
|
660
636
|
mesh.devices.item(0)
|
|
661
637
|
)
|
|
662
638
|
|
|
663
|
-
# Get current buffer size/max_ids.
|
|
664
|
-
previous_max_ids_per_partition = keras.tree.map_structure(
|
|
665
|
-
lambda max_ids_per_partition: max_ids_per_partition.value.item(),
|
|
666
|
-
self._preprocessing_max_ids_per_partition,
|
|
667
|
-
)
|
|
668
|
-
previous_max_unique_ids_per_partition = keras.tree.map_structure(
|
|
669
|
-
lambda max_unique_ids_per_partition: (
|
|
670
|
-
max_unique_ids_per_partition.value.item()
|
|
671
|
-
),
|
|
672
|
-
self._preprocessing_max_unique_ids_per_partition,
|
|
673
|
-
)
|
|
674
|
-
previous_buffer_size = keras.tree.map_structure(
|
|
675
|
-
lambda buffer_size: buffer_size.value.item(),
|
|
676
|
-
self._preprocessing_buffer_size,
|
|
677
|
-
)
|
|
678
|
-
|
|
679
639
|
preprocessed, stats = embedding_utils.stack_and_shard_samples(
|
|
680
640
|
self._config.feature_specs,
|
|
681
641
|
samples,
|
|
682
642
|
local_device_count,
|
|
683
643
|
global_device_count,
|
|
684
644
|
num_sc_per_device,
|
|
685
|
-
static_buffer_size=previous_buffer_size,
|
|
686
645
|
)
|
|
687
646
|
|
|
688
|
-
# Extract max unique IDs and buffer sizes.
|
|
689
|
-
# We need to replicate this value across all local CPU devices.
|
|
690
647
|
if training:
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
table_name: np.repeat(
|
|
694
|
-
# Maximum across all partitions and previous max.
|
|
695
|
-
np.maximum(
|
|
696
|
-
np.max(elems),
|
|
697
|
-
previous_max_ids_per_partition[table_name],
|
|
698
|
-
),
|
|
699
|
-
num_local_cpu_devices,
|
|
700
|
-
)
|
|
701
|
-
for table_name, elems in stats.max_ids_per_partition.items()
|
|
702
|
-
}
|
|
703
|
-
local_max_unique_ids_per_partition = {
|
|
704
|
-
name: np.repeat(
|
|
705
|
-
# Maximum across all partitions and previous max.
|
|
706
|
-
np.maximum(
|
|
707
|
-
np.max(elems),
|
|
708
|
-
previous_max_unique_ids_per_partition[name],
|
|
709
|
-
),
|
|
710
|
-
num_local_cpu_devices,
|
|
711
|
-
)
|
|
712
|
-
for name, elems in stats.max_unique_ids_per_partition.items()
|
|
713
|
-
}
|
|
714
|
-
local_buffer_size = {
|
|
715
|
-
table_name: np.repeat(
|
|
716
|
-
np.maximum(
|
|
717
|
-
np.max(
|
|
718
|
-
# Round values up to the next multiple of 8.
|
|
719
|
-
# Currently using this as a proxy for the actual
|
|
720
|
-
# required buffer size.
|
|
721
|
-
((elems + 7) // 8) * 8
|
|
722
|
-
)
|
|
723
|
-
* global_device_count
|
|
724
|
-
* num_sc_per_device
|
|
725
|
-
* local_device_count
|
|
726
|
-
* num_sc_per_device,
|
|
727
|
-
previous_buffer_size[table_name],
|
|
728
|
-
),
|
|
729
|
-
num_local_cpu_devices,
|
|
730
|
-
)
|
|
731
|
-
for table_name, elems in stats.max_ids_per_partition.items()
|
|
732
|
-
}
|
|
648
|
+
# Synchronize input statistics across all devices and update the
|
|
649
|
+
# underlying stacked tables specs in the feature specs.
|
|
733
650
|
|
|
734
|
-
#
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
)
|
|
742
|
-
new_max_ids_per_partition = max_across_cpus(
|
|
743
|
-
local_max_ids_per_partition
|
|
651
|
+
# Gather stats across all processes/devices via process_allgather.
|
|
652
|
+
all_stats = multihost_utils.process_allgather(stats)
|
|
653
|
+
all_stats = jax.tree.map(np.max, all_stats)
|
|
654
|
+
|
|
655
|
+
# Check if stats changed enough to warrant action.
|
|
656
|
+
stacked_table_specs = embedding.get_stacked_table_specs(
|
|
657
|
+
self._config.feature_specs
|
|
744
658
|
)
|
|
745
|
-
|
|
746
|
-
|
|
659
|
+
changed = any(
|
|
660
|
+
all_stats.max_ids_per_partition[stack_name]
|
|
661
|
+
> spec.max_ids_per_partition
|
|
662
|
+
or all_stats.max_unique_ids_per_partition[stack_name]
|
|
663
|
+
> spec.max_unique_ids_per_partition
|
|
664
|
+
or all_stats.required_buffer_size_per_sc[stack_name]
|
|
665
|
+
* num_sc_per_device
|
|
666
|
+
> (spec.suggested_coo_buffer_size_per_device or 0)
|
|
667
|
+
for stack_name, spec in stacked_table_specs.items()
|
|
747
668
|
)
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
669
|
+
|
|
670
|
+
# Update configuration and repeat preprocessing if stats changed.
|
|
671
|
+
if changed:
|
|
672
|
+
for stack_name, spec in stacked_table_specs.items():
|
|
673
|
+
all_stats.max_ids_per_partition[stack_name] = np.max(
|
|
674
|
+
[
|
|
675
|
+
all_stats.max_ids_per_partition[stack_name],
|
|
676
|
+
spec.max_ids_per_partition,
|
|
677
|
+
]
|
|
678
|
+
)
|
|
679
|
+
all_stats.max_unique_ids_per_partition[stack_name] = np.max(
|
|
680
|
+
[
|
|
681
|
+
all_stats.max_unique_ids_per_partition[stack_name],
|
|
682
|
+
spec.max_unique_ids_per_partition,
|
|
683
|
+
]
|
|
684
|
+
)
|
|
685
|
+
all_stats.required_buffer_size_per_sc[stack_name] = np.max(
|
|
686
|
+
[
|
|
687
|
+
all_stats.required_buffer_size_per_sc[stack_name],
|
|
688
|
+
(
|
|
689
|
+
(spec.suggested_coo_buffer_size_per_device or 0)
|
|
690
|
+
+ (num_sc_per_device - 1)
|
|
691
|
+
)
|
|
692
|
+
// num_sc_per_device,
|
|
693
|
+
]
|
|
694
|
+
)
|
|
695
|
+
|
|
696
|
+
embedding.update_preprocessing_parameters(
|
|
697
|
+
self._config.feature_specs, all_stats, num_sc_per_device
|
|
777
698
|
)
|
|
778
|
-
|
|
699
|
+
|
|
700
|
+
# Re-execute preprocessing with consistent input statistics.
|
|
701
|
+
preprocessed, _ = embedding_utils.stack_and_shard_samples(
|
|
779
702
|
self._config.feature_specs,
|
|
780
|
-
|
|
781
|
-
|
|
703
|
+
samples,
|
|
704
|
+
local_device_count,
|
|
705
|
+
global_device_count,
|
|
706
|
+
num_sc_per_device,
|
|
782
707
|
)
|
|
783
708
|
|
|
784
709
|
return {"inputs": preprocessed}
|
|
@@ -826,19 +751,22 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
|
|
|
826
751
|
raise ValueError("Layer must first be built before setting tables.")
|
|
827
752
|
|
|
828
753
|
if "default_device" in self._placement_to_path_to_feature_config:
|
|
829
|
-
|
|
754
|
+
table_name_to_embedding_layer = {}
|
|
830
755
|
for (
|
|
831
756
|
path,
|
|
832
757
|
feature_config,
|
|
833
758
|
) in self._placement_to_path_to_feature_config[
|
|
834
759
|
"default_device"
|
|
835
760
|
].items():
|
|
836
|
-
|
|
761
|
+
table_name_to_embedding_layer[feature_config.table.name] = (
|
|
837
762
|
self._default_device_embedding_layers[path]
|
|
838
763
|
)
|
|
839
764
|
|
|
840
|
-
for
|
|
841
|
-
|
|
765
|
+
for (
|
|
766
|
+
table_name,
|
|
767
|
+
embedding_layer,
|
|
768
|
+
) in table_name_to_embedding_layer.items():
|
|
769
|
+
table_values = tables.get(table_name, None)
|
|
842
770
|
if table_values is not None:
|
|
843
771
|
if embedding_layer.lora_enabled:
|
|
844
772
|
raise ValueError("Cannot set table if LoRA is enabled.")
|
|
@@ -851,8 +779,8 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
|
|
|
851
779
|
|
|
852
780
|
config = self._config
|
|
853
781
|
num_table_shards = config.mesh.devices.size * config.num_sc_per_device
|
|
854
|
-
table_specs =
|
|
855
|
-
sharded_tables =
|
|
782
|
+
table_specs = embedding.get_table_specs(config.feature_specs)
|
|
783
|
+
sharded_tables = jte_table_stacking.stack_and_shard_tables(
|
|
856
784
|
table_specs,
|
|
857
785
|
tables,
|
|
858
786
|
num_table_shards,
|
|
@@ -871,8 +799,8 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
|
|
|
871
799
|
# Assign stacked table variables to the device values.
|
|
872
800
|
keras.tree.map_structure_up_to(
|
|
873
801
|
device_tables,
|
|
874
|
-
lambda
|
|
875
|
-
table_value:
|
|
802
|
+
lambda embedding_variables,
|
|
803
|
+
table_value: embedding_variables.table.assign(table_value),
|
|
876
804
|
self._table_and_slot_variables,
|
|
877
805
|
device_tables,
|
|
878
806
|
)
|
|
@@ -883,17 +811,19 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
|
|
|
883
811
|
|
|
884
812
|
config = self._config
|
|
885
813
|
num_table_shards = config.mesh.devices.size * config.num_sc_per_device
|
|
886
|
-
table_specs =
|
|
814
|
+
table_specs = embedding.get_table_specs(config.feature_specs)
|
|
887
815
|
|
|
888
816
|
# Extract only the table variables, not the gradient slot variables.
|
|
889
817
|
table_variables = {
|
|
890
|
-
name: jax.device_get(
|
|
891
|
-
for name,
|
|
818
|
+
name: jax.device_get(embedding_variables.table.value)
|
|
819
|
+
for name, embedding_variables in (
|
|
820
|
+
self._table_and_slot_variables.items()
|
|
821
|
+
)
|
|
892
822
|
}
|
|
893
823
|
|
|
894
824
|
return typing.cast(
|
|
895
825
|
dict[str, ArrayLike],
|
|
896
|
-
|
|
826
|
+
jte_table_stacking.unshard_and_unstack_tables(
|
|
897
827
|
table_specs, table_variables, num_table_shards
|
|
898
828
|
),
|
|
899
829
|
)
|
|
@@ -16,9 +16,30 @@ from jax_tpu_embedding.sparsecore.utils import utils as jte_utils
|
|
|
16
16
|
from keras_rs.src.layers.embedding.jax import embedding_utils
|
|
17
17
|
from keras_rs.src.types import Nested
|
|
18
18
|
|
|
19
|
-
|
|
20
|
-
|
|
19
|
+
if jax.__version_info__ >= (0, 8, 0):
|
|
20
|
+
from jax import shard_map
|
|
21
|
+
else:
|
|
22
|
+
from jax.experimental.shard_map import shard_map as exp_shard_map
|
|
23
|
+
|
|
24
|
+
def shard_map( # type: ignore[misc]
|
|
25
|
+
f: Any = None,
|
|
26
|
+
/,
|
|
27
|
+
*,
|
|
28
|
+
out_specs: Any,
|
|
29
|
+
in_specs: Any,
|
|
30
|
+
mesh: Any = None,
|
|
31
|
+
check_vma: bool = True,
|
|
32
|
+
) -> Any:
|
|
33
|
+
return exp_shard_map(
|
|
34
|
+
f,
|
|
35
|
+
mesh=mesh,
|
|
36
|
+
in_specs=in_specs,
|
|
37
|
+
out_specs=out_specs,
|
|
38
|
+
check_rep=check_vma,
|
|
39
|
+
) # type: ignore[no-untyped-call]
|
|
40
|
+
|
|
21
41
|
|
|
42
|
+
ShardedCooMatrix = embedding_utils.ShardedCooMatrix
|
|
22
43
|
ArrayLike: TypeAlias = jax.Array | np.ndarray[Any, Any]
|
|
23
44
|
JaxLayout: TypeAlias = jax.sharding.NamedSharding | jax_layout.Format
|
|
24
45
|
|
|
@@ -121,7 +142,7 @@ def embedding_lookup(
|
|
|
121
142
|
mesh=config.mesh,
|
|
122
143
|
in_specs=(pd, pt),
|
|
123
144
|
out_specs=pd,
|
|
124
|
-
|
|
145
|
+
check_vma=False,
|
|
125
146
|
),
|
|
126
147
|
)
|
|
127
148
|
|
|
@@ -220,7 +241,7 @@ def embedding_lookup_bwd(
|
|
|
220
241
|
mesh=config.mesh,
|
|
221
242
|
in_specs=(pd, pd, pt, preplicate),
|
|
222
243
|
out_specs=pt,
|
|
223
|
-
|
|
244
|
+
check_vma=False,
|
|
224
245
|
),
|
|
225
246
|
# in_shardings=(
|
|
226
247
|
# activation_layout,
|