tpu-inference 0.0.1rc1__py3-none-any.whl → 0.11.1.dev202511180814__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (56) hide show
  1. tests/kernels/fused_moe_v1_test.py +34 -303
  2. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +2 -2
  3. tests/lora/test_layers.py +6 -0
  4. tests/lora/utils.py +8 -0
  5. tests/test_envs.py +11 -32
  6. tests/test_utils.py +2 -1
  7. tpu_inference/__init__.py +3 -22
  8. tpu_inference/core/disagg_utils.py +8 -6
  9. tpu_inference/distributed/tpu_connector.py +4 -3
  10. tpu_inference/distributed/utils.py +2 -3
  11. tpu_inference/envs.py +8 -61
  12. tpu_inference/executors/ray_distributed_executor.py +2 -9
  13. tpu_inference/kernels/fused_moe/v1/kernel.py +110 -641
  14. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +54 -77
  15. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +145 -266
  16. tpu_inference/layers/common/attention_interface.py +1 -7
  17. tpu_inference/layers/common/sharding.py +5 -5
  18. tpu_inference/layers/vllm/fused_moe.py +208 -170
  19. tpu_inference/layers/vllm/quantization/common.py +1 -6
  20. tpu_inference/layers/vllm/quantization/mxfp4.py +73 -138
  21. tpu_inference/layers/vllm/quantization/unquantized.py +64 -58
  22. tpu_inference/layers/vllm/sharding.py +2 -2
  23. tpu_inference/lora/torch_punica_tpu.py +2 -1
  24. tpu_inference/mock/__init__.py +0 -0
  25. tpu_inference/mock/vllm_config_utils.py +28 -0
  26. tpu_inference/mock/vllm_envs.py +1219 -0
  27. tpu_inference/mock/vllm_logger.py +212 -0
  28. tpu_inference/mock/vllm_logging_utils.py +15 -0
  29. tpu_inference/models/common/model_loader.py +10 -43
  30. tpu_inference/models/jax/llama3.py +1 -2
  31. tpu_inference/models/jax/llama_eagle3.py +5 -8
  32. tpu_inference/models/jax/phi3.py +376 -0
  33. tpu_inference/models/jax/qwen2.py +1 -2
  34. tpu_inference/models/jax/qwen2_5_vl.py +48 -163
  35. tpu_inference/models/jax/qwen3.py +1 -2
  36. tpu_inference/models/jax/utils/quantization/quantization_utils.py +6 -3
  37. tpu_inference/models/jax/utils/weight_utils.py +143 -198
  38. tpu_inference/models/vllm/vllm_model_wrapper.py +8 -14
  39. tpu_inference/platforms/tpu_platform.py +31 -37
  40. tpu_inference/runner/compilation_manager.py +58 -141
  41. tpu_inference/runner/kv_cache.py +1 -1
  42. tpu_inference/runner/kv_cache_manager.py +18 -17
  43. tpu_inference/runner/persistent_batch_manager.py +2 -40
  44. tpu_inference/runner/structured_decoding_manager.py +3 -2
  45. tpu_inference/runner/tpu_runner.py +147 -271
  46. tpu_inference/runner/utils.py +2 -2
  47. tpu_inference/spec_decode/jax/eagle3.py +21 -71
  48. tpu_inference/tpu_info.py +3 -4
  49. tpu_inference/utils.py +13 -36
  50. tpu_inference/worker/tpu_worker.py +25 -162
  51. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/METADATA +3 -4
  52. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/RECORD +55 -50
  53. tpu_inference/models/jax/llama_guard_4.py +0 -361
  54. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/WHEEL +0 -0
  55. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/licenses/LICENSE +0 -0
  56. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/top_level.txt +0 -0
@@ -2,16 +2,17 @@ import functools
2
2
 
3
3
  import jax
4
4
  from jax import numpy as jnp
5
- from jax import shard_map
6
5
  from jax.experimental.pallas.ops.tpu.megablox.gmm import gmm
7
- from jax.sharding import Mesh
8
- from jax.sharding import PartitionSpec as P
6
+ from jax.experimental.shard_map import shard_map
7
+ from jax.sharding import Mesh, NamedSharding, PartitionSpec
9
8
 
