keras-rs-nightly 0.3.1.dev202510170329__tar.gz → 0.3.1.dev202510180323__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510180323}/PKG-INFO +1 -1
- {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510180323}/keras_rs/src/layers/embedding/jax/distributed_embedding.py +4 -4
- keras_rs_nightly-0.3.1.dev202510180323/keras_rs/src/layers/embedding/jax/embedding_utils.py +244 -0
- {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510180323}/keras_rs/src/version.py +1 -1
- {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510180323}/keras_rs_nightly.egg-info/PKG-INFO +1 -1
- keras_rs_nightly-0.3.1.dev202510170329/keras_rs/src/layers/embedding/jax/embedding_utils.py +0 -535
- {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510180323}/README.md +0 -0
- {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510180323}/keras_rs/api/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510180323}/keras_rs/api/layers/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510180323}/keras_rs/api/losses/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510180323}/keras_rs/api/metrics/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510180323}/keras_rs/src/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510180323}/keras_rs/src/api_export.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510180323}/keras_rs/src/layers/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510180323}/keras_rs/src/layers/embedding/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510180323}/keras_rs/src/layers/embedding/base_distributed_embedding.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510180323}/keras_rs/src/layers/embedding/distributed_embedding.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510180323}/keras_rs/src/layers/embedding/distributed_embedding_config.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510180323}/keras_rs/src/layers/embedding/embed_reduce.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510180323}/keras_rs/src/layers/embedding/jax/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510180323}/keras_rs/src/layers/embedding/jax/checkpoint_utils.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510180323}/keras_rs/src/layers/embedding/jax/config_conversion.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510180323}/keras_rs/src/layers/embedding/jax/embedding_lookup.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510180323}/keras_rs/src/layers/embedding/tensorflow/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510180323}/keras_rs/src/layers/embedding/tensorflow/config_conversion.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510180323}/keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510180323}/keras_rs/src/layers/feature_interaction/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510180323}/keras_rs/src/layers/feature_interaction/dot_interaction.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510180323}/keras_rs/src/layers/feature_interaction/feature_cross.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510180323}/keras_rs/src/layers/retrieval/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510180323}/keras_rs/src/layers/retrieval/brute_force_retrieval.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510180323}/keras_rs/src/layers/retrieval/hard_negative_mining.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510180323}/keras_rs/src/layers/retrieval/remove_accidental_hits.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510180323}/keras_rs/src/layers/retrieval/retrieval.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510180323}/keras_rs/src/layers/retrieval/sampling_probability_correction.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510180323}/keras_rs/src/losses/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510180323}/keras_rs/src/losses/pairwise_hinge_loss.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510180323}/keras_rs/src/losses/pairwise_logistic_loss.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510180323}/keras_rs/src/losses/pairwise_loss.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510180323}/keras_rs/src/losses/pairwise_loss_utils.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510180323}/keras_rs/src/losses/pairwise_mean_squared_error.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510180323}/keras_rs/src/losses/pairwise_soft_zero_one_loss.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510180323}/keras_rs/src/metrics/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510180323}/keras_rs/src/metrics/dcg.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510180323}/keras_rs/src/metrics/mean_average_precision.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510180323}/keras_rs/src/metrics/mean_reciprocal_rank.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510180323}/keras_rs/src/metrics/ndcg.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510180323}/keras_rs/src/metrics/precision_at_k.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510180323}/keras_rs/src/metrics/ranking_metric.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510180323}/keras_rs/src/metrics/ranking_metrics_utils.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510180323}/keras_rs/src/metrics/recall_at_k.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510180323}/keras_rs/src/metrics/utils.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510180323}/keras_rs/src/types.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510180323}/keras_rs/src/utils/__init__.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510180323}/keras_rs/src/utils/doc_string_utils.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510180323}/keras_rs/src/utils/keras_utils.py +0 -0
- {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510180323}/keras_rs_nightly.egg-info/SOURCES.txt +0 -0
- {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510180323}/keras_rs_nightly.egg-info/dependency_links.txt +0 -0
- {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510180323}/keras_rs_nightly.egg-info/requires.txt +0 -0
- {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510180323}/keras_rs_nightly.egg-info/top_level.txt +0 -0
- {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510180323}/pyproject.toml +0 -0
- {keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510180323}/setup.cfg +0 -0
|
@@ -442,7 +442,7 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
|
|
|
442
442
|
|
|
443
443
|
# Collect all stacked tables.
|
|
444
444
|
table_specs = embedding.get_table_specs(feature_specs)
|
|
445
|
-
table_stacks =
|
|
445
|
+
table_stacks = jte_table_stacking.get_table_stacks(table_specs)
|
|
446
446
|
|
|
447
447
|
# Create variables for all stacked tables and slot variables.
|
|
448
448
|
with sparsecore_distribution.scope():
|
|
@@ -516,7 +516,7 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
|
|
|
516
516
|
|
|
517
517
|
# Each stacked-table gets a ShardedCooMatrix.
|
|
518
518
|
table_specs = embedding.get_table_specs(self._config.feature_specs)
|
|
519
|
-
table_stacks =
|
|
519
|
+
table_stacks = jte_table_stacking.get_table_stacks(table_specs)
|
|
520
520
|
stacked_table_specs = {
|
|
521
521
|
stack_name: stack[0].stacked_table_spec
|
|
522
522
|
for stack_name, stack in table_stacks.items()
|
|
@@ -720,7 +720,7 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
|
|
|
720
720
|
config = self._config
|
|
721
721
|
num_table_shards = config.mesh.devices.size * config.num_sc_per_device
|
|
722
722
|
table_specs = embedding.get_table_specs(config.feature_specs)
|
|
723
|
-
sharded_tables =
|
|
723
|
+
sharded_tables = jte_table_stacking.stack_and_shard_tables(
|
|
724
724
|
table_specs,
|
|
725
725
|
tables,
|
|
726
726
|
num_table_shards,
|
|
@@ -763,7 +763,7 @@ class DistributedEmbedding(base_distributed_embedding.DistributedEmbedding):
|
|
|
763
763
|
|
|
764
764
|
return typing.cast(
|
|
765
765
|
dict[str, ArrayLike],
|
|
766
|
-
|
|
766
|
+
jte_table_stacking.unshard_and_unstack_tables(
|
|
767
767
|
table_specs, table_variables, num_table_shards
|
|
768
768
|
),
|
|
769
769
|
)
|
|
@@ -0,0 +1,244 @@
|
|
|
1
|
+
"""Utility functions for manipulating JAX embedding tables and inputs."""
|
|
2
|
+
|
|
3
|
+
import collections
|
|
4
|
+
from typing import Any, Mapping, NamedTuple, Sequence, TypeAlias, TypeVar
|
|
5
|
+
|
|
6
|
+
import jax
|
|
7
|
+
import numpy as np
|
|
8
|
+
from jax_tpu_embedding.sparsecore.lib.nn import embedding
|
|
9
|
+
from jax_tpu_embedding.sparsecore.lib.nn import table_stacking
|
|
10
|
+
from jax_tpu_embedding.sparsecore.lib.nn.embedding_spec import FeatureSpec
|
|
11
|
+
|
|
12
|
+
from keras_rs.src.types import Nested
|
|
13
|
+
|
|
14
|
+
T = TypeVar("T")
|
|
15
|
+
|
|
16
|
+
# Any to support tf.Ragged without needing an explicit TF dependency.
|
|
17
|
+
ArrayLike: TypeAlias = jax.Array | np.ndarray | Any # type: ignore
|
|
18
|
+
Shape: TypeAlias = tuple[int, ...]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class FeatureSamples(NamedTuple):
|
|
22
|
+
tokens: ArrayLike
|
|
23
|
+
weights: ArrayLike
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class ShardedCooMatrix(NamedTuple):
|
|
27
|
+
shard_starts: ArrayLike
|
|
28
|
+
shard_ends: ArrayLike
|
|
29
|
+
col_ids: ArrayLike
|
|
30
|
+
row_ids: ArrayLike
|
|
31
|
+
values: ArrayLike
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def convert_to_numpy(
|
|
35
|
+
ragged_or_dense: np.ndarray[Any, Any] | Sequence[Sequence[Any]] | Any,
|
|
36
|
+
dtype: Any,
|
|
37
|
+
) -> np.ndarray[Any, Any]:
|
|
38
|
+
"""Converts a ragged or dense list of inputs to a ragged/dense numpy array.
|
|
39
|
+
|
|
40
|
+
The output is adjusted to be 2D.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
ragged_or_dense: Input that is either already a numpy array, or nested
|
|
44
|
+
sequence.
|
|
45
|
+
dtype: Numpy dtype of output array.
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
Corresponding numpy array.
|
|
49
|
+
"""
|
|
50
|
+
if hasattr(ragged_or_dense, "numpy"):
|
|
51
|
+
# Support tf.RaggedTensor and other TF input dtypes.
|
|
52
|
+
if callable(getattr(ragged_or_dense, "numpy")):
|
|
53
|
+
ragged_or_dense = ragged_or_dense.numpy()
|
|
54
|
+
|
|
55
|
+
if isinstance(ragged_or_dense, jax.Array):
|
|
56
|
+
ragged_or_dense = np.asarray(ragged_or_dense)
|
|
57
|
+
|
|
58
|
+
if isinstance(ragged_or_dense, np.ndarray):
|
|
59
|
+
# Convert 1D to 2D.
|
|
60
|
+
if ragged_or_dense.dtype != np.ndarray and ragged_or_dense.ndim == 1:
|
|
61
|
+
return ragged_or_dense.reshape(-1, 1).astype(dtype)
|
|
62
|
+
|
|
63
|
+
# If dense, return converted dense type.
|
|
64
|
+
if ragged_or_dense.dtype != np.ndarray:
|
|
65
|
+
return ragged_or_dense.astype(dtype)
|
|
66
|
+
|
|
67
|
+
# Ragged numpy array.
|
|
68
|
+
return ragged_or_dense
|
|
69
|
+
|
|
70
|
+
# Handle 1D sequence input.
|
|
71
|
+
if not isinstance(ragged_or_dense[0], collections.abc.Sequence):
|
|
72
|
+
return np.asarray(ragged_or_dense, dtype=dtype).reshape(-1, 1)
|
|
73
|
+
|
|
74
|
+
# Assemble elements into an nd-array.
|
|
75
|
+
counts = [len(vals) for vals in ragged_or_dense]
|
|
76
|
+
if all([count == counts[0] for count in counts]):
|
|
77
|
+
# Dense input.
|
|
78
|
+
return np.asarray(ragged_or_dense, dtype=dtype)
|
|
79
|
+
else:
|
|
80
|
+
# Ragged input, convert to ragged numpy arrays.
|
|
81
|
+
return np.array(
|
|
82
|
+
[np.array(row, dtype=dtype) for row in ragged_or_dense],
|
|
83
|
+
dtype=np.ndarray,
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def ones_like(
|
|
88
|
+
ragged_or_dense: np.ndarray[Any, Any], dtype: Any = None
|
|
89
|
+
) -> np.ndarray[Any, Any]:
|
|
90
|
+
"""Creates an array of ones the same as as the input.
|
|
91
|
+
|
|
92
|
+
This differs from traditional numpy in that a ragged input will lead to
|
|
93
|
+
a resulting ragged array of ones, whereas np.ones_like(...) will instead
|
|
94
|
+
only consider the outer array and return a 1D dense array of ones.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
ragged_or_dense: The ragged or dense input whose shape and data-type
|
|
98
|
+
define these same attributes of the returned array.
|
|
99
|
+
dtype: The data-type of the returned array.
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
An array of ones with the same shape as the input, and specified data
|
|
103
|
+
type.
|
|
104
|
+
"""
|
|
105
|
+
dtype = dtype or ragged_or_dense.dtype
|
|
106
|
+
if ragged_or_dense.dtype == np.ndarray:
|
|
107
|
+
# Ragged.
|
|
108
|
+
return np.array(
|
|
109
|
+
[np.ones_like(row, dtype=dtype) for row in ragged_or_dense],
|
|
110
|
+
dtype=np.ndarray,
|
|
111
|
+
)
|
|
112
|
+
else:
|
|
113
|
+
# Dense.
|
|
114
|
+
return np.ones_like(ragged_or_dense, dtype=dtype)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def create_feature_samples(
|
|
118
|
+
feature_structure: Nested[T],
|
|
119
|
+
feature_ids: Nested[ArrayLike | Sequence[int] | Sequence[Sequence[int]]],
|
|
120
|
+
feature_weights: None
|
|
121
|
+
| (Nested[ArrayLike | Sequence[float] | Sequence[Sequence[float]]]),
|
|
122
|
+
) -> Nested[FeatureSamples]:
|
|
123
|
+
"""Constructs a collection of sample tuples from provided IDs and weights.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
feature_structure: The nested structure of the inputs (typically
|
|
127
|
+
`FeatureSpec`s).
|
|
128
|
+
feature_ids: The feature IDs to use for the samples.
|
|
129
|
+
feature_weights: The feature weights to use for the samples. Defaults
|
|
130
|
+
to ones if not provided.
|
|
131
|
+
|
|
132
|
+
Returns:
|
|
133
|
+
A nested collection of `FeatureSamples` corresponding to the input IDs
|
|
134
|
+
and weights, for use in embedding lookups.
|
|
135
|
+
"""
|
|
136
|
+
# Create numpy arrays from inputs.
|
|
137
|
+
feature_ids = jax.tree.map(
|
|
138
|
+
lambda _, ids: convert_to_numpy(ids, np.int32),
|
|
139
|
+
feature_structure,
|
|
140
|
+
feature_ids,
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
if feature_weights is None:
|
|
144
|
+
# Make ragged or dense ones_like.
|
|
145
|
+
feature_weights = jax.tree.map(
|
|
146
|
+
lambda _, ids: ones_like(ids, np.float32),
|
|
147
|
+
feature_structure,
|
|
148
|
+
feature_ids,
|
|
149
|
+
)
|
|
150
|
+
else:
|
|
151
|
+
feature_weights = jax.tree.map(
|
|
152
|
+
lambda _, wgts: convert_to_numpy(wgts, np.float32),
|
|
153
|
+
feature_structure,
|
|
154
|
+
feature_weights,
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
# Assemble.
|
|
158
|
+
def _create_feature_samples(
|
|
159
|
+
sample_ids: np.ndarray[Any, Any],
|
|
160
|
+
sample_weights: np.ndarray[Any, Any],
|
|
161
|
+
) -> FeatureSamples:
|
|
162
|
+
return FeatureSamples(sample_ids, sample_weights)
|
|
163
|
+
|
|
164
|
+
output: Nested[FeatureSamples] = jax.tree.map(
|
|
165
|
+
lambda _, sample_ids, sample_weights: _create_feature_samples(
|
|
166
|
+
sample_ids, sample_weights
|
|
167
|
+
),
|
|
168
|
+
feature_structure,
|
|
169
|
+
feature_ids,
|
|
170
|
+
feature_weights,
|
|
171
|
+
)
|
|
172
|
+
return output
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def stack_and_shard_samples(
|
|
176
|
+
feature_specs: Nested[FeatureSpec],
|
|
177
|
+
feature_samples: Nested[FeatureSamples],
|
|
178
|
+
local_device_count: int,
|
|
179
|
+
global_device_count: int,
|
|
180
|
+
num_sc_per_device: int,
|
|
181
|
+
static_buffer_size: int | Mapping[str, int] | None = None,
|
|
182
|
+
) -> tuple[dict[str, ShardedCooMatrix], embedding.SparseDenseMatmulInputStats]:
|
|
183
|
+
"""Prepares input samples for use in embedding lookups.
|
|
184
|
+
|
|
185
|
+
Args:
|
|
186
|
+
feature_specs: Nested collection of feature specifications.
|
|
187
|
+
feature_samples: Nested collection of feature samples.
|
|
188
|
+
local_device_count: Number of local JAX devices.
|
|
189
|
+
global_device_count: Number of global JAX devices.
|
|
190
|
+
num_sc_per_device: Number of sparsecores per device.
|
|
191
|
+
static_buffer_size: The static buffer size to use for the samples.
|
|
192
|
+
Defaults to None, in which case an upper-bound for the buffer size
|
|
193
|
+
will be automatically determined.
|
|
194
|
+
|
|
195
|
+
Returns:
|
|
196
|
+
The preprocessed inputs, and statistics useful for updating FeatureSpecs
|
|
197
|
+
based on the provided input data.
|
|
198
|
+
"""
|
|
199
|
+
del static_buffer_size # Currently ignored.
|
|
200
|
+
flat_feature_specs, _ = jax.tree.flatten(feature_specs)
|
|
201
|
+
|
|
202
|
+
feature_tokens = []
|
|
203
|
+
feature_weights = []
|
|
204
|
+
|
|
205
|
+
def collect_tokens_and_weights(
|
|
206
|
+
feature_spec: FeatureSpec, samples: FeatureSamples
|
|
207
|
+
) -> None:
|
|
208
|
+
del feature_spec
|
|
209
|
+
feature_tokens.append(samples.tokens)
|
|
210
|
+
feature_weights.append(samples.weights)
|
|
211
|
+
|
|
212
|
+
jax.tree.map(collect_tokens_and_weights, feature_specs, feature_samples)
|
|
213
|
+
|
|
214
|
+
preprocessed_inputs, stats = embedding.preprocess_sparse_dense_matmul_input(
|
|
215
|
+
feature_tokens,
|
|
216
|
+
feature_weights,
|
|
217
|
+
flat_feature_specs,
|
|
218
|
+
local_device_count=local_device_count,
|
|
219
|
+
global_device_count=global_device_count,
|
|
220
|
+
num_sc_per_device=num_sc_per_device,
|
|
221
|
+
sharding_strategy="MOD",
|
|
222
|
+
has_leading_dimension=False,
|
|
223
|
+
allow_id_dropping=True,
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
out: dict[str, ShardedCooMatrix] = {}
|
|
227
|
+
tables_names = preprocessed_inputs.lhs_row_pointers.keys()
|
|
228
|
+
for table_name in tables_names:
|
|
229
|
+
shard_ends = preprocessed_inputs.lhs_row_pointers[table_name]
|
|
230
|
+
shard_starts = np.concatenate(
|
|
231
|
+
[
|
|
232
|
+
np.asarray([0]),
|
|
233
|
+
table_stacking._next_largest_multiple(shard_ends[:-1], 8),
|
|
234
|
+
]
|
|
235
|
+
)
|
|
236
|
+
out[table_name] = ShardedCooMatrix(
|
|
237
|
+
shard_starts=shard_starts,
|
|
238
|
+
shard_ends=shard_ends,
|
|
239
|
+
col_ids=preprocessed_inputs.lhs_embedding_ids[table_name],
|
|
240
|
+
row_ids=preprocessed_inputs.lhs_sample_ids[table_name],
|
|
241
|
+
values=preprocessed_inputs.lhs_gains[table_name],
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
return out, stats
|
|
@@ -1,535 +0,0 @@
|
|
|
1
|
-
"""Utility functions for manipulating JAX embedding tables and inputs."""
|
|
2
|
-
|
|
3
|
-
import collections
|
|
4
|
-
import typing
|
|
5
|
-
from typing import Any, Mapping, NamedTuple, Sequence, TypeAlias, TypeVar
|
|
6
|
-
|
|
7
|
-
import jax
|
|
8
|
-
import numpy as np
|
|
9
|
-
from jax import numpy as jnp
|
|
10
|
-
from jax_tpu_embedding.sparsecore.lib.nn import embedding
|
|
11
|
-
from jax_tpu_embedding.sparsecore.lib.nn.embedding_spec import FeatureSpec
|
|
12
|
-
from jax_tpu_embedding.sparsecore.lib.nn.embedding_spec import StackedTableSpec
|
|
13
|
-
from jax_tpu_embedding.sparsecore.lib.nn.embedding_spec import TableSpec
|
|
14
|
-
|
|
15
|
-
from keras_rs.src.types import Nested
|
|
16
|
-
|
|
17
|
-
T = TypeVar("T")
|
|
18
|
-
|
|
19
|
-
# Any to support tf.Ragged without needing an explicit TF dependency.
|
|
20
|
-
ArrayLike: TypeAlias = jax.Array | np.ndarray | Any # type: ignore
|
|
21
|
-
Shape: TypeAlias = tuple[int, ...]
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
class FeatureSamples(NamedTuple):
|
|
25
|
-
tokens: ArrayLike
|
|
26
|
-
weights: ArrayLike
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
class ShardedCooMatrix(NamedTuple):
|
|
30
|
-
shard_starts: ArrayLike
|
|
31
|
-
shard_ends: ArrayLike
|
|
32
|
-
col_ids: ArrayLike
|
|
33
|
-
row_ids: ArrayLike
|
|
34
|
-
values: ArrayLike
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
def _round_up_to_multiple(value: int, multiple: int) -> int:
|
|
38
|
-
return ((value + multiple - 1) // multiple) * multiple
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
def _default_stacked_table_spec(
|
|
42
|
-
table_spec: TableSpec, num_shards: int, batch_size: int
|
|
43
|
-
) -> StackedTableSpec:
|
|
44
|
-
return StackedTableSpec(
|
|
45
|
-
stack_name=table_spec.name,
|
|
46
|
-
stack_vocab_size=_round_up_to_multiple(
|
|
47
|
-
table_spec.vocabulary_size, 8 * num_shards
|
|
48
|
-
),
|
|
49
|
-
stack_embedding_dim=_round_up_to_multiple(table_spec.embedding_dim, 8),
|
|
50
|
-
optimizer=table_spec.optimizer,
|
|
51
|
-
combiner=table_spec.combiner,
|
|
52
|
-
total_sample_count=batch_size,
|
|
53
|
-
max_ids_per_partition=table_spec.max_ids_per_partition,
|
|
54
|
-
max_unique_ids_per_partition=table_spec.max_unique_ids_per_partition,
|
|
55
|
-
)
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
def _get_stacked_table_spec(
|
|
59
|
-
table_spec: TableSpec, num_shards: int, batch_size: int = 0
|
|
60
|
-
) -> StackedTableSpec:
|
|
61
|
-
return table_spec.stacked_table_spec or _default_stacked_table_spec(
|
|
62
|
-
table_spec, num_shards, batch_size
|
|
63
|
-
)
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
def pad_table(
|
|
67
|
-
table_spec: TableSpec,
|
|
68
|
-
table_values: jax.Array,
|
|
69
|
-
num_shards: int,
|
|
70
|
-
pad_value: jnp.float32 = jnp.nan,
|
|
71
|
-
) -> jax.Array:
|
|
72
|
-
"""Adds appropriate padding to a table to prepare for stacking.
|
|
73
|
-
|
|
74
|
-
Args:
|
|
75
|
-
table_spec: Table specification describing the table to pad.
|
|
76
|
-
table_values: Table values array to pad.
|
|
77
|
-
num_shards: Number of shards in the table (typically
|
|
78
|
-
`global_device_count * num_sc_per_device`).
|
|
79
|
-
pad_value: Value to use for padding.
|
|
80
|
-
|
|
81
|
-
Returns:
|
|
82
|
-
Padded table values.
|
|
83
|
-
"""
|
|
84
|
-
vocabulary_size = table_spec.vocabulary_size
|
|
85
|
-
embedding_dim = table_spec.embedding_dim
|
|
86
|
-
padded_vocabulary_size = _round_up_to_multiple(
|
|
87
|
-
vocabulary_size, 8 * num_shards
|
|
88
|
-
)
|
|
89
|
-
stack_embedding_dim = _get_stacked_table_spec(
|
|
90
|
-
table_spec, num_shards
|
|
91
|
-
).stack_embedding_dim
|
|
92
|
-
return jnp.pad(
|
|
93
|
-
table_values,
|
|
94
|
-
(
|
|
95
|
-
(0, padded_vocabulary_size - vocabulary_size),
|
|
96
|
-
(0, stack_embedding_dim - embedding_dim),
|
|
97
|
-
),
|
|
98
|
-
constant_values=pad_value,
|
|
99
|
-
)
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
def _stack_and_shard_table(
|
|
103
|
-
stacked_table: jax.Array,
|
|
104
|
-
table_spec: TableSpec,
|
|
105
|
-
table: jax.Array,
|
|
106
|
-
num_shards: int,
|
|
107
|
-
pad_value: jnp.float32,
|
|
108
|
-
) -> jax.Array:
|
|
109
|
-
"""Stacks and shards a single table for use in sparsecore lookups."""
|
|
110
|
-
padded_values = pad_table(table_spec, table, num_shards, pad_value)
|
|
111
|
-
sharded_padded_vocabulary_size = padded_values.shape[0] // num_shards
|
|
112
|
-
stack_embedding_dim = stacked_table.shape[-1]
|
|
113
|
-
|
|
114
|
-
# Mod-shard vocabulary across devices.
|
|
115
|
-
sharded_values = jnp.swapaxes(
|
|
116
|
-
padded_values.reshape(-1, num_shards, stack_embedding_dim),
|
|
117
|
-
0,
|
|
118
|
-
1,
|
|
119
|
-
)
|
|
120
|
-
|
|
121
|
-
# Rotate shards.
|
|
122
|
-
setting_in_stack = table_spec.setting_in_stack
|
|
123
|
-
rotated_values = jnp.roll(
|
|
124
|
-
sharded_values, setting_in_stack.shard_rotation, axis=0
|
|
125
|
-
)
|
|
126
|
-
|
|
127
|
-
# Insert table into the stack.
|
|
128
|
-
table_row = setting_in_stack.row_offset_in_shard
|
|
129
|
-
stacked_table = stacked_table.at[
|
|
130
|
-
:, table_row : (table_row + sharded_padded_vocabulary_size), :
|
|
131
|
-
].set(rotated_values)
|
|
132
|
-
|
|
133
|
-
return stacked_table
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
def stack_and_shard_tables(
|
|
137
|
-
table_specs: Nested[TableSpec],
|
|
138
|
-
tables: Nested[ArrayLike],
|
|
139
|
-
num_shards: int,
|
|
140
|
-
pad_value: jnp.float32 = jnp.nan,
|
|
141
|
-
) -> dict[str, Nested[jax.Array]]:
|
|
142
|
-
"""Stacks and shards tables for use in sparsecore lookups.
|
|
143
|
-
|
|
144
|
-
Args:
|
|
145
|
-
table_specs: Nested collection of unstacked table specifications.
|
|
146
|
-
tables: Table values corresponding to the table_specs.
|
|
147
|
-
num_shards: Number of shards in the table (typically
|
|
148
|
-
`global_device_count * num_sc_per_device`).
|
|
149
|
-
pad_value: Value to use for padding.
|
|
150
|
-
|
|
151
|
-
Returns:
|
|
152
|
-
A mapping of stacked table names to stacked table values.
|
|
153
|
-
"""
|
|
154
|
-
|
|
155
|
-
# Gather stacked table information.
|
|
156
|
-
stacked_table_map: dict[
|
|
157
|
-
str,
|
|
158
|
-
tuple[StackedTableSpec, list[TableSpec]],
|
|
159
|
-
] = {}
|
|
160
|
-
|
|
161
|
-
def collect_stacked_tables(table_spec: TableSpec) -> None:
|
|
162
|
-
stacked_table_spec = _get_stacked_table_spec(table_spec, num_shards)
|
|
163
|
-
stacked_table_name = stacked_table_spec.stack_name
|
|
164
|
-
if stacked_table_name not in stacked_table_map:
|
|
165
|
-
stacked_table_map[stacked_table_name] = (stacked_table_spec, [])
|
|
166
|
-
stacked_table_map[stacked_table_name][1].append(table_spec)
|
|
167
|
-
|
|
168
|
-
_ = jax.tree.map(collect_stacked_tables, table_specs)
|
|
169
|
-
|
|
170
|
-
table_map: dict[str, Nested[jax.Array]] = {}
|
|
171
|
-
|
|
172
|
-
def collect_tables(table_spec: TableSpec, table: Nested[jax.Array]) -> None:
|
|
173
|
-
table_map[table_spec.name] = table
|
|
174
|
-
|
|
175
|
-
_ = jax.tree.map(collect_tables, table_specs, tables)
|
|
176
|
-
|
|
177
|
-
stacked_tables: dict[str, Nested[jax.Array]] = {}
|
|
178
|
-
for (
|
|
179
|
-
stacked_table_spec,
|
|
180
|
-
table_specs,
|
|
181
|
-
) in stacked_table_map.values():
|
|
182
|
-
stack_vocab_size = stacked_table_spec.stack_vocab_size
|
|
183
|
-
sharded_vocab_size = stack_vocab_size // num_shards
|
|
184
|
-
stack_embedding_dim = stacked_table_spec.stack_embedding_dim
|
|
185
|
-
|
|
186
|
-
# Allocate initial buffer. The stacked table will be divided among
|
|
187
|
-
# shards by splitting the vocabulary dimension:
|
|
188
|
-
# [ v, e ] -> [s, v/s, e]
|
|
189
|
-
stacked_table_tree = jax.tree.map(
|
|
190
|
-
lambda _: jnp.zeros(
|
|
191
|
-
# pylint: disable-next=cell-var-from-loop, used only in loop body.
|
|
192
|
-
shape=(num_shards, sharded_vocab_size, stack_embedding_dim),
|
|
193
|
-
dtype=jnp.float32,
|
|
194
|
-
),
|
|
195
|
-
table_map[table_specs[0].name],
|
|
196
|
-
)
|
|
197
|
-
|
|
198
|
-
for table_spec in table_specs:
|
|
199
|
-
table_tree = table_map[table_spec.name]
|
|
200
|
-
stacked_table_tree = jax.tree.map(
|
|
201
|
-
lambda stacked_table, table: _stack_and_shard_table(
|
|
202
|
-
# pylint: disable-next=cell-var-from-loop, used only in loop body.
|
|
203
|
-
stacked_table,
|
|
204
|
-
# pylint: disable-next=cell-var-from-loop, used only in loop body.
|
|
205
|
-
table_spec,
|
|
206
|
-
table,
|
|
207
|
-
num_shards,
|
|
208
|
-
pad_value,
|
|
209
|
-
),
|
|
210
|
-
stacked_table_tree,
|
|
211
|
-
table_tree,
|
|
212
|
-
)
|
|
213
|
-
|
|
214
|
-
stacked_tables[stacked_table_spec.stack_name] = stacked_table_tree
|
|
215
|
-
|
|
216
|
-
return stacked_tables
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
def _unshard_and_unstack_table(
|
|
220
|
-
table_spec: TableSpec,
|
|
221
|
-
stacked_table_tree: Nested[jax.Array],
|
|
222
|
-
num_shards: int,
|
|
223
|
-
) -> Nested[jax.Array]:
|
|
224
|
-
"""Unshards and unstacks a single table."""
|
|
225
|
-
vocabulary_size = table_spec.vocabulary_size
|
|
226
|
-
embedding_dim = table_spec.embedding_dim
|
|
227
|
-
|
|
228
|
-
def _unshard_and_unstack_single_table(
|
|
229
|
-
table_spec: TableSpec, stacked_table: jax.Array
|
|
230
|
-
) -> jax.Array:
|
|
231
|
-
stack_embedding_dim = stacked_table.shape[-1]
|
|
232
|
-
|
|
233
|
-
# Maybe re-shape in case it was flattened.
|
|
234
|
-
stacked_table = stacked_table.reshape(
|
|
235
|
-
num_shards, -1, stack_embedding_dim
|
|
236
|
-
)
|
|
237
|
-
sharded_vocabulary_size = (
|
|
238
|
-
_round_up_to_multiple(vocabulary_size, 8 * num_shards) // num_shards
|
|
239
|
-
)
|
|
240
|
-
|
|
241
|
-
# Extract padded values from the stacked table.
|
|
242
|
-
setting_in_stack = table_spec.setting_in_stack
|
|
243
|
-
row = setting_in_stack.row_offset_in_shard
|
|
244
|
-
padded_values = stacked_table[
|
|
245
|
-
:, row : (row + sharded_vocabulary_size), :
|
|
246
|
-
]
|
|
247
|
-
|
|
248
|
-
# Un-rotate shards.
|
|
249
|
-
padded_values = jnp.roll(
|
|
250
|
-
padded_values, -setting_in_stack.shard_rotation, axis=0
|
|
251
|
-
)
|
|
252
|
-
|
|
253
|
-
# Un-mod-shard.
|
|
254
|
-
padded_values = jnp.swapaxes(padded_values, 0, 1).reshape(
|
|
255
|
-
-1, stack_embedding_dim
|
|
256
|
-
)
|
|
257
|
-
|
|
258
|
-
# Un-pad.
|
|
259
|
-
return padded_values[:vocabulary_size, :embedding_dim]
|
|
260
|
-
|
|
261
|
-
output: Nested[jax.Array] = jax.tree.map(
|
|
262
|
-
lambda stacked_table: _unshard_and_unstack_single_table(
|
|
263
|
-
table_spec, stacked_table
|
|
264
|
-
),
|
|
265
|
-
stacked_table_tree,
|
|
266
|
-
)
|
|
267
|
-
return output
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
def unshard_and_unstack_tables(
|
|
271
|
-
table_specs: Nested[TableSpec],
|
|
272
|
-
stacked_tables: Mapping[str, Nested[jax.Array]],
|
|
273
|
-
num_shards: int,
|
|
274
|
-
) -> Nested[jax.Array]:
|
|
275
|
-
"""Unshards and unstacks a collection of tables.
|
|
276
|
-
|
|
277
|
-
Args:
|
|
278
|
-
table_specs: Nested collection of unstacked table specifications.
|
|
279
|
-
stacked_tables: Mapping of stacked table names to stacked table values.
|
|
280
|
-
num_shards: Number of shards in the table (typically
|
|
281
|
-
`global_device_count * num_sc_per_device`).
|
|
282
|
-
|
|
283
|
-
Returns:
|
|
284
|
-
A mapping of table names to unstacked table values.
|
|
285
|
-
"""
|
|
286
|
-
output: Nested[jax.Array] = jax.tree.map(
|
|
287
|
-
lambda table_spec: _unshard_and_unstack_table(
|
|
288
|
-
table_spec,
|
|
289
|
-
stacked_tables[
|
|
290
|
-
_get_stacked_table_spec(table_spec, num_shards=1).stack_name
|
|
291
|
-
],
|
|
292
|
-
num_shards,
|
|
293
|
-
),
|
|
294
|
-
table_specs,
|
|
295
|
-
)
|
|
296
|
-
return output
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
def get_table_stacks(
|
|
300
|
-
table_specs: Nested[TableSpec],
|
|
301
|
-
) -> dict[str, list[TableSpec]]:
|
|
302
|
-
"""Extracts lists of tables that are stacked together.
|
|
303
|
-
|
|
304
|
-
Args:
|
|
305
|
-
table_specs: Nested collection of table specifications.
|
|
306
|
-
|
|
307
|
-
Returns:
|
|
308
|
-
A mapping of stacked table names to lists of table specifications for
|
|
309
|
-
each stack.
|
|
310
|
-
"""
|
|
311
|
-
stacked_table_specs: dict[str, list[TableSpec]] = collections.defaultdict(
|
|
312
|
-
list
|
|
313
|
-
)
|
|
314
|
-
flat_table_specs, _ = jax.tree.flatten(table_specs)
|
|
315
|
-
for table_spec in flat_table_specs:
|
|
316
|
-
table_spec = typing.cast(TableSpec, table_spec)
|
|
317
|
-
stacked_table_spec = table_spec.stacked_table_spec
|
|
318
|
-
if stacked_table_spec is not None:
|
|
319
|
-
stacked_table_specs[stacked_table_spec.stack_name].append(
|
|
320
|
-
table_spec
|
|
321
|
-
)
|
|
322
|
-
else:
|
|
323
|
-
stacked_table_specs[table_spec.name].append(table_spec)
|
|
324
|
-
|
|
325
|
-
return stacked_table_specs
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
def convert_to_numpy(
|
|
329
|
-
ragged_or_dense: np.ndarray[Any, Any] | Sequence[Sequence[Any]] | Any,
|
|
330
|
-
dtype: Any,
|
|
331
|
-
) -> np.ndarray[Any, Any]:
|
|
332
|
-
"""Converts a ragged or dense list of inputs to a ragged/dense numpy array.
|
|
333
|
-
|
|
334
|
-
The output is adjusted to be 2D.
|
|
335
|
-
|
|
336
|
-
Args:
|
|
337
|
-
ragged_or_dense: Input that is either already a numpy array, or nested
|
|
338
|
-
sequence.
|
|
339
|
-
dtype: Numpy dtype of output array.
|
|
340
|
-
|
|
341
|
-
Returns:
|
|
342
|
-
Corresponding numpy array.
|
|
343
|
-
"""
|
|
344
|
-
if hasattr(ragged_or_dense, "numpy"):
|
|
345
|
-
# Support tf.RaggedTensor and other TF input dtypes.
|
|
346
|
-
if callable(getattr(ragged_or_dense, "numpy")):
|
|
347
|
-
ragged_or_dense = ragged_or_dense.numpy()
|
|
348
|
-
|
|
349
|
-
if isinstance(ragged_or_dense, jax.Array):
|
|
350
|
-
ragged_or_dense = np.asarray(ragged_or_dense)
|
|
351
|
-
|
|
352
|
-
if isinstance(ragged_or_dense, np.ndarray):
|
|
353
|
-
# Convert 1D to 2D.
|
|
354
|
-
if ragged_or_dense.dtype != np.ndarray and ragged_or_dense.ndim == 1:
|
|
355
|
-
return ragged_or_dense.reshape(-1, 1).astype(dtype)
|
|
356
|
-
|
|
357
|
-
# If dense, return converted dense type.
|
|
358
|
-
if ragged_or_dense.dtype != np.ndarray:
|
|
359
|
-
return ragged_or_dense.astype(dtype)
|
|
360
|
-
|
|
361
|
-
# Ragged numpy array.
|
|
362
|
-
return ragged_or_dense
|
|
363
|
-
|
|
364
|
-
# Handle 1D sequence input.
|
|
365
|
-
if not isinstance(ragged_or_dense[0], collections.abc.Sequence):
|
|
366
|
-
return np.asarray(ragged_or_dense, dtype=dtype).reshape(-1, 1)
|
|
367
|
-
|
|
368
|
-
# Assemble elements into an nd-array.
|
|
369
|
-
counts = [len(vals) for vals in ragged_or_dense]
|
|
370
|
-
if all([count == counts[0] for count in counts]):
|
|
371
|
-
# Dense input.
|
|
372
|
-
return np.asarray(ragged_or_dense, dtype=dtype)
|
|
373
|
-
else:
|
|
374
|
-
# Ragged input, convert to ragged numpy arrays.
|
|
375
|
-
return np.array(
|
|
376
|
-
[np.array(row, dtype=dtype) for row in ragged_or_dense],
|
|
377
|
-
dtype=np.ndarray,
|
|
378
|
-
)
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
def ones_like(
|
|
382
|
-
ragged_or_dense: np.ndarray[Any, Any], dtype: Any = None
|
|
383
|
-
) -> np.ndarray[Any, Any]:
|
|
384
|
-
"""Creates an array of ones the same as as the input.
|
|
385
|
-
|
|
386
|
-
This differs from traditional numpy in that a ragged input will lead to
|
|
387
|
-
a resulting ragged array of ones, whereas np.ones_like(...) will instead
|
|
388
|
-
only consider the outer array and return a 1D dense array of ones.
|
|
389
|
-
|
|
390
|
-
Args:
|
|
391
|
-
ragged_or_dense: The ragged or dense input whose shape and data-type
|
|
392
|
-
define these same attributes of the returned array.
|
|
393
|
-
dtype: The data-type of the returned array.
|
|
394
|
-
|
|
395
|
-
Returns:
|
|
396
|
-
An array of ones with the same shape as the input, and specified data
|
|
397
|
-
type.
|
|
398
|
-
"""
|
|
399
|
-
dtype = dtype or ragged_or_dense.dtype
|
|
400
|
-
if ragged_or_dense.dtype == np.ndarray:
|
|
401
|
-
# Ragged.
|
|
402
|
-
return np.array(
|
|
403
|
-
[np.ones_like(row, dtype=dtype) for row in ragged_or_dense],
|
|
404
|
-
dtype=np.ndarray,
|
|
405
|
-
)
|
|
406
|
-
else:
|
|
407
|
-
# Dense.
|
|
408
|
-
return np.ones_like(ragged_or_dense, dtype=dtype)
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
def create_feature_samples(
|
|
412
|
-
feature_structure: Nested[T],
|
|
413
|
-
feature_ids: Nested[ArrayLike | Sequence[int] | Sequence[Sequence[int]]],
|
|
414
|
-
feature_weights: None
|
|
415
|
-
| (Nested[ArrayLike | Sequence[float] | Sequence[Sequence[float]]]),
|
|
416
|
-
) -> Nested[FeatureSamples]:
|
|
417
|
-
"""Constructs a collection of sample tuples from provided IDs and weights.
|
|
418
|
-
|
|
419
|
-
Args:
|
|
420
|
-
feature_structure: The nested structure of the inputs (typically
|
|
421
|
-
`FeatureSpec`s).
|
|
422
|
-
feature_ids: The feature IDs to use for the samples.
|
|
423
|
-
feature_weights: The feature weights to use for the samples. Defaults
|
|
424
|
-
to ones if not provided.
|
|
425
|
-
|
|
426
|
-
Returns:
|
|
427
|
-
A nested collection of `FeatureSamples` corresponding to the input IDs
|
|
428
|
-
and weights, for use in embedding lookups.
|
|
429
|
-
"""
|
|
430
|
-
# Create numpy arrays from inputs.
|
|
431
|
-
feature_ids = jax.tree.map(
|
|
432
|
-
lambda _, ids: convert_to_numpy(ids, np.int32),
|
|
433
|
-
feature_structure,
|
|
434
|
-
feature_ids,
|
|
435
|
-
)
|
|
436
|
-
|
|
437
|
-
if feature_weights is None:
|
|
438
|
-
# Make ragged or dense ones_like.
|
|
439
|
-
feature_weights = jax.tree.map(
|
|
440
|
-
lambda _, ids: ones_like(ids, np.float32),
|
|
441
|
-
feature_structure,
|
|
442
|
-
feature_ids,
|
|
443
|
-
)
|
|
444
|
-
else:
|
|
445
|
-
feature_weights = jax.tree.map(
|
|
446
|
-
lambda _, wgts: convert_to_numpy(wgts, np.float32),
|
|
447
|
-
feature_structure,
|
|
448
|
-
feature_weights,
|
|
449
|
-
)
|
|
450
|
-
|
|
451
|
-
# Assemble.
|
|
452
|
-
def _create_feature_samples(
|
|
453
|
-
sample_ids: np.ndarray[Any, Any],
|
|
454
|
-
sample_weights: np.ndarray[Any, Any],
|
|
455
|
-
) -> FeatureSamples:
|
|
456
|
-
return FeatureSamples(sample_ids, sample_weights)
|
|
457
|
-
|
|
458
|
-
output: Nested[FeatureSamples] = jax.tree.map(
|
|
459
|
-
lambda _, sample_ids, sample_weights: _create_feature_samples(
|
|
460
|
-
sample_ids, sample_weights
|
|
461
|
-
),
|
|
462
|
-
feature_structure,
|
|
463
|
-
feature_ids,
|
|
464
|
-
feature_weights,
|
|
465
|
-
)
|
|
466
|
-
return output
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
def stack_and_shard_samples(
|
|
470
|
-
feature_specs: Nested[FeatureSpec],
|
|
471
|
-
feature_samples: Nested[FeatureSamples],
|
|
472
|
-
local_device_count: int,
|
|
473
|
-
global_device_count: int,
|
|
474
|
-
num_sc_per_device: int,
|
|
475
|
-
static_buffer_size: int | Mapping[str, int] | None = None,
|
|
476
|
-
) -> tuple[dict[str, ShardedCooMatrix], embedding.SparseDenseMatmulInputStats]:
|
|
477
|
-
"""Prepares input samples for use in embedding lookups.
|
|
478
|
-
|
|
479
|
-
Args:
|
|
480
|
-
feature_specs: Nested collection of feature specifications.
|
|
481
|
-
feature_samples: Nested collection of feature samples.
|
|
482
|
-
local_device_count: Number of local JAX devices.
|
|
483
|
-
global_device_count: Number of global JAX devices.
|
|
484
|
-
num_sc_per_device: Number of sparsecores per device.
|
|
485
|
-
static_buffer_size: The static buffer size to use for the samples.
|
|
486
|
-
Defaults to None, in which case an upper-bound for the buffer size
|
|
487
|
-
will be automatically determined.
|
|
488
|
-
|
|
489
|
-
Returns:
|
|
490
|
-
The preprocessed inputs, and statistics useful for updating FeatureSpecs
|
|
491
|
-
based on the provided input data.
|
|
492
|
-
"""
|
|
493
|
-
del static_buffer_size # Currently ignored.
|
|
494
|
-
flat_feature_specs, _ = jax.tree.flatten(feature_specs)
|
|
495
|
-
|
|
496
|
-
feature_tokens = []
|
|
497
|
-
feature_weights = []
|
|
498
|
-
|
|
499
|
-
def collect_tokens_and_weights(
|
|
500
|
-
feature_spec: FeatureSpec, samples: FeatureSamples
|
|
501
|
-
) -> None:
|
|
502
|
-
del feature_spec
|
|
503
|
-
feature_tokens.append(samples.tokens)
|
|
504
|
-
feature_weights.append(samples.weights)
|
|
505
|
-
|
|
506
|
-
jax.tree.map(collect_tokens_and_weights, feature_specs, feature_samples)
|
|
507
|
-
|
|
508
|
-
preprocessed_inputs, stats = embedding.preprocess_sparse_dense_matmul_input(
|
|
509
|
-
feature_tokens,
|
|
510
|
-
feature_weights,
|
|
511
|
-
flat_feature_specs,
|
|
512
|
-
local_device_count=local_device_count,
|
|
513
|
-
global_device_count=global_device_count,
|
|
514
|
-
num_sc_per_device=num_sc_per_device,
|
|
515
|
-
sharding_strategy="MOD",
|
|
516
|
-
has_leading_dimension=False,
|
|
517
|
-
allow_id_dropping=True,
|
|
518
|
-
)
|
|
519
|
-
|
|
520
|
-
out: dict[str, ShardedCooMatrix] = {}
|
|
521
|
-
tables_names = preprocessed_inputs.lhs_row_pointers.keys()
|
|
522
|
-
for table_name in tables_names:
|
|
523
|
-
shard_ends = preprocessed_inputs.lhs_row_pointers[table_name]
|
|
524
|
-
shard_starts = np.concatenate(
|
|
525
|
-
[np.asarray([0]), _round_up_to_multiple(shard_ends[:-1], 8)]
|
|
526
|
-
)
|
|
527
|
-
out[table_name] = ShardedCooMatrix(
|
|
528
|
-
shard_starts=shard_starts,
|
|
529
|
-
shard_ends=shard_ends,
|
|
530
|
-
col_ids=preprocessed_inputs.lhs_embedding_ids[table_name],
|
|
531
|
-
row_ids=preprocessed_inputs.lhs_sample_ids[table_name],
|
|
532
|
-
values=preprocessed_inputs.lhs_gains[table_name],
|
|
533
|
-
)
|
|
534
|
-
|
|
535
|
-
return out, stats
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{keras_rs_nightly-0.3.1.dev202510170329 → keras_rs_nightly-0.3.1.dev202510180323}/pyproject.toml
RENAMED
|
File without changes
|
|
File without changes
|