sglang 0.5.0rc2__py3-none-any.whl → 0.5.1.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. sglang/bench_one_batch.py +0 -6
  2. sglang/bench_one_batch_server.py +7 -2
  3. sglang/bench_serving.py +3 -3
  4. sglang/eval/llama3_eval.py +0 -1
  5. sglang/srt/configs/model_config.py +24 -9
  6. sglang/srt/configs/update_config.py +40 -5
  7. sglang/srt/constrained/xgrammar_backend.py +23 -11
  8. sglang/srt/conversation.py +2 -15
  9. sglang/srt/disaggregation/ascend/conn.py +1 -3
  10. sglang/srt/disaggregation/base/conn.py +1 -0
  11. sglang/srt/disaggregation/decode.py +1 -1
  12. sglang/srt/disaggregation/launch_lb.py +7 -1
  13. sglang/srt/disaggregation/mini_lb.py +11 -5
  14. sglang/srt/disaggregation/mooncake/conn.py +141 -47
  15. sglang/srt/disaggregation/prefill.py +261 -5
  16. sglang/srt/disaggregation/utils.py +2 -1
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  18. sglang/srt/distributed/device_communicators/pynccl.py +68 -18
  19. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
  20. sglang/srt/distributed/naive_distributed.py +112 -0
  21. sglang/srt/distributed/parallel_state.py +90 -4
  22. sglang/srt/entrypoints/context.py +20 -1
  23. sglang/srt/entrypoints/engine.py +27 -2
  24. sglang/srt/entrypoints/http_server.py +12 -0
  25. sglang/srt/entrypoints/openai/protocol.py +2 -2
  26. sglang/srt/entrypoints/openai/serving_chat.py +22 -6
  27. sglang/srt/entrypoints/openai/serving_completions.py +9 -1
  28. sglang/srt/entrypoints/openai/serving_responses.py +2 -2
  29. sglang/srt/eplb/expert_distribution.py +2 -3
  30. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  31. sglang/srt/hf_transformers_utils.py +24 -0
  32. sglang/srt/host_shared_memory.py +83 -0
  33. sglang/srt/layers/attention/ascend_backend.py +132 -22
  34. sglang/srt/layers/attention/flashattention_backend.py +24 -17
  35. sglang/srt/layers/attention/flashinfer_backend.py +11 -3
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +226 -76
  37. sglang/srt/layers/attention/triton_backend.py +85 -46
  38. sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
  39. sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
  40. sglang/srt/layers/attention/trtllm_mha_backend.py +390 -30
  41. sglang/srt/layers/attention/trtllm_mla_backend.py +39 -16
  42. sglang/srt/layers/attention/utils.py +94 -15
  43. sglang/srt/layers/attention/vision.py +40 -13
  44. sglang/srt/layers/attention/vision_utils.py +65 -0
  45. sglang/srt/layers/communicator.py +51 -3
  46. sglang/srt/layers/dp_attention.py +23 -4
  47. sglang/srt/layers/elementwise.py +94 -0
  48. sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
  49. sglang/srt/layers/layernorm.py +8 -1
  50. sglang/srt/layers/linear.py +24 -0
  51. sglang/srt/layers/logits_processor.py +5 -1
  52. sglang/srt/layers/moe/__init__.py +31 -0
  53. sglang/srt/layers/moe/ep_moe/layer.py +37 -33
  54. sglang/srt/layers/moe/fused_moe_native.py +14 -25
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
  60. sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
  61. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
  62. sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
  63. sglang/srt/layers/moe/moe_runner/base.py +13 -0
  64. sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
  65. sglang/srt/layers/moe/router.py +15 -9
  66. sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
  67. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
  68. sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
  69. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  70. sglang/srt/layers/moe/topk.py +167 -83
  71. sglang/srt/layers/moe/utils.py +159 -18
  72. sglang/srt/layers/quantization/__init__.py +13 -14
  73. sglang/srt/layers/quantization/awq.py +7 -7
  74. sglang/srt/layers/quantization/base_config.py +2 -6
  75. sglang/srt/layers/quantization/blockwise_int8.py +4 -12
  76. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -28
  77. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -0
  78. sglang/srt/layers/quantization/fp8.py +127 -119
  79. sglang/srt/layers/quantization/fp8_kernel.py +195 -24
  80. sglang/srt/layers/quantization/fp8_utils.py +34 -9
  81. sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
  82. sglang/srt/layers/quantization/gptq.py +5 -4
  83. sglang/srt/layers/quantization/marlin_utils.py +11 -3
  84. sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
  85. sglang/srt/layers/quantization/modelopt_quant.py +165 -68
  86. sglang/srt/layers/quantization/moe_wna16.py +10 -15
  87. sglang/srt/layers/quantization/mxfp4.py +206 -37
  88. sglang/srt/layers/quantization/quark/quark.py +390 -0
  89. sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
  90. sglang/srt/layers/quantization/unquant.py +34 -70
  91. sglang/srt/layers/quantization/utils.py +25 -0
  92. sglang/srt/layers/quantization/w4afp8.py +7 -8
  93. sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
  94. sglang/srt/layers/quantization/w8a8_int8.py +5 -13
  95. sglang/srt/layers/radix_attention.py +6 -0
  96. sglang/srt/layers/rotary_embedding.py +1 -0
  97. sglang/srt/lora/lora_manager.py +21 -22
  98. sglang/srt/lora/lora_registry.py +3 -3
  99. sglang/srt/lora/mem_pool.py +26 -24
  100. sglang/srt/lora/utils.py +10 -12
  101. sglang/srt/managers/cache_controller.py +76 -18
  102. sglang/srt/managers/detokenizer_manager.py +10 -2
  103. sglang/srt/managers/io_struct.py +9 -0
  104. sglang/srt/managers/mm_utils.py +1 -1
  105. sglang/srt/managers/schedule_batch.py +4 -9
  106. sglang/srt/managers/scheduler.py +25 -16
  107. sglang/srt/managers/session_controller.py +1 -1
  108. sglang/srt/managers/template_manager.py +7 -5
  109. sglang/srt/managers/tokenizer_manager.py +60 -21
  110. sglang/srt/managers/tp_worker.py +1 -0
  111. sglang/srt/managers/utils.py +59 -1
  112. sglang/srt/mem_cache/allocator.py +7 -5
  113. sglang/srt/mem_cache/allocator_ascend.py +0 -11
  114. sglang/srt/mem_cache/hicache_storage.py +14 -4
  115. sglang/srt/mem_cache/memory_pool.py +3 -3
  116. sglang/srt/mem_cache/memory_pool_host.py +35 -2
  117. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
  118. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
  119. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
  120. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
  121. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
  122. sglang/srt/model_executor/cuda_graph_runner.py +25 -12
  123. sglang/srt/model_executor/forward_batch_info.py +4 -1
  124. sglang/srt/model_executor/model_runner.py +43 -32
  125. sglang/srt/model_executor/npu_graph_runner.py +94 -0
  126. sglang/srt/model_loader/loader.py +24 -6
  127. sglang/srt/models/dbrx.py +12 -6
  128. sglang/srt/models/deepseek.py +2 -1
  129. sglang/srt/models/deepseek_nextn.py +3 -1
  130. sglang/srt/models/deepseek_v2.py +224 -223
  131. sglang/srt/models/ernie4.py +2 -2
  132. sglang/srt/models/glm4_moe.py +25 -63
  133. sglang/srt/models/glm4v.py +52 -1
  134. sglang/srt/models/glm4v_moe.py +8 -11
  135. sglang/srt/models/gpt_oss.py +34 -74
  136. sglang/srt/models/granitemoe.py +0 -1
  137. sglang/srt/models/grok.py +375 -51
  138. sglang/srt/models/interns1.py +12 -47
  139. sglang/srt/models/internvl.py +6 -51
  140. sglang/srt/models/llama4.py +0 -2
  141. sglang/srt/models/minicpm3.py +0 -1
  142. sglang/srt/models/mixtral.py +0 -2
  143. sglang/srt/models/nemotron_nas.py +435 -0
  144. sglang/srt/models/olmoe.py +0 -1
  145. sglang/srt/models/phi4mm.py +3 -21
  146. sglang/srt/models/qwen2_5_vl.py +2 -0
  147. sglang/srt/models/qwen2_moe.py +3 -18
  148. sglang/srt/models/qwen3.py +2 -2
  149. sglang/srt/models/qwen3_classification.py +7 -1
  150. sglang/srt/models/qwen3_moe.py +9 -38
  151. sglang/srt/models/step3_vl.py +2 -1
  152. sglang/srt/models/xverse_moe.py +11 -5
  153. sglang/srt/multimodal/processors/base_processor.py +3 -3
  154. sglang/srt/multimodal/processors/internvl.py +7 -2
  155. sglang/srt/multimodal/processors/llava.py +11 -7
  156. sglang/srt/offloader.py +433 -0
  157. sglang/srt/operations.py +6 -1
  158. sglang/srt/reasoning_parser.py +4 -3
  159. sglang/srt/server_args.py +237 -104
  160. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
  161. sglang/srt/speculative/eagle_utils.py +36 -13
  162. sglang/srt/speculative/eagle_worker.py +56 -3
  163. sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
  164. sglang/srt/two_batch_overlap.py +16 -11
  165. sglang/srt/utils.py +68 -70
  166. sglang/test/runners.py +8 -5
  167. sglang/test/test_block_fp8.py +5 -6
  168. sglang/test/test_block_fp8_ep.py +13 -19
  169. sglang/test/test_cutlass_moe.py +4 -6
  170. sglang/test/test_cutlass_w4a8_moe.py +4 -3
  171. sglang/test/test_fp4_moe.py +4 -3
  172. sglang/test/test_utils.py +7 -0
  173. sglang/utils.py +0 -1
  174. sglang/version.py +1 -1
  175. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/METADATA +7 -7
  176. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/RECORD +179 -161
  177. sglang/srt/layers/quantization/fp4.py +0 -557
  178. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/WHEEL +0 -0
  179. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/licenses/LICENSE +0 -0
  180. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,31 @@