10
9
  from tpu_inference.layers.vllm.linear_common import \
11
10
  slice_sharded_tensor_for_concatenation
12
11
 
12
+ P = PartitionSpec
13
13
 
14
- def activation_fn(activation: str, x1: jax.Array, x2: jax.Array) -> jax.Array:
14
+
15
+ def activation_fn(activation: str, x1, x2):
15
16
  match activation:
16
17
  case "silu":
17
18
  return jax.nn.silu(x1) * x2
@@ -22,10 +23,7 @@ def activation_fn(activation: str, x1: jax.Array, x2: jax.Array) -> jax.Array:
22
23
  f"FusedMoE does not support {activation} activation")
23
24
 
24
25
 
25
- def _swigluoai(x1: jax.Array,
26
- x2: jax.Array,
27
- alpha=1.702,
28
- limit=7.0) -> jax.Array:
26
+ def _swigluoai(x1, x2, alpha=1.702, limit=7.0):
29
27
  x1 = jnp.clip(x1, a_max=limit)
30
28
  x2 = jnp.clip(x2, a_min=-limit, a_max=limit)
31
29
 
@@ -105,53 +103,40 @@ def tensor_sharded_gmm_merged_column_parallel(
105
103
  rhs: jax.Array,
106
104
  rhs_bias: jax.Array | None,
107
105
  group_sizes: jax.Array,
106
+ transpose_rhs: bool,
108
107
  mesh: Mesh,
109
- ) -> tuple[jax.Array, jax.Array]:
110
-
111
- def _gmm(lhs, rhs, group_sizes):
112
- m, g, n, k = lhs.shape[0], *rhs.shape
113
- tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n, g)
114
- return gmm(
115
- lhs,
116
- rhs,
117
- group_sizes,
118
- preferred_element_type=lhs.dtype,
119
- tiling=(tm, tk, tn),
120
- transpose_rhs=True,
121
- group_offset=jnp.array(0),
122
- )
108
+ intermediate_size: int,
109
+ ) -> jax.Array:
110
+ # adapted from https://github.com/pytorch/xla/blob/1d409399474197c484894be90b75d9855393dda5/torch_xla/experimental/custom_kernel.py#L1401
111
+ m, k, g = lhs.shape[0], lhs.shape[1], rhs.shape[0]
112
+ n = rhs.shape[1] if transpose_rhs else rhs.shape[2]
113
+ tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n, g)
114
+
115
+ _gmm = functools.partial(
116
+ gmm,
117
+ preferred_element_type=lhs.dtype,
118
+ tiling=(tm, tk, tn),
119
+ transpose_rhs=transpose_rhs,
120
+ group_offset=jnp.array(0),
121
+ )
123
122
 
124
123
  gmm_result = shard_map(
125
124
  _gmm,
126
125
  mesh=mesh,
127
- in_specs=(P("data", None), P(None, "model", None), P("data")),
128
- out_specs=(P("data", "model")),
129
- check_vma=False,
126
+ in_specs=(P(), P(None, "model", None), P()),
127
+ out_specs=(P(None, "model")),
128
+ check_rep=False,
130
129
  )(lhs, rhs, group_sizes)
131
130
 
132
131
  if rhs_bias is not None:
132
+ rhs_bis = jnp.repeat(rhs_bias, group_sizes, 0, total_repeat_length=m)
133
+ gmm_result = (gmm_result + rhs_bis).astype(gmm_result.dtype)
133
134
 
134
- def _add_bias(gmm_result_local, rhs_bias_local, group_sizes_global):
135
- rhs_bias = jnp.repeat(
136
- rhs_bias_local,
137
- group_sizes_global,
138
- 0,
139
- total_repeat_length=gmm_result_local.shape[0])
140
- return gmm_result_local + rhs_bias
141
-
142
- gmm_result = shard_map(
143
- _add_bias,
144
- mesh=mesh,
145
- in_specs=(P("data", "model"), P(None, "model"), P("data")),
146
- out_specs=(P("data", "model")),
147
- )(gmm_result, rhs_bias, group_sizes)
148
- gmm_result = gmm_result.astype(lhs.dtype)
149
-
150
- tp_size = mesh.shape["model"]
151
- intermediate_size = gmm_result.shape[-1] // 2
135
+ n_shards = mesh.shape["model"]
152
136
  output_sizes = [intermediate_size, intermediate_size]
