sglang 0.5.4__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. sglang/bench_one_batch.py +149 -34
  2. sglang/bench_serving.py +73 -14
  3. sglang/compile_deep_gemm.py +13 -7
  4. sglang/launch_server.py +2 -0
  5. sglang/srt/batch_invariant_ops/__init__.py +2 -0
  6. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +221 -4
  7. sglang/srt/checkpoint_engine/__init__.py +9 -0
  8. sglang/srt/checkpoint_engine/update.py +317 -0
  9. sglang/srt/compilation/backend.py +1 -1
  10. sglang/srt/configs/__init__.py +2 -0
  11. sglang/srt/configs/deepseek_ocr.py +542 -10
  12. sglang/srt/configs/deepseekvl2.py +95 -194
  13. sglang/srt/configs/kimi_linear.py +160 -0
  14. sglang/srt/configs/mamba_utils.py +66 -0
  15. sglang/srt/configs/model_config.py +30 -7
  16. sglang/srt/constants.py +7 -0
  17. sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
  18. sglang/srt/disaggregation/decode.py +34 -6
  19. sglang/srt/disaggregation/nixl/conn.py +2 -2
  20. sglang/srt/disaggregation/prefill.py +25 -3
  21. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
  22. sglang/srt/distributed/parallel_state.py +9 -12
  23. sglang/srt/entrypoints/engine.py +31 -20
  24. sglang/srt/entrypoints/grpc_server.py +0 -1
  25. sglang/srt/entrypoints/http_server.py +94 -94
  26. sglang/srt/entrypoints/openai/protocol.py +7 -1
  27. sglang/srt/entrypoints/openai/serving_chat.py +42 -0
  28. sglang/srt/entrypoints/openai/serving_completions.py +10 -0
  29. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  30. sglang/srt/environ.py +23 -2
  31. sglang/srt/eplb/expert_distribution.py +64 -1
  32. sglang/srt/eplb/expert_location.py +106 -36
  33. sglang/srt/function_call/function_call_parser.py +2 -0
  34. sglang/srt/function_call/minimax_m2.py +367 -0
  35. sglang/srt/grpc/compile_proto.py +3 -0
  36. sglang/srt/layers/activation.py +6 -0
  37. sglang/srt/layers/attention/ascend_backend.py +233 -5
  38. sglang/srt/layers/attention/attention_registry.py +3 -0
  39. sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
  40. sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
  41. sglang/srt/layers/attention/fla/kda.py +1359 -0
  42. sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
  43. sglang/srt/layers/attention/flashattention_backend.py +19 -8
  44. sglang/srt/layers/attention/flashinfer_backend.py +10 -1
  45. sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -11
  46. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  47. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
  48. sglang/srt/layers/attention/mamba/mamba.py +20 -11
  49. sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
  50. sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
  51. sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
  52. sglang/srt/layers/attention/nsa/transform_index.py +1 -1
  53. sglang/srt/layers/attention/nsa_backend.py +157 -23
  54. sglang/srt/layers/attention/triton_backend.py +4 -1
  55. sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
  56. sglang/srt/layers/attention/trtllm_mla_backend.py +11 -15
  57. sglang/srt/layers/attention/utils.py +78 -0
  58. sglang/srt/layers/communicator.py +24 -1
  59. sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
  60. sglang/srt/layers/layernorm.py +35 -6
  61. sglang/srt/layers/logits_processor.py +9 -20
  62. sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
  63. sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
  64. sglang/srt/layers/moe/ep_moe/layer.py +78 -289
  65. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
  67. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
  68. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
  69. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
  70. sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
  71. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
  72. sglang/srt/layers/moe/moe_runner/deep_gemm.py +340 -55
  73. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  74. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  75. sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
  76. sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
  77. sglang/srt/layers/moe/token_dispatcher/deepep.py +25 -18
  78. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  79. sglang/srt/layers/moe/topk.py +35 -10
  80. sglang/srt/layers/moe/utils.py +3 -4
  81. sglang/srt/layers/pooler.py +21 -2
  82. sglang/srt/layers/quantization/__init__.py +13 -84
  83. sglang/srt/layers/quantization/auto_round.py +394 -0
  84. sglang/srt/layers/quantization/awq.py +0 -3
  85. sglang/srt/layers/quantization/base_config.py +7 -0
  86. sglang/srt/layers/quantization/fp8.py +68 -63
  87. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  88. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  89. sglang/srt/layers/quantization/gguf.py +566 -0
  90. sglang/srt/layers/quantization/modelopt_quant.py +168 -11
  91. sglang/srt/layers/quantization/mxfp4.py +30 -38
  92. sglang/srt/layers/quantization/unquant.py +23 -45
  93. sglang/srt/layers/quantization/w4afp8.py +38 -2
  94. sglang/srt/layers/radix_attention.py +5 -2
  95. sglang/srt/layers/rotary_embedding.py +130 -46
  96. sglang/srt/layers/sampler.py +12 -1
  97. sglang/srt/lora/lora_registry.py +9 -0
  98. sglang/srt/managers/async_mm_data_processor.py +122 -0
  99. sglang/srt/managers/data_parallel_controller.py +30 -3
  100. sglang/srt/managers/detokenizer_manager.py +3 -0
  101. sglang/srt/managers/io_struct.py +29 -4
  102. sglang/srt/managers/multi_tokenizer_mixin.py +22 -1
  103. sglang/srt/managers/schedule_batch.py +74 -15
  104. sglang/srt/managers/scheduler.py +185 -144
  105. sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
  106. sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
  107. sglang/srt/managers/scheduler_pp_mixin.py +7 -2
  108. sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
  109. sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
  110. sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
  111. sglang/srt/managers/session_controller.py +6 -5
  112. sglang/srt/managers/tokenizer_manager.py +165 -78
  113. sglang/srt/managers/tp_worker.py +24 -1
  114. sglang/srt/mem_cache/base_prefix_cache.py +23 -4
  115. sglang/srt/mem_cache/common.py +1 -0
  116. sglang/srt/mem_cache/hicache_storage.py +7 -1
  117. sglang/srt/mem_cache/memory_pool.py +253 -57
  118. sglang/srt/mem_cache/memory_pool_host.py +12 -5
  119. sglang/srt/mem_cache/radix_cache.py +4 -0
  120. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  121. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
  122. sglang/srt/metrics/collector.py +46 -3
  123. sglang/srt/model_executor/cuda_graph_runner.py +15 -3
  124. sglang/srt/model_executor/forward_batch_info.py +55 -14
  125. sglang/srt/model_executor/model_runner.py +77 -170
  126. sglang/srt/model_executor/npu_graph_runner.py +7 -3
  127. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
  128. sglang/srt/model_loader/weight_utils.py +1 -1
  129. sglang/srt/models/bailing_moe.py +9 -2
  130. sglang/srt/models/deepseek_nextn.py +11 -2
  131. sglang/srt/models/deepseek_v2.py +296 -78
  132. sglang/srt/models/glm4.py +391 -77
  133. sglang/srt/models/glm4_moe.py +322 -354
  134. sglang/srt/models/glm4_moe_nextn.py +4 -14
  135. sglang/srt/models/glm4v.py +196 -55
  136. sglang/srt/models/glm4v_moe.py +29 -197
  137. sglang/srt/models/gpt_oss.py +1 -10
  138. sglang/srt/models/kimi_linear.py +678 -0
  139. sglang/srt/models/llama4.py +1 -1
  140. sglang/srt/models/llama_eagle3.py +11 -1
  141. sglang/srt/models/longcat_flash.py +2 -2
  142. sglang/srt/models/minimax_m2.py +922 -0
  143. sglang/srt/models/nvila.py +355 -0
  144. sglang/srt/models/nvila_lite.py +184 -0
  145. sglang/srt/models/qwen2.py +23 -2
  146. sglang/srt/models/qwen2_moe.py +30 -15
  147. sglang/srt/models/qwen3.py +35 -5
  148. sglang/srt/models/qwen3_moe.py +18 -12
  149. sglang/srt/models/qwen3_next.py +7 -0
  150. sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
  151. sglang/srt/multimodal/processors/base_processor.py +1 -0
  152. sglang/srt/multimodal/processors/glm4v.py +1 -1
  153. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  154. sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
  155. sglang/srt/multiplex/multiplexing_mixin.py +209 -0
  156. sglang/srt/multiplex/pdmux_context.py +164 -0
  157. sglang/srt/parser/conversation.py +7 -1
  158. sglang/srt/parser/reasoning_parser.py +28 -1
  159. sglang/srt/sampling/custom_logit_processor.py +67 -1
  160. sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
  161. sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
  162. sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
  163. sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
  164. sglang/srt/server_args.py +459 -199
  165. sglang/srt/single_batch_overlap.py +2 -4
  166. sglang/srt/speculative/draft_utils.py +16 -0
  167. sglang/srt/speculative/eagle_info.py +42 -36
  168. sglang/srt/speculative/eagle_info_v2.py +68 -25
  169. sglang/srt/speculative/eagle_utils.py +261 -16
  170. sglang/srt/speculative/eagle_worker.py +11 -3
  171. sglang/srt/speculative/eagle_worker_v2.py +15 -9
  172. sglang/srt/speculative/spec_info.py +305 -31
  173. sglang/srt/speculative/spec_utils.py +44 -8
  174. sglang/srt/tracing/trace.py +121 -12
  175. sglang/srt/utils/common.py +142 -74
  176. sglang/srt/utils/hf_transformers_utils.py +38 -12
  177. sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
  178. sglang/test/kits/radix_cache_server_kit.py +50 -0
  179. sglang/test/runners.py +31 -7
  180. sglang/test/simple_eval_common.py +5 -3
  181. sglang/test/simple_eval_humaneval.py +1 -0
  182. sglang/test/simple_eval_math.py +1 -0
  183. sglang/test/simple_eval_mmlu.py +1 -0
  184. sglang/test/simple_eval_mmmu_vlm.py +1 -0
  185. sglang/test/test_deterministic.py +235 -12
  186. sglang/test/test_deterministic_utils.py +2 -1
  187. sglang/test/test_utils.py +7 -1
  188. sglang/version.py +1 -1
  189. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +15 -28
  190. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +194 -175
  191. sglang/srt/models/vila.py +0 -306
  192. /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
  193. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
  194. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
  195. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
