sglang 0.5.0rc0__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. sglang/__init__.py +8 -3
  2. sglang/bench_one_batch.py +6 -0
  3. sglang/lang/chat_template.py +18 -0
  4. sglang/srt/bench_utils.py +137 -0
  5. sglang/srt/configs/model_config.py +7 -7
  6. sglang/srt/disaggregation/decode.py +8 -3
  7. sglang/srt/disaggregation/mooncake/conn.py +43 -25
  8. sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
  9. sglang/srt/distributed/parallel_state.py +4 -2
  10. sglang/srt/entrypoints/context.py +3 -20
  11. sglang/srt/entrypoints/engine.py +13 -8
  12. sglang/srt/entrypoints/harmony_utils.py +2 -0
  13. sglang/srt/entrypoints/http_server.py +4 -5
  14. sglang/srt/entrypoints/openai/protocol.py +0 -9
  15. sglang/srt/entrypoints/openai/serving_chat.py +59 -265
  16. sglang/srt/entrypoints/openai/tool_server.py +4 -3
  17. sglang/srt/function_call/ebnf_composer.py +1 -0
  18. sglang/srt/function_call/function_call_parser.py +2 -0
  19. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  20. sglang/srt/function_call/gpt_oss_detector.py +331 -0
  21. sglang/srt/function_call/kimik2_detector.py +3 -3
  22. sglang/srt/function_call/qwen3_coder_detector.py +219 -9
  23. sglang/srt/jinja_template_utils.py +6 -0
  24. sglang/srt/layers/attention/aiter_backend.py +370 -107
  25. sglang/srt/layers/attention/ascend_backend.py +3 -0
  26. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  27. sglang/srt/layers/attention/flashattention_backend.py +18 -0
  28. sglang/srt/layers/attention/flashinfer_backend.py +52 -13
  29. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  30. sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
  31. sglang/srt/layers/attention/vision.py +9 -1
  32. sglang/srt/layers/attention/wave_backend.py +627 -0
  33. sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
  34. sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
  35. sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
  36. sglang/srt/layers/communicator.py +8 -10
  37. sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
  38. sglang/srt/layers/linear.py +1 -0
  39. sglang/srt/layers/moe/cutlass_moe.py +11 -16
  40. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
  41. sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
  42. sglang/srt/layers/moe/ep_moe/layer.py +60 -2
  43. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  44. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  45. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -9
  46. sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
  47. sglang/srt/layers/moe/topk.py +4 -1
  48. sglang/srt/layers/quantization/__init__.py +5 -3
  49. sglang/srt/layers/quantization/fp8_kernel.py +277 -0
  50. sglang/srt/layers/quantization/fp8_utils.py +22 -10
  51. sglang/srt/layers/quantization/modelopt_quant.py +6 -11
  52. sglang/srt/layers/quantization/mxfp4.py +4 -1
  53. sglang/srt/layers/quantization/w4afp8.py +20 -11
  54. sglang/srt/layers/quantization/w8a8_int8.py +48 -34
  55. sglang/srt/layers/rotary_embedding.py +281 -2
  56. sglang/srt/lora/backend/base_backend.py +3 -23
  57. sglang/srt/lora/layers.py +60 -114
  58. sglang/srt/lora/lora.py +17 -62
  59. sglang/srt/lora/lora_manager.py +12 -48
  60. sglang/srt/lora/lora_registry.py +20 -9
  61. sglang/srt/lora/mem_pool.py +20 -63
  62. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  63. sglang/srt/lora/utils.py +25 -58
  64. sglang/srt/managers/cache_controller.py +21 -29
  65. sglang/srt/managers/detokenizer_manager.py +1 -1
  66. sglang/srt/managers/io_struct.py +6 -6
  67. sglang/srt/managers/mm_utils.py +1 -2
  68. sglang/srt/managers/multimodal_processor.py +1 -1
  69. sglang/srt/managers/schedule_batch.py +35 -20
  70. sglang/srt/managers/schedule_policy.py +6 -6
  71. sglang/srt/managers/scheduler.py +15 -7
  72. sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
  73. sglang/srt/managers/tokenizer_manager.py +25 -26
  74. sglang/srt/mem_cache/allocator.py +61 -87
  75. sglang/srt/mem_cache/hicache_storage.py +1 -1
  76. sglang/srt/mem_cache/hiradix_cache.py +34 -24
  77. sglang/srt/mem_cache/lora_radix_cache.py +421 -0
  78. sglang/srt/mem_cache/memory_pool_host.py +33 -35
  79. sglang/srt/mem_cache/radix_cache.py +2 -5
  80. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
  81. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
  82. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
  83. sglang/srt/model_executor/cuda_graph_runner.py +22 -3
  84. sglang/srt/model_executor/forward_batch_info.py +26 -5
  85. sglang/srt/model_executor/model_runner.py +129 -35
  86. sglang/srt/model_loader/loader.py +18 -6
  87. sglang/srt/models/deepseek_v2.py +74 -35
  88. sglang/srt/models/gemma2.py +0 -34
  89. sglang/srt/models/gemma3n_mm.py +8 -9
  90. sglang/srt/models/glm4.py +6 -0
  91. sglang/srt/models/glm4_moe.py +9 -9
  92. sglang/srt/models/glm4v.py +589 -0
  93. sglang/srt/models/glm4v_moe.py +400 -0
  94. sglang/srt/models/gpt_oss.py +136 -19
  95. sglang/srt/models/granite.py +0 -25
  96. sglang/srt/models/llama.py +0 -25
  97. sglang/srt/models/llama4.py +1 -1
  98. sglang/srt/models/qwen2_5_vl.py +7 -3
  99. sglang/srt/models/qwen2_audio.py +10 -9
  100. sglang/srt/models/qwen3.py +0 -24
  101. sglang/srt/models/registry.py +1 -1
  102. sglang/srt/models/torch_native_llama.py +0 -24
  103. sglang/srt/multimodal/processors/base_processor.py +23 -13
  104. sglang/srt/multimodal/processors/glm4v.py +132 -0
  105. sglang/srt/multimodal/processors/qwen_audio.py +4 -2
  106. sglang/srt/reasoning_parser.py +316 -0
  107. sglang/srt/server_args.py +115 -139
  108. sglang/srt/speculative/eagle_worker.py +16 -0
  109. sglang/srt/two_batch_overlap.py +12 -4
  110. sglang/srt/utils.py +3 -3
  111. sglang/srt/weight_sync/tensor_bucket.py +106 -0
  112. sglang/test/attention/test_trtllm_mla_backend.py +186 -36
  113. sglang/test/doc_patch.py +59 -0
  114. sglang/test/few_shot_gsm8k.py +1 -1
  115. sglang/test/few_shot_gsm8k_engine.py +1 -1
  116. sglang/test/run_eval.py +4 -1
  117. sglang/test/simple_eval_common.py +6 -0
  118. sglang/test/simple_eval_gpqa.py +2 -0
  119. sglang/test/test_fp4_moe.py +118 -36
  120. sglang/utils.py +1 -1
  121. sglang/version.py +1 -1
  122. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +26 -30
  123. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +127 -115
  124. sglang/lang/backend/__init__.py +0 -0
  125. sglang/srt/function_call/harmony_tool_parser.py +0 -130
  126. sglang/srt/lora/backend/flashinfer_backend.py +0 -131
  127. /sglang/{api.py → lang/api.py} +0 -0
  128. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
  129. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
  130. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -34,6 +34,7 @@ from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip,