137
+
153
138
  return slice_sharded_tensor_for_concatenation(gmm_result, output_sizes,
154
- tp_size)
139
+ n_shards)
155
140
 
156
141
 
157
142
  def tensor_sharded_gmm_row_parallel(
@@ -159,75 +144,74 @@ def tensor_sharded_gmm_row_parallel(
159
144
  rhs: jax.Array,
160
145
  rhs_bias: jax.Array | None,
161
146
  group_sizes: jax.Array,
147
+ transpose_rhs: bool,
162
148
  mesh: Mesh,
163
149
  ) -> jax.Array:
150
+ # adapted from https://github.com/pytorch/xla/blob/1d409399474197c484894be90b75d9855393dda5/torch_xla/experimental/custom_kernel.py#L1401
151
+ m, k, g = lhs.shape[0], lhs.shape[1], rhs.shape[0]
152
+ n = rhs.shape[1] if transpose_rhs else rhs.shape[2]
153
+ tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n, g)
154
+
155
+ _gmm = functools.partial(
156
+ gmm,
157
+ preferred_element_type=lhs.dtype,
158
+ tiling=(tm, tk, tn),
159
+ transpose_rhs=transpose_rhs,
160
+ group_offset=jnp.array(0),
161
+ )
164
162
 
165
163
  def _gmm_all_reduce(lhs, rhs, group_sizes):
166
- m, g, n, k = lhs.shape[0], *rhs.shape
167
- tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n, g)
168
- out = gmm(
169
- lhs,
170
- rhs,
171
- group_sizes,
172
- preferred_element_type=lhs.dtype,
173
- tiling=(tm, tk, tn),
174
- transpose_rhs=True,
175
- group_offset=jnp.array(0),
176
- )
177
- return jax.lax.psum(out, axis_name="model")
164
+ r = _gmm(lhs, rhs, group_sizes)
165
+ return jax.lax.psum(r, axis_name="model")
178
166
 
179
167
  gmm_result = shard_map(
180
168
  _gmm_all_reduce,
181
169
  mesh=mesh,
182
- in_specs=(P("data", "model"), P(None, None, "model"), P("data")),
183
- out_specs=(P("data")),
184
- check_vma=False,
170
+ in_specs=(P(None, "model"), P(None, None, "model"), P()),
171
+ out_specs=(P()),
172
+ check_rep=False,
185
173
  )(lhs, rhs, group_sizes)
186
174
 
187
175
  if rhs_bias is not None:
176
+ rhs_bias = jnp.repeat(rhs_bias, group_sizes, 0, total_repeat_length=m)
177
+ gmm_result = (gmm_result + rhs_bias).astype(gmm_result.dtype)
188
178
 
189
- def _add_bias(gmm_result_local, rhs_bias_local, group_sizes_global):
190
- rhs_bias = jnp.repeat(
191
- rhs_bias_local,
192
- group_sizes_global,
193
- 0,
194
- total_repeat_length=gmm_result_local.shape[0])
195
- return gmm_result_local + rhs_bias
196
-
197
- gmm_result = shard_map(
198
- _add_bias,
199
- mesh=mesh,
200
- in_specs=(P("data"), P(), P("data")),
201
- out_specs=(P("data")),
202
- )(gmm_result, rhs_bias, group_sizes)
203
-
204
- return gmm_result.astype(lhs.dtype)
179
+ return gmm_result
205
180
 
206
181
 
207
182
  def expert_sharded_gmm(
208
183
  lhs: jax.Array,
209
184
  rhs: jax.Array,
210
185
  group_sizes: jax.Array,
186
+ transpose_rhs: bool,
211
187
  mesh: Mesh,
188
+ num_experts: int,
189
+ ep_size: int,
212
190
  ) -> jax.Array:
