tpu-inference 0.11.1rc1__py3-none-any.whl → 0.11.1rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (50) hide show
  1. tpu_inference/kernels/collectives/__init__.py +0 -0
  2. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  3. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  4. tpu_inference/kernels/collectives/util.py +47 -0
  5. tpu_inference/layers/__init__.py +0 -0
  6. tpu_inference/layers/common/__init__.py +0 -0
  7. tpu_inference/layers/common/attention_metadata.py +34 -0
  8. tpu_inference/layers/jax/__init__.py +0 -0
  9. tpu_inference/layers/jax/attention/__init__.py +0 -0
  10. tpu_inference/layers/jax/attention/attention.py +254 -0
  11. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  12. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  13. tpu_inference/layers/jax/attention_interface.py +356 -0
  14. tpu_inference/layers/jax/base.py +151 -0
  15. tpu_inference/layers/jax/binary_search.py +295 -0
  16. tpu_inference/layers/jax/constants.py +88 -0
  17. tpu_inference/layers/jax/layers.py +301 -0
  18. tpu_inference/layers/jax/misc.py +16 -0
  19. tpu_inference/layers/jax/moe/__init__.py +0 -0
  20. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  21. tpu_inference/layers/jax/moe/moe.py +209 -0
  22. tpu_inference/layers/jax/rope.py +172 -0
  23. tpu_inference/layers/jax/rope_interface.py +214 -0
  24. tpu_inference/layers/jax/sample/__init__.py +0 -0
  25. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  26. tpu_inference/layers/jax/sample/sampling.py +95 -0
  27. tpu_inference/layers/jax/sample/sampling_metadata.py +69 -0
  28. tpu_inference/layers/jax/sharding.py +406 -0
  29. tpu_inference/layers/jax/transformer_block.py +76 -0
  30. tpu_inference/layers/vllm/__init__.py +0 -0
  31. tpu_inference/layers/vllm/attention.py +184 -0
  32. tpu_inference/layers/vllm/fused_moe.py +399 -0
  33. tpu_inference/layers/vllm/linear_common.py +186 -0
  34. tpu_inference/layers/vllm/quantization/__init__.py +34 -0
  35. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  36. tpu_inference/layers/vllm/quantization/common.py +105 -0
  37. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  38. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +121 -0
  39. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  40. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  41. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  42. tpu_inference/layers/vllm/quantization/unquantized.py +263 -0
  43. tpu_inference/layers/vllm/sharding.py +151 -0
  44. tpu_inference/models/common/__init__.py +0 -0
  45. tpu_inference/models/common/model_loader.py +433 -0
  46. {tpu_inference-0.11.1rc1.dist-info → tpu_inference-0.11.1rc3.dist-info}/METADATA +6 -6
  47. {tpu_inference-0.11.1rc1.dist-info → tpu_inference-0.11.1rc3.dist-info}/RECORD +50 -5
  48. {tpu_inference-0.11.1rc1.dist-info → tpu_inference-0.11.1rc3.dist-info}/WHEEL +0 -0
  49. {tpu_inference-0.11.1rc1.dist-info → tpu_inference-0.11.1rc3.dist-info}/licenses/LICENSE +0 -0
  50. {tpu_inference-0.11.1rc1.dist-info → tpu_inference-0.11.1rc3.dist-info}/top_level.txt +0 -0