1
+ from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
2
+ from sglang.srt.layers.moe.utils import (
3
+ DeepEPMode,
4
+ MoeA2ABackend,
5
+ MoeRunnerBackend,
6
+ get_deepep_config,
7
+ get_deepep_mode,
8
+ get_moe_a2a_backend,
9
+ get_moe_runner_backend,
10
+ get_tbo_token_distribution_threshold,
11
+ initialize_moe_config,
12
+ is_tbo_enabled,
13
+ should_use_flashinfer_cutlass_moe_fp4_allgather,
14
+ should_use_flashinfer_trtllm_moe,
15
+ )
16
+
17
+ __all__ = [
18
+ "DeepEPMode",
19
+ "MoeA2ABackend",
20
+ "MoeRunnerConfig",
21
+ "MoeRunnerBackend",
22
+ "initialize_moe_config",
23
+ "get_moe_a2a_backend",
24
+ "get_moe_runner_backend",
25
+ "get_deepep_mode",
26
+ "should_use_flashinfer_trtllm_moe",
27
+ "should_use_flashinfer_cutlass_moe_fp4_allgather",
28
+ "is_tbo_enabled",
29
+ "get_tbo_token_distribution_threshold",
30
+ "get_deepep_config",
31
+ ]
@@ -1,11 +1,17 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import logging
4
- from typing import TYPE_CHECKING, Optional
4
+ from typing import TYPE_CHECKING, Optional, Union
5
5
 