213
- ep_size = mesh.shape["model"]
191
+ # adapted from https://github.com/pytorch/xla/blob/1d409399474197c484894be90b75d9855393dda5/torch_xla/experimental/custom_kernel.py#L1401
192
+ m, k, g = lhs.shape[0], lhs.shape[1], rhs.shape[0]
193
+ n = rhs.shape[1] if transpose_rhs else rhs.shape[2]
194
+ tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n, g)
214
195
 
215
- num_experts = rhs.shape[0]
216
196
  num_experts_per_shard = num_experts // ep_size
217
197
  group_offset = jnp.arange(0, num_experts, num_experts_per_shard)
198
+ group_offset = jax.lax.with_sharding_constraint(
199
+ group_offset, NamedSharding(mesh, P("model")))
218
200
 
219
201
  def _gmm(lhs, rhs, group_sizes, group_offset):
220
- m, g, n, k = lhs.shape[0], *rhs.shape
221
- tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n, g)
222
-
202
+ # Group offset for this shard. `group_offset` is sharded, and in this
203
+ # sharded function, it has only 1 element and `group_offset.shape` is
204
+ # (1,) but gmm kernel requires the group_offset to be a ()-shaped array,
205
+ # so we group_offset[0].
206
+ group_offset_of_shard = group_offset[0]
223
207
  gmm_res = gmm(
224
208
  lhs=lhs,
225
209
  rhs=rhs,
226
210
  group_sizes=group_sizes,
227
211
  preferred_element_type=lhs.dtype,
228
212
  tiling=(tm, tk, tn),
229
- transpose_rhs=True,
230
- group_offset=group_offset[0],
213
+ transpose_rhs=transpose_rhs,
214
+ group_offset=group_offset_of_shard,
231
215
  )
232
216
  return gmm_res
233
217
 
@@ -254,24 +238,30 @@ def expert_sharded_gmm(
254
238
  mesh=mesh,
255
239
  in_specs=(P(), P("model", None, None), P(), P("model")),
256
240
  out_specs=(P("model", None)),
257
- check_vma=False,
241
+ check_rep=False,
258
242
  )(lhs, rhs, group_sizes, group_offset)
259
243
 
260
244
  # For i-th shard, it is responsible groups (AKA experts) from
261
245
  # i*num_experts_per_shard to (i+1)*num_experts_per_shard We sum them up to
262
246
  # get total rows in that shard, and that is the size for shard to send to
263
247
  # its peers. This is also the number of non-zero rows from the gmm results.
264
- # In the working example, send_sizes would be [3, 2, 5, 4].
265
-
266
- # group_sizes has shape of [num_tokens_per_shard * num_experts_per_shard].
267
- # So reshaping to [num_tokens_per_shard, num_experts_per_shard] and applying
268
- # sum(axis=1) will get desired send_sizes shaped [num_tokens_per_shard].
269
- send_sizes = group_sizes.reshape(-1, num_experts_per_shard).sum(axis=1)
248
+ # In the working example, send_sizes would be [3, 2, 5, 4]
249
+ send_sizes = jnp.array([
250
+ group_sizes[i * num_experts_per_shard:(i + 1) *
251
+ num_experts_per_shard].sum() for i in range(ep_size)
252
+ ])
270
253
  # In the working example, input_offsets would be [0, 3, 5, 10]
271
254
  input_offsets = jnp.concatenate((jnp.array([0]), send_sizes.cumsum()[:-1]))
272
255
  output_offsets = input_offsets
273
256
  recv_sizes = send_sizes
274
257
 
258
+ input_offsets = jax.lax.with_sharding_constraint(
259
+ input_offsets, NamedSharding(mesh, P("model")))
260
+ send_sizes = jax.lax.with_sharding_constraint(
261
+ send_sizes, NamedSharding(mesh, P("model")))
262
+ output_offsets = jax.lax.with_sharding_constraint(
263
+ output_offsets, NamedSharding(mesh, P("model")))
264
+
275
265
  def _ragged_all_to_all(operand, input_offsets, send_sizes, output_offsets,
276
266
  recv_sizes):
277
267
  output = jnp.zeros_like(operand)
