keras-rs-nightly 0.2.2.dev202508050345__tar.gz → 0.2.2.dev202508070344__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of keras-rs-nightly might be problematic. Click here for more details.

Files changed (61) hide show
  1. {keras_rs_nightly-0.2.2.dev202508050345 → keras_rs_nightly-0.2.2.dev202508070344}/PKG-INFO +1 -1
  2. {keras_rs_nightly-0.2.2.dev202508050345 → keras_rs_nightly-0.2.2.dev202508070344}/keras_rs/src/layers/embedding/base_distributed_embedding.py +4 -4
  3. {keras_rs_nightly-0.2.2.dev202508050345 → keras_rs_nightly-0.2.2.dev202508070344}/keras_rs/src/layers/embedding/distributed_embedding_config.py +4 -1
  4. {keras_rs_nightly-0.2.2.dev202508050345 → keras_rs_nightly-0.2.2.dev202508070344}/keras_rs/src/layers/embedding/tensorflow/config_conversion.py +37 -4
  5. {keras_rs_nightly-0.2.2.dev202508050345 → keras_rs_nightly-0.2.2.dev202508070344}/keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py +3 -1
  6. {keras_rs_nightly-0.2.2.dev202508050345 → keras_rs_nightly-0.2.2.dev202508070344}/keras_rs/src/version.py +1 -1
  7. {keras_rs_nightly-0.2.2.dev202508050345 → keras_rs_nightly-0.2.2.dev202508070344}/keras_rs_nightly.egg-info/PKG-INFO +1 -1
  8. {keras_rs_nightly-0.2.2.dev202508050345 → keras_rs_nightly-0.2.2.dev202508070344}/README.md +0 -0
  9. {keras_rs_nightly-0.2.2.dev202508050345 → keras_rs_nightly-0.2.2.dev202508070344}/keras_rs/api/__init__.py +0 -0
  10. {keras_rs_nightly-0.2.2.dev202508050345 → keras_rs_nightly-0.2.2.dev202508070344}/keras_rs/api/layers/__init__.py +0 -0
  11. {keras_rs_nightly-0.2.2.dev202508050345 → keras_rs_nightly-0.2.2.dev202508070344}/keras_rs/api/losses/__init__.py +0 -0
  12. {keras_rs_nightly-0.2.2.dev202508050345 → keras_rs_nightly-0.2.2.dev202508070344}/keras_rs/api/metrics/__init__.py +0 -0
  13. {keras_rs_nightly-0.2.2.dev202508050345 → keras_rs_nightly-0.2.2.dev202508070344}/keras_rs/src/__init__.py +0 -0
  14. {keras_rs_nightly-0.2.2.dev202508050345 → keras_rs_nightly-0.2.2.dev202508070344}/keras_rs/src/api_export.py +0 -0
  15. {keras_rs_nightly-0.2.2.dev202508050345 → keras_rs_nightly-0.2.2.dev202508070344}/keras_rs/src/layers/__init__.py +0 -0
  16. {keras_rs_nightly-0.2.2.dev202508050345 → keras_rs_nightly-0.2.2.dev202508070344}/keras_rs/src/layers/embedding/__init__.py +0 -0
  17. {keras_rs_nightly-0.2.2.dev202508050345 → keras_rs_nightly-0.2.2.dev202508070344}/keras_rs/src/layers/embedding/distributed_embedding.py +0 -0
  18. {keras_rs_nightly-0.2.2.dev202508050345 → keras_rs_nightly-0.2.2.dev202508070344}/keras_rs/src/layers/embedding/embed_reduce.py +0 -0
  19. {keras_rs_nightly-0.2.2.dev202508050345 → keras_rs_nightly-0.2.2.dev202508070344}/keras_rs/src/layers/embedding/jax/__init__.py +0 -0
  20. {keras_rs_nightly-0.2.2.dev202508050345 → keras_rs_nightly-0.2.2.dev202508070344}/keras_rs/src/layers/embedding/jax/checkpoint_utils.py +0 -0
  21. {keras_rs_nightly-0.2.2.dev202508050345 → keras_rs_nightly-0.2.2.dev202508070344}/keras_rs/src/layers/embedding/jax/config_conversion.py +0 -0
  22. {keras_rs_nightly-0.2.2.dev202508050345 → keras_rs_nightly-0.2.2.dev202508070344}/keras_rs/src/layers/embedding/jax/distributed_embedding.py +0 -0
  23. {keras_rs_nightly-0.2.2.dev202508050345 → keras_rs_nightly-0.2.2.dev202508070344}/keras_rs/src/layers/embedding/jax/embedding_lookup.py +0 -0
  24. {keras_rs_nightly-0.2.2.dev202508050345 → keras_rs_nightly-0.2.2.dev202508070344}/keras_rs/src/layers/embedding/jax/embedding_utils.py +0 -0
  25. {keras_rs_nightly-0.2.2.dev202508050345 → keras_rs_nightly-0.2.2.dev202508070344}/keras_rs/src/layers/embedding/tensorflow/__init__.py +0 -0
  26. {keras_rs_nightly-0.2.2.dev202508050345 → keras_rs_nightly-0.2.2.dev202508070344}/keras_rs/src/layers/feature_interaction/__init__.py +0 -0
  27. {keras_rs_nightly-0.2.2.dev202508050345 → keras_rs_nightly-0.2.2.dev202508070344}/keras_rs/src/layers/feature_interaction/dot_interaction.py +0 -0
  28. {keras_rs_nightly-0.2.2.dev202508050345 → keras_rs_nightly-0.2.2.dev202508070344}/keras_rs/src/layers/feature_interaction/feature_cross.py +0 -0
  29. {keras_rs_nightly-0.2.2.dev202508050345 → keras_rs_nightly-0.2.2.dev202508070344}/keras_rs/src/layers/retrieval/__init__.py +0 -0
  30. {keras_rs_nightly-0.2.2.dev202508050345 → keras_rs_nightly-0.2.2.dev202508070344}/keras_rs/src/layers/retrieval/brute_force_retrieval.py +0 -0
  31. {keras_rs_nightly-0.2.2.dev202508050345 → keras_rs_nightly-0.2.2.dev202508070344}/keras_rs/src/layers/retrieval/hard_negative_mining.py +0 -0
  32. {keras_rs_nightly-0.2.2.dev202508050345 → keras_rs_nightly-0.2.2.dev202508070344}/keras_rs/src/layers/retrieval/remove_accidental_hits.py +0 -0
  33. {keras_rs_nightly-0.2.2.dev202508050345 → keras_rs_nightly-0.2.2.dev202508070344}/keras_rs/src/layers/retrieval/retrieval.py +0 -0
  34. {keras_rs_nightly-0.2.2.dev202508050345 → keras_rs_nightly-0.2.2.dev202508070344}/keras_rs/src/layers/retrieval/sampling_probability_correction.py +0 -0
  35. {keras_rs_nightly-0.2.2.dev202508050345 → keras_rs_nightly-0.2.2.dev202508070344}/keras_rs/src/losses/__init__.py +0 -0
  36. {keras_rs_nightly-0.2.2.dev202508050345 → keras_rs_nightly-0.2.2.dev202508070344}/keras_rs/src/losses/pairwise_hinge_loss.py +0 -0
  37. {keras_rs_nightly-0.2.2.dev202508050345 → keras_rs_nightly-0.2.2.dev202508070344}/keras_rs/src/losses/pairwise_logistic_loss.py +0 -0
  38. {keras_rs_nightly-0.2.2.dev202508050345 → keras_rs_nightly-0.2.2.dev202508070344}/keras_rs/src/losses/pairwise_loss.py +0 -0
  39. {keras_rs_nightly-0.2.2.dev202508050345 → keras_rs_nightly-0.2.2.dev202508070344}/keras_rs/src/losses/pairwise_loss_utils.py +0 -0
  40. {keras_rs_nightly-0.2.2.dev202508050345 → keras_rs_nightly-0.2.2.dev202508070344}/keras_rs/src/losses/pairwise_mean_squared_error.py +0 -0
  41. {keras_rs_nightly-0.2.2.dev202508050345 → keras_rs_nightly-0.2.2.dev202508070344}/keras_rs/src/losses/pairwise_soft_zero_one_loss.py +0 -0
  42. {keras_rs_nightly-0.2.2.dev202508050345 → keras_rs_nightly-0.2.2.dev202508070344}/keras_rs/src/metrics/__init__.py +0 -0
  43. {keras_rs_nightly-0.2.2.dev202508050345 → keras_rs_nightly-0.2.2.dev202508070344}/keras_rs/src/metrics/dcg.py +0 -0
  44. {keras_rs_nightly-0.2.2.dev202508050345 → keras_rs_nightly-0.2.2.dev202508070344}/keras_rs/src/metrics/mean_average_precision.py +0 -0
  45. {keras_rs_nightly-0.2.2.dev202508050345 → keras_rs_nightly-0.2.2.dev202508070344}/keras_rs/src/metrics/mean_reciprocal_rank.py +0 -0
  46. {keras_rs_nightly-0.2.2.dev202508050345 → keras_rs_nightly-0.2.2.dev202508070344}/keras_rs/src/metrics/ndcg.py +0 -0
  47. {keras_rs_nightly-0.2.2.dev202508050345 → keras_rs_nightly-0.2.2.dev202508070344}/keras_rs/src/metrics/precision_at_k.py +0 -0
  48. {keras_rs_nightly-0.2.2.dev202508050345 → keras_rs_nightly-0.2.2.dev202508070344}/keras_rs/src/metrics/ranking_metric.py +0 -0
  49. {keras_rs_nightly-0.2.2.dev202508050345 → keras_rs_nightly-0.2.2.dev202508070344}/keras_rs/src/metrics/ranking_metrics_utils.py +0 -0
  50. {keras_rs_nightly-0.2.2.dev202508050345 → keras_rs_nightly-0.2.2.dev202508070344}/keras_rs/src/metrics/recall_at_k.py +0 -0
  51. {keras_rs_nightly-0.2.2.dev202508050345 → keras_rs_nightly-0.2.2.dev202508070344}/keras_rs/src/metrics/utils.py +0 -0
  52. {keras_rs_nightly-0.2.2.dev202508050345 → keras_rs_nightly-0.2.2.dev202508070344}/keras_rs/src/types.py +0 -0
  53. {keras_rs_nightly-0.2.2.dev202508050345 → keras_rs_nightly-0.2.2.dev202508070344}/keras_rs/src/utils/__init__.py +0 -0
  54. {keras_rs_nightly-0.2.2.dev202508050345 → keras_rs_nightly-0.2.2.dev202508070344}/keras_rs/src/utils/doc_string_utils.py +0 -0
  55. {keras_rs_nightly-0.2.2.dev202508050345 → keras_rs_nightly-0.2.2.dev202508070344}/keras_rs/src/utils/keras_utils.py +0 -0
  56. {keras_rs_nightly-0.2.2.dev202508050345 → keras_rs_nightly-0.2.2.dev202508070344}/keras_rs_nightly.egg-info/SOURCES.txt +0 -0
  57. {keras_rs_nightly-0.2.2.dev202508050345 → keras_rs_nightly-0.2.2.dev202508070344}/keras_rs_nightly.egg-info/dependency_links.txt +0 -0
  58. {keras_rs_nightly-0.2.2.dev202508050345 → keras_rs_nightly-0.2.2.dev202508070344}/keras_rs_nightly.egg-info/requires.txt +0 -0
  59. {keras_rs_nightly-0.2.2.dev202508050345 → keras_rs_nightly-0.2.2.dev202508070344}/keras_rs_nightly.egg-info/top_level.txt +0 -0
  60. {keras_rs_nightly-0.2.2.dev202508050345 → keras_rs_nightly-0.2.2.dev202508070344}/pyproject.toml +0 -0
  61. {keras_rs_nightly-0.2.2.dev202508050345 → keras_rs_nightly-0.2.2.dev202508070344}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: keras-rs-nightly