6
6
  import torch
7
7
 
8
8
  from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
9
+ from sglang.srt.layers.moe import (
10
+ get_deepep_mode,
11
+ get_moe_a2a_backend,
12
+ get_moe_runner_backend,
13
+ should_use_flashinfer_trtllm_moe,
14
+ )
9
15
  from sglang.srt.layers.moe.ep_moe.kernels import (
10
16
  ep_gather,
11
17
  ep_scatter,
@@ -16,14 +22,9 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
16
22
  )
17
23
  from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFusedMoE, FusedMoE
18
24
  from sglang.srt.layers.moe.topk import TopKOutput
19
- from sglang.srt.layers.moe.utils import DeepEPMode, should_use_flashinfer_trtllm_moe
20
25
  from sglang.srt.layers.quantization import deep_gemm_wrapper
21
26
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
22
- from sglang.srt.layers.quantization.fp8 import (
23
- Fp8Config,
24
- Fp8MoEMethod,
25
- get_tile_tokens_dim,
26
- )
27
+ from sglang.srt.layers.quantization.fp8 import Fp8Config
27
28
  from sglang.srt.layers.quantization.fp8_kernel import (
28
29
  is_fp8_fnuz,
29
30
  sglang_per_token_group_quant_fp8,
@@ -51,7 +52,6 @@ if not (_is_npu or _is_hip):
51
52
  if _use_aiter:
52
53
  from aiter import ActivationType, QuantType
53
54
  from aiter.fused_moe import fused_moe
54
- from aiter.ops.shuffle import shuffle_weight
55
55
 
56
56
  logger = logging.getLogger(__name__)
57
57
 
@@ -89,12 +89,11 @@ class EPMoE(FusedMoE):
89
89
  num_fused_shared_experts: int = 0,
90
90
  params_dtype: Optional[torch.dtype] = None,
91
91
  quant_config: Optional[QuantizationConfig] = None,
92
- tp_size: Optional[int] = None,
93
92
  prefix: str = "",
94
93
  activation: str = "silu",
95
94
  routed_scaling_factor: Optional[float] = None,
96
- activation_alpha: Optional[float] = None,
97
- swiglu_limit: Optional[float] = None,
95
+ gemm1_alpha: Optional[float] = None,
96
+ gemm1_clamp_limit: Optional[float] = None,
98
97
  with_bias: bool = False,
99
98
  ):
