sglang 0.4.7.post1__py3-none-any.whl → 0.4.8.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +8 -6
- sglang/srt/_custom_ops.py +2 -2
- sglang/srt/code_completion_parser.py +2 -44
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/constants.py +3 -0
- sglang/srt/conversation.py +14 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/disaggregation/base/conn.py +2 -0
- sglang/srt/disaggregation/decode.py +22 -28
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
- sglang/srt/disaggregation/mini_lb.py +34 -4
- sglang/srt/disaggregation/mooncake/conn.py +301 -64
- sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
- sglang/srt/disaggregation/nixl/conn.py +94 -46
- sglang/srt/disaggregation/prefill.py +20 -15
- sglang/srt/disaggregation/utils.py +47 -18
- sglang/srt/distributed/parallel_state.py +12 -4
- sglang/srt/entrypoints/engine.py +27 -31
- sglang/srt/entrypoints/http_server.py +149 -79
- sglang/srt/entrypoints/http_server_engine.py +0 -3
- sglang/srt/entrypoints/openai/__init__.py +0 -0
- sglang/srt/{openai_api → entrypoints/openai}/protocol.py +115 -34
- sglang/srt/entrypoints/openai/serving_base.py +149 -0
- sglang/srt/entrypoints/openai/serving_chat.py +897 -0
- sglang/srt/entrypoints/openai/serving_completions.py +425 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +170 -0
- sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
- sglang/srt/entrypoints/openai/serving_score.py +61 -0
- sglang/srt/entrypoints/openai/usage_processor.py +81 -0
- sglang/srt/entrypoints/openai/utils.py +72 -0
- sglang/srt/function_call/base_format_detector.py +7 -4
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +64 -10
- sglang/srt/function_call/function_call_parser.py +6 -6
- sglang/srt/function_call/llama32_detector.py +1 -1
- sglang/srt/function_call/mistral_detector.py +1 -1
- sglang/srt/function_call/pythonic_detector.py +1 -1
- sglang/srt/function_call/qwen25_detector.py +1 -1
- sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
- sglang/srt/layers/activation.py +28 -3
- sglang/srt/layers/attention/aiter_backend.py +5 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/cutlass_mla_backend.py +1 -0
- sglang/srt/layers/attention/flashattention_backend.py +43 -23
- sglang/srt/layers/attention/flashinfer_backend.py +9 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +7 -4
- sglang/srt/layers/attention/flashmla_backend.py +5 -2
- sglang/srt/layers/attention/tbo_backend.py +3 -3
- sglang/srt/layers/attention/triton_backend.py +19 -11
- sglang/srt/layers/communicator.py +5 -5
- sglang/srt/layers/dp_attention.py +11 -2
- sglang/srt/layers/layernorm.py +44 -2
- sglang/srt/layers/linear.py +18 -1
- sglang/srt/layers/logits_processor.py +14 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +159 -2
- sglang/srt/layers/moe/ep_moe/layer.py +286 -13
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +19 -2
- sglang/srt/layers/moe/fused_moe_native.py +7 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +13 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +148 -26
- sglang/srt/layers/moe/topk.py +117 -4
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
- sglang/srt/layers/quantization/fp8.py +25 -17
- sglang/srt/layers/quantization/fp8_utils.py +5 -4
- sglang/srt/layers/quantization/modelopt_quant.py +62 -8
- sglang/srt/layers/quantization/utils.py +5 -2
- sglang/srt/layers/rotary_embedding.py +144 -12
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/layers/vocab_parallel_embedding.py +14 -1
- sglang/srt/lora/lora_manager.py +173 -74
- sglang/srt/lora/mem_pool.py +49 -45
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +33 -15
- sglang/srt/managers/expert_distribution.py +21 -0
- sglang/srt/managers/io_struct.py +19 -14
- sglang/srt/managers/multimodal_processors/base_processor.py +44 -9
- sglang/srt/managers/multimodal_processors/gemma3n.py +97 -0
- sglang/srt/managers/schedule_batch.py +49 -32
- sglang/srt/managers/schedule_policy.py +70 -56
- sglang/srt/managers/scheduler.py +189 -68
- sglang/srt/managers/template_manager.py +226 -0
- sglang/srt/managers/tokenizer_manager.py +11 -8
- sglang/srt/managers/tp_worker.py +12 -2
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
- sglang/srt/mem_cache/base_prefix_cache.py +52 -8
- sglang/srt/mem_cache/chunk_cache.py +11 -16
- sglang/srt/mem_cache/hiradix_cache.py +34 -23
- sglang/srt/mem_cache/memory_pool.py +118 -114
- sglang/srt/mem_cache/radix_cache.py +20 -16
- sglang/srt/model_executor/cuda_graph_runner.py +77 -46
- sglang/srt/model_executor/forward_batch_info.py +18 -5
- sglang/srt/model_executor/model_runner.py +27 -8
- sglang/srt/model_loader/loader.py +50 -8
- sglang/srt/model_loader/weight_utils.py +100 -2
- sglang/srt/models/deepseek_nextn.py +35 -30
- sglang/srt/models/deepseek_v2.py +255 -30
- sglang/srt/models/gemma3n_audio.py +949 -0
- sglang/srt/models/gemma3n_causal.py +1009 -0
- sglang/srt/models/gemma3n_mm.py +511 -0
- sglang/srt/models/glm4.py +312 -0
- sglang/srt/models/hunyuan.py +771 -0
- sglang/srt/models/mimo_mtp.py +2 -18
- sglang/srt/reasoning_parser.py +21 -11
- sglang/srt/server_args.py +51 -9
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -10
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +125 -12
- sglang/srt/speculative/eagle_utils.py +80 -8
- sglang/srt/speculative/eagle_worker.py +124 -41
- sglang/srt/torch_memory_saver_adapter.py +19 -15
- sglang/srt/two_batch_overlap.py +4 -1
- sglang/srt/utils.py +248 -11
- sglang/test/test_block_fp8_ep.py +1 -0
- sglang/test/test_utils.py +1 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/METADATA +4 -10
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/RECORD +121 -105
- sglang/srt/entrypoints/verl_engine.py +0 -179
- sglang/srt/openai_api/adapter.py +0 -2148
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 16,
|
4
|
+
"BLOCK_SIZE_N": 128,
|
5
|
+
"BLOCK_SIZE_K": 128,
|
6
|
+
"GROUP_SIZE_M": 1,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 4
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 16,
|
12
|
+
"BLOCK_SIZE_N": 64,
|
13
|
+
"BLOCK_SIZE_K": 128,
|
14
|
+
"GROUP_SIZE_M": 1,
|
15
|
+
"num_warps": 4,
|
16
|
+
"num_stages": 5
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 16,
|
20
|
+
"BLOCK_SIZE_N": 128,
|
21
|
+
"BLOCK_SIZE_K": 128,
|
22
|
+
"GROUP_SIZE_M": 1,
|
23
|
+
"num_warps": 4,
|
24
|
+
"num_stages": 3
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 16,
|
28
|
+
"BLOCK_SIZE_N": 128,
|
29
|
+
"BLOCK_SIZE_K": 128,
|
30
|
+
"GROUP_SIZE_M": 1,
|
31
|
+
"num_warps": 4,
|
32
|
+
"num_stages": 3
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 16,
|
36
|
+
"BLOCK_SIZE_N": 128,
|
37
|
+
"BLOCK_SIZE_K": 128,
|
38
|
+
"GROUP_SIZE_M": 32,
|
39
|
+
"num_warps": 4,
|
40
|
+
"num_stages": 3
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 16,
|
44
|
+
"BLOCK_SIZE_N": 64,
|
45
|
+
"BLOCK_SIZE_K": 128,
|
46
|
+
"GROUP_SIZE_M": 1,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 3
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 16,
|
52
|
+
"BLOCK_SIZE_N": 128,
|
53
|
+
"BLOCK_SIZE_K": 128,
|
54
|
+
"GROUP_SIZE_M": 32,
|
55
|
+
"num_warps": 4,
|
56
|
+
"num_stages": 3
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 16,
|
60
|
+
"BLOCK_SIZE_N": 128,
|
61
|
+
"BLOCK_SIZE_K": 128,
|
62
|
+
"GROUP_SIZE_M": 32,
|
63
|
+
"num_warps": 4,
|
64
|
+
"num_stages": 5
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 16,
|
68
|
+
"BLOCK_SIZE_N": 128,
|
69
|
+
"BLOCK_SIZE_K": 128,
|
70
|
+
"GROUP_SIZE_M": 64,
|
71
|
+
"num_warps": 4,
|
72
|
+
"num_stages": 5
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 16,
|
76
|
+
"BLOCK_SIZE_N": 128,
|
77
|
+
"BLOCK_SIZE_K": 128,
|
78
|
+
"GROUP_SIZE_M": 16,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 5
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 16,
|
84
|
+
"BLOCK_SIZE_N": 128,
|
85
|
+
"BLOCK_SIZE_K": 128,
|
86
|
+
"GROUP_SIZE_M": 16,
|
87
|
+
"num_warps": 4,
|
88
|
+
"num_stages": 5
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 64,
|
92
|
+
"BLOCK_SIZE_N": 128,
|
93
|
+
"BLOCK_SIZE_K": 128,
|
94
|
+
"GROUP_SIZE_M": 16,
|
95
|
+
"num_warps": 4,
|
96
|
+
"num_stages": 4
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 64,
|
100
|
+
"BLOCK_SIZE_N": 128,
|
101
|
+
"BLOCK_SIZE_K": 128,
|
102
|
+
"GROUP_SIZE_M": 32,
|
103
|
+
"num_warps": 4,
|
104
|
+
"num_stages": 4
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 64,
|
108
|
+
"BLOCK_SIZE_N": 128,
|
109
|
+
"BLOCK_SIZE_K": 128,
|
110
|
+
"GROUP_SIZE_M": 16,
|
111
|
+
"num_warps": 4,
|
112
|
+
"num_stages": 4
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 64,
|
116
|
+
"BLOCK_SIZE_N": 128,
|
117
|
+
"BLOCK_SIZE_K": 128,
|
118
|
+
"GROUP_SIZE_M": 16,
|
119
|
+
"num_warps": 4,
|
120
|
+
"num_stages": 4
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 64,
|
124
|
+
"BLOCK_SIZE_N": 128,
|
125
|
+
"BLOCK_SIZE_K": 128,
|
126
|
+
"GROUP_SIZE_M": 16,
|
127
|
+
"num_warps": 4,
|
128
|
+
"num_stages": 4
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 64,
|
132
|
+
"BLOCK_SIZE_N": 128,
|
133
|
+
"BLOCK_SIZE_K": 128,
|
134
|
+
"GROUP_SIZE_M": 16,
|
135
|
+
"num_warps": 4,
|
136
|
+
"num_stages": 4
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 64,
|
140
|
+
"BLOCK_SIZE_N": 128,
|
141
|
+
"BLOCK_SIZE_K": 128,
|
142
|
+
"GROUP_SIZE_M": 16,
|
143
|
+
"num_warps": 4,
|
144
|
+
"num_stages": 4
|
145
|
+
}
|
146
|
+
}
|
@@ -25,9 +25,11 @@ from sglang.srt.layers.quantization.int8_kernel import (
|
|
25
25
|
sglang_per_token_group_quant_int8,
|
26
26
|
)
|
27
27
|
from sglang.srt.utils import (
|
28
|
+
cpu_has_amx_support,
|
28
29
|
direct_register_custom_op,
|
29
30
|
get_bool_env_var,
|
30
31
|
get_device_name,
|
32
|
+
is_cpu,
|
31
33
|
is_cuda,
|
32
34
|
is_hip,
|
33
35
|
log_info_on_rank0,
|
@@ -36,9 +38,13 @@ from sglang.srt.utils import (
|
|
36
38
|
|
37
39
|
_is_hip = is_hip()
|
38
40
|
_is_cuda = is_cuda()
|
41
|
+
_is_cpu_amx_available = cpu_has_amx_support()
|
42
|
+
_is_cpu = is_cpu()
|
39
43
|
|
40
44
|
if _is_cuda:
|
41
45
|
from sgl_kernel import gelu_and_mul, silu_and_mul
|
46
|
+
elif _is_cpu and _is_cpu_amx_available:
|
47
|
+
pass
|
42
48
|
else:
|
43
49
|
from vllm import _custom_ops as vllm_ops
|
44
50
|
from vllm._custom_ops import scaled_fp8_quant
|
@@ -744,9 +750,11 @@ def moe_align_block_size(
|
|
744
750
|
by block_size for proper block matrix operations.
|
745
751
|
"""
|
746
752
|
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
747
|
-
sorted_ids
|
748
|
-
max_num_tokens_padded,
|
753
|
+
sorted_ids = torch.empty(
|
754
|
+
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
|
749
755
|
)
|
756
|
+
sorted_ids.fill_(topk_ids.numel())
|
757
|
+
|
750
758
|
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
|
751
759
|
expert_ids = torch.empty(
|
752
760
|
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
@@ -762,6 +770,9 @@ def moe_align_block_size(
|
|
762
770
|
num_tokens_post_pad,
|
763
771
|
)
|
764
772
|
else:
|
773
|
+
cumsum_buffer = torch.empty(
|
774
|
+
(num_experts + 1,), dtype=torch.int32, device=topk_ids.device
|
775
|
+
)
|
765
776
|
token_cnts_buffer = torch.empty(
|
766
777
|
(num_experts + 1) * num_experts,
|
767
778
|
dtype=torch.int32,
|
@@ -18,7 +18,14 @@ from sglang.srt.layers.quantization.base_config import (
|
|
18
18
|
QuantizationConfig,
|
19
19
|
QuantizeMethodBase,
|
20
20
|
)
|
21
|
-
from sglang.srt.utils import
|
21
|
+
from sglang.srt.utils import (
|
22
|
+
_process_weight_after_loading,
|
23
|
+
cpu_has_amx_support,
|
24
|
+
get_bool_env_var,
|
25
|
+
is_cpu,
|
26
|
+
is_hip,
|
27
|
+
set_weight_attrs,
|
28
|
+
)
|
22
29
|
|
23
30
|
if torch.cuda.is_available():
|
24
31
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
@@ -28,10 +35,13 @@ else:
|
|
28
35
|
import logging
|
29
36
|
|
30
37
|
_is_hip = is_hip()
|
38
|
+
_is_cpu_amx_available = cpu_has_amx_support()
|
39
|
+
_is_cpu = is_cpu()
|
31
40
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
32
41
|
|
33
42
|
if _use_aiter:
|
34
43
|
from aiter import ActivationType
|
44
|
+
from aiter.fused_moe import fused_moe
|
35
45
|
from aiter.fused_moe_bf16_asm import ck_moe_2stages
|
36
46
|
from aiter.ops.shuffle import shuffle_weight
|
37
47
|
|
@@ -116,6 +126,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
116
126
|
requires_grad=False,
|
117
127
|
)
|
118
128
|
torch.cuda.empty_cache()
|
129
|
+
|
130
|
+
# Pack weight for get better performance on CPU
|
131
|
+
if _is_cpu and _is_cpu_amx_available:
|
132
|
+
_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
|
133
|
+
|
119
134
|
return
|
120
135
|
|
121
136
|
def apply(
|
@@ -204,7 +219,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
204
219
|
topk_weights, dtype=torch.float32
|
205
220
|
) # topk_weights must be FP32 (float32)
|
206
221
|
|
207
|
-
return
|
222
|
+
return fused_moe(
|
208
223
|
x,
|
209
224
|
layer.w13_weight,
|
210
225
|
layer.w2_weight,
|
@@ -241,26 +256,75 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
241
256
|
num_fused_shared_experts: int = 0,
|
242
257
|
custom_routing_function: Optional[Callable] = None,
|
243
258
|
correction_bias: Optional[torch.Tensor] = None,
|
259
|
+
activation: str = "silu",
|
260
|
+
apply_router_weight_on_input: bool = False,
|
244
261
|
inplace: bool = True,
|
262
|
+
no_combine: bool = False,
|
263
|
+
routed_scaling_factor: Optional[float] = None,
|
245
264
|
) -> torch.Tensor:
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
265
|
+
assert activation == "silu", f"activation = {activation} is not supported."
|
266
|
+
|
267
|
+
if (
|
268
|
+
getattr(layer, "use_intel_amx_backend", False)
|
269
|
+
and not apply_router_weight_on_input
|
270
|
+
):
|
271
|
+
topk_weights, topk_ids = select_experts(
|
272
|
+
hidden_states=x,
|
273
|
+
router_logits=router_logits,
|
274
|
+
use_grouped_topk=use_grouped_topk,
|
275
|
+
top_k=top_k,
|
276
|
+
renormalize=renormalize,
|
277
|
+
topk_group=topk_group,
|
278
|
+
num_expert_group=num_expert_group,
|
279
|
+
num_fused_shared_experts=num_fused_shared_experts,
|
280
|
+
custom_routing_function=custom_routing_function,
|
281
|
+
correction_bias=correction_bias,
|
282
|
+
routed_scaling_factor=routed_scaling_factor,
|
283
|
+
)
|
284
|
+
|
285
|
+
# TODO: support apply_router_weight_on_input in the fused_experts_cpu kernel
|
286
|
+
return torch.ops.sgl_kernel.fused_experts_cpu(
|
287
|
+
x,
|
288
|
+
layer.w13_weight,
|
289
|
+
layer.w2_weight,
|
290
|
+
topk_weights.to(
|
291
|
+
torch.float
|
292
|
+
), # TODO: the topk_weights of llama4 is computed via Llama4MoE:custom_routing_function and is bfloat16 while the kernel requires it to be float32
|
293
|
+
topk_ids,
|
294
|
+
True, # inplace
|
295
|
+
False, # use_int8_w8a8
|
296
|
+
False, # use_fp8_w8a16
|
297
|
+
None, # w1_scale
|
298
|
+
None, # w2_scale
|
299
|
+
None, # block_size
|
300
|
+
None, # a1_scale
|
301
|
+
None, # a2_scale
|
302
|
+
True, # is_vnni
|
303
|
+
)
|
304
|
+
else:
|
305
|
+
return moe_forward_native(
|
306
|
+
layer,
|
307
|
+
x,
|
308
|
+
use_grouped_topk,
|
309
|
+
top_k,
|
310
|
+
router_logits,
|
311
|
+
renormalize,
|
312
|
+
topk_group,
|
313
|
+
num_expert_group,
|
314
|
+
num_fused_shared_experts,
|
315
|
+
custom_routing_function,
|
316
|
+
correction_bias,
|
317
|
+
activation,
|
318
|
+
apply_router_weight_on_input,
|
319
|
+
inplace,
|
320
|
+
no_combine,
|
321
|
+
routed_scaling_factor,
|
322
|
+
)
|
259
323
|
|
260
324
|
def forward_tpu(self, *args, **kwargs) -> torch.Tensor:
|
261
325
|
raise NotImplementedError("The TPU backend currently does not support MoE.")
|
262
326
|
|
263
|
-
forward_native =
|
327
|
+
forward_native = forward_cpu
|
264
328
|
|
265
329
|
|
266
330
|
class FusedMoE(torch.nn.Module):
|
@@ -310,6 +374,8 @@ class FusedMoE(torch.nn.Module):
|
|
310
374
|
inplace: bool = True,
|
311
375
|
no_combine: bool = False,
|
312
376
|
routed_scaling_factor: Optional[float] = None,
|
377
|
+
enable_flashinfer_moe: Optional[bool] = False,
|
378
|
+
enable_ep_moe: Optional[bool] = False,
|
313
379
|
):
|
314
380
|
super().__init__()
|
315
381
|
|
@@ -320,9 +386,40 @@ class FusedMoE(torch.nn.Module):
|
|
320
386
|
self.tp_size = (
|
321
387
|
tp_size if tp_size is not None else get_tensor_model_parallel_world_size()
|
322
388
|
)
|
389
|
+
self.tp_rank = get_tensor_model_parallel_rank()
|
390
|
+
self.num_experts = num_experts
|
391
|
+
self.expert_map = None
|
392
|
+
|
393
|
+
if enable_flashinfer_moe and quant_config is None:
|
394
|
+
logger.warning("Disable flashinfer MoE when quantization config is None.")
|
395
|
+
enable_flashinfer_moe = False
|
396
|
+
enable_ep_moe = False
|
397
|
+
|
398
|
+
self.enable_flashinfer_moe = enable_flashinfer_moe
|
399
|
+
if enable_ep_moe:
|
400
|
+
assert (
|
401
|
+
self.enable_flashinfer_moe
|
402
|
+
), "FusedMoE only supports EP with --enable-flashinfer-moe"
|
403
|
+
self.ep_size = self.tp_size
|
404
|
+
self.ep_rank = self.tp_rank
|
405
|
+
self.tp_size = 1
|
406
|
+
self.tp_rank = 0
|
407
|
+
# Create a tensor of size num_experts filled with -1
|
408
|
+
self.expert_map = torch.full((self.num_experts,), -1, dtype=torch.int32)
|
409
|
+
# Create a expert map for the local experts
|
410
|
+
assert num_experts % self.ep_size == 0
|
411
|
+
self.local_num_experts = num_experts // self.ep_size
|
412
|
+
self.expert_map[
|
413
|
+
self.ep_rank
|
414
|
+
* self.local_num_experts : (self.ep_rank + 1)
|
415
|
+
* self.local_num_experts
|
416
|
+
] = torch.arange(0, self.local_num_experts, dtype=torch.int32, device="cpu")
|
417
|
+
else:
|
418
|
+
self.ep_size = 1
|
419
|
+
self.ep_rank = 0
|
420
|
+
self.local_num_experts = num_experts
|
323
421
|
self.routed_scaling_factor = routed_scaling_factor
|
324
422
|
self.top_k = top_k
|
325
|
-
self.num_experts = num_experts
|
326
423
|
assert intermediate_size % self.tp_size == 0
|
327
424
|
self.intermediate_size_per_partition = intermediate_size // self.tp_size
|
328
425
|
self.reduce_results = reduce_results
|
@@ -340,7 +437,6 @@ class FusedMoE(torch.nn.Module):
|
|
340
437
|
self.use_presharded_weights = use_presharded_weights
|
341
438
|
self.inplace = inplace
|
342
439
|
self.no_combine = no_combine
|
343
|
-
self.local_num_experts = num_experts
|
344
440
|
|
345
441
|
if quant_config is None:
|
346
442
|
self.quant_method: Optional[QuantizeMethodBase] = (
|
@@ -348,11 +444,13 @@ class FusedMoE(torch.nn.Module):
|
|
348
444
|
)
|
349
445
|
else:
|
350
446
|
self.quant_method = quant_config.get_quant_method(self, prefix)
|
447
|
+
if self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod":
|
448
|
+
self.quant_method.enable_flashinfer_moe = self.enable_flashinfer_moe
|
351
449
|
assert self.quant_method is not None
|
352
450
|
|
353
451
|
self.quant_method.create_weights(
|
354
452
|
layer=self,
|
355
|
-
num_experts=
|
453
|
+
num_experts=self.local_num_experts,
|
356
454
|
hidden_size=hidden_size,
|
357
455
|
# FIXME: figure out which intermediate_size to use
|
358
456
|
intermediate_size=self.intermediate_size_per_partition,
|
@@ -446,12 +544,15 @@ class FusedMoE(torch.nn.Module):
|
|
446
544
|
|
447
545
|
# Narrow parameter and load.
|
448
546
|
# w1, gate_proj: Load into first logical weight of w13.
|
449
|
-
if shard_id == "w1":
|
450
|
-
expert_data = expert_data.narrow(shard_dim, 0, shard_size)
|
451
547
|
# w3, up_proj: Load into second logical weight of w13.
|
548
|
+
# trtllm cutlass kernel assumes differently
|
549
|
+
assert shard_id in ("w1", "w3")
|
550
|
+
switch_w13 = getattr(self.quant_method, "load_up_proj_weight_first", False)
|
551
|
+
if (switch_w13 and shard_id == "w1") or (not switch_w13 and shard_id == "w3"):
|
552
|
+
start = shard_size
|
452
553
|
else:
|
453
|
-
|
454
|
-
|
554
|
+
start = 0
|
555
|
+
expert_data = expert_data.narrow(shard_dim, start, shard_size)
|
455
556
|
expert_data.copy_(loaded_weight)
|
456
557
|
|
457
558
|
def _load_w2(
|
@@ -505,6 +606,11 @@ class FusedMoE(torch.nn.Module):
|
|
505
606
|
assert shard_id in ("w1", "w3")
|
506
607
|
expert_data.copy_(loaded_weight)
|
507
608
|
|
609
|
+
def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int:
|
610
|
+
if self.expert_map is None:
|
611
|
+
return expert_id
|
612
|
+
return self.expert_map[expert_id].item()
|
613
|
+
|
508
614
|
def weight_loader(
|
509
615
|
self,
|
510
616
|
param: torch.nn.Parameter,
|
@@ -513,6 +619,13 @@ class FusedMoE(torch.nn.Module):
|
|
513
619
|
shard_id: str,
|
514
620
|
expert_id: int,
|
515
621
|
) -> None:
|
622
|
+
expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
|
623
|
+
if expert_id == -1:
|
624
|
+
return
|
625
|
+
|
626
|
+
# TP rank is set to 0 if EP is enabled
|
627
|
+
tp_rank = 0 if self.ep_size > 1 else get_tensor_model_parallel_rank()
|
628
|
+
|
516
629
|
# compressed-tensors checkpoints with packed weights are stored flipped
|
517
630
|
# TODO (mgoin): check self.quant_method.quant_config.quant_format
|
518
631
|
# against known CompressionFormat enum values that have this quality
|
@@ -537,7 +650,6 @@ class FusedMoE(torch.nn.Module):
|
|
537
650
|
SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}
|
538
651
|
|
539
652
|
expert_data = param.data[expert_id]
|
540
|
-
tp_rank = get_tensor_model_parallel_rank()
|
541
653
|
|
542
654
|
# is_transposed: if the dim to shard the weight
|
543
655
|
# should be flipped. Required by GPTQ, compressed-tensors
|
@@ -545,7 +657,7 @@ class FusedMoE(torch.nn.Module):
|
|
545
657
|
is_transposed = getattr(param, "is_transposed", False)
|
546
658
|
shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
|
547
659
|
if is_transposed:
|
548
|
-
shard_dim =
|
660
|
+
shard_dim = int(not shard_dim)
|
549
661
|
|
550
662
|
# Case input scale: input_scale loading is only supported for fp8
|
551
663
|
if "input_scale" in weight_name:
|
@@ -686,9 +798,19 @@ class FusedMoE(torch.nn.Module):
|
|
686
798
|
activation=self.activation,
|
687
799
|
apply_router_weight_on_input=self.apply_router_weight_on_input,
|
688
800
|
routed_scaling_factor=self.routed_scaling_factor,
|
801
|
+
**(
|
802
|
+
dict(
|
803
|
+
tp_rank=self.tp_rank,
|
804
|
+
tp_size=self.tp_size,
|
805
|
+
ep_rank=self.ep_rank,
|
806
|
+
ep_size=self.ep_size,
|
807
|
+
)
|
808
|
+
if self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod"
|
809
|
+
else {}
|
810
|
+
),
|
689
811
|
)
|
690
812
|
|
691
|
-
if self.reduce_results and self.tp_size > 1:
|
813
|
+
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
|
692
814
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
693
815
|
|
694
816
|
return final_hidden_states
|
sglang/srt/layers/moe/topk.py
CHANGED
@@ -28,19 +28,34 @@ from sglang.srt.managers.expert_location_dispatch import (
|
|
28
28
|
topk_ids_logical_to_physical,
|
29
29
|
)
|
30
30
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
31
|
-
from sglang.srt.utils import
|
31
|
+
from sglang.srt.utils import (
|
32
|
+
cpu_has_amx_support,
|
33
|
+
get_bool_env_var,
|
34
|
+
get_compiler_backend,
|
35
|
+
is_cpu,
|
36
|
+
is_cuda,
|
37
|
+
is_hip,
|
38
|
+
)
|
32
39
|
|
33
40
|
_is_cuda = is_cuda()
|
34
41
|
_is_hip = is_hip()
|
42
|
+
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
43
|
+
_is_cpu_amx_available = cpu_has_amx_support()
|
44
|
+
_is_cpu = is_cpu()
|
35
45
|
|
36
46
|
if _is_cuda:
|
37
47
|
from sgl_kernel import moe_fused_gate
|
38
48
|
|
39
49
|
if _is_cuda or _is_hip:
|
40
50
|
from sgl_kernel import topk_softmax
|
51
|
+
if _use_aiter:
|
52
|
+
try:
|
53
|
+
from aiter import biased_grouped_topk as aiter_biased_grouped_topk
|
54
|
+
except ImportError:
|
55
|
+
raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")
|
41
56
|
|
42
57
|
|
43
|
-
def
|
58
|
+
def fused_topk_torch_native(
|
44
59
|
hidden_states: torch.Tensor,
|
45
60
|
gating_output: torch.Tensor,
|
46
61
|
topk: int,
|
@@ -61,6 +76,20 @@ def fused_topk_native(
|
|
61
76
|
return topk_weights, topk_ids
|
62
77
|
|
63
78
|
|
79
|
+
def fused_topk_cpu(
|
80
|
+
hidden_states: torch.Tensor,
|
81
|
+
gating_output: torch.Tensor,
|
82
|
+
topk: int,
|
83
|
+
renormalize: bool,
|
84
|
+
):
|
85
|
+
return torch.ops.sgl_kernel.topk_softmax_cpu(
|
86
|
+
hidden_states=hidden_states,
|
87
|
+
gating_output=gating_output,
|
88
|
+
topk=topk,
|
89
|
+
renormalize=renormalize,
|
90
|
+
)
|
91
|
+
|
92
|
+
|
64
93
|
def fused_topk(
|
65
94
|
hidden_states: torch.Tensor,
|
66
95
|
gating_output: torch.Tensor,
|
@@ -115,7 +144,7 @@ def _fused_topk_postprocess(
|
|
115
144
|
|
116
145
|
# This is used by the Deepseek V2/V3/R1 series models
|
117
146
|
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
118
|
-
def
|
147
|
+
def grouped_topk_gpu(
|
119
148
|
hidden_states: torch.Tensor,
|
120
149
|
gating_output: torch.Tensor,
|
121
150
|
topk: int,
|
@@ -171,6 +200,32 @@ def grouped_topk(
|
|
171
200
|
return topk_weights, topk_ids
|
172
201
|
|
173
202
|
|
203
|
+
def grouped_topk_cpu(
|
204
|
+
hidden_states: torch.Tensor,
|
205
|
+
gating_output: torch.Tensor,
|
206
|
+
topk: int,
|
207
|
+
renormalize: bool,
|
208
|
+
num_expert_group: int = 0,
|
209
|
+
topk_group: int = 0,
|
210
|
+
num_fused_shared_experts: int = 0,
|
211
|
+
routed_scaling_factor: Optional[float] = None,
|
212
|
+
num_token_non_padded: Optional[torch.Tensor] = None,
|
213
|
+
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
214
|
+
):
|
215
|
+
assert expert_location_dispatch_info is None
|
216
|
+
return torch.ops.sgl_kernel.grouped_topk_cpu(
|
217
|
+
hidden_states,
|
218
|
+
gating_output,
|
219
|
+
topk,
|
220
|
+
renormalize,
|
221
|
+
num_expert_group,
|
222
|
+
topk_group,
|
223
|
+
num_fused_shared_experts,
|
224
|
+
routed_scaling_factor,
|
225
|
+
num_token_non_padded,
|
226
|
+
)
|
227
|
+
|
228
|
+
|
174
229
|
def biased_grouped_topk_impl(
|
175
230
|
hidden_states: torch.Tensor,
|
176
231
|
gating_output: torch.Tensor,
|
@@ -258,7 +313,7 @@ def _biased_grouped_topk_postprocess(
|
|
258
313
|
return topk_ids
|
259
314
|
|
260
315
|
|
261
|
-
def
|
316
|
+
def biased_grouped_topk_gpu(
|
262
317
|
hidden_states: torch.Tensor,
|
263
318
|
gating_output: torch.Tensor,
|
264
319
|
correction_bias: torch.Tensor,
|
@@ -299,6 +354,25 @@ def biased_grouped_topk(
|
|
299
354
|
topk_ids, expert_location_dispatch_info, num_token_non_padded
|
300
355
|
)
|
301
356
|
return topk_weights, topk_ids
|
357
|
+
elif _use_aiter:
|
358
|
+
token = gating_output.shape[0]
|
359
|
+
device = gating_output.device
|
360
|
+
assert (
|
361
|
+
hidden_states.shape[0] == gating_output.shape[0]
|
362
|
+
), f"Number of tokens mismatch: hidden_states.shape[0] = {hidden_states.shape[0]}, gating_output.shape[0] = {gating_output.shape[0]}"
|
363
|
+
topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device)
|
364
|
+
topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device)
|
365
|
+
aiter_biased_grouped_topk(
|
366
|
+
gating_output,
|
367
|
+
correction_bias,
|
368
|
+
topk_weights,
|
369
|
+
topk_ids,
|
370
|
+
num_expert_group,
|
371
|
+
topk_group,
|
372
|
+
renormalize,
|
373
|
+
routed_scaling_factor,
|
374
|
+
)
|
375
|
+
return topk_weights, topk_ids
|
302
376
|
else:
|
303
377
|
biased_grouped_topk_fn = (
|
304
378
|
torch.compile(
|
@@ -322,6 +396,45 @@ def biased_grouped_topk(
|
|
322
396
|
)
|
323
397
|
|
324
398
|
|
399
|
+
def biased_grouped_topk_cpu(
|
400
|
+
hidden_states: torch.Tensor,
|
401
|
+
gating_output: torch.Tensor,
|
402
|
+
correction_bias: torch.Tensor,
|
403
|
+
topk: int,
|
404
|
+
renormalize: bool,
|
405
|
+
num_expert_group: int = 0,
|
406
|
+
topk_group: int = 0,
|
407
|
+
compiled: bool = True,
|
408
|
+
num_fused_shared_experts: int = 0,
|
409
|
+
routed_scaling_factor: Optional[float] = None,
|
410
|
+
num_token_non_padded: Optional[torch.Tensor] = None,
|
411
|
+
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
412
|
+
):
|
413
|
+
assert expert_location_dispatch_info is None
|
414
|
+
return torch.ops.sgl_kernel.biased_grouped_topk_cpu(
|
415
|
+
hidden_states,
|
416
|
+
gating_output,
|
417
|
+
correction_bias,
|
418
|
+
topk,
|
419
|
+
renormalize,
|
420
|
+
num_expert_group,
|
421
|
+
topk_group,
|
422
|
+
num_fused_shared_experts,
|
423
|
+
routed_scaling_factor,
|
424
|
+
num_token_non_padded,
|
425
|
+
)
|
426
|
+
|
427
|
+
|
428
|
+
if _is_cpu and _is_cpu_amx_available:
|
429
|
+
biased_grouped_topk = biased_grouped_topk_cpu
|
430
|
+
grouped_topk = grouped_topk_cpu
|
431
|
+
fused_topk_native = fused_topk_cpu
|
432
|
+
else:
|
433
|
+
biased_grouped_topk = biased_grouped_topk_gpu
|
434
|
+
grouped_topk = grouped_topk_gpu
|
435
|
+
fused_topk_native = fused_topk_torch_native
|
436
|
+
|
437
|
+
|
325
438
|
def select_experts(
|
326
439
|
hidden_states: torch.Tensor,
|
327
440
|
router_logits: torch.Tensor,
|
@@ -14,14 +14,18 @@ from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_qu
|
|
14
14
|
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
|
15
15
|
from sglang.srt.layers.quantization.utils import (
|
16
16
|
all_close_1d,
|
17
|
+
cpu_has_amx_support,
|
17
18
|
per_tensor_dequantize,
|
18
19
|
replace_parameter,
|
19
20
|
)
|
20
|
-
from sglang.srt.utils import is_cuda, set_weight_attrs
|
21
|
+
from sglang.srt.utils import is_cpu, is_cuda, is_npu, set_weight_attrs
|
21
22
|
|
22
23
|
_is_cuda = is_cuda()
|
24
|
+
_is_npu = is_npu()
|
25
|
+
_is_cpu_amx_available = cpu_has_amx_support()
|
26
|
+
_is_cpu = is_cpu()
|
23
27
|
|
24
|
-
if not _is_cuda:
|
28
|
+
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)):
|
25
29
|
from vllm import _custom_ops as vllm_ops
|
26
30
|
from vllm._custom_ops import scaled_fp8_quant
|
27
31
|
|