@@ -25,10 +25,11 @@ from sglang.srt.layers.quantization.base_config import (
25
25
  QuantizationConfig,
26
26
  QuantizeMethodBase,
27
27
  )
28
+ from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
28
29
  from sglang.srt.layers.quantization.fp8_utils import (
29
30
  apply_fp8_linear,
30
31
  cutlass_fp8_supported,
31
- is_sm100_supported,
32
+ is_blackwell_supported,
32
33
  )
33
34
  from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
34
35
  from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
@@ -49,8 +50,10 @@ if TYPE_CHECKING:
49
50
  )
50
51
  from sglang.srt.single_batch_overlap import DownGemmOverlapArgs
51
52
 
52
- if is_cuda():
53
- from sgl_kernel import scaled_fp4_quant
53
+ try:
54
+ from flashinfer import fp4_quantize
55
+ except ImportError:
56
+ fp4_quantize = None
54
57
 
55
58
  try:
56
59
  from flashinfer import mm_fp4 as fp4_gemm
@@ -466,8 +469,6 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
466
469
  # Fp8 moe kernel needs single weight scale for w13 per expert.
467
470
  # We take the max of the w1 and w3 scales then dequant and requant each expert.
468
471
  if layer.w13_weight_scale.dim() == 2: # Shape: (num_experts, 2)