100
99
  super().__init__(
@@ -106,13 +105,12 @@ class EPMoE(FusedMoE):
106
105
  top_k=top_k,
107
106
  params_dtype=params_dtype,
108
107
  quant_config=quant_config,
109
- tp_size=tp_size,
110
108
  prefix=prefix,
111
109
  activation=activation,
112
110
  # apply_router_weight_on_input=apply_router_weight_on_input,
113
111
  routed_scaling_factor=routed_scaling_factor,
114
- activation_alpha=activation_alpha,
115
- swiglu_limit=swiglu_limit,
112
+ gemm1_alpha=gemm1_alpha,
113
+ gemm1_clamp_limit=gemm1_clamp_limit,
116
114
  with_bias=with_bias,
117
115
  )
118
116
 
@@ -163,7 +161,8 @@ class EPMoE(FusedMoE):
163
161
  )
164
162
 
165
163
  assert self.quant_method is not None
166
- assert self.activation == "silu"
164
+ assert self.moe_runner_config.activation == "silu"
165
+
167
166
  hidden_states_shape = hidden_states.shape
168
167
  hidden_states_dtype = hidden_states.dtype
169
168
  hidden_states_device = hidden_states.device
@@ -327,8 +326,8 @@ class EPMoE(FusedMoE):
327
326
  m_max * self.start_expert_id,
328
327
  BLOCK_SIZE=512,
329
328
  )