34
34
 
35
35
  if TYPE_CHECKING:
36
36
  from sglang.srt.layers.moe.token_dispatcher import (
37
+ AscendDeepEPLLOutput,
37
38
  DeepEPLLOutput,
38
39
  DeepEPNormalOutput,
39
40
  DispatchOutput,
@@ -387,7 +388,8 @@ class DeepEPMoE(EPMoE):
387
388
  return_recv_hook=True,
388
389
  )
389
390
 
390
- if self.deepep_mode.enable_low_latency():
391
+ if self.deepep_mode.enable_low_latency() and not _is_npu:
392
+ # NPU supports low_latency deepep without deepgemm
391
393
  assert (
392
394
  deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
393
395
  ), f"DeepEP {self.deepep_mode} mode requires deep_gemm"
@@ -404,7 +406,7 @@ class DeepEPMoE(EPMoE):
404
406
  )
405
407
  # the last one is invalid rank_id
406
408
  self.expert_mask[:-1] = 1
407
- else:
409
+ elif not _is_npu:
408
410
  self.w13_weight_fp8 = (
409
411
  self.w13_weight,
410
412
  (
@@ -459,6 +461,8 @@ class DeepEPMoE(EPMoE):
459
461
  if _use_aiter:
460
462
  # in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
461
463
  return self.forward_aiter(dispatch_output)
464
+ if _is_npu:
465
+ return self.forward_npu(dispatch_output)
462
466
  if dispatch_output.format.is_deepep_normal():
463
467
  assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
464
468
  return self.forward_deepgemm_contiguous(dispatch_output)
@@ -723,6 +727,60 @@ class DeepEPMoE(EPMoE):
723
727
 
724
728
  return down_output
725
729
 
730
+ def forward_npu(
731
+ self,
732
+ dispatch_output: DeepEPLLOutput,
733
+ ):
734
+ if TYPE_CHECKING:
735
+ assert isinstance(dispatch_output, AscendDeepEPLLOutput)
736
+ hidden_states, topk_idx, topk_weights, _, seg_indptr, _ = dispatch_output
737
+ assert self.quant_method is not None
738
+ assert self.activation == "silu"
739
+
740
+ # NOTE: Ascend's Dispatch & Combine does not support FP16
741
+ output_dtype = torch.bfloat16
742
+
743
+ pertoken_scale = hidden_states[1]
744
+ hidden_states = hidden_states[0]
745
+
746
+ group_list_type = 1
747
+ seg_indptr = seg_indptr.to(torch.int64)
748
+
749
+ import torch_npu
750
+
751
+ # gmm1: gate_up_proj
752
+ hidden_states = torch_npu.npu_grouped_matmul(
753
+ x=[hidden_states],
754
+ weight=[self.w13_weight],
755
+ scale=[self.w13_weight_scale.to(output_dtype)],
756
+ per_token_scale=[pertoken_scale],
757
+ split_item=2,
758
+ group_list_type=group_list_type,
759
+ group_type=0,
760
+ group_list=seg_indptr,
761
+ output_dtype=output_dtype,
762
+ )[0]
763
+
764
+ # act_fn: swiglu
765
+ hidden_states = torch_npu.npu_swiglu(hidden_states)
766
+
767
+ hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(hidden_states)
768
+
769
+ # gmm2: down_proj
770
+ hidden_states = torch_npu.npu_grouped_matmul(
771
+ x=[hidden_states],
772
+ weight=[self.w2_weight],
773
+ scale=[self.w2_weight_scale.to(output_dtype)],
774
+ per_token_scale=[swiglu_out_scale],
775
+ split_item=2,
776
+ group_list_type=group_list_type,
777
+ group_type=0,
778
+ group_list=seg_indptr,
779
+ output_dtype=output_dtype,
780
+ )[0]
781
+
782
+ return hidden_states
783
+
726
784
 
727
785
  def get_moe_impl_class():
728
786
  if global_server_args_dict["moe_a2a_backend"].is_deepep():
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 128,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 4
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 256,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 1,
15
+ "num_warps": 4,
16
+ "num_stages": 5
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 128,
21
+ "BLOCK_SIZE_K": 256,
22
+ "GROUP_SIZE_M": 1,
23
+ "num_warps": 4,
24
+ "num_stages": 4
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 128,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 1,
31
+ "num_warps": 4,
32
+ "num_stages": 4
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 256,
38
+ "GROUP_SIZE_M": 1,
39
+ "num_warps": 4,
40
+ "num_stages": 4
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 128,
45
+ "BLOCK_SIZE_K": 256,
46
+ "GROUP_SIZE_M": 1,
47
+ "num_warps": 4,
48
+ "num_stages": 4
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 128,
53
+ "BLOCK_SIZE_K": 256,
54
+ "GROUP_SIZE_M": 1,
55
+ "num_warps": 4,
56
+ "num_stages": 4
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 128,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 1,
63
+ "num_warps": 4,
64
+ "num_stages": 4
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 256,
70
+ "GROUP_SIZE_M": 1,
71
+ "num_warps": 4,
72
+ "num_stages": 4
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 256,
78
+ "GROUP_SIZE_M": 1,
79
+ "num_warps": 4,
80
+ "num_stages": 4
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 256,
86
+ "GROUP_SIZE_M": 1,
87
+ "num_warps": 4,
88
+ "num_stages": 4
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 1,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 64,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 256,
102
+ "GROUP_SIZE_M": 32,
103
+ "num_warps": 4,
104
+ "num_stages": 4
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 64,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 32,
111
+ "num_warps": 4,
112
+ "num_stages": 4
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 64,
116
+ "BLOCK_SIZE_N": 128,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 1,
119
+ "num_warps": 4,
120
+ "num_stages": 4
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 64,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 256,
126
+ "GROUP_SIZE_M": 32,
127
+ "num_warps": 4,
128
+ "num_stages": 4
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 64,
132
+ "BLOCK_SIZE_N": 128,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 32,
135
+ "num_warps": 4,
136
+ "num_stages": 4
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 64,
140
+ "BLOCK_SIZE_N": 128,
141
+ "BLOCK_SIZE_K": 256,
142
+ "GROUP_SIZE_M": 16,
143
+ "num_warps": 4,
144
+ "num_stages": 4
145
+ }
146
+ }
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 128,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 4
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 128,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 1,
15
+ "num_warps": 4,
16
+ "num_stages": 4
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 128,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 1,
23
+ "num_warps": 4,
24
+ "num_stages": 4
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 128,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 1,
31
+ "num_warps": 4,
32
+ "num_stages": 3
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 1,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 256,
45
+ "BLOCK_SIZE_K": 64,
46
+ "GROUP_SIZE_M": 1,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 256,
53
+ "BLOCK_SIZE_K": 64,
54
+ "GROUP_SIZE_M": 1,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 128,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 1,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 128,
70
+ "GROUP_SIZE_M": 1,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 1,
79
+ "num_warps": 4,
80
+ "num_stages": 4
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 1,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 1,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 16,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 1,
103
+ "num_warps": 4,
104
+ "num_stages": 3
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 32,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 1,
111
+ "num_warps": 4,
112
+ "num_stages": 3
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 32,
116
+ "BLOCK_SIZE_N": 128,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 1,
119
+ "num_warps": 4,
120
+ "num_stages": 4
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 64,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 32,
127
+ "num_warps": 4,
128
+ "num_stages": 4
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 64,
132
+ "BLOCK_SIZE_N": 128,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 64,
135
+ "num_warps": 4,
136
+ "num_stages": 4
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 64,
140
+ "BLOCK_SIZE_N": 128,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 32,
143
+ "num_warps": 4,
144
+ "num_stages": 4
145
+ }
146
+ }
@@ -147,6 +147,7 @@ class FusedMoE(torch.nn.Module):
147
147
 