469
- from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
470
-
471
472
  # Get the maximum scale across w1 and w3 for each expert
472
473
  max_w13_scales = layer.w13_weight_scale.max(dim=1).values
473
474
 
@@ -515,6 +516,84 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
515
516
  layer.w2_input_scale.max(), requires_grad=False
516
517
  )
517
518
 
519
+ # Align FP8 weights to FlashInfer per-tensor kernel layout if enabled
520
+ if should_use_flashinfer_trtllm_moe():
521
+ from flashinfer import reorder_rows_for_gated_act_gemm, shuffle_matrix_a
522
+
523
+ # 1) Swap W13 halves: [Up, Gate] -> [Gate, Up] expected by FI
524
+ num_experts, two_n, hidden = layer.w13_weight.shape
525
+ inter = two_n // 2
526
+ w13_swapped = (
527
+ layer.w13_weight.reshape(num_experts, 2, inter, hidden)
528
+ .flip(dims=[1])
529
+ .reshape(num_experts, two_n, hidden)
530
+ )
531
+
532
+ # 2) Reorder rows for fused gated activation (W13)
533
+ w13_interleaved = [
534
+ reorder_rows_for_gated_act_gemm(w13_swapped[i])
535
+ for i in range(num_experts)
536
+ ]
537
+ w13_interleaved = torch.stack(w13_interleaved).reshape(
538
+ num_experts, two_n, hidden
539
+ )
540
+
541
+ # 3) Shuffle weights for transposed MMA output (both W13, W2)
542
+ epilogue_tile_m = 128
543
+ w13_shuffled = [
544
+ shuffle_matrix_a(w13_interleaved[i].view(torch.uint8), epilogue_tile_m)
545
+ for i in range(num_experts)
546
+ ]
547
+ w2_shuffled = [
548
+ shuffle_matrix_a(layer.w2_weight[i].view(torch.uint8), epilogue_tile_m)
549
+ for i in range(num_experts)
550
+ ]
551
+
552
+ layer.w13_weight = Parameter(
553
+ torch.stack(w13_shuffled).view(torch.float8_e4m3fn),
554
+ requires_grad=False,
555
+ )
556
+ layer.w2_weight = Parameter(
557
+ torch.stack(w2_shuffled).view(torch.float8_e4m3fn),
558
+ requires_grad=False,
559
+ )
560
+
561
+ # Precompute and register per-expert output scaling factors for FI MoE
562
+ if should_use_flashinfer_trtllm_moe():
563
+ # Note: w13_input_scale and w2_input_scale are scalar Parameters post-reduction
564
+ assert (
565
+ hasattr(layer, "w13_input_scale") and layer.w13_input_scale is not None
566
+ )
567
+ assert hasattr(layer, "w2_input_scale") and layer.w2_input_scale is not None
568
+ assert (
569
+ hasattr(layer, "w13_weight_scale")
570
+ and layer.w13_weight_scale is not None
571
+ )
572
+ assert (
573
+ hasattr(layer, "w2_weight_scale") and layer.w2_weight_scale is not None
574
+ )
575
+
576
+ input_scale = layer.w13_input_scale.to(torch.float32)
577
+ activation_scale = layer.w2_input_scale.to(torch.float32)
578
+ w13_weight_scale = layer.w13_weight_scale.to(torch.float32)
579
+ w2_weight_scale = layer.w2_weight_scale.to(torch.float32)
580
+
581
+ output1_scales_scalar = (
582
+ w13_weight_scale * input_scale * (1.0 / activation_scale)
583
+ )
584
+ output1_scales_gate_scalar = w13_weight_scale * input_scale
585
+ output2_scales_scalar = activation_scale * w2_weight_scale
586
+
587
+ layer.output1_scales_scalar = Parameter(
588
+ output1_scales_scalar, requires_grad=False
589
+ )
590
+ layer.output1_scales_gate_scalar = Parameter(
591
+ output1_scales_gate_scalar, requires_grad=False
592
+ )
593
+ layer.output2_scales_scalar = Parameter(
594
+ output2_scales_scalar, requires_grad=False
595
+ )
596
+
518
597
  def create_moe_runner(
519
598
  self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
520
599
  ):