330
- if self.routed_scaling_factor is not None:
331
- output *= self.routed_scaling_factor
329
+ if self.moe_runner_config.routed_scaling_factor is not None:
330
+ output *= self.moe_runner_config.routed_scaling_factor
332
331
  return output
333
332
 
334
333
 
@@ -349,11 +348,9 @@ class DeepEPMoE(EPMoE):
349
348
  num_fused_shared_experts: int = 0,
350
349
  params_dtype: Optional[torch.dtype] = None,
351
350
  quant_config: Optional[QuantizationConfig] = None,
352
- tp_size: Optional[int] = None,
353
351
  prefix: str = "",
354
352
  activation: str = "silu",
355
353
  routed_scaling_factor: Optional[float] = None,
356
- deepep_mode: DeepEPMode = DeepEPMode.AUTO,
357
354
  ):
358
355
  super().__init__(
359
356
  num_experts=num_experts,
@@ -364,12 +361,11 @@ class DeepEPMoE(EPMoE):
364
361
  num_fused_shared_experts=num_fused_shared_experts,
365
362
  params_dtype=params_dtype,
366
363
  quant_config=quant_config,
367
- tp_size=tp_size,
368
364
  prefix=prefix,
369
365
  activation=activation,
370
366
  routed_scaling_factor=routed_scaling_factor,
371
367
  )
372
- self.deepep_mode = deepep_mode
368
+ self.deepep_mode = get_deepep_mode()
373
369
 
374
370
  # TODO: move to the beginning of the file
375
371
  from sglang.srt.distributed.parallel_state import get_tp_group
@@ -383,7 +379,7 @@ class DeepEPMoE(EPMoE):
383
379
  num_local_experts=self.num_local_experts,
384
380
  hidden_size=hidden_size,
385
381
  params_dtype=params_dtype,
386
- deepep_mode=deepep_mode,
382
+ deepep_mode=self.deepep_mode,
387
383
  async_finish=True, # TODO
388
384
  return_recv_hook=True,
389
385
  )
@@ -458,15 +454,19 @@ class DeepEPMoE(EPMoE):
458
454
  )
459
455
 
460
456
  def moe_impl(self, dispatch_output: DispatchOutput):
457
+ from sglang.srt.layers.moe.token_dispatcher import DispatchOutputChecker
458
+
461
459
  if _use_aiter:
460
+ assert DispatchOutputChecker.format_is_deepep(dispatch_output)
462
461
  # in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
463
462
  return self.forward_aiter(dispatch_output)
464
463
  if _is_npu:
464
+ assert DispatchOutputChecker.format_is_ascent_ll(dispatch_output)
465
465
  return self.forward_npu(dispatch_output)
466
- if dispatch_output.format.is_deepep_normal():
466
+ if DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
467
467
  assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
468
468
  return self.forward_deepgemm_contiguous(dispatch_output)
469
- elif dispatch_output.format.is_deepep_ll():
469
+ elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
470
470
  assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
471
471
  return self.forward_deepgemm_masked(dispatch_output)
472
472
  else:
@@ -490,7 +490,7 @@ class DeepEPMoE(EPMoE):
490
490
 
491
491
  def forward_aiter(
492
492
  self,
493
- dispatch_output: DeepEPNormalOutput,
493
+ dispatch_output: Union[DeepEPNormalOutput, DeepEPLLOutput],
494
494
  ):
495
495
  hidden_states, topk_idx, topk_weights = (
496
496
  dispatch_output.hidden_states,
@@ -516,7 +516,7 @@ class DeepEPMoE(EPMoE):
516
516
  quant_type=QuantType.per_128x128,
517
517
  activation=(
518
518
  ActivationType.Silu
519
- if self.activation == "silu"
519
+ if self.moe_runner_config.activation == "silu"
520
520
  else ActivationType.Gelu
521
521
  ),
522
522
  expert_mask=self.expert_mask,
@@ -531,7 +531,7 @@ class DeepEPMoE(EPMoE):
531
531
  )