File without changes
@@ -0,0 +1,608 @@
1
+ import enum
2
+ from dataclasses import InitVar, dataclass
3
+ from functools import partial
4
+ from typing import Optional, Tuple
5
+
6
+ import jax
7
+ import jax.numpy as jnp
8
+ from flax import nnx
9
+ from flax.typing import Sharding
10
+ from jax.sharding import PartitionSpec
11
+ from jaxtyping import Float
12
+ from qwix._src.core.ragged_dot import ragged_dot as qwix_ragged_dot
13
+ from qwix._src.providers import ptq
14
+
15
+ from tpu_inference.layers.jax.base import create_param
16
+ from tpu_inference.layers.jax.layers import FlaxUtils
17
+ from tpu_inference.layers.jax.moe.moe import MoE
18
+ from tpu_inference.models.jax.utils.quantization.quantization_utils import (
19
+ manually_quantize_qwix_activation, manually_quantize_qwix_weight)
20
+
21
+ modeling_flax_utils = FlaxUtils()
22
+
23
+
24
+ @dataclass
25
+ class DeepSeekV3Router(nnx.Module):
26
+ """Router module for Mixture-of-Experts (MoE) layers.
27
+
28
+ This module determines which experts each token should be routed to based on the input.
29
+
30
+ """
31
+
32
+ hidden_size: int
33
+ num_experts: int
34
+ num_experts_per_tok: int
35
+ n_groups: int
36
+ topk_groups: int
37
+ norm_topk_prob: bool
38
+ routed_scaling_factor: float
39
+ dtype: jnp.dtype
40
+ rngs: InitVar[nnx.Rngs]
41
+
42
+ # Sharding Attributes
43
+ activation_ffw_td: Sharding = ()
44
+ ed_sharding: Sharding = ()
45
+ e_sharding: Sharding = ()
46
+
47
+ random_init: bool = False
48
+
49
+ router_bias_dtype: jnp.dtype = jnp.float32
50
+
51
+ def get_topk_indices(self, scores_TE: Float) -> Float:
52
+ """Get the topk indices of the scores.
53
+
54
+ Args:
55
+ scores_TE: The scores to get the topk indices of. Shape (sequence, num_experts).
56
+
57
+ Returns:
58
+ The topk indices of the scores. Shape (sequence, num_experts_per_tok).
59
+ """
60
+
61
+ scores_TE = scores_TE + self.bias_E
62
+ if self.n_groups > 1:
63
+ experts_per_group = self.num_experts // self.n_groups
64
+ group_scores_TGM = jnp.reshape(
65
+ scores_TE, (-1, self.n_groups, experts_per_group))
66
+ group_scores_TG2 = jax.lax.top_k(group_scores_TGM, k=2)[0]
67
+ group_scores_TG = jnp.sum(group_scores_TG2, axis=-1)
68
+ indices = jax.lax.top_k(group_scores_TG, k=self.topk_groups)[1]
69
+
70
+ mask_TG = jnp.any(jnp.arange(
71
+ self.n_groups)[:, None] == indices[..., None, :],
72
+ axis=-1)
73
+ mask_TE = jnp.repeat(mask_TG,
74
+ scores_TE.shape[-1] // mask_TG.shape[-1], -1)
75
+ scores_TE = jnp.where(mask_TE, scores_TE, 0.0)
76
+
77
+ indices_TX = jax.lax.top_k(scores_TE, k=self.num_experts_per_tok)[1]
78
+
79
+ return indices_TX
80
+
81
+ def __call__(self, x_TD: Float) -> Tuple[Float, Float]:
82
+ """Routes tokens to top k experts.
83
+
84
+ Args:
85
+ x_TD: Input array of shape (sequence, d_model).
86
+
87
+ Returns:
88
+ A tuple containing:
89
+ - weights: Normalized weights for selected experts, shape (sequence, num_experts_per_tok).
90
+ - indices: Indices of selected experts, shape (sequence, num_experts_per_tok).
91
+ """
92
+ x_TD = jnp.asarray(x_TD, self.dtype)
93
+ x_TD = nnx.with_sharding_constraint(x_TD, self.activation_ffw_td)
94
+
95
+ scores_TE = jnp.einsum("TD,DE -> TE", x_TD, self.kernel_DE.value)
96
+ scores_TE = nnx.sigmoid(scores_TE)
97
+
98
+ original_scores_TE = scores_TE
99
+ topk_indices_TX = self.get_topk_indices(scores_TE)
100
+ weights_TX = jnp.take_along_axis(original_scores_TE,
101
+ topk_indices_TX,
102
+ axis=-1)
103
+
104
+ if self.norm_topk_prob:
105
+ weights_TX /= jnp.sum(weights_TX, axis=-1)[..., None] + 1e-20
106
+
107
+ weights_TX *= self.routed_scaling_factor
108
+
109
+ return weights_TX, topk_indices_TX
110
+
111
+ def __post_init__(self, rngs: nnx.Rngs):
112
+ """Generates the router kernel (weights and bias) for routing."""
113
+ D = self.hidden_size
114
+ E = self.num_experts
115
+ self.kernel_DE = create_param(rngs,
116
+ shape=(D, E),
117
+ dtype=self.dtype,
118
+ sharding=self.ed_sharding,
119
+ random_init=self.random_init)
120
+ self.bias_E = create_param(rngs,
121
+ shape=(E, ),
122
+ dtype=self.router_bias_dtype,
123
+ sharding=self.e_sharding,
124
+ random_init=self.random_init)
125
+
126
+
127
+ @dataclass(kw_only=True)
128
+ class SparseMoE(MoE):
129
+ """Mixture-of-Experts (MoE) Routed MLP Layer.
130
+
131
+ This module implements a Sparse MoE layer with a router and multiple expert MLPs.
132
+
133
+ Attributes:
134
+ num_experts_per_tok: The number of experts each token is routed to.
135
+ tile_size: A tuple (batch, activation_dim, weight_dim) for GMM tiling.
136
+ use_megablox: If True, uses the MegaBlox GMM kernel.
137
+ mesh: The device mesh.
138
+ # TODO: need to redesign this I/O for parallelism
139
+ num_expert_parallelism: The size of the 'expert' mesh dimension.
140
+ # TODO: determine if we get it from external or extrat it in MoE class
141
+ is_batch_sharded_by_expert: True if batch is sharded over 'expert' dim.
142
+ """
143
+ num_experts_per_tok: int
144
+ #TODO: tile size is (tile_batch_seq, tile_activation_dim, tile_weight_dim,) from MaxText
145
+ tile_size: tuple[int, int, int] = (128, 64, 128)
146
+ use_megablox: bool = False
147
+ mesh: jax.sharding.Mesh
148
+ # This should be set if and only if you have quantized your model (via Qwix)
149
+ quantized_dtype: Optional[jnp.dtype] = None
150
+
151
+ def __post_init__(self, rngs: nnx.Rngs):
152
+ super().__post_init__(rngs)
153
+
154
+ # Derive the expert sharding
155
+ self.expert_axis_name = self.edf_sharding[0]
156
+ if self.expert_axis_name is None:
157
+ self.num_expert_parallelism = 1
158
+ else:
159
+ self.num_expert_parallelism = self.mesh.shape[
160
+ self.expert_axis_name]
161
+
162
+ # Derive if data is sharded by expert
163
+ self.data_axis_name = self.activation_ffw_td[0]
164
+ self.is_batch_sharded_by_expert = (
165
+ self.expert_axis_name is not None) and (self.expert_axis_name
166
+ == self.data_axis_name)
167
+
168
+ def _sort_activations(self, inputs: jax.Array,
169
+ sort_indices: jax.Array) -> jax.Array:
170
+ """Sorts activations(inputs) by `sort_indices` for the forward pass."""
171
+ return inputs[sort_indices, ...]
172
+
173
+ @staticmethod
174
+ def get_all_to_all_params(
175
+ all_shards_group_sizes,
176
+ shard_id,
177
+ num_expert_parallelism,
178
+ is_batch_sharded=True,
179
+ ):
180
+ """Generates params for ragged_all_to_all communication."""
181
+
182
+ class TransformStrategy(enum.Enum):
183
+ INPUT_OFFSET = enum.auto()
184
+ SEND_SIZE = enum.auto()
185
+ OUTPUT_OFFSET = enum.auto()
186
+ RECV_SIZE = enum.auto()
187
+
188
+ def transform_array(input_array, shard_id, strategy, is_batch_sharded):
189
+ if is_batch_sharded:
190
+ if strategy == TransformStrategy.INPUT_OFFSET:
191
+ local_array = input_array[shard_id]
192
+ return jnp.concatenate(
193
+ (jnp.array([0]), jnp.cumsum(local_array)[:-1]))
194
+ elif strategy == TransformStrategy.SEND_SIZE:
195
+ return input_array[shard_id]
196
+ elif strategy == TransformStrategy.OUTPUT_OFFSET:
197
+ zero_row = jnp.zeros((1, ) + input_array.shape[1:],
198
+ dtype=input_array.dtype)
199
+ array_with_zeros = jnp.concatenate((zero_row, input_array),
200
+ axis=0)
201
+ cumulated_array = jnp.cumsum(array_with_zeros,
202
+ axis=0,
203
+ dtype=input_array.dtype)
204
+ return cumulated_array[shard_id]
205
+ elif strategy == TransformStrategy.RECV_SIZE:
206
+ return input_array[:, shard_id]
207
+ else:
208
+ raise ValueError(
209
+ f"Unknown transform array strategy: {strategy}")
210
+ else:
211
+ if strategy == TransformStrategy.INPUT_OFFSET:
212
+ return jnp.zeros(num_expert_parallelism,
213
+ dtype=input_array.dtype)
214
+ elif strategy == TransformStrategy.SEND_SIZE:
215
+ return jnp.repeat(input_array[shard_id],
216
+ num_expert_parallelism)
217
+ elif strategy == TransformStrategy.OUTPUT_OFFSET:
218
+ output_offset = jnp.concatenate(
219
+ (jnp.array([0]),
220
+ jnp.cumsum(input_array[:-1])))[shard_id]
221
+ return jnp.repeat(output_offset, num_expert_parallelism)
222
+ elif strategy == TransformStrategy.RECV_SIZE:
223
+ return input_array
224
+ else:
225
+ raise ValueError(
226
+ f"Unknown transform array strategy: {strategy}")
227
+
228
+ input_offsets = transform_array(all_shards_group_sizes, shard_id,
229
+ TransformStrategy.INPUT_OFFSET,
230
+ is_batch_sharded)
231
+ send_sizes = transform_array(all_shards_group_sizes, shard_id,
232
+ TransformStrategy.SEND_SIZE,
233
+ is_batch_sharded)
234
+ output_offsets = transform_array(all_shards_group_sizes, shard_id,
235
+ TransformStrategy.OUTPUT_OFFSET,
236
+ is_batch_sharded)
237
+ recv_sizes = transform_array(all_shards_group_sizes, shard_id,
238
+ TransformStrategy.RECV_SIZE,
239
+ is_batch_sharded)
240
+ return input_offsets, send_sizes, output_offsets, recv_sizes
241
+
242
+ def _local_permute(
243
+ self,
244
+ inputs,
245
+ global_group_sizes,
246
+ local_expert_size,
247
+ shard_index,
248
+ is_offset=False,
249
+ global_sorted_experts=None,
250
+ ):
251
+ """Permutes tokens locally within an expert shard."""
252
+ # global_group_sizes: (tokens parallelism, num_total_experts)
253
+ # all_shard_local_sizes: (tokens parallelism, num local experts in the shard)
254
+ all_shard_local_sizes = jax.lax.dynamic_slice_in_dim(
255
+ global_group_sizes,
256
+ shard_index * local_expert_size,
257
+ local_expert_size,
258
+ axis=1,
259
+ )
260
+ local_sizes = all_shard_local_sizes.reshape(-1)
261
+
262
+ # local_group_size: (tokens parallelism, )
263
+ local_group_size = jnp.sum(all_shard_local_sizes, axis=0)
264
+
265
+ # When token replicated in devices
266
+ if is_offset:
267
+ global_sorted_shard_assignments = jnp.floor_divide(
268
+ global_sorted_experts, local_expert_size)
269
+ expert_indices = jnp.where(
270
+ global_sorted_shard_assignments == shard_index,
271
+ jnp.mod(global_sorted_experts, local_expert_size),
272
+ local_expert_size,
273
+ )
274
+
275
+ # When token sharded in devices
276
+ else:
277
+ base_indices = jnp.mod(jnp.arange(local_sizes.shape[0]),
278
+ local_expert_size)
279
+ expert_indices = jnp.repeat(base_indices,
280
+ local_sizes,
281
+ total_repeat_length=inputs.shape[0])
282
+
283
+ sorted_indices = jnp.argsort(expert_indices)
284
+ # sort the inputs based on the local expert_indices
285
+ sorted_inputs = self._sort_activations(inputs, sorted_indices)
286
+ # sortted local expert id from 0 to local expert size
287
+ sorted_experts_ids = expert_indices[sorted_indices]
288
+ return (
289
+ sorted_inputs,
290
+ sorted_indices,
291
+ local_group_size,
292
+ sorted_experts_ids,
293
+ )
294
+
295
+ def _permute(self, inputs_TD: Float, selected_experts_TX: jax.Array):
296
+ """Global permute: Sorts tokens by assigned expert."""
297
+ # suffix t = T * X = total_assignments for the local tokens(T) on this device.
298
+ total_tokens = inputs_TD.shape[0]
299
+ flat_expert_indices = selected_experts_TX.flatten()
300
+ sort_indices_t = jnp.argsort(flat_expert_indices)
301
+
302
+ replicated_inputs_tD = jnp.repeat(inputs_TD,
303
+ self.num_experts_per_tok,
304
+ axis=0)
305
+ sorted_inputs_tD = self._sort_activations(replicated_inputs_tD,
306
+ sort_indices_t)
307
+
308
+ # number of tokens assigned to each expert
309
+ group_sizes_E = jnp.bincount(flat_expert_indices,
310
+ length=self.num_local_experts)
311
+
312
+ expert_ids = jnp.arange(self.num_local_experts)
313
+ total_assignments = total_tokens * self.num_experts_per_tok
314
+ sorted_expert_assignments_t = jnp.repeat(
315
+ expert_ids,
316
+ repeats=group_sizes_E,
317
+ total_repeat_length=total_assignments)
318
+
319
+ return (
320
+ sorted_inputs_tD,
321
+ sort_indices_t,
322
+ group_sizes_E,
323
+ sorted_expert_assignments_t,
324
+ )
325
+
326
+ def _unpermute(self, processed_tokens: jax.Array, sort_indices: jax.Array,
327
+ router_weights_TX: jax.Array):
328
+ """Unsorts tokens to their original order and combines expert outputs with router's weight."""
329
+ with jax.named_scope("unpermute"):
330
+ unsorted_tokens_tD = self._sort_activations(
331
+ processed_tokens, jnp.argsort(sort_indices))
332
+ reshaped_tokens_TXD = unsorted_tokens_tD.reshape(
333
+ -1, self.num_experts_per_tok, self.hidden_size)
334
+ with jax.named_scope("combine_weights"):
335
+ output_TD = jnp.einsum(
336
+ "TXD,TX -> TD",
337
+ reshaped_tokens_TXD.astype(jnp.float32),
338
+ router_weights_TX.astype(jnp.float32),
339
+ precision='float32',
340
+ )
341
+
342
+ return output_TD.astype(self.dtype)
343
+
344
+ def _gmm(self, inputs, kernel, group_sizes):
345
+ """Performs Grouped Matrix Multiply."""
346
+ num_rows = inputs.shape[0]
347
+ pad_amount = (self.tile_size[0] -
348
+ num_rows % self.tile_size[0]) % self.tile_size[0]
349
+ if pad_amount > 0:
350
+ inputs = jnp.pad(inputs, ((0, pad_amount), (0, 0)))
351
+
352
+ if self.use_megablox:
353
+ #TODO: megablox is used in MaxText, keep a placeholder here for future implement
354
+ raise NotImplementedError(
355
+ "MegaBlox kernel call is not implemented.")
356
+ else:
357
+ inputs = manually_quantize_qwix_activation(
358
+ inputs, "ragged_dot", jnp.float8_e4m3fn, [0], {},
359
+ "absmax") if self.quantized_dtype else inputs
360
+ ragged_dot_func = qwix_ragged_dot if self.quantized_dtype else jax.lax.ragged_dot
361
+ output = ragged_dot_func(
362
+ lhs=inputs,
363
+ rhs=kernel,
364
+ group_sizes=group_sizes,
365
+ preferred_element_type=self.dtype,
366
+ )
367
+
368
+ if pad_amount > 0:
369
+ output = output[:num_rows, :]
370
+ return output
371
+
372
+ @staticmethod
373
+ def _distributed_sparse_moe_fwd(
374
+ self,
375
+ x_TD: jax.Array,
376
+ router_weights_TX: jax.Array,
377
+ selected_experts_TX: jax.Array,
378
+ kernel_gating: jax.Array,
379
+ kernel_up_proj: jax.Array,
380
+ kernel_down_proj: jax.Array,
381
+ ):
382
+ """
383
+ The sparse MoE forward pass with fully distributed logic.
384
+ This assumes it is running within a distributed TPU.
385
+ """
386
+
387
+ # 1. Global Permute, perpute all tokens across shards
388
+ (
389
+ sorted_inputs,
390
+ global_sort_indices,
391
+ global_group_sizes,
392
+ global_sorted_experts,
393
+ ) = self._permute(x_TD, selected_experts_TX)
394
+
395
+ # TODO: update to 'expert' after we enable expert parallelism, currently experts are sharded along model axis
396
+ # or we sould derive it from the model init
397
+ expert_shard_id = jax.lax.axis_index(self.expert_axis_name)
398
+ local_expert_size = self.num_local_experts // self.num_expert_parallelism
399
+
400
+ if self.num_expert_parallelism > 1:
401
+ if self.is_batch_sharded_by_expert:
402
+ # When token sharded in devices
403
+ # In this path, we assume the data(tokens) are fully sharded on expert, namely data_axis_name == expert_axis_name
404
+
405
+ # 2a. Send Tokens To Experts (All-to-All)
406
+ # Gather group sizes from all data shards
407
+ # all_shards_group_sizes: (data parallelism = expert parallelism, number of total experts )
408
+ all_shards_group_sizes = jax.lax.all_gather(
409
+ global_group_sizes, axis_name=self.data_axis_name)
410
+
411
+ # all_shards_group_sizes_per_expert_shard[i][j] = # tokens on shard[i] to be sent to expert shard[j]
412
+ all_shards_group_sizes_per_expert_shard = jnp.sum(
413
+ all_shards_group_sizes.reshape(
414
+ self.num_expert_parallelism, # data parallelism
415
+ self.num_expert_parallelism, # expert parallelism
416
+ local_expert_size # Experts per shard
417
+ ),
418
+ axis=2)
419
+ input_offsets, send_sizes, output_offsets, recv_sizes = self.get_all_to_all_params(
420
+ all_shards_group_sizes_per_expert_shard, expert_shard_id,
421
+ self.num_expert_parallelism)
422
+ # Estimate buffer size
423
+ local_total_assignments = x_TD.shape[
424
+ 0] * self.num_experts_per_tok
425
+ global_total_assignments = local_total_assignments * self.num_expert_parallelism
426
+ output_shape_est = jnp.zeros(
427
+ (global_total_assignments, self.hidden_size),
428
+ dtype=sorted_inputs.dtype)
429
+
430
+ inputs_after_all2all = jax.lax.ragged_all_to_all(
431
+ sorted_inputs,
432
+ output_shape_est,
433
+ input_offsets,
434
+ send_sizes,
435
+ output_offsets,
436
+ recv_sizes,
437
+ axis_name=self.expert_axis_name)
438
+
439
+ # 3a. Local Permute
440
+ # Get full group sizes from all shards
441
+ full_global_group_sizes = jax.lax.all_gather(
442
+ global_group_sizes, axis_name=self.expert_axis_name)
443
+ (
444
+ compute_inputs,
445
+ local_sorted_indices,
446
+ compute_group_sizes,
447
+ compute_expert_ids,
448
+ ) = self._local_permute(
449
+ inputs_after_all2all,
450
+ full_global_group_sizes,
451
+ local_expert_size,
452
+ shard_index=expert_shard_id,
453
+ is_offset=False,
454
+ )
455
+
456
+ else:
457
+ # When token replicated in devices
458
+
459
+ # 2. No send all-to-all needed, as the tokens are sorted and replicated on all devices
460
+ # 3b. Local "Permute"
461
+ (
462
+ compute_inputs,
463
+ local_sorted_indices,
464
+ compute_group_sizes,
465
+ compute_expert_ids,
466
+ ) = self._local_permute(
467
+ sorted_inputs,
468
+ global_group_sizes[None, :],
469
+ local_expert_size,
470
+ shard_index=expert_shard_id,
471
+ is_offset=True,
472
+ global_sorted_experts=global_sorted_experts,
473
+ )
474
+
475
+ # Calculate group sizes for return all-to-all
476
+ reshaped_group_sizes = jnp.sum(global_group_sizes.reshape(
477
+ -1, local_expert_size),
478
+ axis=1)
479
+ mask = compute_expert_ids < local_expert_size
480
+ compute_inputs = compute_inputs * mask[..., None]
481
+
482
+ else:
483
+ # --- NO EXPERT PARALLELISM ---
484
+ compute_inputs = sorted_inputs
485
+ compute_group_sizes = global_group_sizes
486
+ compute_expert_ids = global_sorted_experts
487
+ local_sorted_indices = jnp.arange(sorted_inputs.shape[0])
488
+
489
+ # 4. Compute: Apply experts using Grouped Matrix Multiply
490
+ with jax.named_scope("gating"):
491
+ # compute_inputs: (local total assignments, D)
492
+ gating_TEF = self._gmm(compute_inputs, kernel_gating,
493
+ compute_group_sizes)
494
+ activated_gating_TEF = modeling_flax_utils.ACT2FN[self.hidden_act](
495
+ gating_TEF)
496
+
497
+ with jax.named_scope("up_projection"):
498
+ up_proj_TEF = self._gmm(compute_inputs, kernel_up_proj,
499
+ compute_group_sizes)
500
+
501
+ fuse_TEF = activated_gating_TEF * up_proj_TEF
502
+
503
+ with jax.named_scope("down_projection"):
504
+ # intermediate_output: (local total assignments, D)
505
+ intermediate_output = self._gmm(fuse_TEF, kernel_down_proj,
506
+ compute_group_sizes)
507
+
508
+ # 5. Return Results (All-to-All)
509
+ if self.num_expert_parallelism > 1:
510
+ local_total_assignments = x_TD.shape[0] * self.num_experts_per_tok
511
+ output_shape = jnp.zeros(
512
+ (local_total_assignments, self.hidden_size),
513
+ dtype=intermediate_output.dtype)
514
+
515
+ if self.is_batch_sharded_by_expert:
516
+ # When token sharded in devices
517
+ # Unsort locally before sending back
518
+ local_output = self._sort_activations(
519
+ intermediate_output, jnp.argsort(local_sorted_indices))
520
+
521
+ input_offsets, send_sizes, output_offsets, recv_sizes = self.get_all_to_all_params(
522
+ jnp.transpose(all_shards_group_sizes),
523
+ expert_shard_id,
524
+ self.num_expert_parallelism,
525
+ )
526
+ final_intermediate_output = jax.lax.ragged_all_to_all(
527
+ local_output,
528
+ output_shape,
529
+ input_offsets,
530
+ send_sizes,
531
+ output_offsets,
532
+ recv_sizes,
533
+ axis_name=self.expert_axis_name)
534
+ else:
535
+ # When token replicated in devices
536
+ input_offsets, send_sizes, output_offsets, recv_sizes = self.get_all_to_all_params(
537
+ reshaped_group_sizes,
538
+ expert_shard_id,
539
+ self.num_expert_parallelism,
540
+ is_batch_sharded=False,
541
+ )
542
+ final_intermediate_output = jax.lax.ragged_all_to_all(
543
+ intermediate_output,
544
+ output_shape,
545
+ input_offsets,
546
+ send_sizes,
547
+ output_offsets,
548
+ recv_sizes,
549
+ axis_name=self.expert_axis_name)
550
+ else:
551
+ final_intermediate_output = intermediate_output
552
+
553
+ # 6. Global Unpermute (on the data shard)
554
+ with jax.named_scope("unpermute"):
555
+ output_TD = self._unpermute(final_intermediate_output,
556
+ global_sort_indices, router_weights_TX)
557
+
558
+ return output_TD
559
+
560
+ def __call__(self, x_TD: Float):
561
+ """Performs the forward pass of the Sparse MoE layer."""
562
+ x_TD = jnp.asarray(x_TD, self.dtype)
563
+ x_TD = nnx.with_sharding_constraint(x_TD, self.activation_ffw_td)
564
+ router_weights_TX, selected_experts_TX = self.router(x_TD)
565
+
566
+ in_specs = (
567
+ PartitionSpec(), # Replicated `self`
568
+ PartitionSpec(*self.activation_ffw_td), # Sharded x_TD
569
+ PartitionSpec(), # Replicated router_weights_TX
570
+ PartitionSpec(), # Replicated selected_experts_TX
571
+ PartitionSpec(*self.edf_sharding), # Sharded gating kernel
572
+ PartitionSpec(*self.edf_sharding), # Sharded up-projection kernel
573
+ PartitionSpec(
574
+ *self.efd_sharding), # Sharded down-projection kernel
575
+ )
576
+ out_specs = PartitionSpec(*self.activation_ffw_td)
577
+
578
+ mapped_moe_fwd = partial(jax.experimental.shard_map.shard_map,
579
+ mesh=self.mesh,
580
+ in_specs=in_specs,
581
+ out_specs=out_specs,
582
+ check_rep=False)(
583
+ SparseMoE._distributed_sparse_moe_fwd)
584
+
585
+ kernel_gating_EDF = self.kernel_gating_EDF.value
586
+ kernel_up_proj_EDF = self.kernel_up_proj_EDF.value
587
+ kernel_down_proj_EFD = self.kernel_down_proj_EFD.value
588
+
589
+ if self.quantized_dtype:
590
+ if not isinstance(kernel_gating_EDF, ptq.WithAux):
591
+ kernel_gating_EDF = manually_quantize_qwix_weight(
592
+ kernel_gating_EDF, self.quantized_dtype, [0, 2], {},
593
+ "absmax")
594
+ if not isinstance(kernel_up_proj_EDF, ptq.WithAux):
595
+ kernel_up_proj_EDF = manually_quantize_qwix_weight(
596
+ kernel_up_proj_EDF, self.quantized_dtype, [0, 2], {},
597
+ "absmax")
598
+ if not isinstance(kernel_down_proj_EFD, ptq.WithAux):
599
+ kernel_down_proj_EFD = manually_quantize_qwix_weight(
600
+ kernel_down_proj_EFD, self.quantized_dtype, [0, 1], {},
601
+ "absmax")
602
+ kernel_gating_EDF = kernel_gating_EDF.array
603
+ kernel_up_proj_EDF = kernel_up_proj_EDF.array
604
+ kernel_down_proj_EFD = kernel_down_proj_EFD.array
605
+
606
+ return mapped_moe_fwd(self, x_TD, router_weights_TX,
607
+ selected_experts_TX, kernel_gating_EDF,
608
+ kernel_up_proj_EDF, kernel_down_proj_EFD)