@@ -326,20 +316,10 @@ def expert_sharded_gmm(
326
316
  mesh=mesh,
327
317
  in_specs=(P("model", None), P("model"), P("model"), P("model"), P()),
328
318
  out_specs=(P()),
329
- check_vma=False,
319
+ check_rep=False,
330
320
  )(gmm_res, input_offsets, send_sizes, output_offsets, recv_sizes)
331
321
 
332
322
 
333
- @functools.partial(
334
- jax.jit,
335
- static_argnames=(
336
- "topk",
337
- "renormalize",
338
- "mesh",
339
- "use_ep",
340
- "activation",
341
- ),
342
- )
343
323
  def fused_moe_func(
344
324
  hidden_states: jax.Array,
345
325
  w1: jax.Array,
@@ -348,45 +328,37 @@ def fused_moe_func(
348
328
  w2_bias: jax.Array | None,
349
329
  gating_output: jax.Array,
350
330
  topk: int,
331
+ global_num_experts: int,
351
332
  renormalize: bool,
333
+ reduce_results: bool,
352
334
  mesh: Mesh,
353
335
  use_ep: bool,
354
336
  activation: str,
355
- ) -> jax.Array:
337
+ ):
356
338
  """
357
- Route tokens in hidden_states into each experts based on routing
358
- information in gating_out and performs moe with w1 and w2 weights.
359
-
360
339
  Args:
361
- hidden_states: [num_tokens, hidden_size]
362
- w1: first moe weights [num_experts, intermediate_size * 2, hidden_size]
363
- w2: second moe weights [num_experts, hidden_size, intermediate_size]
364
- w1_bias: optional bias of w1 [num_experts, intermediate_size * 2]
365
- w2_bias: optional bias of w2 [num_experts, hidden_size]
366
- gating_output: routing information of tokens [num_tokens, num_experts]
367
- topk: number of experts to choose per token.
368
- renormalize: normalize gating_output.
369
- mesh: mesh to perform moe.
370
- use_ep: use expert parallelism.
371
- activation: activation function to perform on the output of w1.
372
-
373
- Returns:
374
- Output of moe operation [num_tokens, hidden_size]
340
+ hidden_states: [*, hidden_size]
341
+ w1: [num_experts, intermediate_size * 2, hidden_size]
342
+ w2: [num_experts, hidden_size, intermediate_size]
343
+ gating_output: [*, num_experts]
375
344
  """
345
+ # adapted from https://github.com/vllm-project/vllm/blob/29fa5cac1cd731026f59084d93a822921507573c/vllm/model_executor/layers/fused_moe/moe_pallas.py#L26
376
346
  if use_ep and (w1_bias is not None or w2_bias is not None):
377
347
  raise NotImplementedError(
378
348
  "Bias is not supported when using expert parallelism.")
379
-
380
- num_tokens = hidden_states.shape[0]
381
- global_num_experts, hidden_size, intermediate_size = w2.shape
349
+ orig_shape = hidden_states.shape
350
+ hidden_size = hidden_states.shape[-1]
351
+ num_tokens = hidden_states.size // hidden_size
352
+ assert global_num_experts == w1.shape[0]
353
+ ep_size = mesh.shape["model"] # only used if use_ep is True.
354
+ intermediate_size = w2.shape[-1]
382
355
  dtype = hidden_states.dtype
383
-
384
356
  assert (num_tokens * topk) % 16 == 0, (
385
357
  "The kernel requires num_tokens * topk to be a multiple of "
386
358
  f"16 but got {num_tokens}*{topk}={num_tokens*topk}")
387
- assert hidden_states.shape == (num_tokens, hidden_size)
388
- assert gating_output.shape == (num_tokens, global_num_experts)
389
- assert w1.shape == (global_num_experts, intermediate_size * 2, hidden_size)
359
+
360
+ hidden_states = hidden_states.reshape(num_tokens, hidden_size)
361
+ gating_output = gating_output.reshape(num_tokens, global_num_experts)
390
362
 
391
363
  topk_weights = jax.nn.softmax(gating_output.astype(jnp.float32), axis=-1)
392
364
  topk_weights, topk_indices = jax.lax.top_k(topk_weights, k=topk)