148
148
  self.layer_id = layer_id
149
149
  self.top_k = top_k
150
+ self.hidden_size = hidden_size
150
151
  self.num_experts = num_experts
151
152
  self.num_fused_shared_experts = num_fused_shared_experts
152
153
  self.expert_map_cpu = None
@@ -209,13 +210,13 @@ class FusedMoE(torch.nn.Module):
209
210
  self.use_enable_flashinfer_mxfp4_moe = global_server_args_dict.get(
210
211
  "enable_flashinfer_mxfp4_moe", False
211
212
  )
213
+ # TODO maybe we should remove this `if`, since `Mxfp4MoEMethod` does another round-up logic
212
214
  if (
213
215
  self.quant_config is not None
214
216
  and self.quant_config.get_name() == "mxfp4"
215
217
  and self.use_enable_flashinfer_mxfp4_moe
216
218
  ):
217
219
  hidden_size = round_up(hidden_size, 256)
218
- self.hidden_size = hidden_size
219
220
  self.quant_method.create_weights(
220
221
  layer=self,
221
222
  num_experts=self.num_local_experts,
@@ -795,13 +796,6 @@ class FusedMoE(torch.nn.Module):
795
796
 
796
797
  def forward(self, hidden_states: torch.Tensor, topk_output: StandardTopKOutput):
797
798
  origin_hidden_states_dim = hidden_states.shape[-1]
798
- if self.hidden_size != origin_hidden_states_dim:
799
- hidden_states = torch.nn.functional.pad(
800
- hidden_states,
801
- (0, self.hidden_size - origin_hidden_states_dim),
802
- mode="constant",
803
- value=0.0,
804
- )
805
799
  assert self.quant_method is not None
806
800
 
807
801
  if self.moe_ep_size > 1 and not self.enable_flashinfer_cutlass_moe:
@@ -846,10 +840,14 @@ class FusedMoE(torch.nn.Module):
846
840
  )
