tpu-inference 0.11.1rc2__py3-none-any.whl → 0.11.1rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (50) hide show
  1. tpu_inference/kernels/collectives/__init__.py +0 -0
  2. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  3. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  4. tpu_inference/kernels/collectives/util.py +47 -0
  5. tpu_inference/layers/__init__.py +0 -0
  6. tpu_inference/layers/common/__init__.py +0 -0
  7. tpu_inference/layers/common/attention_metadata.py +34 -0
  8. tpu_inference/layers/jax/__init__.py +0 -0
  9. tpu_inference/layers/jax/attention/__init__.py +0 -0
  10. tpu_inference/layers/jax/attention/attention.py +254 -0
  11. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  12. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  13. tpu_inference/layers/jax/attention_interface.py +356 -0
  14. tpu_inference/layers/jax/base.py +151 -0
  15. tpu_inference/layers/jax/binary_search.py +295 -0
  16. tpu_inference/layers/jax/constants.py +88 -0
  17. tpu_inference/layers/jax/layers.py +301 -0
  18. tpu_inference/layers/jax/misc.py +16 -0
  19. tpu_inference/layers/jax/moe/__init__.py +0 -0
  20. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  21. tpu_inference/layers/jax/moe/moe.py +209 -0
  22. tpu_inference/layers/jax/rope.py +172 -0
  23. tpu_inference/layers/jax/rope_interface.py +214 -0
  24. tpu_inference/layers/jax/sample/__init__.py +0 -0
  25. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  26. tpu_inference/layers/jax/sample/sampling.py +95 -0
  27. tpu_inference/layers/jax/sample/sampling_metadata.py +69 -0
  28. tpu_inference/layers/jax/sharding.py +406 -0
  29. tpu_inference/layers/jax/transformer_block.py +76 -0
  30. tpu_inference/layers/vllm/__init__.py +0 -0
  31. tpu_inference/layers/vllm/attention.py +184 -0
  32. tpu_inference/layers/vllm/fused_moe.py +399 -0
  33. tpu_inference/layers/vllm/linear_common.py +186 -0
  34. tpu_inference/layers/vllm/quantization/__init__.py +34 -0
  35. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  36. tpu_inference/layers/vllm/quantization/common.py +105 -0
  37. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  38. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +121 -0
  39. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  40. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  41. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  42. tpu_inference/layers/vllm/quantization/unquantized.py +263 -0
  43. tpu_inference/layers/vllm/sharding.py +151 -0
  44. tpu_inference/models/common/__init__.py +0 -0
  45. tpu_inference/models/common/model_loader.py +433 -0
  46. {tpu_inference-0.11.1rc2.dist-info → tpu_inference-0.11.1rc3.dist-info}/METADATA +6 -6
  47. {tpu_inference-0.11.1rc2.dist-info → tpu_inference-0.11.1rc3.dist-info}/RECORD +50 -5
  48. {tpu_inference-0.11.1rc2.dist-info → tpu_inference-0.11.1rc3.dist-info}/WHEEL +1 -1
  49. {tpu_inference-0.11.1rc2.dist-info → tpu_inference-0.11.1rc3.dist-info}/licenses/LICENSE +0 -0
  50. {tpu_inference-0.11.1rc2.dist-info → tpu_inference-0.11.1rc3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,295 @@
1
+ # Copyright 2024 The T5X Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Binary search over float32 bits.
15
+
16
+ Includes fast algorithms top-k masking and top-p masking on probability
17
+ distributions.
18
+ """
19
+
20
+ from typing import Callable, Sequence
21
+
22
+ import jax
23
+ from jax import lax
24
+ from jax import numpy as jnp
25
+
26
+
27
+ def int32_bsearch(batch_shape: Sequence[int],
28
+ predicate: Callable[[jnp.ndarray], jnp.ndarray]):
29
+ """Batched binary search over int32 values.
30
+
31
+ For each element of the batch, search for the largest int32 (closest to
32
+ positive infinity) for which the predicate is False. If the predicate is
33
+ always True, returns the minimum int32 value.
34
+
35
+ Args:
36
+ batch_shape: Shape of the search that we're batching over.
37
+ predicate: the query we're searching for. For every batch element, this is
38
+ required to be a monotonic function from int32 to bool. In other words,
39
+ the predicate must return False for all numbers <= some threshold and then
40
+ return True for all numbers > that threshold. The threshold may be
41
+ different for different elements of the batch.
42
+
43
+ Returns:
44
+ For each element of the batch, the largest int32 for which the predicate
45
+ returns False. Shape: batch_shape.
46
+ """
47
+ current_bits = jnp.zeros(batch_shape, dtype=jnp.int32)
48
+
49
+ # bit 31 is special, because it compares in the opposite order of all other
50
+ # bits. we use uint32 due to numpy promotion/casting rules.
51
+ midpoint = current_bits
52
+ predicate_satisfied = predicate(midpoint)
53
+ current_bits = current_bits | jnp.where(predicate_satisfied,
54
+ jnp.uint32(1 << 31), jnp.uint32(0))
55
+ del midpoint, predicate_satisfied
56
+
57
+ def loop_body(i, current_bits):
58
+ bit_index = 30 - i
59
+ bit = jnp.int32(1 << bit_index)
60
+ midpoint = current_bits | bit
61
+ predicate_satisfied = predicate(midpoint)
62
+ current_bits = current_bits | jnp.where(predicate_satisfied,
63
+ jnp.int32(0), bit)
64
+ return current_bits
65
+
66
+ current_bits = lax.fori_loop(0, 31, loop_body, current_bits)
67
+ return current_bits
68
+
69
+
70
+ def _monotonic_int32_to_float32_bit_pattern(x: int) -> int:
71
+ """Converts an int32 to a float32 bit pattern with consistent ordering.
72
+
73
+ This function is the unique function that is monotonic with respect to the
74
+ floating point total order, see
75
+ https://en.wikipedia.org/wiki/IEEE_754#Total-ordering_predicate. Note that
76
+ this function returns an int32, not a float32. For the function that returns
77
+ float32, see `monotonic_int32_to_float32`.
78
+
79
+ Args:
80
+ x: int bit pattern.
81
+
82
+ Returns:
83
+ Bit pattern of a float32 number.
84
+ """
85
+ non_sign_bits = jnp.int32((1 << 31) - 1)
86
+ # See
87
+ # https://stackoverflow.com/questions/20097380/iee-754-total-order-in-standard-c11
88
+ # for the relationship between int32 order and f32 total order, including
89
+ # the "xor trick".
90
+
91
+ # Flip the sort order for numbers where the sign bit is set. On int32,
92
+ # the bit pattern with sign bit set and all other bits clear is the most
93
+ # negative bit pattern (it's int32::MIN), whereas on float32 it's the least
94
+ # negative bit pattern (it's -0.0). Flipping all the non-sign bits makes the
95
+ # int32 sort order consistent with the float32 sort order.
96
+ x = x ^ jnp.where(x < 0, non_sign_bits, jnp.int32(0))
97
+ return x
98
+
99
+
100
+ def _monotonic_int32_to_float32(x: int) -> jax.Array:
101
+ """Converts an int32 to a float32 with consistent ordering.
102
+
103
+ This function is the unique function that is monotonic with respect to the
104
+ floating point total order, see
105
+ https://en.wikipedia.org/wiki/IEEE_754#Total-ordering_predicate.
106
+
107
+ Args:
108
+ x: int bit pattern.
109
+
110
+ Returns:
111
+ float32 number with consistent ordering.
112
+ """
113
+ x = _monotonic_int32_to_float32_bit_pattern(x)
114
+ return lax.bitcast_convert_type(x, jnp.float32)
115
+
116
+
117
+ def float32_bsearch(batch_shape, predicate):
118
+ """Binary search on finite float32 numbers.
119
+
120
+ For each element of the batch, this function searches for the largest finite
121
+ non-NaN float32 for which the predicate is False.
122
+
123
+ Args:
124
+ batch_shape: Shape of the search that we're batching over.
125
+ predicate: the query we're searching for. This is required to be monotonic
126
+ with respect to the floating point order, i.e. it must be False for all
127
+ numbers <= a threshold, and then True for all numbers > the threshold. The
128
+ threshold may be different for different elements of the batch.
129
+
130
+ Returns:
131
+ For each element of the batch, the largest float32 for which the predicate
132
+ returns False. Shape: f32[batch_shape].
133
+ """
134
+ exponent_bits = jnp.int32((1 << 31) - (1 << (31 - 8)))
135
+
136
+ def int32_predicate(x):
137
+ x = _monotonic_int32_to_float32_bit_pattern(x)
138
+ is_finite = (x & exponent_bits) != exponent_bits
139
+
140
+ # Non-finite numbers (infinity and NaN) are at the very extremes of the
141
+ # int32 range, i.e. they include int32::MAX and int32::MIN, plus the numbers
142
+ # adjacent to them. For the nonfinite numbers touching int32::MIN, we
143
+ # arrange for them to return False from the predicate, and for the nonfinite
144
+ # numbers touching int32::MAX, we arrange for them to return True from the
145
+ # predicate. x>=0 is an easy way to achieve that.
146
+ predicate_on_nonfinite = x >= 0
147
+ x_float32 = lax.bitcast_convert_type(x, jnp.float32)
148
+ return jnp.where(is_finite, predicate(x_float32),
149
+ predicate_on_nonfinite)
150
+
151
+ # We search over bit patterns, which requires bit shifting and ordering of bit
152
+ # patterns. This is natively supported on int32 but not on float32.
153
+ # Additionally, it's more common to reason about int32 bit arithmetic and
154
+ # ordering than float32 bit arithmetic and ordering, so we do the core of our
155
+ # search in int32. Additionally, this allows us to test the underlying binary
156
+ # search on int32 values.
157
+ #
158
+ # The function _monotonic_int32_to_float32 encapsulates all of the knowledge
159
+ # we need about float32 bit patterns.
160
+ result = int32_bsearch(batch_shape, int32_predicate)
161
+ return _monotonic_int32_to_float32(result)
162
+
163
+
164
+ def topk_mask(x: jnp.ndarray, k: int, replace_val: jnp.ndarray) -> jnp.ndarray:
165
+ """Sets everything to replace_val, except the top k values per batch element.
166
+
167
+ Sharding considerations: this function does 32 reductions over the vocab_size
168
+ axis of the input array. To avoid excessive latency from these reductions, you
169
+ should ensure that the vocab_size axis is unsharded on input to this function.
170
+ Prefer to shard the batch axes instead.
171
+
172
+ Scratchpad memory considerations: this function is most efficient if the
173
+ entire input array can fit in a fast memory tier. To help ensure this, you may
174
+ wish to split the batch axes into microbatches and the microbatches in a
175
+ sequential loop.
176
+
177
+ Args:
178
+ x: Values before masking. [batch..., vocab_size]
179
+ k: Number of masked values to return. In presence of ties, more than k
180
+ values might be returned.
181
+ replace_val: For the masked values of x, what to overwrite them with.
182
+
183
+ Returns:
184
+ masked version of x. [batch..., vocab_size]
185
+ """
186
+ batch_shape = tuple(list(x.shape)[:-1]) # [batch...]
187
+
188
+ x_for_loop = x
189
+ reduce_axis = x.ndim - 1
190
+ if x.ndim > 1:
191
+ # We're going to be doing 32 reductions over 'reduce_axis'. Generally,
192
+ # reductions over the last dimension are the most expensive, because they
193
+ # involve reducing across vector lanes, which is often not efficient. So
194
+ # we transpose the reduce_axis to be the second-last dimension, to avoid
195
+ # this inefficiency.
196
+ #
197
+ # Normaly the XLA compiler would automatically perform this optimization,
198
+ # but it doesn't yet see through loops to do so. So we do it ourselves.
199
+ x_for_loop = jnp.swapaxes(x_for_loop, -1, -2)
200
+ reduce_axis = x.ndim - 2
201
+
202
+ # x: [batch..., vocab_size, batch]
203
+ def predicate(threshold):
204
+ # threshold: [batch...]
205
+
206
+ # Since we've negated, we now want a predicate that is True for small
207
+ # numbers and False for large numbers. The result of the bsearch is the
208
+ # smallest float32 for which the predicate is False.
209
+ threshold = -threshold
210
+
211
+ threshold = lax.expand_dims(threshold, (reduce_axis, ))
212
+ # threshold: [batch..., 1, last_batch]
213
+
214
+ # count_ge: [batch...]
215
+ count_gt = jnp.sum(x_for_loop > threshold, axis=reduce_axis)
216
+
217
+ return count_gt >= k
218
+
219
+ # cutoff: [batch...]
220
+ cutoff = float32_bsearch(batch_shape, predicate)
221
+ cutoff = -cutoff
222
+ # cutoff: [batch..., 1]
223
+ cutoff = lax.expand_dims(cutoff, (cutoff.ndim, ))
224
+ return jnp.where(x >= cutoff, x, jnp.full_like(x, replace_val))
225
+
226
+
227
+ def topp_mask(logits: jnp.ndarray, p: float,
228
+ replace_val: jnp.ndarray) -> jnp.ndarray:
229
+ """Applies top-p masking to logits.
230
+
231
+ Masks logits down to the smallest set of choices, such that the total
232
+ probability mass is >= p. Values in this set are left as they are. All other
233
+ values are set with `replace_val`.
234
+
235
+ Sharding considerations: this function does 33 reductions over the vocab_size
236
+ axis of the input array. To avoid excessive latency from these reductions, you
237
+ should ensure that the vocab_size axis is unsharded on input to this function.
238
+ Prefer to shard the batch axes instead.
239
+
240
+ Scratchpad memory considerations: this function is most efficient if the
241
+ entire input array can fit in a fast memory tier. To help ensure this, you may
242
+ wish to split the batch axes into microbatches and the microbatches in a
243
+ sequential loop.
244
+
245
+ Args:
246
+ logits: Logits before masking. [batch..., vocab_size]
247
+ p: Minimum probability mass requested.
248
+ replace_val: For the masked values of logits, what to overwrite them with.
249
+
250
+ Returns:
251
+ masked version of x. [batch..., vocab_size]
252
+ """
253
+ batch_shape = tuple(list(logits.shape)[:-1]) # [batch...]
254
+
255
+ probs = jax.nn.softmax(logits, axis=-1)
256
+
257
+ probs_for_reduction = probs
258
+ reduce_axis = probs_for_reduction.ndim - 1
259
+ if probs_for_reduction.ndim > 1:
260
+ # We're going to be doing 33 reductions over 'reduce_axis'. Generally,
261
+ # reductions over the last dimension are the most expensive, because they
262
+ # involve reducing across vector lanes, which is often not efficient. So
263
+ # we transpose the reduce_axis to be the second-last dimension, to avoid
264
+ # this inefficiency.
265
+ probs_for_reduction = jnp.swapaxes(probs_for_reduction, -1, -2)
266
+ reduce_axis = probs_for_reduction.ndim - 2
267
+
268
+ # As we increase the threshold, the probability mass decreases, and the number
269
+ # selected decreases.
270
+ #
271
+ # We want the largest threshold with the probability mass >= p. Binary search
272
+ # searches for when the predicate is False, so we negate the output of the
273
+ # predicate, i.e. probability mass < p.
274
+
275
+ # probs_for_reduction: [batch..., vocab_size, batch]
276
+ def predicate(threshold):
277
+ # threshold: [batch...]
278
+ threshold = lax.expand_dims(threshold, (reduce_axis, ))
279
+ # threshold: [batch..., 1, last_batch]
280
+
281
+ # count_ge: [batch...]
282
+ probability_mass = jnp.sum(
283
+ jnp.where(probs_for_reduction >= threshold, probs_for_reduction,
284
+ 0.0),
285
+ axis=reduce_axis,
286
+ )
287
+
288
+ return probability_mass < p
289
+
290
+ # threshold: [batch...]
291
+ threshold = float32_bsearch(batch_shape, predicate)
292
+ # threshold: [batch..., 1]
293
+ threshold = lax.expand_dims(threshold, (threshold.ndim, ))
294
+ return jnp.where(probs >= threshold, logits,
295
+ jnp.full_like(logits, replace_val))
@@ -0,0 +1,88 @@
1
+ """
2
+ Current Used Abbreviation for Tensor Dimensions:
3
+ B: Batch size
4
+ T: Sequence Length (for Query tensors)
5
+ S: Sequence Length (for Key/Value tensors)
6
+ D: d_model, the embedding dimension of the model
7
+ F: d_ff, the hidden dimension of the feed-forward MLP layers
8
+ V: Vocab Size
9
+ H: Dimension of each attention head
10
+ N: Number of query heads in Attention
11
+ Q: Number of query heads (synonymous with N)
12
+ K: Number of Key/Value heads in Attention
13
+ C: Expert capacity in Mixture-of-Experts models
14
+ X: Number of activated experts per token in MoE
15
+ G: Number of groups in Grouped-Query Attention
16
+ E: Total number of experts in MoE
17
+ """
18
+
19
+ import enum
20
+ from typing import Tuple, TypeAlias
21
+
22
+ import jax
23
+
24
+ KVCacheType: TypeAlias = Tuple[jax.Array, jax.Array]
25
+
26
+
27
+ class RouterType(enum.Enum):
28
+ """Enum for router types."""
29
+ TOP_K = 'top_k'
30
+
31
+
32
+ class OPERATION_MODE(enum.Enum):
33
+ PREFILL = 1
34
+ DECODE = 2
35
+
36
+
37
+ class HuggingFaceArgNames(enum.Enum):
38
+ ## Modeling params
39
+ HIDDEN_ACT: str = "hidden_act"
40
+ HIDDEN_SIZE: str = "hidden_size"
41
+ NUM_HIDDEN_LAYERS: str = "num_hidden_layers"
42
+ RMS_NORM_EPS: str = "rms_norm_eps"
43
+ ROPE_SCALING: str = "rope_scaling"
44
+ ROPE_THETA: str = "rope_theta"
45
+ VOCAB_SIZE: str = "vocab_size"
46
+
47
+ # Block parameters
48
+ SHARED_EXPERTS: str = "shared_experts"
49
+
50
+ # FFW params
51
+ INTERMEDIATE_SIZE: str = "intermediate_size"
52
+
53
+ # Attention params
54
+ HEAD_DIM: str = "head_dim"
55
+ NUM_ATTENTION_HEADS: str = "num_attention_heads"
56
+ NUM_KEY_VALUE_HEADS: str = "num_key_value_heads"
57
+ ATTENTION_DROPOUT: str = "attention_dropout"
58
+ ATTENTION_BIAS: str = "attention_bias"
59
+ ATTENTION_CHUNK_SIZE: str = "attention_chunk_size"
60
+
61
+ ## Llama4 Attention Params
62
+ USE_QK_NORM: str = "use_qk_norm"
63
+ TEMPERATURE_TUNING: str = "temperature_tuning"
64
+ TEMPERATURE_TUNING_SCALE: str = "temperature_tuning_scale"
65
+ TEMPERATURE_TUNING_FLOOR_SCALE: str = "temperature_tuning_floor_scale"
66
+
67
+ # MLA params
68
+ KV_LORA_RANK: str = "kv_lora_rank"
69
+ Q_LORA_RANK: str = "q_lora_rank"
70
+ QK_NOPE_HEAD_DIM: str = "qk_nope_head_dim"
71
+ QK_ROPE_HEAD_DIM: str = "qk_rope_head_dim"
72
+ V_HEAD_DIM: str = "v_head_dim"
73
+
74
+ # MoE
75
+ INTERMEDIATE_SIZE_MOE: str = "intermediate_size_moe"
76
+ NUM_LOCAL_EXPERTS: str = "num_local_experts" # Llama moe
77
+ NUM_EXPERTS_PER_TOKEN: str = "num_experts_per_token"
78
+ NUM_ROUTED_EXPERTS: str = "n_routed_experts" # Deepseek moe
79
+ NUM_SHARED_ROUTED_EXPERTS: str = "n_shared_experts"
80
+ NUM_GROUPS: str = "n_group"
81
+ ROUTED_SCALING_FACTOR: str = "routed_scaling_factor"
82
+ TOPK_GROUP: str = "topk_group"
83
+ NORM_TOPK_PROB: str = "norm_topk_prob"
84
+ SCORING_FUNCTION: str = "scoring_func"
85
+
86
+ ## Sampling params
87
+ BOS_TOKEN_ID: str = "bos_token_id"
88
+ EOS_TOKEN_ID: str = "eos_token_id"
@@ -0,0 +1,301 @@
1
+ from dataclasses import InitVar, dataclass
2
+ from typing import Any
3
+
4
+ import jax
5
+ import jax.numpy as jnp
6
+ from flax import nnx
7
+ from flax.typing import Sharding
8
+ from jaxtyping import Float, Int
9
+
10
+ from tpu_inference.layers.jax.base import create_param
11
+
12
+
13
+ # A dummy for modeling_flax_utils which might contain activation functions
14
+ class FlaxUtils:
15
+ """A dummy class to namespace activation functions, mimicking external utilities."""
16
+ ACT2FN = {
17
+ 'silu': nnx.silu,
18
+ 'gelu': nnx.gelu,
19
+ 'relu': nnx.relu,
20
+ 'sigmoid': nnx.sigmoid,
21
+ 'softmax': nnx.softmax
22
+ }
23
+
24
+
25
+ modeling_flax_utils = FlaxUtils()
26
+
27
+
28
+ @dataclass
29
+ class RuntimeParams:
30
+ """A container for runtime parameters needed by neural network blocks.
31
+
32
+ This dataclass acts as a flexible container to pass objects that are only
33
+ available at runtime (like a pre-allocated KV cache or dynamic sharding
34
+ configurations) into the initialization of stateful modules. This avoids
35
+ having to update the constructor signature of every module when a new
36
+ runtime dependency is introduced.
37
+
38
+ Attributes:
39
+ kv_cache: The key-value cache object for attention layers.
40
+ sharding_cfg: The configuration for tensor sharding.
41
+ quantization: Configuration for quantization schemes.
42
+ """
43
+ kv_cache: Any = None
44
+ sharding_cfg: Any = None
45
+ quantization: Any = None
46
+
47
+
48
+ @dataclass(kw_only=True)
49
+ class RMSNorm(nnx.Module):
50
+ """An implementation of Root Mean Square Layer Normalization.
51
+
52
+ Attributes:
53
+ dims: The feature dimension to normalize over.
54
+ epsilon: A small float added to the variance to avoid division by zero.
55
+ with_scale: If True, learns a multiplicative scale parameter.
56
+ dtype: The data type for computations.
57
+ """
58
+ dims: int
59
+ activation_ffw_td: Sharding = ()
60
+ random_init: bool = False
61
+ epsilon: float = 1e-6
62
+ with_scale: bool = True
63
+ dtype: Any = jnp.float32
64
+
65
+ rngs: InitVar[nnx.Rngs]
66
+
67
+ def __call__(self, x_TD: Float, op_mode='generate') -> Float:
68
+ """Applies RMS Normalization to the input tensor.
69
+
70
+ Args:
71
+ x_TD: The input tensor. The normalization is applied over the last dimension.
72
+
73
+ Returns:
74
+ The normalized tensor with the same shape as the input.
75
+ """
76
+ x_TD = jnp.asarray(x_TD, self.dtype)
77
+ x_TD = nnx.with_sharding_constraint(x_TD, self.activation_ffw_td)
78
+
79
+ with jax.named_scope("rms_norm_variance"):
80
+ var_T1 = jnp.mean(jnp.square(x_TD), axis=-1, keepdims=True)
81
+ with jax.named_scope("rms_norm_rsqrt"):
82
+ normed_x_TD = x_TD * jax.lax.rsqrt(var_T1 + self.epsilon)
83
+
84
+ with jax.named_scope("rms_norm_scale_apply"):
85
+ normed_x_TD *= self.scale.value
86
+ normed_x_TD = nnx.with_sharding_constraint(normed_x_TD,
87
+ self.activation_ffw_td)
88
+ return normed_x_TD.astype(self.dtype)
89
+
90
+ def __post_init__(self, rngs: nnx.Rngs):
91
+ self.scale = create_param(rngs,
92
+ shape=(self.dims, ),
93
+ dtype=self.dtype,
94
+ random_init=self.random_init)
95
+
96
+
97
+ @dataclass(kw_only=True)
98
+ class DenseFFW(nnx.Module):
99
+ """A Gated Feed-Forward Network (FFN) layer.
100
+
101
+ This module consists of two linear projections (gating and up-projection),
102
+ an element-wise multiplication of the activated gating projection and the
103
+ up-projection, followed by a final downward projection.
104
+
105
+ Attributes:
106
+ sharding_cfg: The configuration for tensor sharding.
107
+ """
108
+ dtype: jnp.dtype
109
+ hidden_act: str
110
+ hidden_size: int
111
+ intermediate_size: int
112
+ df_sharding: Sharding = ()
113
+ fd_sharding: Sharding = ()
114
+ activation_ffw_td: Sharding = ()
115
+ random_init: bool = False
116
+
117
+ rngs: InitVar[nnx.Rngs]
118
+
119
+ def __call__(self, x_TD):
120
+ """Performs the forward pass of the FFW layer.
121
+
122
+ Args:
123
+ x_TD: The input tensor of shape either `(sequence, d_model)`
124
+
125
+ Returns:
126
+ The output tensor of shape `(batch, sequence, d_model)`.
127
+ """
128
+ # TODO consider to create factories for einsum(?)
129
+ x_TD = jnp.asarray(x_TD, self.dtype)
130
+ x_TD = nnx.with_sharding_constraint(x_TD, self.activation_ffw_td)
131
+ with jax.named_scope("wi_0"):
132
+ gating_TF = jnp.einsum('TD,DF -> TF', x_TD,
133
+ self.kernel_gating_DF.value)
134
+ activated_gating_TF = modeling_flax_utils.ACT2FN[self.hidden_act](
135
+ gating_TF)
136
+ with jax.named_scope("wi_1"):
137
+ up_proj_TF = jnp.einsum('TD,DF -> TF', x_TD,
138
+ self.kernel_up_proj_DF.value)
139
+ fuse_TF = activated_gating_TF * up_proj_TF
140
+ with jax.named_scope("wo"):
141
+ output_TD = jnp.einsum('TF,FD -> TD', fuse_TF,
142
+ self.kernel_down_proj_FD.value)
143
+
144
+ return output_TD
145
+
146
+ def __post_init__(self, rngs: nnx.Rngs):
147
+ D = self.hidden_size
148
+ F = self.intermediate_size
149
+
150
+ self.kernel_gating_DF = create_param(rngs,
151
+ shape=(D, F),
152
+ dtype=self.dtype,
153
+ sharding=self.df_sharding,
154
+ random_init=self.random_init)
155
+ self.kernel_up_proj_DF = create_param(rngs,
156
+ shape=(D, F),
157
+ dtype=self.dtype,
158
+ sharding=self.df_sharding,
159
+ random_init=self.random_init)
160
+ self.kernel_down_proj_FD = create_param(rngs,
161
+ shape=(F, D),
162
+ dtype=self.dtype,
163
+ sharding=self.fd_sharding,
164
+ random_init=self.random_init)
165
+
166
+
167
+ @dataclass(kw_only=True)
168
+ class Embedder(nnx.Module):
169
+ """A module for token embedding and, optionally, decoding (tied embeddings).
170
+
171
+ This class handles both the "encoding" step of converting token IDs to dense
172
+ vectors and the "decoding" step of projecting model outputs back to logits
173
+ over the vocabulary.
174
+
175
+ """
176
+ vocab_size: int
177
+ hidden_size: int
178
+ dtype: jnp.dtype
179
+ prelogit_td: Sharding = ()
180
+ vd_sharding: Sharding = ()
181
+ random_init: bool = False
182
+ normalize_embeddings: bool = False
183
+
184
+ rngs: InitVar[nnx.Rngs]
185
+
186
+ def __post_init__(self, rngs: nnx.Rngs):
187
+ self.input_embedding_table_VD = create_param(
188
+ rngs,
189
+ shape=(self.vocab_size, self.hidden_size),
190
+ sharding=self.vd_sharding,
191
+ dtype=self.dtype,
192
+ random_init=self.random_init)
193
+
194
+ def __call__(self, x, decode=False):
195
+ """Dispatches to either the encode or decode method.
196
+
197
+ Args:
198
+ x: The input tensor. Either token IDs for encoding or hidden states
199
+ for decoding.
200
+ decode: A boolean flag. If False (default), performs encoding. If
201
+ True, performs decoding.
202
+
203
+ Returns:
204
+ Either embedding vectors or logit scores.
205
+ """
206
+ if decode:
207
+ return self.decode(x)
208
+ else:
209
+ return self.encode(x)
210
+
211
+ def decode(self, x_TD: Float) -> Float:
212
+ """Projects hidden states to vocabulary logits.
213
+
214
+ Args:
215
+ x_TD: The input tensor of hidden states from the model backbone, with
216
+ shape `(sequence, d_model)`.
217
+
218
+ Returns:
219
+ The output logits over the vocabulary, with shape
220
+ `(sequence, vocab_size)`.
221
+ """
222
+ x_TD = jnp.asarray(x_TD, self.dtype)
223
+ x_TD = nnx.with_sharding_constraint(x_TD, self.prelogit_td)
224
+
225
+ with jax.named_scope("embedder_decode_projection"):
226
+ logits_TV = jnp.einsum('VD,TD -> TV',
227
+ self.input_embedding_table_VD.value, x_TD)
228
+ return logits_TV
229
+
230
+ def encode(self, x_T: Int) -> Float:
231
+ """Converts integer token IDs to dense embedding vectors.
232
+
233
+ Args:
234
+ x_T: The input tensor of token IDs, with shape `(sequence, )`.
235
+
236
+ Returns:
237
+ The corresponding embedding vectors, with shape
238
+ `(batch, sequence, d_model)`.
239
+ """
240
+ with jax.named_scope("embedder_encode_lookup"):
241
+ embedding_TD = jnp.take(self.input_embedding_table_VD.value,
242
+ x_T,
243
+ axis=0)
244
+
245
+ if self.normalize_embeddings:
246
+ with jax.named_scope("embedder_normalize_embeddings"):
247
+ embedding_TD *= jnp.sqrt(self.hidden_size).astype(self.dtype)
248
+ return embedding_TD
249
+
250
+
251
+ @dataclass(kw_only=True)
252
+ class LMhead(Embedder):
253
+ """
254
+ An Embedder that uses a (D, V) shaped embedding table, inheriting from
255
+ the base Embedder class.
256
+
257
+ This implementation overrides the kernel generation, encoding, and decoding
258
+ methods to work with the transposed embedding matrix layout.
259
+ """
260
+ dv_sharding: Sharding
261
+
262
+ def __post_init__(self, rngs: nnx.Rngs):
263
+ self.input_embedding_table_DV = create_param(
264
+ rngs,
265
+ shape=(self.hidden_size, self.vocab_size),
266
+ sharding=self.dv_sharding,
267
+ dtype=self.dtype,
268
+ random_init=self.random_init)
269
+
270
+ def __call__(self, x):
271
+ """Dispatches to decode method.
272
+
273
+ Args:
274
+ x: The input tensor. Either token IDs for encoding or hidden states
275
+ for decoding.
276
+ decode: A boolean flag. If False (default), performs encoding. If
277
+ True, performs decoding.
278
+
279
+ Returns:
280
+ Either embedding vectors or logit scores.
281
+ """
282
+ return self.decode(x)
283
+
284
+ def decode(self, x_TD: Float) -> Float:
285
+ """Projects hidden states to vocabulary logits.
286
+
287
+ Args:
288
+ x_TD: The input tensor of hidden states from the model backbone, with
289
+ shape `(sequence, d_model)`.
290
+
291
+ Returns:
292
+ The output logits over the vocabulary, with shape
293
+ `(sequence, vocab_size)`.
294
+ """
295
+ x_TD = jnp.asarray(x_TD, self.dtype)
296
+ x_TD = nnx.with_sharding_constraint(x_TD, self.prelogit_td)
297
+
298
+ with jax.named_scope("lmhead_decode_projection"):
299
+ logits_TV = jnp.einsum('DV,TD -> TV',
300
+ self.input_embedding_table_DV.value, x_TD)
301
+ return logits_TV
@@ -0,0 +1,16 @@
1
+ import math
2
+ from typing import Tuple
3
+
4
+ import jax
5
+ from jax.sharding import NamedSharding
6
+ from jax.sharding import PartitionSpec as P
7
+
8
+
9
+ # TODO(xiang): move this to weight_utils.py
10
+ def shard_put(x: jax.Array, sharding_names: Tuple[str, ...] | P,
11
+ mesh: jax.sharding.Mesh) -> jax.Array:
12
+ # Single device sharding requires this special handling
13
+ # to avoid the recursive jit error.
14
+ if math.prod(mesh.axis_sizes) == 1:
15
+ return jax.device_put(x, mesh.devices.flatten()[0])
16
+ return jax.device_put(x, NamedSharding(mesh, P(*sharding_names)))