sglang 0.4.5__py3-none-any.whl → 0.4.5.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -4
- sglang/bench_one_batch.py +23 -2
- sglang/bench_serving.py +6 -4
- sglang/lang/backend/anthropic.py +0 -4
- sglang/lang/backend/base_backend.py +1 -1
- sglang/lang/backend/openai.py +1 -1
- sglang/lang/backend/vertexai.py +0 -1
- sglang/lang/compiler.py +1 -7
- sglang/lang/tracer.py +3 -7
- sglang/srt/_custom_ops.py +0 -2
- sglang/srt/configs/model_config.py +37 -5
- sglang/srt/constrained/base_grammar_backend.py +26 -5
- sglang/srt/constrained/llguidance_backend.py +1 -0
- sglang/srt/constrained/outlines_backend.py +1 -0
- sglang/srt/constrained/outlines_jump_forward.py +14 -1
- sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
- sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
- sglang/srt/constrained/xgrammar_backend.py +27 -4
- sglang/srt/custom_op.py +0 -62
- sglang/srt/disaggregation/base/__init__.py +8 -0
- sglang/srt/disaggregation/base/conn.py +113 -0
- sglang/srt/disaggregation/decode.py +80 -11
- sglang/srt/disaggregation/mini_lb.py +58 -123
- sglang/srt/disaggregation/mooncake/__init__.py +6 -0
- sglang/srt/disaggregation/mooncake/conn.py +585 -0
- sglang/srt/disaggregation/mooncake/transfer_engine.py +77 -0
- sglang/srt/disaggregation/prefill.py +82 -22
- sglang/srt/disaggregation/utils.py +46 -0
- sglang/srt/entrypoints/EngineBase.py +53 -0
- sglang/srt/entrypoints/engine.py +36 -8
- sglang/srt/entrypoints/http_server.py +37 -8
- sglang/srt/entrypoints/http_server_engine.py +142 -0
- sglang/srt/entrypoints/verl_engine.py +42 -13
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/activation.py +6 -8
- sglang/srt/layers/attention/flashattention_backend.py +430 -257
- sglang/srt/layers/attention/flashinfer_backend.py +18 -9
- sglang/srt/layers/attention/torch_native_backend.py +6 -1
- sglang/srt/layers/attention/triton_backend.py +6 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +13 -2
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/dp_attention.py +2 -4
- sglang/srt/layers/elementwise.py +15 -2
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +18 -3
- sglang/srt/layers/moe/ep_moe/layer.py +15 -29
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
- sglang/srt/layers/moe/fused_moe_native.py +4 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +46 -34
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
- sglang/srt/layers/moe/router.py +7 -1
- sglang/srt/layers/moe/topk.py +63 -45
- sglang/srt/layers/parameter.py +0 -2
- sglang/srt/layers/quantization/__init__.py +13 -5
- sglang/srt/layers/quantization/blockwise_int8.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +12 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -77
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
- sglang/srt/layers/quantization/fp8.py +131 -136
- sglang/srt/layers/quantization/fp8_kernel.py +328 -46
- sglang/srt/layers/quantization/fp8_utils.py +206 -253
- sglang/srt/layers/quantization/kv_cache.py +43 -52
- sglang/srt/layers/quantization/modelopt_quant.py +271 -4
- sglang/srt/layers/quantization/moe_wna16.py +2 -0
- sglang/srt/layers/quantization/utils.py +5 -11
- sglang/srt/layers/quantization/w8a8_fp8.py +156 -4
- sglang/srt/layers/quantization/w8a8_int8.py +8 -7
- sglang/srt/layers/radix_attention.py +28 -1
- sglang/srt/layers/rotary_embedding.py +15 -3
- sglang/srt/layers/sampler.py +5 -10
- sglang/srt/lora/backend/base_backend.py +18 -2
- sglang/srt/lora/backend/flashinfer_backend.py +1 -1
- sglang/srt/lora/backend/triton_backend.py +1 -1
- sglang/srt/lora/layers.py +1 -1
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/lora/lora_manager.py +1 -1
- sglang/srt/managers/detokenizer_manager.py +0 -1
- sglang/srt/managers/io_struct.py +255 -97
- sglang/srt/managers/mm_utils.py +7 -5
- sglang/srt/managers/multimodal_processor.py +0 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +117 -79
- sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
- sglang/srt/managers/multimodal_processors/mllama4.py +21 -36
- sglang/srt/managers/schedule_batch.py +64 -25
- sglang/srt/managers/scheduler.py +80 -82
- sglang/srt/managers/tokenizer_manager.py +18 -3
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +5 -1
- sglang/srt/mem_cache/memory_pool.py +21 -3
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +9 -6
- sglang/srt/model_executor/forward_batch_info.py +234 -15
- sglang/srt/model_executor/model_runner.py +67 -35
- sglang/srt/model_loader/loader.py +31 -4
- sglang/srt/model_loader/weight_utils.py +4 -2
- sglang/srt/models/baichuan.py +2 -0
- sglang/srt/models/bert.py +398 -0
- sglang/srt/models/chatglm.py +1 -0
- sglang/srt/models/commandr.py +1 -0
- sglang/srt/models/dbrx.py +1 -0
- sglang/srt/models/deepseek.py +2 -1
- sglang/srt/models/deepseek_nextn.py +74 -70
- sglang/srt/models/deepseek_v2.py +494 -366
- sglang/srt/models/exaone.py +1 -0
- sglang/srt/models/gemma.py +1 -0
- sglang/srt/models/gemma2.py +1 -0
- sglang/srt/models/gemma3_causal.py +1 -0
- sglang/srt/models/gpt2.py +1 -0
- sglang/srt/models/gpt_bigcode.py +1 -0
- sglang/srt/models/granite.py +1 -0
- sglang/srt/models/grok.py +1 -0
- sglang/srt/models/internlm2.py +1 -0
- sglang/srt/models/llama.py +6 -5
- sglang/srt/models/llama4.py +101 -34
- sglang/srt/models/minicpm.py +1 -0
- sglang/srt/models/minicpm3.py +30 -200
- sglang/srt/models/mixtral.py +1 -0
- sglang/srt/models/mixtral_quant.py +1 -0
- sglang/srt/models/mllama.py +51 -8
- sglang/srt/models/mllama4.py +102 -29
- sglang/srt/models/olmo.py +1 -0
- sglang/srt/models/olmo2.py +1 -0
- sglang/srt/models/olmoe.py +1 -0
- sglang/srt/models/phi3_small.py +1 -0
- sglang/srt/models/qwen.py +1 -0
- sglang/srt/models/qwen2.py +5 -1
- sglang/srt/models/qwen2_5_vl.py +35 -70
- sglang/srt/models/qwen2_moe.py +15 -13
- sglang/srt/models/qwen2_vl.py +27 -25
- sglang/srt/models/qwen3.py +335 -0
- sglang/srt/models/qwen3_moe.py +423 -0
- sglang/srt/models/stablelm.py +1 -0
- sglang/srt/models/xverse.py +1 -0
- sglang/srt/models/xverse_moe.py +1 -0
- sglang/srt/openai_api/adapter.py +4 -1
- sglang/srt/patch_torch.py +11 -0
- sglang/srt/reasoning_parser.py +0 -1
- sglang/srt/sampling/sampling_batch_info.py +2 -3
- sglang/srt/server_args.py +55 -19
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
- sglang/srt/speculative/eagle_utils.py +1 -11
- sglang/srt/speculative/eagle_worker.py +10 -9
- sglang/srt/utils.py +136 -10
- sglang/test/attention/test_flashattn_backend.py +259 -221
- sglang/test/attention/test_flashattn_mla_backend.py +285 -0
- sglang/test/attention/test_prefix_chunk_info.py +224 -0
- sglang/test/runners.py +5 -1
- sglang/test/test_block_fp8.py +224 -0
- sglang/test/test_custom_ops.py +1 -1
- sglang/test/test_utils.py +19 -8
- sglang/version.py +1 -1
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/METADATA +15 -5
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/RECORD +162 -147
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/WHEEL +1 -1
- sglang/lang/__init__.py +0 -0
- sglang/srt/disaggregation/conn.py +0 -81
- sglang/srt/lora/backend/__init__.py +0 -25
- sglang/srt/server.py +0 -18
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/top_level.txt +0 -0
sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json
ADDED
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 16,
|
4
|
+
"BLOCK_SIZE_N": 64,
|
5
|
+
"BLOCK_SIZE_K": 64,
|
6
|
+
"GROUP_SIZE_M": 64,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 3
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 16,
|
12
|
+
"BLOCK_SIZE_N": 64,
|
13
|
+
"BLOCK_SIZE_K": 64,
|
14
|
+
"GROUP_SIZE_M": 64,
|
15
|
+
"num_warps": 4,
|
16
|
+
"num_stages": 4
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 16,
|
20
|
+
"BLOCK_SIZE_N": 64,
|
21
|
+
"BLOCK_SIZE_K": 64,
|
22
|
+
"GROUP_SIZE_M": 64,
|
23
|
+
"num_warps": 4,
|
24
|
+
"num_stages": 4
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 16,
|
28
|
+
"BLOCK_SIZE_N": 64,
|
29
|
+
"BLOCK_SIZE_K": 64,
|
30
|
+
"GROUP_SIZE_M": 16,
|
31
|
+
"num_warps": 4,
|
32
|
+
"num_stages": 3
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 16,
|
36
|
+
"BLOCK_SIZE_N": 64,
|
37
|
+
"BLOCK_SIZE_K": 64,
|
38
|
+
"GROUP_SIZE_M": 16,
|
39
|
+
"num_warps": 4,
|
40
|
+
"num_stages": 3
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 16,
|
44
|
+
"BLOCK_SIZE_N": 64,
|
45
|
+
"BLOCK_SIZE_K": 64,
|
46
|
+
"GROUP_SIZE_M": 1,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 3
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 16,
|
52
|
+
"BLOCK_SIZE_N": 64,
|
53
|
+
"BLOCK_SIZE_K": 64,
|
54
|
+
"GROUP_SIZE_M": 16,
|
55
|
+
"num_warps": 4,
|
56
|
+
"num_stages": 3
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 16,
|
60
|
+
"BLOCK_SIZE_N": 64,
|
61
|
+
"BLOCK_SIZE_K": 64,
|
62
|
+
"GROUP_SIZE_M": 1,
|
63
|
+
"num_warps": 4,
|
64
|
+
"num_stages": 3
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 16,
|
68
|
+
"BLOCK_SIZE_N": 128,
|
69
|
+
"BLOCK_SIZE_K": 64,
|
70
|
+
"GROUP_SIZE_M": 64,
|
71
|
+
"num_warps": 4,
|
72
|
+
"num_stages": 3
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 16,
|
76
|
+
"BLOCK_SIZE_N": 128,
|
77
|
+
"BLOCK_SIZE_K": 64,
|
78
|
+
"GROUP_SIZE_M": 16,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 3
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 16,
|
84
|
+
"BLOCK_SIZE_N": 128,
|
85
|
+
"BLOCK_SIZE_K": 64,
|
86
|
+
"GROUP_SIZE_M": 32,
|
87
|
+
"num_warps": 4,
|
88
|
+
"num_stages": 3
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 16,
|
92
|
+
"BLOCK_SIZE_N": 128,
|
93
|
+
"BLOCK_SIZE_K": 64,
|
94
|
+
"GROUP_SIZE_M": 64,
|
95
|
+
"num_warps": 4,
|
96
|
+
"num_stages": 3
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 32,
|
100
|
+
"BLOCK_SIZE_N": 128,
|
101
|
+
"BLOCK_SIZE_K": 64,
|
102
|
+
"GROUP_SIZE_M": 64,
|
103
|
+
"num_warps": 4,
|
104
|
+
"num_stages": 3
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 32,
|
108
|
+
"BLOCK_SIZE_N": 128,
|
109
|
+
"BLOCK_SIZE_K": 64,
|
110
|
+
"GROUP_SIZE_M": 64,
|
111
|
+
"num_warps": 4,
|
112
|
+
"num_stages": 2
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 64,
|
116
|
+
"BLOCK_SIZE_N": 128,
|
117
|
+
"BLOCK_SIZE_K": 64,
|
118
|
+
"GROUP_SIZE_M": 64,
|
119
|
+
"num_warps": 4,
|
120
|
+
"num_stages": 3
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 64,
|
124
|
+
"BLOCK_SIZE_N": 128,
|
125
|
+
"BLOCK_SIZE_K": 64,
|
126
|
+
"GROUP_SIZE_M": 64,
|
127
|
+
"num_warps": 4,
|
128
|
+
"num_stages": 4
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 128,
|
132
|
+
"BLOCK_SIZE_N": 128,
|
133
|
+
"BLOCK_SIZE_K": 64,
|
134
|
+
"GROUP_SIZE_M": 64,
|
135
|
+
"num_warps": 4,
|
136
|
+
"num_stages": 3
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 128,
|
140
|
+
"BLOCK_SIZE_N": 64,
|
141
|
+
"BLOCK_SIZE_K": 64,
|
142
|
+
"GROUP_SIZE_M": 16,
|
143
|
+
"num_warps": 4,
|
144
|
+
"num_stages": 2
|
145
|
+
}
|
146
|
+
}
|
@@ -13,6 +13,7 @@ import triton
|
|
13
13
|
import triton.language as tl
|
14
14
|
|
15
15
|
from sglang.srt.layers.moe.topk import select_experts
|
16
|
+
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
16
17
|
from sglang.srt.utils import (
|
17
18
|
direct_register_custom_op,
|
18
19
|
get_bool_env_var,
|
@@ -22,28 +23,25 @@ from sglang.srt.utils import (
|
|
22
23
|
)
|
23
24
|
|
24
25
|
_is_hip = is_hip()
|
25
|
-
|
26
|
-
|
27
|
-
logger = logging.getLogger(__name__)
|
28
|
-
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
|
29
|
-
|
30
|
-
enable_moe_align_block_size_triton = bool(
|
31
|
-
int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
|
32
|
-
)
|
33
|
-
|
34
26
|
_is_cuda = is_cuda()
|
35
27
|
|
36
28
|
if _is_cuda:
|
37
29
|
from sgl_kernel import gelu_and_mul, silu_and_mul
|
38
|
-
|
39
|
-
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
|
40
30
|
else:
|
41
31
|
from vllm import _custom_ops as vllm_ops
|
32
|
+
from vllm._custom_ops import scaled_fp8_quant
|
42
33
|
|
43
34
|
if _is_cuda or _is_hip:
|
44
35
|
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
|
45
36
|
|
46
37
|
|
38
|
+
logger = logging.getLogger(__name__)
|
39
|
+
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
|
40
|
+
enable_moe_align_block_size_triton = bool(
|
41
|
+
int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
|
42
|
+
)
|
43
|
+
|
44
|
+
|
47
45
|
@triton.jit
|
48
46
|
def write_zeros_to_output(
|
49
47
|
c_ptr,
|
@@ -342,6 +340,7 @@ def fused_moe_kernel(
|
|
342
340
|
use_fp8_w8a8: tl.constexpr,
|
343
341
|
use_int8_w8a8: tl.constexpr,
|
344
342
|
use_int8_w8a16: tl.constexpr,
|
343
|
+
per_channel_quant: tl.constexpr,
|
345
344
|
even_Ks: tl.constexpr,
|
346
345
|
):
|
347
346
|
"""
|
@@ -416,20 +415,7 @@ def fused_moe_kernel(
|
|
416
415
|
)
|
417
416
|
b_scale = tl.load(b_scale_ptrs)
|
418
417
|
|
419
|
-
if use_fp8_w8a8:
|
420
|
-
# block-wise
|
421
|
-
if group_k > 0 and group_n > 0:
|
422
|
-
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
|
423
|
-
offs_bsn = offs_bn // group_n
|
424
|
-
b_scale_ptrs = (
|
425
|
-
b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
|
426
|
-
)
|
427
|
-
# tensor-wise
|
428
|
-
else:
|
429
|
-
a_scale = tl.load(a_scale_ptr)
|
430
|
-
b_scale = tl.load(b_scale_ptr + off_experts)
|
431
|
-
|
432
|
-
if use_int8_w8a8:
|
418
|
+
if use_fp8_w8a8 or use_int8_w8a8:
|
433
419
|
# block-wise
|
434
420
|
if group_k > 0 and group_n > 0:
|
435
421
|
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
|
@@ -438,8 +424,7 @@ def fused_moe_kernel(
|
|
438
424
|
b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
|
439
425
|
)
|
440
426
|
# channel-wise
|
441
|
-
|
442
|
-
# Load per-column scale for weights
|
427
|
+
elif per_channel_quant:
|
443
428
|
b_scale_ptrs = (
|
444
429
|
b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
|
445
430
|
)
|
@@ -447,6 +432,10 @@ def fused_moe_kernel(
|
|
447
432
|
# Load per-token scale for activations
|
448
433
|
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
|
449
434
|
a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:, None]
|
435
|
+
# tensor-wise
|
436
|
+
else:
|
437
|
+
a_scale = tl.load(a_scale_ptr)
|
438
|
+
b_scale = tl.load(b_scale_ptr + off_experts)
|
450
439
|
|
451
440
|
# -----------------------------------------------------------
|
452
441
|
# Iterate to compute a block of the C matrix.
|
@@ -711,12 +700,12 @@ def moe_align_block_size(
|
|
711
700
|
num_tokens_post_pad,
|
712
701
|
)
|
713
702
|
else:
|
714
|
-
token_cnts_buffer = torch.
|
703
|
+
token_cnts_buffer = torch.empty(
|
715
704
|
(num_experts + 1) * num_experts,
|
716
705
|
dtype=torch.int32,
|
717
706
|
device=topk_ids.device,
|
718
707
|
)
|
719
|
-
cumsum_buffer = torch.
|
708
|
+
cumsum_buffer = torch.empty(
|
720
709
|
num_experts + 1, dtype=torch.int32, device=topk_ids.device
|
721
710
|
)
|
722
711
|
|
@@ -753,6 +742,7 @@ def invoke_fused_moe_kernel(
|
|
753
742
|
use_int8_w8a8: bool,
|
754
743
|
use_int8_w8a16: bool,
|
755
744
|
use_int4_w4a16: bool,
|
745
|
+
per_channel_quant: bool,
|
756
746
|
block_shape: Optional[List[int]] = None,
|
757
747
|
no_combine: bool = False,
|
758
748
|
) -> None:
|
@@ -765,6 +755,8 @@ def invoke_fused_moe_kernel(
|
|
765
755
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
766
756
|
sglang_per_token_group_quant_fp8,
|
767
757
|
)
|
758
|
+
else:
|
759
|
+
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
768
760
|
|
769
761
|
assert topk_weights.stride(1) == 1
|
770
762
|
assert sorted_token_ids.stride(0) == 1
|
@@ -775,10 +767,10 @@ def invoke_fused_moe_kernel(
|
|
775
767
|
if block_shape is None:
|
776
768
|
# activation tensor-wise fp8 quantization, dynamic or static
|
777
769
|
padded_size = padding_size
|
778
|
-
|
779
|
-
|
780
|
-
|
781
|
-
|
770
|
+
# activations apply per-token quantization when weights apply per-channel quantization by default
|
771
|
+
A, A_scale = scaled_fp8_quant(
|
772
|
+
A, A_scale, use_per_token_if_dynamic=per_channel_quant
|
773
|
+
)
|
782
774
|
else:
|
783
775
|
# activation block-wise fp8 quantization
|
784
776
|
assert len(block_shape) == 2
|
@@ -794,6 +786,9 @@ def invoke_fused_moe_kernel(
|
|
794
786
|
assert B_scale is not None
|
795
787
|
if block_shape is None:
|
796
788
|
# activation channel-wise int8 quantization
|
789
|
+
assert (
|
790
|
+
per_channel_quant
|
791
|
+
), "int8 quantization only supports channel-wise quantization except for block-wise quantization"
|
797
792
|
A, A_scale = per_token_quant_int8(A)
|
798
793
|
else:
|
799
794
|
# activation block-wise int8 quantization
|
@@ -902,6 +897,7 @@ def invoke_fused_moe_kernel(
|
|
902
897
|
use_fp8_w8a8=use_fp8_w8a8,
|
903
898
|
use_int8_w8a8=use_int8_w8a8,
|
904
899
|
use_int8_w8a16=use_int8_w8a16,
|
900
|
+
per_channel_quant=per_channel_quant,
|
905
901
|
even_Ks=even_Ks,
|
906
902
|
**config,
|
907
903
|
)
|
@@ -953,7 +949,7 @@ def get_moe_configs(
|
|
953
949
|
logger.warning(
|
954
950
|
(
|
955
951
|
"Using default MoE config. Performance might be sub-optimal! "
|
956
|
-
"Config file not found at %s"
|
952
|
+
"Config file not found at %s, you can tune the config with https://github.com/sgl-project/sglang/blob/main/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py."
|
957
953
|
),
|
958
954
|
config_file_path,
|
959
955
|
)
|
@@ -1084,6 +1080,7 @@ def inplace_fused_experts(
|
|
1084
1080
|
use_int8_w8a8: bool = False,
|
1085
1081
|
use_int8_w8a16: bool = False,
|
1086
1082
|
use_int4_w4a16: bool = False,
|
1083
|
+
per_channel_quant: bool = False,
|
1087
1084
|
w1_scale: Optional[torch.Tensor] = None,
|
1088
1085
|
w2_scale: Optional[torch.Tensor] = None,
|
1089
1086
|
w1_zp: Optional[torch.Tensor] = None,
|
@@ -1105,6 +1102,7 @@ def inplace_fused_experts(
|
|
1105
1102
|
use_int8_w8a8,
|
1106
1103
|
use_int8_w8a16,
|
1107
1104
|
use_int4_w4a16,
|
1105
|
+
per_channel_quant,
|
1108
1106
|
w1_scale,
|
1109
1107
|
w2_scale,
|
1110
1108
|
w1_zp,
|
@@ -1127,6 +1125,7 @@ def inplace_fused_experts_fake(
|
|
1127
1125
|
use_int8_w8a8: bool = False,
|
1128
1126
|
use_int8_w8a16: bool = False,
|
1129
1127
|
use_int4_w4a16: bool = False,
|
1128
|
+
per_channel_quant: bool = False,
|
1130
1129
|
w1_scale: Optional[torch.Tensor] = None,
|
1131
1130
|
w2_scale: Optional[torch.Tensor] = None,
|
1132
1131
|
w1_zp: Optional[torch.Tensor] = None,
|
@@ -1158,6 +1157,7 @@ def outplace_fused_experts(
|
|
1158
1157
|
use_int8_w8a8: bool = False,
|
1159
1158
|
use_int8_w8a16: bool = False,
|
1160
1159
|
use_int4_w4a16: bool = False,
|
1160
|
+
per_channel_quant: bool = False,
|
1161
1161
|
w1_scale: Optional[torch.Tensor] = None,
|
1162
1162
|
w2_scale: Optional[torch.Tensor] = None,
|
1163
1163
|
w1_zp: Optional[torch.Tensor] = None,
|
@@ -1180,6 +1180,7 @@ def outplace_fused_experts(
|
|
1180
1180
|
use_int8_w8a8,
|
1181
1181
|
use_int8_w8a16,
|
1182
1182
|
use_int4_w4a16,
|
1183
|
+
per_channel_quant,
|
1183
1184
|
w1_scale,
|
1184
1185
|
w2_scale,
|
1185
1186
|
w1_zp,
|
@@ -1203,6 +1204,7 @@ def outplace_fused_experts_fake(
|
|
1203
1204
|
use_int8_w8a8: bool = False,
|
1204
1205
|
use_int8_w8a16: bool = False,
|
1205
1206
|
use_int4_w4a16: bool = False,
|
1207
|
+
per_channel_quant: bool = False,
|
1206
1208
|
w1_scale: Optional[torch.Tensor] = None,
|
1207
1209
|
w2_scale: Optional[torch.Tensor] = None,
|
1208
1210
|
w1_zp: Optional[torch.Tensor] = None,
|
@@ -1236,6 +1238,7 @@ def fused_experts(
|
|
1236
1238
|
use_int8_w8a8: bool = False,
|
1237
1239
|
use_int8_w8a16: bool = False,
|
1238
1240
|
use_int4_w4a16: bool = False,
|
1241
|
+
per_channel_quant: bool = False,
|
1239
1242
|
w1_scale: Optional[torch.Tensor] = None,
|
1240
1243
|
w2_scale: Optional[torch.Tensor] = None,
|
1241
1244
|
w1_zp: Optional[torch.Tensor] = None,
|
@@ -1259,6 +1262,7 @@ def fused_experts(
|
|
1259
1262
|
use_int8_w8a8,
|
1260
1263
|
use_int8_w8a16,
|
1261
1264
|
use_int4_w4a16,
|
1265
|
+
per_channel_quant,
|
1262
1266
|
w1_scale,
|
1263
1267
|
w2_scale,
|
1264
1268
|
w1_zp,
|
@@ -1281,6 +1285,7 @@ def fused_experts(
|
|
1281
1285
|
use_int8_w8a8,
|
1282
1286
|
use_int8_w8a16,
|
1283
1287
|
use_int4_w4a16,
|
1288
|
+
per_channel_quant,
|
1284
1289
|
w1_scale,
|
1285
1290
|
w2_scale,
|
1286
1291
|
w1_zp,
|
@@ -1305,6 +1310,7 @@ def fused_experts_impl(
|
|
1305
1310
|
use_int8_w8a8: bool = False,
|
1306
1311
|
use_int8_w8a16: bool = False,
|
1307
1312
|
use_int4_w4a16: bool = False,
|
1313
|
+
per_channel_quant: bool = False,
|
1308
1314
|
w1_scale: Optional[torch.Tensor] = None,
|
1309
1315
|
w2_scale: Optional[torch.Tensor] = None,
|
1310
1316
|
w1_zp: Optional[torch.Tensor] = None,
|
@@ -1441,6 +1447,7 @@ def fused_experts_impl(
|
|
1441
1447
|
use_int8_w8a8=use_int8_w8a8,
|
1442
1448
|
use_int8_w8a16=use_int8_w8a16,
|
1443
1449
|
use_int4_w4a16=use_int4_w4a16,
|
1450
|
+
per_channel_quant=per_channel_quant,
|
1444
1451
|
block_shape=block_shape,
|
1445
1452
|
)
|
1446
1453
|
if activation == "silu":
|
@@ -1484,6 +1491,7 @@ def fused_experts_impl(
|
|
1484
1491
|
use_int8_w8a8=use_int8_w8a8,
|
1485
1492
|
use_int8_w8a16=use_int8_w8a16,
|
1486
1493
|
use_int4_w4a16=use_int4_w4a16,
|
1494
|
+
per_channel_quant=per_channel_quant,
|
1487
1495
|
block_shape=block_shape,
|
1488
1496
|
)
|
1489
1497
|
|
@@ -1530,6 +1538,7 @@ def fused_moe(
|
|
1530
1538
|
use_int8_w8a8: bool = False,
|
1531
1539
|
use_int8_w8a16: bool = False,
|
1532
1540
|
use_int4_w4a16: bool = False,
|
1541
|
+
per_channel_quant: bool = False,
|
1533
1542
|
w1_scale: Optional[torch.Tensor] = None,
|
1534
1543
|
w2_scale: Optional[torch.Tensor] = None,
|
1535
1544
|
w1_zp: Optional[torch.Tensor] = None,
|
@@ -1538,6 +1547,7 @@ def fused_moe(
|
|
1538
1547
|
a2_scale: Optional[torch.Tensor] = None,
|
1539
1548
|
block_shape: Optional[List[int]] = None,
|
1540
1549
|
no_combine: bool = False,
|
1550
|
+
routed_scaling_factor: Optional[float] = None,
|
1541
1551
|
) -> torch.Tensor:
|
1542
1552
|
"""
|
1543
1553
|
This function computes a Mixture of Experts (MoE) layer using two sets of
|
@@ -1592,6 +1602,7 @@ def fused_moe(
|
|
1592
1602
|
topk_group=topk_group,
|
1593
1603
|
num_expert_group=num_expert_group,
|
1594
1604
|
custom_routing_function=custom_routing_function,
|
1605
|
+
routed_scaling_factor=routed_scaling_factor,
|
1595
1606
|
)
|
1596
1607
|
|
1597
1608
|
return fused_experts(
|
@@ -1606,6 +1617,7 @@ def fused_moe(
|
|
1606
1617
|
use_int8_w8a8=use_int8_w8a8,
|
1607
1618
|
use_int8_w8a16=use_int8_w8a16,
|
1608
1619
|
use_int4_w4a16=use_int4_w4a16,
|
1620
|
+
per_channel_quant=per_channel_quant,
|
1609
1621
|
w1_scale=w1_scale,
|
1610
1622
|
w2_scale=w2_scale,
|
1611
1623
|
w1_zp=w1_zp,
|
@@ -131,6 +131,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
131
131
|
apply_router_weight_on_input: bool = False,
|
132
132
|
inplace: bool = True,
|
133
133
|
no_combine: bool = False,
|
134
|
+
routed_scaling_factor: Optional[float] = None,
|
134
135
|
) -> torch.Tensor:
|
135
136
|
return self.forward(
|
136
137
|
x=x,
|
@@ -147,6 +148,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
147
148
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
148
149
|
inplace=inplace,
|
149
150
|
no_combine=no_combine,
|
151
|
+
routed_scaling_factor=routed_scaling_factor,
|
150
152
|
)
|
151
153
|
|
152
154
|
def forward_cuda(
|
@@ -165,6 +167,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
165
167
|
apply_router_weight_on_input: bool = False,
|
166
168
|
inplace: bool = True,
|
167
169
|
no_combine: bool = False,
|
170
|
+
routed_scaling_factor: Optional[float] = None,
|
168
171
|
) -> torch.Tensor:
|
169
172
|
topk_weights, topk_ids = select_experts(
|
170
173
|
hidden_states=x,
|
@@ -176,6 +179,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
176
179
|
num_expert_group=num_expert_group,
|
177
180
|
custom_routing_function=custom_routing_function,
|
178
181
|
correction_bias=correction_bias,
|
182
|
+
routed_scaling_factor=routed_scaling_factor,
|
179
183
|
)
|
180
184
|
|
181
185
|
if _is_hip and get_bool_env_var("CK_MOE"):
|
@@ -284,6 +288,7 @@ class FusedMoE(torch.nn.Module):
|
|
284
288
|
use_presharded_weights: bool = False,
|
285
289
|
inplace: bool = True,
|
286
290
|
no_combine: bool = False,
|
291
|
+
routed_scaling_factor: Optional[float] = None,
|
287
292
|
):
|
288
293
|
super().__init__()
|
289
294
|
|
@@ -293,6 +298,7 @@ class FusedMoE(torch.nn.Module):
|
|
293
298
|
self.tp_size = (
|
294
299
|
tp_size if tp_size is not None else get_tensor_model_parallel_world_size()
|
295
300
|
)
|
301
|
+
self.routed_scaling_factor = routed_scaling_factor
|
296
302
|
self.top_k = top_k
|
297
303
|
self.num_experts = num_experts
|
298
304
|
assert intermediate_size % self.tp_size == 0
|
@@ -637,6 +643,7 @@ class FusedMoE(torch.nn.Module):
|
|
637
643
|
correction_bias=self.correction_bias,
|
638
644
|
activation=self.activation,
|
639
645
|
apply_router_weight_on_input=self.apply_router_weight_on_input,
|
646
|
+
routed_scaling_factor=self.routed_scaling_factor,
|
640
647
|
)
|
641
648
|
|
642
649
|
if self.reduce_results and self.tp_size > 1:
|
sglang/srt/layers/moe/router.py
CHANGED
@@ -5,6 +5,9 @@ import triton
|
|
5
5
|
import triton.language as tl
|
6
6
|
|
7
7
|
from sglang.srt.layers.moe.topk import fused_topk
|
8
|
+
from sglang.srt.utils import is_hip
|
9
|
+
|
10
|
+
_is_hip = is_hip()
|
8
11
|
|
9
12
|
|
10
13
|
@triton.jit
|
@@ -116,10 +119,13 @@ def fused_moe_router_impl(
|
|
116
119
|
topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device)
|
117
120
|
|
118
121
|
grid = lambda meta: (bs,)
|
122
|
+
|
123
|
+
min_num_warps = 16 if _is_hip else 32
|
124
|
+
|
119
125
|
config = {
|
120
126
|
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
|
121
127
|
"num_warps": max(
|
122
|
-
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)),
|
128
|
+
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), min_num_warps), 4
|
123
129
|
),
|
124
130
|
}
|
125
131
|
|