847
841
  sm.tag(final_hidden_states)
848
842
 
843
+ final_hidden_states = final_hidden_states[
844
+ ..., :origin_hidden_states_dim
845
+ ].contiguous()
846
+
849
847
  if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
850
848
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
851
849
 
852
- return final_hidden_states[..., :origin_hidden_states_dim].contiguous()
850
+ return final_hidden_states
853
851
 
854
852
  @classmethod
855
853
  def make_expert_params_mapping(
@@ -23,14 +23,23 @@ from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
23
23
  from sglang.srt.layers.moe.utils import DeepEPMode
24
24
  from sglang.srt.layers.quantization import deep_gemm_wrapper
25
25
  from sglang.srt.managers.schedule_batch import global_server_args_dict
26
- from sglang.srt.utils import get_bool_env_var, get_int_env_var, is_hip, load_json_config
26
+ from sglang.srt.utils import (
27
+ get_bool_env_var,
28
+ get_int_env_var,
29
+ is_hip,
30
+ is_npu,
31
+ load_json_config,
32
+ )
33
+
34
+ _is_npu = is_npu()
27
35
 
28
36
  try:
29
37
  from deep_ep import Buffer, Config
30
38
 
31
- from sglang.srt.layers.quantization.fp8_kernel import (
32
- sglang_per_token_group_quant_fp8,
33
- )
39
+ if not _is_npu:
40
+ from sglang.srt.layers.quantization.fp8_kernel import (
41
+ sglang_per_token_group_quant_fp8,
42
+ )
34
43
 
35
44
  use_deepep = True
36
45
  except ImportError:
@@ -80,8 +89,24 @@ class DeepEPLLOutput(NamedTuple):
80
89
  return DispatchOutputFormat.deepep_ll
81
90
 
82
91
 
92
+ class AscendDeepEPLLOutput(NamedTuple):
93
+ """AscendDeepEP low latency dispatch output."""
94
+
95
+ hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor]
96
+ topk_idx: torch.Tensor
97
+ topk_weights: torch.Tensor
98
+ masked_m: torch.Tensor
99
+ seg_indptr: torch.Tensor
100
+ expected_m: int
101
+
102
+ @property
103
+ def format(self) -> DispatchOutputFormat:
104
+ return DispatchOutputFormat.deepep_ll
105
+
106
+
83
107
  assert isinstance(DeepEPNormalOutput, DispatchOutput)
