keras-rs-nightly 0.2.2.dev202508190331__py3-none-any.whl → 0.3.1.dev202512130338__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,17 +1,13 @@
1
1
  """Utility functions for manipulating JAX embedding tables and inputs."""
2
2
 
3
3
  import collections
4
- import dataclasses
5
- import typing
6
4
  from typing import Any, Mapping, NamedTuple, Sequence, TypeAlias, TypeVar
7
5
 
8
6
  import jax
9
7
  import numpy as np
10
- from jax import numpy as jnp
11
8
  from jax_tpu_embedding.sparsecore.lib.nn import embedding
9
+ from jax_tpu_embedding.sparsecore.lib.nn import table_stacking
12
10
  from jax_tpu_embedding.sparsecore.lib.nn.embedding_spec import FeatureSpec
13
- from jax_tpu_embedding.sparsecore.lib.nn.embedding_spec import StackedTableSpec
14
- from jax_tpu_embedding.sparsecore.lib.nn.embedding_spec import TableSpec
15
11
 
16
12
  from keras_rs.src.types import Nested
17
13
 
@@ -24,7 +20,7 @@ Shape: TypeAlias = tuple[int, ...]
24
20
 
25
21
  class FeatureSamples(NamedTuple):
26
22
  tokens: ArrayLike
27
- weights: ArrayLike
23
+ weights: ArrayLike | None
28
24
 
29
25
 
30
26
  class ShardedCooMatrix(NamedTuple):
@@ -35,357 +31,6 @@ class ShardedCooMatrix(NamedTuple):
35
31
  values: ArrayLike
36
32
 
37
33
 
