keras-rs-nightly 0.0.1.dev2025021903__py3-none-any.whl → 0.3.1.dev202512130338__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_rs/__init__.py +9 -28
- keras_rs/layers/__init__.py +37 -0
- keras_rs/losses/__init__.py +19 -0
- keras_rs/metrics/__init__.py +16 -0
- keras_rs/src/layers/embedding/base_distributed_embedding.py +1151 -0
- keras_rs/src/layers/embedding/distributed_embedding.py +33 -0
- keras_rs/src/layers/embedding/distributed_embedding_config.py +132 -0
- keras_rs/src/layers/embedding/embed_reduce.py +309 -0
- keras_rs/src/layers/embedding/jax/__init__.py +0 -0
- keras_rs/src/layers/embedding/jax/checkpoint_utils.py +104 -0
- keras_rs/src/layers/embedding/jax/config_conversion.py +468 -0
- keras_rs/src/layers/embedding/jax/distributed_embedding.py +829 -0
- keras_rs/src/layers/embedding/jax/embedding_lookup.py +276 -0
- keras_rs/src/layers/embedding/jax/embedding_utils.py +217 -0
- keras_rs/src/layers/embedding/tensorflow/__init__.py +0 -0
- keras_rs/src/layers/embedding/tensorflow/config_conversion.py +363 -0
- keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py +436 -0
- keras_rs/src/layers/feature_interaction/__init__.py +0 -0
- keras_rs/src/layers/{modeling → feature_interaction}/dot_interaction.py +116 -25
- keras_rs/src/layers/{modeling → feature_interaction}/feature_cross.py +40 -22
- keras_rs/src/layers/retrieval/brute_force_retrieval.py +16 -65
- keras_rs/src/layers/retrieval/hard_negative_mining.py +94 -0
- keras_rs/src/layers/retrieval/remove_accidental_hits.py +97 -0
- keras_rs/src/layers/retrieval/retrieval.py +127 -0
- keras_rs/src/layers/retrieval/sampling_probability_correction.py +63 -0
- keras_rs/src/losses/__init__.py +0 -0
- keras_rs/src/losses/list_mle_loss.py +212 -0
- keras_rs/src/losses/pairwise_hinge_loss.py +90 -0
- keras_rs/src/losses/pairwise_logistic_loss.py +99 -0
- keras_rs/src/losses/pairwise_loss.py +165 -0
- keras_rs/src/losses/pairwise_loss_utils.py +39 -0
- keras_rs/src/losses/pairwise_mean_squared_error.py +133 -0
- keras_rs/src/losses/pairwise_soft_zero_one_loss.py +98 -0
- keras_rs/src/metrics/__init__.py +0 -0
- keras_rs/src/metrics/dcg.py +161 -0
- keras_rs/src/metrics/mean_average_precision.py +130 -0
- keras_rs/src/metrics/mean_reciprocal_rank.py +121 -0
- keras_rs/src/metrics/ndcg.py +197 -0
- keras_rs/src/metrics/precision_at_k.py +117 -0
- keras_rs/src/metrics/ranking_metric.py +260 -0
- keras_rs/src/metrics/ranking_metrics_utils.py +257 -0
- keras_rs/src/metrics/recall_at_k.py +108 -0
- keras_rs/src/metrics/utils.py +70 -0
- keras_rs/src/types.py +43 -14
- keras_rs/src/utils/doc_string_utils.py +53 -0
- keras_rs/src/utils/keras_utils.py +52 -3
- keras_rs/src/utils/tpu_test_utils.py +120 -0
- keras_rs/src/version.py +1 -1
- {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/METADATA +88 -8
- keras_rs_nightly-0.3.1.dev202512130338.dist-info/RECORD +58 -0
- {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/WHEEL +1 -1
- keras_rs/api/__init__.py +0 -9
- keras_rs/api/layers/__init__.py +0 -11
- keras_rs_nightly-0.0.1.dev2025021903.dist-info/RECORD +0 -19
- /keras_rs/src/layers/{modeling → embedding}/__init__.py +0 -0
- {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,1151 @@
|
|
|
1
|
+
import collections
|
|
2
|
+
import dataclasses
|
|
3
|
+
import importlib.util
|
|
4
|
+
import typing
|
|
5
|
+
from typing import Any, Sequence
|
|
6
|
+
|
|
7
|
+
import keras
|
|
8
|
+
import numpy as np
|
|
9
|
+
from keras.src import backend
|
|
10
|
+
|
|
11
|
+
from keras_rs.src import types
|
|
12
|
+
from keras_rs.src.layers.embedding import distributed_embedding_config
|
|
13
|
+
from keras_rs.src.layers.embedding import embed_reduce
|
|
14
|
+
from keras_rs.src.utils import keras_utils
|
|
15
|
+
|
|
16
|
+
FeatureConfig = distributed_embedding_config.FeatureConfig
|
|
17
|
+
TableConfig = distributed_embedding_config.TableConfig
|
|
18
|
+
EmbedReduce = embed_reduce.EmbedReduce
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
SUPPORTED_PLACEMENTS = ("auto", "default_device", "sparsecore")
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclasses.dataclass(eq=True, unsafe_hash=True, order=True)
|
|
25
|
+
class PlacementAndPath:
|
|
26
|
+
placement: str
|
|
27
|
+
path: str
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _ragged_to_dense_inputs(
|
|
31
|
+
inputs: Any, weights: Any | None = None, dense_row_length: int | None = None
|
|
32
|
+
) -> Any:
|
|
33
|
+
"""Converts a ragged set of inputs and weights to dense.
|
|
34
|
+
|
|
35
|
+
If inputs are ragged and weights are `None`, will create a dense set of
|
|
36
|
+
weights to mask out the new padded values.
|
|
37
|
+
|
|
38
|
+
If inputs are not ragged, returns the original `inputs` and `weights`
|
|
39
|
+
unmodified.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
inputs: The inputs array.
|
|
43
|
+
weights: The optional weights array.
|
|
44
|
+
dense_row_length: The output dense row length. If None, uses the length
|
|
45
|
+
of the longest row of the input.
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
Tuple of new (inputs, weights). If the input is a ragged array, returns
|
|
49
|
+
dense numpy arrays. Otherwise, returns the original input and weights.
|
|
50
|
+
"""
|
|
51
|
+
x = inputs
|
|
52
|
+
w = weights
|
|
53
|
+
# tf.Ragged or other .numpy()-able types.
|
|
54
|
+
if hasattr(x, "numpy") and callable(getattr(x, "numpy")):
|
|
55
|
+
x = x.numpy()
|
|
56
|
+
|
|
57
|
+
# Ragged numpy array to dense numpy array.
|
|
58
|
+
if isinstance(x, np.ndarray) and len(x) > 0 and x.dtype == np.ndarray:
|
|
59
|
+
# Maybe convert weights to numpy.
|
|
60
|
+
if (
|
|
61
|
+
w is not None
|
|
62
|
+
and hasattr(w, "numpy")
|
|
63
|
+
and callable(getattr(w, "numpy"))
|
|
64
|
+
):
|
|
65
|
+
w = w.numpy()
|
|
66
|
+
|
|
67
|
+
if dense_row_length is None:
|
|
68
|
+
# Use length of longest row.
|
|
69
|
+
dense_row_length = max([len(row) for row in x])
|
|
70
|
+
|
|
71
|
+
output = np.zeros((len(x), dense_row_length), dtype=x[0].dtype)
|
|
72
|
+
for i, row in enumerate(x):
|
|
73
|
+
output[i, : len(row)] = row
|
|
74
|
+
|
|
75
|
+
output_weights = np.zeros((len(x), dense_row_length), dtype=np.float32)
|
|
76
|
+
if w is None:
|
|
77
|
+
for i, row in enumerate(x):
|
|
78
|
+
output_weights[i, : len(row)] = 1.0
|
|
79
|
+
else:
|
|
80
|
+
for i, row in enumerate(w):
|
|
81
|
+
output_weights[i, : len(row)] = row
|
|
82
|
+
|
|
83
|
+
return output, output_weights
|
|
84
|
+
|
|
85
|
+
# Convert symbolic ragged/sparse keras tensors to dense tensors.
|
|
86
|
+
if isinstance(x, keras.KerasTensor) and (x.ragged or x.sparse):
|
|
87
|
+
inputs = keras.ops.convert_to_tensor(x, ragged=False)
|
|
88
|
+
weights = keras.ops.convert_to_tensor(x, dtype="float32", ragged=False)
|
|
89
|
+
|
|
90
|
+
# If not a ragged array, return the original, unmodified.
|
|
91
|
+
return inputs, weights
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class DistributedEmbedding(keras.layers.Layer):
|
|
95
|
+
"""DistributedEmbedding, a layer for accelerated large embedding lookups.
|
|
96
|
+
|
|
97
|
+
---
|
|
98
|
+
|
|
99
|
+
## Note: `DistributedEmbedding` is in Preview.
|
|
100
|
+
|
|
101
|
+
---
|
|
102
|
+
|
|
103
|
+
`DistributedEmbedding` is a layer optimized for TPU chips with SparseCore
|
|
104
|
+
and can dramatically improve the speed of embedding lookups and embedding
|
|
105
|
+
training. It works by combining multiple lookups into one invocation, and by
|
|
106
|
+
sharding the embedding tables across the available chips. Note that one will
|
|
107
|
+
only see performance benefits for embedding tables that are large enough to
|
|
108
|
+
to require sharding because they don't fit on a single chip. More details
|
|
109
|
+
are provided in the "Placement" section below.
|
|
110
|
+
|
|
111
|
+
On other hardware, GPUs, CPUs and TPUs without SparseCore,
|
|
112
|
+
`DistributedEmbedding` provides the same API without any specific
|
|
113
|
+
acceleration. No particular distribution scheme is applied besides the one
|
|
114
|
+
set via `keras.distribution.set_distribution`.
|
|
115
|
+
|
|
116
|
+
`DistributedEmbedding` embeds sequences of inputs and reduces them to a
|
|
117
|
+
single embedding by applying a configurable combiner function.
|
|
118
|
+
|
|
119
|
+
### Configuration
|
|
120
|
+
|
|
121
|
+
#### Features and tables
|
|
122
|
+
|
|
123
|
+
A `DistributedEmbedding` embedding layer is configured via a set of
|
|
124
|
+
`keras_rs.layers.FeatureConfig` objects, which themselves refer to
|
|
125
|
+
`keras_rs.layers.TableConfig` objects.
|
|
126
|
+
|
|
127
|
+
- `TableConfig` defines an embedding table with parameters such as its
|
|
128
|
+
vocabulary size, embedding dimension, as well as a combiner for reduction
|
|
129
|
+
and optimizer for training.
|
|
130
|
+
- `FeatureConfig` defines what input features the `DistributedEmbedding`
|
|
131
|
+
will handle and which embedding table to use. Note that multiple features
|
|
132
|
+
can use the same embedding table.
|
|
133
|
+
|
|
134
|
+
```python
|
|
135
|
+
table1 = keras_rs.layers.TableConfig(
|
|
136
|
+
name="table1",
|
|
137
|
+
vocabulary_size=TABLE1_VOCABULARY_SIZE,
|
|
138
|
+
embedding_dim=TABLE1_EMBEDDING_SIZE,
|
|
139
|
+
placement="auto",
|
|
140
|
+
)
|
|
141
|
+
table2 = keras_rs.layers.TableConfig(
|
|
142
|
+
name="table2",
|
|
143
|
+
vocabulary_size=TABLE2_VOCABULARY_SIZE,
|
|
144
|
+
embedding_dim=TABLE2_EMBEDDING_SIZE,
|
|
145
|
+
placement="auto",
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
feature1 = keras_rs.layers.FeatureConfig(
|
|
149
|
+
name="feature1",
|
|
150
|
+
table=table1,
|
|
151
|
+
input_shape=(GLOBAL_BATCH_SIZE,),
|
|
152
|
+
output_shape=(GLOBAL_BATCH_SIZE, TABLE1_EMBEDDING_SIZE),
|
|
153
|
+
)
|
|
154
|
+
feature2 = keras_rs.layers.FeatureConfig(
|
|
155
|
+
name="feature2",
|
|
156
|
+
table=table2,
|
|
157
|
+
input_shape=(GLOBAL_BATCH_SIZE,),
|
|
158
|
+
output_shape=(GLOBAL_BATCH_SIZE, TABLE2_EMBEDDING_SIZE),
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
feature_configs = {
|
|
162
|
+
"feature1": feature1,
|
|
163
|
+
"feature2": feature2,
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
embedding = keras_rs.layers.DistributedEmbedding(feature_configs)
|
|
167
|
+
```
|
|
168
|
+
|
|
169
|
+
#### Optimizers
|
|
170
|
+
|
|
171
|
+
Each embedding table within `DistributedEmbedding` uses its own optimizer
|
|
172
|
+
for training, which is independent from the optimizer set on the model via
|
|
173
|
+
`model.compile()`.
|
|
174
|
+
|
|
175
|
+
Note that not all optimizers are supported. Currently, the following are
|
|
176
|
+
supported on all backends and accelerators:
|
|
177
|
+
|
|
178
|
+
- `keras.optimizers.Adagrad`
|
|
179
|
+
- `keras.optimizers.Adam`
|
|
180
|
+
- `keras.optimizers.Ftrl`
|
|
181
|
+
- `keras.optimizers.SGD`
|
|
182
|
+
|
|
183
|
+
Also, not all parameters of the optimizers are supported (e.g. the
|
|
184
|
+
`nesterov` option of `SGD`). An error is raised when an unsupported
|
|
185
|
+
optimizer or an unsupported optimizer parameter is used.
|
|
186
|
+
|
|
187
|
+
#### Placement
|
|
188
|
+
|
|
189
|
+
Each embedding table within `DistributedEmbedding` can be either placed on
|
|
190
|
+
the SparseCore chip or the default device placement for the accelerator
|
|
191
|
+
(e.g. HBM of the Tensor Cores on TPU). This is controlled by the `placement`
|
|
192
|
+
attribute of `keras_rs.layers.TableConfig`.
|
|
193
|
+
|
|
194
|
+
- A placement of `"sparsecore"` indicates that the table should be placed on
|
|
195
|
+
the SparseCore chips. An error is raised if this option is selected and
|
|
196
|
+
there are no SparseCore chips.
|
|
197
|
+
- A placement of `"default_device"` indicates that the table should not be
|
|
198
|
+
placed on SparseCore, even if available. Instead the table is placed on
|
|
199
|
+
the device where the model normally goes, i.e. the HBM on TPUs and GPUs.
|
|
200
|
+
In this case, if applicable, the table is distributed using the scheme set
|
|
201
|
+
via `keras.distribution.set_distribution`. On GPUs, CPUs and TPUs without
|
|
202
|
+
SparseCore, this is the only placement available, and is the one selected
|
|
203
|
+
by `"auto"`.
|
|
204
|
+
- A placement of `"auto"` indicates to use `"sparsecore"` if available, and
|
|
205
|
+
`"default_device"` otherwise. This is the default when not specified.
|
|
206
|
+
|
|
207
|
+
To optimize performance on TPU:
|
|
208
|
+
|
|
209
|
+
- Tables that are so large that they need to be sharded should use the
|
|
210
|
+
`"sparsecore"` placement.
|
|
211
|
+
- Tables that are small enough should use `"default_device"` and should
|
|
212
|
+
typically be replicated across TPUs by using the
|
|
213
|
+
`keras.distribution.DataParallel` distribution option.
|
|
214
|
+
|
|
215
|
+
### Usage with TensorFlow on TPU with SpareCore
|
|
216
|
+
|
|
217
|
+
#### Inputs
|
|
218
|
+
|
|
219
|
+
In addition to `tf.Tensor`, `DistributedEmbedding` accepts `tf.RaggedTensor`
|
|
220
|
+
and `tf.SparseTensor` as inputs for the embedding lookups. Ragged tensors
|
|
221
|
+
must be ragged in the dimension with index 1. Note that if weights are
|
|
222
|
+
passed, each weight tensor must be of the same class as the inputs for that
|
|
223
|
+
particular feature and use the exact same ragged row lenghts for ragged
|
|
224
|
+
tensors, and the same indices for sparse tensors. All the output of
|
|
225
|
+
`DistributedEmbedding` are dense tensors.
|
|
226
|
+
|
|
227
|
+
#### Setup
|
|
228
|
+
|
|
229
|
+
To use `DistributedEmbedding` on TPUs with TensorFlow, one must use a
|
|
230
|
+
`tf.distribute.TPUStrategy`. The `DistributedEmbedding` layer must be
|
|
231
|
+
created under the `TPUStrategy`.
|
|
232
|
+
|
|
233
|
+
```python
|
|
234
|
+
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="local")
|
|
235
|
+
topology = tf.tpu.experimental.initialize_tpu_system(resolver)
|
|
236
|
+
device_assignment = tf.tpu.experimental.DeviceAssignment.build(
|
|
237
|
+
topology, num_replicas=resolver.get_tpu_system_metadata().num_cores
|
|
238
|
+
)
|
|
239
|
+
strategy = tf.distribute.TPUStrategy(
|
|
240
|
+
resolver, experimental_device_assignment=device_assignment
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
with strategy.scope():
|
|
244
|
+
embedding = keras_rs.layers.DistributedEmbedding(feature_configs)
|
|
245
|
+
```
|
|
246
|
+
|
|
247
|
+
#### Usage in a Keras model
|
|
248
|
+
|
|
249
|
+
To use Keras' `model.fit()`, one must compile the model under the
|
|
250
|
+
`TPUStrategy`. Then, `model.fit()`, `model.evaluate()` or `model.predict()`
|
|
251
|
+
can be called directly. The Keras model takes care of running the model
|
|
252
|
+
using the strategy and also automatically distributes the dataset.
|
|
253
|
+
|
|
254
|
+
```python
|
|
255
|
+
with strategy.scope():
|
|
256
|
+
embedding = keras_rs.layers.DistributedEmbedding(feature_configs)
|
|
257
|
+
model = create_model(embedding)
|
|
258
|
+
model.compile(loss=keras.losses.MeanSquaredError(), optimizer="adam")
|
|
259
|
+
|
|
260
|
+
model.fit(dataset, epochs=10)
|
|
261
|
+
```
|
|
262
|
+
|
|
263
|
+
#### Direct invocation
|
|
264
|
+
|
|
265
|
+
`DistributedEmbedding` must be invoked via a `strategy.run` call nested in a
|
|
266
|
+
`tf.function`.
|
|
267
|
+
|
|
268
|
+
```python
|
|
269
|
+
@tf.function
|
|
270
|
+
def embedding_wrapper(tf_fn_inputs, tf_fn_weights=None):
|
|
271
|
+
def strategy_fn(st_fn_inputs, st_fn_weights):
|
|
272
|
+
return embedding(st_fn_inputs, st_fn_weights)
|
|
273
|
+
|
|
274
|
+
return strategy.run(strategy_fn, args=(tf_fn_inputs, tf_fn_weights)))
|
|
275
|
+
|
|
276
|
+
embedding_wrapper(my_inputs, my_weights)
|
|
277
|
+
```
|
|
278
|
+
|
|
279
|
+
When using a dataset, the dataset must be distributed. The iterator can then
|
|
280
|
+
be passed to the `tf.function` that uses `strategy.run`.
|
|
281
|
+
|
|
282
|
+
```python
|
|
283
|
+
dataset = strategy.experimental_distribute_dataset(dataset)
|
|
284
|
+
|
|
285
|
+
@tf.function
|
|
286
|
+
def run_loop(iterator):
|
|
287
|
+
def step(data):
|
|
288
|
+
(inputs, weights), labels = data
|
|
289
|
+
with tf.GradientTape() as tape:
|
|
290
|
+
result = embedding(inputs, weights)
|
|
291
|
+
loss = keras.losses.mean_squared_error(labels, result)
|
|
292
|
+
tape.gradient(loss, embedding.trainable_variables)
|
|
293
|
+
return result
|
|
294
|
+
|
|
295
|
+
for _ in tf.range(4):
|
|
296
|
+
result = strategy.run(step, args=(next(iterator),))
|
|
297
|
+
|
|
298
|
+
run_loop(iter(dataset))
|
|
299
|
+
```
|
|
300
|
+
|
|
301
|
+
### Usage with JAX on TPU with SpareCore
|
|
302
|
+
|
|
303
|
+
#### Setup
|
|
304
|
+
|
|
305
|
+
To use `DistributedEmbedding` on TPUs with JAX, one must create and set a
|
|
306
|
+
Keras `Distribution`.
|
|
307
|
+
```python
|
|
308
|
+
distribution = keras.distribution.DataParallel(devices=jax.device("tpu"))
|
|
309
|
+
keras.distribution.set_distribution(distribution)
|
|
310
|
+
```
|
|
311
|
+
|
|
312
|
+
#### Inputs
|
|
313
|
+
|
|
314
|
+
For JAX, inputs can either be dense tensors, or ragged (nested) NumPy
|
|
315
|
+
arrays. To enable `jit_compile = True`, one must explicitly call
|
|
316
|
+
`layer.preprocess(...)` on the inputs, and then feed the preprocessed
|
|
317
|
+
output to the model. See the next section on preprocessing for details.
|
|
318
|
+
|
|
319
|
+
Ragged input arrays must be ragged in the dimension with index 1. Note that
|
|
320
|
+
if weights are passed, each weight tensor must be of the same class as the
|
|
321
|
+
inputs for that particular feature and use the exact same ragged row lengths
|
|
322
|
+
for ragged tensors. All the output of `DistributedEmbedding` are dense
|
|
323
|
+
tensors.
|
|
324
|
+
|
|
325
|
+
#### Preprocessing
|
|
326
|
+
|
|
327
|
+
In JAX, SparseCore usage requires specially formatted data that depends
|
|
328
|
+
on properties of the available hardware. This data reformatting
|
|
329
|
+
currently does not support jit-compilation, so must be applied _prior_
|
|
330
|
+
to passing data into a model.
|
|
331
|
+
|
|
332
|
+
Preprocessing works on dense or ragged NumPy arrays, or on tensors that are
|
|
333
|
+
convertible to dense or ragged NumPy arrays like `tf.RaggedTensor`.
|
|
334
|
+
|
|
335
|
+
One simple way to add preprocessing is to append the function to an input
|
|
336
|
+
pipeline by using a python generator.
|
|
337
|
+
```python
|
|
338
|
+
# Create the embedding layer.
|
|
339
|
+
embedding_layer = DistributedEmbedding(feature_configs)
|
|
340
|
+
|
|
341
|
+
# Add preprocessing to a data input pipeline.
|
|
342
|
+
def preprocessed_dataset_generator(dataset):
|
|
343
|
+
for (inputs, weights), labels in iter(dataset):
|
|
344
|
+
yield embedding_layer.preprocess(
|
|
345
|
+
inputs, weights, training=True
|
|
346
|
+
), labels
|
|
347
|
+
|
|
348
|
+
preprocessed_train_dataset = preprocessed_dataset_generator(train_dataset)
|
|
349
|
+
```
|
|
350
|
+
This explicit preprocessing stage combines the input and optional weights,
|
|
351
|
+
so the new data can be passed directly into the `inputs` argument of the
|
|
352
|
+
layer or model.
|
|
353
|
+
|
|
354
|
+
**NOTE**: When working in a multi-host setting with data parallelism, the
|
|
355
|
+
data needs to be sharded properly across hosts. If the original dataset is
|
|
356
|
+
of type `tf.data.Dataset`, it will need to be manually sharded _prior_ to
|
|
357
|
+
applying the preprocess generator:
|
|
358
|
+
```python
|
|
359
|
+
# Manually shard the dataset across hosts.
|
|
360
|
+
train_dataset = distribution.distribute_dataset(train_dataset)
|
|
361
|
+
distribution.auto_shard_dataset = False # Dataset is already sharded.
|
|
362
|
+
|
|
363
|
+
# Add a preprocessing stage to the distributed data input pipeline.
|
|
364
|
+
train_dataset = preprocessed_dataset_generator(train_dataset)
|
|
365
|
+
```
|
|
366
|
+
If the original dataset is _not_ a `tf.data.Dataset`, it must already be
|
|
367
|
+
pre-sharded across hosts.
|
|
368
|
+
|
|
369
|
+
#### Usage in a Keras model
|
|
370
|
+
|
|
371
|
+
Once the global distribution is set and the input preprocessing pipeline
|
|
372
|
+
is defined, model training can proceed as normal. For example:
|
|
373
|
+
```python
|
|
374
|
+
# Construct, compile, and fit the model using the preprocessed data.
|
|
375
|
+
model = keras.Sequential(
|
|
376
|
+
[
|
|
377
|
+
embedding_layer,
|
|
378
|
+
keras.layers.Dense(2),
|
|
379
|
+
keras.layers.Dense(3),
|
|
380
|
+
keras.layers.Dense(4),
|
|
381
|
+
]
|
|
382
|
+
)
|
|
383
|
+
model.compile(optimizer="adam", loss="mse", jit_compile=True)
|
|
384
|
+
model.fit(preprocessed_train_dataset, epochs=10)
|
|
385
|
+
```
|
|
386
|
+
|
|
387
|
+
#### Direct invocation
|
|
388
|
+
|
|
389
|
+
The `DistributedEmbedding` layer can also be invoked directly. Explicit
|
|
390
|
+
preprocessing is required when used with JIT compilation.
|
|
391
|
+
```python
|
|
392
|
+
# Call the layer directly.
|
|
393
|
+
activations = embedding_layer(my_inputs, my_weights)
|
|
394
|
+
|
|
395
|
+
# Call the layer with JIT compilation and explicitly preprocessed inputs.
|
|
396
|
+
embedding_layer_jit = jax.jit(embedding_layer)
|
|
397
|
+
preprocessed_inputs = embedding_layer.preprocess(my_inputs, my_weights)
|
|
398
|
+
activations = embedding_layer_jit(preprocessed_inputs)
|
|
399
|
+
```
|
|
400
|
+
|
|
401
|
+
Similarly, for custom training loops, preprocessing must be applied prior
|
|
402
|
+
to passing the data to the JIT-compiled training step.
|
|
403
|
+
```python
|
|
404
|
+
# Create an optimizer and loss function.
|
|
405
|
+
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
|
|
406
|
+
|
|
407
|
+
def loss_and_updates(trainable_variables, non_trainable_variables, x, y):
|
|
408
|
+
y_pred, non_trainable_variables = model.stateless_call(
|
|
409
|
+
trainable_variables, non_trainable_variables, x, training=True
|
|
410
|
+
)
|
|
411
|
+
loss = keras.losses.mean_squared_error(y, y_pred)
|
|
412
|
+
return loss, non_trainable_variables
|
|
413
|
+
|
|
414
|
+
grad_fn = jax.value_and_grad(loss_and_updates, has_aux=True)
|
|
415
|
+
|
|
416
|
+
# Create a JIT-compiled training step.
|
|
417
|
+
@jax.jit
|
|
418
|
+
def train_step(state, x, y):
|
|
419
|
+
(
|
|
420
|
+
trainable_variables,
|
|
421
|
+
non_trainable_variables,
|
|
422
|
+
optimizer_variables,
|
|
423
|
+
) = state
|
|
424
|
+
(loss, non_trainable_variables), grads = grad_fn(
|
|
425
|
+
trainable_variables, non_trainable_variables, x, y
|
|
426
|
+
)
|
|
427
|
+
trainable_variables, optimizer_variables = optimizer.stateless_apply(
|
|
428
|
+
optimizer_variables, grads, trainable_variables
|
|
429
|
+
)
|
|
430
|
+
return loss, (
|
|
431
|
+
trainable_variables,
|
|
432
|
+
non_trainable_variables,
|
|
433
|
+
optimizer_variables,
|
|
434
|
+
)
|
|
435
|
+
|
|
436
|
+
# Build optimizer variables.
|
|
437
|
+
optimizer.build(model.trainable_variables)
|
|
438
|
+
|
|
439
|
+
# Assemble the training state.
|
|
440
|
+
trainable_variables = model.trainable_variables
|
|
441
|
+
non_trainable_variables = model.non_trainable_variables
|
|
442
|
+
optimizer_variables = optimizer.variables
|
|
443
|
+
state = trainable_variables, non_trainable_variables, optimizer_variables
|
|
444
|
+
|
|
445
|
+
# Training loop.
|
|
446
|
+
for (inputs, weights), labels in train_dataset:
|
|
447
|
+
# Explicitly preprocess the data.
|
|
448
|
+
preprocessed_inputs = embedding_layer.preprocess(inputs, weights)
|
|
449
|
+
loss, state = train_step(state, preprocessed_inputs, labels)
|
|
450
|
+
```
|
|
451
|
+
|
|
452
|
+
Args:
|
|
453
|
+
feature_configs: A nested structure of `keras_rs.layers.FeatureConfig`.
|
|
454
|
+
table_stacking: The table stacking to use. `None` means no table
|
|
455
|
+
stacking. `"auto"` means to stack tables automatically. A list of
|
|
456
|
+
table names or list of lists of table names means to stack the
|
|
457
|
+
tables in the inner lists together. Note that table stacking is not
|
|
458
|
+
supported on older TPUs, in which case the default value of `"auto"`
|
|
459
|
+
will be interpreted as no table stacking.
|
|
460
|
+
**kwargs: Additional arguments to pass to the layer base class.
|
|
461
|
+
"""
|
|
462
|
+
|
|
463
|
+
def __init__(
|
|
464
|
+
self,
|
|
465
|
+
feature_configs: types.Nested[FeatureConfig],
|
|
466
|
+
*,
|
|
467
|
+
table_stacking: (
|
|
468
|
+
str | Sequence[str] | Sequence[Sequence[str]]
|
|
469
|
+
) = "auto",
|
|
470
|
+
**kwargs: Any,
|
|
471
|
+
) -> None:
|
|
472
|
+
super().__init__(**kwargs)
|
|
473
|
+
|
|
474
|
+
self._init_feature_configs_structures(feature_configs)
|
|
475
|
+
|
|
476
|
+
# Initialize for features placed on "sparsecore".
|
|
477
|
+
if "sparsecore" in self._placement_to_path_to_feature_config:
|
|
478
|
+
self._sparsecore_init(
|
|
479
|
+
self._placement_to_path_to_feature_config["sparsecore"],
|
|
480
|
+
table_stacking,
|
|
481
|
+
)
|
|
482
|
+
# Initialize for features placed on "default_device".
|
|
483
|
+
if "default_device" in self._placement_to_path_to_feature_config:
|
|
484
|
+
self._default_device_init(
|
|
485
|
+
self._placement_to_path_to_feature_config["default_device"],
|
|
486
|
+
table_stacking,
|
|
487
|
+
)
|
|
488
|
+
|
|
489
|
+
@keras_utils.no_automatic_dependency_tracking
|
|
490
|
+
def _init_feature_configs_structures(
|
|
491
|
+
self,
|
|
492
|
+
feature_configs: types.Nested[FeatureConfig],
|
|
493
|
+
) -> None:
|
|
494
|
+
"""Initializations for efficiently transforming nested structures.
|
|
495
|
+
|
|
496
|
+
This layer handles arbitrarily nested structures for input features, and
|
|
497
|
+
therefore for outputs and feature configs. However, as an intermediary
|
|
498
|
+
format we use a two-level representation with nested dicts. the top
|
|
499
|
+
level dict is keyed by placement and the inner dict is keyed by path,
|
|
500
|
+
with the path representing the path in the original deeply nested
|
|
501
|
+
structure. Thanks to this intermediate representation, we can:
|
|
502
|
+
- dispatch the inputs by placement to overridden methods
|
|
503
|
+
- have backend specific implementations support only one level of
|
|
504
|
+
nesting.
|
|
505
|
+
|
|
506
|
+
This method is responsible for creating structures that allow this
|
|
507
|
+
conversion to happen in a few lines of code and efficiently. The
|
|
508
|
+
following attributes are created:
|
|
509
|
+
- self._feature_configs: the deeply nested `FeatureConfig` instances as
|
|
510
|
+
provided by user in `__init__`
|
|
511
|
+
- self._feature_deeply_nested_placement_and_paths: `PlacementAndPath`
|
|
512
|
+
instances in the same deeply nested structure as
|
|
513
|
+
`self._feature_configs`. Needed for `build` because flatten cannot be
|
|
514
|
+
used as it would expand the shape tuples.
|
|
515
|
+
- self._placement_to_path_to_feature_config: `FeatureConfig` instances
|
|
516
|
+
in the same two-level representation keyed by placement and then path.
|
|
517
|
+
Used to go from a flat representation to the intermediate
|
|
518
|
+
representation.
|
|
519
|
+
|
|
520
|
+
With these structures in place, the steps to:
|
|
521
|
+
- go from the deeply nested structure to the two-level structure are:
|
|
522
|
+
- `assert_same_struct` as `self._feature_configs`
|
|
523
|
+
- use `self._feature_deeply_nested_placement_and_paths` to map from
|
|
524
|
+
deeply nested to two-level
|
|
525
|
+
- go from the two-level structure to the deeply nested structure:
|
|
526
|
+
- `assert_same_struct` as `self._placement_to_path_to_feature_config`
|
|
527
|
+
- use `self._feature_deeply_nested_placement_and_paths` to locate each
|
|
528
|
+
output in the two-level dicts
|
|
529
|
+
|
|
530
|
+
Args:
|
|
531
|
+
feature_configs: The deeply nested structure of `FeatureConfig` or
|
|
532
|
+
`tf.tpu.experimental.embedding.FeatureConfig` as provided by the
|
|
533
|
+
user.
|
|
534
|
+
"""
|
|
535
|
+
# Needs to be assigned with `no_automatic_dependency_tracking` to not
|
|
536
|
+
# alter the data structure types.
|
|
537
|
+
self._feature_configs = feature_configs
|
|
538
|
+
|
|
539
|
+
placement_and_paths: list[PlacementAndPath] = []
|
|
540
|
+
paths_and_feature_configs = keras.tree.flatten_with_path(
|
|
541
|
+
self._feature_configs
|
|
542
|
+
)
|
|
543
|
+
self._placement_to_path_to_feature_config: dict[
|
|
544
|
+
str, dict[str, FeatureConfig]
|
|
545
|
+
] = {}
|
|
546
|
+
|
|
547
|
+
# Lazily initialized.
|
|
548
|
+
has_sparsecore = None
|
|
549
|
+
|
|
550
|
+
for path, feature_config in paths_and_feature_configs:
|
|
551
|
+
if isinstance(feature_config, FeatureConfig):
|
|
552
|
+
placement = feature_config.table.placement
|
|
553
|
+
# Resolve "auto" to an actual placement.
|
|
554
|
+
if placement == "auto":
|
|
555
|
+
if has_sparsecore is None:
|
|
556
|
+
has_sparsecore = self._has_sparsecore()
|
|
557
|
+
placement = (
|
|
558
|
+
"sparsecore" if has_sparsecore else "default_device"
|
|
559
|
+
)
|
|
560
|
+
else:
|
|
561
|
+
# It's a `tf.tpu.experimental.embedding.FeatureConfig`.
|
|
562
|
+
placement = "sparsecore"
|
|
563
|
+
|
|
564
|
+
path = ".".join([str(e) for e in path])
|
|
565
|
+
if placement not in SUPPORTED_PLACEMENTS:
|
|
566
|
+
raise ValueError(
|
|
567
|
+
f"Feature '{path}' with name '{feature_config.name}' has "
|
|
568
|
+
f"unsupported placement '{placement}'."
|
|
569
|
+
)
|
|
570
|
+
placement_and_paths.append(PlacementAndPath(placement, path))
|
|
571
|
+
if placement not in self._placement_to_path_to_feature_config:
|
|
572
|
+
self._placement_to_path_to_feature_config[placement] = {}
|
|
573
|
+
self._placement_to_path_to_feature_config[placement][path] = (
|
|
574
|
+
feature_config
|
|
575
|
+
)
|
|
576
|
+
|
|
577
|
+
self._feature_deeply_nested_placement_and_paths = (
|
|
578
|
+
keras.tree.pack_sequence_as(
|
|
579
|
+
self._feature_configs, placement_and_paths
|
|
580
|
+
)
|
|
581
|
+
)
|
|
582
|
+
|
|
583
|
+
def build(self, input_shapes: types.Nested[types.Shape]) -> None:
|
|
584
|
+
if self.built:
|
|
585
|
+
return
|
|
586
|
+
|
|
587
|
+
self._verify_input_shapes(input_shapes)
|
|
588
|
+
|
|
589
|
+
# Go from deeply nested structure to placement -> path -> input shape.
|
|
590
|
+
placement_to_path_to_input_shape: collections.defaultdict[
|
|
591
|
+
str, dict[str, types.Shape]
|
|
592
|
+
] = collections.defaultdict(dict)
|
|
593
|
+
|
|
594
|
+
def populate_placement_to_path_to_input_shape(
|
|
595
|
+
pp: PlacementAndPath, input_shape: types.Shape
|
|
596
|
+
) -> None:
|
|
597
|
+
placement_to_path_to_input_shape[pp.placement][pp.path] = (
|
|
598
|
+
input_shape
|
|
599
|
+
)
|
|
600
|
+
|
|
601
|
+
keras.tree.map_structure_up_to(
|
|
602
|
+
self._feature_deeply_nested_placement_and_paths,
|
|
603
|
+
populate_placement_to_path_to_input_shape,
|
|
604
|
+
self._feature_deeply_nested_placement_and_paths,
|
|
605
|
+
input_shapes,
|
|
606
|
+
)
|
|
607
|
+
|
|
608
|
+
# Build for features placed on "sparsecore".
|
|
609
|
+
if "sparsecore" in placement_to_path_to_input_shape:
|
|
610
|
+
self._sparsecore_build(
|
|
611
|
+
placement_to_path_to_input_shape["sparsecore"]
|
|
612
|
+
)
|
|
613
|
+
|
|
614
|
+
# Build for features placed on "default_device".
|
|
615
|
+
if "default_device" in placement_to_path_to_input_shape:
|
|
616
|
+
self._default_device_build(
|
|
617
|
+
placement_to_path_to_input_shape["default_device"]
|
|
618
|
+
)
|
|
619
|
+
|
|
620
|
+
super().build(input_shapes)
|
|
621
|
+
|
|
622
|
+
def preprocess(
|
|
623
|
+
self,
|
|
624
|
+
inputs: types.Nested[types.Tensor],
|
|
625
|
+
weights: types.Nested[types.Tensor] | None = None,
|
|
626
|
+
training: bool = False,
|
|
627
|
+
) -> types.Nested[types.Tensor]:
|
|
628
|
+
"""Preprocesses and reformats the data for consumption by the model.
|
|
629
|
+
|
|
630
|
+
For the JAX backend, converts the input data to a hardward-dependent
|
|
631
|
+
format required for use with SparseCores. Calling `preprocess`
|
|
632
|
+
explicitly is only necessary to enable `jit_compile = True`.
|
|
633
|
+
|
|
634
|
+
For non-JAX backends, preprocessing will bundle together the inputs and
|
|
635
|
+
weights, and separate the inputs by device placement. This step is
|
|
636
|
+
entirely optional.
|
|
637
|
+
|
|
638
|
+
Args:
|
|
639
|
+
inputs: Ragged or dense set of sample IDs.
|
|
640
|
+
weights: Optional ragged or dense set of sample weights.
|
|
641
|
+
training: If true, will update internal parameters, such as
|
|
642
|
+
required buffer sizes for the preprocessed data.
|
|
643
|
+
|
|
644
|
+
Returns:
|
|
645
|
+
Set of preprocessed inputs that can be fed directly into the
|
|
646
|
+
`inputs` argument of the layer.
|
|
647
|
+
"""
|
|
648
|
+
# Verify input structure.
|
|
649
|
+
keras.tree.assert_same_structure(self._feature_configs, inputs)
|
|
650
|
+
if weights is not None:
|
|
651
|
+
keras.tree.assert_same_structure(self._feature_configs, weights)
|
|
652
|
+
|
|
653
|
+
if not self.built:
|
|
654
|
+
input_shapes = keras.tree.map_structure(
|
|
655
|
+
lambda array: backend.standardize_shape(array.shape),
|
|
656
|
+
inputs,
|
|
657
|
+
)
|
|
658
|
+
self.build(input_shapes)
|
|
659
|
+
|
|
660
|
+
# Go from deeply nested to nested dict placement -> path -> input.
|
|
661
|
+
def to_placement_to_path(
|
|
662
|
+
tensors: types.Nested[types.Tensor],
|
|
663
|
+
) -> dict[str, dict[str, types.Tensor]]:
|
|
664
|
+
result: dict[str, dict[str, types.Tensor]] = {
|
|
665
|
+
p: dict() for p in self._placement_to_path_to_feature_config
|
|
666
|
+
}
|
|
667
|
+
|
|
668
|
+
def populate(pp: PlacementAndPath, x: types.Tensor) -> None:
|
|
669
|
+
result[pp.placement][pp.path] = x
|
|
670
|
+
|
|
671
|
+
keras.tree.map_structure(
|
|
672
|
+
populate,
|
|
673
|
+
self._feature_deeply_nested_placement_and_paths,
|
|
674
|
+
tensors,
|
|
675
|
+
)
|
|
676
|
+
return result
|
|
677
|
+
|
|
678
|
+
placement_to_path_to_inputs = to_placement_to_path(inputs)
|
|
679
|
+
|
|
680
|
+
# Same for weights if present.
|
|
681
|
+
placement_to_path_to_weights = (
|
|
682
|
+
to_placement_to_path(weights) if weights is not None else None
|
|
683
|
+
)
|
|
684
|
+
|
|
685
|
+
placement_to_path_to_preprocessed: dict[
|
|
686
|
+
str, dict[str, dict[str, types.Nested[types.Tensor]]]
|
|
687
|
+
] = {}
|
|
688
|
+
|
|
689
|
+
# Preprocess for features placed on "sparsecore".
|
|
690
|
+
if "sparsecore" in placement_to_path_to_inputs:
|
|
691
|
+
placement_to_path_to_preprocessed["sparsecore"] = (
|
|
692
|
+
self._sparsecore_preprocess(
|
|
693
|
+
placement_to_path_to_inputs["sparsecore"],
|
|
694
|
+
placement_to_path_to_weights["sparsecore"]
|
|
695
|
+
if placement_to_path_to_weights is not None
|
|
696
|
+
else None,
|
|
697
|
+
training,
|
|
698
|
+
)
|
|
699
|
+
)
|
|
700
|
+
|
|
701
|
+
# Preprocess for features placed on "default_device".
|
|
702
|
+
if "default_device" in placement_to_path_to_inputs:
|
|
703
|
+
placement_to_path_to_preprocessed["default_device"] = (
|
|
704
|
+
self._default_device_preprocess(
|
|
705
|
+
placement_to_path_to_inputs["default_device"],
|
|
706
|
+
placement_to_path_to_weights["default_device"]
|
|
707
|
+
if placement_to_path_to_weights is not None
|
|
708
|
+
else None,
|
|
709
|
+
training,
|
|
710
|
+
)
|
|
711
|
+
)
|
|
712
|
+
|
|
713
|
+
# Mark inputs as preprocessed using an extra level of nesting.
|
|
714
|
+
# This is necessary to detect whether inputs are already preprocessed
|
|
715
|
+
# in `call`.
|
|
716
|
+
output = {
|
|
717
|
+
"preprocessed_inputs_per_placement": (
|
|
718
|
+
placement_to_path_to_preprocessed
|
|
719
|
+
)
|
|
720
|
+
}
|
|
721
|
+
return output
|
|
722
|
+
|
|
723
|
+
def _is_preprocessed(
|
|
724
|
+
self, inputs: types.Nested[types.Tensor | types.Shape]
|
|
725
|
+
) -> bool:
|
|
726
|
+
"""Checks if the input is already preprocessed."""
|
|
727
|
+
return (
|
|
728
|
+
isinstance(inputs, dict)
|
|
729
|
+
and "preprocessed_inputs_per_placement" in inputs
|
|
730
|
+
)
|
|
731
|
+
|
|
732
|
+
def call(
|
|
733
|
+
self,
|
|
734
|
+
inputs: types.Nested[types.Tensor],
|
|
735
|
+
weights: types.Nested[types.Tensor] | None = None,
|
|
736
|
+
training: bool = False,
|
|
737
|
+
) -> types.Nested[types.Tensor]:
|
|
738
|
+
"""Lookup features in embedding tables and apply reduction.
|
|
739
|
+
|
|
740
|
+
Args:
|
|
741
|
+
inputs: A nested structure of 2D tensors to embed and reduce. The
|
|
742
|
+
structure must be the same as the `feature_configs` passed
|
|
743
|
+
during construction. Alternatively, may consist of already
|
|
744
|
+
preprocessed inputs (see `preprocess`).
|
|
745
|
+
weights: An optional nested structure of 2D tensors of weights to
|
|
746
|
+
apply before reduction. When present, the structure must be the
|
|
747
|
+
same as `inputs` and the shapes must match.
|
|
748
|
+
training: Whether we are training or evaluating the model.
|
|
749
|
+
|
|
750
|
+
Returns:
|
|
751
|
+
A nested structure of dense 2D tensors, which are the reduced
|
|
752
|
+
embeddings from the passed features. The structure is the same as
|
|
753
|
+
`inputs`.
|
|
754
|
+
"""
|
|
755
|
+
preprocessed_inputs = inputs
|
|
756
|
+
# Preprocess if not already done.
|
|
757
|
+
if not self._is_preprocessed(inputs):
|
|
758
|
+
preprocessed_inputs = self.preprocess(inputs, weights, training)
|
|
759
|
+
|
|
760
|
+
preprocessed_inputs = typing.cast(
|
|
761
|
+
dict[str, dict[str, dict[str, types.Tensor]]], preprocessed_inputs
|
|
762
|
+
)
|
|
763
|
+
# Placement -> path -> preprocessed inputs.
|
|
764
|
+
preprocessed_inputs = preprocessed_inputs[
|
|
765
|
+
"preprocessed_inputs_per_placement"
|
|
766
|
+
]
|
|
767
|
+
|
|
768
|
+
placement_to_path_to_outputs = {}
|
|
769
|
+
|
|
770
|
+
# Call for features placed on "sparsecore".
|
|
771
|
+
if "sparsecore" in preprocessed_inputs:
|
|
772
|
+
inputs_and_weights = preprocessed_inputs["sparsecore"]
|
|
773
|
+
placement_to_path_to_outputs["sparsecore"] = self._sparsecore_call(
|
|
774
|
+
**inputs_and_weights,
|
|
775
|
+
training=training,
|
|
776
|
+
)
|
|
777
|
+
|
|
778
|
+
# Call for features placed on "default_device".
|
|
779
|
+
if "default_device" in preprocessed_inputs:
|
|
780
|
+
inputs_and_weights = preprocessed_inputs["default_device"]
|
|
781
|
+
placement_to_path_to_outputs["default_device"] = (
|
|
782
|
+
self._default_device_call(
|
|
783
|
+
**inputs_and_weights,
|
|
784
|
+
training=training,
|
|
785
|
+
)
|
|
786
|
+
)
|
|
787
|
+
|
|
788
|
+
# Verify output structure.
|
|
789
|
+
keras.tree.assert_same_structure(
|
|
790
|
+
self._placement_to_path_to_feature_config,
|
|
791
|
+
placement_to_path_to_outputs,
|
|
792
|
+
)
|
|
793
|
+
|
|
794
|
+
# Go from placement -> path -> output to deeply nested structure.
|
|
795
|
+
def populate_output(pp: PlacementAndPath) -> types.Tensor:
|
|
796
|
+
return placement_to_path_to_outputs[pp.placement][pp.path]
|
|
797
|
+
|
|
798
|
+
return keras.tree.map_structure(
|
|
799
|
+
populate_output, self._feature_deeply_nested_placement_and_paths
|
|
800
|
+
)
|
|
801
|
+
|
|
802
|
+
def get_embedding_tables(self) -> dict[str, types.Tensor]:
|
|
803
|
+
"""Return the content of the embedding tables by table name.
|
|
804
|
+
|
|
805
|
+
The tables are keyed by the name provided in each `TableConfig`. Note
|
|
806
|
+
that the returned tensors are not the actual embedding table variables
|
|
807
|
+
used internally by `DistributedEmbedding`.
|
|
808
|
+
|
|
809
|
+
Returns:
|
|
810
|
+
A dictionary of table name to tensor for the embedding tables.
|
|
811
|
+
"""
|
|
812
|
+
tables = {}
|
|
813
|
+
if "sparsecore" in self._placement_to_path_to_feature_config:
|
|
814
|
+
tables.update(self._sparsecore_get_embedding_tables())
|
|
815
|
+
if "default_device" in self._placement_to_path_to_feature_config:
|
|
816
|
+
tables.update(self._default_device_get_embedding_tables())
|
|
817
|
+
return tables
|
|
818
|
+
|
|
819
|
+
def _default_device_init(
|
|
820
|
+
self,
|
|
821
|
+
feature_configs: dict[str, FeatureConfig],
|
|
822
|
+
table_stacking: str | Sequence[Sequence[str]],
|
|
823
|
+
) -> None:
|
|
824
|
+
del table_stacking
|
|
825
|
+
table_config_id_to_embedding_layer: dict[int, EmbedReduce] = {}
|
|
826
|
+
self._default_device_embedding_layers: dict[str, EmbedReduce] = {}
|
|
827
|
+
|
|
828
|
+
for path, feature_config in feature_configs.items():
|
|
829
|
+
if id(feature_config.table) in table_config_id_to_embedding_layer:
|
|
830
|
+
self._default_device_embedding_layers[path] = (
|
|
831
|
+
table_config_id_to_embedding_layer[id(feature_config.table)]
|
|
832
|
+
)
|
|
833
|
+
else:
|
|
834
|
+
embedding_layer = EmbedReduce(
|
|
835
|
+
name=feature_config.table.name,
|
|
836
|
+
input_dim=feature_config.table.vocabulary_size,
|
|
837
|
+
output_dim=feature_config.table.embedding_dim,
|
|
838
|
+
embeddings_initializer=feature_config.table.initializer,
|
|
839
|
+
combiner=feature_config.table.combiner,
|
|
840
|
+
)
|
|
841
|
+
table_config_id_to_embedding_layer[id(feature_config.table)] = (
|
|
842
|
+
embedding_layer
|
|
843
|
+
)
|
|
844
|
+
self._default_device_embedding_layers[path] = embedding_layer
|
|
845
|
+
|
|
846
|
+
def _default_device_build(
|
|
847
|
+
self, input_shapes: dict[str, types.Shape]
|
|
848
|
+
) -> None:
|
|
849
|
+
for path, input_shape in input_shapes.items():
|
|
850
|
+
embedding_layer = self._default_device_embedding_layers[path]
|
|
851
|
+
if not embedding_layer.built:
|
|
852
|
+
embedding_layer.build(input_shape)
|
|
853
|
+
|
|
854
|
+
def _default_device_preprocess(
|
|
855
|
+
self,
|
|
856
|
+
inputs: dict[str, types.Tensor],
|
|
857
|
+
weights: dict[str, types.Tensor] | None,
|
|
858
|
+
training: bool = False,
|
|
859
|
+
) -> dict[str, dict[str, types.Tensor]]:
|
|
860
|
+
del training
|
|
861
|
+
|
|
862
|
+
# NOTE: This JAX specialization is in the base layer so it is available
|
|
863
|
+
# on all platforms. The superclass jax.DistributedEmbedding layer
|
|
864
|
+
# is currently only imported in linux_x86_64.
|
|
865
|
+
if keras.backend.backend() == "jax":
|
|
866
|
+
feature_configs = self._placement_to_path_to_feature_config[
|
|
867
|
+
"default_device"
|
|
868
|
+
]
|
|
869
|
+
|
|
870
|
+
# Potentially track new weights. For ragged inputs, if we
|
|
871
|
+
# densify, we will generate a dense weight tensor.
|
|
872
|
+
new_weights: dict[str, types.Tensor] = {}
|
|
873
|
+
use_weights = weights is not None
|
|
874
|
+
|
|
875
|
+
# Convert any ragged inputs to dense.
|
|
876
|
+
for path, config in feature_configs.items():
|
|
877
|
+
feature_inputs = inputs[path]
|
|
878
|
+
feature_weights = weights[path] if weights is not None else None
|
|
879
|
+
|
|
880
|
+
feature_valence = (
|
|
881
|
+
None
|
|
882
|
+
if len(config.input_shape) <= 1
|
|
883
|
+
else config.input_shape[1]
|
|
884
|
+
)
|
|
885
|
+
feature_inputs, feature_weights = _ragged_to_dense_inputs(
|
|
886
|
+
feature_inputs, feature_weights, feature_valence
|
|
887
|
+
)
|
|
888
|
+
# Converting to ragged may have introduced a weights array.
|
|
889
|
+
use_weights = use_weights or feature_weights is not None
|
|
890
|
+
inputs[path] = feature_inputs
|
|
891
|
+
new_weights[path] = feature_weights
|
|
892
|
+
|
|
893
|
+
if use_weights:
|
|
894
|
+
weights = new_weights
|
|
895
|
+
|
|
896
|
+
output: dict[str, types.Tensor] = {"inputs": inputs}
|
|
897
|
+
if weights is not None:
|
|
898
|
+
output["weights"] = weights
|
|
899
|
+
|
|
900
|
+
return output
|
|
901
|
+
|
|
902
|
+
def _default_device_call(
|
|
903
|
+
self,
|
|
904
|
+
inputs: dict[str, types.Tensor],
|
|
905
|
+
weights: dict[str, types.Tensor] | None = None,
|
|
906
|
+
training: bool = False,
|
|
907
|
+
) -> dict[str, types.Tensor]:
|
|
908
|
+
del training # Unused by default.
|
|
909
|
+
if weights is None:
|
|
910
|
+
return {
|
|
911
|
+
path: self._default_device_embedding_layers[path](x)
|
|
912
|
+
for path, x in inputs.items()
|
|
913
|
+
}
|
|
914
|
+
else:
|
|
915
|
+
return {
|
|
916
|
+
path: self._default_device_embedding_layers[path](
|
|
917
|
+
x, weights[path]
|
|
918
|
+
)
|
|
919
|
+
for path, x in inputs.items()
|
|
920
|
+
}
|
|
921
|
+
|
|
922
|
+
def _default_device_get_embedding_tables(self) -> dict[str, types.Tensor]:
|
|
923
|
+
tables = {}
|
|
924
|
+
for path, feature_config in self._placement_to_path_to_feature_config[
|
|
925
|
+
"default_device"
|
|
926
|
+
].items():
|
|
927
|
+
tables[feature_config.table.name] = (
|
|
928
|
+
self._default_device_embedding_layers[path].embeddings.value
|
|
929
|
+
)
|
|
930
|
+
return tables
|
|
931
|
+
|
|
932
|
+
def _has_sparsecore(self) -> bool:
|
|
933
|
+
# Explicitly check for SparseCore availability.
|
|
934
|
+
# We need this check here rather than in jax/distributed_embedding.py
|
|
935
|
+
# so that we can warn the user about missing dependencies.
|
|
936
|
+
if keras.backend.backend() == "jax":
|
|
937
|
+
# Check if SparseCores are available.
|
|
938
|
+
try:
|
|
939
|
+
import jax
|
|
940
|
+
|
|
941
|
+
tpu_devices = jax.devices("tpu")
|
|
942
|
+
except RuntimeError:
|
|
943
|
+
# No TPUs available.
|
|
944
|
+
return False
|
|
945
|
+
|
|
946
|
+
if len(tpu_devices) > 0:
|
|
947
|
+
device_kind = tpu_devices[0].device_kind
|
|
948
|
+
if device_kind in ["TPU v5", "TPU v6 lite"]:
|
|
949
|
+
return True
|
|
950
|
+
|
|
951
|
+
return False
|
|
952
|
+
|
|
953
|
+
def _sparsecore_init(
|
|
954
|
+
self,
|
|
955
|
+
feature_configs: dict[str, FeatureConfig],
|
|
956
|
+
table_stacking: str | Sequence[Sequence[str]],
|
|
957
|
+
) -> None:
|
|
958
|
+
del feature_configs, table_stacking
|
|
959
|
+
|
|
960
|
+
if keras.backend.backend() == "jax":
|
|
961
|
+
jax_tpu_embedding_spec = importlib.util.find_spec(
|
|
962
|
+
"jax_tpu_embedding"
|
|
963
|
+
)
|
|
964
|
+
if jax_tpu_embedding_spec is None:
|
|
965
|
+
raise ImportError(
|
|
966
|
+
"Please install jax-tpu-embedding to use "
|
|
967
|
+
"DistributedEmbedding on sparsecore devices."
|
|
968
|
+
)
|
|
969
|
+
|
|
970
|
+
raise self._unsupported_placement_error("sparsecore")
|
|
971
|
+
|
|
972
|
+
def _sparsecore_build(self, input_shapes: dict[str, types.Shape]) -> None:
|
|
973
|
+
del input_shapes
|
|
974
|
+
raise self._unsupported_placement_error("sparsecore")
|
|
975
|
+
|
|
976
|
+
def _sparsecore_preprocess(
|
|
977
|
+
self,
|
|
978
|
+
inputs: dict[str, types.Tensor],
|
|
979
|
+
weights: dict[str, types.Tensor] | None,
|
|
980
|
+
training: bool = False,
|
|
981
|
+
) -> dict[str, dict[str, types.Tensor]]:
|
|
982
|
+
del training
|
|
983
|
+
output: dict[str, types.Tensor] = {"inputs": inputs}
|
|
984
|
+
if weights is not None:
|
|
985
|
+
output["weights"] = weights
|
|
986
|
+
|
|
987
|
+
return output
|
|
988
|
+
|
|
989
|
+
def _sparsecore_call(
|
|
990
|
+
self,
|
|
991
|
+
inputs: dict[str, types.Tensor],
|
|
992
|
+
weights: dict[str, types.Tensor] | None = None,
|
|
993
|
+
training: bool = False,
|
|
994
|
+
) -> dict[str, types.Tensor]:
|
|
995
|
+
del inputs, weights, training
|
|
996
|
+
raise self._unsupported_placement_error("sparsecore")
|
|
997
|
+
|
|
998
|
+
def _sparsecore_get_embedding_tables(self) -> dict[str, types.Tensor]:
|
|
999
|
+
raise self._unsupported_placement_error("sparsecore")
|
|
1000
|
+
|
|
1001
|
+
def compute_output_shape(
|
|
1002
|
+
self, input_shapes: types.Nested[types.Shape]
|
|
1003
|
+
) -> types.Nested[types.Shape]:
|
|
1004
|
+
self._verify_input_shapes(input_shapes)
|
|
1005
|
+
output_shape: types.Nested[types.Shape] = keras.tree.map_structure(
|
|
1006
|
+
lambda fc: fc.output_shape, self._feature_configs
|
|
1007
|
+
)
|
|
1008
|
+
return output_shape
|
|
1009
|
+
|
|
1010
|
+
def get_config(self) -> dict[str, Any]:
|
|
1011
|
+
# Because the Keras serialization creates a tree of serialized objects,
|
|
1012
|
+
# it does not directly support sharing tables between feature configs.
|
|
1013
|
+
# We therefore serialize the tables config as a flat list and then refer
|
|
1014
|
+
# to them by index in each feature config.
|
|
1015
|
+
|
|
1016
|
+
# The serialized `TableConfig` objects.
|
|
1017
|
+
table_config_dicts: list[dict[str, Any]] = []
|
|
1018
|
+
# Mapping from `TableConfig` id to index in `table_config_dicts`.
|
|
1019
|
+
table_config_id_to_index: dict[int, int] = {}
|
|
1020
|
+
|
|
1021
|
+
def serialize_feature_config(
|
|
1022
|
+
feature_config: FeatureConfig,
|
|
1023
|
+
) -> dict[str, Any]:
|
|
1024
|
+
# Note that for consistency with the contract of `get_config`, the
|
|
1025
|
+
# returned dict contains the serialized `TableConfig` in the "table"
|
|
1026
|
+
# key.
|
|
1027
|
+
feature_config_dict = feature_config.get_config()
|
|
1028
|
+
|
|
1029
|
+
if id(feature_config.table) not in table_config_id_to_index:
|
|
1030
|
+
# Save the serialized `TableConfig` the first time we see it and
|
|
1031
|
+
# remember its index.
|
|
1032
|
+
table_config_id_to_index[id(feature_config.table)] = len(
|
|
1033
|
+
table_config_dicts
|
|
1034
|
+
)
|
|
1035
|
+
table_config_dicts.append(feature_config_dict["table"])
|
|
1036
|
+
|
|
1037
|
+
# Replace the serialized `TableConfig` with its index.
|
|
1038
|
+
feature_config_dict["table"] = table_config_id_to_index[
|
|
1039
|
+
id(feature_config.table)
|
|
1040
|
+
]
|
|
1041
|
+
return feature_config_dict
|
|
1042
|
+
|
|
1043
|
+
config: dict[str, Any] = super().get_config()
|
|
1044
|
+
config["feature_configs"] = keras.tree.map_structure(
|
|
1045
|
+
serialize_feature_config, self._feature_configs
|
|
1046
|
+
)
|
|
1047
|
+
config["tables"] = table_config_dicts
|
|
1048
|
+
if hasattr(self, "_table_stacking"):
|
|
1049
|
+
config["table_stacking"] = self._table_stacking
|
|
1050
|
+
return config
|
|
1051
|
+
|
|
1052
|
+
@classmethod
|
|
1053
|
+
def from_config(cls, config: dict[str, Any]) -> "DistributedEmbedding":
|
|
1054
|
+
config = config.copy()
|
|
1055
|
+
# We need to reconnect the `TableConfig`s to the `FeatureConfig`s.
|
|
1056
|
+
|
|
1057
|
+
# The serialized `TableConfig` objects.
|
|
1058
|
+
table_config_dicts: list[dict[str, Any]] = config.pop("tables")
|
|
1059
|
+
# The deserialized `TableConfig` objects at the same indices.
|
|
1060
|
+
table_configs: list[TableConfig | None] = [None] * len(
|
|
1061
|
+
table_config_dicts
|
|
1062
|
+
)
|
|
1063
|
+
|
|
1064
|
+
def deserialize_feature_config(
|
|
1065
|
+
feature_config_dict: dict[str, Any],
|
|
1066
|
+
) -> FeatureConfig | None:
|
|
1067
|
+
# Look for a "name" attribute which is a string to detect a
|
|
1068
|
+
# `FeatureConfig` leaf node. If not, keep recursing.
|
|
1069
|
+
if "name" not in feature_config_dict or not isinstance(
|
|
1070
|
+
feature_config_dict["name"], str
|
|
1071
|
+
):
|
|
1072
|
+
# Tell `traverse` to recurse.
|
|
1073
|
+
return None
|
|
1074
|
+
|
|
1075
|
+
table_index = feature_config_dict["table"]
|
|
1076
|
+
# Note that for consistency with the contract of `from_config`, the
|
|
1077
|
+
# passed dict must contain the serialized `TableConfig` in the
|
|
1078
|
+
# "table" key.
|
|
1079
|
+
feature_config_dict["table"] = table_config_dicts[table_index]
|
|
1080
|
+
feature_config = FeatureConfig.from_config(feature_config_dict)
|
|
1081
|
+
# But then dedupe `TableConfig`s.
|
|
1082
|
+
if table_configs[table_index] is None:
|
|
1083
|
+
# Remember each new `TableConfig` we see.
|
|
1084
|
+
table_configs[table_index] = feature_config.table
|
|
1085
|
+
else:
|
|
1086
|
+
# And swap duplicates for the original.
|
|
1087
|
+
feature_config.table = table_configs[table_index]
|
|
1088
|
+
return feature_config
|
|
1089
|
+
|
|
1090
|
+
# Because each `FeatureConfig` is serialized as a dict, we cannot use
|
|
1091
|
+
# `map_structure` as it would recurse in the config itself. We use
|
|
1092
|
+
# `traverse` instead with a function that detects leaf nodes.
|
|
1093
|
+
config["feature_configs"] = keras.tree.traverse(
|
|
1094
|
+
deserialize_feature_config, config["feature_configs"]
|
|
1095
|
+
)
|
|
1096
|
+
return cls(**config)
|
|
1097
|
+
|
|
1098
|
+
def _verify_input_shapes(
|
|
1099
|
+
self, input_shapes: types.Nested[types.Shape]
|
|
1100
|
+
) -> None:
|
|
1101
|
+
"""Verifies that the input shapes match the ones in the feature configs.
|
|
1102
|
+
|
|
1103
|
+
Args:
|
|
1104
|
+
input_shapes: The structure of input shapes to verify.
|
|
1105
|
+
"""
|
|
1106
|
+
# Support preprocessing.
|
|
1107
|
+
if self._is_preprocessed(input_shapes):
|
|
1108
|
+
# Structure should be :
|
|
1109
|
+
# {
|
|
1110
|
+
# placement: {
|
|
1111
|
+
# inputs: {path: Any},
|
|
1112
|
+
# weights: {path: Any}
|
|
1113
|
+
# }
|
|
1114
|
+
# }
|
|
1115
|
+
#
|
|
1116
|
+
# But the `Any` values could be nested tensors with varying
|
|
1117
|
+
# structure, depending on hardware constraints. This complicates
|
|
1118
|
+
# checking shapes via keras.tree methods. So, assume the
|
|
1119
|
+
# input is a result of explicitly calling the `preprocess(...)`
|
|
1120
|
+
# function, in which case the structure has already been verified.
|
|
1121
|
+
return
|
|
1122
|
+
|
|
1123
|
+
def _verify_input_shape(
|
|
1124
|
+
feature_config: FeatureConfig,
|
|
1125
|
+
input_shape: types.Shape,
|
|
1126
|
+
) -> None:
|
|
1127
|
+
if not isinstance(input_shape, (tuple, list)) or not all(
|
|
1128
|
+
isinstance(d, (int, type(None))) for d in input_shape
|
|
1129
|
+
):
|
|
1130
|
+
raise ValueError(f"Received invalid input shape {input_shape}.")
|
|
1131
|
+
if len(input_shape) < 1:
|
|
1132
|
+
raise ValueError(
|
|
1133
|
+
f"Received input shape {input_shape}. Rank must be 1 or "
|
|
1134
|
+
"above."
|
|
1135
|
+
)
|
|
1136
|
+
keras_utils.check_shapes_compatible(
|
|
1137
|
+
feature_config.input_shape, input_shape
|
|
1138
|
+
)
|
|
1139
|
+
|
|
1140
|
+
keras.tree.map_structure_up_to(
|
|
1141
|
+
self._feature_configs,
|
|
1142
|
+
_verify_input_shape,
|
|
1143
|
+
self._feature_configs,
|
|
1144
|
+
input_shapes,
|
|
1145
|
+
)
|
|
1146
|
+
|
|
1147
|
+
def _unsupported_placement_error(self, placement: str) -> Exception:
|
|
1148
|
+
return NotImplementedError(
|
|
1149
|
+
f"Backend '{keras.backend.backend()}' does not support the "
|
|
1150
|
+
f"'{placement}' placement."
|
|
1151
|
+
)
|