sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. sglang/__init__.py +8 -3
  2. sglang/bench_one_batch.py +119 -17
  3. sglang/lang/chat_template.py +18 -0
  4. sglang/srt/bench_utils.py +137 -0
  5. sglang/srt/configs/model_config.py +42 -7
  6. sglang/srt/conversation.py +9 -5
  7. sglang/srt/disaggregation/base/conn.py +5 -2
  8. sglang/srt/disaggregation/decode.py +14 -4
  9. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
  10. sglang/srt/disaggregation/mooncake/conn.py +286 -160
  11. sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
  12. sglang/srt/disaggregation/prefill.py +2 -0
  13. sglang/srt/distributed/parallel_state.py +15 -11
  14. sglang/srt/entrypoints/context.py +227 -0
  15. sglang/srt/entrypoints/engine.py +15 -9
  16. sglang/srt/entrypoints/harmony_utils.py +372 -0
  17. sglang/srt/entrypoints/http_server.py +74 -4
  18. sglang/srt/entrypoints/openai/protocol.py +218 -1
  19. sglang/srt/entrypoints/openai/serving_chat.py +41 -11
  20. sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
  21. sglang/srt/entrypoints/openai/tool_server.py +175 -0
  22. sglang/srt/entrypoints/tool.py +87 -0
  23. sglang/srt/eplb/expert_location.py +5 -1
  24. sglang/srt/function_call/ebnf_composer.py +1 -0
  25. sglang/srt/function_call/function_call_parser.py +2 -0
  26. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  27. sglang/srt/function_call/gpt_oss_detector.py +331 -0
  28. sglang/srt/function_call/kimik2_detector.py +3 -3
  29. sglang/srt/function_call/qwen3_coder_detector.py +219 -9
  30. sglang/srt/hf_transformers_utils.py +30 -3
  31. sglang/srt/jinja_template_utils.py +14 -1
  32. sglang/srt/layers/attention/aiter_backend.py +375 -115
  33. sglang/srt/layers/attention/ascend_backend.py +3 -0
  34. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
  35. sglang/srt/layers/attention/flashattention_backend.py +18 -0
  36. sglang/srt/layers/attention/flashinfer_backend.py +52 -13
  37. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  38. sglang/srt/layers/attention/triton_backend.py +85 -14
  39. sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
  40. sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
  41. sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
  42. sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
  43. sglang/srt/layers/attention/vision.py +22 -6
  44. sglang/srt/layers/attention/wave_backend.py +627 -0
  45. sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
  46. sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
  47. sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
  48. sglang/srt/layers/communicator.py +29 -14
  49. sglang/srt/layers/dp_attention.py +12 -0
  50. sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
  51. sglang/srt/layers/linear.py +3 -7
  52. sglang/srt/layers/moe/cutlass_moe.py +12 -3
  53. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
  54. sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
  55. sglang/srt/layers/moe/ep_moe/layer.py +135 -73
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
  59. sglang/srt/layers/moe/fused_moe_triton/layer.py +412 -33
  60. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
  61. sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
  62. sglang/srt/layers/moe/topk.py +16 -4
  63. sglang/srt/layers/moe/utils.py +16 -0
  64. sglang/srt/layers/quantization/__init__.py +27 -3
  65. sglang/srt/layers/quantization/fp4.py +557 -0
  66. sglang/srt/layers/quantization/fp8.py +3 -6
  67. sglang/srt/layers/quantization/fp8_kernel.py +277 -0
  68. sglang/srt/layers/quantization/fp8_utils.py +51 -10
  69. sglang/srt/layers/quantization/modelopt_quant.py +258 -68
  70. sglang/srt/layers/quantization/mxfp4.py +654 -0
  71. sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
  72. sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
  73. sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
  74. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
  75. sglang/srt/layers/quantization/quark/utils.py +107 -0
  76. sglang/srt/layers/quantization/unquant.py +60 -6
  77. sglang/srt/layers/quantization/w4afp8.py +21 -12
  78. sglang/srt/layers/quantization/w8a8_int8.py +48 -34
  79. sglang/srt/layers/rotary_embedding.py +506 -3
  80. sglang/srt/layers/utils.py +9 -0
  81. sglang/srt/layers/vocab_parallel_embedding.py +8 -3
  82. sglang/srt/lora/backend/base_backend.py +3 -23
  83. sglang/srt/lora/layers.py +60 -114
  84. sglang/srt/lora/lora.py +17 -62
  85. sglang/srt/lora/lora_manager.py +82 -62
  86. sglang/srt/lora/lora_registry.py +23 -11
  87. sglang/srt/lora/mem_pool.py +63 -68
  88. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  89. sglang/srt/lora/utils.py +25 -58
  90. sglang/srt/managers/cache_controller.py +75 -58
  91. sglang/srt/managers/detokenizer_manager.py +1 -1
  92. sglang/srt/managers/io_struct.py +20 -8
  93. sglang/srt/managers/mm_utils.py +6 -13
  94. sglang/srt/managers/multimodal_processor.py +1 -1
  95. sglang/srt/managers/schedule_batch.py +61 -25
  96. sglang/srt/managers/schedule_policy.py +6 -6
  97. sglang/srt/managers/scheduler.py +41 -19
  98. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
  99. sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
  100. sglang/srt/managers/scheduler_recv_skipper.py +37 -0
  101. sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
  102. sglang/srt/managers/template_manager.py +35 -1
  103. sglang/srt/managers/tokenizer_manager.py +47 -30
  104. sglang/srt/managers/tp_worker.py +3 -0
  105. sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
  106. sglang/srt/mem_cache/allocator.py +61 -87
  107. sglang/srt/mem_cache/hicache_storage.py +1 -1
  108. sglang/srt/mem_cache/hiradix_cache.py +80 -22
  109. sglang/srt/mem_cache/lora_radix_cache.py +421 -0
  110. sglang/srt/mem_cache/memory_pool_host.py +34 -36
  111. sglang/srt/mem_cache/multimodal_cache.py +33 -13
  112. sglang/srt/mem_cache/radix_cache.py +2 -5
  113. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
  114. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
  115. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
  116. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
  117. sglang/srt/model_executor/cuda_graph_runner.py +29 -9
  118. sglang/srt/model_executor/forward_batch_info.py +61 -19
  119. sglang/srt/model_executor/model_runner.py +148 -37
  120. sglang/srt/model_loader/loader.py +18 -6
  121. sglang/srt/model_loader/weight_utils.py +10 -0
  122. sglang/srt/models/bailing_moe.py +425 -0
  123. sglang/srt/models/deepseek_v2.py +137 -59
  124. sglang/srt/models/ernie4.py +426 -0
  125. sglang/srt/models/ernie4_eagle.py +203 -0
  126. sglang/srt/models/gemma2.py +0 -34
  127. sglang/srt/models/gemma3n_mm.py +38 -0
  128. sglang/srt/models/glm4.py +6 -0
  129. sglang/srt/models/glm4_moe.py +28 -16
  130. sglang/srt/models/glm4v.py +589 -0
  131. sglang/srt/models/glm4v_moe.py +400 -0
  132. sglang/srt/models/gpt_oss.py +1251 -0
  133. sglang/srt/models/granite.py +0 -25
  134. sglang/srt/models/llama.py +0 -25
  135. sglang/srt/models/llama4.py +1 -1
  136. sglang/srt/models/qwen2.py +6 -0
  137. sglang/srt/models/qwen2_5_vl.py +7 -3
  138. sglang/srt/models/qwen2_audio.py +10 -9
  139. sglang/srt/models/qwen2_moe.py +6 -0
  140. sglang/srt/models/qwen3.py +0 -24
  141. sglang/srt/models/qwen3_moe.py +32 -6
  142. sglang/srt/models/registry.py +1 -1
  143. sglang/srt/models/step3_vl.py +9 -0
  144. sglang/srt/models/torch_native_llama.py +0 -24
  145. sglang/srt/models/transformers.py +2 -5
  146. sglang/srt/multimodal/processors/base_processor.py +23 -13
  147. sglang/srt/multimodal/processors/glm4v.py +132 -0
  148. sglang/srt/multimodal/processors/qwen_audio.py +4 -2
  149. sglang/srt/multimodal/processors/step3_vl.py +3 -1
  150. sglang/srt/reasoning_parser.py +332 -37
  151. sglang/srt/server_args.py +186 -75
  152. sglang/srt/speculative/eagle_worker.py +16 -0
  153. sglang/srt/two_batch_overlap.py +169 -9
  154. sglang/srt/utils.py +41 -5
  155. sglang/srt/weight_sync/tensor_bucket.py +106 -0
  156. sglang/test/attention/test_trtllm_mla_backend.py +186 -36
  157. sglang/test/doc_patch.py +59 -0
  158. sglang/test/few_shot_gsm8k.py +1 -1
  159. sglang/test/few_shot_gsm8k_engine.py +1 -1
  160. sglang/test/run_eval.py +4 -1
  161. sglang/test/runners.py +2 -2
  162. sglang/test/simple_eval_common.py +6 -0
  163. sglang/test/simple_eval_gpqa.py +2 -0
  164. sglang/test/test_fp4_moe.py +118 -36
  165. sglang/test/test_utils.py +1 -1
  166. sglang/utils.py +1 -1
  167. sglang/version.py +1 -1
  168. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +36 -38
  169. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +174 -141
  170. sglang/srt/lora/backend/flashinfer_backend.py +0 -131
  171. /sglang/{api.py → lang/api.py} +0 -0
  172. /sglang/{lang/backend → srt/layers/quantization/quark}/__init__.py +0 -0
  173. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
  174. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
  175. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -581,6 +581,49 @@ def post_reorder_triton_kernel(
581
581
  )