532
532
  hidden_states_fp8, hidden_states_scale = hidden_states_fp8
533
533
  assert self.quant_method is not None
534
- assert self.activation == "silu"
534
+ assert self.moe_runner_config.activation == "silu"
535
535
  if num_recv_tokens_per_expert is None:
536
536
  return hidden_states_fp8.bfloat16()
537
537
  all_tokens = sum(num_recv_tokens_per_expert)
@@ -652,7 +652,7 @@ class DeepEPMoE(EPMoE):
652
652
  ):
653
653
  hidden_states_fp8, _, _, masked_m, expected_m = dispatch_output
654
654
  assert self.quant_method is not None
655
- assert self.activation == "silu"
655
+ assert self.moe_runner_config.activation == "silu"
656
656
 
657
657
  # GroupGemm-0
658
658
  num_groups, m, k = hidden_states_fp8[0].size()
@@ -735,7 +735,7 @@ class DeepEPMoE(EPMoE):
735
735
  assert isinstance(dispatch_output, AscendDeepEPLLOutput)
736
736
  hidden_states, topk_idx, topk_weights, _, seg_indptr, _ = dispatch_output
737
737
  assert self.quant_method is not None
738
- assert self.activation == "silu"
738
+ assert self.moe_runner_config.activation == "silu"
739
739
 
740
740
  # NOTE: Ascend's Dispatch & Combine does not support FP16
741
741
  output_dtype = torch.bfloat16
@@ -782,13 +782,17 @@ class DeepEPMoE(EPMoE):
782
782
  return hidden_states
783
783
 
784
784
 
785
- def get_moe_impl_class():
786
- if global_server_args_dict["moe_a2a_backend"].is_deepep():
785
+ def get_moe_impl_class(quant_config: Optional[QuantizationConfig] = None):
786
+ if get_moe_a2a_backend().is_deepep():
787
787
  return DeepEPMoE
788
788
 
789
789
  # NEW: Direct FP4 detection (bypasses EP requirements)
790
790
  # Check for FP4 quantization with TRTLLM flag, regardless of EP
791
- if global_server_args_dict.get("enable_flashinfer_trtllm_moe", False):
791
+ if get_moe_runner_backend().is_flashinfer_trtllm():
792
+ # FlashInferFP4MoE must be paired with ModelOptNvFp4FusedMoEMethod.
793
+ # If UnquantizedFusedMoEMethod is detected, fall back to FusedMoE instead.
794
+ if quant_config is None:
795
+ return FusedMoE
792
796
  try:
793
797
  # Check the quantization argument directly
794
798
  quantization = global_server_args_dict.get("quantization")
@@ -803,7 +807,7 @@ def get_moe_impl_class():
803
807
 
804
808
  if should_use_flashinfer_trtllm_moe():
805
809
  return FlashInferFusedMoE
806
- if global_server_args_dict["enable_flashinfer_cutlass_moe"]:
810
+ if get_moe_runner_backend().is_flashinfer_cutlass():
807
811
  return FusedMoE
808
812
  if get_moe_expert_parallel_world_size() > 1:
809
813
  return EPMoE
@@ -3,28 +3,22 @@ Torch-native implementation for FusedMoE. This is used for torch.compile.
3
3
  It is based on https://github.com/pytorch-labs/gpt-fast/blob/32971d3129541c5bfb4f715abc33d1c5f408d204/mixtral-moe/model.py#L204