@@ -526,6 +605,81 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
526
605
  layer: torch.nn.Module,
527
606
  dispatch_output: StandardDispatchOutput,
528
607
  ) -> CombineInput:
608
+ x = dispatch_output.hidden_states
609
+ topk_output = dispatch_output.topk_output
610
+
611
+ # Fast path: TRT-LLM FP8 per-tensor MoE using BYPASSED TopK routing
612
+ from sglang.srt.layers.moe.topk import TopKOutputChecker
613
+
614
+ if should_use_flashinfer_trtllm_moe() and TopKOutputChecker.format_is_bypassed(
615
+ topk_output
616
+ ):
617
+ router_logits = topk_output.router_logits
618
+ topk_config = topk_output.topk_config
619
+
620
+ # Constraints
621
+ assert (
622
+ self.moe_runner_config.activation == "silu"
623
+ ), "Only silu is supported for flashinfer fp8 moe"
624
+
625
+ from flashinfer import RoutingMethodType
626
+ from flashinfer.fused_moe import trtllm_fp8_per_tensor_scale_moe
627
+
628
+ correction_bias = (
629
+ None
630
+ if topk_config.correction_bias is None
631
+ else topk_config.correction_bias
632
+ )
633
+ # Pre-quantize activations to FP8 per-tensor using provided input scale
634
+ x_fp8, _ = scaled_fp8_quant(x, layer.w13_input_scale)
635
+
636
+ use_routing_scales_on_input = True
637
+ routed_scaling_factor = self.moe_runner_config.routed_scaling_factor
638
+
639
+ # Enforce Llama4 routing for ModelOpt FP8 MoE for now.
640
+ # TODO(brayden): support other routing methods
641
+ assert topk_config.top_k == 1, "ModelOpt FP8 MoE requires top_k==1"
642
+ assert (
643
+ not topk_config.num_expert_group
644
+ ), "ModelOpt FP8 MoE does not support expert grouping"
645
+ assert (
646
+ not topk_config.topk_group
647
+ ), "ModelOpt FP8 MoE does not support grouped top-k"
648
+ routing_method_type = RoutingMethodType.Llama4
649
+
650
+ # FlashInfer TRTLLM requires routing_logits (and bias) to be bfloat16
651
+ routing_logits_cast = router_logits.to(torch.bfloat16)
652
+ routing_bias_cast = (
653
+ None if correction_bias is None else correction_bias.to(torch.bfloat16)
654
+ )
655
+
656
+ output = trtllm_fp8_per_tensor_scale_moe(
657
+ routing_logits=routing_logits_cast,
658
+ routing_bias=routing_bias_cast,
659
+ hidden_states=x_fp8,
660
+ gemm1_weights=layer.w13_weight,
661
+ output1_scales_scalar=layer.output1_scales_scalar,
662
+ output1_scales_gate_scalar=layer.output1_scales_gate_scalar,
663
+ gemm2_weights=layer.w2_weight,
664
+ output2_scales_scalar=layer.output2_scales_scalar,
665
+ num_experts=layer.num_experts,
666
+ top_k=topk_config.top_k,
667
+ n_group=0,
668
+ topk_group=0,
669
+ intermediate_size=layer.w2_weight.shape[2],
670
+ local_expert_offset=layer.moe_ep_rank * layer.num_local_experts,
671
+ local_num_experts=layer.num_local_experts,
672
+ routed_scaling_factor=(
673
+ routed_scaling_factor if routed_scaling_factor is not None else 1.0
674
+ ),
675
+ use_routing_scales_on_input=use_routing_scales_on_input,
676
+ tile_tokens_dim=8, # TODO(brayden): use the FI tile calculation
677
+ routing_method_type=routing_method_type,
678
+ )
679
+
680
+ from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
681
+
682
+ return StandardCombineInput(hidden_states=output)
529
683
 