38
- def _round_up_to_multiple(value: int, multiple: int) -> int:
39
- return ((value + multiple - 1) // multiple) * multiple
40
-
41
-
42
- def _default_stacked_table_spec(
43
- table_spec: TableSpec, num_shards: int, batch_size: int
44
- ) -> StackedTableSpec:
45
- return StackedTableSpec(
46
- stack_name=table_spec.name,
47
- stack_vocab_size=_round_up_to_multiple(
48
- table_spec.vocabulary_size, 8 * num_shards
49
- ),
50
- stack_embedding_dim=_round_up_to_multiple(table_spec.embedding_dim, 8),
51
- optimizer=table_spec.optimizer,
52
- combiner=table_spec.combiner,
53
- total_sample_count=batch_size,
54
- max_ids_per_partition=table_spec.max_ids_per_partition,
55
- max_unique_ids_per_partition=table_spec.max_unique_ids_per_partition,
56
- )
57
-
58
-
59
- def _get_stacked_table_spec(
60
- table_spec: TableSpec, num_shards: int, batch_size: int = 0
61
- ) -> StackedTableSpec:
62
- return table_spec.stacked_table_spec or _default_stacked_table_spec(
63
- table_spec, num_shards, batch_size
64
- )
65
-
66
-
67
- def pad_table(
68
- table_spec: TableSpec,
69
- table_values: jax.Array,
70
- num_shards: int,
71
- pad_value: jnp.float32 = jnp.nan,
72
- ) -> jax.Array:
73
- """Adds appropriate padding to a table to prepare for stacking.
74
-
75
- Args:
76
- table_spec: Table specification describing the table to pad.
77
- table_values: Table values array to pad.
78
- num_shards: Number of shards in the table (typically
79
- `global_device_count * num_sc_per_device`).
80
- pad_value: Value to use for padding.
81
-
82
- Returns:
83
- Padded table values.
84
- """
85
- vocabulary_size = table_spec.vocabulary_size
86
- embedding_dim = table_spec.embedding_dim
87
- padded_vocabulary_size = _round_up_to_multiple(
88
- vocabulary_size, 8 * num_shards
89
- )
90
- stack_embedding_dim = _get_stacked_table_spec(
91
- table_spec, num_shards
92
- ).stack_embedding_dim
93
- return jnp.pad(
94
- table_values,
95
- (
96
- (0, padded_vocabulary_size - vocabulary_size),
97
- (0, stack_embedding_dim - embedding_dim),
98
- ),
99
- constant_values=pad_value,
100
- )
101
-
102
-
103
- def _stack_and_shard_table(
104
- stacked_table: jax.Array,
105
- table_spec: TableSpec,
106
- table: jax.Array,
107
- num_shards: int,
108
- pad_value: jnp.float32,
109
- ) -> jax.Array:
110
- """Stacks and shards a single table for use in sparsecore lookups."""
111
- padded_values = pad_table(table_spec, table, num_shards, pad_value)
112
- sharded_padded_vocabulary_size = padded_values.shape[0] // num_shards
113
- stack_embedding_dim = stacked_table.shape[-1]
114
-
115
- # Mod-shard vocabulary across devices.
116
- sharded_values = jnp.swapaxes(
117
- padded_values.reshape(-1, num_shards, stack_embedding_dim),
118
- 0,
119
- 1,
120
- )
121
-
122
- # Rotate shards.
123
- setting_in_stack = table_spec.setting_in_stack
124
- rotated_values = jnp.roll(
125
- sharded_values, setting_in_stack.shard_rotation, axis=0
126
- )
127
-
128
- # Insert table into the stack.
129
- table_row = setting_in_stack.row_offset_in_shard
130
- stacked_table = stacked_table.at[
131
- :, table_row : (table_row + sharded_padded_vocabulary_size), :
132
- ].set(rotated_values)
133
-
134
- return stacked_table
135
-
136
-
137
- def stack_and_shard_tables(
138
- table_specs: Nested[TableSpec],
139
- tables: Nested[ArrayLike],
140
- num_shards: int,
141
- pad_value: jnp.float32 = jnp.nan,
142
- ) -> dict[str, Nested[jax.Array]]:
143
- """Stacks and shards tables for use in sparsecore lookups.
144
-
145
- Args:
146
- table_specs: Nested collection of unstacked table specifications.
147
- tables: Table values corresponding to the table_specs.
148
- num_shards: Number of shards in the table (typically
149
- `global_device_count * num_sc_per_device`).
150
- pad_value: Value to use for padding.
151
-
152
- Returns:
153
- A mapping of stacked table names to stacked table values.
154
- """
155
-
156
- # Gather stacked table information.
157
- stacked_table_map: dict[
158
- str,
159
- tuple[StackedTableSpec, list[TableSpec]],
160
- ] = {}
161
-
162
- def collect_stacked_tables(table_spec: TableSpec) -> None:
163
- stacked_table_spec = _get_stacked_table_spec(table_spec, num_shards)
164
- stacked_table_name = stacked_table_spec.stack_name
165
- if stacked_table_name not in stacked_table_map:
166
- stacked_table_map[stacked_table_name] = (stacked_table_spec, [])
167
- stacked_table_map[stacked_table_name][1].append(table_spec)
168
-
169
- _ = jax.tree.map(collect_stacked_tables, table_specs)
170
-
171
- table_map: dict[str, Nested[jax.Array]] = {}
172
-
173
- def collect_tables(table_spec: TableSpec, table: Nested[jax.Array]) -> None:
174
- table_map[table_spec.name] = table
175
-
176
- _ = jax.tree.map(collect_tables, table_specs, tables)
177
-
178
- stacked_tables: dict[str, Nested[jax.Array]] = {}
179
- for (
180
- stacked_table_spec,
181
- table_specs,
182
- ) in stacked_table_map.values():
183
- stack_vocab_size = stacked_table_spec.stack_vocab_size
184
- sharded_vocab_size = stack_vocab_size // num_shards
185
- stack_embedding_dim = stacked_table_spec.stack_embedding_dim
186
-
187
- # Allocate initial buffer. The stacked table will be divided among
188
- # shards by splitting the vocabulary dimension:
189
- # [ v, e ] -> [s, v/s, e]
190
- stacked_table_tree = jax.tree.map(
191
- lambda _: jnp.zeros(
192
- # pylint: disable-next=cell-var-from-loop, used only in loop body.
193
- shape=(num_shards, sharded_vocab_size, stack_embedding_dim),
194
- dtype=jnp.float32,
195
- ),
196
- table_map[table_specs[0].name],
197
- )
198
-
199
- for table_spec in table_specs:
200
- table_tree = table_map[table_spec.name]
201
- stacked_table_tree = jax.tree.map(
202
- lambda stacked_table, table: _stack_and_shard_table(
203
- # pylint: disable-next=cell-var-from-loop, used only in loop body.
204
- stacked_table,
205
- # pylint: disable-next=cell-var-from-loop, used only in loop body.
206
- table_spec,
207
- table,
208
- num_shards,
209
- pad_value,
210
- ),
211
- stacked_table_tree,
212
- table_tree,
213
- )
214
-
215
- stacked_tables[stacked_table_spec.stack_name] = stacked_table_tree
216
-
217
- return stacked_tables
218
-
219
-
220
- def _unshard_and_unstack_table(
221
- table_spec: TableSpec,
222
- stacked_table_tree: Nested[jax.Array],
223
- num_shards: int,
224
- ) -> Nested[jax.Array]:
225
- """Unshards and unstacks a single table."""
226
- vocabulary_size = table_spec.vocabulary_size
227
- embedding_dim = table_spec.embedding_dim
228
-
229
- def _unshard_and_unstack_single_table(
230
- table_spec: TableSpec, stacked_table: jax.Array
231
- ) -> jax.Array:
232
- stack_embedding_dim = stacked_table.shape[-1]
233
-
234
- # Maybe re-shape in case it was flattened.
235
- stacked_table = stacked_table.reshape(
236
- num_shards, -1, stack_embedding_dim
237
- )
238
- sharded_vocabulary_size = (
239
- _round_up_to_multiple(vocabulary_size, 8 * num_shards) // num_shards
240
- )
241
-
242
- # Extract padded values from the stacked table.
243
- setting_in_stack = table_spec.setting_in_stack
244
- row = setting_in_stack.row_offset_in_shard
245
- padded_values = stacked_table[
246
- :, row : (row + sharded_vocabulary_size), :
247
- ]
248
-
249
- # Un-rotate shards.
250
- padded_values = jnp.roll(
251
- padded_values, -setting_in_stack.shard_rotation, axis=0
252
- )
253
-
254
- # Un-mod-shard.
255
- padded_values = jnp.swapaxes(padded_values, 0, 1).reshape(
256
- -1, stack_embedding_dim
257
- )
258
-
259
- # Un-pad.
260
- return padded_values[:vocabulary_size, :embedding_dim]
261
-
262
- output: Nested[jax.Array] = jax.tree.map(
263
- lambda stacked_table: _unshard_and_unstack_single_table(
264
- table_spec, stacked_table
265
- ),
266
- stacked_table_tree,
267
- )
268
- return output
269
-
270
-
271
- def unshard_and_unstack_tables(
272
- table_specs: Nested[TableSpec],
273
- stacked_tables: Mapping[str, Nested[jax.Array]],
274
- num_shards: int,
275
- ) -> Nested[jax.Array]:
276
- """Unshards and unstacks a collection of tables.
277
-
278
- Args:
279
- table_specs: Nested collection of unstacked table specifications.
280
- stacked_tables: Mapping of stacked table names to stacked table values.
281
- num_shards: Number of shards in the table (typically
282
- `global_device_count * num_sc_per_device`).
283
-
284
- Returns:
285
- A mapping of table names to unstacked table values.
286
- """
287
- output: Nested[jax.Array] = jax.tree.map(
288
- lambda table_spec: _unshard_and_unstack_table(
289
- table_spec,
290
- stacked_tables[
291
- _get_stacked_table_spec(table_spec, num_shards=1).stack_name
292
- ],
293
- num_shards,
294
- ),
295
- table_specs,
296
- )
297
- return output
298
-
299
-
300
- def get_table_specs(feature_specs: Nested[FeatureSpec]) -> dict[str, TableSpec]:
301
- table_spec_map: dict[str, TableSpec] = {}
302
- flat_feature_specs, _ = jax.tree.flatten(feature_specs)
303
- for feature_spec in flat_feature_specs:
304
- table_spec = feature_spec.table_spec
305
- table_spec_map[table_spec.name] = table_spec
306
- return table_spec_map
307
-
308
-
309
- def get_table_stacks(
310
- table_specs: Nested[TableSpec],
311
- ) -> dict[str, list[TableSpec]]:
312
- """Extracts lists of tables that are stacked together.
313
-
314
- Args:
315
- table_specs: Nested collection of table specifications.
316
-
317
- Returns:
318
- A mapping of stacked table names to lists of table specifications for
319
- each stack.
320
- """
321
- stacked_table_specs: dict[str, list[TableSpec]] = collections.defaultdict(
322
- list
323
- )
324
- flat_table_specs, _ = jax.tree.flatten(table_specs)
325
- for table_spec in flat_table_specs:
326
- table_spec = typing.cast(TableSpec, table_spec)
327
- stacked_table_spec = table_spec.stacked_table_spec
328
- if stacked_table_spec is not None:
329
- stacked_table_specs[stacked_table_spec.stack_name].append(
330
- table_spec
331
- )
332
- else:
333
- stacked_table_specs[table_spec.name].append(table_spec)
334
-
335
- return stacked_table_specs
336
-
337
-
338
- def update_stacked_table_specs(
339
- feature_specs: Nested[FeatureSpec],
340
- max_ids_per_partition: Mapping[str, int],
341
- max_unique_ids_per_partition: Mapping[str, int],
342
- ) -> None:
343
- """Updates properties in the supplied feature specs.
344
-
345
- Args:
346
- feature_specs: Feature specs to update in-place.
347
- max_ids_per_partition: Mapping of table stack name to
348
- new `max_ids_per_partition` for the stack.
349
- max_unique_ids_per_partition: Mapping of table stack name to
350
- new `max_unique_ids_per_partition` for the stack.
351
- """
352
- # Collect table specs and stacked table specs.
353
- table_specs: dict[str, TableSpec] = {}
354
- for feature_spec in jax.tree.flatten(feature_specs)[0]:
355
- feature_spec = typing.cast(FeatureSpec, feature_spec)
356
- table_specs[feature_spec.table_spec.name] = feature_spec.table_spec
357
-
358
- stacked_table_specs: dict[str, StackedTableSpec] = {}
359
- for table_spec in table_specs.values():
360
- stacked_table_spec = typing.cast(
361
- StackedTableSpec, table_spec.stacked_table_spec
362
- )
363
- stacked_table_specs[stacked_table_spec.stack_name] = stacked_table_spec
364
-
365
- # Replace fields in the stacked_table_specs.
366
- stacked_table_specs = {
367
- stack_name: dataclasses.replace(
368
- stacked_table_spec,
369
- max_ids_per_partition=max_ids_per_partition[
370
- stacked_table_spec.stack_name
371
- ],
372
- max_unique_ids_per_partition=max_unique_ids_per_partition[
373
- stacked_table_spec.stack_name
374
- ],
375
- )
376
- for stack_name, stacked_table_spec in stacked_table_specs.items()
377
- }
378
-
379
- # Insert new stacked tables into tables.
380
- for table_spec in table_specs.values():
381
- stacked_table_spec = typing.cast(
382
- StackedTableSpec, table_spec.stacked_table_spec
383
- )
384
- table_spec.stacked_table_spec = stacked_table_specs[
385
- stacked_table_spec.stack_name
386
- ]
387
-
388
-
389
34
  def convert_to_numpy(
390
35
  ragged_or_dense: np.ndarray[Any, Any] | Sequence[Sequence[Any]] | Any,
391
36
  dtype: Any,
@@ -439,36 +84,6 @@ def convert_to_numpy(
439
84
  )
440
85
 
441
86
 
442
- def ones_like(
443
- ragged_or_dense: np.ndarray[Any, Any], dtype: Any = None
444
- ) -> np.ndarray[Any, Any]:
445
- """Creates an array of ones the same as as the input.
446
-
447
- This differs from traditional numpy in that a ragged input will lead to
448
- a resulting ragged array of ones, whereas np.ones_like(...) will instead
449
- only consider the outer array and return a 1D dense array of ones.
450
-
451
- Args:
452
- ragged_or_dense: The ragged or dense input whose shape and data-type
453
- define these same attributes of the returned array.
454
- dtype: The data-type of the returned array.
455
-
456
- Returns:
457
- An array of ones with the same shape as the input, and specified data
458
- type.
459
- """
460
- dtype = dtype or ragged_or_dense.dtype
461
- if ragged_or_dense.dtype == np.ndarray:
462
- # Ragged.
463
- return np.array(
464
- [np.ones_like(row, dtype=dtype) for row in ragged_or_dense],
465
- dtype=np.ndarray,
466
- )
467
- else:
468
- # Dense.
469
- return np.ones_like(ragged_or_dense, dtype=dtype)
470
-
471
-
472
87
  def create_feature_samples(
473
88
  feature_structure: Nested[T],
474
89
  feature_ids: Nested[ArrayLike | Sequence[int] | Sequence[Sequence[int]]],
@@ -496,18 +111,17 @@ def create_feature_samples(
496
111
  )
497
112
 
498
113
  if feature_weights is None:
499
- # Make ragged or dense ones_like.
500
- feature_weights = jax.tree.map(
501
- lambda _, ids: ones_like(ids, np.float32),
114
+ return jax.tree.map( # type: ignore[no-any-return]
115
+ lambda _, ids: FeatureSamples(ids, None),
502
116
  feature_structure,
503
117
  feature_ids,
504
118
  )
505
- else:
506
- feature_weights = jax.tree.map(
507
- lambda _, wgts: convert_to_numpy(wgts, np.float32),
508
- feature_structure,
509
- feature_weights,
510
- )
119
+
120
+ feature_weights = jax.tree.map(
121
+ lambda _, wgts: convert_to_numpy(wgts, np.float32),
122
+ feature_structure,
123
+ feature_weights,
124
+ )
511
125
 
512
126
  # Assemble.
513
127
  def _create_feature_samples(
@@ -544,8 +158,8 @@ def stack_and_shard_samples(
544
158
  global_device_count: Number of global JAX devices.
545
159
  num_sc_per_device: Number of sparsecores per device.
546
160
  static_buffer_size: The static buffer size to use for the samples.
547
- Defaults to None, in which case an upper-bound for the buffer size
548
- will be automatically determined.
161
+ Defaults to None, in which case an upper-bound for the buffer size
162
+ will be automatically determined.
549
163
 
550
164
  Returns:
551
165
  The preprocessed inputs, and statistics useful for updating FeatureSpecs
@@ -555,17 +169,21 @@ def stack_and_shard_samples(
555
169
  flat_feature_specs, _ = jax.tree.flatten(feature_specs)
556
170
 
557
171
  feature_tokens = []
558
- feature_weights = []
172
+ collected_weights = []
559
173
 
560
174
  def collect_tokens_and_weights(
561
175
  feature_spec: FeatureSpec, samples: FeatureSamples
562
176
  ) -> None:
563
177
  del feature_spec
564
178
  feature_tokens.append(samples.tokens)
565
- feature_weights.append(samples.weights)
179
+ collected_weights.append(samples.weights)
566
180
 
567
181
  jax.tree.map(collect_tokens_and_weights, feature_specs, feature_samples)
568
182
 
183
+ feature_weights = (
184
+ None if all(w is None for w in collected_weights) else collected_weights
185
+ )
186
+
569
187
  preprocessed_inputs, stats = embedding.preprocess_sparse_dense_matmul_input(
570
188
  feature_tokens,
571
189
  feature_weights,
@@ -583,7 +201,10 @@ def stack_and_shard_samples(
583
201
  for table_name in tables_names:
584
202
  shard_ends = preprocessed_inputs.lhs_row_pointers[table_name]
585
203
  shard_starts = np.concatenate(
586
- [np.asarray([0]), _round_up_to_multiple(shard_ends[:-1], 8)]
204
+ [
205
+ np.asarray([0]),
206
+ table_stacking._next_largest_multiple(shard_ends[:-1], 8),
207
+ ]
587
208
  )
588
209
  out[table_name] = ShardedCooMatrix(
589
210
  shard_starts=shard_starts,
@@ -53,7 +53,7 @@ OPTIMIZER_MAPPINGS = {
53
53
  # KerasRS to TensorFlow
54
54
 
55
55
 
56
- def translate_keras_rs_configuration(
56
+ def keras_to_tf_tpu_configuration(
57
57
  feature_configs: types.Nested[FeatureConfig],
58
58
  table_stacking: str | Sequence[str] | Sequence[Sequence[str]],
59
59
  num_replicas_in_sync: int,
@@ -66,14 +66,15 @@ def translate_keras_rs_configuration(
66
66
  Args:
67
67
  feature_configs: The nested Keras RS feature configs.
68
68
  table_stacking: The Keras RS table stacking.
69
+ num_replicas_in_sync: The number of replicas in sync from the strategy.
69
70
 
70
71
  Returns:
71
72
  A tuple containing the TensorFlow TPU feature configs and the TensorFlow
72
73
  TPU sparse core embedding config.
73
74
  """
74
- tables: dict[TableConfig, tf.tpu.experimental.embedding.TableConfig] = {}
75
+ tables: dict[int, tf.tpu.experimental.embedding.TableConfig] = {}
75
76
  feature_configs = keras.tree.map_structure(
76
- lambda f: translate_keras_rs_feature_config(
77
+ lambda f: keras_to_tf_tpu_feature_config(
77
78
  f, tables, num_replicas_in_sync
78
79
  ),
79
80
  feature_configs,
@@ -108,9 +109,9 @@ def translate_keras_rs_configuration(
108
109
  return feature_configs, sparse_core_embedding_config
109
110
 
110
111
 
111
- def translate_keras_rs_feature_config(
112
+ def keras_to_tf_tpu_feature_config(
112
113
  feature_config: FeatureConfig,
113
- tables: dict[TableConfig, tf.tpu.experimental.embedding.TableConfig],
114
+ tables: dict[int, tf.tpu.experimental.embedding.TableConfig],
114
115
  num_replicas_in_sync: int,
115
116
  ) -> tf.tpu.experimental.embedding.FeatureConfig:
116
117
  """Translates a Keras RS feature config to a TensorFlow TPU feature config.
@@ -120,7 +121,8 @@ def translate_keras_rs_feature_config(
120
121
 
121
122
  Args:
122
123
  feature_config: The Keras RS feature config to translate.
123
- tables: A mapping of KerasRS table configs to TF TPU table configs.
124
+ tables: A mapping of KerasRS table config ids to TF TPU table configs.
125
+ num_replicas_in_sync: The number of replicas in sync from the strategy.
124
126
 
125
127
  Returns:
126
128
  The TensorFlow TPU feature config.
@@ -131,10 +133,10 @@ def translate_keras_rs_feature_config(
131
133
  f"but got {num_replicas_in_sync}."
132
134
  )
133
135
 
134
- table = tables.get(feature_config.table, None)
136
+ table = tables.get(id(feature_config.table), None)
135
137
  if table is None:
136
- table = translate_keras_rs_table_config(feature_config.table)
137
- tables[feature_config.table] = table
138
+ table = keras_to_tf_tpu_table_config(feature_config.table)
139
+ tables[id(feature_config.table)] = table
138
140
 
139
141
  if len(feature_config.output_shape) < 2:
140
142
  raise ValueError(
@@ -168,7 +170,7 @@ def translate_keras_rs_feature_config(
168
170
  )
169
171
 
170
172
 
171
- def translate_keras_rs_table_config(
173
+ def keras_to_tf_tpu_table_config(
172
174
  table_config: TableConfig,
173
175
  ) -> tf.tpu.experimental.embedding.TableConfig:
174
176
  initializer = table_config.initializer
@@ -179,13 +181,13 @@ def translate_keras_rs_table_config(
179
181
  vocabulary_size=table_config.vocabulary_size,
180
182
  dim=table_config.embedding_dim,
181
183
  initializer=initializer,
182
- optimizer=translate_optimizer(table_config.optimizer),
184
+ optimizer=to_tf_tpu_optimizer(table_config.optimizer),
183
185
  combiner=table_config.combiner,
184
186
  name=table_config.name,
185
187
  )
186
188
 
187
189
 
188
- def translate_keras_optimizer(
190
+ def keras_to_tf_tpu_optimizer(
189
191
  optimizer: keras.optimizers.Optimizer,
190
192
  ) -> TfTpuOptimizer:
191
193
  """Translates a Keras optimizer to a TensorFlow TPU `_Optimizer`.
@@ -238,7 +240,12 @@ def translate_keras_optimizer(
238
240
  "Unsupported optimizer option `Optimizer.loss_scale_factor`."
239
241
  )
240
242
 
241
- optimizer_mapping = OPTIMIZER_MAPPINGS.get(type(optimizer), None)
243
+ optimizer_mapping = None
244
+ for optimizer_class, mapping in OPTIMIZER_MAPPINGS.items():
245
+ # Handle subclasses of the main optimizer class.
246
+ if isinstance(optimizer, optimizer_class):
247
+ optimizer_mapping = mapping
248
+ break
242
249
  if optimizer_mapping is None:
243
250
  raise ValueError(
244
251
  f"Unsupported optimizer type {type(optimizer)}. Optimizer must be "
@@ -258,7 +265,7 @@ def translate_keras_optimizer(
258
265
  return optimizer_mapping.tpu_optimizer_class(**tpu_optimizer_kwargs)
259
266
 
260
267
 
261
- def translate_optimizer(
268
+ def to_tf_tpu_optimizer(
262
269
  optimizer: str | keras.optimizers.Optimizer | TfTpuOptimizer | None,
263
270
  ) -> TfTpuOptimizer:
264
271
  """Translates a Keras optimizer into a TensorFlow TPU `_Optimizer`.
@@ -299,7 +306,7 @@ def translate_optimizer(
299
306
  "'sgd', 'adagrad', 'adam', or 'ftrl'"
300
307
  )
301
308
  elif isinstance(optimizer, keras.optimizers.Optimizer):
302
- return translate_keras_optimizer(optimizer)
309
+ return keras_to_tf_tpu_optimizer(optimizer)
303
310
  else:
304
311
  raise ValueError(
305
312
  f"Unknown optimizer type {type(optimizer)}. Please pass an "
@@ -312,7 +319,7 @@ def translate_optimizer(
312
319
  # TensorFlow to TensorFlow
313
320
 
314
321
 
315
- def clone_tf_feature_configs(
322
+ def clone_tf_tpu_feature_configs(
316
323
  feature_configs: types.Nested[tf.tpu.experimental.embedding.FeatureConfig],
317
324
  ) -> types.Nested[tf.tpu.experimental.embedding.FeatureConfig]:
318
325
  """Clones and resolves TensorFlow TPU feature configs.
@@ -327,7 +334,7 @@ def clone_tf_feature_configs(
327
334
  """
328
335
  table_configs_dict = {}
329
336
 
330
- def clone_and_resolve_tf_feature_config(
337
+ def clone_and_resolve_tf_tpu_feature_config(
331
338
  fc: tf.tpu.experimental.embedding.FeatureConfig,
332
339
  ) -> tf.tpu.experimental.embedding.FeatureConfig:
333
340
  if fc.table not in table_configs_dict:
@@ -336,7 +343,7 @@ def clone_tf_feature_configs(
336
343
  vocabulary_size=fc.table.vocabulary_size,
337
344
  dim=fc.table.dim,
338
345
  initializer=fc.table.initializer,
339
- optimizer=translate_optimizer(fc.table.optimizer),
346
+ optimizer=to_tf_tpu_optimizer(fc.table.optimizer),
340
347
  combiner=fc.table.combiner,
341
348
  name=fc.table.name,
342
349
  quantization_config=fc.table.quantization_config,
@@ -352,5 +359,5 @@ def clone_tf_feature_configs(
352
359
  )
353
360
 
354
361
  return keras.tree.map_structure(
355
- clone_and_resolve_tf_feature_config, feature_configs
362
+ clone_and_resolve_tf_tpu_feature_config, feature_configs
356
363
  )
@@ -106,7 +106,7 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
106
106
  "for the configuration."
107
107
  )
108
108
  self._tpu_feature_configs, self._sparse_core_embedding_config = (
109
- config_conversion.translate_keras_rs_configuration(
109
+ config_conversion.keras_to_tf_tpu_configuration(
110
110
  feature_configs,
111
111
  table_stacking,
112
112
  strategy.num_replicas_in_sync,
@@ -135,10 +135,10 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
135
135
  "supported with this TPU generation."
136
136
  )
137
137
  self._tpu_feature_configs = (
138
- config_conversion.clone_tf_feature_configs(feature_configs)
138
+ config_conversion.clone_tf_tpu_feature_configs(feature_configs)
139
139
  )
140
140
 
141
- self._tpu_optimizer = config_conversion.translate_optimizer(
141
+ self._tpu_optimizer = config_conversion.to_tf_tpu_optimizer(
142
142
  self._optimizer
143
143
  )
144
144
 
@@ -281,8 +281,18 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
281
281
  def _sparsecore_get_embedding_tables(self) -> dict[str, types.Tensor]:
282
282
  tables: dict[str, types.Tensor] = {}
283
283
  strategy = tf.distribute.get_strategy()
284
- # 4 is the number of sparsecores per chip
285
- num_shards = strategy.num_replicas_in_sync * 4
284
+ if not self._is_tpu_strategy(strategy):
285
+ raise RuntimeError(
286
+ "`DistributedEmbedding.get_embedding_tables` needs to be "
287
+ "called under the TPUStrategy that DistributedEmbedding was "
288
+ f"created with, but is being called under strategy {strategy}. "
289
+ "Please use `with strategy.scope()` when calling "
290
+ "`get_embedding_tables`."
291
+ )
292
+
293
+ tpu_hardware = strategy.extended.tpu_hardware_feature
294
+ num_sc_per_device = tpu_hardware.num_embedding_devices_per_chip
295
+ num_shards = strategy.num_replicas_in_sync * num_sc_per_device
286
296
 
287
297
  def populate_table(
288
298
  feature_config: tf.tpu.experimental.embedding.FeatureConfig,