84
108
  assert isinstance(DeepEPLLOutput, DispatchOutput)
109
+ assert isinstance(AscendDeepEPLLOutput, DispatchOutput)
85
110
 
86
111
 
87
112
  class DeepEPDispatchMode(IntEnum):
@@ -150,19 +175,20 @@ class DeepEPBuffer:
150
175
  else:
151
176
  raise NotImplementedError
152
177
 
153
- total_num_sms = torch.cuda.get_device_properties(
154
- device="cuda"
155
- ).multi_processor_count
156
- if (
157
- (deepep_mode != DeepEPMode.LOW_LATENCY)
158
- and not global_server_args_dict["enable_two_batch_overlap"]
159
- and (DeepEPConfig.get_instance().num_sms < total_num_sms // 2)
160
- ):
161
- logger.warning(
162
- f"Only use {DeepEPConfig.get_instance().num_sms} SMs for DeepEP communication. "
163
- f"This may result in highly suboptimal performance. "
164
- f"Consider using --deepep-config to change the behavior."
165
- )
178
+ if not _is_npu:
179
+ total_num_sms = torch.cuda.get_device_properties(
180
+ device="cuda"
181
+ ).multi_processor_count
182
+ if (
183
+ (deepep_mode != DeepEPMode.LOW_LATENCY)
184
+ and not global_server_args_dict["enable_two_batch_overlap"]
185
+ and (DeepEPConfig.get_instance().num_sms < total_num_sms // 2)
186
+ ):
187
+ logger.warning(
188
+ f"Only use {DeepEPConfig.get_instance().num_sms} SMs for DeepEP communication. "
189
+ f"This may result in highly suboptimal performance. "
190
+ f"Consider using --deepep-config to change the behavior."
191
+ )
166
192
 
167
193
  cls._buffer = Buffer(
168
194
  group,
@@ -507,13 +533,24 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
507
533
  masked_m
508
534
  )
509
535
 
510
- return DeepEPLLOutput(
511
- hidden_states,
512
- topk_idx,
513
- topk_weights,
514
- masked_m,
515
- expected_m,
516
- )
536
+ if _is_npu:
537
+ deepep_output = AscendDeepEPLLOutput(
538
+ hidden_states,
539
+ topk_idx,
540
+ topk_weights,
541
+ masked_m,
542
+ self.handle[1],
543
+ expected_m,
544
+ )
545
+ else:
546
+ deepep_output = DeepEPLLOutput(
547
+ hidden_states,
548
+ topk_idx,
549
+ topk_weights,
550
+ masked_m,
551
+ expected_m,
552
+ )
553
+ return deepep_output
517
554
 
518
555
  def _dispatch_core(
519
556
  self,
@@ -245,10 +245,11 @@ class TopK(CustomOp):
245
245
 
246
246
  # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
247
247
  if global_num_experts == 256:
248
+ router_logits = router_logits.to(torch.float32)
248
249
  return torch_npu.npu_moe_gating_top_k(
249
250
  router_logits,
250
251
  k=self.top_k,
251
- bias=self.correction_bias,
252
+ bias=self.correction_bias.to(torch.float32),
252
253
  k_group=self.topk_group,
253
254
  group_count=self.num_expert_group,
254
255
  group_select_mode=1,
@@ -440,7 +441,9 @@ def grouped_topk_cpu(
440
441
  routed_scaling_factor: Optional[float] = None,
441
442
  num_token_non_padded: Optional[torch.Tensor] = None,
442
443
  expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
444
+ apply_routed_scaling_factor_on_output: Optional[bool] = False,
443
445
  ):
446
+ assert not apply_routed_scaling_factor_on_output
444
447
  assert expert_location_dispatch_info is None
445
448
  return torch.ops.sgl_kernel.grouped_topk_cpu(
446
449
  hidden_states,
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
 
4
4
  import builtins
5
5
  import inspect
6
- from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, Union
6
+ from typing import TYPE_CHECKING, Dict, Optional, Type
7
7
 
8
8
  import torch
9
9
 
@@ -26,8 +26,9 @@ try:
26
26
  from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
27
27
 
28
28
  VLLM_AVAILABLE = True
29
- except ImportError:
29
+ except ImportError as e:
30
30
  VLLM_AVAILABLE = False
31
+ VLLM_IMPORT_ERROR = e
31
32
 
32
33
  # Define empty classes as placeholders when vllm is not available
33
34
  class DummyConfig:
@@ -137,7 +138,8 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
137
138
  if quantization in VLLM_QUANTIZATION_METHODS and not VLLM_AVAILABLE:
138
139
  raise ValueError(
139
140
  f"{quantization} quantization requires some operators from vllm. "
140
- "Please install vllm by `pip install vllm==0.9.0.1`"
141
+ f"Please install vllm by `pip install vllm==0.9.0.1`\n"
142
+ f"Import error: {VLLM_IMPORT_ERROR}"
141
143
  )
142
144
 
143
145
  return QUANTIZATION_METHODS[quantization]