sglang 0.4.4.post4__py3-none-any.whl → 0.4.5.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +21 -0
- sglang/bench_serving.py +10 -4
- sglang/lang/chat_template.py +24 -0
- sglang/srt/configs/model_config.py +40 -4
- sglang/srt/constrained/base_grammar_backend.py +26 -5
- sglang/srt/constrained/llguidance_backend.py +1 -0
- sglang/srt/constrained/outlines_backend.py +1 -0
- sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
- sglang/srt/constrained/xgrammar_backend.py +1 -0
- sglang/srt/conversation.py +29 -4
- sglang/srt/disaggregation/base/__init__.py +8 -0
- sglang/srt/disaggregation/base/conn.py +113 -0
- sglang/srt/disaggregation/decode.py +18 -5
- sglang/srt/disaggregation/mini_lb.py +53 -122
- sglang/srt/disaggregation/mooncake/__init__.py +6 -0
- sglang/srt/disaggregation/mooncake/conn.py +615 -0
- sglang/srt/disaggregation/mooncake/transfer_engine.py +108 -0
- sglang/srt/disaggregation/prefill.py +43 -19
- sglang/srt/disaggregation/utils.py +31 -0
- sglang/srt/entrypoints/EngineBase.py +53 -0
- sglang/srt/entrypoints/engine.py +36 -8
- sglang/srt/entrypoints/http_server.py +37 -8
- sglang/srt/entrypoints/http_server_engine.py +142 -0
- sglang/srt/entrypoints/verl_engine.py +37 -10
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/attention/flashattention_backend.py +609 -202
- sglang/srt/layers/attention/flashinfer_backend.py +13 -7
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/dp_attention.py +2 -4
- sglang/srt/layers/elementwise.py +15 -2
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
- sglang/srt/layers/moe/fused_moe_native.py +5 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +51 -24
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
- sglang/srt/layers/moe/router.py +7 -1
- sglang/srt/layers/moe/topk.py +37 -16
- sglang/srt/layers/quantization/__init__.py +13 -5
- sglang/srt/layers/quantization/blockwise_int8.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +4 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +68 -45
- sglang/srt/layers/quantization/fp8.py +28 -14
- sglang/srt/layers/quantization/fp8_kernel.py +130 -4
- sglang/srt/layers/quantization/fp8_utils.py +34 -6
- sglang/srt/layers/quantization/kv_cache.py +43 -52
- sglang/srt/layers/quantization/modelopt_quant.py +271 -4
- sglang/srt/layers/quantization/moe_wna16.py +2 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +154 -4
- sglang/srt/layers/quantization/w8a8_int8.py +3 -0
- sglang/srt/layers/radix_attention.py +14 -0
- sglang/srt/layers/rotary_embedding.py +75 -1
- sglang/srt/managers/io_struct.py +254 -97
- sglang/srt/managers/mm_utils.py +3 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +114 -77
- sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
- sglang/srt/managers/multimodal_processors/mllama4.py +146 -0
- sglang/srt/managers/schedule_batch.py +62 -21
- sglang/srt/managers/scheduler.py +71 -14
- sglang/srt/managers/tokenizer_manager.py +17 -3
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/memory_pool.py +14 -1
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +7 -4
- sglang/srt/model_executor/forward_batch_info.py +234 -15
- sglang/srt/model_executor/model_runner.py +49 -9
- sglang/srt/model_loader/loader.py +31 -4
- sglang/srt/model_loader/weight_utils.py +4 -2
- sglang/srt/models/baichuan.py +2 -0
- sglang/srt/models/chatglm.py +1 -0
- sglang/srt/models/commandr.py +1 -0
- sglang/srt/models/dbrx.py +1 -0
- sglang/srt/models/deepseek.py +1 -0
- sglang/srt/models/deepseek_v2.py +248 -61
- sglang/srt/models/exaone.py +1 -0
- sglang/srt/models/gemma.py +1 -0
- sglang/srt/models/gemma2.py +1 -0
- sglang/srt/models/gemma3_causal.py +1 -0
- sglang/srt/models/gpt2.py +1 -0
- sglang/srt/models/gpt_bigcode.py +1 -0
- sglang/srt/models/granite.py +1 -0
- sglang/srt/models/grok.py +1 -0
- sglang/srt/models/internlm2.py +1 -0
- sglang/srt/models/llama.py +13 -4
- sglang/srt/models/llama4.py +487 -0
- sglang/srt/models/minicpm.py +1 -0
- sglang/srt/models/minicpm3.py +2 -0
- sglang/srt/models/mixtral.py +1 -0
- sglang/srt/models/mixtral_quant.py +1 -0
- sglang/srt/models/mllama.py +51 -8
- sglang/srt/models/mllama4.py +227 -0
- sglang/srt/models/olmo.py +1 -0
- sglang/srt/models/olmo2.py +1 -0
- sglang/srt/models/olmoe.py +1 -0
- sglang/srt/models/phi3_small.py +1 -0
- sglang/srt/models/qwen.py +1 -0
- sglang/srt/models/qwen2.py +1 -0
- sglang/srt/models/qwen2_5_vl.py +35 -70
- sglang/srt/models/qwen2_moe.py +1 -0
- sglang/srt/models/qwen2_vl.py +27 -25
- sglang/srt/models/stablelm.py +1 -0
- sglang/srt/models/xverse.py +1 -0
- sglang/srt/models/xverse_moe.py +1 -0
- sglang/srt/openai_api/adapter.py +4 -1
- sglang/srt/patch_torch.py +11 -0
- sglang/srt/server_args.py +34 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
- sglang/srt/speculative/eagle_utils.py +1 -11
- sglang/srt/speculative/eagle_worker.py +6 -2
- sglang/srt/utils.py +120 -9
- sglang/test/attention/test_flashattn_backend.py +259 -221
- sglang/test/attention/test_flashattn_mla_backend.py +285 -0
- sglang/test/attention/test_prefix_chunk_info.py +224 -0
- sglang/test/test_block_fp8.py +57 -0
- sglang/test/test_utils.py +19 -8
- sglang/version.py +1 -1
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/METADATA +14 -4
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/RECORD +133 -109
- sglang/srt/disaggregation/conn.py +0 -81
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/top_level.txt +0 -0
sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json
ADDED
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 16,
|
4
|
+
"BLOCK_SIZE_N": 64,
|
5
|
+
"BLOCK_SIZE_K": 64,
|
6
|
+
"GROUP_SIZE_M": 64,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 3
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 16,
|
12
|
+
"BLOCK_SIZE_N": 64,
|
13
|
+
"BLOCK_SIZE_K": 64,
|
14
|
+
"GROUP_SIZE_M": 64,
|
15
|
+
"num_warps": 4,
|
16
|
+
"num_stages": 4
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 16,
|
20
|
+
"BLOCK_SIZE_N": 64,
|
21
|
+
"BLOCK_SIZE_K": 64,
|
22
|
+
"GROUP_SIZE_M": 64,
|
23
|
+
"num_warps": 4,
|
24
|
+
"num_stages": 4
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 16,
|
28
|
+
"BLOCK_SIZE_N": 64,
|
29
|
+
"BLOCK_SIZE_K": 64,
|
30
|
+
"GROUP_SIZE_M": 16,
|
31
|
+
"num_warps": 4,
|
32
|
+
"num_stages": 3
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 16,
|
36
|
+
"BLOCK_SIZE_N": 64,
|
37
|
+
"BLOCK_SIZE_K": 64,
|
38
|
+
"GROUP_SIZE_M": 16,
|
39
|
+
"num_warps": 4,
|
40
|
+
"num_stages": 3
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 16,
|
44
|
+
"BLOCK_SIZE_N": 64,
|
45
|
+
"BLOCK_SIZE_K": 64,
|
46
|
+
"GROUP_SIZE_M": 1,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 3
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 16,
|
52
|
+
"BLOCK_SIZE_N": 64,
|
53
|
+
"BLOCK_SIZE_K": 64,
|
54
|
+
"GROUP_SIZE_M": 16,
|
55
|
+
"num_warps": 4,
|
56
|
+
"num_stages": 3
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 16,
|
60
|
+
"BLOCK_SIZE_N": 64,
|
61
|
+
"BLOCK_SIZE_K": 64,
|
62
|
+
"GROUP_SIZE_M": 1,
|
63
|
+
"num_warps": 4,
|
64
|
+
"num_stages": 3
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 16,
|
68
|
+
"BLOCK_SIZE_N": 128,
|
69
|
+
"BLOCK_SIZE_K": 64,
|
70
|
+
"GROUP_SIZE_M": 64,
|
71
|
+
"num_warps": 4,
|
72
|
+
"num_stages": 3
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 16,
|
76
|
+
"BLOCK_SIZE_N": 128,
|
77
|
+
"BLOCK_SIZE_K": 64,
|
78
|
+
"GROUP_SIZE_M": 16,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 3
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 16,
|
84
|
+
"BLOCK_SIZE_N": 128,
|
85
|
+
"BLOCK_SIZE_K": 64,
|
86
|
+
"GROUP_SIZE_M": 32,
|
87
|
+
"num_warps": 4,
|
88
|
+
"num_stages": 3
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 16,
|
92
|
+
"BLOCK_SIZE_N": 128,
|
93
|
+
"BLOCK_SIZE_K": 64,
|
94
|
+
"GROUP_SIZE_M": 64,
|
95
|
+
"num_warps": 4,
|
96
|
+
"num_stages": 3
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 32,
|
100
|
+
"BLOCK_SIZE_N": 128,
|
101
|
+
"BLOCK_SIZE_K": 64,
|
102
|
+
"GROUP_SIZE_M": 64,
|
103
|
+
"num_warps": 4,
|
104
|
+
"num_stages": 3
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 32,
|
108
|
+
"BLOCK_SIZE_N": 128,
|
109
|
+
"BLOCK_SIZE_K": 64,
|
110
|
+
"GROUP_SIZE_M": 64,
|
111
|
+
"num_warps": 4,
|
112
|
+
"num_stages": 2
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 64,
|
116
|
+
"BLOCK_SIZE_N": 128,
|
117
|
+
"BLOCK_SIZE_K": 64,
|
118
|
+
"GROUP_SIZE_M": 64,
|
119
|
+
"num_warps": 4,
|
120
|
+
"num_stages": 3
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 64,
|
124
|
+
"BLOCK_SIZE_N": 128,
|
125
|
+
"BLOCK_SIZE_K": 64,
|
126
|
+
"GROUP_SIZE_M": 64,
|
127
|
+
"num_warps": 4,
|
128
|
+
"num_stages": 4
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 128,
|
132
|
+
"BLOCK_SIZE_N": 128,
|
133
|
+
"BLOCK_SIZE_K": 64,
|
134
|
+
"GROUP_SIZE_M": 64,
|
135
|
+
"num_warps": 4,
|
136
|
+
"num_stages": 3
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 128,
|
140
|
+
"BLOCK_SIZE_N": 64,
|
141
|
+
"BLOCK_SIZE_K": 64,
|
142
|
+
"GROUP_SIZE_M": 16,
|
143
|
+
"num_warps": 4,
|
144
|
+
"num_stages": 2
|
145
|
+
}
|
146
|
+
}
|
@@ -342,6 +342,7 @@ def fused_moe_kernel(
|
|
342
342
|
use_fp8_w8a8: tl.constexpr,
|
343
343
|
use_int8_w8a8: tl.constexpr,
|
344
344
|
use_int8_w8a16: tl.constexpr,
|
345
|
+
per_channel_quant: tl.constexpr,
|
345
346
|
even_Ks: tl.constexpr,
|
346
347
|
):
|
347
348
|
"""
|
@@ -416,20 +417,7 @@ def fused_moe_kernel(
|
|
416
417
|
)
|
417
418
|
b_scale = tl.load(b_scale_ptrs)
|
418
419
|
|
419
|
-
if use_fp8_w8a8:
|
420
|
-
# block-wise
|
421
|
-
if group_k > 0 and group_n > 0:
|
422
|
-
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
|
423
|
-
offs_bsn = offs_bn // group_n
|
424
|
-
b_scale_ptrs = (
|
425
|
-
b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
|
426
|
-
)
|
427
|
-
# tensor-wise
|
428
|
-
else:
|
429
|
-
a_scale = tl.load(a_scale_ptr)
|
430
|
-
b_scale = tl.load(b_scale_ptr + off_experts)
|
431
|
-
|
432
|
-
if use_int8_w8a8:
|
420
|
+
if use_fp8_w8a8 or use_int8_w8a8:
|
433
421
|
# block-wise
|
434
422
|
if group_k > 0 and group_n > 0:
|
435
423
|
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
|
@@ -438,8 +426,7 @@ def fused_moe_kernel(
|
|
438
426
|
b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
|
439
427
|
)
|
440
428
|
# channel-wise
|
441
|
-
|
442
|
-
# Load per-column scale for weights
|
429
|
+
elif per_channel_quant:
|
443
430
|
b_scale_ptrs = (
|
444
431
|
b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
|
445
432
|
)
|
@@ -447,6 +434,10 @@ def fused_moe_kernel(
|
|
447
434
|
# Load per-token scale for activations
|
448
435
|
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
|
449
436
|
a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:, None]
|
437
|
+
# tensor-wise
|
438
|
+
else:
|
439
|
+
a_scale = tl.load(a_scale_ptr)
|
440
|
+
b_scale = tl.load(b_scale_ptr + off_experts)
|
450
441
|
|
451
442
|
# -----------------------------------------------------------
|
452
443
|
# Iterate to compute a block of the C matrix.
|
@@ -711,12 +702,12 @@ def moe_align_block_size(
|
|
711
702
|
num_tokens_post_pad,
|
712
703
|
)
|
713
704
|
else:
|
714
|
-
token_cnts_buffer = torch.
|
705
|
+
token_cnts_buffer = torch.empty(
|
715
706
|
(num_experts + 1) * num_experts,
|
716
707
|
dtype=torch.int32,
|
717
708
|
device=topk_ids.device,
|
718
709
|
)
|
719
|
-
cumsum_buffer = torch.
|
710
|
+
cumsum_buffer = torch.empty(
|
720
711
|
num_experts + 1, dtype=torch.int32, device=topk_ids.device
|
721
712
|
)
|
722
713
|
|
@@ -753,6 +744,7 @@ def invoke_fused_moe_kernel(
|
|
753
744
|
use_int8_w8a8: bool,
|
754
745
|
use_int8_w8a16: bool,
|
755
746
|
use_int4_w4a16: bool,
|
747
|
+
per_channel_quant: bool,
|
756
748
|
block_shape: Optional[List[int]] = None,
|
757
749
|
no_combine: bool = False,
|
758
750
|
) -> None:
|
@@ -765,6 +757,8 @@ def invoke_fused_moe_kernel(
|
|
765
757
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
766
758
|
sglang_per_token_group_quant_fp8,
|
767
759
|
)
|
760
|
+
else:
|
761
|
+
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
768
762
|
|
769
763
|
assert topk_weights.stride(1) == 1
|
770
764
|
assert sorted_token_ids.stride(0) == 1
|
@@ -775,10 +769,15 @@ def invoke_fused_moe_kernel(
|
|
775
769
|
if block_shape is None:
|
776
770
|
# activation tensor-wise fp8 quantization, dynamic or static
|
777
771
|
padded_size = padding_size
|
772
|
+
# activations apply per-token quantization when weights apply per-channel quantization by default
|
778
773
|
if _is_cuda:
|
779
|
-
A, A_scale = sgl_scaled_fp8_quant(
|
774
|
+
A, A_scale = sgl_scaled_fp8_quant(
|
775
|
+
A, A_scale, use_per_token_if_dynamic=per_channel_quant
|
776
|
+
)
|
780
777
|
else:
|
781
|
-
A, A_scale = vllm_ops.scaled_fp8_quant(
|
778
|
+
A, A_scale = vllm_ops.scaled_fp8_quant(
|
779
|
+
A, A_scale, use_per_token_if_dynamic=per_channel_quant
|
780
|
+
)
|
782
781
|
else:
|
783
782
|
# activation block-wise fp8 quantization
|
784
783
|
assert len(block_shape) == 2
|
@@ -794,6 +793,9 @@ def invoke_fused_moe_kernel(
|
|
794
793
|
assert B_scale is not None
|
795
794
|
if block_shape is None:
|
796
795
|
# activation channel-wise int8 quantization
|
796
|
+
assert (
|
797
|
+
per_channel_quant
|
798
|
+
), "int8 quantization only supports channel-wise quantization except for block-wise quantization"
|
797
799
|
A, A_scale = per_token_quant_int8(A)
|
798
800
|
else:
|
799
801
|
# activation block-wise int8 quantization
|
@@ -902,6 +904,7 @@ def invoke_fused_moe_kernel(
|
|
902
904
|
use_fp8_w8a8=use_fp8_w8a8,
|
903
905
|
use_int8_w8a8=use_int8_w8a8,
|
904
906
|
use_int8_w8a16=use_int8_w8a16,
|
907
|
+
per_channel_quant=per_channel_quant,
|
905
908
|
even_Ks=even_Ks,
|
906
909
|
**config,
|
907
910
|
)
|
@@ -953,7 +956,7 @@ def get_moe_configs(
|
|
953
956
|
logger.warning(
|
954
957
|
(
|
955
958
|
"Using default MoE config. Performance might be sub-optimal! "
|
956
|
-
"Config file not found at %s"
|
959
|
+
"Config file not found at %s, you can tune the config with https://github.com/sgl-project/sglang/blob/main/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py."
|
957
960
|
),
|
958
961
|
config_file_path,
|
959
962
|
)
|
@@ -1079,10 +1082,12 @@ def inplace_fused_experts(
|
|
1079
1082
|
topk_weights: torch.Tensor,
|
1080
1083
|
topk_ids: torch.Tensor,
|
1081
1084
|
activation: str = "silu",
|
1085
|
+
apply_router_weight_on_input: bool = False,
|
1082
1086
|
use_fp8_w8a8: bool = False,
|
1083
1087
|
use_int8_w8a8: bool = False,
|
1084
1088
|
use_int8_w8a16: bool = False,
|
1085
1089
|
use_int4_w4a16: bool = False,
|
1090
|
+
per_channel_quant: bool = False,
|
1086
1091
|
w1_scale: Optional[torch.Tensor] = None,
|
1087
1092
|
w2_scale: Optional[torch.Tensor] = None,
|
1088
1093
|
w1_zp: Optional[torch.Tensor] = None,
|
@@ -1099,10 +1104,12 @@ def inplace_fused_experts(
|
|
1099
1104
|
topk_ids,
|
1100
1105
|
True,
|
1101
1106
|
activation,
|
1107
|
+
apply_router_weight_on_input,
|
1102
1108
|
use_fp8_w8a8,
|
1103
1109
|
use_int8_w8a8,
|
1104
1110
|
use_int8_w8a16,
|
1105
1111
|
use_int4_w4a16,
|
1112
|
+
per_channel_quant,
|
1106
1113
|
w1_scale,
|
1107
1114
|
w2_scale,
|
1108
1115
|
w1_zp,
|
@@ -1120,10 +1127,12 @@ def inplace_fused_experts_fake(
|
|
1120
1127
|
topk_weights: torch.Tensor,
|
1121
1128
|
topk_ids: torch.Tensor,
|
1122
1129
|
activation: str = "silu",
|
1130
|
+
apply_router_weight_on_input: bool = False,
|
1123
1131
|
use_fp8_w8a8: bool = False,
|
1124
1132
|
use_int8_w8a8: bool = False,
|
1125
1133
|
use_int8_w8a16: bool = False,
|
1126
1134
|
use_int4_w4a16: bool = False,
|
1135
|
+
per_channel_quant: bool = False,
|
1127
1136
|
w1_scale: Optional[torch.Tensor] = None,
|
1128
1137
|
w2_scale: Optional[torch.Tensor] = None,
|
1129
1138
|
w1_zp: Optional[torch.Tensor] = None,
|
@@ -1150,10 +1159,12 @@ def outplace_fused_experts(
|
|
1150
1159
|
topk_weights: torch.Tensor,
|
1151
1160
|
topk_ids: torch.Tensor,
|
1152
1161
|
activation: str = "silu",
|
1162
|
+
apply_router_weight_on_input: bool = False,
|
1153
1163
|
use_fp8_w8a8: bool = False,
|
1154
1164
|
use_int8_w8a8: bool = False,
|
1155
1165
|
use_int8_w8a16: bool = False,
|
1156
1166
|
use_int4_w4a16: bool = False,
|
1167
|
+
per_channel_quant: bool = False,
|
1157
1168
|
w1_scale: Optional[torch.Tensor] = None,
|
1158
1169
|
w2_scale: Optional[torch.Tensor] = None,
|
1159
1170
|
w1_zp: Optional[torch.Tensor] = None,
|
@@ -1171,10 +1182,12 @@ def outplace_fused_experts(
|
|
1171
1182
|
topk_ids,
|
1172
1183
|
False,
|
1173
1184
|
activation,
|
1185
|
+
apply_router_weight_on_input,
|
1174
1186
|
use_fp8_w8a8,
|
1175
1187
|
use_int8_w8a8,
|
1176
1188
|
use_int8_w8a16,
|
1177
1189
|
use_int4_w4a16,
|
1190
|
+
per_channel_quant,
|
1178
1191
|
w1_scale,
|
1179
1192
|
w2_scale,
|
1180
1193
|
w1_zp,
|
@@ -1193,10 +1206,12 @@ def outplace_fused_experts_fake(
|
|
1193
1206
|
topk_weights: torch.Tensor,
|
1194
1207
|
topk_ids: torch.Tensor,
|
1195
1208
|
activation: str = "silu",
|
1209
|
+
apply_router_weight_on_input: bool = False,
|
1196
1210
|
use_fp8_w8a8: bool = False,
|
1197
1211
|
use_int8_w8a8: bool = False,
|
1198
1212
|
use_int8_w8a16: bool = False,
|
1199
1213
|
use_int4_w4a16: bool = False,
|
1214
|
+
per_channel_quant: bool = False,
|
1200
1215
|
w1_scale: Optional[torch.Tensor] = None,
|
1201
1216
|
w2_scale: Optional[torch.Tensor] = None,
|
1202
1217
|
w1_zp: Optional[torch.Tensor] = None,
|
@@ -1225,10 +1240,12 @@ def fused_experts(
|
|
1225
1240
|
topk_ids: torch.Tensor,
|
1226
1241
|
inplace: bool = False,
|
1227
1242
|
activation: str = "silu",
|
1243
|
+
apply_router_weight_on_input: bool = False,
|
1228
1244
|
use_fp8_w8a8: bool = False,
|
1229
1245
|
use_int8_w8a8: bool = False,
|
1230
1246
|
use_int8_w8a16: bool = False,
|
1231
1247
|
use_int4_w4a16: bool = False,
|
1248
|
+
per_channel_quant: bool = False,
|
1232
1249
|
w1_scale: Optional[torch.Tensor] = None,
|
1233
1250
|
w2_scale: Optional[torch.Tensor] = None,
|
1234
1251
|
w1_zp: Optional[torch.Tensor] = None,
|
@@ -1247,10 +1264,12 @@ def fused_experts(
|
|
1247
1264
|
topk_weights,
|
1248
1265
|
topk_ids,
|
1249
1266
|
activation,
|
1267
|
+
apply_router_weight_on_input,
|
1250
1268
|
use_fp8_w8a8,
|
1251
1269
|
use_int8_w8a8,
|
1252
1270
|
use_int8_w8a16,
|
1253
1271
|
use_int4_w4a16,
|
1272
|
+
per_channel_quant,
|
1254
1273
|
w1_scale,
|
1255
1274
|
w2_scale,
|
1256
1275
|
w1_zp,
|
@@ -1268,10 +1287,12 @@ def fused_experts(
|
|
1268
1287
|
topk_weights,
|
1269
1288
|
topk_ids,
|
1270
1289
|
activation,
|
1290
|
+
apply_router_weight_on_input,
|
1271
1291
|
use_fp8_w8a8,
|
1272
1292
|
use_int8_w8a8,
|
1273
1293
|
use_int8_w8a16,
|
1274
1294
|
use_int4_w4a16,
|
1295
|
+
per_channel_quant,
|
1275
1296
|
w1_scale,
|
1276
1297
|
w2_scale,
|
1277
1298
|
w1_zp,
|
@@ -1291,10 +1312,12 @@ def fused_experts_impl(
|
|
1291
1312
|
topk_ids: torch.Tensor,
|
1292
1313
|
inplace: bool = False,
|
1293
1314
|
activation: str = "silu",
|
1315
|
+
apply_router_weight_on_input: bool = False,
|
1294
1316
|
use_fp8_w8a8: bool = False,
|
1295
1317
|
use_int8_w8a8: bool = False,
|
1296
1318
|
use_int8_w8a16: bool = False,
|
1297
1319
|
use_int4_w4a16: bool = False,
|
1320
|
+
per_channel_quant: bool = False,
|
1298
1321
|
w1_scale: Optional[torch.Tensor] = None,
|
1299
1322
|
w2_scale: Optional[torch.Tensor] = None,
|
1300
1323
|
w1_zp: Optional[torch.Tensor] = None,
|
@@ -1423,7 +1446,7 @@ def fused_experts_impl(
|
|
1423
1446
|
sorted_token_ids,
|
1424
1447
|
expert_ids,
|
1425
1448
|
num_tokens_post_padded,
|
1426
|
-
|
1449
|
+
apply_router_weight_on_input,
|
1427
1450
|
topk_ids.shape[1],
|
1428
1451
|
config,
|
1429
1452
|
compute_type=compute_type,
|
@@ -1431,6 +1454,7 @@ def fused_experts_impl(
|
|
1431
1454
|
use_int8_w8a8=use_int8_w8a8,
|
1432
1455
|
use_int8_w8a16=use_int8_w8a16,
|
1433
1456
|
use_int4_w4a16=use_int4_w4a16,
|
1457
|
+
per_channel_quant=per_channel_quant,
|
1434
1458
|
block_shape=block_shape,
|
1435
1459
|
)
|
1436
1460
|
if activation == "silu":
|
@@ -1456,7 +1480,7 @@ def fused_experts_impl(
|
|
1456
1480
|
(
|
1457
1481
|
intermediate_cache3
|
1458
1482
|
if not no_combine and topk_ids.shape[1] != 1
|
1459
|
-
else out_hidden_states[begin_chunk_idx:end_chunk_idx]
|
1483
|
+
else out_hidden_states[begin_chunk_idx:end_chunk_idx].unsqueeze(0)
|
1460
1484
|
),
|
1461
1485
|
a2_scale,
|
1462
1486
|
w2_scale,
|
@@ -1466,7 +1490,7 @@ def fused_experts_impl(
|
|
1466
1490
|
sorted_token_ids,
|
1467
1491
|
expert_ids,
|
1468
1492
|
num_tokens_post_padded,
|
1469
|
-
|
1493
|
+
not apply_router_weight_on_input,
|
1470
1494
|
1,
|
1471
1495
|
config,
|
1472
1496
|
compute_type=compute_type,
|
@@ -1474,6 +1498,7 @@ def fused_experts_impl(
|
|
1474
1498
|
use_int8_w8a8=use_int8_w8a8,
|
1475
1499
|
use_int8_w8a16=use_int8_w8a16,
|
1476
1500
|
use_int4_w4a16=use_int4_w4a16,
|
1501
|
+
per_channel_quant=per_channel_quant,
|
1477
1502
|
block_shape=block_shape,
|
1478
1503
|
)
|
1479
1504
|
|
@@ -1520,6 +1545,7 @@ def fused_moe(
|
|
1520
1545
|
use_int8_w8a8: bool = False,
|
1521
1546
|
use_int8_w8a16: bool = False,
|
1522
1547
|
use_int4_w4a16: bool = False,
|
1548
|
+
per_channel_quant: bool = False,
|
1523
1549
|
w1_scale: Optional[torch.Tensor] = None,
|
1524
1550
|
w2_scale: Optional[torch.Tensor] = None,
|
1525
1551
|
w1_zp: Optional[torch.Tensor] = None,
|
@@ -1596,6 +1622,7 @@ def fused_moe(
|
|
1596
1622
|
use_int8_w8a8=use_int8_w8a8,
|
1597
1623
|
use_int8_w8a16=use_int8_w8a16,
|
1598
1624
|
use_int4_w4a16=use_int4_w4a16,
|
1625
|
+
per_channel_quant=per_channel_quant,
|
1599
1626
|
w1_scale=w1_scale,
|
1600
1627
|
w2_scale=w2_scale,
|
1601
1628
|
w1_zp=w1_zp,
|
@@ -128,6 +128,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
128
128
|
custom_routing_function: Optional[Callable] = None,
|
129
129
|
correction_bias: Optional[torch.Tensor] = None,
|
130
130
|
activation: str = "silu",
|
131
|
+
apply_router_weight_on_input: bool = False,
|
131
132
|
inplace: bool = True,
|
132
133
|
no_combine: bool = False,
|
133
134
|
) -> torch.Tensor:
|
@@ -143,6 +144,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
143
144
|
custom_routing_function=custom_routing_function,
|
144
145
|
correction_bias=correction_bias,
|
145
146
|
activation=activation,
|
147
|
+
apply_router_weight_on_input=apply_router_weight_on_input,
|
146
148
|
inplace=inplace,
|
147
149
|
no_combine=no_combine,
|
148
150
|
)
|
@@ -160,6 +162,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
160
162
|
custom_routing_function: Optional[Callable] = None,
|
161
163
|
correction_bias: Optional[torch.Tensor] = None,
|
162
164
|
activation: str = "silu",
|
165
|
+
apply_router_weight_on_input: bool = False,
|
163
166
|
inplace: bool = True,
|
164
167
|
no_combine: bool = False,
|
165
168
|
) -> torch.Tensor:
|
@@ -200,6 +203,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
200
203
|
topk_ids=topk_ids,
|
201
204
|
inplace=inplace and not no_combine,
|
202
205
|
activation=activation,
|
206
|
+
apply_router_weight_on_input=apply_router_weight_on_input,
|
203
207
|
no_combine=no_combine,
|
204
208
|
)
|
205
209
|
|
@@ -276,6 +280,7 @@ class FusedMoE(torch.nn.Module):
|
|
276
280
|
custom_routing_function: Optional[Callable] = None,
|
277
281
|
correction_bias: Optional[torch.Tensor] = None,
|
278
282
|
activation: str = "silu",
|
283
|
+
apply_router_weight_on_input: bool = False,
|
279
284
|
use_presharded_weights: bool = False,
|
280
285
|
inplace: bool = True,
|
281
286
|
no_combine: bool = False,
|
@@ -302,6 +307,7 @@ class FusedMoE(torch.nn.Module):
|
|
302
307
|
self.custom_routing_function = custom_routing_function
|
303
308
|
self.correction_bias = correction_bias
|
304
309
|
self.activation = activation
|
310
|
+
self.apply_router_weight_on_input = apply_router_weight_on_input
|
305
311
|
self.use_presharded_weights = use_presharded_weights
|
306
312
|
self.inplace = inplace
|
307
313
|
self.no_combine = no_combine
|
@@ -630,6 +636,7 @@ class FusedMoE(torch.nn.Module):
|
|
630
636
|
custom_routing_function=self.custom_routing_function,
|
631
637
|
correction_bias=self.correction_bias,
|
632
638
|
activation=self.activation,
|
639
|
+
apply_router_weight_on_input=self.apply_router_weight_on_input,
|
633
640
|
)
|
634
641
|
|
635
642
|
if self.reduce_results and self.tp_size > 1:
|
sglang/srt/layers/moe/router.py
CHANGED
@@ -5,6 +5,9 @@ import triton
|
|
5
5
|
import triton.language as tl
|
6
6
|
|
7
7
|
from sglang.srt.layers.moe.topk import fused_topk
|
8
|
+
from sglang.srt.utils import is_hip
|
9
|
+
|
10
|
+
_is_hip = is_hip()
|
8
11
|
|
9
12
|
|
10
13
|
@triton.jit
|
@@ -116,10 +119,13 @@ def fused_moe_router_impl(
|
|
116
119
|
topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device)
|
117
120
|
|
118
121
|
grid = lambda meta: (bs,)
|
122
|
+
|
123
|
+
min_num_warps = 16 if _is_hip else 32
|
124
|
+
|
119
125
|
config = {
|
120
126
|
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
|
121
127
|
"num_warps": max(
|
122
|
-
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)),
|
128
|
+
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), min_num_warps), 4
|
123
129
|
),
|
124
130
|
}
|
125
131
|
|
sglang/srt/layers/moe/topk.py
CHANGED
@@ -12,6 +12,7 @@
|
|
12
12
|
# limitations under the License.
|
13
13
|
# ==============================================================================
|
14
14
|
|
15
|
+
import math
|
15
16
|
import os
|
16
17
|
from typing import Callable, Optional
|
17
18
|
|
@@ -25,6 +26,8 @@ from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip
|
|
25
26
|
_is_cuda = is_cuda()
|
26
27
|
_is_hip = is_hip()
|
27
28
|
|
29
|
+
if _is_cuda:
|
30
|
+
from sgl_kernel import moe_fused_gate
|
28
31
|
|
29
32
|
expert_distribution_recorder = ExpertDistributionRecorder()
|
30
33
|
|
@@ -209,6 +212,10 @@ def biased_grouped_topk_impl(
|
|
209
212
|
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
210
213
|
|
211
214
|
|
215
|
+
def is_power_of_two(n):
|
216
|
+
return n > 0 and math.log2(n).is_integer()
|
217
|
+
|
218
|
+
|
212
219
|
def biased_grouped_topk(
|
213
220
|
hidden_states: torch.Tensor,
|
214
221
|
gating_output: torch.Tensor,
|
@@ -220,23 +227,37 @@ def biased_grouped_topk(
|
|
220
227
|
compiled: bool = True,
|
221
228
|
n_share_experts_fusion: int = 0,
|
222
229
|
):
|
223
|
-
|
224
|
-
|
225
|
-
|
230
|
+
# TODO: moe_fused_gate kernel is not supported for n_share_experts_fusion > 0 now.
|
231
|
+
if (
|
232
|
+
_is_cuda
|
233
|
+
and n_share_experts_fusion == 0
|
234
|
+
and is_power_of_two(correction_bias.shape[0])
|
235
|
+
):
|
236
|
+
return moe_fused_gate(
|
237
|
+
gating_output,
|
238
|
+
correction_bias,
|
239
|
+
num_expert_group,
|
240
|
+
topk_group,
|
241
|
+
topk,
|
242
|
+
)
|
243
|
+
else:
|
244
|
+
biased_grouped_topk_fn = (
|
245
|
+
torch.compile(
|
246
|
+
biased_grouped_topk_impl, dynamic=True, backend=get_compiler_backend()
|
247
|
+
)
|
248
|
+
if compiled
|
249
|
+
else biased_grouped_topk_impl
|
250
|
+
)
|
251
|
+
return biased_grouped_topk_fn(
|
252
|
+
hidden_states,
|
253
|
+
gating_output,
|
254
|
+
correction_bias,
|
255
|
+
topk,
|
256
|
+
renormalize,
|
257
|
+
num_expert_group,
|
258
|
+
topk_group,
|
259
|
+
n_share_experts_fusion=n_share_experts_fusion,
|
226
260
|
)
|
227
|
-
if compiled
|
228
|
-
else biased_grouped_topk_impl
|
229
|
-
)
|
230
|
-
return biased_grouped_topk_fn(
|
231
|
-
hidden_states,
|
232
|
-
gating_output,
|
233
|
-
correction_bias,
|
234
|
-
topk,
|
235
|
-
renormalize,
|
236
|
-
num_expert_group,
|
237
|
-
topk_group,
|
238
|
-
n_share_experts_fusion=n_share_experts_fusion,
|
239
|
-
)
|
240
261
|
|
241
262
|
|
242
263
|
def select_experts(
|
@@ -59,20 +59,20 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import
|
|
59
59
|
)
|
60
60
|
from sglang.srt.layers.quantization.fp8 import Fp8Config
|
61
61
|
from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
|
62
|
-
from sglang.srt.layers.quantization.modelopt_quant import
|
62
|
+
from sglang.srt.layers.quantization.modelopt_quant import (
|
63
|
+
ModelOptFp4Config,
|
64
|
+
ModelOptFp8Config,
|
65
|
+
)
|
63
66
|
from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
|
64
67
|
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
|
65
68
|
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
|
66
|
-
from sglang.srt.layers.vocab_parallel_embedding import (
|
67
|
-
ParallelLMHead,
|
68
|
-
UnquantizedEmbeddingMethod,
|
69
|
-
)
|
70
69
|
|
71
70
|
# Base quantization methods that don't depend on vllm
|
72
71
|
BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
73
72
|
"fp8": Fp8Config,
|
74
73
|
"blockwise_int8": BlockInt8Config,
|
75
74
|
"modelopt": ModelOptFp8Config,
|
75
|
+
"modelopt_fp4": ModelOptFp4Config,
|
76
76
|
"w8a8_int8": W8A8Int8Config,
|
77
77
|
"w8a8_fp8": W8A8Fp8Config,
|
78
78
|
"moe_wna16": MoeWNA16Config,
|
@@ -176,6 +176,13 @@ def get_linear_quant_method(
|
|
176
176
|
prefix: str,
|
177
177
|
linear_method_cls: type,
|
178
178
|
):
|
179
|
+
# Move import here to avoid circular import. This is only used in monkey patching
|
180
|
+
# of vllm's QuantizationConfig.
|
181
|
+
from sglang.srt.layers.vocab_parallel_embedding import (
|
182
|
+
ParallelLMHead,
|
183
|
+
UnquantizedEmbeddingMethod,
|
184
|
+
)
|
185
|
+
|
179
186
|
cloned_config = deepcopy(config)
|
180
187
|
parallel_lm_head_quantized = (
|
181
188
|
isinstance(layer, ParallelLMHead) and cloned_config.lm_head_quantized
|
@@ -280,6 +287,7 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
|
|
280
287
|
custom_routing_function: Optional[Callable] = None,
|
281
288
|
correction_bias: Optional[torch.Tensor] = None,
|
282
289
|
activation: str = "silu",
|
290
|
+
apply_router_weight_on_input: bool = False,
|
283
291
|
inplace: bool = True,
|
284
292
|
no_combine: bool = False,
|
285
293
|
):
|
@@ -370,6 +370,7 @@ class BlockInt8MoEMethod:
|
|
370
370
|
custom_routing_function: Optional[Callable] = None,
|
371
371
|
correction_bias: Optional[torch.Tensor] = None,
|
372
372
|
activation: str = "silu",
|
373
|
+
apply_router_weight_on_input: bool = False,
|
373
374
|
inplace: bool = True,
|
374
375
|
no_combine: bool = False,
|
375
376
|
) -> torch.Tensor:
|
@@ -398,6 +399,7 @@ class BlockInt8MoEMethod:
|
|
398
399
|
topk_ids=topk_ids,
|
399
400
|
inplace=inplace,
|
400
401
|
activation=activation,
|
402
|
+
apply_router_weight_on_input=apply_router_weight_on_input,
|
401
403
|
use_int8_w8a8=True,
|
402
404
|
w1_scale=(layer.w13_weight_scale_inv),
|
403
405
|
w2_scale=(layer.w2_weight_scale_inv),
|