keras-rs-nightly 0.3.1.dev202510280332__py3-none-any.whl → 0.3.1.dev202510300334__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of keras-rs-nightly might be problematic. Click here for more details.

@@ -1,5 +1,6 @@
1
1
  """JAX implementation of the TPU embedding layer."""
2
2
 
3
+ import dataclasses
3
4
  import math
4
5
  import typing
5
6
  from typing import Any, Mapping, Sequence, Union
@@ -445,6 +446,30 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
445
446
  table_specs = embedding.get_table_specs(feature_specs)
446
447
  table_stacks = jte_table_stacking.get_table_stacks(table_specs)
447
448
 
449
+ # Create new instances of StackTableSpec with updated values that are
450
+ # the maximum from stacked tables.
451
+ stacked_table_specs = embedding.get_stacked_table_specs(feature_specs)
452
+ stacked_table_specs = {
453
+ stack_name: dataclasses.replace(
454
+ stacked_table_spec,
455
+ max_ids_per_partition=max(
456
+ table.max_ids_per_partition
457
+ for table in table_stacks[stack_name]
458
+ ),
459
+ max_unique_ids_per_partition=max(
460
+ table.max_unique_ids_per_partition
461
+ for table in table_stacks[stack_name]
462
+ ),
463
+ )
464
+ for stack_name, stacked_table_spec in stacked_table_specs.items()
465
+ }
466
+
467
+ # Rewrite the stacked_table_spec in all TableSpecs.
468
+ for stack_name, table_specs in table_stacks.items():
469
+ stacked_table_spec = stacked_table_specs[stack_name]
470
+ for table_spec in table_specs:
471
+ table_spec.stacked_table_spec = stacked_table_spec
472
+
448
473
  # Create variables for all stacked tables and slot variables.
449
474
  with sparsecore_distribution.scope():
450
475
  self._table_and_slot_variables = {
keras_rs/src/version.py CHANGED
@@ -1,7 +1,7 @@
1
1
  from keras_rs.src.api_export import keras_rs_export
2
2
 
3
3
  # Unique source of truth for the version number.
4
- __version__ = "0.3.1.dev202510280332"
4
+ __version__ = "0.3.1.dev202510300334"
5
5
 
6
6
 
7
7
  @keras_rs_export("keras_rs.version")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: keras-rs-nightly
3
- Version: 0.3.1.dev202510280332
3
+ Version: 0.3.1.dev202510300334
4
4
  Summary: Multi-backend recommender systems with Keras 3.
5
5
  Author-email: Keras team <keras-users@googlegroups.com>
6
6
  License: Apache License 2.0
@@ -5,7 +5,7 @@ keras_rs/metrics/__init__.py,sha256=Qxpf6OFooIL9TIn2l3WgOea3HFRG0hq02glPAxtMZ9c,
5
5
  keras_rs/src/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
6
  keras_rs/src/api_export.py,sha256=RsmG-DvO-cdFeAF9W6LRzms0kvtm-Yp9BAA_d-952zI,510
7
7
  keras_rs/src/types.py,sha256=1A-oLRdX1-f2DsVZBcNl8qNsaH8pM-gnleLT9FWZWBw,1189
8
- keras_rs/src/version.py,sha256=LBNXhlFa6P1nQhY9SUb1spImnuIwMrfVQVmJfnUfGGM,224
8
+ keras_rs/src/version.py,sha256=TSXfr6OYMhKNF7BAmoXC-DV8ckEnzJyki8qYNvkM-Rk,224
9
9
  keras_rs/src/layers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
10
  keras_rs/src/layers/embedding/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
11
  keras_rs/src/layers/embedding/base_distributed_embedding.py,sha256=RkXZ6notj3Cq6ryR9w30Wb8UlaWjLcUK2Os9ZUQvuhY,45568
@@ -15,7 +15,7 @@ keras_rs/src/layers/embedding/embed_reduce.py,sha256=c-MnEw1-KWs0jTf0JJ_ZBOY-9hR
15
15
  keras_rs/src/layers/embedding/jax/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
16
16
  keras_rs/src/layers/embedding/jax/checkpoint_utils.py,sha256=wZ4I5WZVNg5WnrD2j7nhAXgLzDc7xMrUEkSAOx5Sz5c,3495
17
17
  keras_rs/src/layers/embedding/jax/config_conversion.py,sha256=Di1UzRwLgGHd7RuWYJMj2mCOr1u9MseFEWaYKnwD9Bs,16742
18
- keras_rs/src/layers/embedding/jax/distributed_embedding.py,sha256=O3G0AFRzukYdXPRyx7ZDqDvNgJrcbFwTCYTHigfdiKw,29628
18
+ keras_rs/src/layers/embedding/jax/distributed_embedding.py,sha256=d0QEb_0JHIs_cqYvmkowfpC1EWuMGsG-nwqggjPiSYY,30710
19
19
  keras_rs/src/layers/embedding/jax/embedding_lookup.py,sha256=8LigXjPr7uQaUOdZM6yoLGoPYdRcbkXkFeL_sJoQ6uQ,8223
20
20
  keras_rs/src/layers/embedding/jax/embedding_utils.py,sha256=slJ0XwkI1z4vTAnRXQwm39LFnK9AL3CODuGRn5BufgE,8292
21
21
  keras_rs/src/layers/embedding/tensorflow/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -50,7 +50,7 @@ keras_rs/src/metrics/utils.py,sha256=fGTo8j0ykVE5Y3yQCS2orSFcHY20Uxt0NazyPsybUsw
50
50
  keras_rs/src/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
51
51
  keras_rs/src/utils/doc_string_utils.py,sha256=CmqomepmaYcvpACpXEXkrJb8DMnvIgmYK-lJ53lYarY,1675
52
52
  keras_rs/src/utils/keras_utils.py,sha256=dc-NFzs3a-qmRw0vBDiMslPLfrm9yymGduLWesXPhuY,2123
53
- keras_rs_nightly-0.3.1.dev202510280332.dist-info/METADATA,sha256=VVWxJMaJj1ItnuQmyfTAFAAdcBJrQ-sxUAg1SYs6c8Q,5324
54
- keras_rs_nightly-0.3.1.dev202510280332.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
55
- keras_rs_nightly-0.3.1.dev202510280332.dist-info/top_level.txt,sha256=pWs8X78Z0cn6lfcIb9VYOW5UeJ-TpoaO9dByzo7_FFo,9
56
- keras_rs_nightly-0.3.1.dev202510280332.dist-info/RECORD,,
53
+ keras_rs_nightly-0.3.1.dev202510300334.dist-info/METADATA,sha256=6rVE87pwIQxWpXVZ6oNm659wv9MgUJLzX-bpi4mrS3o,5324
54
+ keras_rs_nightly-0.3.1.dev202510300334.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
55
+ keras_rs_nightly-0.3.1.dev202510300334.dist-info/top_level.txt,sha256=pWs8X78Z0cn6lfcIb9VYOW5UeJ-TpoaO9dByzo7_FFo,9
56
+ keras_rs_nightly-0.3.1.dev202510300334.dist-info/RECORD,,