keras-rs-nightly 0.0.1.dev2025021903__py3-none-any.whl → 0.3.1.dev202512130338__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keras_rs/__init__.py +9 -28
- keras_rs/layers/__init__.py +37 -0
- keras_rs/losses/__init__.py +19 -0
- keras_rs/metrics/__init__.py +16 -0
- keras_rs/src/layers/embedding/base_distributed_embedding.py +1151 -0
- keras_rs/src/layers/embedding/distributed_embedding.py +33 -0
- keras_rs/src/layers/embedding/distributed_embedding_config.py +132 -0
- keras_rs/src/layers/embedding/embed_reduce.py +309 -0
- keras_rs/src/layers/embedding/jax/__init__.py +0 -0
- keras_rs/src/layers/embedding/jax/checkpoint_utils.py +104 -0
- keras_rs/src/layers/embedding/jax/config_conversion.py +468 -0
- keras_rs/src/layers/embedding/jax/distributed_embedding.py +829 -0
- keras_rs/src/layers/embedding/jax/embedding_lookup.py +276 -0
- keras_rs/src/layers/embedding/jax/embedding_utils.py +217 -0
- keras_rs/src/layers/embedding/tensorflow/__init__.py +0 -0
- keras_rs/src/layers/embedding/tensorflow/config_conversion.py +363 -0
- keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py +436 -0
- keras_rs/src/layers/feature_interaction/__init__.py +0 -0
- keras_rs/src/layers/{modeling → feature_interaction}/dot_interaction.py +116 -25
- keras_rs/src/layers/{modeling → feature_interaction}/feature_cross.py +40 -22
- keras_rs/src/layers/retrieval/brute_force_retrieval.py +16 -65
- keras_rs/src/layers/retrieval/hard_negative_mining.py +94 -0
- keras_rs/src/layers/retrieval/remove_accidental_hits.py +97 -0
- keras_rs/src/layers/retrieval/retrieval.py +127 -0
- keras_rs/src/layers/retrieval/sampling_probability_correction.py +63 -0
- keras_rs/src/losses/__init__.py +0 -0
- keras_rs/src/losses/list_mle_loss.py +212 -0
- keras_rs/src/losses/pairwise_hinge_loss.py +90 -0
- keras_rs/src/losses/pairwise_logistic_loss.py +99 -0
- keras_rs/src/losses/pairwise_loss.py +165 -0
- keras_rs/src/losses/pairwise_loss_utils.py +39 -0
- keras_rs/src/losses/pairwise_mean_squared_error.py +133 -0
- keras_rs/src/losses/pairwise_soft_zero_one_loss.py +98 -0
- keras_rs/src/metrics/__init__.py +0 -0
- keras_rs/src/metrics/dcg.py +161 -0
- keras_rs/src/metrics/mean_average_precision.py +130 -0
- keras_rs/src/metrics/mean_reciprocal_rank.py +121 -0
- keras_rs/src/metrics/ndcg.py +197 -0
- keras_rs/src/metrics/precision_at_k.py +117 -0
- keras_rs/src/metrics/ranking_metric.py +260 -0
- keras_rs/src/metrics/ranking_metrics_utils.py +257 -0
- keras_rs/src/metrics/recall_at_k.py +108 -0
- keras_rs/src/metrics/utils.py +70 -0
- keras_rs/src/types.py +43 -14
- keras_rs/src/utils/doc_string_utils.py +53 -0
- keras_rs/src/utils/keras_utils.py +52 -3
- keras_rs/src/utils/tpu_test_utils.py +120 -0
- keras_rs/src/version.py +1 -1
- {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/METADATA +88 -8
- keras_rs_nightly-0.3.1.dev202512130338.dist-info/RECORD +58 -0
- {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/WHEEL +1 -1
- keras_rs/api/__init__.py +0 -9
- keras_rs/api/layers/__init__.py +0 -11
- keras_rs_nightly-0.0.1.dev2025021903.dist-info/RECORD +0 -19
- /keras_rs/src/layers/{modeling → embedding}/__init__.py +0 -0
- {keras_rs_nightly-0.0.1.dev2025021903.dist-info → keras_rs_nightly-0.3.1.dev202512130338.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,829 @@
|
|
|
1
|
+
"""JAX implementation of the TPU embedding layer."""
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
import typing
|
|
5
|
+
from typing import Any, Mapping, Sequence, Union
|
|
6
|
+
|
|
7
|
+
import jax
|
|
8
|
+
import keras
|
|
9
|
+
import numpy as np
|
|
10
|
+
from jax import numpy as jnp
|
|
11
|
+
from jax.experimental import layout as jax_layout
|
|
12
|
+
from jax.experimental import multihost_utils
|
|
13
|
+
from jax_tpu_embedding.sparsecore.lib.nn import embedding
|
|
14
|
+
from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec
|
|
15
|
+
from jax_tpu_embedding.sparsecore.lib.nn import (
|
|
16
|
+
table_stacking as jte_table_stacking,
|
|
17
|
+
)
|
|
18
|
+
from jax_tpu_embedding.sparsecore.utils import utils as jte_utils
|
|
19
|
+
|
|
20
|
+
from keras_rs.src import types
|
|
21
|
+
from keras_rs.src.layers.embedding import base_distributed_embedding
|
|
22
|
+
from keras_rs.src.layers.embedding import distributed_embedding_config as config
|
|
23
|
+
from keras_rs.src.layers.embedding.jax import config_conversion
|
|
24
|
+
from keras_rs.src.layers.embedding.jax import (
|
|
25
|
+
embedding_lookup as jte_embedding_lookup,
|
|
26
|
+
)
|
|
27
|
+
from keras_rs.src.layers.embedding.jax import embedding_utils
|
|
28
|
+
from keras_rs.src.types import Nested
|
|
29
|
+
from keras_rs.src.utils import keras_utils
|
|
30
|
+
|
|
31
|
+
if jax.__version_info__ >= (0, 8, 0):
|
|
32
|
+
from jax import shard_map
|
|
33
|
+
else:
|
|
34
|
+
from jax.experimental.shard_map import shard_map # type: ignore[assignment]
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
ArrayLike = Union[np.ndarray[Any, Any], jax.Array]
|
|
38
|
+
FeatureConfig = config.FeatureConfig
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _get_partition_spec(
|
|
42
|
+
layout: (
|
|
43
|
+
keras.distribution.TensorLayout
|
|
44
|
+
| jax_layout.Format
|
|
45
|
+
| jax.sharding.NamedSharding
|
|
46
|
+
| jax.sharding.PartitionSpec
|
|
47
|
+
),
|
|
48
|
+
) -> Any:
|
|
49
|
+
"""Extracts the partition spec from a layout or sharding."""
|
|
50
|
+
if isinstance(layout, keras.distribution.TensorLayout):
|
|
51
|
+
layout = layout.backend_layout
|
|
52
|
+
|
|
53
|
+
if isinstance(layout, jax_layout.Format):
|
|
54
|
+
layout = layout.sharding
|
|
55
|
+
|
|
56
|
+
if isinstance(layout, jax.sharding.NamedSharding):
|
|
57
|
+
layout = layout.spec
|
|
58
|
+
|
|
59
|
+
return layout
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class ShardedInitializer(keras.initializers.Initializer):
|
|
63
|
+
"""Wraps an initializer to prepare for use with embedding tables.
|
|
64
|
+
|
|
65
|
+
Jit-compiles the function and applies optimal output sharding to
|
|
66
|
+
allow initialization on device.
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
def __init__(
|
|
70
|
+
self,
|
|
71
|
+
initializer: keras.initializers.Initializer | str,
|
|
72
|
+
layout: keras.distribution.TensorLayout | None,
|
|
73
|
+
):
|
|
74
|
+
if isinstance(initializer, str):
|
|
75
|
+
initializer = keras.initializers.get(initializer)
|
|
76
|
+
|
|
77
|
+
self._initializer = initializer
|
|
78
|
+
self._layout = layout
|
|
79
|
+
|
|
80
|
+
def __call__(
|
|
81
|
+
self, shape: types.Shape, dtype: types.DType | None = None
|
|
82
|
+
) -> jax.Array:
|
|
83
|
+
if self._layout is not None:
|
|
84
|
+
compiled_initializer = jax.jit(
|
|
85
|
+
self._initializer,
|
|
86
|
+
out_shardings=self._layout.backend_layout,
|
|
87
|
+
static_argnames=["shape", "dtype"],
|
|
88
|
+
)
|
|
89
|
+
output: jax.Array = compiled_initializer(shape, dtype)
|
|
90
|
+
return output
|
|
91
|
+
|
|
92
|
+
output = self._initializer(shape, dtype)
|
|
93
|
+
return output
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class StackedTableInitializer(keras.initializers.Initializer):
|
|
97
|
+
"""Initializes a single stacked table from multiple table initializers."""
|
|
98
|
+
|
|
99
|
+
def __init__(
|
|
100
|
+
self,
|
|
101
|
+
table_specs: Nested[embedding_spec.TableSpec],
|
|
102
|
+
num_shards: int,
|
|
103
|
+
layout: keras.distribution.TensorLayout,
|
|
104
|
+
seed: int | keras.random.SeedGenerator | jax.Array = 0,
|
|
105
|
+
):
|
|
106
|
+
# Sort table specs so we can simply concatenate them when assembling the
|
|
107
|
+
# stacked table.
|
|
108
|
+
self._table_specs = sorted(
|
|
109
|
+
keras.tree.flatten(table_specs),
|
|
110
|
+
key=lambda table_spec: (
|
|
111
|
+
table_spec.setting_in_stack.row_offset_in_shard,
|
|
112
|
+
),
|
|
113
|
+
)
|
|
114
|
+
self._num_shards = num_shards
|
|
115
|
+
self._layout = layout
|
|
116
|
+
self._key = keras.src.backend.jax.random.jax_draw_seed(seed)
|
|
117
|
+
|
|
118
|
+
def _initialize_shard(
|
|
119
|
+
self,
|
|
120
|
+
keys: jax.Array,
|
|
121
|
+
shape: tuple[int, int],
|
|
122
|
+
dtype: Any,
|
|
123
|
+
num_shards_per_device: int,
|
|
124
|
+
) -> jax.Array:
|
|
125
|
+
"""Initializes a single shard of a stacked table."""
|
|
126
|
+
del shape # Unused.
|
|
127
|
+
table_shards: list[jax.Array] = []
|
|
128
|
+
# NOTE: the following ignores padding, rotations in shard, and
|
|
129
|
+
# mod-sharding, assuming all initializers are shard-independent.
|
|
130
|
+
for i in range(num_shards_per_device):
|
|
131
|
+
for j, table_spec in enumerate(self._table_specs):
|
|
132
|
+
setting_in_stack = table_spec.setting_in_stack
|
|
133
|
+
table_shard_shape = (
|
|
134
|
+
setting_in_stack.padded_vocab_size // self._num_shards,
|
|
135
|
+
setting_in_stack.padded_embedding_dim,
|
|
136
|
+
)
|
|
137
|
+
initializer = table_spec.initializer
|
|
138
|
+
table_shards.append(
|
|
139
|
+
initializer(keys[i, j], table_shard_shape, dtype)
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
return jnp.concatenate(table_shards, axis=0)
|
|
143
|
+
|
|
144
|
+
def __call__(
|
|
145
|
+
self, shape: types.Shape, dtype: types.DType | None = None
|
|
146
|
+
) -> jax.Array:
|
|
147
|
+
stacked_table_spec = typing.cast(
|
|
148
|
+
embedding_spec.StackedTableSpec,
|
|
149
|
+
self._table_specs[0].stacked_table_spec,
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
# Input shape is governed by the table specs.
|
|
153
|
+
assert shape == (
|
|
154
|
+
stacked_table_spec.stack_vocab_size,
|
|
155
|
+
stacked_table_spec.stack_embedding_dim,
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
layout = self._layout
|
|
159
|
+
backend_layout = layout.backend_layout
|
|
160
|
+
backend_mesh = layout.device_mesh.backend_mesh
|
|
161
|
+
num_devices_along_axis = backend_mesh.shape[layout.axes[0]]
|
|
162
|
+
num_shards_per_device = self._num_shards // num_devices_along_axis
|
|
163
|
+
shard_shape = (
|
|
164
|
+
stacked_table_spec.stack_vocab_size // num_devices_along_axis,
|
|
165
|
+
stacked_table_spec.stack_embedding_dim,
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
sharded_initializer = jax.jit(
|
|
169
|
+
shard_map(
|
|
170
|
+
lambda keys: self._initialize_shard(
|
|
171
|
+
keys, shard_shape, dtype, num_shards_per_device
|
|
172
|
+
),
|
|
173
|
+
mesh=backend_mesh,
|
|
174
|
+
in_specs=_get_partition_spec(backend_layout),
|
|
175
|
+
out_specs=_get_partition_spec(backend_layout),
|
|
176
|
+
),
|
|
177
|
+
out_shardings=backend_layout,
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
keys = jax.random.split(
|
|
181
|
+
self._key, (self._num_shards, len(self._table_specs))
|
|
182
|
+
)
|
|
183
|
+
# Try extracting seeds from the existing table initializers.
|
|
184
|
+
for i, table_spec in enumerate(self._table_specs):
|
|
185
|
+
initializer = table_spec.initializer
|
|
186
|
+
if isinstance(
|
|
187
|
+
initializer, config_conversion.WrappedKerasInitializer
|
|
188
|
+
):
|
|
189
|
+
initializer_key = initializer.key()
|
|
190
|
+
if initializer_key is not None:
|
|
191
|
+
col = jax.random.split(initializer_key, self._num_shards)
|
|
192
|
+
keys = keys.at[:, i].set(col)
|
|
193
|
+
|
|
194
|
+
output: jax.Array = sharded_initializer(keys)
|
|
195
|
+
return output
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
|
|
199
|
+
"""JAX implementation of the TPU embedding layer."""
|
|
200
|
+
|
|
201
|
+
def _create_sparsecore_distribution(
|
|
202
|
+
self, sparsecore_axis_name: str = "sparsecore"
|
|
203
|
+
) -> tuple[
|
|
204
|
+
keras.distribution.ModelParallel, keras.distribution.TensorLayout
|
|
205
|
+
]:
|
|
206
|
+
"""SparseCore requires a specific layout.
|
|
207
|
+
|
|
208
|
+
The mesh must be 1D, must use all TPUs available, and must shard all
|
|
209
|
+
tables across all devices.
|
|
210
|
+
|
|
211
|
+
Args:
|
|
212
|
+
sparsecore_axis_name: The name of the sparsecore axis.
|
|
213
|
+
|
|
214
|
+
Returns:
|
|
215
|
+
A Keras distribution to use for all sparsecore operations.
|
|
216
|
+
"""
|
|
217
|
+
all_devices = jax.devices()
|
|
218
|
+
axes = [sparsecore_axis_name]
|
|
219
|
+
device_mesh = keras.distribution.DeviceMesh(
|
|
220
|
+
(len(all_devices),), axes, all_devices
|
|
221
|
+
)
|
|
222
|
+
sparsecore_layout = keras.distribution.TensorLayout(axes, device_mesh)
|
|
223
|
+
# Custom sparsecore layout with tiling.
|
|
224
|
+
LayoutClass = (
|
|
225
|
+
jax_layout.Layout
|
|
226
|
+
if jax.__version_info__ >= (0, 6, 3)
|
|
227
|
+
else jax_layout.DeviceLocalLayout # type: ignore
|
|
228
|
+
)
|
|
229
|
+
layout = (
|
|
230
|
+
LayoutClass(major_to_minor=(0, 1), tiling=((8,),)) # type: ignore
|
|
231
|
+
if jax.__version_info__ >= (0, 7, 1)
|
|
232
|
+
else LayoutClass(major_to_minor=(0, 1), _tiling=((8,),)) # type: ignore
|
|
233
|
+
)
|
|
234
|
+
# pylint: disable-next=protected-access
|
|
235
|
+
sparsecore_layout._backend_layout = jax_layout.Format(
|
|
236
|
+
layout, # type: ignore
|
|
237
|
+
jax.sharding.NamedSharding(
|
|
238
|
+
device_mesh.backend_mesh,
|
|
239
|
+
jax.sharding.PartitionSpec(
|
|
240
|
+
axes # type: ignore[no-untyped-call]
|
|
241
|
+
),
|
|
242
|
+
),
|
|
243
|
+
)
|
|
244
|
+
layout_map = keras.distribution.LayoutMap(device_mesh=device_mesh)
|
|
245
|
+
path = self.path
|
|
246
|
+
if path is None:
|
|
247
|
+
# Layer hasn't been properly built yet. Use current layer name.
|
|
248
|
+
path = self.name
|
|
249
|
+
layout_map[path + "/var"] = sparsecore_layout
|
|
250
|
+
sparsecore_distribution = keras.distribution.ModelParallel(
|
|
251
|
+
layout_map=layout_map
|
|
252
|
+
)
|
|
253
|
+
return sparsecore_distribution, sparsecore_layout
|
|
254
|
+
|
|
255
|
+
def _add_sparsecore_weight(
|
|
256
|
+
self,
|
|
257
|
+
name: str,
|
|
258
|
+
shape: tuple[int, int],
|
|
259
|
+
initializer: jax.nn.initializers.Initializer,
|
|
260
|
+
dtype: Any,
|
|
261
|
+
overwrite_with_gradient: bool,
|
|
262
|
+
) -> keras.Variable:
|
|
263
|
+
var = self.add_weight(
|
|
264
|
+
name=name, shape=shape, initializer=initializer, dtype=dtype
|
|
265
|
+
)
|
|
266
|
+
var.overwrite_with_gradient = overwrite_with_gradient
|
|
267
|
+
return var
|
|
268
|
+
|
|
269
|
+
def _add_table_variable(
|
|
270
|
+
self,
|
|
271
|
+
table_specs: Sequence[embedding_spec.TableSpec],
|
|
272
|
+
num_shards: int,
|
|
273
|
+
add_slot_variables: bool,
|
|
274
|
+
) -> embedding.EmbeddingVariables:
|
|
275
|
+
stacked_table_spec = typing.cast(
|
|
276
|
+
embedding_spec.StackedTableSpec, table_specs[0].stacked_table_spec
|
|
277
|
+
)
|
|
278
|
+
optimizer = stacked_table_spec.optimizer
|
|
279
|
+
num_slot_variables = optimizer.slot_variables_count()
|
|
280
|
+
table_shape = (
|
|
281
|
+
stacked_table_spec.stack_vocab_size,
|
|
282
|
+
stacked_table_spec.stack_embedding_dim,
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
# Make a stacked embedding table initializer.
|
|
286
|
+
table_initializers = [
|
|
287
|
+
config_conversion.jax_to_keras_initializer(table_spec.initializer)
|
|
288
|
+
for table_spec in table_specs
|
|
289
|
+
]
|
|
290
|
+
# If all initializers are the same, we can use a single sharded
|
|
291
|
+
# initializer. Otherwise, we need to interleave individual stacked table
|
|
292
|
+
# shards.
|
|
293
|
+
sparsecore_layout = self._sparsecore_layout
|
|
294
|
+
stacked_table_initializer = ShardedInitializer(
|
|
295
|
+
table_initializers[0], sparsecore_layout
|
|
296
|
+
)
|
|
297
|
+
if not all(
|
|
298
|
+
initializer == table_initializers[0]
|
|
299
|
+
for initializer in table_initializers
|
|
300
|
+
):
|
|
301
|
+
stacked_table_initializer = StackedTableInitializer(
|
|
302
|
+
table_specs, num_shards, sparsecore_layout
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
variable_name = f"var:{stacked_table_spec.stack_name}:table"
|
|
306
|
+
table_variable = self._add_sparsecore_weight(
|
|
307
|
+
name=variable_name,
|
|
308
|
+
shape=table_shape,
|
|
309
|
+
initializer=stacked_table_initializer,
|
|
310
|
+
dtype="float32",
|
|
311
|
+
overwrite_with_gradient=True,
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
slot_variables = None
|
|
315
|
+
if add_slot_variables:
|
|
316
|
+
# All optimizers for a given stacked table are guaranteed to be the
|
|
317
|
+
# same, so we can use a single sharded initializer for the entire
|
|
318
|
+
# stacked table.
|
|
319
|
+
slot_initializers = optimizer.slot_variables_initializers()
|
|
320
|
+
# Try extracting field names from variables, otherwise just use the
|
|
321
|
+
# count.
|
|
322
|
+
slot_names = range(num_slot_variables)
|
|
323
|
+
if hasattr(slot_initializers, "_fields"):
|
|
324
|
+
slot_names = slot_initializers._fields
|
|
325
|
+
|
|
326
|
+
slot_variables = tuple(
|
|
327
|
+
self._add_sparsecore_weight(
|
|
328
|
+
name=f"{variable_name}:slot:{slot_name}",
|
|
329
|
+
shape=table_shape,
|
|
330
|
+
initializer=ShardedInitializer(
|
|
331
|
+
config_conversion.jax_to_keras_initializer(initializer),
|
|
332
|
+
sparsecore_layout,
|
|
333
|
+
),
|
|
334
|
+
dtype=jnp.float32,
|
|
335
|
+
overwrite_with_gradient=True,
|
|
336
|
+
)
|
|
337
|
+
for slot_name, initializer in zip(slot_names, slot_initializers)
|
|
338
|
+
)
|
|
339
|
+
slot_variables = keras.tree.pack_sequence_as(
|
|
340
|
+
slot_initializers, slot_variables
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
return embedding.EmbeddingVariables(table_variable, slot_variables)
|
|
344
|
+
|
|
345
|
+
@keras_utils.no_automatic_dependency_tracking
|
|
346
|
+
def _sparsecore_init(
|
|
347
|
+
self,
|
|
348
|
+
feature_configs: dict[str, FeatureConfig],
|
|
349
|
+
table_stacking: str | Sequence[str] | Sequence[Sequence[str]],
|
|
350
|
+
) -> None:
|
|
351
|
+
if not self._has_sparsecore():
|
|
352
|
+
raise ValueError(
|
|
353
|
+
"Not sparse cores available, cannot use explicit sparsecore"
|
|
354
|
+
" placement."
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
self._sc_feature_configs = feature_configs
|
|
358
|
+
self._sparsecore_built = False
|
|
359
|
+
# Fill in any empty default settings.
|
|
360
|
+
for feature_config in keras.tree.flatten(self._sc_feature_configs):
|
|
361
|
+
if feature_config.table.initializer is None:
|
|
362
|
+
table = feature_config.table
|
|
363
|
+
table.initializer = keras.initializers.TruncatedNormal(
|
|
364
|
+
mean=0.0, stddev=1.0 / math.sqrt(float(table.embedding_dim))
|
|
365
|
+
)
|
|
366
|
+
|
|
367
|
+
# Actual stacking of tables is done in build() to ensure the
|
|
368
|
+
# distribution is set up correctly.
|
|
369
|
+
self._table_stacking = table_stacking
|
|
370
|
+
|
|
371
|
+
def _sparsecore_build(
|
|
372
|
+
self, input_shapes: Nested[types.Shape] | None = None
|
|
373
|
+
) -> None:
|
|
374
|
+
self.sparsecore_build(input_shapes)
|
|
375
|
+
|
|
376
|
+
@keras_utils.no_automatic_dependency_tracking
|
|
377
|
+
def sparsecore_build(
|
|
378
|
+
self, input_shapes: Nested[types.Shape] | None = None
|
|
379
|
+
) -> None:
|
|
380
|
+
del input_shapes # Unused.
|
|
381
|
+
|
|
382
|
+
if self._sparsecore_built:
|
|
383
|
+
return
|
|
384
|
+
|
|
385
|
+
feature_specs = config_conversion.keras_to_jte_feature_configs(
|
|
386
|
+
self._sc_feature_configs
|
|
387
|
+
)
|
|
388
|
+
|
|
389
|
+
# Distribution for sparsecore operations.
|
|
390
|
+
sparsecore_distribution, sparsecore_layout = (
|
|
391
|
+
self._create_sparsecore_distribution()
|
|
392
|
+
)
|
|
393
|
+
self._sparsecore_layout = sparsecore_layout
|
|
394
|
+
self._sparsecore_distribution = sparsecore_distribution
|
|
395
|
+
|
|
396
|
+
mesh = sparsecore_distribution.device_mesh.backend_mesh
|
|
397
|
+
global_device_count = mesh.devices.size
|
|
398
|
+
num_sc_per_device = jte_utils.num_sparsecores_per_device(
|
|
399
|
+
mesh.devices.item(0)
|
|
400
|
+
)
|
|
401
|
+
# One table shard per global sparsecore.
|
|
402
|
+
num_variable_shards = global_device_count * num_sc_per_device
|
|
403
|
+
|
|
404
|
+
# Maybe stack tables.
|
|
405
|
+
table_stacking = self._table_stacking
|
|
406
|
+
if table_stacking is not None:
|
|
407
|
+
if isinstance(table_stacking, str):
|
|
408
|
+
if table_stacking == "auto":
|
|
409
|
+
jte_table_stacking.auto_stack_tables(
|
|
410
|
+
feature_specs, global_device_count, num_sc_per_device
|
|
411
|
+
)
|
|
412
|
+
else:
|
|
413
|
+
raise ValueError(
|
|
414
|
+
f"Unsupported table stacking {table_stacking}, must be"
|
|
415
|
+
"None, 'auto', or sequences of table names to stack."
|
|
416
|
+
)
|
|
417
|
+
else:
|
|
418
|
+
if isinstance(table_stacking, list) and len(table_stacking) > 0:
|
|
419
|
+
elem = table_stacking[0]
|
|
420
|
+
# List of lists of table names.
|
|
421
|
+
if isinstance(elem, list):
|
|
422
|
+
for table_names in table_stacking:
|
|
423
|
+
jte_table_stacking.stack_tables(
|
|
424
|
+
feature_specs,
|
|
425
|
+
table_names,
|
|
426
|
+
global_device_count,
|
|
427
|
+
num_sc_per_device,
|
|
428
|
+
)
|
|
429
|
+
# Single list of table names.
|
|
430
|
+
elif isinstance(elem, str):
|
|
431
|
+
jte_table_stacking.stack_tables(
|
|
432
|
+
feature_specs,
|
|
433
|
+
table_stacking,
|
|
434
|
+
global_device_count,
|
|
435
|
+
num_sc_per_device,
|
|
436
|
+
)
|
|
437
|
+
else:
|
|
438
|
+
raise ValueError(
|
|
439
|
+
f"Unsupported table stacking {table_stacking}, "
|
|
440
|
+
"must be None, 'auto', or sequences of table names "
|
|
441
|
+
"to stack."
|
|
442
|
+
)
|
|
443
|
+
|
|
444
|
+
# Adjust any non-stacked tables to prepare for training.
|
|
445
|
+
embedding.prepare_feature_specs_for_training(
|
|
446
|
+
feature_specs, global_device_count, num_sc_per_device
|
|
447
|
+
)
|
|
448
|
+
|
|
449
|
+
# Collect all stacked tables.
|
|
450
|
+
table_specs = embedding.get_table_specs(feature_specs)
|
|
451
|
+
table_stacks = jte_table_stacking.get_table_stacks(table_specs)
|
|
452
|
+
|
|
453
|
+
# Update stacked table stats to max of values across involved tables.
|
|
454
|
+
max_ids_per_partition = {}
|
|
455
|
+
max_unique_ids_per_partition = {}
|
|
456
|
+
required_buffer_size_per_device = {}
|
|
457
|
+
id_drop_counters = {}
|
|
458
|
+
for stack_name, stack in table_stacks.items():
|
|
459
|
+
max_ids_per_partition[stack_name] = np.max(
|
|
460
|
+
np.asarray(
|
|
461
|
+
[s.max_ids_per_partition for s in stack], dtype=np.int32
|
|
462
|
+
)
|
|
463
|
+
)
|
|
464
|
+
max_unique_ids_per_partition[stack_name] = np.max(
|
|
465
|
+
np.asarray(
|
|
466
|
+
[s.max_unique_ids_per_partition for s in stack],
|
|
467
|
+
dtype=np.int32,
|
|
468
|
+
)
|
|
469
|
+
)
|
|
470
|
+
|
|
471
|
+
# Only set the suggested buffer size if set on any individual table.
|
|
472
|
+
valid_buffer_sizes = [
|
|
473
|
+
s.suggested_coo_buffer_size_per_device
|
|
474
|
+
for s in stack
|
|
475
|
+
if s.suggested_coo_buffer_size_per_device is not None
|
|
476
|
+
]
|
|
477
|
+
if valid_buffer_sizes:
|
|
478
|
+
required_buffer_size_per_device[stack_name] = np.max(
|
|
479
|
+
np.asarray(valid_buffer_sizes, dtype=np.int32)
|
|
480
|
+
)
|
|
481
|
+
|
|
482
|
+
id_drop_counters[stack_name] = 0
|
|
483
|
+
|
|
484
|
+
aggregated_stats = embedding.SparseDenseMatmulInputStats(
|
|
485
|
+
max_ids_per_partition=max_ids_per_partition,
|
|
486
|
+
max_unique_ids_per_partition=max_unique_ids_per_partition,
|
|
487
|
+
required_buffer_size_per_sc=required_buffer_size_per_device,
|
|
488
|
+
id_drop_counters=id_drop_counters,
|
|
489
|
+
)
|
|
490
|
+
embedding.update_preprocessing_parameters(
|
|
491
|
+
feature_specs,
|
|
492
|
+
aggregated_stats,
|
|
493
|
+
num_sc_per_device,
|
|
494
|
+
)
|
|
495
|
+
|
|
496
|
+
# Create variables for all stacked tables and slot variables.
|
|
497
|
+
with sparsecore_distribution.scope():
|
|
498
|
+
self._table_and_slot_variables = {
|
|
499
|
+
table_name: self._add_table_variable(
|
|
500
|
+
table_stack,
|
|
501
|
+
add_slot_variables=self.trainable,
|
|
502
|
+
num_shards=num_variable_shards,
|
|
503
|
+
)
|
|
504
|
+
for table_name, table_stack in table_stacks.items()
|
|
505
|
+
}
|
|
506
|
+
|
|
507
|
+
# Create a step-counter variable for use in custom table gradients.
|
|
508
|
+
# This must be a floating-point type so we can get a real gradient
|
|
509
|
+
# for it. It will automatically be updated with each application of
|
|
510
|
+
# the optimizer, since the next iteration is returned in the
|
|
511
|
+
# gradient.
|
|
512
|
+
sharded_zero_initializer = ShardedInitializer(
|
|
513
|
+
"zeros",
|
|
514
|
+
keras.distribution.TensorLayout(
|
|
515
|
+
[], sparsecore_layout.device_mesh
|
|
516
|
+
),
|
|
517
|
+
)
|
|
518
|
+
self._iterations = self.add_weight(
|
|
519
|
+
shape=(),
|
|
520
|
+
name="iteration",
|
|
521
|
+
initializer=sharded_zero_initializer,
|
|
522
|
+
dtype="float32",
|
|
523
|
+
trainable=True,
|
|
524
|
+
)
|
|
525
|
+
self._iterations.overwrite_with_gradient = True
|
|
526
|
+
|
|
527
|
+
self._config = jte_embedding_lookup.EmbeddingLookupConfiguration(
|
|
528
|
+
feature_specs,
|
|
529
|
+
mesh=mesh,
|
|
530
|
+
table_partition=_get_partition_spec(sparsecore_layout),
|
|
531
|
+
samples_partition=_get_partition_spec(sparsecore_layout),
|
|
532
|
+
table_layout=sparsecore_layout.backend_layout,
|
|
533
|
+
)
|
|
534
|
+
|
|
535
|
+
self._sparsecore_built = True
|
|
536
|
+
|
|
537
|
+
def _sparsecore_symbolic_preprocess(
|
|
538
|
+
self,
|
|
539
|
+
inputs: dict[str, types.Tensor],
|
|
540
|
+
weights: dict[str, types.Tensor] | None,
|
|
541
|
+
training: bool = False,
|
|
542
|
+
) -> dict[str, dict[str, embedding_utils.ShardedCooMatrix]]:
|
|
543
|
+
"""Allow preprocess(...) with `keras.Input`s.
|
|
544
|
+
|
|
545
|
+
This is to support creating functional models via:
|
|
546
|
+
```python
|
|
547
|
+
inputs = keras.Input(shape=(None), dtype="int32")
|
|
548
|
+
weights = keras.Input(shape=(None), dtype="float32")
|
|
549
|
+
preprocessed_inputs = distributed_embedding.preprocess(inputs, weights)
|
|
550
|
+
outputs = distributed_embedding(preprocessed_inputs)
|
|
551
|
+
model = keras.Model(inputs=preprocessed_inputs, outputs=outputs)
|
|
552
|
+
```
|
|
553
|
+
|
|
554
|
+
Args:
|
|
555
|
+
inputs: SparseCore path->tensor input ID's tensors.
|
|
556
|
+
weights: Optional Sparsecore path->tensor input weights tensors.
|
|
557
|
+
training: Whether the layer is training or not.
|
|
558
|
+
|
|
559
|
+
Returns:
|
|
560
|
+
Symbolic preprocessed input tensors to the layer/model.
|
|
561
|
+
"""
|
|
562
|
+
# Arguments are currently ignored since the input shape is governed
|
|
563
|
+
# by the stacked table configuration.
|
|
564
|
+
del inputs, weights, training
|
|
565
|
+
|
|
566
|
+
# Each stacked-table gets a ShardedCooMatrix.
|
|
567
|
+
table_specs = embedding.get_table_specs(self._config.feature_specs)
|
|
568
|
+
table_stacks = jte_table_stacking.get_table_stacks(table_specs)
|
|
569
|
+
stacked_table_specs = {
|
|
570
|
+
stack_name: stack[0].stacked_table_spec
|
|
571
|
+
for stack_name, stack in table_stacks.items()
|
|
572
|
+
}
|
|
573
|
+
|
|
574
|
+
def _compute_table_output_spec(
|
|
575
|
+
stacked_table_spec: embedding_spec.StackedTableSpec,
|
|
576
|
+
) -> embedding_utils.ShardedCooMatrix:
|
|
577
|
+
# The true shape of the components in the ShardedCooMatrix depends
|
|
578
|
+
# on the hardware configuration (# devices, sparsecores),
|
|
579
|
+
# properties of the input data (# max IDs, unique IDs), and other
|
|
580
|
+
# hints like a suggested internal buffer size. Some of the
|
|
581
|
+
# calculations are currently a bit in flux as we experiment with
|
|
582
|
+
# memory trade-offs. For the purposes of input/output sizes,
|
|
583
|
+
# however, the size could be viewed as dynamic 1D without affecting
|
|
584
|
+
# the output spec sizes.
|
|
585
|
+
del stacked_table_spec
|
|
586
|
+
return embedding_utils.ShardedCooMatrix(
|
|
587
|
+
# Mark these as `Input`s since that's how they will be used when
|
|
588
|
+
# constructing a functional Keras model.
|
|
589
|
+
shard_starts=keras.Input(shape=tuple(), dtype="int32"),
|
|
590
|
+
shard_ends=keras.Input(shape=tuple(), dtype="int32"),
|
|
591
|
+
col_ids=keras.Input(shape=tuple(), dtype="int32"),
|
|
592
|
+
row_ids=keras.Input(shape=tuple(), dtype="int32"),
|
|
593
|
+
values=keras.Input(shape=tuple(), dtype="float32"),
|
|
594
|
+
)
|
|
595
|
+
|
|
596
|
+
preprocessed = keras.tree.map_structure(
|
|
597
|
+
_compute_table_output_spec, stacked_table_specs
|
|
598
|
+
)
|
|
599
|
+
|
|
600
|
+
return {"inputs": preprocessed}
|
|
601
|
+
|
|
602
|
+
def _sparsecore_preprocess(
|
|
603
|
+
self,
|
|
604
|
+
inputs: dict[str, types.Tensor],
|
|
605
|
+
weights: dict[str, types.Tensor] | None,
|
|
606
|
+
training: bool = False,
|
|
607
|
+
) -> dict[str, dict[str, embedding_utils.ShardedCooMatrix]]:
|
|
608
|
+
if any(
|
|
609
|
+
isinstance(x, jax.core.Tracer) for x in keras.tree.flatten(inputs)
|
|
610
|
+
):
|
|
611
|
+
raise ValueError(
|
|
612
|
+
"DistributedEmbedding.preprocess(...) does not support"
|
|
613
|
+
" jit-compilation"
|
|
614
|
+
)
|
|
615
|
+
|
|
616
|
+
if not self._sparsecore_built:
|
|
617
|
+
self._sparsecore_build()
|
|
618
|
+
|
|
619
|
+
# Support symbolic KerasTensors (i.e. keras.Input).
|
|
620
|
+
if any(
|
|
621
|
+
isinstance(x, keras.KerasTensor) for x in keras.tree.flatten(inputs)
|
|
622
|
+
):
|
|
623
|
+
return self._sparsecore_symbolic_preprocess(
|
|
624
|
+
inputs, weights, training
|
|
625
|
+
)
|
|
626
|
+
|
|
627
|
+
samples = embedding_utils.create_feature_samples(
|
|
628
|
+
self._config.feature_specs, inputs, weights
|
|
629
|
+
)
|
|
630
|
+
|
|
631
|
+
layout = self._sparsecore_layout
|
|
632
|
+
mesh = layout.device_mesh.backend_mesh
|
|
633
|
+
global_device_count = mesh.devices.size
|
|
634
|
+
local_device_count = mesh.local_mesh.devices.size
|
|
635
|
+
num_sc_per_device = jte_utils.num_sparsecores_per_device(
|
|
636
|
+
mesh.devices.item(0)
|
|
637
|
+
)
|
|
638
|
+
|
|
639
|
+
preprocessed, stats = embedding_utils.stack_and_shard_samples(
|
|
640
|
+
self._config.feature_specs,
|
|
641
|
+
samples,
|
|
642
|
+
local_device_count,
|
|
643
|
+
global_device_count,
|
|
644
|
+
num_sc_per_device,
|
|
645
|
+
)
|
|
646
|
+
|
|
647
|
+
if training:
|
|
648
|
+
# Synchronize input statistics across all devices and update the
|
|
649
|
+
# underlying stacked tables specs in the feature specs.
|
|
650
|
+
|
|
651
|
+
# Gather stats across all processes/devices via process_allgather.
|
|
652
|
+
all_stats = multihost_utils.process_allgather(stats)
|
|
653
|
+
all_stats = jax.tree.map(np.max, all_stats)
|
|
654
|
+
|
|
655
|
+
# Check if stats changed enough to warrant action.
|
|
656
|
+
stacked_table_specs = embedding.get_stacked_table_specs(
|
|
657
|
+
self._config.feature_specs
|
|
658
|
+
)
|
|
659
|
+
changed = any(
|
|
660
|
+
all_stats.max_ids_per_partition[stack_name]
|
|
661
|
+
> spec.max_ids_per_partition
|
|
662
|
+
or all_stats.max_unique_ids_per_partition[stack_name]
|
|
663
|
+
> spec.max_unique_ids_per_partition
|
|
664
|
+
or all_stats.required_buffer_size_per_sc[stack_name]
|
|
665
|
+
* num_sc_per_device
|
|
666
|
+
> (spec.suggested_coo_buffer_size_per_device or 0)
|
|
667
|
+
for stack_name, spec in stacked_table_specs.items()
|
|
668
|
+
)
|
|
669
|
+
|
|
670
|
+
# Update configuration and repeat preprocessing if stats changed.
|
|
671
|
+
if changed:
|
|
672
|
+
for stack_name, spec in stacked_table_specs.items():
|
|
673
|
+
all_stats.max_ids_per_partition[stack_name] = np.max(
|
|
674
|
+
[
|
|
675
|
+
all_stats.max_ids_per_partition[stack_name],
|
|
676
|
+
spec.max_ids_per_partition,
|
|
677
|
+
]
|
|
678
|
+
)
|
|
679
|
+
all_stats.max_unique_ids_per_partition[stack_name] = np.max(
|
|
680
|
+
[
|
|
681
|
+
all_stats.max_unique_ids_per_partition[stack_name],
|
|
682
|
+
spec.max_unique_ids_per_partition,
|
|
683
|
+
]
|
|
684
|
+
)
|
|
685
|
+
all_stats.required_buffer_size_per_sc[stack_name] = np.max(
|
|
686
|
+
[
|
|
687
|
+
all_stats.required_buffer_size_per_sc[stack_name],
|
|
688
|
+
(
|
|
689
|
+
(spec.suggested_coo_buffer_size_per_device or 0)
|
|
690
|
+
+ (num_sc_per_device - 1)
|
|
691
|
+
)
|
|
692
|
+
// num_sc_per_device,
|
|
693
|
+
]
|
|
694
|
+
)
|
|
695
|
+
|
|
696
|
+
embedding.update_preprocessing_parameters(
|
|
697
|
+
self._config.feature_specs, all_stats, num_sc_per_device
|
|
698
|
+
)
|
|
699
|
+
|
|
700
|
+
# Re-execute preprocessing with consistent input statistics.
|
|
701
|
+
preprocessed, _ = embedding_utils.stack_and_shard_samples(
|
|
702
|
+
self._config.feature_specs,
|
|
703
|
+
samples,
|
|
704
|
+
local_device_count,
|
|
705
|
+
global_device_count,
|
|
706
|
+
num_sc_per_device,
|
|
707
|
+
)
|
|
708
|
+
|
|
709
|
+
return {"inputs": preprocessed}
|
|
710
|
+
|
|
711
|
+
def _sparsecore_call(
|
|
712
|
+
self,
|
|
713
|
+
inputs: dict[str, types.Tensor],
|
|
714
|
+
weights: dict[str, types.Tensor] | None = None,
|
|
715
|
+
training: bool = False,
|
|
716
|
+
**kwargs: Any,
|
|
717
|
+
) -> dict[str, types.Tensor]:
|
|
718
|
+
assert weights is None
|
|
719
|
+
|
|
720
|
+
if not self._sparsecore_built:
|
|
721
|
+
self._sparsecore_build()
|
|
722
|
+
|
|
723
|
+
table_and_slots = keras.tree.map_structure(
|
|
724
|
+
lambda var: var.value, self._table_and_slot_variables
|
|
725
|
+
)
|
|
726
|
+
with self._sparsecore_distribution.scope():
|
|
727
|
+
lookup_func = jax.jit(
|
|
728
|
+
jte_embedding_lookup.embedding_lookup, static_argnames="config"
|
|
729
|
+
)
|
|
730
|
+
out: dict[str, types.Tensor] = lookup_func(
|
|
731
|
+
self._config, inputs, table_and_slots, self._iterations.value
|
|
732
|
+
)
|
|
733
|
+
return out
|
|
734
|
+
|
|
735
|
+
def set_embedding_tables(self, tables: Mapping[str, ArrayLike]) -> None:
|
|
736
|
+
"""Sets the embedding tables to specific (unsharded) values.
|
|
737
|
+
|
|
738
|
+
Args:
|
|
739
|
+
tables: Mapping of table name -> table values.
|
|
740
|
+
"""
|
|
741
|
+
if "default_device" in self._placement_to_path_to_feature_config:
|
|
742
|
+
self._default_device_set_tables(tables)
|
|
743
|
+
|
|
744
|
+
if "sparsecore" in self._placement_to_path_to_feature_config:
|
|
745
|
+
self._sparsecore_set_tables(tables)
|
|
746
|
+
|
|
747
|
+
def _default_device_set_tables(
|
|
748
|
+
self, tables: Mapping[str, ArrayLike]
|
|
749
|
+
) -> None:
|
|
750
|
+
if not self.built:
|
|
751
|
+
raise ValueError("Layer must first be built before setting tables.")
|
|
752
|
+
|
|
753
|
+
if "default_device" in self._placement_to_path_to_feature_config:
|
|
754
|
+
table_name_to_embedding_layer = {}
|
|
755
|
+
for (
|
|
756
|
+
path,
|
|
757
|
+
feature_config,
|
|
758
|
+
) in self._placement_to_path_to_feature_config[
|
|
759
|
+
"default_device"
|
|
760
|
+
].items():
|
|
761
|
+
table_name_to_embedding_layer[feature_config.table.name] = (
|
|
762
|
+
self._default_device_embedding_layers[path]
|
|
763
|
+
)
|
|
764
|
+
|
|
765
|
+
for (
|
|
766
|
+
table_name,
|
|
767
|
+
embedding_layer,
|
|
768
|
+
) in table_name_to_embedding_layer.items():
|
|
769
|
+
table_values = tables.get(table_name, None)
|
|
770
|
+
if table_values is not None:
|
|
771
|
+
if embedding_layer.lora_enabled:
|
|
772
|
+
raise ValueError("Cannot set table if LoRA is enabled.")
|
|
773
|
+
# pylint: disable-next=protected-access
|
|
774
|
+
embedding_layer._embeddings.assign(table_values)
|
|
775
|
+
|
|
776
|
+
def _sparsecore_set_tables(self, tables: Mapping[str, ArrayLike]) -> None:
|
|
777
|
+
if not self._sparsecore_built:
|
|
778
|
+
self._sparsecore_build()
|
|
779
|
+
|
|
780
|
+
config = self._config
|
|
781
|
+
num_table_shards = config.mesh.devices.size * config.num_sc_per_device
|
|
782
|
+
table_specs = embedding.get_table_specs(config.feature_specs)
|
|
783
|
+
sharded_tables = jte_table_stacking.stack_and_shard_tables(
|
|
784
|
+
table_specs,
|
|
785
|
+
tables,
|
|
786
|
+
num_table_shards,
|
|
787
|
+
)
|
|
788
|
+
|
|
789
|
+
device_tables = jax.device_put(
|
|
790
|
+
jax.tree.map(
|
|
791
|
+
# Flatten shard dimension to allow auto-sharding to split the
|
|
792
|
+
# array.
|
|
793
|
+
lambda table: table.reshape((-1, table.shape[-1])),
|
|
794
|
+
sharded_tables,
|
|
795
|
+
),
|
|
796
|
+
self._sparsecore_layout.backend_layout,
|
|
797
|
+
)
|
|
798
|
+
|
|
799
|
+
# Assign stacked table variables to the device values.
|
|
800
|
+
keras.tree.map_structure_up_to(
|
|
801
|
+
device_tables,
|
|
802
|
+
lambda embedding_variables,
|
|
803
|
+
table_value: embedding_variables.table.assign(table_value),
|
|
804
|
+
self._table_and_slot_variables,
|
|
805
|
+
device_tables,
|
|
806
|
+
)
|
|
807
|
+
|
|
808
|
+
def _sparsecore_get_embedding_tables(self) -> dict[str, ArrayLike]:
|
|
809
|
+
if not self._sparsecore_built:
|
|
810
|
+
self.sparsecore_build()
|
|
811
|
+
|
|
812
|
+
config = self._config
|
|
813
|
+
num_table_shards = config.mesh.devices.size * config.num_sc_per_device
|
|
814
|
+
table_specs = embedding.get_table_specs(config.feature_specs)
|
|
815
|
+
|
|
816
|
+
# Extract only the table variables, not the gradient slot variables.
|
|
817
|
+
table_variables = {
|
|
818
|
+
name: jax.device_get(embedding_variables.table.value)
|
|
819
|
+
for name, embedding_variables in (
|
|
820
|
+
self._table_and_slot_variables.items()
|
|
821
|
+
)
|
|
822
|
+
}
|
|
823
|
+
|
|
824
|
+
return typing.cast(
|
|
825
|
+
dict[str, ArrayLike],
|
|
826
|
+
jte_table_stacking.unshard_and_unstack_tables(
|
|
827
|
+
table_specs, table_variables, num_table_shards
|
|
828
|
+
),
|
|
829
|
+
)
|