530
684
  quant_info = TritonMoeQuantInfo(
531
685
  w13_weight=layer.w13_weight,
@@ -867,10 +1021,9 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
867
1021
  output_shape = [x_m, w_n]
868
1022
 
869
1023
  # Quantize BF16 or FP16 to (FP4 and interleaved block scale)
870
- x_fp4, x_scale_interleaved = scaled_fp4_quant(x, layer.input_scale_inv)
1024
+ x_fp4, x_scale_interleaved = fp4_quantize(x, layer.input_scale_inv)
871
1025
 
872
1026
  assert x_fp4.dtype == torch.uint8
873
- assert x_scale_interleaved.dtype == torch.float8_e4m3fn
874
1027
  assert layer.weight.dtype == torch.uint8
875
1028
  assert layer.weight_scale_interleaved.dtype == torch.float8_e4m3fn
876
1029
  assert layer.alpha.dtype == torch.float32
@@ -903,7 +1056,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
903
1056
 
904
1057
  def __init__(self, quant_config: ModelOptFp4Config):
905
1058
  self.quant_config = quant_config
906
- if not is_sm100_supported():
1059
+ if not is_blackwell_supported():
907
1060
  raise ValueError(
908
1061
  "Current platform does not support NVFP4"
909
1062
  " quantization. Please use Blackwell and"
@@ -1383,8 +1536,6 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1383
1536
  alt_stream=None,
1384
1537
  ) -> CombineInput:
1385
1538
 
1386
- from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
1387
-
1388
1539
  x = dispatch_output.hidden_states
1389
1540
  topk_output = dispatch_output.topk_output
1390
1541
 
@@ -1397,6 +1548,8 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1397
1548
  # Check if this is a FlashInferFP4MoE layer that should handle its own forward
1398
1549
  if hasattr(layer, "gemm1_weights_fp4_shuffled"):
1399
1550
  # This layer was processed with flashinfer TRTLLM - delegate to its own forward
1551
+ from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
1552
+
1400
1553
  return StandardCombineInput(hidden_states=layer.forward(x, topk_output))
1401
1554
 
1402
1555
  if self.enable_flashinfer_cutlass_moe:
@@ -1410,7 +1563,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1410
1563
  output_dtype = x.dtype
1411
1564
  x_sf = None
1412
1565
  if should_use_flashinfer_cutlass_moe_fp4_allgather():
1413
- from flashinfer import fp4_quantize, nvfp4_block_scale_interleave
1566
+ from flashinfer import nvfp4_block_scale_interleave
1414
1567
 
1415
1568
  # Quantize before comm, swizzle after.
1416
1569
  if x.shape[0] > 0:
@@ -1465,6 +1618,8 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1465
1618
  if forward_shared_experts is not None:
1466
1619
  torch.cuda.current_stream().wait_stream(alt_stream)
1467
1620
 
1621
+ from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
1622
+
1468
1623
  return StandardCombineInput(hidden_states=output)
1469
1624
 
1470
1625
  from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
@@ -1486,6 +1641,8 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1486
1641
  apply_router_weight_on_input=moe_runner_config.apply_router_weight_on_input,
1487
1642
  ).to(x.dtype)