582
582
 
583
583
 
584
+ @triton.jit
585
+ def post_reorder_triton_kernel_for_cutlass_moe(
586
+ down_output_ptr,
587
+ output_ptr,
588
+ src2dst_ptr,
589
+ topk_ids_ptr,
590
+ topk_weights_ptr,
591
+ num_experts,
592
+ topk,
593
+ hidden_size,
594
+ dst_start,
595
+ BLOCK_SIZE: tl.constexpr,
596
+ ):
597
+ InDtype = down_output_ptr.dtype.element_ty
598
+
599
+ src_idx_int32 = tl.program_id(0)
600
+ src_idx = src_idx_int32.to(tl.int64)
601
+ src2dst_ptr = src2dst_ptr + src_idx * topk
602
+ topk_ids_ptr = topk_ids_ptr + src_idx * topk
603
+ topk_weights_ptr = topk_weights_ptr + src_idx * topk
604
+
605
+ store_ptr = output_ptr + src_idx * hidden_size
606
+
607
+ vec = tl.arange(0, BLOCK_SIZE)
608
+
609
+ for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
610
+ offset = start_offset + vec
611
+ mask = offset < hidden_size
612
+
613
+ sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype)
614
+ for idx in range(topk):
615
+ expert_id = tl.load(topk_ids_ptr + idx)
616
+ if expert_id != num_experts:
617
+ dst_idx_int32 = tl.load(src2dst_ptr + idx)
618
+ dst_idx = dst_idx_int32.to(tl.int64)
619
+ dst_idx = dst_idx - dst_start
620
+ weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype)
621
+ load_ptr = down_output_ptr + dst_idx * hidden_size
622
+ in_data = tl.load(load_ptr + offset, mask=mask)
623
+ sum_vec += in_data * weigh_scale
624
+ tl.store(store_ptr + offset, sum_vec, mask=mask)
625
+
626
+
584
627
  @triton.jit