@@ -394,76 +366,142 @@ def fused_moe_func(
394
366
  topk_weights = topk_weights / topk_weights.sum(axis=-1, keepdims=True)
395
367
  topk_weights = topk_weights.astype(dtype)
396
368
 
397
- def _process_tokens_locally(hidden_states_local, topk_indices_local):
398
- num_tokens_local = hidden_states_local.shape[0]
399
- topk_indices_flat = topk_indices_local.flatten()
400
- topk_argsort_indices = jnp.argsort(topk_indices_flat)
401
- topk_argsort_revert_indices = jnp.argsort(topk_argsort_indices)
402
- token_indices = jnp.arange(num_tokens_local,
403
- dtype=jnp.int32).repeat(topk)
404
- token_indices_sorted = token_indices[topk_argsort_indices]
405
- group_sizes_local = jnp.bincount(topk_indices_flat,
406
- length=global_num_experts)
407
-
408
- x = hidden_states_local[token_indices_sorted]
409
- return x, group_sizes_local, topk_argsort_revert_indices
410
-
411
- x, group_sizes, topk_argsort_revert_indices = shard_map(
412
- _process_tokens_locally,
413
- mesh=mesh,
414
- in_specs=(P("data", None), P("data", None)),
415
- out_specs=(P("data", None), P("data"), P("data")),
416
- )(hidden_states, topk_indices)
369
+ topk_indices_flat = topk_indices.flatten()
370
+ topk_argsort_indices = jnp.argsort(topk_indices_flat)
371
+ topk_argsort_revert_indices = jnp.argsort(topk_argsort_indices)
372
+ token_indices = jnp.arange(num_tokens, dtype=jnp.int32).repeat(topk)
373
+ token_indices_sorted = token_indices[topk_argsort_indices]
374
+ group_sizes = jnp.bincount(topk_indices_flat, length=global_num_experts)
375
+
376
+ x = hidden_states[token_indices_sorted]
417
377
 
418
378
  if use_ep:
419
379
  x = expert_sharded_gmm(
420
380
  x,
421
381
  w1,
422
382
  group_sizes,
383
+ transpose_rhs=True,
423
384
  mesh=mesh,
385
+ num_experts=global_num_experts,
386
+ ep_size=ep_size,
424
387
  )
425
- x1, x2 = jnp.split(x, 2, -1)
426
-
427
- x = activation_fn(activation, x1, x2)
428
-
429
- x = expert_sharded_gmm(
430
- x,
431
- w2,
432
- group_sizes,
433
- mesh=mesh,
434
- )
388
+ x1, x2 = x[..., :intermediate_size], x[..., intermediate_size:]
435
389
  else:
436
390
  x1, x2 = tensor_sharded_gmm_merged_column_parallel(
437
391
  x,
438
392
  w1,
439
393
  w1_bias,
440
394
  group_sizes,
395
+ transpose_rhs=True,
441
396
  mesh=mesh,
397
+ intermediate_size=intermediate_size,
442
398
  )
443
399
 
444
- x = activation_fn(activation, x1, x2)
400
+ x = activation_fn(activation, x1, x2)
445
401
 
402
+ if use_ep:
403
+ x = expert_sharded_gmm(
404
+ x,
405
+ w2,
406
+ group_sizes,
407
+ transpose_rhs=True,
408
+ mesh=mesh,
409
+ num_experts=global_num_experts,
410
+ ep_size=ep_size,
411
+ )
412
+ else:
413
+ x = jax.lax.with_sharding_constraint(
414
+ x, NamedSharding(mesh, P(None, "model")))
446
415
  x = tensor_sharded_gmm_row_parallel(
447
416
  x,
448
417
  w2,
449
418
  w2_bias,
450
419
  group_sizes,
420
+ transpose_rhs=True,
451
421
  mesh=mesh,
452
422
  )
453
423
 
454
- def _finalize_output(x_local, topk_argsort_revert_indices_local,
455
- topk_weights_local):
456
- x_local = x_local[topk_argsort_revert_indices_local].reshape(
457
- -1, topk, hidden_size)
458
- x_local = x_local * jnp.expand_dims(topk_weights_local, axis=-1)
459
- x_local = x_local.sum(axis=-2)
460
- return x_local
424
+ x = x[topk_argsort_revert_indices].reshape(-1, topk, hidden_size)
425
+ x = x * jnp.expand_dims(topk_weights, axis=-1)
426
+ x = x.sum(axis=-2)
427
+ x = x.reshape(orig_shape)
461
428
 
462
- x = shard_map(
463
- _finalize_output,
464
- mesh=mesh,
465
- in_specs=(P("data", None), P("data"), P("data", None)),
466
- out_specs=(P("data", None)),
467
- )(x, topk_argsort_revert_indices, topk_weights)
429
+ if reduce_results:
430
+ x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P()))
431
+ return x
468
432
 
