keras-rs-nightly 0.2.2.dev202507030337__tar.gz → 0.3.1.dev202511120334__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of keras-rs-nightly might be problematic. Click here for more details.
- {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/PKG-INFO +4 -3
- {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/api/losses/__init__.py +1 -0
- {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/embedding/base_distributed_embedding.py +85 -55
- {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/embedding/distributed_embedding_config.py +6 -3
- {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/embedding/jax/distributed_embedding.py +111 -195
- keras_rs_nightly-0.3.1.dev202511120334/keras_rs/src/layers/embedding/jax/embedding_utils.py +244 -0
- {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/embedding/tensorflow/config_conversion.py +62 -22
- {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py +18 -6
- keras_rs_nightly-0.3.1.dev202511120334/keras_rs/src/losses/list_mle_loss.py +212 -0
- {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/metrics/ranking_metrics_utils.py +19 -0
- {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/version.py +1 -1
- {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs_nightly.egg-info/PKG-INFO +4 -3
- {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs_nightly.egg-info/SOURCES.txt +1 -0
- {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/pyproject.toml +4 -3
- keras_rs_nightly-0.2.2.dev202507030337/keras_rs/src/layers/embedding/jax/embedding_utils.py +0 -596
- {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/README.md +0 -0
- {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/api/__init__.py +0 -0
- {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/api/layers/__init__.py +0 -0
- {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/api/metrics/__init__.py +0 -0
- {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/__init__.py +0 -0
- {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/api_export.py +0 -0
- {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/__init__.py +0 -0
- {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/embedding/__init__.py +0 -0
- {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/embedding/distributed_embedding.py +0 -0
- {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/embedding/embed_reduce.py +0 -0
- {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/embedding/jax/__init__.py +0 -0
- {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/embedding/jax/checkpoint_utils.py +0 -0
- {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/embedding/jax/config_conversion.py +0 -0
- {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/embedding/jax/embedding_lookup.py +0 -0
- {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/embedding/tensorflow/__init__.py +0 -0
- {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/feature_interaction/__init__.py +0 -0
- {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/feature_interaction/dot_interaction.py +0 -0
- {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/feature_interaction/feature_cross.py +0 -0
- {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/retrieval/__init__.py +0 -0
- {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/retrieval/brute_force_retrieval.py +0 -0
- {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/retrieval/hard_negative_mining.py +0 -0
- {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/retrieval/remove_accidental_hits.py +0 -0
- {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/retrieval/retrieval.py +0 -0
- {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/layers/retrieval/sampling_probability_correction.py +0 -0
- {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/losses/__init__.py +0 -0
- {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/losses/pairwise_hinge_loss.py +0 -0
- {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/losses/pairwise_logistic_loss.py +0 -0
- {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/losses/pairwise_loss.py +0 -0
- {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/losses/pairwise_loss_utils.py +0 -0
- {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/losses/pairwise_mean_squared_error.py +0 -0
- {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/losses/pairwise_soft_zero_one_loss.py +0 -0
- {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/metrics/__init__.py +0 -0
- {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/metrics/dcg.py +0 -0
- {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/metrics/mean_average_precision.py +0 -0
- {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/metrics/mean_reciprocal_rank.py +0 -0
- {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/metrics/ndcg.py +0 -0
- {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/metrics/precision_at_k.py +0 -0
- {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/metrics/ranking_metric.py +0 -0
- {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/metrics/recall_at_k.py +0 -0
- {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/metrics/utils.py +0 -0
- {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/types.py +0 -0
- {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/utils/__init__.py +0 -0
- {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/utils/doc_string_utils.py +0 -0
- {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs/src/utils/keras_utils.py +0 -0
- {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs_nightly.egg-info/dependency_links.txt +0 -0
- {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs_nightly.egg-info/requires.txt +0 -0
- {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/keras_rs_nightly.egg-info/top_level.txt +0 -0
- {keras_rs_nightly-0.2.2.dev202507030337 → keras_rs_nightly-0.3.1.dev202511120334}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: keras-rs-nightly
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.3.1.dev202511120334
|
|
4
4
|
Summary: Multi-backend recommender systems with Keras 3.
|
|
5
5
|
Author-email: Keras team <keras-users@googlegroups.com>
|
|
6
6
|
License: Apache License 2.0
|
|
@@ -8,8 +8,9 @@ Project-URL: Home, https://keras.io/keras_rs
|
|
|
8
8
|
Project-URL: Repository, https://github.com/keras-team/keras-rs
|
|
9
9
|
Classifier: Development Status :: 3 - Alpha
|
|
10
10
|
Classifier: Programming Language :: Python :: 3
|
|
11
|
-
Classifier: Programming Language :: Python :: 3.10
|
|
12
11
|
Classifier: Programming Language :: Python :: 3.11
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
13
14
|
Classifier: Programming Language :: Python :: 3 :: Only
|
|
14
15
|
Classifier: Operating System :: Unix
|
|
15
16
|
Classifier: Operating System :: Microsoft :: Windows
|
|
@@ -17,7 +18,7 @@ Classifier: Operating System :: MacOS
|
|
|
17
18
|
Classifier: Intended Audience :: Science/Research
|
|
18
19
|
Classifier: Topic :: Scientific/Engineering
|
|
19
20
|
Classifier: Topic :: Software Development
|
|
20
|
-
Requires-Python: >=3.
|
|
21
|
+
Requires-Python: >=3.11
|
|
21
22
|
Description-Content-Type: text/markdown
|
|
22
23
|
Requires-Dist: keras
|
|
23
24
|
Requires-Dist: ml-dtypes
|
|
@@ -4,6 +4,7 @@ This file was autogenerated. Do not edit it by hand,
|
|
|
4
4
|
since your modifications would be overwritten.
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
|
+
from keras_rs.src.losses.list_mle_loss import ListMLELoss as ListMLELoss
|
|
7
8
|
from keras_rs.src.losses.pairwise_hinge_loss import (
|
|
8
9
|
PairwiseHingeLoss as PairwiseHingeLoss,
|
|
9
10
|
)
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import collections
|
|
2
|
+
import dataclasses
|
|
2
3
|
import importlib.util
|
|
3
4
|
import typing
|
|
4
5
|
from typing import Any, Sequence
|
|
@@ -20,9 +21,10 @@ EmbedReduce = embed_reduce.EmbedReduce
|
|
|
20
21
|
SUPPORTED_PLACEMENTS = ("auto", "default_device", "sparsecore")
|
|
21
22
|
|
|
22
23
|
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
24
|
+
@dataclasses.dataclass(eq=True, unsafe_hash=True, order=True)
|
|
25
|
+
class PlacementAndPath:
|
|
26
|
+
placement: str
|
|
27
|
+
path: str
|
|
26
28
|
|
|
27
29
|
|
|
28
30
|
def _ragged_to_dense_inputs(
|
|
@@ -146,14 +148,14 @@ class DistributedEmbedding(keras.layers.Layer):
|
|
|
146
148
|
feature1 = keras_rs.layers.FeatureConfig(
|
|
147
149
|
name="feature1",
|
|
148
150
|
table=table1,
|
|
149
|
-
input_shape=(
|
|
150
|
-
output_shape=(
|
|
151
|
+
input_shape=(GLOBAL_BATCH_SIZE,),
|
|
152
|
+
output_shape=(GLOBAL_BATCH_SIZE, TABLE1_EMBEDDING_SIZE),
|
|
151
153
|
)
|
|
152
154
|
feature2 = keras_rs.layers.FeatureConfig(
|
|
153
155
|
name="feature2",
|
|
154
156
|
table=table2,
|
|
155
|
-
input_shape=(
|
|
156
|
-
output_shape=(
|
|
157
|
+
input_shape=(GLOBAL_BATCH_SIZE,),
|
|
158
|
+
output_shape=(GLOBAL_BATCH_SIZE, TABLE2_EMBEDDING_SIZE),
|
|
157
159
|
)
|
|
158
160
|
|
|
159
161
|
feature_configs = {
|
|
@@ -337,18 +339,33 @@ class DistributedEmbedding(keras.layers.Layer):
|
|
|
337
339
|
embedding_layer = DistributedEmbedding(feature_configs)
|
|
338
340
|
|
|
339
341
|
# Add preprocessing to a data input pipeline.
|
|
340
|
-
def
|
|
341
|
-
for (inputs, weights), labels in iter(
|
|
342
|
+
def preprocessed_dataset_generator(dataset):
|
|
343
|
+
for (inputs, weights), labels in iter(dataset):
|
|
342
344
|
yield embedding_layer.preprocess(
|
|
343
345
|
inputs, weights, training=True
|
|
344
346
|
), labels
|
|
345
347
|
|
|
346
|
-
preprocessed_train_dataset =
|
|
348
|
+
preprocessed_train_dataset = preprocessed_dataset_generator(train_dataset)
|
|
347
349
|
```
|
|
348
350
|
This explicit preprocessing stage combines the input and optional weights,
|
|
349
351
|
so the new data can be passed directly into the `inputs` argument of the
|
|
350
352
|
layer or model.
|
|
351
353
|
|
|
354
|
+
**NOTE**: When working in a multi-host setting with data parallelism, the
|
|
355
|
+
data needs to be sharded properly across hosts. If the original dataset is
|
|
356
|
+
of type `tf.data.Dataset`, it will need to be manually sharded _prior_ to
|
|
357
|
+
applying the preprocess generator:
|
|
358
|
+
```python
|
|
359
|
+
# Manually shard the dataset across hosts.
|
|
360
|
+
train_dataset = distribution.distribute_dataset(train_dataset)
|
|
361
|
+
distribution.auto_shard_dataset = False # Dataset is already sharded.
|
|
362
|
+
|
|
363
|
+
# Add a preprocessing stage to the distributed data input pipeline.
|
|
364
|
+
train_dataset = preprocessed_dataset_generator(train_dataset)
|
|
365
|
+
```
|
|
366
|
+
If the original dataset is _not_ a `tf.data.Dataset`, it must already be
|
|
367
|
+
pre-sharded across hosts.
|
|
368
|
+
|
|
352
369
|
#### Usage in a Keras model
|
|
353
370
|
|
|
354
371
|
Once the global distribution is set and the input preprocessing pipeline
|
|
@@ -503,12 +520,12 @@ class DistributedEmbedding(keras.layers.Layer):
|
|
|
503
520
|
With these structures in place, the steps to:
|
|
504
521
|
- go from the deeply nested structure to the two-level structure are:
|
|
505
522
|
- `assert_same_struct` as `self._feature_configs`
|
|
506
|
-
- `
|
|
507
|
-
|
|
523
|
+
- use `self._feature_deeply_nested_placement_and_paths` to map from
|
|
524
|
+
deeply nested to two-level
|
|
508
525
|
- go from the two-level structure to the deeply nested structure:
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
526
|
+
- `assert_same_struct` as `self._placement_to_path_to_feature_config`
|
|
527
|
+
- use `self._feature_deeply_nested_placement_and_paths` to locate each
|
|
528
|
+
output in the two-level dicts
|
|
512
529
|
|
|
513
530
|
Args:
|
|
514
531
|
feature_configs: The deeply nested structure of `FeatureConfig` or
|
|
@@ -575,14 +592,14 @@ class DistributedEmbedding(keras.layers.Layer):
|
|
|
575
592
|
] = collections.defaultdict(dict)
|
|
576
593
|
|
|
577
594
|
def populate_placement_to_path_to_input_shape(
|
|
578
|
-
|
|
595
|
+
pp: PlacementAndPath, input_shape: types.Shape
|
|
579
596
|
) -> None:
|
|
580
|
-
placement_to_path_to_input_shape[
|
|
581
|
-
|
|
582
|
-
|
|
597
|
+
placement_to_path_to_input_shape[pp.placement][pp.path] = (
|
|
598
|
+
input_shape
|
|
599
|
+
)
|
|
583
600
|
|
|
584
601
|
keras.tree.map_structure_up_to(
|
|
585
|
-
self.
|
|
602
|
+
self._feature_deeply_nested_placement_and_paths,
|
|
586
603
|
populate_placement_to_path_to_input_shape,
|
|
587
604
|
self._feature_deeply_nested_placement_and_paths,
|
|
588
605
|
input_shapes,
|
|
@@ -630,35 +647,40 @@ class DistributedEmbedding(keras.layers.Layer):
|
|
|
630
647
|
"""
|
|
631
648
|
# Verify input structure.
|
|
632
649
|
keras.tree.assert_same_structure(self._feature_configs, inputs)
|
|
650
|
+
if weights is not None:
|
|
651
|
+
keras.tree.assert_same_structure(self._feature_configs, weights)
|
|
633
652
|
|
|
634
653
|
if not self.built:
|
|
635
|
-
input_shapes = keras.tree.
|
|
636
|
-
self._feature_configs,
|
|
654
|
+
input_shapes = keras.tree.map_structure(
|
|
637
655
|
lambda array: backend.standardize_shape(array.shape),
|
|
638
656
|
inputs,
|
|
639
657
|
)
|
|
640
658
|
self.build(input_shapes)
|
|
641
659
|
|
|
642
|
-
# Go from deeply nested
|
|
643
|
-
|
|
660
|
+
# Go from deeply nested to nested dict placement -> path -> input.
|
|
661
|
+
def to_placement_to_path(
|
|
662
|
+
tensors: types.Nested[types.Tensor],
|
|
663
|
+
) -> dict[str, dict[str, types.Tensor]]:
|
|
664
|
+
result: dict[str, dict[str, types.Tensor]] = {
|
|
665
|
+
p: dict() for p in self._placement_to_path_to_feature_config
|
|
666
|
+
}
|
|
644
667
|
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
self._placement_to_path_to_feature_config, flat_inputs
|
|
648
|
-
)
|
|
668
|
+
def populate(pp: PlacementAndPath, x: types.Tensor) -> None:
|
|
669
|
+
result[pp.placement][pp.path] = x
|
|
649
670
|
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
placement_to_path_to_weights = keras.tree.pack_sequence_as(
|
|
655
|
-
self._placement_to_path_to_feature_config, flat_weights
|
|
671
|
+
keras.tree.map_structure(
|
|
672
|
+
populate,
|
|
673
|
+
self._feature_deeply_nested_placement_and_paths,
|
|
674
|
+
tensors,
|
|
656
675
|
)
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
676
|
+
return result
|
|
677
|
+
|
|
678
|
+
placement_to_path_to_inputs = to_placement_to_path(inputs)
|
|
679
|
+
|
|
680
|
+
# Same for weights if present.
|
|
681
|
+
placement_to_path_to_weights = (
|
|
682
|
+
to_placement_to_path(weights) if weights is not None else None
|
|
683
|
+
)
|
|
662
684
|
|
|
663
685
|
placement_to_path_to_preprocessed: dict[
|
|
664
686
|
str, dict[str, dict[str, types.Nested[types.Tensor]]]
|
|
@@ -669,7 +691,9 @@ class DistributedEmbedding(keras.layers.Layer):
|
|
|
669
691
|
placement_to_path_to_preprocessed["sparsecore"] = (
|
|
670
692
|
self._sparsecore_preprocess(
|
|
671
693
|
placement_to_path_to_inputs["sparsecore"],
|
|
672
|
-
placement_to_path_to_weights["sparsecore"]
|
|
694
|
+
placement_to_path_to_weights["sparsecore"]
|
|
695
|
+
if placement_to_path_to_weights is not None
|
|
696
|
+
else None,
|
|
673
697
|
training,
|
|
674
698
|
)
|
|
675
699
|
)
|
|
@@ -679,7 +703,9 @@ class DistributedEmbedding(keras.layers.Layer):
|
|
|
679
703
|
placement_to_path_to_preprocessed["default_device"] = (
|
|
680
704
|
self._default_device_preprocess(
|
|
681
705
|
placement_to_path_to_inputs["default_device"],
|
|
682
|
-
placement_to_path_to_weights["default_device"]
|
|
706
|
+
placement_to_path_to_weights["default_device"]
|
|
707
|
+
if placement_to_path_to_weights is not None
|
|
708
|
+
else None,
|
|
683
709
|
training,
|
|
684
710
|
)
|
|
685
711
|
)
|
|
@@ -765,11 +791,13 @@ class DistributedEmbedding(keras.layers.Layer):
|
|
|
765
791
|
placement_to_path_to_outputs,
|
|
766
792
|
)
|
|
767
793
|
|
|
768
|
-
# Go from placement -> path -> output to
|
|
769
|
-
|
|
794
|
+
# Go from placement -> path -> output to deeply nested structure.
|
|
795
|
+
def populate_output(pp: PlacementAndPath) -> types.Tensor:
|
|
796
|
+
return placement_to_path_to_outputs[pp.placement][pp.path]
|
|
770
797
|
|
|
771
|
-
|
|
772
|
-
|
|
798
|
+
return keras.tree.map_structure(
|
|
799
|
+
populate_output, self._feature_deeply_nested_placement_and_paths
|
|
800
|
+
)
|
|
773
801
|
|
|
774
802
|
def get_embedding_tables(self) -> dict[str, types.Tensor]:
|
|
775
803
|
"""Return the content of the embedding tables by table name.
|
|
@@ -794,13 +822,13 @@ class DistributedEmbedding(keras.layers.Layer):
|
|
|
794
822
|
table_stacking: str | Sequence[Sequence[str]],
|
|
795
823
|
) -> None:
|
|
796
824
|
del table_stacking
|
|
797
|
-
|
|
825
|
+
table_config_id_to_embedding_layer: dict[int, EmbedReduce] = {}
|
|
798
826
|
self._default_device_embedding_layers: dict[str, EmbedReduce] = {}
|
|
799
827
|
|
|
800
828
|
for path, feature_config in feature_configs.items():
|
|
801
|
-
if feature_config.table in
|
|
829
|
+
if id(feature_config.table) in table_config_id_to_embedding_layer:
|
|
802
830
|
self._default_device_embedding_layers[path] = (
|
|
803
|
-
|
|
831
|
+
table_config_id_to_embedding_layer[id(feature_config.table)]
|
|
804
832
|
)
|
|
805
833
|
else:
|
|
806
834
|
embedding_layer = EmbedReduce(
|
|
@@ -810,7 +838,9 @@ class DistributedEmbedding(keras.layers.Layer):
|
|
|
810
838
|
embeddings_initializer=feature_config.table.initializer,
|
|
811
839
|
combiner=feature_config.table.combiner,
|
|
812
840
|
)
|
|
813
|
-
|
|
841
|
+
table_config_id_to_embedding_layer[id(feature_config.table)] = (
|
|
842
|
+
embedding_layer
|
|
843
|
+
)
|
|
814
844
|
self._default_device_embedding_layers[path] = embedding_layer
|
|
815
845
|
|
|
816
846
|
def _default_device_build(
|
|
@@ -985,8 +1015,8 @@ class DistributedEmbedding(keras.layers.Layer):
|
|
|
985
1015
|
|
|
986
1016
|
# The serialized `TableConfig` objects.
|
|
987
1017
|
table_config_dicts: list[dict[str, Any]] = []
|
|
988
|
-
# Mapping from `TableConfig` to index in `table_config_dicts`.
|
|
989
|
-
|
|
1018
|
+
# Mapping from `TableConfig` id to index in `table_config_dicts`.
|
|
1019
|
+
table_config_id_to_index: dict[int, int] = {}
|
|
990
1020
|
|
|
991
1021
|
def serialize_feature_config(
|
|
992
1022
|
feature_config: FeatureConfig,
|
|
@@ -996,17 +1026,17 @@ class DistributedEmbedding(keras.layers.Layer):
|
|
|
996
1026
|
# key.
|
|
997
1027
|
feature_config_dict = feature_config.get_config()
|
|
998
1028
|
|
|
999
|
-
if feature_config.table not in
|
|
1029
|
+
if id(feature_config.table) not in table_config_id_to_index:
|
|
1000
1030
|
# Save the serialized `TableConfig` the first time we see it and
|
|
1001
1031
|
# remember its index.
|
|
1002
|
-
|
|
1032
|
+
table_config_id_to_index[id(feature_config.table)] = len(
|
|
1003
1033
|
table_config_dicts
|
|
1004
1034
|
)
|
|
1005
1035
|
table_config_dicts.append(feature_config_dict["table"])
|
|
1006
1036
|
|
|
1007
1037
|
# Replace the serialized `TableConfig` with its index.
|
|
1008
|
-
feature_config_dict["table"] =
|
|
1009
|
-
feature_config.table
|
|
1038
|
+
feature_config_dict["table"] = table_config_id_to_index[
|
|
1039
|
+
id(feature_config.table)
|
|
1010
1040
|
]
|
|
1011
1041
|
return feature_config_dict
|
|
1012
1042
|
|
|
@@ -10,7 +10,7 @@ from keras_rs.src.api_export import keras_rs_export
|
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
@keras_rs_export("keras_rs.layers.TableConfig")
|
|
13
|
-
@dataclasses.dataclass(
|
|
13
|
+
@dataclasses.dataclass(order=True)
|
|
14
14
|
class TableConfig:
|
|
15
15
|
"""Configuration for one embedding table.
|
|
16
16
|
|
|
@@ -88,7 +88,7 @@ class TableConfig:
|
|
|
88
88
|
|
|
89
89
|
|
|
90
90
|
@keras_rs_export("keras_rs.layers.FeatureConfig")
|
|
91
|
-
@dataclasses.dataclass(
|
|
91
|
+
@dataclasses.dataclass(order=True)
|
|
92
92
|
class FeatureConfig:
|
|
93
93
|
"""Configuration for one embedding feature.
|
|
94
94
|
|
|
@@ -102,7 +102,10 @@ class FeatureConfig:
|
|
|
102
102
|
input_shape: The input shape of the feature. The feature fed into the
|
|
103
103
|
layer has to match the shape. Note that for ragged dimensions in the
|
|
104
104
|
input, the dimension provided here presents the maximum value;
|
|
105
|
-
anything larger will be truncated.
|
|
105
|
+
anything larger will be truncated. Also note that the first
|
|
106
|
+
dimension represents the global batch size. For example, on TPU,
|
|
107
|
+
this represents the total number of samples that are dispatched to
|
|
108
|
+
all the TPUs connected to the current host.
|
|
106
109
|
output_shape: The output shape of the feature activation. What is
|
|
107
110
|
returned by the embedding layer has to match this shape.
|
|
108
111
|
"""
|