3
- Version: 0.2.2.dev202508050345
3
+ Version: 0.2.2.dev202508070344
4
4
  Summary: Multi-backend recommender systems with Keras 3.
5
5
  Author-email: Keras team <keras-users@googlegroups.com>
6
6
  License: Apache License 2.0
@@ -146,14 +146,14 @@ class DistributedEmbedding(keras.layers.Layer):
146
146
  feature1 = keras_rs.layers.FeatureConfig(
147
147
  name="feature1",
148
148
  table=table1,
149
- input_shape=(PER_REPLICA_BATCH_SIZE,),
150
- output_shape=(PER_REPLICA_BATCH_SIZE, TABLE1_EMBEDDING_SIZE),
149
+ input_shape=(GLOBAL_BATCH_SIZE,),
150
+ output_shape=(GLOBAL_BATCH_SIZE, TABLE1_EMBEDDING_SIZE),
151
151
  )
152
152
  feature2 = keras_rs.layers.FeatureConfig(
153
153
  name="feature2",
154
154
  table=table2,
155
- input_shape=(PER_REPLICA_BATCH_SIZE,),
156
- output_shape=(PER_REPLICA_BATCH_SIZE, TABLE2_EMBEDDING_SIZE),
155
+ input_shape=(GLOBAL_BATCH_SIZE,),
156
+ output_shape=(GLOBAL_BATCH_SIZE, TABLE2_EMBEDDING_SIZE),
157
157
  )
