tpu-inference 0.11.1.dev202511150811__py3-none-any.whl → 0.11.1.dev202512030818__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (54) hide show
  1. tests/kernels/fused_moe_v1_test.py +303 -34
  2. tests/lora/test_layers.py +0 -6
  3. tests/lora/utils.py +0 -8
  4. tests/test_envs.py +32 -11
  5. tests/test_utils.py +1 -2
  6. tpu_inference/__init__.py +22 -3
  7. tpu_inference/core/disagg_utils.py +6 -8
  8. tpu_inference/distributed/tpu_connector.py +3 -4
  9. tpu_inference/distributed/utils.py +3 -2
  10. tpu_inference/envs.py +61 -8
  11. tpu_inference/executors/ray_distributed_executor.py +31 -11
  12. tpu_inference/kernels/fused_moe/v1/kernel.py +641 -110
  13. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +77 -54
  14. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +213 -126
  15. tpu_inference/layers/common/attention_interface.py +7 -1
  16. tpu_inference/layers/common/sharding.py +5 -5
  17. tpu_inference/layers/vllm/fused_moe.py +74 -25
  18. tpu_inference/layers/vllm/quantization/common.py +6 -1
  19. tpu_inference/layers/vllm/quantization/mxfp4.py +137 -62
  20. tpu_inference/layers/vllm/quantization/unquantized.py +107 -113
  21. tpu_inference/layers/vllm/sharding.py +2 -2
  22. tpu_inference/lora/torch_punica_tpu.py +1 -2
  23. tpu_inference/models/common/model_loader.py +45 -11
  24. tpu_inference/models/jax/llama3.py +2 -1
  25. tpu_inference/models/jax/llama_eagle3.py +8 -5
  26. tpu_inference/models/jax/llama_guard_4.py +361 -0
  27. tpu_inference/models/jax/qwen2.py +2 -1
  28. tpu_inference/models/jax/qwen2_5_vl.py +163 -48
  29. tpu_inference/models/jax/qwen3.py +2 -1
  30. tpu_inference/models/jax/utils/quantization/quantization_utils.py +3 -6
  31. tpu_inference/models/jax/utils/weight_utils.py +198 -143
  32. tpu_inference/models/vllm/vllm_model_wrapper.py +14 -7
  33. tpu_inference/platforms/tpu_platform.py +28 -22
  34. tpu_inference/runner/compilation_manager.py +144 -59
  35. tpu_inference/runner/kv_cache_manager.py +17 -18
  36. tpu_inference/runner/persistent_batch_manager.py +40 -2
  37. tpu_inference/runner/structured_decoding_manager.py +2 -3
  38. tpu_inference/runner/tpu_runner.py +271 -147
  39. tpu_inference/runner/utils.py +2 -2
  40. tpu_inference/spec_decode/jax/eagle3.py +71 -21
  41. tpu_inference/tpu_info.py +4 -3
  42. tpu_inference/utils.py +36 -13
  43. tpu_inference/worker/tpu_worker.py +162 -25
  44. {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/METADATA +3 -2
  45. {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/RECORD +48 -53
  46. tpu_inference/mock/__init__.py +0 -0
  47. tpu_inference/mock/vllm_config_utils.py +0 -28
  48. tpu_inference/mock/vllm_envs.py +0 -1219
  49. tpu_inference/mock/vllm_logger.py +0 -212
  50. tpu_inference/mock/vllm_logging_utils.py +0 -15
  51. tpu_inference/models/jax/phi3.py +0 -376
  52. {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/WHEEL +0 -0
  53. {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/licenses/LICENSE +0 -0
  54. {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,5 @@
1
1
  import json
2
2
  import math
3
- import os
4
3
  from dataclasses import asdict, dataclass
5
4
  from typing import TYPE_CHECKING, List, Optional
6
5
 
@@ -8,7 +7,7 @@ import jax.numpy as jnp
8
7
  import numpy as np
9
8
  from jax.sharding import Mesh
10
9
 
11
- from tpu_inference import utils
10
+ from tpu_inference import envs, utils
12
11
 
13
12
  if TYPE_CHECKING:
14
13
  from vllm.v1.configs.vllm_config import VllmConfig
@@ -48,7 +47,7 @@ class ShardingAxisName2D:
48
47
 
49
48
 
50
49
  try:
51
- _use_base_sharding = os.getenv("NEW_MODEL_DESIGN", False)
50
+ _use_base_sharding = envs.NEW_MODEL_DESIGN
52
51
  if _use_base_sharding:
53
52
  ShardingAxisName = ShardingAxisNameBase
54
53
  else:
@@ -166,9 +165,10 @@ class ShardingConfigManager:
166
165
  f"LoRA is not supported with data parallelism "
167
166
  f"(DP size: {total_dp_size}). Please disable LoRA or "
168
167
  f"set data parallelism to 1.")
169
- if not os.environ.get("NEW_MODEL_DESIGN", False):
168
+ if sharding_strategy.attention_data_parallelism > 1:
169
+ if not envs.NEW_MODEL_DESIGN:
170
170
  raise ValueError(
171
- "Must run DP with NEW_MODEL_DESIGN enabled. Please set the "
171
+ "Must run Attention DP with NEW_MODEL_DESIGN enabled. Please set the "
172
172
  "NEW_MODEL_DESIGN=True.")
173
173
 
174
174
  @property
@@ -110,7 +110,8 @@ def tensor_sharded_gmm_merged_column_parallel(
110
110
  # adapted from https://github.com/pytorch/xla/blob/1d409399474197c484894be90b75d9855393dda5/torch_xla/experimental/custom_kernel.py#L1401
111
111
  m, k, g = lhs.shape[0], lhs.shape[1], rhs.shape[0]
112
112
  n = rhs.shape[1] if transpose_rhs else rhs.shape[2]
113
- tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n, g)
113
+ tm, tk, tn = _get_tiling_size_for_gmm_kernel(m // mesh.shape["data"], k, n,
114
+ g)
114
115
 
115
116
  _gmm = functools.partial(
116
117
  gmm,
@@ -123,14 +124,26 @@ def tensor_sharded_gmm_merged_column_parallel(
123
124
  gmm_result = shard_map(
124
125
  _gmm,
125
126
  mesh=mesh,
126
- in_specs=(P(), P(None, "model", None), P()),
127
- out_specs=(P(None, "model")),
127
+ in_specs=(P("data", None), P(None, "model", None), P("data")),
128
+ out_specs=(P("data", "model")),
128
129
  check_rep=False,
129
130
  )(lhs, rhs, group_sizes)
130
131
 
131
132
  if rhs_bias is not None:
132
- rhs_bis = jnp.repeat(rhs_bias, group_sizes, 0, total_repeat_length=m)
133
- gmm_result = (gmm_result + rhs_bis).astype(gmm_result.dtype)
133
+
134
+ def _add_bias(gmm_result_local, rhs_bias_local, group_sizes_global):
135
+ rhs_bis = jnp.repeat(rhs_bias_local,
136
+ group_sizes_global,
137
+ 0,
138
+ total_repeat_length=m // mesh.shape["data"])
139
+ return (gmm_result_local + rhs_bis).astype(gmm_result_local.dtype)
140
+
141
+ gmm_result = shard_map(
142
+ _add_bias,
143
+ mesh=mesh,
144
+ in_specs=(P("data", "model"), P(None, "model"), P("data")),
145
+ out_specs=(P("data", "model")),
146
+ )(gmm_result, rhs_bias, group_sizes)
134
147
 
135
148
  n_shards = mesh.shape["model"]
136
149
  output_sizes = [intermediate_size, intermediate_size]
@@ -150,7 +163,8 @@ def tensor_sharded_gmm_row_parallel(
150
163
  # adapted from https://github.com/pytorch/xla/blob/1d409399474197c484894be90b75d9855393dda5/torch_xla/experimental/custom_kernel.py#L1401
151
164
  m, k, g = lhs.shape[0], lhs.shape[1], rhs.shape[0]
152
165
  n = rhs.shape[1] if transpose_rhs else rhs.shape[2]
153
- tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n, g)
166
+ tm, tk, tn = _get_tiling_size_for_gmm_kernel(m // mesh.shape["data"], k, n,
167
+ g)
154
168
 
155
169
  _gmm = functools.partial(
156
170
  gmm,
@@ -167,14 +181,25 @@ def tensor_sharded_gmm_row_parallel(
167
181
  gmm_result = shard_map(
168
182
  _gmm_all_reduce,
169
183
  mesh=mesh,
170
- in_specs=(P(None, "model"), P(None, None, "model"), P()),
171
- out_specs=(P()),
184
+ in_specs=(P("data", "model"), P(None, None, "model"), P("data")),
185
+ out_specs=(P("data")),
172
186
  check_rep=False,
173
187
  )(lhs, rhs, group_sizes)
174
-
175
188
  if rhs_bias is not None:
176
- rhs_bias = jnp.repeat(rhs_bias, group_sizes, 0, total_repeat_length=m)
177
- gmm_result = (gmm_result + rhs_bias).astype(gmm_result.dtype)
189
+
190
+ def _add_bias(gmm_result_local, rhs_bias_local, group_sizes_global):
191
+ rhs_bis = jnp.repeat(rhs_bias_local,
192
+ group_sizes_global,
193
+ 0,
194
+ total_repeat_length=m // mesh.shape["data"])
195
+ return (gmm_result_local + rhs_bis).astype(gmm_result_local.dtype)
196
+
197
+ gmm_result = shard_map(
198
+ _add_bias,
199
+ mesh=mesh,
200
+ in_specs=(P("data"), P(), P("data")),
201
+ out_specs=(P("data")),
202
+ )(gmm_result, rhs_bias, group_sizes)
178
203
 
179
204
  return gmm_result
180
205
 
@@ -366,15 +391,27 @@ def fused_moe_func(
366
391
  topk_weights = topk_weights / topk_weights.sum(axis=-1, keepdims=True)
367
392
  topk_weights = topk_weights.astype(dtype)
368
393
 
369
- topk_indices_flat = topk_indices.flatten()
370
- topk_argsort_indices = jnp.argsort(topk_indices_flat)
371
- topk_argsort_revert_indices = jnp.argsort(topk_argsort_indices)
372
- token_indices = jnp.arange(num_tokens, dtype=jnp.int32).repeat(topk)
373
- token_indices_sorted = token_indices[topk_argsort_indices]
374
- group_sizes = jnp.bincount(topk_indices_flat, length=global_num_experts)
375
-
376
- x = hidden_states[token_indices_sorted]
377
-
394
+ def _process_tokens_locally(hidden_states_local, topk_indices_local):
395
+ num_tokens_local = hidden_states_local.shape[0]
396
+ topk_indices_flat = topk_indices_local.flatten()
397
+ topk_argsort_indices = jnp.argsort(topk_indices_flat)
398
+ topk_argsort_revert_indices = jnp.argsort(topk_argsort_indices)
399
+ token_indices = jnp.arange(num_tokens_local,
400
+ dtype=jnp.int32).repeat(topk)
401
+ token_indices_sorted = token_indices[topk_argsort_indices]
402
+ group_sizes_local = jnp.bincount(topk_indices_flat,
403
+ length=global_num_experts)
404
+
405
+ x = hidden_states_local[token_indices_sorted]
406
+ return x, group_sizes_local, topk_argsort_revert_indices
407
+
408
+ x, group_sizes, topk_argsort_revert_indices = shard_map(
409
+ _process_tokens_locally,
410
+ mesh=mesh,
411
+ in_specs=(P("data", None), P("data", None)),
412
+ out_specs=(P("data", None), P("data"), P("data")),
413
+ check_rep=False,
414
+ )(hidden_states, topk_indices)
378
415
  if use_ep:
379
416
  x = expert_sharded_gmm(
380
417
  x,
@@ -411,7 +448,7 @@ def fused_moe_func(
411
448
  )
412
449
  else:
413
450
  x = jax.lax.with_sharding_constraint(
414
- x, NamedSharding(mesh, P(None, "model")))
451
+ x, NamedSharding(mesh, P("data", "model")))
415
452
  x = tensor_sharded_gmm_row_parallel(
416
453
  x,
417
454
  w2,
@@ -421,13 +458,25 @@ def fused_moe_func(
421
458
  mesh=mesh,
422
459
  )
423
460
 
424
- x = x[topk_argsort_revert_indices].reshape(-1, topk, hidden_size)
425
- x = x * jnp.expand_dims(topk_weights, axis=-1)
426
- x = x.sum(axis=-2)
461
+ def _finalize_output(x_local, topk_argsort_revert_indices_local,
462
+ topk_weights_local):
463
+ x_local = x_local[topk_argsort_revert_indices_local].reshape(
464
+ -1, topk, hidden_size)
465
+ x_local = x_local * jnp.expand_dims(topk_weights_local, axis=-1)
466
+ x_local = x_local.sum(axis=-2)
467
+ return x_local
468
+
469
+ x = shard_map(
470
+ _finalize_output,
471
+ mesh=mesh,
472
+ in_specs=(P("data", None), P("data"), P("data", None)),
473
+ out_specs=(P("data", None)),
474
+ check_rep=False,
475
+ )(x, topk_argsort_revert_indices, topk_weights)
427
476
  x = x.reshape(orig_shape)
428
477
 
429
478
  if reduce_results:
430
- x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P()))
479
+ x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P("data")))
431
480
  return x
432
481
 
433
482
 
@@ -61,7 +61,12 @@ class JaxCommonLinearConfig:
61
61
  " bad performance.", type(layer))
62
62
 
63
63
  self.bias_sharding = P(self.weight_sharding[0])
64
- self.n_shards = self.mesh.shape.get(self.weight_sharding[0], 1)
64
+ if isinstance(self.weight_sharding[0], tuple):
65
+ self.n_shards = 1
66
+ for axis in self.weight_sharding[0]:
67
+ self.n_shards *= self.mesh.shape.get(axis, 1)
68
+ else:
69
+ self.n_shards = self.mesh.shape.get(self.weight_sharding[0], 1)
65
70
 
66
71
  def get_input_sharding(self, x: torchax.tensor.Tensor):
67
72
  if self.enable_sequence_parallelism:
@@ -24,6 +24,8 @@ from vllm.model_executor.layers.quantization.mxfp4 import (Mxfp4Backend,
24
24
  from vllm.model_executor.layers.quantization.utils.quant_utils import \
25
25
  is_layer_skipped
26
26
 
27
+ from tpu_inference import envs
28
+ from tpu_inference.kernels.fused_moe.v1.kernel import fused_ep_moe
27
29
  from tpu_inference.layers.common.quant_methods import (MXFP4,
28
30
  get_tpu_quant_method)
29
31
  from tpu_inference.layers.vllm.fused_moe import fused_moe_func_padded
@@ -93,7 +95,8 @@ class VllmMxfp4Config(Mxfp4Config, JaxCommonConfig):
93
95
  "UnquantizedLinearMethod.")
94
96
  return VllmUnquantizedLinearMethod(linear_config)
95
97
  elif isinstance(layer, FusedMoE):
96
- return VllmMxfp4MoEMethod(layer.moe_config, self.mesh)
98
+ moe_config = self.get_moe_config(layer)
99
+ return VllmMxfp4MoEMethod(moe_config, self.mesh)
97
100
  elif isinstance(layer, Attention):
98
101
  # TODO: Add support for MXFP4 Attention.
99
102
  logger.warning_once("MXFP4 attention layer is not implemented. "
@@ -103,13 +106,30 @@ class VllmMxfp4Config(Mxfp4Config, JaxCommonConfig):
103
106
 
104
107
  class VllmMxfp4MoEMethod(Mxfp4MoEMethod):
105
108
 
106
- def __init__(self, moe: FusedMoEConfig, mesh: Mesh):
109
+ def __init__(self,
110
+ moe: FusedMoEConfig,
111
+ mesh: Mesh,
112
+ ep_axis_name: str = 'model'):
107
113
  FusedMoEMethodBase.__init__(self, moe)
108
114
 
109
115
  # We piggyback on triton implementation as it applies minimal hardware
110
116
  # specific post processing to the weights.
111
117
  self.mxfp4_backend = Mxfp4Backend.TRITON
118
+
112
119
  self.mesh = mesh
120
+ self.use_kernel = envs.USE_MOE_EP_KERNEL
121
+ self.ep_axis_name = ep_axis_name
122
+ # TODO: Use autotune table once we have it.
123
+ self.block_size = {
124
+ "bt": 64,
125
+ "bf": 1024,
126
+ "bd1": 1536,
127
+ "bd2": 1536,
128
+ "btc": 64,
129
+ "bfc": 1024,
130
+ "bd1c": 1536,
131
+ "bd2c": 1536,
132
+ }
113
133
 
114
134
  def get_fused_moe_quant_config(
115
135
  self, layer: torch.nn.Module) -> FusedMoEQuantConfig | None:
@@ -122,6 +142,7 @@ class VllmMxfp4MoEMethod(Mxfp4MoEMethod):
122
142
 
123
143
  def process_weights_after_loading(self, layer: torch.nn.Module):
124
144
  assert isinstance(layer, FusedMoE)
145
+ assert layer.moe_config.has_bias, "mxfp4 quantization alwyas use bias."
125
146
 
126
147
  w13_weight = u8_unpack_e2m1(t2j(layer.w13_weight, use_dlpack=False))
127
148
  w13_weight_scale = e8m0_to_fp32(
@@ -157,57 +178,95 @@ class VllmMxfp4MoEMethod(Mxfp4MoEMethod):
157
178
  w3_bias = w13_bias[:, 1::2]
158
179
  w13_bias = jnp.concat([w1_bias, w3_bias], axis=1)
159
180
 
160
- # TODO(kyuyeunk): Add weight processing logic for the new kernel.
161
- if layer.use_ep:
181
+ if self.use_kernel and layer.use_ep:
182
+ # Kernel expects:
183
+ # w13: (num_experts, 2, hidden_size, intermediate_size)
184
+ # w2: (num_experts, intermediate_size, hidden_size)
185
+ # Current format:
186
+ # w13_weight: (num_experts, 2*intermediate_size, hidden_size)
187
+ # w2_weight: (num_experts, hidden_size, intermediate_size)
188
+ num_experts = w13_weight.shape[0]
189
+ intermediate_size = w13_weight.shape[1] // 2
190
+ hidden_size = w13_weight.shape[2]
191
+
192
+ # Reshape and transpose w13_weight to (num_experts, 2, hidden_size, intermediate_size)
193
+ w13_reshaped = w13_weight.reshape(num_experts, 2,
194
+ intermediate_size, hidden_size)
195
+ w13_weight_transposed = jnp.transpose(w13_reshaped, (0, 1, 3, 2))
196
+
197
+ # Transpose w2_weight to (num_experts, intermediate_size, hidden_size)
198
+ w2_weight_transposed = jnp.transpose(w2_weight, (0, 2, 1))
199
+
200
+ # Apply EP sharding
162
201
  w13_weight = jax.device_put(
163
- w13_weight,
164
- Format(Layout((0, 1, 2)),
165
- NamedSharding(self.mesh, P("model", None, None))))
202
+ w13_weight_transposed,
203
+ Format(Layout((0, 1, 2, 3)),
204
+ NamedSharding(self.mesh, P("model", None, None, None))))
166
205
  w2_weight = jax.device_put(
167
- w2_weight,
206
+ w2_weight_transposed,
168
207
  Format(Layout((0, 1, 2)),
169
208
  NamedSharding(self.mesh, P("model", None, None))))
170
209
 
171
- w13_bias = jax.device_put(
172
- w13_bias,
173
- Format(Layout((0, 1)),
174
- NamedSharding(self.mesh, P("model", None))))
175
- w2_bias = jax.device_put(
176
- w2_bias,
177
- Format(Layout((0, 1)),
178
- NamedSharding(self.mesh, P("model", None))))
210
+ if self.moe.has_bias:
211
+ w13_bias = w13_bias.reshape(num_experts, 2, intermediate_size)
212
+
213
+ # Apply EP sharding
214
+ w13_bias = jax.device_put(
215
+ w13_bias,
216
+ Format(Layout((0, 1, 2)),
217
+ NamedSharding(self.mesh, P("model", None, None))))
218
+ w2_bias = jax.device_put(
219
+ w2_bias,
220
+ Format(Layout((0, 1)),
221
+ NamedSharding(self.mesh, P("model", None))))
179
222
 
180
223
  else:
181
- intermediate_size = w13_weight.shape[1] // 2
182
- assert intermediate_size == w2_weight.shape[-1]
183
- output_sizes = [intermediate_size, intermediate_size]
184
- n_shards = self.mesh.shape["model"]
185
- assert intermediate_size % n_shards == 0
186
- w13_weight = reorder_concatenated_tensor_for_sharding(w13_weight,
187
- output_sizes,
188
- n_shards,
189
- dim=1)
190
- w13_weight = jax.device_put(
191
- w13_weight,
192
- Format(Layout((0, 1, 2)),
193
- NamedSharding(self.mesh, P(None, "model", None))))
194
- w2_weight = jax.device_put(
195
- w2_weight,
196
- Format(Layout((0, 1, 2)),
197
- NamedSharding(self.mesh, P(None, None, "model"))))
198
-
199
- w13_bias = reorder_concatenated_tensor_for_sharding(w13_bias,
200
- output_sizes,
201
- n_shards,
202
- dim=1)
203
- w13_bias = jax.device_put(
204
- w13_bias,
205
- Format(Layout((0, 1)),
206
- NamedSharding(self.mesh, P(None, "model"))))
207
- w2_bias = jax.device_put(
208
- w2_bias,
209
- Format(Layout((0, 1)), NamedSharding(self.mesh, P(None,
210
- None))))
224
+ if layer.use_ep:
225
+ w13_weight = jax.device_put(
226
+ w13_weight,
227
+ Format(Layout((0, 1, 2)),
228
+ NamedSharding(self.mesh, P("model", None, None))))
229
+ w2_weight = jax.device_put(
230
+ w2_weight,
231
+ Format(Layout((0, 1, 2)),
232
+ NamedSharding(self.mesh, P("model", None, None))))
233
+
234
+ w13_bias = jax.device_put(
235
+ w13_bias,
236
+ Format(Layout((0, 1)),
237
+ NamedSharding(self.mesh, P("model", None))))
238
+ w2_bias = jax.device_put(
239
+ w2_bias,
240
+ Format(Layout((0, 1)),
241
+ NamedSharding(self.mesh, P("model", None))))
242
+
243
+ else:
244
+ intermediate_size = w13_weight.shape[1] // 2
245
+ assert intermediate_size == w2_weight.shape[-1]
246
+ output_sizes = [intermediate_size, intermediate_size]
247
+ n_shards = self.mesh.shape["model"]
248
+ assert intermediate_size % n_shards == 0
249
+ w13_weight = reorder_concatenated_tensor_for_sharding(
250
+ w13_weight, output_sizes, n_shards, dim=1)
251
+ w13_weight = jax.device_put(
252
+ w13_weight,
253
+ Format(Layout((0, 1, 2)),
254
+ NamedSharding(self.mesh, P(None, "model", None))))
255
+ w2_weight = jax.device_put(
256
+ w2_weight,
257
+ Format(Layout((0, 1, 2)),
258
+ NamedSharding(self.mesh, P(None, None, "model"))))
259
+
260
+ w13_bias = reorder_concatenated_tensor_for_sharding(
261
+ w13_bias, output_sizes, n_shards, dim=1)
262
+ w13_bias = jax.device_put(
263
+ w13_bias,
264
+ Format(Layout((0, 1)),
265
+ NamedSharding(self.mesh, P(None, "model"))))
266
+ w2_bias = jax.device_put(
267
+ w2_bias,
268
+ Format(Layout((0, 1)),
269
+ NamedSharding(self.mesh, P(None, None))))
211
270
 
212
271
  layer.w13_weight = Parameter(torch_view(w13_weight),
213
272
  requires_grad=False)
@@ -246,21 +305,37 @@ class VllmMxfp4MoEMethod(Mxfp4MoEMethod):
246
305
  raise NotImplementedError(
247
306
  "Only softmax is supported for scoring_func")
248
307
 
249
- # Use the original implementation
250
- output = fused_moe_func_padded(
251
- jax_view(x),
252
- jax_view(layer.w13_weight),
253
- jax_view(layer.w2_weight),
254
- jax_view(layer.w13_bias) if self.moe.has_bias else None,
255
- jax_view(layer.w2_bias) if self.moe.has_bias else None,
256
- jax_view(router_logits),
257
- topk=top_k,
258
- global_num_experts=global_num_experts,
259
- renormalize=renormalize,
260
- reduce_results=layer.reduce_results,
261
- mesh=self.mesh,
262
- use_ep=layer.use_ep,
263
- activation=activation,
264
- )
308
+ if self.use_kernel and layer.use_ep:
309
+ output = fused_ep_moe(
310
+ mesh=self.mesh,
311
+ tokens=jax_view(x),
312
+ w1=jax_view(layer.w13_weight),
313
+ w2=jax_view(layer.w2_weight),
314
+ b1=jax_view(layer.w13_bias),
315
+ b2=jax_view(layer.w2_bias),
316
+ gating_output=jax_view(router_logits),
317
+ top_k=top_k,
318
+ ep_axis_name=self.ep_axis_name,
319
+ renormalize_topk_logits=renormalize,
320
+ act_fn=activation,
321
+ **self.block_size,
322
+ )
323
+ else:
324
+ # Use the original implementation
325
+ output = fused_moe_func_padded(
326
+ jax_view(x),
327
+ jax_view(layer.w13_weight),
328
+ jax_view(layer.w2_weight),
329
+ jax_view(layer.w13_bias),
330
+ jax_view(layer.w2_bias),
331
+ jax_view(router_logits),
332
+ topk=top_k,
333
+ global_num_experts=global_num_experts,
334
+ renormalize=renormalize,
335
+ reduce_results=layer.reduce_results,
336
+ mesh=self.mesh,
337
+ use_ep=layer.use_ep,
338
+ activation=activation,
339
+ )
265
340
 
266
341
  return torch_view(output)