keras-rs-nightly 0.3.1.dev202512130338__py3-none-any.whl → 0.4.1.dev202601250348__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_rs/src/layers/embedding/base_distributed_embedding.py +7 -0
- keras_rs/src/layers/embedding/jax/distributed_embedding.py +4 -2
- keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py +7 -0
- keras_rs/src/version.py +1 -1
- {keras_rs_nightly-0.3.1.dev202512130338.dist-info → keras_rs_nightly-0.4.1.dev202601250348.dist-info}/METADATA +1 -1
- {keras_rs_nightly-0.3.1.dev202512130338.dist-info → keras_rs_nightly-0.4.1.dev202601250348.dist-info}/RECORD +8 -8
- {keras_rs_nightly-0.3.1.dev202512130338.dist-info → keras_rs_nightly-0.4.1.dev202601250348.dist-info}/WHEEL +1 -1
- {keras_rs_nightly-0.3.1.dev202512130338.dist-info → keras_rs_nightly-0.4.1.dev202601250348.dist-info}/top_level.txt +0 -0
|
@@ -457,6 +457,10 @@ class DistributedEmbedding(keras.layers.Layer):
|
|
|
457
457
|
tables in the inner lists together. Note that table stacking is not
|
|
458
458
|
supported on older TPUs, in which case the default value of `"auto"`
|
|
459
459
|
will be interpreted as no table stacking.
|
|
460
|
+
update_stats: If True, `'max_ids_per_partition'`,
|
|
461
|
+
`'max_unique_ids_per_partition'` and
|
|
462
|
+
`'suggested_coo_buffer_size_per_device'` are updated during
|
|
463
|
+
training. This argument can be set to True only for the JAX backend.
|
|
460
464
|
**kwargs: Additional arguments to pass to the layer base class.
|
|
461
465
|
"""
|
|
462
466
|
|
|
@@ -467,6 +471,7 @@ class DistributedEmbedding(keras.layers.Layer):
|
|
|
467
471
|
table_stacking: (
|
|
468
472
|
str | Sequence[str] | Sequence[Sequence[str]]
|
|
469
473
|
) = "auto",
|
|
474
|
+
update_stats: bool = False,
|
|
470
475
|
**kwargs: Any,
|
|
471
476
|
) -> None:
|
|
472
477
|
super().__init__(**kwargs)
|
|
@@ -486,6 +491,8 @@ class DistributedEmbedding(keras.layers.Layer):
|
|
|
486
491
|
table_stacking,
|
|
487
492
|
)
|
|
488
493
|
|
|
494
|
+
self.update_stats = update_stats
|
|
495
|
+
|
|
489
496
|
@keras_utils.no_automatic_dependency_tracking
|
|
490
497
|
def _init_feature_configs_structures(
|
|
491
498
|
self,
|
|
@@ -407,7 +407,9 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
|
|
|
407
407
|
if isinstance(table_stacking, str):
|
|
408
408
|
if table_stacking == "auto":
|
|
409
409
|
jte_table_stacking.auto_stack_tables(
|
|
410
|
-
feature_specs,
|
|
410
|
+
feature_specs,
|
|
411
|
+
global_device_count,
|
|
412
|
+
num_sc_per_device,
|
|
411
413
|
)
|
|
412
414
|
else:
|
|
413
415
|
raise ValueError(
|
|
@@ -644,7 +646,7 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
|
|
|
644
646
|
num_sc_per_device,
|
|
645
647
|
)
|
|
646
648
|
|
|
647
|
-
if training:
|
|
649
|
+
if training and self.update_stats:
|
|
648
650
|
# Synchronize input statistics across all devices and update the
|
|
649
651
|
# underlying stacked tables specs in the feature specs.
|
|
650
652
|
|
|
@@ -35,8 +35,15 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
|
|
|
35
35
|
table_stacking: (
|
|
36
36
|
str | Sequence[str] | Sequence[Sequence[str]]
|
|
37
37
|
) = "auto",
|
|
38
|
+
update_stats: bool = False,
|
|
38
39
|
**kwargs: Any,
|
|
39
40
|
) -> None:
|
|
41
|
+
# `update_stats` is supported only on JAX.
|
|
42
|
+
if update_stats:
|
|
43
|
+
raise ValueError(
|
|
44
|
+
"`update_stats` cannot be True for the TensorFlow backend."
|
|
45
|
+
)
|
|
46
|
+
|
|
40
47
|
# Intercept arguments that are supported only on TensorFlow.
|
|
41
48
|
self._optimizer = kwargs.pop("optimizer", None)
|
|
42
49
|
self._pipeline_execution_with_tensor_core = kwargs.pop(
|
keras_rs/src/version.py
CHANGED
|
@@ -5,22 +5,22 @@ keras_rs/metrics/__init__.py,sha256=Qxpf6OFooIL9TIn2l3WgOea3HFRG0hq02glPAxtMZ9c,
|
|
|
5
5
|
keras_rs/src/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
6
6
|
keras_rs/src/api_export.py,sha256=RsmG-DvO-cdFeAF9W6LRzms0kvtm-Yp9BAA_d-952zI,510
|
|
7
7
|
keras_rs/src/types.py,sha256=1A-oLRdX1-f2DsVZBcNl8qNsaH8pM-gnleLT9FWZWBw,1189
|
|
8
|
-
keras_rs/src/version.py,sha256=
|
|
8
|
+
keras_rs/src/version.py,sha256=7F19b6JBtXTYKOxUk6K2Y_nS6JGidlVq305CdH3935o,224
|
|
9
9
|
keras_rs/src/layers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
10
10
|
keras_rs/src/layers/embedding/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
11
|
-
keras_rs/src/layers/embedding/base_distributed_embedding.py,sha256=
|
|
11
|
+
keras_rs/src/layers/embedding/base_distributed_embedding.py,sha256=REpNKTvkS4eoAi9t1DohnDMfWwMgfxDN4ByWji3aALM,45906
|
|
12
12
|
keras_rs/src/layers/embedding/distributed_embedding.py,sha256=94jxUHoGK3Gs9yfV0KxFTuqPo7XFnhgCNlO2FEeiSgM,1072
|
|
13
13
|
keras_rs/src/layers/embedding/distributed_embedding_config.py,sha256=L41x6W1xcXI-3m94nOh_OsHn6OYpoynakKJvNboJnvE,5762
|
|
14
14
|
keras_rs/src/layers/embedding/embed_reduce.py,sha256=c-MnEw1-KWs0jTf0JJ_ZBOY-9hRkiFyu989Dof3DnS8,12343
|
|
15
15
|
keras_rs/src/layers/embedding/jax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
16
16
|
keras_rs/src/layers/embedding/jax/checkpoint_utils.py,sha256=wZ4I5WZVNg5WnrD2j7nhAXgLzDc7xMrUEkSAOx5Sz5c,3495
|
|
17
17
|
keras_rs/src/layers/embedding/jax/config_conversion.py,sha256=Di1UzRwLgGHd7RuWYJMj2mCOr1u9MseFEWaYKnwD9Bs,16742
|
|
18
|
-
keras_rs/src/layers/embedding/jax/distributed_embedding.py,sha256=
|
|
18
|
+
keras_rs/src/layers/embedding/jax/distributed_embedding.py,sha256=D2tJehcy10w30KzYKnnV30WwCUvgLUm2YQM3Twwge9M,32338
|
|
19
19
|
keras_rs/src/layers/embedding/jax/embedding_lookup.py,sha256=a90tWTbU9tkFdESG3xir9PTtcvb1cmYR8vl5dDK9PSY,8703
|
|
20
20
|
keras_rs/src/layers/embedding/jax/embedding_utils.py,sha256=5rQGli4Qflg0BU2-j_-4xbBxSqopqbtjkY2KKYWq64Y,7329
|
|
21
21
|
keras_rs/src/layers/embedding/tensorflow/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
22
22
|
keras_rs/src/layers/embedding/tensorflow/config_conversion.py,sha256=HpuDthRQQ3X0EO8dW6OAdcgTODkujZlx_swgreVwXyk,13220
|
|
23
|
-
keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py,sha256=
|
|
23
|
+
keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py,sha256=rkxnPzMHmq82FEzrLrO13NhDHPiX-3PxRM3AUE6Rv10,18050
|
|
24
24
|
keras_rs/src/layers/feature_interaction/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
25
25
|
keras_rs/src/layers/feature_interaction/dot_interaction.py,sha256=Rs8xIHXNWQNiwjp_xzvQRmTSV1AyhJjDgVc3K5pTmrQ,8530
|
|
26
26
|
keras_rs/src/layers/feature_interaction/feature_cross.py,sha256=Wq_eQvO0WTRlep69QbKi8TwY8bnFoF9vreP_j6ZHNFE,8666
|
|
@@ -52,7 +52,7 @@ keras_rs/src/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuF
|
|
|
52
52
|
keras_rs/src/utils/doc_string_utils.py,sha256=CmqomepmaYcvpACpXEXkrJb8DMnvIgmYK-lJ53lYarY,1675
|
|
53
53
|
keras_rs/src/utils/keras_utils.py,sha256=dc-NFzs3a-qmRw0vBDiMslPLfrm9yymGduLWesXPhuY,2123
|
|
54
54
|
keras_rs/src/utils/tpu_test_utils.py,sha256=mQVBrI-CCBbXwQxBq1yDKGMwYhm4g4k3_AaSy_sCs0U,4028
|
|
55
|
-
keras_rs_nightly-0.
|
|
56
|
-
keras_rs_nightly-0.
|
|
57
|
-
keras_rs_nightly-0.
|
|
58
|
-
keras_rs_nightly-0.
|
|
55
|
+
keras_rs_nightly-0.4.1.dev202601250348.dist-info/METADATA,sha256=APNEvzS76AMD7Km5fXAWMnmc7VJspHKopOCOAd2xM1s,5324
|
|
56
|
+
keras_rs_nightly-0.4.1.dev202601250348.dist-info/WHEEL,sha256=qELbo2s1Yzl39ZmrAibXA2jjPLUYfnVhUNTlyF1rq0Y,92
|
|
57
|
+
keras_rs_nightly-0.4.1.dev202601250348.dist-info/top_level.txt,sha256=pWs8X78Z0cn6lfcIb9VYOW5UeJ-TpoaO9dByzo7_FFo,9
|
|
58
|
+
keras_rs_nightly-0.4.1.dev202601250348.dist-info/RECORD,,
|
|
File without changes
|