1488
1643
  # Scale by routed_scaling_factor is fused into select_experts.
1644
+ from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
1645
+
1489
1646
  return StandardCombineInput(hidden_states=output)
1490
1647
 
1491
1648
  def apply_without_routing_weights(
@@ -261,26 +261,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
261
261
 
262
262
  self.prefix = prefix
263
263
  self.topk_indices_dtype = None
264
- self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
264
+ self.use_triton_kernels = get_moe_runner_backend().is_triton_kernels()
265
265
  self.with_bias = False
266
266
  self.use_flashinfer = get_moe_runner_backend().is_flashinfer_mxfp4()
267
267
  self.flashinfer_mxfp4_moe_precision = (
268
268
  get_global_server_args().flashinfer_mxfp4_moe_precision
269
269
  )
270
270
 
271
- self.triton_kernel_moe_forward = None
272
- self.triton_kernel_moe_with_bias_forward = None
273
- if torch.cuda.is_available() and has_triton_kernels:
274
- from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
275
- triton_kernel_moe_forward as _tk_forward,
276
- )
277
- from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
278
- triton_kernel_moe_with_bias_forward as _tk_with_bias_forward,
279
- )
280
-
281
- self.triton_kernel_moe_forward = _tk_forward
282
- self.triton_kernel_moe_with_bias_forward = _tk_with_bias_forward
283
-
284
271
  def create_weights(
285
272
  self,
286
273
  layer: torch.nn.Module,
@@ -600,7 +587,12 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
600
587
  self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
601
588
  ):
602
589
  self.moe_runner_config = moe_runner_config
603
- self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
590
+ backend = (
591
+ MoeRunnerBackend.TRITON_KERNELS
592
+ if self.use_triton_kernels
593
+ else MoeRunnerBackend.TRITON
594
+ )
595
+ self.runner = MoeRunner(backend, moe_runner_config)
604
596
 
605
597
  def apply(
606
598
  self,
@@ -677,31 +669,31 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
677
669
  )[0]
678
670
  return StandardCombineInput(hidden_states=trtllm_gen_output)
679
671
 
680
- if self.use_triton_kernels:
672
+ backend = self.runner.runner_backend
673
+ if backend.is_triton_kernels():
674
+ from sglang.srt.layers.moe.moe_runner.triton_kernels import (
675
+ TritonKernelsQuantInfo,
676
+ )
677
+
681
678
  assert (
682
679
  layer.moe_ep_size == 1
683
680
  ), "Expert parallel is not supported when using triton kernels"
684
- if self.with_bias:
685
- output = self.triton_kernel_moe_with_bias_forward(
686
- hidden_states=x,
687
- w1=self.w13_weight_triton_tensor,
688
- w1_pcg=self.w13_precision_config,
689
- w2=self.w2_weight_triton_tensor,
690
- w2_pcg=self.w2_precision_config,
691
- b1=layer.w13_weight_bias,
692
- b2=layer.w2_weight_bias,
693
- topk_output=topk_output,
694
- moe_runner_config=moe_runner_config,
695
- )
696
- else:
697
- output = self.triton_kernel_moe_forward(
698
- hidden_states=x,
699
- w1=layer.w13_weight,
700
- w2=layer.w2_weight,
701
- topk_output=topk_output,
702
- moe_runner_config=moe_runner_config,
703
- )
704
- return StandardCombineInput(hidden_states=output)
681
+ quant_info = TritonKernelsQuantInfo(
682
+ w13_weight=(
683
+ self.w13_weight_triton_tensor
684
+ if self.w13_weight_triton_tensor is not None
685
+ else layer.w13_weight
686
+ ),
687
+ w2_weight=(
688
+ self.w2_weight_triton_tensor
689
+ if self.w2_weight_triton_tensor is not None
690
+ else layer.w2_weight
691
+ ),
692
+ w13_bias=getattr(layer, "w13_weight_bias", None),
693
+ w2_bias=getattr(layer, "w2_weight_bias", None),
694
+ w13_precision_config=getattr(self, "w13_precision_config", None),
695
+ w2_precision_config=getattr(self, "w2_precision_config", None),
696
+ )
705
697
  else:
