keras-rs-nightly 0.2.2.dev202508060343__py3-none-any.whl → 0.2.2.dev202508070344__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of keras-rs-nightly might be problematic. Click here for more details.
- keras_rs/src/layers/embedding/tensorflow/config_conversion.py +37 -4
- keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py +3 -1
- keras_rs/src/version.py +1 -1
- {keras_rs_nightly-0.2.2.dev202508060343.dist-info → keras_rs_nightly-0.2.2.dev202508070344.dist-info}/METADATA +1 -1
- {keras_rs_nightly-0.2.2.dev202508060343.dist-info → keras_rs_nightly-0.2.2.dev202508070344.dist-info}/RECORD +7 -7
- {keras_rs_nightly-0.2.2.dev202508060343.dist-info → keras_rs_nightly-0.2.2.dev202508070344.dist-info}/WHEEL +0 -0
- {keras_rs_nightly-0.2.2.dev202508060343.dist-info → keras_rs_nightly-0.2.2.dev202508070344.dist-info}/top_level.txt +0 -0
|
@@ -56,6 +56,7 @@ OPTIMIZER_MAPPINGS = {
|
|
|
56
56
|
def translate_keras_rs_configuration(
|
|
57
57
|
feature_configs: types.Nested[FeatureConfig],
|
|
58
58
|
table_stacking: str | Sequence[str] | Sequence[Sequence[str]],
|
|
59
|
+
num_replicas_in_sync: int,
|
|
59
60
|
) -> tuple[
|
|
60
61
|
types.Nested[tf.tpu.experimental.embedding.FeatureConfig],
|
|
61
62
|
tf.tpu.experimental.embedding.SparseCoreEmbeddingConfig,
|
|
@@ -72,7 +73,10 @@ def translate_keras_rs_configuration(
|
|
|
72
73
|
"""
|
|
73
74
|
tables: dict[TableConfig, tf.tpu.experimental.embedding.TableConfig] = {}
|
|
74
75
|
feature_configs = keras.tree.map_structure(
|
|
75
|
-
lambda f: translate_keras_rs_feature_config(
|
|
76
|
+
lambda f: translate_keras_rs_feature_config(
|
|
77
|
+
f, tables, num_replicas_in_sync
|
|
78
|
+
),
|
|
79
|
+
feature_configs,
|
|
76
80
|
)
|
|
77
81
|
|
|
78
82
|
# max_ids_per_chip_per_sample
|
|
@@ -107,6 +111,7 @@ def translate_keras_rs_configuration(
|
|
|
107
111
|
def translate_keras_rs_feature_config(
|
|
108
112
|
feature_config: FeatureConfig,
|
|
109
113
|
tables: dict[TableConfig, tf.tpu.experimental.embedding.TableConfig],
|
|
114
|
+
num_replicas_in_sync: int,
|
|
110
115
|
) -> tf.tpu.experimental.embedding.FeatureConfig:
|
|
111
116
|
"""Translates a Keras RS feature config to a TensorFlow TPU feature config.
|
|
112
117
|
|
|
@@ -120,18 +125,46 @@ def translate_keras_rs_feature_config(
|
|
|
120
125
|
Returns:
|
|
121
126
|
The TensorFlow TPU feature config.
|
|
122
127
|
"""
|
|
128
|
+
if num_replicas_in_sync <= 0:
|
|
129
|
+
raise ValueError(
|
|
130
|
+
"`num_replicas_in_sync` must be positive, "
|
|
131
|
+
f"but got {num_replicas_in_sync}."
|
|
132
|
+
)
|
|
133
|
+
|
|
123
134
|
table = tables.get(feature_config.table, None)
|
|
124
135
|
if table is None:
|
|
125
136
|
table = translate_keras_rs_table_config(feature_config.table)
|
|
126
137
|
tables[feature_config.table] = table
|
|
127
138
|
|
|
139
|
+
if len(feature_config.output_shape) < 2:
|
|
140
|
+
raise ValueError(
|
|
141
|
+
f"Invalid `output_shape` {feature_config.output_shape} in "
|
|
142
|
+
f"`FeatureConfig` {feature_config}. It must have at least 2 "
|
|
143
|
+
"dimensions: a batch dimension and an embedding dimension."
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
# Exclude last dimension, TensorFlow's TPUEmbedding doesn't want it.
|
|
147
|
+
output_shape = list(feature_config.output_shape[0:-1])
|
|
148
|
+
|
|
149
|
+
batch_size = output_shape[0]
|
|
150
|
+
per_replica_batch_size: int | None = None
|
|
151
|
+
if batch_size is not None:
|
|
152
|
+
if batch_size % num_replicas_in_sync != 0:
|
|
153
|
+
raise ValueError(
|
|
154
|
+
f"Invalid `output_shape` {feature_config.output_shape} in "
|
|
155
|
+
f"`FeatureConfig` {feature_config}. Batch size {batch_size} is "
|
|
156
|
+
f"not a multiple of the number of TPUs {num_replicas_in_sync}."
|
|
157
|
+
)
|
|
158
|
+
per_replica_batch_size = batch_size // num_replicas_in_sync
|
|
159
|
+
|
|
160
|
+
# TensorFlow's TPUEmbedding wants the per replica batch size.
|
|
161
|
+
output_shape = [per_replica_batch_size] + output_shape[1:]
|
|
162
|
+
|
|
128
163
|
# max_sequence_length
|
|
129
164
|
return tf.tpu.experimental.embedding.FeatureConfig(
|
|
130
165
|
name=feature_config.name,
|
|
131
166
|
table=table,
|
|
132
|
-
output_shape=
|
|
133
|
-
0:-1
|
|
134
|
-
], # exclude last dimension
|
|
167
|
+
output_shape=output_shape,
|
|
135
168
|
)
|
|
136
169
|
|
|
137
170
|
|
|
@@ -107,7 +107,9 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
|
|
|
107
107
|
)
|
|
108
108
|
self._tpu_feature_configs, self._sparse_core_embedding_config = (
|
|
109
109
|
config_conversion.translate_keras_rs_configuration(
|
|
110
|
-
feature_configs,
|
|
110
|
+
feature_configs,
|
|
111
|
+
table_stacking,
|
|
112
|
+
strategy.num_replicas_in_sync,
|
|
111
113
|
)
|
|
112
114
|
)
|
|
113
115
|
if tpu_embedding_feature == EMBEDDING_FEATURE_V1:
|
keras_rs/src/version.py
CHANGED
|
@@ -5,7 +5,7 @@ keras_rs/metrics/__init__.py,sha256=Qxpf6OFooIL9TIn2l3WgOea3HFRG0hq02glPAxtMZ9c,
|
|
|
5
5
|
keras_rs/src/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
6
6
|
keras_rs/src/api_export.py,sha256=RsmG-DvO-cdFeAF9W6LRzms0kvtm-Yp9BAA_d-952zI,510
|
|
7
7
|
keras_rs/src/types.py,sha256=1A-oLRdX1-f2DsVZBcNl8qNsaH8pM-gnleLT9FWZWBw,1189
|
|
8
|
-
keras_rs/src/version.py,sha256=
|
|
8
|
+
keras_rs/src/version.py,sha256=u9dAyD9wPmdZRGUcU0EBtAFVCmSHpqmIJPDURbAXDoo,224
|
|
9
9
|
keras_rs/src/layers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
10
10
|
keras_rs/src/layers/embedding/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
11
11
|
keras_rs/src/layers/embedding/base_distributed_embedding.py,sha256=J1mW7sgOgfkR8s-n3uyuqMb_iwPgrD6j0J5OBSfXQt8,45006
|
|
@@ -19,8 +19,8 @@ keras_rs/src/layers/embedding/jax/distributed_embedding.py,sha256=jLqEuh_7hHM4ba
|
|
|
19
19
|
keras_rs/src/layers/embedding/jax/embedding_lookup.py,sha256=8LigXjPr7uQaUOdZM6yoLGoPYdRcbkXkFeL_sJoQ6uQ,8223
|
|
20
20
|
keras_rs/src/layers/embedding/jax/embedding_utils.py,sha256=EHrQjPLl94STLWf9g8Ew8nuwupXRq-a_QmvFlXV6G6A,20331
|
|
21
21
|
keras_rs/src/layers/embedding/tensorflow/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
22
|
-
keras_rs/src/layers/embedding/tensorflow/config_conversion.py,sha256=
|
|
23
|
-
keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py,sha256
|
|
22
|
+
keras_rs/src/layers/embedding/tensorflow/config_conversion.py,sha256=6aU7B-SX-WUqH8UC1amXg6BpFdalFmbr7rYQiiH_k4A,12883
|
|
23
|
+
keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py,sha256=-K6l3PpIqOEVCzcVZN_cJMx9ddJnRKBTj486CP3wkps,17278
|
|
24
24
|
keras_rs/src/layers/feature_interaction/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
25
25
|
keras_rs/src/layers/feature_interaction/dot_interaction.py,sha256=Rs8xIHXNWQNiwjp_xzvQRmTSV1AyhJjDgVc3K5pTmrQ,8530
|
|
26
26
|
keras_rs/src/layers/feature_interaction/feature_cross.py,sha256=Wq_eQvO0WTRlep69QbKi8TwY8bnFoF9vreP_j6ZHNFE,8666
|
|
@@ -50,7 +50,7 @@ keras_rs/src/metrics/utils.py,sha256=fGTo8j0ykVE5Y3yQCS2orSFcHY20Uxt0NazyPsybUsw
|
|
|
50
50
|
keras_rs/src/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
51
51
|
keras_rs/src/utils/doc_string_utils.py,sha256=CmqomepmaYcvpACpXEXkrJb8DMnvIgmYK-lJ53lYarY,1675
|
|
52
52
|
keras_rs/src/utils/keras_utils.py,sha256=dc-NFzs3a-qmRw0vBDiMslPLfrm9yymGduLWesXPhuY,2123
|
|
53
|
-
keras_rs_nightly-0.2.2.
|
|
54
|
-
keras_rs_nightly-0.2.2.
|
|
55
|
-
keras_rs_nightly-0.2.2.
|
|
56
|
-
keras_rs_nightly-0.2.2.
|
|
53
|
+
keras_rs_nightly-0.2.2.dev202508070344.dist-info/METADATA,sha256=rBjnDQocZX2HP0E7Rj72zR7icvhp_ziuzH46k1ESB20,5273
|
|
54
|
+
keras_rs_nightly-0.2.2.dev202508070344.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
55
|
+
keras_rs_nightly-0.2.2.dev202508070344.dist-info/top_level.txt,sha256=pWs8X78Z0cn6lfcIb9VYOW5UeJ-TpoaO9dByzo7_FFo,9
|
|
56
|
+
keras_rs_nightly-0.2.2.dev202508070344.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|