4
4
  """
5
5
 
6
- from typing import Callable, Optional
7
-
8
6
  import torch
9
7
  from torch.nn import functional as F
10
8
 
11
9
  from sglang.srt.layers.activation import GeluAndMul, SiluAndMul
12
- from sglang.srt.layers.moe.topk import TopKOutput
10
+ from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
11
+ from sglang.srt.layers.moe.topk import StandardTopKOutput
13
12
 
14
13
 
15
14
  def fused_moe_forward_native(
16
15
  layer: torch.nn.Module,
17
16
  x: torch.Tensor,
18
- topk_output: TopKOutput,
19
- *,
20
- activation: str = "silu",
21
- apply_router_weight_on_input: bool = False,
22
- inplace: bool = True,
23
- no_combine: bool = False,
24
- routed_scaling_factor: Optional[float] = None,
17
+ topk_output: StandardTopKOutput,
18
+ moe_runner_config: MoeRunnerConfig,
25
19
  ) -> torch.Tensor:
26
20
 
27
- if apply_router_weight_on_input:
21
+ if moe_runner_config.apply_router_weight_on_input:
28
22
  raise NotImplementedError()
29
23
 
30
24
  topk_weights, topk_ids, _ = topk_output
@@ -33,12 +27,12 @@ def fused_moe_forward_native(
33
27
  w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
34
28
  w2_weights = layer.w2_weight[topk_ids]
35
29
  x1 = torch.einsum("ti,taoi -> tao", x, w1_weights)
36
- if activation == "silu":
30
+ if moe_runner_config.activation == "silu":
37
31
  x1 = F.silu(x1)
38
- elif activation == "gelu":
32
+ elif moe_runner_config.activation == "gelu":
39
33
  x1 = F.gelu(x1)
40
34
  else:
41
- raise ValueError(f"Unsupported activation: {activation=}")
35
+ raise ValueError(f"Unsupported activation: {moe_runner_config.activation=}")
42
36
  x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
43
37
  expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
44
38
  return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
@@ -47,16 +41,11 @@ def fused_moe_forward_native(
47
41
  def moe_forward_native(
48
42
  layer: torch.nn.Module,
49
43
  x: torch.Tensor,
50
- topk_output: TopKOutput,
51
- *,
52
- activation: str = "silu",
53
- apply_router_weight_on_input: bool = False,
54
- inplace: bool = True,
55
- no_combine: bool = False,
56
- routed_scaling_factor: Optional[float] = None,
44
+ topk_output: StandardTopKOutput,
45
+ moe_runner_config: MoeRunnerConfig,
57
46
  ) -> torch.Tensor:
58
47
 
59
- if apply_router_weight_on_input:
48
+ if moe_runner_config.apply_router_weight_on_input:
60
49
  raise NotImplementedError()
61
50
 
62
51
  topk_weights, topk_ids, _ = topk_output
@@ -72,12 +61,12 @@ def moe_forward_native(
72
61
  sorted_tokens = x[idxs // topk_ids.shape[1]]
73
62
  tokens_per_expert = tokens_per_expert.cpu().numpy()
74
63
 
75
- if activation == "silu":
64
+ if moe_runner_config.activation == "silu":
76
65
  act = SiluAndMul()
77
- elif activation == "gelu":
66
+ elif moe_runner_config.activation == "gelu":
78
67
  act = GeluAndMul()
79
68
  else:
80
- raise ValueError(f"Unsupported activation: {activation=}")
69
+ raise ValueError(f"Unsupported activation: {moe_runner_config.activation=}")
81
70
 
82
71
  outputs = []
83
72
  start_idx = 0
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 64,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 4
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 128,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 1,
15
+ "num_warps": 4,
16
+ "num_stages": 4
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 64,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 1,
23
+ "num_warps": 4,
24
+ "num_stages": 3
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 128,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 1,
31
+ "num_warps": 4,
32
+ "num_stages": 3
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 1,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 128,
45
+ "BLOCK_SIZE_K": 64,
46
+ "GROUP_SIZE_M": 1,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 256,
53
+ "BLOCK_SIZE_K": 64,
54
+ "GROUP_SIZE_M": 1,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 256,
61
+ "BLOCK_SIZE_K": 64,
62
+ "GROUP_SIZE_M": 1,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 256,
69
+ "BLOCK_SIZE_K": 64,
70
+ "GROUP_SIZE_M": 1,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 256,
77
+ "BLOCK_SIZE_K": 64,
78
+ "GROUP_SIZE_M": 1,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 256,
85
+ "BLOCK_SIZE_K": 64,
86
+ "GROUP_SIZE_M": 1,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 1,
95
+ "num_warps": 4,
96
+ "num_stages": 4
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 64,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 16,
103
+ "num_warps": 4,
104
+ "num_stages": 4
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 64,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 32,
111
+ "num_warps": 4,
112
+ "num_stages": 4
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 64,
116
+ "BLOCK_SIZE_N": 128,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 16,
119
+ "num_warps": 4,
120
+ "num_stages": 4
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 64,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 32,
127
+ "num_warps": 4,
128
+ "num_stages": 4
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 64,
132
+ "BLOCK_SIZE_N": 128,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 32,
135
+ "num_warps": 4,
136
+ "num_stages": 4
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 64,
140
+ "BLOCK_SIZE_N": 128,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 16,
143
+ "num_warps": 4,
144
+ "num_stages": 4
145
+ }
146
+ }
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 64,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 2
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 64,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 32,
15
+ "num_warps": 4,
16
+ "num_stages": 2
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 32,
21
+ "BLOCK_SIZE_K": 256,
22
+ "GROUP_SIZE_M": 64,
23
+ "num_warps": 4,
24
+ "num_stages": 2
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 64,
29
+ "BLOCK_SIZE_K": 256,
30
+ "GROUP_SIZE_M": 32,
31
+ "num_warps": 4,
32
+ "num_stages": 3
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 32,
36
+ "BLOCK_SIZE_N": 64,
37
+ "BLOCK_SIZE_K": 256,
38
+ "GROUP_SIZE_M": 1,
39
+ "num_warps": 8,
40
+ "num_stages": 2
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 64,
45
+ "BLOCK_SIZE_K": 256,
46
+ "GROUP_SIZE_M": 32,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 32,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 16,
55
+ "num_warps": 4,
56
+ "num_stages": 4
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 32,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 16,
63
+ "num_warps": 8,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 32,
69
+ "BLOCK_SIZE_K": 256,
70
+ "GROUP_SIZE_M": 16,
71
+ "num_warps": 8,
72
+ "num_stages": 2
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 64,
77
+ "BLOCK_SIZE_K": 256,
78
+ "GROUP_SIZE_M": 16,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 64,
85
+ "BLOCK_SIZE_K": 256,
86
+ "GROUP_SIZE_M": 1,
87
+ "num_warps": 4,
88
+ "num_stages": 2
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 32,
92
+ "BLOCK_SIZE_N": 64,
93
+ "BLOCK_SIZE_K": 256,
94
+ "GROUP_SIZE_M": 32,
95
+ "num_warps": 4,
96
+ "num_stages": 2
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 32,
100
+ "BLOCK_SIZE_N": 64,
101
+ "BLOCK_SIZE_K": 256,
102
+ "GROUP_SIZE_M": 64,
103
+ "num_warps": 4,
104
+ "num_stages": 2
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 64,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 256,
110
+ "GROUP_SIZE_M": 16,
111
+ "num_warps": 4,
112
+ "num_stages": 2
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 128,
116
+ "BLOCK_SIZE_N": 64,
117
+ "BLOCK_SIZE_K": 256,
118
+ "GROUP_SIZE_M": 16,
119
+ "num_warps": 4,
120
+ "num_stages": 2
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 64,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 32,
127
+ "num_warps": 4,
128
+ "num_stages": 3
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 128,
132
+ "BLOCK_SIZE_N": 256,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 1,
135
+ "num_warps": 8,
136
+ "num_stages": 3
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 128,
140
+ "BLOCK_SIZE_N": 256,
141
+ "BLOCK_SIZE_K": 64,
142
+ "GROUP_SIZE_M": 32,
143
+ "num_warps": 8,
144
+ "num_stages": 3
145
+ }
146
+ }