706
698
  quant_info = TritonMoeQuantInfo(
707
699
  w13_weight=layer.w13_weight,
@@ -709,7 +701,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
709
701
  b13=getattr(layer, "w13_weight_bias", None),
710
702
  b2=getattr(layer, "w2_weight_bias", None),
711
703
  )
712
- return self.runner.run(dispatch_output, quant_info)
704
+ return self.runner.run(dispatch_output, quant_info)
713
705
 
714
706
 
715
707
  class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase):
@@ -115,13 +115,15 @@ class UnquantizedLinearMethod(LinearMethodBase):
115
115
  x: torch.Tensor,
116
116
  bias: Optional[torch.Tensor] = None,
117
117
  ) -> torch.Tensor:
118
-
119
118
  if use_intel_amx_backend(layer):
120
119
  x_shapes = x.shape
121
120
  if len(x_shapes) == 3:
122
121
  x = x.view(-1, x.shape[-1])
123
122
  output = torch.ops.sgl_kernel.weight_packed_linear(
124
- x, layer.weight, bias, True # is_vnni
123
+ x,
124
+ layer.weight,
125
+ bias,
126
+ True, # is_vnni
125
127
  )
126
128
  if len(x_shapes) == 3:
127
129
  output = output.view(x_shapes[0], x_shapes[1], -1)
@@ -138,19 +140,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
138
140
  self.use_triton_kernels = use_triton_kernels
139
141
  self.with_bias = False
140
142
 
141
- self.triton_kernel_moe_forward = None
142
- self.triton_kernel_moe_with_bias_forward = None
143
- if torch.cuda.is_available() and use_triton_kernels:
144
- from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
145
- triton_kernel_moe_forward as _tk_forward,
146
- )
147
- from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
148
- triton_kernel_moe_with_bias_forward as _tk_with_bias_forward,
149
- )
150
-
151
- self.triton_kernel_moe_forward = _tk_forward
152
- self.triton_kernel_moe_with_bias_forward = _tk_with_bias_forward
153
-
154
143
  def create_weights(
155
144
  self,
156
145
  layer: torch.nn.Module,
@@ -231,14 +220,18 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
231
220
  self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
232
221
  ):
233
222
  self.moe_runner_config = moe_runner_config
234
- self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
223
+ backend = (
224
+ MoeRunnerBackend.TRITON_KERNELS
225
+ if self.use_triton_kernels
226
+ else MoeRunnerBackend.TRITON
227
+ )
228
+ self.runner = MoeRunner(backend, moe_runner_config)
235
229
 
236
230
  def apply(
237
231
  self,
238
232
  layer: torch.nn.Module,
239
233
  dispatch_output: StandardDispatchOutput,
240
234
  ) -> CombineInput:
241
-
242
235
  return self.forward(
243
236
  layer=layer,
244
237
  dispatch_output=dispatch_output,
@@ -249,7 +242,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
249
242
  layer: torch.nn.Module,
250
243
  dispatch_output: StandardDispatchOutput,
251
244
  ) -> CombineInput:
252
-
253
245
  from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
254
246
 
255
247
  x = dispatch_output.hidden_states
@@ -257,30 +249,19 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
257
249
 
258
250
  moe_runner_config = self.moe_runner_config
259
251
 
260
- if self.use_triton_kernels:
261
- if self.with_bias:
262
- assert self.triton_kernel_moe_with_bias_forward is not None
263
- output = self.triton_kernel_moe_with_bias_forward(
264
- hidden_states=x,
265
- w1=layer.w13_weight,
266
- w2=layer.w2_weight,
267
- b1=layer.w13_weight_bias,
268
- b2=layer.w2_weight_bias,
269
- topk_output=topk_output,
270
- moe_runner_config=moe_runner_config,
271
- w1_pcg=None,
272
- w2_pcg=None,
273
- )
274
- else:
275
- assert self.triton_kernel_moe_forward is not None
276
- output = self.triton_kernel_moe_forward(
277
- hidden_states=x,
278
- w1=layer.w13_weight,
279
- w2=layer.w2_weight,
280
- topk_output=topk_output,
281
- moe_runner_config=moe_runner_config,
282
- )
283
- return StandardCombineInput(hidden_states=output)
252
+ backend = self.runner.runner_backend
253
+ if backend.is_triton_kernels():
254
+ from sglang.srt.layers.moe.moe_runner.triton_kernels import (
255
+ TritonKernelsQuantInfo,
256
+ )
257
+
258
+ quant_info = TritonKernelsQuantInfo(
259
+ w13_weight=layer.w13_weight,
260
+ w2_weight=layer.w2_weight,
261
+ w13_bias=getattr(layer, "w13_weight_bias", None),
262
+ w2_bias=getattr(layer, "w2_weight_bias", None),
263
+ )
264
+ return self.runner.run(dispatch_output, quant_info)
284
265
  else:
