sglang 0.5.0rc0__py3-none-any.whl → 0.5.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -3
- sglang/bench_one_batch.py +6 -1
- sglang/lang/chat_template.py +18 -0
- sglang/srt/bench_utils.py +137 -0
- sglang/srt/configs/model_config.py +8 -7
- sglang/srt/disaggregation/decode.py +8 -4
- sglang/srt/disaggregation/mooncake/conn.py +43 -25
- sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
- sglang/srt/distributed/parallel_state.py +4 -2
- sglang/srt/entrypoints/context.py +3 -20
- sglang/srt/entrypoints/engine.py +13 -8
- sglang/srt/entrypoints/harmony_utils.py +2 -0
- sglang/srt/entrypoints/http_server.py +68 -5
- sglang/srt/entrypoints/openai/protocol.py +2 -9
- sglang/srt/entrypoints/openai/serving_chat.py +60 -265
- sglang/srt/entrypoints/openai/serving_completions.py +1 -0
- sglang/srt/entrypoints/openai/tool_server.py +4 -3
- sglang/srt/function_call/ebnf_composer.py +1 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +331 -0
- sglang/srt/function_call/kimik2_detector.py +3 -3
- sglang/srt/function_call/qwen3_coder_detector.py +219 -9
- sglang/srt/jinja_template_utils.py +6 -0
- sglang/srt/layers/attention/aiter_backend.py +370 -107
- sglang/srt/layers/attention/ascend_backend.py +3 -0
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +18 -0
- sglang/srt/layers/attention/flashinfer_backend.py +55 -13
- sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -0
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +24 -27
- sglang/srt/layers/attention/trtllm_mha_backend.py +8 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +129 -25
- sglang/srt/layers/attention/vision.py +9 -1
- sglang/srt/layers/attention/wave_backend.py +627 -0
- sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
- sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
- sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
- sglang/srt/layers/communicator.py +11 -13
- sglang/srt/layers/dp_attention.py +118 -27
- sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/logits_processor.py +12 -18
- sglang/srt/layers/moe/cutlass_moe.py +11 -16
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
- sglang/srt/layers/moe/ep_moe/layer.py +60 -2
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -9
- sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
- sglang/srt/layers/moe/topk.py +4 -1
- sglang/srt/layers/multimodal.py +156 -40
- sglang/srt/layers/quantization/__init__.py +10 -35
- sglang/srt/layers/quantization/awq.py +15 -16
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +0 -1
- sglang/srt/layers/quantization/fp8_kernel.py +277 -0
- sglang/srt/layers/quantization/fp8_utils.py +22 -10
- sglang/srt/layers/quantization/gptq.py +12 -17
- sglang/srt/layers/quantization/marlin_utils.py +15 -5
- sglang/srt/layers/quantization/modelopt_quant.py +58 -41
- sglang/srt/layers/quantization/mxfp4.py +20 -3
- sglang/srt/layers/quantization/utils.py +52 -2
- sglang/srt/layers/quantization/w4afp8.py +20 -11
- sglang/srt/layers/quantization/w8a8_int8.py +48 -34
- sglang/srt/layers/rotary_embedding.py +281 -2
- sglang/srt/layers/sampler.py +5 -2
- sglang/srt/lora/backend/base_backend.py +3 -23
- sglang/srt/lora/layers.py +66 -116
- sglang/srt/lora/lora.py +17 -62
- sglang/srt/lora/lora_manager.py +12 -48
- sglang/srt/lora/lora_registry.py +20 -9
- sglang/srt/lora/mem_pool.py +20 -63
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/utils.py +25 -58
- sglang/srt/managers/cache_controller.py +24 -29
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +20 -6
- sglang/srt/managers/mm_utils.py +1 -2
- sglang/srt/managers/multimodal_processor.py +1 -1
- sglang/srt/managers/schedule_batch.py +43 -49
- sglang/srt/managers/schedule_policy.py +6 -6
- sglang/srt/managers/scheduler.py +18 -11
- sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
- sglang/srt/managers/tokenizer_manager.py +53 -44
- sglang/srt/mem_cache/allocator.py +39 -214
- sglang/srt/mem_cache/allocator_ascend.py +158 -0
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +1 -1
- sglang/srt/mem_cache/hiradix_cache.py +34 -24
- sglang/srt/mem_cache/lora_radix_cache.py +421 -0
- sglang/srt/mem_cache/memory_pool_host.py +33 -35
- sglang/srt/mem_cache/radix_cache.py +2 -5
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
- sglang/srt/model_executor/cuda_graph_runner.py +29 -23
- sglang/srt/model_executor/forward_batch_info.py +33 -14
- sglang/srt/model_executor/model_runner.py +179 -81
- sglang/srt/model_loader/loader.py +18 -6
- sglang/srt/models/deepseek_nextn.py +2 -1
- sglang/srt/models/deepseek_v2.py +79 -38
- sglang/srt/models/gemma2.py +0 -34
- sglang/srt/models/gemma3n_mm.py +8 -9
- sglang/srt/models/glm4.py +6 -0
- sglang/srt/models/glm4_moe.py +11 -11
- sglang/srt/models/glm4_moe_nextn.py +2 -1
- sglang/srt/models/glm4v.py +589 -0
- sglang/srt/models/glm4v_moe.py +400 -0
- sglang/srt/models/gpt_oss.py +142 -20
- sglang/srt/models/granite.py +0 -25
- sglang/srt/models/llama.py +10 -27
- sglang/srt/models/llama4.py +19 -6
- sglang/srt/models/qwen2.py +2 -2
- sglang/srt/models/qwen2_5_vl.py +7 -3
- sglang/srt/models/qwen2_audio.py +10 -9
- sglang/srt/models/qwen2_moe.py +20 -5
- sglang/srt/models/qwen3.py +0 -24
- sglang/srt/models/qwen3_classification.py +78 -0
- sglang/srt/models/qwen3_moe.py +18 -5
- sglang/srt/models/registry.py +1 -1
- sglang/srt/models/step3_vl.py +6 -2
- sglang/srt/models/torch_native_llama.py +0 -24
- sglang/srt/multimodal/processors/base_processor.py +23 -13
- sglang/srt/multimodal/processors/glm4v.py +132 -0
- sglang/srt/multimodal/processors/qwen_audio.py +4 -2
- sglang/srt/operations.py +17 -2
- sglang/srt/reasoning_parser.py +316 -0
- sglang/srt/sampling/sampling_batch_info.py +7 -4
- sglang/srt/server_args.py +142 -140
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -21
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
- sglang/srt/speculative/eagle_worker.py +16 -0
- sglang/srt/two_batch_overlap.py +16 -12
- sglang/srt/utils.py +3 -3
- sglang/srt/weight_sync/tensor_bucket.py +106 -0
- sglang/test/attention/test_trtllm_mla_backend.py +186 -36
- sglang/test/doc_patch.py +59 -0
- sglang/test/few_shot_gsm8k.py +1 -1
- sglang/test/few_shot_gsm8k_engine.py +1 -1
- sglang/test/run_eval.py +4 -1
- sglang/test/simple_eval_common.py +6 -0
- sglang/test/simple_eval_gpqa.py +2 -0
- sglang/test/test_fp4_moe.py +118 -36
- sglang/test/test_marlin_moe.py +1 -1
- sglang/test/test_marlin_utils.py +1 -1
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/METADATA +27 -31
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/RECORD +166 -142
- sglang/lang/backend/__init__.py +0 -0
- sglang/srt/function_call/harmony_tool_parser.py +0 -130
- sglang/srt/layers/quantization/scalar_type.py +0 -352
- sglang/srt/lora/backend/flashinfer_backend.py +0 -131
- /sglang/{api.py → lang/api.py} +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 16,
|
4
|
+
"BLOCK_SIZE_N": 128,
|
5
|
+
"BLOCK_SIZE_K": 128,
|
6
|
+
"GROUP_SIZE_M": 1,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 4
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 16,
|
12
|
+
"BLOCK_SIZE_N": 128,
|
13
|
+
"BLOCK_SIZE_K": 128,
|
14
|
+
"GROUP_SIZE_M": 1,
|
15
|
+
"num_warps": 4,
|
16
|
+
"num_stages": 4
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 16,
|
20
|
+
"BLOCK_SIZE_N": 128,
|
21
|
+
"BLOCK_SIZE_K": 128,
|
22
|
+
"GROUP_SIZE_M": 1,
|
23
|
+
"num_warps": 4,
|
24
|
+
"num_stages": 4
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 16,
|
28
|
+
"BLOCK_SIZE_N": 128,
|
29
|
+
"BLOCK_SIZE_K": 128,
|
30
|
+
"GROUP_SIZE_M": 1,
|
31
|
+
"num_warps": 4,
|
32
|
+
"num_stages": 3
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 16,
|
36
|
+
"BLOCK_SIZE_N": 128,
|
37
|
+
"BLOCK_SIZE_K": 128,
|
38
|
+
"GROUP_SIZE_M": 1,
|
39
|
+
"num_warps": 4,
|
40
|
+
"num_stages": 3
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 16,
|
44
|
+
"BLOCK_SIZE_N": 256,
|
45
|
+
"BLOCK_SIZE_K": 64,
|
46
|
+
"GROUP_SIZE_M": 1,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 3
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 16,
|
52
|
+
"BLOCK_SIZE_N": 256,
|
53
|
+
"BLOCK_SIZE_K": 64,
|
54
|
+
"GROUP_SIZE_M": 1,
|
55
|
+
"num_warps": 4,
|
56
|
+
"num_stages": 3
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 16,
|
60
|
+
"BLOCK_SIZE_N": 128,
|
61
|
+
"BLOCK_SIZE_K": 128,
|
62
|
+
"GROUP_SIZE_M": 1,
|
63
|
+
"num_warps": 4,
|
64
|
+
"num_stages": 3
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 16,
|
68
|
+
"BLOCK_SIZE_N": 128,
|
69
|
+
"BLOCK_SIZE_K": 128,
|
70
|
+
"GROUP_SIZE_M": 1,
|
71
|
+
"num_warps": 4,
|
72
|
+
"num_stages": 3
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 16,
|
76
|
+
"BLOCK_SIZE_N": 128,
|
77
|
+
"BLOCK_SIZE_K": 128,
|
78
|
+
"GROUP_SIZE_M": 1,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 4
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 16,
|
84
|
+
"BLOCK_SIZE_N": 128,
|
85
|
+
"BLOCK_SIZE_K": 128,
|
86
|
+
"GROUP_SIZE_M": 1,
|
87
|
+
"num_warps": 4,
|
88
|
+
"num_stages": 3
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 16,
|
92
|
+
"BLOCK_SIZE_N": 128,
|
93
|
+
"BLOCK_SIZE_K": 128,
|
94
|
+
"GROUP_SIZE_M": 1,
|
95
|
+
"num_warps": 4,
|
96
|
+
"num_stages": 3
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 16,
|
100
|
+
"BLOCK_SIZE_N": 128,
|
101
|
+
"BLOCK_SIZE_K": 128,
|
102
|
+
"GROUP_SIZE_M": 1,
|
103
|
+
"num_warps": 4,
|
104
|
+
"num_stages": 3
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 32,
|
108
|
+
"BLOCK_SIZE_N": 128,
|
109
|
+
"BLOCK_SIZE_K": 128,
|
110
|
+
"GROUP_SIZE_M": 1,
|
111
|
+
"num_warps": 4,
|
112
|
+
"num_stages": 3
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 32,
|
116
|
+
"BLOCK_SIZE_N": 128,
|
117
|
+
"BLOCK_SIZE_K": 128,
|
118
|
+
"GROUP_SIZE_M": 1,
|
119
|
+
"num_warps": 4,
|
120
|
+
"num_stages": 4
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 64,
|
124
|
+
"BLOCK_SIZE_N": 128,
|
125
|
+
"BLOCK_SIZE_K": 128,
|
126
|
+
"GROUP_SIZE_M": 32,
|
127
|
+
"num_warps": 4,
|
128
|
+
"num_stages": 4
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 64,
|
132
|
+
"BLOCK_SIZE_N": 128,
|
133
|
+
"BLOCK_SIZE_K": 128,
|
134
|
+
"GROUP_SIZE_M": 64,
|
135
|
+
"num_warps": 4,
|
136
|
+
"num_stages": 4
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 64,
|
140
|
+
"BLOCK_SIZE_N": 128,
|
141
|
+
"BLOCK_SIZE_K": 128,
|
142
|
+
"GROUP_SIZE_M": 32,
|
143
|
+
"num_warps": 4,
|
144
|
+
"num_stages": 4
|
145
|
+
}
|
146
|
+
}
|
@@ -147,6 +147,7 @@ class FusedMoE(torch.nn.Module):
|
|
147
147
|
|
148
148
|
self.layer_id = layer_id
|
149
149
|
self.top_k = top_k
|
150
|
+
self.hidden_size = hidden_size
|
150
151
|
self.num_experts = num_experts
|
151
152
|
self.num_fused_shared_experts = num_fused_shared_experts
|
152
153
|
self.expert_map_cpu = None
|
@@ -209,13 +210,13 @@ class FusedMoE(torch.nn.Module):
|
|
209
210
|
self.use_enable_flashinfer_mxfp4_moe = global_server_args_dict.get(
|
210
211
|
"enable_flashinfer_mxfp4_moe", False
|
211
212
|
)
|
213
|
+
# TODO maybe we should remove this `if`, since `Mxfp4MoEMethod` does another round-up logic
|
212
214
|
if (
|
213
215
|
self.quant_config is not None
|
214
216
|
and self.quant_config.get_name() == "mxfp4"
|
215
217
|
and self.use_enable_flashinfer_mxfp4_moe
|
216
218
|
):
|
217
219
|
hidden_size = round_up(hidden_size, 256)
|
218
|
-
self.hidden_size = hidden_size
|
219
220
|
self.quant_method.create_weights(
|
220
221
|
layer=self,
|
221
222
|
num_experts=self.num_local_experts,
|
@@ -795,13 +796,6 @@ class FusedMoE(torch.nn.Module):
|
|
795
796
|
|
796
797
|
def forward(self, hidden_states: torch.Tensor, topk_output: StandardTopKOutput):
|
797
798
|
origin_hidden_states_dim = hidden_states.shape[-1]
|
798
|
-
if self.hidden_size != origin_hidden_states_dim:
|
799
|
-
hidden_states = torch.nn.functional.pad(
|
800
|
-
hidden_states,
|
801
|
-
(0, self.hidden_size - origin_hidden_states_dim),
|
802
|
-
mode="constant",
|
803
|
-
value=0.0,
|
804
|
-
)
|
805
799
|
assert self.quant_method is not None
|
806
800
|
|
807
801
|
if self.moe_ep_size > 1 and not self.enable_flashinfer_cutlass_moe:
|
@@ -846,10 +840,14 @@ class FusedMoE(torch.nn.Module):
|
|
846
840
|
)
|
847
841
|
sm.tag(final_hidden_states)
|
848
842
|
|
843
|
+
final_hidden_states = final_hidden_states[
|
844
|
+
..., :origin_hidden_states_dim
|
845
|
+
].contiguous()
|
846
|
+
|
849
847
|
if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
|
850
848
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
851
849
|
|
852
|
-
return final_hidden_states
|
850
|
+
return final_hidden_states
|
853
851
|
|
854
852
|
@classmethod
|
855
853
|
def make_expert_params_mapping(
|
@@ -23,14 +23,23 @@ from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
|
|
23
23
|
from sglang.srt.layers.moe.utils import DeepEPMode
|
24
24
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
25
25
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
26
|
-
from sglang.srt.utils import
|
26
|
+
from sglang.srt.utils import (
|
27
|
+
get_bool_env_var,
|
28
|
+
get_int_env_var,
|
29
|
+
is_hip,
|
30
|
+
is_npu,
|
31
|
+
load_json_config,
|
32
|
+
)
|
33
|
+
|
34
|
+
_is_npu = is_npu()
|
27
35
|
|
28
36
|
try:
|
29
37
|
from deep_ep import Buffer, Config
|
30
38
|
|
31
|
-
|
32
|
-
|
33
|
-
|
39
|
+
if not _is_npu:
|
40
|
+
from sglang.srt.layers.quantization.fp8_kernel import (
|
41
|
+
sglang_per_token_group_quant_fp8,
|
42
|
+
)
|
34
43
|
|
35
44
|
use_deepep = True
|
36
45
|
except ImportError:
|
@@ -80,8 +89,24 @@ class DeepEPLLOutput(NamedTuple):
|
|
80
89
|
return DispatchOutputFormat.deepep_ll
|
81
90
|
|
82
91
|
|
92
|
+
class AscendDeepEPLLOutput(NamedTuple):
|
93
|
+
"""AscendDeepEP low latency dispatch output."""
|
94
|
+
|
95
|
+
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor]
|
96
|
+
topk_idx: torch.Tensor
|
97
|
+
topk_weights: torch.Tensor
|
98
|
+
masked_m: torch.Tensor
|
99
|
+
seg_indptr: torch.Tensor
|
100
|
+
expected_m: int
|
101
|
+
|
102
|
+
@property
|
103
|
+
def format(self) -> DispatchOutputFormat:
|
104
|
+
return DispatchOutputFormat.deepep_ll
|
105
|
+
|
106
|
+
|
83
107
|
assert isinstance(DeepEPNormalOutput, DispatchOutput)
|
84
108
|
assert isinstance(DeepEPLLOutput, DispatchOutput)
|
109
|
+
assert isinstance(AscendDeepEPLLOutput, DispatchOutput)
|
85
110
|
|
86
111
|
|
87
112
|
class DeepEPDispatchMode(IntEnum):
|
@@ -150,19 +175,20 @@ class DeepEPBuffer:
|
|
150
175
|
else:
|
151
176
|
raise NotImplementedError
|
152
177
|
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
(
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
178
|
+
if not _is_npu:
|
179
|
+
total_num_sms = torch.cuda.get_device_properties(
|
180
|
+
device="cuda"
|
181
|
+
).multi_processor_count
|
182
|
+
if (
|
183
|
+
(deepep_mode != DeepEPMode.LOW_LATENCY)
|
184
|
+
and not global_server_args_dict["enable_two_batch_overlap"]
|
185
|
+
and (DeepEPConfig.get_instance().num_sms < total_num_sms // 2)
|
186
|
+
):
|
187
|
+
logger.warning(
|
188
|
+
f"Only use {DeepEPConfig.get_instance().num_sms} SMs for DeepEP communication. "
|
189
|
+
f"This may result in highly suboptimal performance. "
|
190
|
+
f"Consider using --deepep-config to change the behavior."
|
191
|
+
)
|
166
192
|
|
167
193
|
cls._buffer = Buffer(
|
168
194
|
group,
|
@@ -507,13 +533,24 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|
507
533
|
masked_m
|
508
534
|
)
|
509
535
|
|
510
|
-
|
511
|
-
|
512
|
-
|
513
|
-
|
514
|
-
|
515
|
-
|
516
|
-
|
536
|
+
if _is_npu:
|
537
|
+
deepep_output = AscendDeepEPLLOutput(
|
538
|
+
hidden_states,
|
539
|
+
topk_idx,
|
540
|
+
topk_weights,
|
541
|
+
masked_m,
|
542
|
+
self.handle[1],
|
543
|
+
expected_m,
|
544
|
+
)
|
545
|
+
else:
|
546
|
+
deepep_output = DeepEPLLOutput(
|
547
|
+
hidden_states,
|
548
|
+
topk_idx,
|
549
|
+
topk_weights,
|
550
|
+
masked_m,
|
551
|
+
expected_m,
|
552
|
+
)
|
553
|
+
return deepep_output
|
517
554
|
|
518
555
|
def _dispatch_core(
|
519
556
|
self,
|
sglang/srt/layers/moe/topk.py
CHANGED
@@ -245,10 +245,11 @@ class TopK(CustomOp):
|
|
245
245
|
|
246
246
|
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
247
247
|
if global_num_experts == 256:
|
248
|
+
router_logits = router_logits.to(torch.float32)
|
248
249
|
return torch_npu.npu_moe_gating_top_k(
|
249
250
|
router_logits,
|
250
251
|
k=self.top_k,
|
251
|
-
bias=self.correction_bias,
|
252
|
+
bias=self.correction_bias.to(torch.float32),
|
252
253
|
k_group=self.topk_group,
|
253
254
|
group_count=self.num_expert_group,
|
254
255
|
group_select_mode=1,
|
@@ -440,7 +441,9 @@ def grouped_topk_cpu(
|
|
440
441
|
routed_scaling_factor: Optional[float] = None,
|
441
442
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
442
443
|
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
444
|
+
apply_routed_scaling_factor_on_output: Optional[bool] = False,
|
443
445
|
):
|
446
|
+
assert not apply_routed_scaling_factor_on_output
|
444
447
|
assert expert_location_dispatch_info is None
|
445
448
|
return torch.ops.sgl_kernel.grouped_topk_cpu(
|
446
449
|
hidden_states,
|
sglang/srt/layers/multimodal.py
CHANGED
@@ -17,57 +17,173 @@ import torch
|
|
17
17
|
import triton
|
18
18
|
import triton.language as tl
|
19
19
|
|
20
|
+
FMIX32_C1 = 0x85EBCA6B
|
21
|
+
FMIX32_C2 = 0xC2B2AE35
|
22
|
+
POS_C1 = 0x27D4EB2D
|
23
|
+
POS_C2 = 0x165667B1
|
24
|
+
|
25
|
+
|
26
|
+
@triton.jit
|
27
|
+
def _rotl32(x, r: tl.constexpr):
|
28
|
+
return (x << r) | (x >> (32 - r))
|
29
|
+
|
30
|
+
|
31
|
+
@triton.jit
|
32
|
+
def _fmix32(x, C1: tl.constexpr, C2: tl.constexpr):
|
33
|
+
c1 = tl.full((), C1, tl.uint32)
|
34
|
+
c2 = tl.full((), C2, tl.uint32)
|
35
|
+
x ^= x >> 16
|
36
|
+
x = x * c1
|
37
|
+
x ^= x >> 13
|
38
|
+
x = x * c2
|
39
|
+
x ^= x >> 16
|
40
|
+
return x
|
41
|
+
|
20
42
|
|
21
43
|
@triton.jit
|
22
|
-
def
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
44
|
+
def hash_tiles32_kernel_blocked(
|
45
|
+
in_ptr,
|
46
|
+
out_ptr,
|
47
|
+
n_u32,
|
48
|
+
seed1,
|
49
|
+
seed2,
|
50
|
+
FM_C1: tl.constexpr,
|
51
|
+
FM_C2: tl.constexpr,
|
52
|
+
POS_A: tl.constexpr,
|
53
|
+
POS_B: tl.constexpr,
|
54
|
+
TILE: tl.constexpr,
|
55
|
+
BLOCK: tl.constexpr,
|
56
|
+
USE_CG: tl.constexpr,
|
29
57
|
):
|
30
58
|
pid = tl.program_id(axis=0)
|
31
|
-
|
32
|
-
|
33
|
-
|
59
|
+
base = pid * TILE
|
60
|
+
|
61
|
+
s1 = tl.full((), seed1, tl.uint32)
|
62
|
+
s2 = tl.full((), seed2, tl.uint32)
|
63
|
+
posA = tl.full((), POS_A, tl.uint32)
|
64
|
+
posB = tl.full((), POS_B, tl.uint32)
|
65
|
+
|
66
|
+
h1 = tl.zeros((), dtype=tl.uint32)
|
67
|
+
h2 = tl.zeros((), dtype=tl.uint32)
|
68
|
+
|
69
|
+
for off in tl.static_range(0, TILE, BLOCK):
|
70
|
+
idx = base + off + tl.arange(0, BLOCK)
|
71
|
+
m = idx < n_u32
|
34
72
|
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
73
|
+
if USE_CG:
|
74
|
+
v = tl.load(in_ptr + idx, mask=m, other=0, cache_modifier=".cg")
|
75
|
+
else:
|
76
|
+
v = tl.load(in_ptr + idx, mask=m, other=0)
|
77
|
+
v = v.to(tl.uint32)
|
78
|
+
|
79
|
+
iu = idx.to(tl.uint32)
|
80
|
+
p1 = (iu * posA + s1) ^ _rotl32(iu, 15)
|
81
|
+
p2 = (iu * posB + s2) ^ _rotl32(iu, 13)
|
82
|
+
|
83
|
+
k1 = _fmix32(v ^ p1, C1=FM_C1, C2=FM_C2)
|
84
|
+
k2 = _fmix32(v ^ p2, C1=FM_C1, C2=FM_C2)
|
85
|
+
|
86
|
+
zero32 = tl.zeros_like(k1)
|
87
|
+
k1 = tl.where(m, k1, zero32)
|
88
|
+
k2 = tl.where(m, k2, zero32)
|
89
|
+
|
90
|
+
h1 += tl.sum(k1, axis=0).to(tl.uint32)
|
91
|
+
h2 += tl.sum(k2, axis=0).to(tl.uint32)
|
92
|
+
|
93
|
+
nbytes = tl.full((), n_u32 * 4, tl.uint32)
|
94
|
+
h1 ^= nbytes
|
95
|
+
h2 ^= nbytes
|
96
|
+
h1 = _fmix32(h1, C1=FM_C1, C2=FM_C2)
|
97
|
+
h2 = (
|
98
|
+
_fmix32(h2, C1=FMIX32_C1, C2=FMIX32_C2)
|
99
|
+
if False
|
100
|
+
else _fmix32(h2, C1=FM_C1, C2=FM_C2)
|
101
|
+
)
|
102
|
+
|
103
|
+
out = (h1.to(tl.uint64) << 32) | h2.to(tl.uint64)
|
104
|
+
tl.store(out_ptr + pid, out)
|
105
|
+
|
106
|
+
|
107
|
+
@triton.jit
|
108
|
+
def add_tree_reduce_u64_kernel(in_ptr, out_ptr, n_elems, CHUNK: tl.constexpr):
|
109
|
+
pid = tl.program_id(axis=0)
|
110
|
+
start = pid * CHUNK
|
111
|
+
h = tl.zeros((), dtype=tl.uint64)
|
112
|
+
for i in tl.static_range(0, CHUNK):
|
113
|
+
idx = start + i
|
114
|
+
m = idx < n_elems
|
115
|
+
v = tl.load(in_ptr + idx, mask=m, other=0).to(tl.uint64)
|
116
|
+
h += v
|
117
|
+
tl.store(out_ptr + pid, h)
|
41
118
|
|
42
|
-
tl.store(output_ptr + offsets, hash_val, mask=mask)
|
43
119
|
|
120
|
+
def _as_uint32_words(t: torch.Tensor) -> torch.Tensor:
|
121
|
+
assert t.is_cuda, "Use .cuda() first"
|
122
|
+
tb = t.contiguous().view(torch.uint8)
|
123
|
+
nbytes = tb.numel()
|
124
|
+
pad = (4 - (nbytes & 3)) & 3
|
125
|
+
if pad:
|
126
|
+
tb_p = torch.empty(nbytes + pad, dtype=torch.uint8, device=tb.device)
|
127
|
+
tb_p[:nbytes].copy_(tb)
|
128
|
+
tb_p[nbytes:].zero_()
|
129
|
+
tb = tb_p
|
130
|
+
return tb.view(torch.uint32)
|
44
131
|
|
45
|
-
PRIME_1 = -(11400714785074694791 ^ 0xFFFFFFFFFFFFFFFF) - 1
|
46
|
-
PRIME_2 = -(14029467366897019727 ^ 0xFFFFFFFFFFFFFFFF) - 1
|
47
132
|
|
133
|
+
def _final_splitmix64(x: int) -> int:
|
134
|
+
mask = (1 << 64) - 1
|
135
|
+
x &= mask
|
136
|
+
x ^= x >> 30
|
137
|
+
x = (x * 0xBF58476D1CE4E5B9) & mask
|
138
|
+
x ^= x >> 27
|
139
|
+
x = (x * 0x94D049BB133111EB) & mask
|
140
|
+
x ^= x >> 31
|
141
|
+
return x
|
48
142
|
|
49
|
-
def gpu_tensor_hash(tensor: torch.Tensor) -> int:
|
50
|
-
assert tensor.is_cuda
|
51
|
-
tensor = tensor.contiguous().view(torch.int32)
|
52
|
-
n = tensor.numel()
|
53
|
-
BLOCK_SIZE = 1024
|
54
|
-
grid = (triton.cdiv(n, BLOCK_SIZE),)
|
55
143
|
|
56
|
-
|
144
|
+
@torch.inference_mode()
|
145
|
+
def gpu_tensor_hash(
|
146
|
+
tensor: torch.Tensor,
|
147
|
+
*,
|
148
|
+
seed: int = 0x243F6A88,
|
149
|
+
tile_words: int = 8192,
|
150
|
+
block_words: int = 256,
|
151
|
+
reduce_chunk: int = 1024,
|
152
|
+
num_warps: int = 4,
|
153
|
+
num_stages: int = 4,
|
154
|
+
use_cg: bool = True,
|
155
|
+
) -> int:
|
156
|
+
assert tensor.is_cuda, "Use .cuda() first"
|
157
|
+
u32 = _as_uint32_words(tensor)
|
158
|
+
n = u32.numel()
|
159
|
+
if n == 0:
|
160
|
+
return 0
|
57
161
|
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
162
|
+
grid1 = (triton.cdiv(n, tile_words),)
|
163
|
+
partials = torch.empty(grid1[0], dtype=torch.uint64, device=u32.device)
|
164
|
+
hash_tiles32_kernel_blocked[grid1](
|
165
|
+
u32,
|
166
|
+
partials,
|
167
|
+
n,
|
168
|
+
seed1=seed & 0xFFFFFFFF,
|
169
|
+
seed2=((seed * 0x9E3779B1) ^ 0xDEADBEEF) & 0xFFFFFFFF,
|
170
|
+
FM_C1=FMIX32_C1,
|
171
|
+
FM_C2=FMIX32_C2,
|
172
|
+
POS_A=POS_C1,
|
173
|
+
POS_B=POS_C2,
|
174
|
+
TILE=tile_words,
|
175
|
+
BLOCK=block_words,
|
176
|
+
USE_CG=use_cg,
|
177
|
+
num_warps=num_warps,
|
178
|
+
num_stages=num_stages,
|
179
|
+
)
|
69
180
|
|
70
|
-
|
71
|
-
|
181
|
+
cur = partials
|
182
|
+
while cur.numel() > 1:
|
183
|
+
n_elems = cur.numel()
|
184
|
+
grid2 = (triton.cdiv(n_elems, reduce_chunk),)
|
185
|
+
nxt = torch.empty(grid2[0], dtype=torch.uint64, device=cur.device)
|
186
|
+
add_tree_reduce_u64_kernel[grid2](cur, nxt, n_elems, CHUNK=reduce_chunk)
|
187
|
+
cur = nxt
|
72
188
|
|
73
|
-
return
|
189
|
+
return _final_splitmix64(int(cur.item()))
|
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|
3
3
|
|
4
4
|
import builtins
|
5
5
|
import inspect
|
6
|
-
from typing import TYPE_CHECKING,
|
6
|
+
from typing import TYPE_CHECKING, Dict, Optional, Type
|
7
7
|
|
8
8
|
import torch
|
9
9
|
|
@@ -26,8 +26,9 @@ try:
|
|
26
26
|
from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
|
27
27
|
|
28
28
|
VLLM_AVAILABLE = True
|
29
|
-
except ImportError:
|
29
|
+
except ImportError as e:
|
30
30
|
VLLM_AVAILABLE = False
|
31
|
+
VLLM_IMPORT_ERROR = e
|
31
32
|
|
32
33
|
# Define empty classes as placeholders when vllm is not available
|
33
34
|
class DummyConfig:
|
@@ -54,13 +55,7 @@ if is_mxfp_supported:
|
|
54
55
|
from sglang.srt.layers.quantization.fp4 import MxFp4Config
|
55
56
|
|
56
57
|
from sglang.srt.layers.quantization.fp8 import Fp8Config
|
57
|
-
from sglang.srt.layers.quantization.gptq import
|
58
|
-
GPTQConfig,
|
59
|
-
GPTQLinearMethod,
|
60
|
-
GPTQMarlinConfig,
|
61
|
-
GPTQMarlinLinearMethod,
|
62
|
-
GPTQMarlinMoEMethod,
|
63
|
-
)
|
58
|
+
from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
|
64
59
|
from sglang.srt.layers.quantization.modelopt_quant import (
|
65
60
|
ModelOptFp4Config,
|
66
61
|
ModelOptFp8Config,
|
@@ -69,7 +64,6 @@ from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
|
|
69
64
|
from sglang.srt.layers.quantization.mxfp4 import Mxfp4Config
|
70
65
|
from sglang.srt.layers.quantization.petit import PetitNvFp4Config
|
71
66
|
from sglang.srt.layers.quantization.qoq import QoQConfig
|
72
|
-
from sglang.srt.layers.quantization.utils import get_linear_quant_method
|
73
67
|
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
|
74
68
|
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
|
75
69
|
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
|
@@ -85,6 +79,10 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
|
85
79
|
"modelopt_fp4": ModelOptFp4Config,
|
86
80
|
"w8a8_int8": W8A8Int8Config,
|
87
81
|
"w8a8_fp8": W8A8Fp8Config,
|
82
|
+
"awq": AWQConfig,
|
83
|
+
"awq_marlin": AWQMarlinConfig,
|
84
|
+
"gptq": GPTQConfig,
|
85
|
+
"gptq_marlin": GPTQMarlinConfig,
|
88
86
|
"moe_wna16": MoeWNA16Config,
|
89
87
|
"compressed-tensors": CompressedTensorsConfig,
|
90
88
|
"qoq": QoQConfig,
|
@@ -110,19 +108,15 @@ elif is_mxfp_supported and is_hip():
|
|
110
108
|
# VLLM-dependent quantization methods
|
111
109
|
VLLM_QUANTIZATION_METHODS = {
|
112
110
|
"aqlm": AQLMConfig,
|
113
|
-
"awq": AWQConfig,
|
114
111
|
"deepspeedfp": DeepSpeedFPConfig,
|
115
112
|
"tpu_int8": Int8TpuConfig,
|
116
113
|
"fbgemm_fp8": FBGEMMFp8Config,
|
117
114
|
"marlin": MarlinConfig,
|
118
115
|
"gguf": GGUFConfig,
|
119
116
|
"gptq_marlin_24": GPTQMarlin24Config,
|
120
|
-
"awq_marlin": AWQMarlinConfig,
|
121
117
|
"bitsandbytes": BitsAndBytesConfig,
|
122
118
|
"qqq": QQQConfig,
|
123
119
|
"experts_int8": ExpertsInt8Config,
|
124
|
-
"gptq_marlin": GPTQMarlinConfig,
|
125
|
-
"gptq": GPTQConfig,
|
126
120
|
}
|
127
121
|
|
128
122
|
QUANTIZATION_METHODS = {**BASE_QUANTIZATION_METHODS, **VLLM_QUANTIZATION_METHODS}
|
@@ -137,29 +131,13 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
|
137
131
|
if quantization in VLLM_QUANTIZATION_METHODS and not VLLM_AVAILABLE:
|
138
132
|
raise ValueError(
|
139
133
|
f"{quantization} quantization requires some operators from vllm. "
|
140
|
-
"Please install vllm by `pip install vllm==0.9.0.1
|
134
|
+
f"Please install vllm by `pip install vllm==0.9.0.1`\n"
|
135
|
+
f"Import error: {VLLM_IMPORT_ERROR}"
|
141
136
|
)
|
142
137
|
|
143
138
|
return QUANTIZATION_METHODS[quantization]
|
144
139
|
|
145
140
|
|
146
|
-
def gptq_get_quant_method(self, layer, prefix):
|
147
|
-
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
148
|
-
|
149
|
-
if isinstance(layer, FusedMoE):
|
150
|
-
return GPTQMarlinMoEMethod(self)
|
151
|
-
|
152
|
-
if isinstance(self, GPTQConfig):
|
153
|
-
return get_linear_quant_method(
|
154
|
-
self, layer, prefix=prefix, linear_method_cls=GPTQLinearMethod
|
155
|
-
)
|
156
|
-
elif isinstance(self, GPTQMarlinConfig):
|
157
|
-
return get_linear_quant_method(
|
158
|
-
self, layer, prefix=prefix, linear_method_cls=GPTQMarlinLinearMethod
|
159
|
-
)
|
160
|
-
return None
|
161
|
-
|
162
|
-
|
163
141
|
original_isinstance = builtins.isinstance
|
164
142
|
|
165
143
|
|
@@ -237,10 +215,7 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
|
|
237
215
|
|
238
216
|
def monkey_patch_quant_configs():
|
239
217
|
"""Apply all monkey patches in one place."""
|
240
|
-
setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
|
241
|
-
setattr(GPTQConfig, "get_quant_method", gptq_get_quant_method)
|
242
218
|
|
243
|
-
monkey_patch_moe_apply(GPTQMarlinMoEMethod)
|
244
219
|
monkey_patch_moe_apply(CompressedTensorsW8A8Fp8MoEMethod)
|
245
220
|
monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod)
|
246
221
|
|