585
628
  def compute_m_range(
586
629
  pid,
@@ -14,13 +14,9 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
14
14
  silu_and_mul_masked_post_quant_fwd,
15
15
  tma_align_input_scale,
16
16
  )
17
- from sglang.srt.layers.moe.fused_moe_triton.layer import (
18
- FlashInferFusedMoE,
19
- FusedMoE,
20
- should_use_flashinfer_trtllm_moe,
21
- )
17
+ from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFusedMoE, FusedMoE
22
18
  from sglang.srt.layers.moe.topk import TopKOutput
23
- from sglang.srt.layers.moe.utils import DeepEPMode
19
+ from sglang.srt.layers.moe.utils import DeepEPMode, should_use_flashinfer_trtllm_moe
24
20
  from sglang.srt.layers.quantization import deep_gemm_wrapper
25
21
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
26
22
  from sglang.srt.layers.quantization.fp8 import (
@@ -38,6 +34,7 @@ from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip,
38
34
 
39
35
  if TYPE_CHECKING:
40
36
  from sglang.srt.layers.moe.token_dispatcher import (
37
+ AscendDeepEPLLOutput,
41
38
  DeepEPLLOutput,
42
39
  DeepEPNormalOutput,
43
40
  DispatchOutput,
@@ -48,7 +45,6 @@ _is_npu = is_npu()
48
45
  _is_fp8_fnuz = is_fp8_fnuz()
49
46
  _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
50
47
 
51
-
52
48
  if not (_is_npu or _is_hip):
53
49
  from sgl_kernel import silu_and_mul
54
50
 
@@ -60,6 +56,22 @@ if _use_aiter:
60
56
  logger = logging.getLogger(__name__)
61
57
 
62
58
 
59
+ # TODO(kaixih@nvidia): ideally we should merge this logic into
60
+ # `fill_gateup_input_triton_kernel` to directly generate e8m0 scale.
61
+ @torch.compile
62
+ def _cast_to_e8m0_with_rounding_up(x: torch.Tensor) -> torch.Tensor:
63
+ temp = x.to(torch.float32).view(torch.int32)
64
+ exp = torch.bitwise_right_shift(temp, 23)
65
+ mant = torch.bitwise_and(temp, 0x7FFFFF)
66
+ is_ru = torch.logical_and(
67
+ torch.logical_and((mant > 0), (exp != 0xFE)),
68
+ ~torch.logical_and((exp == 0), (mant <= 0x400000)),
69
+ )
70
+ exp = torch.where(is_ru, exp + 1, exp)
71
+ new_x = exp.to(torch.uint8).view(torch.int)
72
+ return new_x.transpose(1, 2).contiguous().transpose(1, 2)
73
+
74
+
63
75
  class EPMoE(FusedMoE):
64
76
  """
65
77
  MoE Expert Parallel Impl
@@ -81,6 +93,9 @@ class EPMoE(FusedMoE):
81
93
  prefix: str = "",
82
94
  activation: str = "silu",
83
95
  routed_scaling_factor: Optional[float] = None,
96
+ activation_alpha: Optional[float] = None,
97
+ swiglu_limit: Optional[float] = None,
98
+ with_bias: bool = False,
84
99
  ):
85
100
  super().__init__(
86
101
  num_experts=num_experts,
@@ -96,6 +111,9 @@ class EPMoE(FusedMoE):
96
111
  activation=activation,
97
112
  # apply_router_weight_on_input=apply_router_weight_on_input,
98
113
  routed_scaling_factor=routed_scaling_factor,
114
+ activation_alpha=activation_alpha,
115
+ swiglu_limit=swiglu_limit,
116
+ with_bias=with_bias,
99
117
  )
100
118
 
101
119
  self.start_expert_id = self.moe_ep_rank * self.num_local_experts
@@ -203,10 +221,22 @@ class EPMoE(FusedMoE):
203
221
 
204
222
  dispose_tensor(hidden_states)
205
223
 
224
+ if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
225
+ b, s_mn, s_k = gateup_input_scale.shape
226
+ assert (
227
+ s_mn % 4 == 0 and s_k % 4 == 0
228
+ ), f"scales must be aligned to 4, but got ({b}, {s_mn}, {s_k})"
229
+
206
230
  # GroupGemm-0
207
231
  gateup_input_fp8 = (
208
232
  gateup_input,
209
- deep_gemm_wrapper.get_col_major_tma_aligned_tensor(gateup_input_scale),
233
+ (
234
+ _cast_to_e8m0_with_rounding_up(gateup_input_scale)
235
+ if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
236
+ else deep_gemm_wrapper.get_col_major_tma_aligned_tensor(
237
+ gateup_input_scale
238
+ )
239
+ ),
210
240
  )
211
241
  num_groups, m, k = gateup_input_fp8[0].size()
212
242
  n = self.w13_weight.size(1)
@@ -214,7 +244,12 @@ class EPMoE(FusedMoE):
214
244
  (num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
215
245
  )
216
246
  deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
217
- gateup_input_fp8, self.w13_weight_fp8, gateup_output, masked_m, expected_m
247
+ gateup_input_fp8,
248
+ self.w13_weight_fp8,
249
+ gateup_output,
250
+ masked_m,
251
+ expected_m,
252
+ recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None,
218
253
  )
219
254
  del gateup_input
220
255
  del gateup_input_fp8
@@ -245,6 +280,7 @@ class EPMoE(FusedMoE):
245
280
  down_input_scale,
246
281
  scale_block_size,
247
282
  masked_m,
283
+ scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
248
284
  )
249
285
  del gateup_output
250
286
 
@@ -252,13 +288,24 @@ class EPMoE(FusedMoE):
252
288
  n = self.w2_weight.size(1)
253
289
  down_input_fp8 = (
254
290
  down_input,
255
- deep_gemm_wrapper.get_col_major_tma_aligned_tensor(down_input_scale),
291
+ (
292
+ down_input_scale
293
+ if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
294
+ else deep_gemm_wrapper.get_col_major_tma_aligned_tensor(
295
+ down_input_scale
296
+ )
297
+ ),
256
298
  )
257
299
  down_output = torch.empty(
258
300
  (num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
259
301
  )
260
302
  deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
261
- down_input_fp8, self.w2_weight_fp8, down_output, masked_m, expected_m
303
+ down_input_fp8,
304
+ self.w2_weight_fp8,
305
+ down_output,
306
+ masked_m,
307
+ expected_m,
308
+ recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None,
262
309
  )
263
310
  del down_input
264
311
  del down_input_fp8
@@ -341,7 +388,8 @@ class DeepEPMoE(EPMoE):
341
388
  return_recv_hook=True,
342
389
  )
343
390
 
344
- if self.deepep_mode.enable_low_latency():
391
+ if self.deepep_mode.enable_low_latency() and not _is_npu:
392
+ # NPU supports low_latency deepep without deepgemm
345
393
  assert (
346
394
  deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
347
395
  ), f"DeepEP {self.deepep_mode} mode requires deep_gemm"
@@ -358,7 +406,7 @@ class DeepEPMoE(EPMoE):
358
406
  )
359
407
  # the last one is invalid rank_id
360
408
  self.expert_mask[:-1] = 1
361
- else:
409
+ elif not _is_npu:
362
410
  self.w13_weight_fp8 = (
363
411
  self.w13_weight,
364
412
  (
@@ -413,6 +461,8 @@ class DeepEPMoE(EPMoE):
413
461
  if _use_aiter:
414
462
  # in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
415
463
  return self.forward_aiter(dispatch_output)
464
+ if _is_npu:
465
+ return self.forward_npu(dispatch_output)
416
466
  if dispatch_output.format.is_deepep_normal():
417
467
  assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
418
468
  return self.forward_deepgemm_contiguous(dispatch_output)
@@ -677,72 +727,84 @@ class DeepEPMoE(EPMoE):
677
727
 
678
728
  return down_output
679
729
 
730
+ def forward_npu(
731
+ self,
732
+ dispatch_output: DeepEPLLOutput,
733
+ ):
734
+ if TYPE_CHECKING:
735
+ assert isinstance(dispatch_output, AscendDeepEPLLOutput)
736
+ hidden_states, topk_idx, topk_weights, _, seg_indptr, _ = dispatch_output
737
+ assert self.quant_method is not None
738
+ assert self.activation == "silu"
680
739
 
681
- class FlashInferEPMoE(EPMoE):
682
- def __init__(self, *args, **kwargs):
683
- renormalize = kwargs.pop("renormalize", True)
684
- num_fused_shared_experts = kwargs.pop("num_fused_shared_experts", 0)
685
- use_grouped_topk = kwargs.pop("use_grouped_topk", False)
686
- num_expert_group = kwargs.pop("num_expert_group", None)
687
- topk_group = kwargs.pop("topk_group", None)
688
- correction_bias = kwargs.pop("correction_bias", None)
689
- super().__init__(*args, **kwargs)
690
- self.renormalize = renormalize
691
- self.num_fused_shared_experts = num_fused_shared_experts
692
- self.use_grouped_topk = use_grouped_topk
693
- if self.use_grouped_topk:
694
- assert num_expert_group is not None and topk_group is not None
695
- self.num_expert_group = num_expert_group
696
- self.topk_group = topk_group
697
- self.correction_bias = correction_bias
698
- self.use_flashinfer_trtllm_moe = should_use_flashinfer_trtllm_moe()
699
-
700
- def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
701
- assert self.use_flashinfer_trtllm_moe
702
- assert (
703
- self.activation == "silu"
704
- ), "Only silu is supported for flashinfer blockscale fp8 moe"
705
- assert (
706
- self.renormalize
707
- ), "Renormalize is required for flashinfer blockscale fp8 moe"
708
- assert (
709
- self.num_fused_shared_experts == 0
710
- ), "Fused shared experts are not supported for flashinfer blockscale fp8 moe"
711
- a_q, a_sf = sglang_per_token_group_quant_fp8(hidden_states, self.block_shape[1])
712
- # NOTE: scales of hidden states have to be transposed!
713
- a_sf_t = a_sf.t().contiguous()
714
- from flashinfer.fused_moe import trtllm_fp8_block_scale_moe
715
-
716
- return trtllm_fp8_block_scale_moe(
717
- routing_logits=router_logits.to(torch.float32),
718
- routing_bias=self.correction_bias.to(hidden_states.dtype),
719
- hidden_states=a_q,
720
- hidden_states_scale=a_sf_t,
721
- gemm1_weights=self.w13_weight,
722
- gemm1_weights_scale=self.w13_weight_scale_inv,
723
- gemm2_weights=self.w2_weight,
724
- gemm2_weights_scale=self.w2_weight_scale_inv,
725
- num_experts=self.num_experts,
726
- top_k=self.top_k,
727
- n_group=self.num_expert_group,
728
- topk_group=self.topk_group,
729
- intermediate_size=self.w2_weight.shape[2],
730
- local_expert_offset=self.start_expert_id,
731
- local_num_experts=self.num_local_experts,
732
- routed_scaling_factor=self.routed_scaling_factor,
733
- tile_tokens_dim=get_tile_tokens_dim(
734
- hidden_states.shape[0], self.top_k, self.num_experts
735
- ),
736
- routing_method_type=2, # DeepSeek-styled routing method
737
- use_shuffled_weight=False,
738
- )
740
+ # NOTE: Ascend's Dispatch & Combine does not support FP16
741
+ output_dtype = torch.bfloat16
742
+
743
+ pertoken_scale = hidden_states[1]
744
+ hidden_states = hidden_states[0]
745
+
746
+ group_list_type = 1
747
+ seg_indptr = seg_indptr.to(torch.int64)
748
+
749
+ import torch_npu
750
+
751
+ # gmm1: gate_up_proj
752
+ hidden_states = torch_npu.npu_grouped_matmul(
753
+ x=[hidden_states],
754
+ weight=[self.w13_weight],
755
+ scale=[self.w13_weight_scale.to(output_dtype)],
756
+ per_token_scale=[pertoken_scale],
757
+ split_item=2,
758
+ group_list_type=group_list_type,
759
+ group_type=0,
760
+ group_list=seg_indptr,
761
+ output_dtype=output_dtype,
762
+ )[0]
763
+
764
+ # act_fn: swiglu
765
+ hidden_states = torch_npu.npu_swiglu(hidden_states)
766
+
767
+ hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(hidden_states)
768
+
769
+ # gmm2: down_proj
770
+ hidden_states = torch_npu.npu_grouped_matmul(
771
+ x=[hidden_states],
772
+ weight=[self.w2_weight],
773
+ scale=[self.w2_weight_scale.to(output_dtype)],
774
+ per_token_scale=[swiglu_out_scale],
775
+ split_item=2,
776
+ group_list_type=group_list_type,
777
+ group_type=0,
778
+ group_list=seg_indptr,
779
+ output_dtype=output_dtype,
780
+ )[0]
781
+
782
+ return hidden_states
739
783
 
740
784
 
741
785
  def get_moe_impl_class():
742
786
  if global_server_args_dict["moe_a2a_backend"].is_deepep():
743
787
  return DeepEPMoE
788
+
789
+ # NEW: Direct FP4 detection (bypasses EP requirements)
790
+ # Check for FP4 quantization with TRTLLM flag, regardless of EP
791
+ if global_server_args_dict.get("enable_flashinfer_trtllm_moe", False):
792
+ try:
793
+ # Check the quantization argument directly
794
+ quantization = global_server_args_dict.get("quantization")
795
+ if quantization == "modelopt_fp4":
796
+ from sglang.srt.layers.moe.fused_moe_triton.layer import (
797
+ FlashInferFP4MoE,
798
+ )
799
+
800
+ return FlashInferFP4MoE
801
+ except:
802
+ pass
803
+
804
+ if should_use_flashinfer_trtllm_moe():
805
+ return FlashInferFusedMoE
744
806
  if global_server_args_dict["enable_flashinfer_cutlass_moe"]:
745
807
  return FusedMoE
746
808
  if get_moe_expert_parallel_world_size() > 1:
747
- return FlashInferEPMoE if should_use_flashinfer_trtllm_moe() else EPMoE
748
- return FlashInferFusedMoE if should_use_flashinfer_trtllm_moe() else FusedMoE
809
+ return EPMoE
810
+ return FusedMoE
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 128,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 4
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 256,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 1,
15
+ "num_warps": 4,
16
+ "num_stages": 5
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 128,
21
+ "BLOCK_SIZE_K": 256,
22
+ "GROUP_SIZE_M": 1,
23
+ "num_warps": 4,
24
+ "num_stages": 4
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 128,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 1,
31
+ "num_warps": 4,
32
+ "num_stages": 4
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 256,
38
+ "GROUP_SIZE_M": 1,
39
+ "num_warps": 4,
40
+ "num_stages": 4
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 128,
45
+ "BLOCK_SIZE_K": 256,
46
+ "GROUP_SIZE_M": 1,
47
+ "num_warps": 4,
48
+ "num_stages": 4
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 128,
53
+ "BLOCK_SIZE_K": 256,
54
+ "GROUP_SIZE_M": 1,
55
+ "num_warps": 4,
56
+ "num_stages": 4
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 128,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 1,
63
+ "num_warps": 4,
64
+ "num_stages": 4
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 256,
70
+ "GROUP_SIZE_M": 1,
71
+ "num_warps": 4,
72
+ "num_stages": 4
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 256,
78
+ "GROUP_SIZE_M": 1,
79
+ "num_warps": 4,
80
+ "num_stages": 4
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 256,
86
+ "GROUP_SIZE_M": 1,
87
+ "num_warps": 4,
88
+ "num_stages": 4
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 1,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 64,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 256,
102
+ "GROUP_SIZE_M": 32,
103
+ "num_warps": 4,
104
+ "num_stages": 4
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 64,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 32,
111
+ "num_warps": 4,
112
+ "num_stages": 4
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 64,
116
+ "BLOCK_SIZE_N": 128,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 1,
119
+ "num_warps": 4,
120
+ "num_stages": 4
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 64,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 256,
126
+ "GROUP_SIZE_M": 32,
127
+ "num_warps": 4,
128
+ "num_stages": 4
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 64,
132
+ "BLOCK_SIZE_N": 128,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 32,
135
+ "num_warps": 4,
136
+ "num_stages": 4
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 64,
140
+ "BLOCK_SIZE_N": 128,
141
+ "BLOCK_SIZE_K": 256,
142
+ "GROUP_SIZE_M": 16,
143
+ "num_warps": 4,
144
+ "num_stages": 4
145
+ }
146
+ }
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 128,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 4
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 128,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 1,
15
+ "num_warps": 4,
16
+ "num_stages": 4
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 128,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 1,
23
+ "num_warps": 4,
24
+ "num_stages": 4
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 128,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 1,
31
+ "num_warps": 4,
32
+ "num_stages": 3
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 1,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 256,
45
+ "BLOCK_SIZE_K": 64,
46
+ "GROUP_SIZE_M": 1,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 256,
53
+ "BLOCK_SIZE_K": 64,
54
+ "GROUP_SIZE_M": 1,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 128,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 1,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 128,
70
+ "GROUP_SIZE_M": 1,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 1,
79
+ "num_warps": 4,
80
+ "num_stages": 4
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 1,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 1,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 16,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 1,
103
+ "num_warps": 4,
104
+ "num_stages": 3
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 32,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 1,
111
+ "num_warps": 4,
112
+ "num_stages": 3
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 32,
116
+ "BLOCK_SIZE_N": 128,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 1,
119
+ "num_warps": 4,
120
+ "num_stages": 4
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 64,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 32,
127
+ "num_warps": 4,
128
+ "num_stages": 4
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 64,
132
+ "BLOCK_SIZE_N": 128,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 64,
135
+ "num_warps": 4,
136
+ "num_stages": 4
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 64,
140
+ "BLOCK_SIZE_N": 128,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 32,
143
+ "num_warps": 4,
144
+ "num_stages": 4
145
+ }
146
+ }