285
266
  if _use_aiter:
286
267
  assert not moe_runner_config.no_combine, "unsupported"
@@ -311,7 +292,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
311
292
  )
312
293
  return StandardCombineInput(hidden_states=output)
313
294
  else:
314
-
315
295
  quant_info = TritonMoeQuantInfo(
316
296
  w13_weight=layer.w13_weight,
317
297
  w2_weight=layer.w2_weight,
@@ -325,7 +305,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
325
305
  layer: torch.nn.Module,
326
306
  dispatch_output: StandardDispatchOutput,
327
307
  ) -> CombineInput:
328
-
329
308
  from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
330
309
 
331
310
  x = dispatch_output.hidden_states
@@ -380,7 +359,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
380
359
  layer: torch.nn.Module,
381
360
  dispatch_output: StandardDispatchOutput,
382
361
  ) -> CombineInput:
383
-
384
362
  import torch_npu
385
363
 
386
364
  from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
@@ -23,7 +23,8 @@ if TYPE_CHECKING:
23
23
  from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE
24
24
  from sglang.srt.layers.moe.token_dispatcher import (
25
25
  CombineInput,
26
- DeepEPNormalOutput,
26
+ DeepEPLLDispatchOutput,
27
+ DeepEPNormalDispatchOutput,
27
28
  StandardDispatchOutput,
28
29
  )
29
30
 
@@ -328,10 +329,45 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
328
329
  output *= self.moe_runner_config.routed_scaling_factor
329
330
  return StandardCombineInput(hidden_states=output)
330
331
 
332
+ def apply_deepep_ll(
333
+ self,
334
+ layer: DeepEPMoE,
335
+ dispatch_output: DeepEPLLDispatchOutput,
336
+ ) -> torch.Tensor:
337
+
338
+ from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe_deepep_ll
339
+
340
+ hidden_states, _, topk_ids, _, masked_m, _ = dispatch_output
341
+
342
+ output = cutlass_w4a8_moe_deepep_ll(
343
+ hidden_states,
344
+ layer.w13_weight,
345
+ layer.w2_weight,
346
+ layer.w13_weight_scale_inv,
347
+ layer.w2_weight_scale_inv,
348
+ topk_ids,
349
+ masked_m,
350
+ layer.quant_method.a_strides1,
351
+ layer.quant_method.b_strides1,
352
+ layer.quant_method.c_strides1,
353
+ layer.quant_method.a_strides2,
354
+ layer.quant_method.b_strides2,
355
+ layer.quant_method.c_strides2,
356
+ layer.quant_method.s_strides13,
357
+ layer.quant_method.s_strides2,
358
+ layer.quant_method.expert_offsets,
359
+ layer.quant_method.problem_sizes1,
360
+ layer.quant_method.problem_sizes2,
361
+ layer.w13_input_scale,
362
+ layer.w2_input_scale,
363
+ )
364
+
365
+ return output
366
+
331
367
  def apply_deepep_normal(
332
368
  self,
333
369
  layer: DeepEPMoE,
334
- dispatch_output: DeepEPNormalOutput,
370
+ dispatch_output: DeepEPNormalDispatchOutput,
335
371
  ) -> torch.Tensor:
336
372
  from sglang.srt.layers.moe.cutlass_w4a8_moe import (
337
373
  cutlass_w4a8_moe_deepep_normal,
@@ -142,8 +142,11 @@ def unified_attention_with_output(
142
142
  ret = forward_batch.attn_backend.forward(
143
143
  query, key, value, attention_layer, forward_batch, save_kv_cache
144
144
  )
145
- assert output.shape == ret.shape
146
- output.copy_(ret)
145
+ assert (
146
+ output.numel() == ret.numel()
147
+ ), f"Output tensor element mismatch: {output.numel()} != {ret.numel()}"
148
+
149
+ output.view(ret.shape).copy_(ret)
147
150
  return
148
151
 
149
152