158
158
 
159
159
  feature_configs = {
@@ -102,7 +102,10 @@ class FeatureConfig:
102
102
  input_shape: The input shape of the feature. The feature fed into the
103
103
  layer has to match the shape. Note that for ragged dimensions in the
104
104
  input, the dimension provided here presents the maximum value;
105
- anything larger will be truncated.
105
+ anything larger will be truncated. Also note that the first
106
+ dimension represents the global batch size. For example, on TPU,
107
+ this represents the total number of samples that are dispatched to
108
+ all the TPUs connected to the current host.
106
109
  output_shape: The output shape of the feature activation. What is
107
110
  returned by the embedding layer has to match this shape.
108
111
  """
@@ -56,6 +56,7 @@ OPTIMIZER_MAPPINGS = {
56
56
  def translate_keras_rs_configuration(
57
57
  feature_configs: types.Nested[FeatureConfig],
58
58
  table_stacking: str | Sequence[str] | Sequence[Sequence[str]],
59
+ num_replicas_in_sync: int,
59
60
  ) -> tuple[
60
61
  types.Nested[tf.tpu.experimental.embedding.FeatureConfig],
61
62
  tf.tpu.experimental.embedding.SparseCoreEmbeddingConfig,
@@ -72,7 +73,10 @@ def translate_keras_rs_configuration(
72
73
  """
73
74
  tables: dict[TableConfig, tf.tpu.experimental.embedding.TableConfig] = {}
74
75
  feature_configs = keras.tree.map_structure(
75
- lambda f: translate_keras_rs_feature_config(f, tables), feature_configs
76
+ lambda f: translate_keras_rs_feature_config(
77
+ f, tables, num_replicas_in_sync
78
+ ),
79
+ feature_configs,
76
80
  )
77
81
 
78
82
  # max_ids_per_chip_per_sample
@@ -107,6 +111,7 @@ def translate_keras_rs_configuration(
107
111
  def translate_keras_rs_feature_config(
108
112
  feature_config: FeatureConfig,
109
113
  tables: dict[TableConfig, tf.tpu.experimental.embedding.TableConfig],
114
+ num_replicas_in_sync: int,
110
115
  ) -> tf.tpu.experimental.embedding.FeatureConfig:
111
116
  """Translates a Keras RS feature config to a TensorFlow TPU feature config.
112
117
 
@@ -120,18 +125,46 @@ def translate_keras_rs_feature_config(
120
125
  Returns:
121
126
  The TensorFlow TPU feature config.
122
127
  """
128
+ if num_replicas_in_sync <= 0:
129
+ raise ValueError(
130
+ "`num_replicas_in_sync` must be positive, "
131
+ f"but got {num_replicas_in_sync}."
132
+ )
133
+
123
134
  table = tables.get(feature_config.table, None)
124
135
  if table is None:
125
136
  table = translate_keras_rs_table_config(feature_config.table)
126
137
  tables[feature_config.table] = table
127
138
 
139
+ if len(feature_config.output_shape) < 2:
140
+ raise ValueError(
141
+ f"Invalid `output_shape` {feature_config.output_shape} in "
142
+ f"`FeatureConfig` {feature_config}. It must have at least 2 "
143
+ "dimensions: a batch dimension and an embedding dimension."
144
+ )
145
+
146
+ # Exclude last dimension, TensorFlow's TPUEmbedding doesn't want it.
147
+ output_shape = list(feature_config.output_shape[0:-1])
148
+
149
+ batch_size = output_shape[0]
150
+ per_replica_batch_size: int | None = None
151
+ if batch_size is not None:
152
+ if batch_size % num_replicas_in_sync != 0:
153
+ raise ValueError(
154
+ f"Invalid `output_shape` {feature_config.output_shape} in "
155
+ f"`FeatureConfig` {feature_config}. Batch size {batch_size} is "
156
+ f"not a multiple of the number of TPUs {num_replicas_in_sync}."
157
+ )
158
+ per_replica_batch_size = batch_size // num_replicas_in_sync
159
+
160
+ # TensorFlow's TPUEmbedding wants the per replica batch size.
161
+ output_shape = [per_replica_batch_size] + output_shape[1:]
162
+
128
163
  # max_sequence_length
129
164
  return tf.tpu.experimental.embedding.FeatureConfig(
130
165
  name=feature_config.name,
131
166
  table=table,
132
- output_shape=feature_config.output_shape[
133
- 0:-1
134
- ], # exclude last dimension
167
+ output_shape=output_shape,
135
168
  )
136
169
 
137
170
 
@@ -107,7 +107,9 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
107
107
  )
108
108
  self._tpu_feature_configs, self._sparse_core_embedding_config = (
109
109
  config_conversion.translate_keras_rs_configuration(
110
- feature_configs, table_stacking
110
+ feature_configs,
111
+ table_stacking,
112
+ strategy.num_replicas_in_sync,
111
113
  )
112
114
  )
113
115
  if tpu_embedding_feature == EMBEDDING_FEATURE_V1:
@@ -1,7 +1,7 @@
1
1
  from keras_rs.src.api_export import keras_rs_export
2
2
 
3
3
  # Unique source of truth for the version number.
4
- __version__ = "0.2.2.dev202508050345"
4
+ __version__ = "0.2.2.dev202508070344"
5
5
 
6
6
 
7
7
  @keras_rs_export("keras_rs.version")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: keras-rs-nightly
3
- Version: 0.2.2.dev202508050345
3
+ Version: 0.2.2.dev202508070344
4
4
  Summary: Multi-backend recommender systems with Keras 3.
5
5
  Author-email: Keras team <keras-users@googlegroups.com>
6
6
  License: Apache License 2.0