sglang 0.5.0rc0__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -3
- sglang/bench_one_batch.py +6 -0
- sglang/lang/chat_template.py +18 -0
- sglang/srt/bench_utils.py +137 -0
- sglang/srt/configs/model_config.py +7 -7
- sglang/srt/disaggregation/decode.py +8 -3
- sglang/srt/disaggregation/mooncake/conn.py +43 -25
- sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
- sglang/srt/distributed/parallel_state.py +4 -2
- sglang/srt/entrypoints/context.py +3 -20
- sglang/srt/entrypoints/engine.py +13 -8
- sglang/srt/entrypoints/harmony_utils.py +2 -0
- sglang/srt/entrypoints/http_server.py +4 -5
- sglang/srt/entrypoints/openai/protocol.py +0 -9
- sglang/srt/entrypoints/openai/serving_chat.py +59 -265
- sglang/srt/entrypoints/openai/tool_server.py +4 -3
- sglang/srt/function_call/ebnf_composer.py +1 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +331 -0
- sglang/srt/function_call/kimik2_detector.py +3 -3
- sglang/srt/function_call/qwen3_coder_detector.py +219 -9
- sglang/srt/jinja_template_utils.py +6 -0
- sglang/srt/layers/attention/aiter_backend.py +370 -107
- sglang/srt/layers/attention/ascend_backend.py +3 -0
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +18 -0
- sglang/srt/layers/attention/flashinfer_backend.py +52 -13
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
- sglang/srt/layers/attention/vision.py +9 -1
- sglang/srt/layers/attention/wave_backend.py +627 -0
- sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
- sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
- sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
- sglang/srt/layers/communicator.py +8 -10
- sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/moe/cutlass_moe.py +11 -16
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
- sglang/srt/layers/moe/ep_moe/layer.py +60 -2
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -9
- sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
- sglang/srt/layers/moe/topk.py +4 -1
- sglang/srt/layers/quantization/__init__.py +5 -3
- sglang/srt/layers/quantization/fp8_kernel.py +277 -0
- sglang/srt/layers/quantization/fp8_utils.py +22 -10
- sglang/srt/layers/quantization/modelopt_quant.py +6 -11
- sglang/srt/layers/quantization/mxfp4.py +4 -1
- sglang/srt/layers/quantization/w4afp8.py +20 -11
- sglang/srt/layers/quantization/w8a8_int8.py +48 -34
- sglang/srt/layers/rotary_embedding.py +281 -2
- sglang/srt/lora/backend/base_backend.py +3 -23
- sglang/srt/lora/layers.py +60 -114
- sglang/srt/lora/lora.py +17 -62
- sglang/srt/lora/lora_manager.py +12 -48
- sglang/srt/lora/lora_registry.py +20 -9
- sglang/srt/lora/mem_pool.py +20 -63
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/utils.py +25 -58
- sglang/srt/managers/cache_controller.py +21 -29
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +6 -6
- sglang/srt/managers/mm_utils.py +1 -2
- sglang/srt/managers/multimodal_processor.py +1 -1
- sglang/srt/managers/schedule_batch.py +35 -20
- sglang/srt/managers/schedule_policy.py +6 -6
- sglang/srt/managers/scheduler.py +15 -7
- sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
- sglang/srt/managers/tokenizer_manager.py +25 -26
- sglang/srt/mem_cache/allocator.py +61 -87
- sglang/srt/mem_cache/hicache_storage.py +1 -1
- sglang/srt/mem_cache/hiradix_cache.py +34 -24
- sglang/srt/mem_cache/lora_radix_cache.py +421 -0
- sglang/srt/mem_cache/memory_pool_host.py +33 -35
- sglang/srt/mem_cache/radix_cache.py +2 -5
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
- sglang/srt/model_executor/cuda_graph_runner.py +22 -3
- sglang/srt/model_executor/forward_batch_info.py +26 -5
- sglang/srt/model_executor/model_runner.py +129 -35
- sglang/srt/model_loader/loader.py +18 -6
- sglang/srt/models/deepseek_v2.py +74 -35
- sglang/srt/models/gemma2.py +0 -34
- sglang/srt/models/gemma3n_mm.py +8 -9
- sglang/srt/models/glm4.py +6 -0
- sglang/srt/models/glm4_moe.py +9 -9
- sglang/srt/models/glm4v.py +589 -0
- sglang/srt/models/glm4v_moe.py +400 -0
- sglang/srt/models/gpt_oss.py +136 -19
- sglang/srt/models/granite.py +0 -25
- sglang/srt/models/llama.py +0 -25
- sglang/srt/models/llama4.py +1 -1
- sglang/srt/models/qwen2_5_vl.py +7 -3
- sglang/srt/models/qwen2_audio.py +10 -9
- sglang/srt/models/qwen3.py +0 -24
- sglang/srt/models/registry.py +1 -1
- sglang/srt/models/torch_native_llama.py +0 -24
- sglang/srt/multimodal/processors/base_processor.py +23 -13
- sglang/srt/multimodal/processors/glm4v.py +132 -0
- sglang/srt/multimodal/processors/qwen_audio.py +4 -2
- sglang/srt/reasoning_parser.py +316 -0
- sglang/srt/server_args.py +115 -139
- sglang/srt/speculative/eagle_worker.py +16 -0
- sglang/srt/two_batch_overlap.py +12 -4
- sglang/srt/utils.py +3 -3
- sglang/srt/weight_sync/tensor_bucket.py +106 -0
- sglang/test/attention/test_trtllm_mla_backend.py +186 -36
- sglang/test/doc_patch.py +59 -0
- sglang/test/few_shot_gsm8k.py +1 -1
- sglang/test/few_shot_gsm8k_engine.py +1 -1
- sglang/test/run_eval.py +4 -1
- sglang/test/simple_eval_common.py +6 -0
- sglang/test/simple_eval_gpqa.py +2 -0
- sglang/test/test_fp4_moe.py +118 -36
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +26 -30
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +127 -115
- sglang/lang/backend/__init__.py +0 -0
- sglang/srt/function_call/harmony_tool_parser.py +0 -130
- sglang/srt/lora/backend/flashinfer_backend.py +0 -131
- /sglang/{api.py → lang/api.py} +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -34,6 +34,7 @@ from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip,
|
|
34
34
|
|
35
35
|
if TYPE_CHECKING:
|
36
36
|
from sglang.srt.layers.moe.token_dispatcher import (
|
37
|
+
AscendDeepEPLLOutput,
|
37
38
|
DeepEPLLOutput,
|
38
39
|
DeepEPNormalOutput,
|
39
40
|
DispatchOutput,
|
@@ -387,7 +388,8 @@ class DeepEPMoE(EPMoE):
|
|
387
388
|
return_recv_hook=True,
|
388
389
|
)
|
389
390
|
|
390
|
-
if self.deepep_mode.enable_low_latency():
|
391
|
+
if self.deepep_mode.enable_low_latency() and not _is_npu:
|
392
|
+
# NPU supports low_latency deepep without deepgemm
|
391
393
|
assert (
|
392
394
|
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
393
395
|
), f"DeepEP {self.deepep_mode} mode requires deep_gemm"
|
@@ -404,7 +406,7 @@ class DeepEPMoE(EPMoE):
|
|
404
406
|
)
|
405
407
|
# the last one is invalid rank_id
|
406
408
|
self.expert_mask[:-1] = 1
|
407
|
-
|
409
|
+
elif not _is_npu:
|
408
410
|
self.w13_weight_fp8 = (
|
409
411
|
self.w13_weight,
|
410
412
|
(
|
@@ -459,6 +461,8 @@ class DeepEPMoE(EPMoE):
|
|
459
461
|
if _use_aiter:
|
460
462
|
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
|
461
463
|
return self.forward_aiter(dispatch_output)
|
464
|
+
if _is_npu:
|
465
|
+
return self.forward_npu(dispatch_output)
|
462
466
|
if dispatch_output.format.is_deepep_normal():
|
463
467
|
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
|
464
468
|
return self.forward_deepgemm_contiguous(dispatch_output)
|
@@ -723,6 +727,60 @@ class DeepEPMoE(EPMoE):
|
|
723
727
|
|
724
728
|
return down_output
|
725
729
|
|
730
|
+
def forward_npu(
|
731
|
+
self,
|
732
|
+
dispatch_output: DeepEPLLOutput,
|
733
|
+
):
|
734
|
+
if TYPE_CHECKING:
|
735
|
+
assert isinstance(dispatch_output, AscendDeepEPLLOutput)
|
736
|
+
hidden_states, topk_idx, topk_weights, _, seg_indptr, _ = dispatch_output
|
737
|
+
assert self.quant_method is not None
|
738
|
+
assert self.activation == "silu"
|
739
|
+
|
740
|
+
# NOTE: Ascend's Dispatch & Combine does not support FP16
|
741
|
+
output_dtype = torch.bfloat16
|
742
|
+
|
743
|
+
pertoken_scale = hidden_states[1]
|
744
|
+
hidden_states = hidden_states[0]
|
745
|
+
|
746
|
+
group_list_type = 1
|
747
|
+
seg_indptr = seg_indptr.to(torch.int64)
|
748
|
+
|
749
|
+
import torch_npu
|
750
|
+
|
751
|
+
# gmm1: gate_up_proj
|
752
|
+
hidden_states = torch_npu.npu_grouped_matmul(
|
753
|
+
x=[hidden_states],
|
754
|
+
weight=[self.w13_weight],
|
755
|
+
scale=[self.w13_weight_scale.to(output_dtype)],
|
756
|
+
per_token_scale=[pertoken_scale],
|
757
|
+
split_item=2,
|
758
|
+
group_list_type=group_list_type,
|
759
|
+
group_type=0,
|
760
|
+
group_list=seg_indptr,
|
761
|
+
output_dtype=output_dtype,
|
762
|
+
)[0]
|
763
|
+
|
764
|
+
# act_fn: swiglu
|
765
|
+
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
766
|
+
|
767
|
+
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(hidden_states)
|
768
|
+
|
769
|
+
# gmm2: down_proj
|
770
|
+
hidden_states = torch_npu.npu_grouped_matmul(
|
771
|
+
x=[hidden_states],
|
772
|
+
weight=[self.w2_weight],
|
773
|
+
scale=[self.w2_weight_scale.to(output_dtype)],
|
774
|
+
per_token_scale=[swiglu_out_scale],
|
775
|
+
split_item=2,
|
776
|
+
group_list_type=group_list_type,
|
777
|
+
group_type=0,
|
778
|
+
group_list=seg_indptr,
|
779
|
+
output_dtype=output_dtype,
|
780
|
+
)[0]
|
781
|
+
|
782
|
+
return hidden_states
|
783
|
+
|
726
784
|
|
727
785
|
def get_moe_impl_class():
|
728
786
|
if global_server_args_dict["moe_a2a_backend"].is_deepep():
|
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 16,
|
4
|
+
"BLOCK_SIZE_N": 128,
|
5
|
+
"BLOCK_SIZE_K": 128,
|
6
|
+
"GROUP_SIZE_M": 1,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 4
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 16,
|
12
|
+
"BLOCK_SIZE_N": 256,
|
13
|
+
"BLOCK_SIZE_K": 128,
|
14
|
+
"GROUP_SIZE_M": 1,
|
15
|
+
"num_warps": 4,
|
16
|
+
"num_stages": 5
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 16,
|
20
|
+
"BLOCK_SIZE_N": 128,
|
21
|
+
"BLOCK_SIZE_K": 256,
|
22
|
+
"GROUP_SIZE_M": 1,
|
23
|
+
"num_warps": 4,
|
24
|
+
"num_stages": 4
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 16,
|
28
|
+
"BLOCK_SIZE_N": 128,
|
29
|
+
"BLOCK_SIZE_K": 128,
|
30
|
+
"GROUP_SIZE_M": 1,
|
31
|
+
"num_warps": 4,
|
32
|
+
"num_stages": 4
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 16,
|
36
|
+
"BLOCK_SIZE_N": 128,
|
37
|
+
"BLOCK_SIZE_K": 256,
|
38
|
+
"GROUP_SIZE_M": 1,
|
39
|
+
"num_warps": 4,
|
40
|
+
"num_stages": 4
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 16,
|
44
|
+
"BLOCK_SIZE_N": 128,
|
45
|
+
"BLOCK_SIZE_K": 256,
|
46
|
+
"GROUP_SIZE_M": 1,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 4
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 16,
|
52
|
+
"BLOCK_SIZE_N": 128,
|
53
|
+
"BLOCK_SIZE_K": 256,
|
54
|
+
"GROUP_SIZE_M": 1,
|
55
|
+
"num_warps": 4,
|
56
|
+
"num_stages": 4
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 16,
|
60
|
+
"BLOCK_SIZE_N": 128,
|
61
|
+
"BLOCK_SIZE_K": 128,
|
62
|
+
"GROUP_SIZE_M": 1,
|
63
|
+
"num_warps": 4,
|
64
|
+
"num_stages": 4
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 16,
|
68
|
+
"BLOCK_SIZE_N": 128,
|
69
|
+
"BLOCK_SIZE_K": 256,
|
70
|
+
"GROUP_SIZE_M": 1,
|
71
|
+
"num_warps": 4,
|
72
|
+
"num_stages": 4
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 16,
|
76
|
+
"BLOCK_SIZE_N": 128,
|
77
|
+
"BLOCK_SIZE_K": 256,
|
78
|
+
"GROUP_SIZE_M": 1,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 4
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 16,
|
84
|
+
"BLOCK_SIZE_N": 128,
|
85
|
+
"BLOCK_SIZE_K": 256,
|
86
|
+
"GROUP_SIZE_M": 1,
|
87
|
+
"num_warps": 4,
|
88
|
+
"num_stages": 4
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 16,
|
92
|
+
"BLOCK_SIZE_N": 128,
|
93
|
+
"BLOCK_SIZE_K": 128,
|
94
|
+
"GROUP_SIZE_M": 1,
|
95
|
+
"num_warps": 4,
|
96
|
+
"num_stages": 3
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 64,
|
100
|
+
"BLOCK_SIZE_N": 128,
|
101
|
+
"BLOCK_SIZE_K": 256,
|
102
|
+
"GROUP_SIZE_M": 32,
|
103
|
+
"num_warps": 4,
|
104
|
+
"num_stages": 4
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 64,
|
108
|
+
"BLOCK_SIZE_N": 128,
|
109
|
+
"BLOCK_SIZE_K": 128,
|
110
|
+
"GROUP_SIZE_M": 32,
|
111
|
+
"num_warps": 4,
|
112
|
+
"num_stages": 4
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 64,
|
116
|
+
"BLOCK_SIZE_N": 128,
|
117
|
+
"BLOCK_SIZE_K": 128,
|
118
|
+
"GROUP_SIZE_M": 1,
|
119
|
+
"num_warps": 4,
|
120
|
+
"num_stages": 4
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 64,
|
124
|
+
"BLOCK_SIZE_N": 128,
|
125
|
+
"BLOCK_SIZE_K": 256,
|
126
|
+
"GROUP_SIZE_M": 32,
|
127
|
+
"num_warps": 4,
|
128
|
+
"num_stages": 4
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 64,
|
132
|
+
"BLOCK_SIZE_N": 128,
|
133
|
+
"BLOCK_SIZE_K": 128,
|
134
|
+
"GROUP_SIZE_M": 32,
|
135
|
+
"num_warps": 4,
|
136
|
+
"num_stages": 4
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 64,
|
140
|
+
"BLOCK_SIZE_N": 128,
|
141
|
+
"BLOCK_SIZE_K": 256,
|
142
|
+
"GROUP_SIZE_M": 16,
|
143
|
+
"num_warps": 4,
|
144
|
+
"num_stages": 4
|
145
|
+
}
|
146
|
+
}
|
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 16,
|
4
|
+
"BLOCK_SIZE_N": 128,
|
5
|
+
"BLOCK_SIZE_K": 128,
|
6
|
+
"GROUP_SIZE_M": 1,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 4
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 16,
|
12
|
+
"BLOCK_SIZE_N": 128,
|
13
|
+
"BLOCK_SIZE_K": 128,
|
14
|
+
"GROUP_SIZE_M": 1,
|
15
|
+
"num_warps": 4,
|
16
|
+
"num_stages": 4
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 16,
|
20
|
+
"BLOCK_SIZE_N": 128,
|
21
|
+
"BLOCK_SIZE_K": 128,
|
22
|
+
"GROUP_SIZE_M": 1,
|
23
|
+
"num_warps": 4,
|
24
|
+
"num_stages": 4
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 16,
|
28
|
+
"BLOCK_SIZE_N": 128,
|
29
|
+
"BLOCK_SIZE_K": 128,
|
30
|
+
"GROUP_SIZE_M": 1,
|
31
|
+
"num_warps": 4,
|
32
|
+
"num_stages": 3
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 16,
|
36
|
+
"BLOCK_SIZE_N": 128,
|
37
|
+
"BLOCK_SIZE_K": 128,
|
38
|
+
"GROUP_SIZE_M": 1,
|
39
|
+
"num_warps": 4,
|
40
|
+
"num_stages": 3
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 16,
|
44
|
+
"BLOCK_SIZE_N": 256,
|
45
|
+
"BLOCK_SIZE_K": 64,
|
46
|
+
"GROUP_SIZE_M": 1,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 3
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 16,
|
52
|
+
"BLOCK_SIZE_N": 256,
|
53
|
+
"BLOCK_SIZE_K": 64,
|
54
|
+
"GROUP_SIZE_M": 1,
|
55
|
+
"num_warps": 4,
|
56
|
+
"num_stages": 3
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 16,
|
60
|
+
"BLOCK_SIZE_N": 128,
|
61
|
+
"BLOCK_SIZE_K": 128,
|
62
|
+
"GROUP_SIZE_M": 1,
|
63
|
+
"num_warps": 4,
|
64
|
+
"num_stages": 3
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 16,
|
68
|
+
"BLOCK_SIZE_N": 128,
|
69
|
+
"BLOCK_SIZE_K": 128,
|
70
|
+
"GROUP_SIZE_M": 1,
|
71
|
+
"num_warps": 4,
|
72
|
+
"num_stages": 3
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 16,
|
76
|
+
"BLOCK_SIZE_N": 128,
|
77
|
+
"BLOCK_SIZE_K": 128,
|
78
|
+
"GROUP_SIZE_M": 1,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 4
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 16,
|
84
|
+
"BLOCK_SIZE_N": 128,
|
85
|
+
"BLOCK_SIZE_K": 128,
|
86
|
+
"GROUP_SIZE_M": 1,
|
87
|
+
"num_warps": 4,
|
88
|
+
"num_stages": 3
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 16,
|
92
|
+
"BLOCK_SIZE_N": 128,
|
93
|
+
"BLOCK_SIZE_K": 128,
|
94
|
+
"GROUP_SIZE_M": 1,
|
95
|
+
"num_warps": 4,
|
96
|
+
"num_stages": 3
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 16,
|
100
|
+
"BLOCK_SIZE_N": 128,
|
101
|
+
"BLOCK_SIZE_K": 128,
|
102
|
+
"GROUP_SIZE_M": 1,
|
103
|
+
"num_warps": 4,
|
104
|
+
"num_stages": 3
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 32,
|
108
|
+
"BLOCK_SIZE_N": 128,
|
109
|
+
"BLOCK_SIZE_K": 128,
|
110
|
+
"GROUP_SIZE_M": 1,
|
111
|
+
"num_warps": 4,
|
112
|
+
"num_stages": 3
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 32,
|
116
|
+
"BLOCK_SIZE_N": 128,
|
117
|
+
"BLOCK_SIZE_K": 128,
|
118
|
+
"GROUP_SIZE_M": 1,
|
119
|
+
"num_warps": 4,
|
120
|
+
"num_stages": 4
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 64,
|
124
|
+
"BLOCK_SIZE_N": 128,
|
125
|
+
"BLOCK_SIZE_K": 128,
|
126
|
+
"GROUP_SIZE_M": 32,
|
127
|
+
"num_warps": 4,
|
128
|
+
"num_stages": 4
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 64,
|
132
|
+
"BLOCK_SIZE_N": 128,
|
133
|
+
"BLOCK_SIZE_K": 128,
|
134
|
+
"GROUP_SIZE_M": 64,
|
135
|
+
"num_warps": 4,
|
136
|
+
"num_stages": 4
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 64,
|
140
|
+
"BLOCK_SIZE_N": 128,
|
141
|
+
"BLOCK_SIZE_K": 128,
|
142
|
+
"GROUP_SIZE_M": 32,
|
143
|
+
"num_warps": 4,
|
144
|
+
"num_stages": 4
|
145
|
+
}
|
146
|
+
}
|
@@ -147,6 +147,7 @@ class FusedMoE(torch.nn.Module):
|
|
147
147
|
|
148
148
|
self.layer_id = layer_id
|
149
149
|
self.top_k = top_k
|
150
|
+
self.hidden_size = hidden_size
|
150
151
|
self.num_experts = num_experts
|
151
152
|
self.num_fused_shared_experts = num_fused_shared_experts
|
152
153
|
self.expert_map_cpu = None
|
@@ -209,13 +210,13 @@ class FusedMoE(torch.nn.Module):
|
|
209
210
|
self.use_enable_flashinfer_mxfp4_moe = global_server_args_dict.get(
|
210
211
|
"enable_flashinfer_mxfp4_moe", False
|
211
212
|
)
|
213
|
+
# TODO maybe we should remove this `if`, since `Mxfp4MoEMethod` does another round-up logic
|
212
214
|
if (
|
213
215
|
self.quant_config is not None
|
214
216
|
and self.quant_config.get_name() == "mxfp4"
|
215
217
|
and self.use_enable_flashinfer_mxfp4_moe
|
216
218
|
):
|
217
219
|
hidden_size = round_up(hidden_size, 256)
|
218
|
-
self.hidden_size = hidden_size
|
219
220
|
self.quant_method.create_weights(
|
220
221
|
layer=self,
|
221
222
|
num_experts=self.num_local_experts,
|
@@ -795,13 +796,6 @@ class FusedMoE(torch.nn.Module):
|
|
795
796
|
|
796
797
|
def forward(self, hidden_states: torch.Tensor, topk_output: StandardTopKOutput):
|
797
798
|
origin_hidden_states_dim = hidden_states.shape[-1]
|
798
|
-
if self.hidden_size != origin_hidden_states_dim:
|
799
|
-
hidden_states = torch.nn.functional.pad(
|
800
|
-
hidden_states,
|
801
|
-
(0, self.hidden_size - origin_hidden_states_dim),
|
802
|
-
mode="constant",
|
803
|
-
value=0.0,
|
804
|
-
)
|
805
799
|
assert self.quant_method is not None
|
806
800
|
|
807
801
|
if self.moe_ep_size > 1 and not self.enable_flashinfer_cutlass_moe:
|
@@ -846,10 +840,14 @@ class FusedMoE(torch.nn.Module):
|
|
846
840
|
)
|
847
841
|
sm.tag(final_hidden_states)
|
848
842
|
|
843
|
+
final_hidden_states = final_hidden_states[
|
844
|
+
..., :origin_hidden_states_dim
|
845
|
+
].contiguous()
|
846
|
+
|
849
847
|
if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
|
850
848
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
851
849
|
|
852
|
-
return final_hidden_states
|
850
|
+
return final_hidden_states
|
853
851
|
|
854
852
|
@classmethod
|
855
853
|
def make_expert_params_mapping(
|
@@ -23,14 +23,23 @@ from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
|
|
23
23
|
from sglang.srt.layers.moe.utils import DeepEPMode
|
24
24
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
25
25
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
26
|
-
from sglang.srt.utils import
|
26
|
+
from sglang.srt.utils import (
|
27
|
+
get_bool_env_var,
|
28
|
+
get_int_env_var,
|
29
|
+
is_hip,
|
30
|
+
is_npu,
|
31
|
+
load_json_config,
|
32
|
+
)
|
33
|
+
|
34
|
+
_is_npu = is_npu()
|
27
35
|
|
28
36
|
try:
|
29
37
|
from deep_ep import Buffer, Config
|
30
38
|
|
31
|
-
|
32
|
-
|
33
|
-
|
39
|
+
if not _is_npu:
|
40
|
+
from sglang.srt.layers.quantization.fp8_kernel import (
|
41
|
+
sglang_per_token_group_quant_fp8,
|
42
|
+
)
|
34
43
|
|
35
44
|
use_deepep = True
|
36
45
|
except ImportError:
|
@@ -80,8 +89,24 @@ class DeepEPLLOutput(NamedTuple):
|
|
80
89
|
return DispatchOutputFormat.deepep_ll
|
81
90
|
|
82
91
|
|
92
|
+
class AscendDeepEPLLOutput(NamedTuple):
|
93
|
+
"""AscendDeepEP low latency dispatch output."""
|
94
|
+
|
95
|
+
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor]
|
96
|
+
topk_idx: torch.Tensor
|
97
|
+
topk_weights: torch.Tensor
|
98
|
+
masked_m: torch.Tensor
|
99
|
+
seg_indptr: torch.Tensor
|
100
|
+
expected_m: int
|
101
|
+
|
102
|
+
@property
|
103
|
+
def format(self) -> DispatchOutputFormat:
|
104
|
+
return DispatchOutputFormat.deepep_ll
|
105
|
+
|
106
|
+
|
83
107
|
assert isinstance(DeepEPNormalOutput, DispatchOutput)
|
84
108
|
assert isinstance(DeepEPLLOutput, DispatchOutput)
|
109
|
+
assert isinstance(AscendDeepEPLLOutput, DispatchOutput)
|
85
110
|
|
86
111
|
|
87
112
|
class DeepEPDispatchMode(IntEnum):
|
@@ -150,19 +175,20 @@ class DeepEPBuffer:
|
|
150
175
|
else:
|
151
176
|
raise NotImplementedError
|
152
177
|
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
(
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
178
|
+
if not _is_npu:
|
179
|
+
total_num_sms = torch.cuda.get_device_properties(
|
180
|
+
device="cuda"
|
181
|
+
).multi_processor_count
|
182
|
+
if (
|
183
|
+
(deepep_mode != DeepEPMode.LOW_LATENCY)
|
184
|
+
and not global_server_args_dict["enable_two_batch_overlap"]
|
185
|
+
and (DeepEPConfig.get_instance().num_sms < total_num_sms // 2)
|
186
|
+
):
|
187
|
+
logger.warning(
|
188
|
+
f"Only use {DeepEPConfig.get_instance().num_sms} SMs for DeepEP communication. "
|
189
|
+
f"This may result in highly suboptimal performance. "
|
190
|
+
f"Consider using --deepep-config to change the behavior."
|
191
|
+
)
|
166
192
|
|
167
193
|
cls._buffer = Buffer(
|
168
194
|
group,
|
@@ -507,13 +533,24 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|
507
533
|
masked_m
|
508
534
|
)
|
509
535
|
|
510
|
-
|
511
|
-
|
512
|
-
|
513
|
-
|
514
|
-
|
515
|
-
|
516
|
-
|
536
|
+
if _is_npu:
|
537
|
+
deepep_output = AscendDeepEPLLOutput(
|
538
|
+
hidden_states,
|
539
|
+
topk_idx,
|
540
|
+
topk_weights,
|
541
|
+
masked_m,
|
542
|
+
self.handle[1],
|
543
|
+
expected_m,
|
544
|
+
)
|
545
|
+
else:
|
546
|
+
deepep_output = DeepEPLLOutput(
|
547
|
+
hidden_states,
|
548
|
+
topk_idx,
|
549
|
+
topk_weights,
|
550
|
+
masked_m,
|
551
|
+
expected_m,
|
552
|
+
)
|
553
|
+
return deepep_output
|
517
554
|
|
518
555
|
def _dispatch_core(
|
519
556
|
self,
|
sglang/srt/layers/moe/topk.py
CHANGED
@@ -245,10 +245,11 @@ class TopK(CustomOp):
|
|
245
245
|
|
246
246
|
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
247
247
|
if global_num_experts == 256:
|
248
|
+
router_logits = router_logits.to(torch.float32)
|
248
249
|
return torch_npu.npu_moe_gating_top_k(
|
249
250
|
router_logits,
|
250
251
|
k=self.top_k,
|
251
|
-
bias=self.correction_bias,
|
252
|
+
bias=self.correction_bias.to(torch.float32),
|
252
253
|
k_group=self.topk_group,
|
253
254
|
group_count=self.num_expert_group,
|
254
255
|
group_select_mode=1,
|
@@ -440,7 +441,9 @@ def grouped_topk_cpu(
|
|
440
441
|
routed_scaling_factor: Optional[float] = None,
|
441
442
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
442
443
|
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
444
|
+
apply_routed_scaling_factor_on_output: Optional[bool] = False,
|
443
445
|
):
|
446
|
+
assert not apply_routed_scaling_factor_on_output
|
444
447
|
assert expert_location_dispatch_info is None
|
445
448
|
return torch.ops.sgl_kernel.grouped_topk_cpu(
|
446
449
|
hidden_states,
|
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|
3
3
|
|
4
4
|
import builtins
|
5
5
|
import inspect
|
6
|
-
from typing import TYPE_CHECKING,
|
6
|
+
from typing import TYPE_CHECKING, Dict, Optional, Type
|
7
7
|
|
8
8
|
import torch
|
9
9
|
|
@@ -26,8 +26,9 @@ try:
|
|
26
26
|
from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
|
27
27
|
|
28
28
|
VLLM_AVAILABLE = True
|
29
|
-
except ImportError:
|
29
|
+
except ImportError as e:
|
30
30
|
VLLM_AVAILABLE = False
|
31
|
+
VLLM_IMPORT_ERROR = e
|
31
32
|
|
32
33
|
# Define empty classes as placeholders when vllm is not available
|
33
34
|
class DummyConfig:
|
@@ -137,7 +138,8 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
|
137
138
|
if quantization in VLLM_QUANTIZATION_METHODS and not VLLM_AVAILABLE:
|
138
139
|
raise ValueError(
|
139
140
|
f"{quantization} quantization requires some operators from vllm. "
|
140
|
-
"Please install vllm by `pip install vllm==0.9.0.1
|
141
|
+
f"Please install vllm by `pip install vllm==0.9.0.1`\n"
|
142
|
+
f"Import error: {VLLM_IMPORT_ERROR}"
|
141
143
|
)
|
142
144
|
|
143
145
|
return QUANTIZATION_METHODS[quantization]
|