469
- return x[:num_tokens, :hidden_size]
433
+
434
+ @functools.partial(
435
+ jax.jit,
436
+ static_argnames=(
437
+ "topk",
438
+ "global_num_experts",
439
+ "renormalize",
440
+ "reduce_results",
441
+ "mesh",
442
+ "use_ep",
443
+ "activation",
444
+ ),
445
+ )
446
+ def fused_moe_func_padded(
447
+ hidden_states: jax.Array,
448
+ w1: jax.Array,
449
+ w2: jax.Array,
450
+ w1_bias: jax.Array | None,
451
+ w2_bias: jax.Array | None,
452
+ gating_output: jax.Array,
453
+ topk: int,
454
+ global_num_experts: int,
455
+ renormalize: bool,
456
+ reduce_results: bool,
457
+ mesh: Mesh,
458
+ use_ep: bool,
459
+ activation: str,
460
+ ):
461
+ # TODO(fanhongmin@google.com): Once the jax runner pads the input, we no longer need this.
462
+ hidden_size = hidden_states.shape[-1]
463
+ num_tokens = hidden_states.size // hidden_size
464
+ if num_tokens * topk < 16:
465
+ assert 16 % (num_tokens *
466
+ topk) == 0, f"Cannot pad to 16: {num_tokens=}, {topk=}"
467
+ n_repeats = 16 // (num_tokens * topk)
468
+
469
+ reps = (n_repeats, ) + (1, ) * (hidden_states.ndim - 1)
470
+ expanded_hidden_states = jnp.tile(hidden_states, reps)
471
+
472
+ reps = (n_repeats, ) + (1, ) * (gating_output.ndim - 1)
473
+ expanded_gating_output = jnp.tile(gating_output, reps)
474
+
475
+ expanded_x = fused_moe_func(
476
+ expanded_hidden_states,
477
+ w1,
478
+ w2,
479
+ w1_bias,
480
+ w2_bias,
481
+ expanded_gating_output,
482
+ topk,
483
+ global_num_experts,
484
+ renormalize,
485
+ reduce_results,
486
+ mesh,
487
+ use_ep,
488
+ activation,
489
+ )
490
+ x = expanded_x[:hidden_states.shape[0]]
491
+ return x
492
+ else:
493
+ return fused_moe_func(
494
+ hidden_states,
495
+ w1,
496
+ w2,
497
+ w1_bias,
498
+ w2_bias,
499
+ gating_output,
500
+ topk,
501
+ global_num_experts,
502
+ renormalize,
503
+ reduce_results,
504
+ mesh,
505
+ use_ep,
506
+ activation,
507
+ )
@@ -61,12 +61,7 @@ class JaxCommonLinearConfig:
61
61
  " bad performance.", type(layer))
62
62
 
63
63
  self.bias_sharding = P(self.weight_sharding[0])
64
- if isinstance(self.weight_sharding[0], tuple):
65
- self.n_shards = 1
66
- for axis in self.weight_sharding[0]:
67
- self.n_shards *= self.mesh.shape.get(axis, 1)
68
- else:
69
- self.n_shards = self.mesh.shape.get(self.weight_sharding[0], 1)
64
+ self.n_shards = self.mesh.shape.get(self.weight_sharding[0], 1)
70
65
 
71
66
  def get_input_sharding(self, x: torchax.tensor.Tensor):
72
67
  if self.enable_sequence_parallelism: