keras-rs-nightly 0.3.1.dev202512130338__py3-none-any.whl → 0.4.1.dev202601250348__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -457,6 +457,10 @@ class DistributedEmbedding(keras.layers.Layer):
457
457
  tables in the inner lists together. Note that table stacking is not
458
458
  supported on older TPUs, in which case the default value of `"auto"`
459
459
  will be interpreted as no table stacking.
460
+ update_stats: If True, `'max_ids_per_partition'`,
461
+ `'max_unique_ids_per_partition'` and
462
+ `'suggested_coo_buffer_size_per_device'` are updated during
463
+ training. This argument can be set to True only for the JAX backend.
460
464
  **kwargs: Additional arguments to pass to the layer base class.
461
465
  """
462
466
 
@@ -467,6 +471,7 @@ class DistributedEmbedding(keras.layers.Layer):
467
471
  table_stacking: (
468
472
  str | Sequence[str] | Sequence[Sequence[str]]
469
473
  ) = "auto",
474
+ update_stats: bool = False,
470
475
  **kwargs: Any,
471
476
  ) -> None:
472
477
  super().__init__(**kwargs)
@@ -486,6 +491,8 @@ class DistributedEmbedding(keras.layers.Layer):
486
491
  table_stacking,
487
492
  )
488
493
 
494
+ self.update_stats = update_stats
495
+
489
496
  @keras_utils.no_automatic_dependency_tracking
490
497
  def _init_feature_configs_structures(
491
498
  self,
@@ -407,7 +407,9 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
407
407
  if isinstance(table_stacking, str):
408
408
  if table_stacking == "auto":
409
409
  jte_table_stacking.auto_stack_tables(
410
- feature_specs, global_device_count, num_sc_per_device
410
+ feature_specs,
411
+ global_device_count,
412
+ num_sc_per_device,
411
413
  )
412
414
  else:
413
415
  raise ValueError(
@@ -644,7 +646,7 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
644
646
  num_sc_per_device,
645
647
  )
646
648
 
647
- if training:
649
+ if training and self.update_stats:
648
650
  # Synchronize input statistics across all devices and update the
649
651
  # underlying stacked tables specs in the feature specs.
650
652
 
@@ -35,8 +35,15 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
35
35
  table_stacking: (
36
36
  str | Sequence[str] | Sequence[Sequence[str]]
37
37
  ) = "auto",
38
+ update_stats: bool = False,
38
39
  **kwargs: Any,
39
40
  ) -> None:
41
+ # `update_stats` is supported only on JAX.
42
+ if update_stats:
43
+ raise ValueError(
44
+ "`update_stats` cannot be True for the TensorFlow backend."
45
+ )
46
+
40
47
  # Intercept arguments that are supported only on TensorFlow.
41
48
  self._optimizer = kwargs.pop("optimizer", None)
42
49
  self._pipeline_execution_with_tensor_core = kwargs.pop(
keras_rs/src/version.py CHANGED
@@ -1,7 +1,7 @@
1
1
  from keras_rs.src.api_export import keras_rs_export
2
2
 
3
3
  # Unique source of truth for the version number.
4
- __version__ = "0.3.1.dev202512130338"
4
+ __version__ = "0.4.1.dev202601250348"
5
5
 
6
6
 
7
7
  @keras_rs_export("keras_rs.version")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: keras-rs-nightly
3
- Version: 0.3.1.dev202512130338
3
+ Version: 0.4.1.dev202601250348
4
4
  Summary: Multi-backend recommender systems with Keras 3.
5
5
  Author-email: Keras team <keras-users@googlegroups.com>
6
6
  License: Apache License 2.0
@@ -5,22 +5,22 @@ keras_rs/metrics/__init__.py,sha256=Qxpf6OFooIL9TIn2l3WgOea3HFRG0hq02glPAxtMZ9c,
5
5
  keras_rs/src/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
6
  keras_rs/src/api_export.py,sha256=RsmG-DvO-cdFeAF9W6LRzms0kvtm-Yp9BAA_d-952zI,510
7
7
  keras_rs/src/types.py,sha256=1A-oLRdX1-f2DsVZBcNl8qNsaH8pM-gnleLT9FWZWBw,1189
8
- keras_rs/src/version.py,sha256=E7EoQd2LJhzxrYtMh6GLQgtbm2hzCwUNr9AcZw0MOgc,224
8
+ keras_rs/src/version.py,sha256=7F19b6JBtXTYKOxUk6K2Y_nS6JGidlVq305CdH3935o,224
9
9
  keras_rs/src/layers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
10
  keras_rs/src/layers/embedding/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
- keras_rs/src/layers/embedding/base_distributed_embedding.py,sha256=RkXZ6notj3Cq6ryR9w30Wb8UlaWjLcUK2Os9ZUQvuhY,45568
11
+ keras_rs/src/layers/embedding/base_distributed_embedding.py,sha256=REpNKTvkS4eoAi9t1DohnDMfWwMgfxDN4ByWji3aALM,45906
12
12
  keras_rs/src/layers/embedding/distributed_embedding.py,sha256=94jxUHoGK3Gs9yfV0KxFTuqPo7XFnhgCNlO2FEeiSgM,1072
13
13
  keras_rs/src/layers/embedding/distributed_embedding_config.py,sha256=L41x6W1xcXI-3m94nOh_OsHn6OYpoynakKJvNboJnvE,5762
14
14
  keras_rs/src/layers/embedding/embed_reduce.py,sha256=c-MnEw1-KWs0jTf0JJ_ZBOY-9hRkiFyu989Dof3DnS8,12343
15
15
  keras_rs/src/layers/embedding/jax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
16
16
  keras_rs/src/layers/embedding/jax/checkpoint_utils.py,sha256=wZ4I5WZVNg5WnrD2j7nhAXgLzDc7xMrUEkSAOx5Sz5c,3495
17
17
  keras_rs/src/layers/embedding/jax/config_conversion.py,sha256=Di1UzRwLgGHd7RuWYJMj2mCOr1u9MseFEWaYKnwD9Bs,16742
18
- keras_rs/src/layers/embedding/jax/distributed_embedding.py,sha256=lY7cwJgVUDpYUH-n7AcA1qlnSzhtckNtD3UqtpAYXz8,32267
18
+ keras_rs/src/layers/embedding/jax/distributed_embedding.py,sha256=D2tJehcy10w30KzYKnnV30WwCUvgLUm2YQM3Twwge9M,32338
19
19
  keras_rs/src/layers/embedding/jax/embedding_lookup.py,sha256=a90tWTbU9tkFdESG3xir9PTtcvb1cmYR8vl5dDK9PSY,8703
20
20
  keras_rs/src/layers/embedding/jax/embedding_utils.py,sha256=5rQGli4Qflg0BU2-j_-4xbBxSqopqbtjkY2KKYWq64Y,7329
21
21
  keras_rs/src/layers/embedding/tensorflow/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
22
22
  keras_rs/src/layers/embedding/tensorflow/config_conversion.py,sha256=HpuDthRQQ3X0EO8dW6OAdcgTODkujZlx_swgreVwXyk,13220
23
- keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py,sha256=TBPYV8gP3ZnAFEwtxmWr_Rp3s-Cj0RrKzF6UOALJ4B0,17817
23
+ keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py,sha256=rkxnPzMHmq82FEzrLrO13NhDHPiX-3PxRM3AUE6Rv10,18050
24
24
  keras_rs/src/layers/feature_interaction/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
25
25
  keras_rs/src/layers/feature_interaction/dot_interaction.py,sha256=Rs8xIHXNWQNiwjp_xzvQRmTSV1AyhJjDgVc3K5pTmrQ,8530
26
26
  keras_rs/src/layers/feature_interaction/feature_cross.py,sha256=Wq_eQvO0WTRlep69QbKi8TwY8bnFoF9vreP_j6ZHNFE,8666
@@ -52,7 +52,7 @@ keras_rs/src/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuF
52
52
  keras_rs/src/utils/doc_string_utils.py,sha256=CmqomepmaYcvpACpXEXkrJb8DMnvIgmYK-lJ53lYarY,1675
53
53
  keras_rs/src/utils/keras_utils.py,sha256=dc-NFzs3a-qmRw0vBDiMslPLfrm9yymGduLWesXPhuY,2123
54
54
  keras_rs/src/utils/tpu_test_utils.py,sha256=mQVBrI-CCBbXwQxBq1yDKGMwYhm4g4k3_AaSy_sCs0U,4028
55
- keras_rs_nightly-0.3.1.dev202512130338.dist-info/METADATA,sha256=keexSmADAe0Wq2qD1bSC-g8JlUV17XrLU-zldDZ3ozM,5324
56
- keras_rs_nightly-0.3.1.dev202512130338.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
57
- keras_rs_nightly-0.3.1.dev202512130338.dist-info/top_level.txt,sha256=pWs8X78Z0cn6lfcIb9VYOW5UeJ-TpoaO9dByzo7_FFo,9
58
- keras_rs_nightly-0.3.1.dev202512130338.dist-info/RECORD,,
55
+ keras_rs_nightly-0.4.1.dev202601250348.dist-info/METADATA,sha256=APNEvzS76AMD7Km5fXAWMnmc7VJspHKopOCOAd2xM1s,5324
56
+ keras_rs_nightly-0.4.1.dev202601250348.dist-info/WHEEL,sha256=qELbo2s1Yzl39ZmrAibXA2jjPLUYfnVhUNTlyF1rq0Y,92
57
+ keras_rs_nightly-0.4.1.dev202601250348.dist-info/top_level.txt,sha256=pWs8X78Z0cn6lfcIb9VYOW5UeJ-TpoaO9dByzo7_FFo,9
58
+ keras_rs_nightly-0.4.1.dev202601250348.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.9.0)
2
+ Generator: setuptools (80.10.1)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5