sglang 0.5.0rc2__py3-none-any.whl → 0.5.1.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -6
- sglang/bench_one_batch_server.py +7 -2
- sglang/bench_serving.py +3 -3
- sglang/eval/llama3_eval.py +0 -1
- sglang/srt/configs/model_config.py +24 -9
- sglang/srt/configs/update_config.py +40 -5
- sglang/srt/constrained/xgrammar_backend.py +23 -11
- sglang/srt/conversation.py +2 -15
- sglang/srt/disaggregation/ascend/conn.py +1 -3
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +1 -1
- sglang/srt/disaggregation/launch_lb.py +7 -1
- sglang/srt/disaggregation/mini_lb.py +11 -5
- sglang/srt/disaggregation/mooncake/conn.py +141 -47
- sglang/srt/disaggregation/prefill.py +261 -5
- sglang/srt/disaggregation/utils.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/device_communicators/pynccl.py +68 -18
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
- sglang/srt/distributed/naive_distributed.py +112 -0
- sglang/srt/distributed/parallel_state.py +90 -4
- sglang/srt/entrypoints/context.py +20 -1
- sglang/srt/entrypoints/engine.py +27 -2
- sglang/srt/entrypoints/http_server.py +12 -0
- sglang/srt/entrypoints/openai/protocol.py +2 -2
- sglang/srt/entrypoints/openai/serving_chat.py +22 -6
- sglang/srt/entrypoints/openai/serving_completions.py +9 -1
- sglang/srt/entrypoints/openai/serving_responses.py +2 -2
- sglang/srt/eplb/expert_distribution.py +2 -3
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +24 -0
- sglang/srt/host_shared_memory.py +83 -0
- sglang/srt/layers/attention/ascend_backend.py +132 -22
- sglang/srt/layers/attention/flashattention_backend.py +24 -17
- sglang/srt/layers/attention/flashinfer_backend.py +11 -3
- sglang/srt/layers/attention/flashinfer_mla_backend.py +226 -76
- sglang/srt/layers/attention/triton_backend.py +85 -46
- sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
- sglang/srt/layers/attention/trtllm_mha_backend.py +390 -30
- sglang/srt/layers/attention/trtllm_mla_backend.py +39 -16
- sglang/srt/layers/attention/utils.py +94 -15
- sglang/srt/layers/attention/vision.py +40 -13
- sglang/srt/layers/attention/vision_utils.py +65 -0
- sglang/srt/layers/communicator.py +51 -3
- sglang/srt/layers/dp_attention.py +23 -4
- sglang/srt/layers/elementwise.py +94 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
- sglang/srt/layers/layernorm.py +8 -1
- sglang/srt/layers/linear.py +24 -0
- sglang/srt/layers/logits_processor.py +5 -1
- sglang/srt/layers/moe/__init__.py +31 -0
- sglang/srt/layers/moe/ep_moe/layer.py +37 -33
- sglang/srt/layers/moe/fused_moe_native.py +14 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
- sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
- sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
- sglang/srt/layers/moe/moe_runner/base.py +13 -0
- sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
- sglang/srt/layers/moe/router.py +15 -9
- sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
- sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +167 -83
- sglang/srt/layers/moe/utils.py +159 -18
- sglang/srt/layers/quantization/__init__.py +13 -14
- sglang/srt/layers/quantization/awq.py +7 -7
- sglang/srt/layers/quantization/base_config.py +2 -6
- sglang/srt/layers/quantization/blockwise_int8.py +4 -12
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -28
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -0
- sglang/srt/layers/quantization/fp8.py +127 -119
- sglang/srt/layers/quantization/fp8_kernel.py +195 -24
- sglang/srt/layers/quantization/fp8_utils.py +34 -9
- sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
- sglang/srt/layers/quantization/gptq.py +5 -4
- sglang/srt/layers/quantization/marlin_utils.py +11 -3
- sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
- sglang/srt/layers/quantization/modelopt_quant.py +165 -68
- sglang/srt/layers/quantization/moe_wna16.py +10 -15
- sglang/srt/layers/quantization/mxfp4.py +206 -37
- sglang/srt/layers/quantization/quark/quark.py +390 -0
- sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
- sglang/srt/layers/quantization/unquant.py +34 -70
- sglang/srt/layers/quantization/utils.py +25 -0
- sglang/srt/layers/quantization/w4afp8.py +7 -8
- sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
- sglang/srt/layers/quantization/w8a8_int8.py +5 -13
- sglang/srt/layers/radix_attention.py +6 -0
- sglang/srt/layers/rotary_embedding.py +1 -0
- sglang/srt/lora/lora_manager.py +21 -22
- sglang/srt/lora/lora_registry.py +3 -3
- sglang/srt/lora/mem_pool.py +26 -24
- sglang/srt/lora/utils.py +10 -12
- sglang/srt/managers/cache_controller.py +76 -18
- sglang/srt/managers/detokenizer_manager.py +10 -2
- sglang/srt/managers/io_struct.py +9 -0
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/schedule_batch.py +4 -9
- sglang/srt/managers/scheduler.py +25 -16
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/template_manager.py +7 -5
- sglang/srt/managers/tokenizer_manager.py +60 -21
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/managers/utils.py +59 -1
- sglang/srt/mem_cache/allocator.py +7 -5
- sglang/srt/mem_cache/allocator_ascend.py +0 -11
- sglang/srt/mem_cache/hicache_storage.py +14 -4
- sglang/srt/mem_cache/memory_pool.py +3 -3
- sglang/srt/mem_cache/memory_pool_host.py +35 -2
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
- sglang/srt/model_executor/cuda_graph_runner.py +25 -12
- sglang/srt/model_executor/forward_batch_info.py +4 -1
- sglang/srt/model_executor/model_runner.py +43 -32
- sglang/srt/model_executor/npu_graph_runner.py +94 -0
- sglang/srt/model_loader/loader.py +24 -6
- sglang/srt/models/dbrx.py +12 -6
- sglang/srt/models/deepseek.py +2 -1
- sglang/srt/models/deepseek_nextn.py +3 -1
- sglang/srt/models/deepseek_v2.py +224 -223
- sglang/srt/models/ernie4.py +2 -2
- sglang/srt/models/glm4_moe.py +25 -63
- sglang/srt/models/glm4v.py +52 -1
- sglang/srt/models/glm4v_moe.py +8 -11
- sglang/srt/models/gpt_oss.py +34 -74
- sglang/srt/models/granitemoe.py +0 -1
- sglang/srt/models/grok.py +375 -51
- sglang/srt/models/interns1.py +12 -47
- sglang/srt/models/internvl.py +6 -51
- sglang/srt/models/llama4.py +0 -2
- sglang/srt/models/minicpm3.py +0 -1
- sglang/srt/models/mixtral.py +0 -2
- sglang/srt/models/nemotron_nas.py +435 -0
- sglang/srt/models/olmoe.py +0 -1
- sglang/srt/models/phi4mm.py +3 -21
- sglang/srt/models/qwen2_5_vl.py +2 -0
- sglang/srt/models/qwen2_moe.py +3 -18
- sglang/srt/models/qwen3.py +2 -2
- sglang/srt/models/qwen3_classification.py +7 -1
- sglang/srt/models/qwen3_moe.py +9 -38
- sglang/srt/models/step3_vl.py +2 -1
- sglang/srt/models/xverse_moe.py +11 -5
- sglang/srt/multimodal/processors/base_processor.py +3 -3
- sglang/srt/multimodal/processors/internvl.py +7 -2
- sglang/srt/multimodal/processors/llava.py +11 -7
- sglang/srt/offloader.py +433 -0
- sglang/srt/operations.py +6 -1
- sglang/srt/reasoning_parser.py +4 -3
- sglang/srt/server_args.py +237 -104
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
- sglang/srt/speculative/eagle_utils.py +36 -13
- sglang/srt/speculative/eagle_worker.py +56 -3
- sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
- sglang/srt/two_batch_overlap.py +16 -11
- sglang/srt/utils.py +68 -70
- sglang/test/runners.py +8 -5
- sglang/test/test_block_fp8.py +5 -6
- sglang/test/test_block_fp8_ep.py +13 -19
- sglang/test/test_cutlass_moe.py +4 -6
- sglang/test/test_cutlass_w4a8_moe.py +4 -3
- sglang/test/test_fp4_moe.py +4 -3
- sglang/test/test_utils.py +7 -0
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/METADATA +7 -7
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/RECORD +179 -161
- sglang/srt/layers/quantization/fp4.py +0 -557
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 16,
|
4
|
+
"BLOCK_SIZE_N": 128,
|
5
|
+
"BLOCK_SIZE_K": 128,
|
6
|
+
"GROUP_SIZE_M": 1,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 5
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 16,
|
12
|
+
"BLOCK_SIZE_N": 128,
|
13
|
+
"BLOCK_SIZE_K": 128,
|
14
|
+
"GROUP_SIZE_M": 1,
|
15
|
+
"num_warps": 4,
|
16
|
+
"num_stages": 4
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 16,
|
20
|
+
"BLOCK_SIZE_N": 128,
|
21
|
+
"BLOCK_SIZE_K": 128,
|
22
|
+
"GROUP_SIZE_M": 16,
|
23
|
+
"num_warps": 4,
|
24
|
+
"num_stages": 4
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 16,
|
28
|
+
"BLOCK_SIZE_N": 64,
|
29
|
+
"BLOCK_SIZE_K": 128,
|
30
|
+
"GROUP_SIZE_M": 1,
|
31
|
+
"num_warps": 4,
|
32
|
+
"num_stages": 3
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 16,
|
36
|
+
"BLOCK_SIZE_N": 128,
|
37
|
+
"BLOCK_SIZE_K": 128,
|
38
|
+
"GROUP_SIZE_M": 1,
|
39
|
+
"num_warps": 4,
|
40
|
+
"num_stages": 4
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 16,
|
44
|
+
"BLOCK_SIZE_N": 128,
|
45
|
+
"BLOCK_SIZE_K": 128,
|
46
|
+
"GROUP_SIZE_M": 1,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 4
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 16,
|
52
|
+
"BLOCK_SIZE_N": 128,
|
53
|
+
"BLOCK_SIZE_K": 128,
|
54
|
+
"GROUP_SIZE_M": 1,
|
55
|
+
"num_warps": 4,
|
56
|
+
"num_stages": 5
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 16,
|
60
|
+
"BLOCK_SIZE_N": 128,
|
61
|
+
"BLOCK_SIZE_K": 128,
|
62
|
+
"GROUP_SIZE_M": 1,
|
63
|
+
"num_warps": 4,
|
64
|
+
"num_stages": 5
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 16,
|
68
|
+
"BLOCK_SIZE_N": 128,
|
69
|
+
"BLOCK_SIZE_K": 128,
|
70
|
+
"GROUP_SIZE_M": 1,
|
71
|
+
"num_warps": 4,
|
72
|
+
"num_stages": 3
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 16,
|
76
|
+
"BLOCK_SIZE_N": 128,
|
77
|
+
"BLOCK_SIZE_K": 128,
|
78
|
+
"GROUP_SIZE_M": 1,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 3
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 16,
|
84
|
+
"BLOCK_SIZE_N": 128,
|
85
|
+
"BLOCK_SIZE_K": 128,
|
86
|
+
"GROUP_SIZE_M": 1,
|
87
|
+
"num_warps": 4,
|
88
|
+
"num_stages": 3
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 16,
|
92
|
+
"BLOCK_SIZE_N": 128,
|
93
|
+
"BLOCK_SIZE_K": 128,
|
94
|
+
"GROUP_SIZE_M": 1,
|
95
|
+
"num_warps": 4,
|
96
|
+
"num_stages": 3
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 64,
|
100
|
+
"BLOCK_SIZE_N": 128,
|
101
|
+
"BLOCK_SIZE_K": 64,
|
102
|
+
"GROUP_SIZE_M": 1,
|
103
|
+
"num_warps": 8,
|
104
|
+
"num_stages": 4
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 128,
|
108
|
+
"BLOCK_SIZE_N": 256,
|
109
|
+
"BLOCK_SIZE_K": 64,
|
110
|
+
"GROUP_SIZE_M": 1,
|
111
|
+
"num_warps": 4,
|
112
|
+
"num_stages": 4
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 128,
|
116
|
+
"BLOCK_SIZE_N": 256,
|
117
|
+
"BLOCK_SIZE_K": 64,
|
118
|
+
"GROUP_SIZE_M": 1,
|
119
|
+
"num_warps": 4,
|
120
|
+
"num_stages": 4
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 128,
|
124
|
+
"BLOCK_SIZE_N": 256,
|
125
|
+
"BLOCK_SIZE_K": 64,
|
126
|
+
"GROUP_SIZE_M": 1,
|
127
|
+
"num_warps": 4,
|
128
|
+
"num_stages": 4
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 256,
|
132
|
+
"BLOCK_SIZE_N": 256,
|
133
|
+
"BLOCK_SIZE_K": 64,
|
134
|
+
"GROUP_SIZE_M": 1,
|
135
|
+
"num_warps": 8,
|
136
|
+
"num_stages": 5
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 128,
|
140
|
+
"BLOCK_SIZE_N": 256,
|
141
|
+
"BLOCK_SIZE_K": 64,
|
142
|
+
"GROUP_SIZE_M": 1,
|
143
|
+
"num_warps": 4,
|
144
|
+
"num_stages": 4
|
145
|
+
}
|
146
|
+
}
|
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 16,
|
4
|
+
"BLOCK_SIZE_N": 64,
|
5
|
+
"BLOCK_SIZE_K": 128,
|
6
|
+
"GROUP_SIZE_M": 1,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 2
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 16,
|
12
|
+
"BLOCK_SIZE_N": 64,
|
13
|
+
"BLOCK_SIZE_K": 128,
|
14
|
+
"GROUP_SIZE_M": 32,
|
15
|
+
"num_warps": 4,
|
16
|
+
"num_stages": 2
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 16,
|
20
|
+
"BLOCK_SIZE_N": 32,
|
21
|
+
"BLOCK_SIZE_K": 256,
|
22
|
+
"GROUP_SIZE_M": 64,
|
23
|
+
"num_warps": 4,
|
24
|
+
"num_stages": 2
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 16,
|
28
|
+
"BLOCK_SIZE_N": 64,
|
29
|
+
"BLOCK_SIZE_K": 256,
|
30
|
+
"GROUP_SIZE_M": 32,
|
31
|
+
"num_warps": 4,
|
32
|
+
"num_stages": 3
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 32,
|
36
|
+
"BLOCK_SIZE_N": 64,
|
37
|
+
"BLOCK_SIZE_K": 256,
|
38
|
+
"GROUP_SIZE_M": 1,
|
39
|
+
"num_warps": 8,
|
40
|
+
"num_stages": 2
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 16,
|
44
|
+
"BLOCK_SIZE_N": 64,
|
45
|
+
"BLOCK_SIZE_K": 256,
|
46
|
+
"GROUP_SIZE_M": 32,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 3
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 16,
|
52
|
+
"BLOCK_SIZE_N": 32,
|
53
|
+
"BLOCK_SIZE_K": 128,
|
54
|
+
"GROUP_SIZE_M": 16,
|
55
|
+
"num_warps": 4,
|
56
|
+
"num_stages": 4
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 16,
|
60
|
+
"BLOCK_SIZE_N": 32,
|
61
|
+
"BLOCK_SIZE_K": 128,
|
62
|
+
"GROUP_SIZE_M": 16,
|
63
|
+
"num_warps": 8,
|
64
|
+
"num_stages": 3
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 16,
|
68
|
+
"BLOCK_SIZE_N": 32,
|
69
|
+
"BLOCK_SIZE_K": 256,
|
70
|
+
"GROUP_SIZE_M": 16,
|
71
|
+
"num_warps": 8,
|
72
|
+
"num_stages": 2
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 16,
|
76
|
+
"BLOCK_SIZE_N": 64,
|
77
|
+
"BLOCK_SIZE_K": 256,
|
78
|
+
"GROUP_SIZE_M": 16,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 3
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 16,
|
84
|
+
"BLOCK_SIZE_N": 64,
|
85
|
+
"BLOCK_SIZE_K": 256,
|
86
|
+
"GROUP_SIZE_M": 1,
|
87
|
+
"num_warps": 4,
|
88
|
+
"num_stages": 2
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 32,
|
92
|
+
"BLOCK_SIZE_N": 64,
|
93
|
+
"BLOCK_SIZE_K": 256,
|
94
|
+
"GROUP_SIZE_M": 32,
|
95
|
+
"num_warps": 4,
|
96
|
+
"num_stages": 2
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 32,
|
100
|
+
"BLOCK_SIZE_N": 64,
|
101
|
+
"BLOCK_SIZE_K": 256,
|
102
|
+
"GROUP_SIZE_M": 64,
|
103
|
+
"num_warps": 4,
|
104
|
+
"num_stages": 2
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 64,
|
108
|
+
"BLOCK_SIZE_N": 128,
|
109
|
+
"BLOCK_SIZE_K": 256,
|
110
|
+
"GROUP_SIZE_M": 16,
|
111
|
+
"num_warps": 4,
|
112
|
+
"num_stages": 2
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 128,
|
116
|
+
"BLOCK_SIZE_N": 64,
|
117
|
+
"BLOCK_SIZE_K": 256,
|
118
|
+
"GROUP_SIZE_M": 16,
|
119
|
+
"num_warps": 4,
|
120
|
+
"num_stages": 2
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 64,
|
124
|
+
"BLOCK_SIZE_N": 128,
|
125
|
+
"BLOCK_SIZE_K": 128,
|
126
|
+
"GROUP_SIZE_M": 32,
|
127
|
+
"num_warps": 4,
|
128
|
+
"num_stages": 3
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 128,
|
132
|
+
"BLOCK_SIZE_N": 256,
|
133
|
+
"BLOCK_SIZE_K": 128,
|
134
|
+
"GROUP_SIZE_M": 1,
|
135
|
+
"num_warps": 8,
|
136
|
+
"num_stages": 3
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 128,
|
140
|
+
"BLOCK_SIZE_N": 256,
|
141
|
+
"BLOCK_SIZE_K": 64,
|
142
|
+
"GROUP_SIZE_M": 32,
|
143
|
+
"num_warps": 8,
|
144
|
+
"num_stages": 3
|
145
|
+
}
|
146
|
+
}
|
@@ -2,17 +2,20 @@
|
|
2
2
|
|
3
3
|
"""Fused MoE kernel."""
|
4
4
|
|
5
|
+
from __future__ import annotations
|
6
|
+
|
5
7
|
import functools
|
6
8
|
import json
|
7
9
|
import logging
|
8
10
|
import os
|
9
|
-
from typing import Any, Dict, List, Optional, Tuple
|
11
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
10
12
|
|
11
13
|
import torch
|
12
14
|
import triton
|
13
15
|
import triton.language as tl
|
14
16
|
|
15
|
-
from sglang.srt.layers.moe.
|
17
|
+
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
18
|
+
from sglang.srt.layers.moe.topk import StandardTopKOutput
|
16
19
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
17
20
|
per_token_group_quant_fp8,
|
18
21
|
scaled_fp8_quant,
|
@@ -46,13 +49,15 @@ if _is_cuda:
|
|
46
49
|
elif _is_cpu and _is_cpu_amx_available:
|
47
50
|
pass
|
48
51
|
elif _is_hip:
|
49
|
-
from
|
52
|
+
from sgl_kernel import gelu_and_mul, silu_and_mul
|
50
53
|
|
51
54
|
if _use_aiter:
|
52
55
|
try:
|
53
56
|
from aiter import moe_sum
|
54
57
|
except ImportError:
|
55
58
|
raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")
|
59
|
+
else:
|
60
|
+
from vllm import _custom_ops as vllm_ops
|
56
61
|
|
57
62
|
|
58
63
|
if _is_cuda or _is_hip:
|
@@ -1025,8 +1030,8 @@ def inplace_fused_experts(
|
|
1025
1030
|
a2_scale: Optional[torch.Tensor] = None,
|
1026
1031
|
block_shape: Optional[List[int]] = None,
|
1027
1032
|
routed_scaling_factor: Optional[float] = None,
|
1028
|
-
|
1029
|
-
|
1033
|
+
gemm1_alpha: Optional[float] = None,
|
1034
|
+
gemm1_limit: Optional[float] = None,
|
1030
1035
|
) -> None:
|
1031
1036
|
fused_experts_impl(
|
1032
1037
|
hidden_states,
|
@@ -1053,8 +1058,8 @@ def inplace_fused_experts(
|
|
1053
1058
|
block_shape,
|
1054
1059
|
False,
|
1055
1060
|
routed_scaling_factor,
|
1056
|
-
|
1057
|
-
|
1061
|
+
gemm1_alpha,
|
1062
|
+
gemm1_limit,
|
1058
1063
|
)
|
1059
1064
|
|
1060
1065
|
|
@@ -1081,8 +1086,8 @@ def inplace_fused_experts_fake(
|
|
1081
1086
|
a2_scale: Optional[torch.Tensor] = None,
|
1082
1087
|
block_shape: Optional[List[int]] = None,
|
1083
1088
|
routed_scaling_factor: Optional[float] = None,
|
1084
|
-
|
1085
|
-
|
1089
|
+
gemm1_alpha: Optional[float] = None,
|
1090
|
+
gemm1_limit: Optional[float] = None,
|
1086
1091
|
) -> None:
|
1087
1092
|
pass
|
1088
1093
|
|
@@ -1119,8 +1124,8 @@ def outplace_fused_experts(
|
|
1119
1124
|
block_shape: Optional[List[int]] = None,
|
1120
1125
|
no_combine: bool = False,
|
1121
1126
|
routed_scaling_factor: Optional[float] = None,
|
1122
|
-
|
1123
|
-
|
1127
|
+
gemm1_alpha: Optional[float] = None,
|
1128
|
+
gemm1_limit: Optional[float] = None,
|
1124
1129
|
) -> torch.Tensor:
|
1125
1130
|
return fused_experts_impl(
|
1126
1131
|
hidden_states,
|
@@ -1147,8 +1152,8 @@ def outplace_fused_experts(
|
|
1147
1152
|
block_shape,
|
1148
1153
|
no_combine=no_combine,
|
1149
1154
|
routed_scaling_factor=routed_scaling_factor,
|
1150
|
-
|
1151
|
-
|
1155
|
+
gemm1_alpha=gemm1_alpha,
|
1156
|
+
gemm1_limit=gemm1_limit,
|
1152
1157
|
)
|
1153
1158
|
|
1154
1159
|
|
@@ -1176,8 +1181,8 @@ def outplace_fused_experts_fake(
|
|
1176
1181
|
block_shape: Optional[List[int]] = None,
|
1177
1182
|
no_combine: bool = False,
|
1178
1183
|
routed_scaling_factor: Optional[float] = None,
|
1179
|
-
|
1180
|
-
|
1184
|
+
gemm1_alpha: Optional[float] = None,
|
1185
|
+
gemm1_limit: Optional[float] = None,
|
1181
1186
|
) -> torch.Tensor:
|
1182
1187
|
return torch.empty_like(hidden_states)
|
1183
1188
|
|
@@ -1194,12 +1199,10 @@ def fused_experts(
|
|
1194
1199
|
hidden_states: torch.Tensor,
|
1195
1200
|
w1: torch.Tensor,
|
1196
1201
|
w2: torch.Tensor,
|
1197
|
-
topk_output:
|
1202
|
+
topk_output: StandardTopKOutput,
|
1203
|
+
moe_runner_config: MoeRunnerConfig,
|
1198
1204
|
b1: Optional[torch.Tensor] = None,
|
1199
1205
|
b2: Optional[torch.Tensor] = None,
|
1200
|
-
inplace: bool = False,
|
1201
|
-
activation: str = "silu",
|
1202
|
-
apply_router_weight_on_input: bool = False,
|
1203
1206
|
use_fp8_w8a8: bool = False,
|
1204
1207
|
use_int8_w8a8: bool = False,
|
1205
1208
|
use_int8_w8a16: bool = False,
|
@@ -1212,14 +1215,10 @@ def fused_experts(
|
|
1212
1215
|
a1_scale: Optional[torch.Tensor] = None,
|
1213
1216
|
a2_scale: Optional[torch.Tensor] = None,
|
1214
1217
|
block_shape: Optional[List[int]] = None,
|
1215
|
-
no_combine: bool = False,
|
1216
|
-
routed_scaling_factor: Optional[float] = None,
|
1217
|
-
activation_alpha: Optional[float] = None,
|
1218
|
-
swiglu_limit: Optional[float] = None,
|
1219
1218
|
):
|
1220
1219
|
topk_weights, topk_ids, _ = topk_output
|
1221
|
-
if inplace:
|
1222
|
-
assert not no_combine, "no combine + inplace makes no sense"
|
1220
|
+
if moe_runner_config.inplace:
|
1221
|
+
assert not moe_runner_config.no_combine, "no combine + inplace makes no sense"
|
1223
1222
|
torch.ops.sglang.inplace_fused_experts(
|
1224
1223
|
hidden_states,
|
1225
1224
|
w1,
|
@@ -1228,8 +1227,8 @@ def fused_experts(
|
|
1228
1227
|
topk_ids,
|
1229
1228
|
b1,
|
1230
1229
|
b2,
|
1231
|
-
activation,
|
1232
|
-
apply_router_weight_on_input,
|
1230
|
+
moe_runner_config.activation,
|
1231
|
+
moe_runner_config.apply_router_weight_on_input,
|
1233
1232
|
use_fp8_w8a8,
|
1234
1233
|
use_int8_w8a8,
|
1235
1234
|
use_int8_w8a16,
|
@@ -1242,9 +1241,9 @@ def fused_experts(
|
|
1242
1241
|
a1_scale,
|
1243
1242
|
a2_scale,
|
1244
1243
|
block_shape,
|
1245
|
-
routed_scaling_factor,
|
1246
|
-
|
1247
|
-
|
1244
|
+
moe_runner_config.routed_scaling_factor,
|
1245
|
+
moe_runner_config.gemm1_alpha,
|
1246
|
+
moe_runner_config.gemm1_clamp_limit,
|
1248
1247
|
)
|
1249
1248
|
return hidden_states
|
1250
1249
|
else:
|
@@ -1256,8 +1255,8 @@ def fused_experts(
|
|
1256
1255
|
topk_ids,
|
1257
1256
|
b1,
|
1258
1257
|
b2,
|
1259
|
-
activation,
|
1260
|
-
apply_router_weight_on_input,
|
1258
|
+
moe_runner_config.activation,
|
1259
|
+
moe_runner_config.apply_router_weight_on_input,
|
1261
1260
|
use_fp8_w8a8,
|
1262
1261
|
use_int8_w8a8,
|
1263
1262
|
use_int8_w8a16,
|
@@ -1270,10 +1269,10 @@ def fused_experts(
|
|
1270
1269
|
a1_scale,
|
1271
1270
|
a2_scale,
|
1272
1271
|
block_shape,
|
1273
|
-
no_combine=no_combine,
|
1274
|
-
routed_scaling_factor=routed_scaling_factor,
|
1275
|
-
|
1276
|
-
|
1272
|
+
no_combine=moe_runner_config.no_combine,
|
1273
|
+
routed_scaling_factor=moe_runner_config.routed_scaling_factor,
|
1274
|
+
gemm1_alpha=moe_runner_config.gemm1_alpha,
|
1275
|
+
gemm1_limit=moe_runner_config.gemm1_clamp_limit,
|
1277
1276
|
)
|
1278
1277
|
|
1279
1278
|
|
@@ -1370,11 +1369,11 @@ def moe_sum_reduce_torch_compile(x, out, routed_scaling_factor):
|
|
1370
1369
|
|
1371
1370
|
|
1372
1371
|
@torch.compile
|
1373
|
-
def swiglu_with_alpha_and_limit(x,
|
1372
|
+
def swiglu_with_alpha_and_limit(x, gemm1_alpha, gemm1_limit):
|
1374
1373
|
gate, up = x[..., ::2], x[..., 1::2]
|
1375
|
-
gate = gate.clamp(min=None, max=
|
1376
|
-
up = up.clamp(min=-
|
1377
|
-
return gate * torch.sigmoid(gate *
|
1374
|
+
gate = gate.clamp(min=None, max=gemm1_limit)
|
1375
|
+
up = up.clamp(min=-gemm1_limit, max=gemm1_limit)
|
1376
|
+
return gate * torch.sigmoid(gate * gemm1_alpha) * (up + 1)
|
1378
1377
|
|
1379
1378
|
|
1380
1379
|
def fused_experts_impl(
|
@@ -1402,8 +1401,8 @@ def fused_experts_impl(
|
|
1402
1401
|
block_shape: Optional[List[int]] = None,
|
1403
1402
|
no_combine: bool = False,
|
1404
1403
|
routed_scaling_factor: Optional[float] = None,
|
1405
|
-
|
1406
|
-
|
1404
|
+
gemm1_alpha: Optional[float] = None,
|
1405
|
+
gemm1_limit: Optional[float] = None,
|
1407
1406
|
):
|
1408
1407
|
padded_size = padding_size
|
1409
1408
|
if not (use_fp8_w8a8 or use_int8_w8a8) or block_shape is not None or _use_aiter:
|
@@ -1533,25 +1532,23 @@ def fused_experts_impl(
|
|
1533
1532
|
block_shape=block_shape,
|
1534
1533
|
)
|
1535
1534
|
if activation == "silu":
|
1536
|
-
if
|
1537
|
-
assert
|
1535
|
+
if gemm1_alpha is not None:
|
1536
|
+
assert gemm1_limit is not None
|
1538
1537
|
intermediate_cache2 = swiglu_with_alpha_and_limit(
|
1539
1538
|
intermediate_cache1.view(-1, N),
|
1540
|
-
|
1541
|
-
|
1539
|
+
gemm1_alpha,
|
1540
|
+
gemm1_limit,
|
1542
1541
|
)
|
1543
|
-
elif _is_cuda:
|
1542
|
+
elif _is_cuda or _is_hip:
|
1544
1543
|
silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
|
1545
1544
|
else:
|
1546
1545
|
vllm_ops.silu_and_mul(
|
1547
1546
|
intermediate_cache2, intermediate_cache1.view(-1, N)
|
1548
1547
|
)
|
1549
1548
|
elif activation == "gelu":
|
1550
|
-
assert
|
1551
|
-
|
1552
|
-
|
1553
|
-
assert swiglu_limit is None, "swiglu_limit is not supported for gelu"
|
1554
|
-
if _is_cuda:
|
1549
|
+
assert gemm1_alpha is None, "gemm1_alpha is not supported for gelu"
|
1550
|
+
assert gemm1_limit is None, "gemm1_limit is not supported for gelu"
|
1551
|
+
if _is_cuda or _is_hip:
|
1555
1552
|
gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
|
1556
1553
|
else:
|
1557
1554
|
vllm_ops.gelu_and_mul(
|
@@ -1624,10 +1621,19 @@ def fused_experts_impl(
|
|
1624
1621
|
out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
1625
1622
|
)
|
1626
1623
|
else:
|
1627
|
-
|
1628
|
-
|
1629
|
-
|
1630
|
-
|
1624
|
+
# According to micro benchmark results, torch.compile can get better performance for small token.
|
1625
|
+
if tokens_in_chunk <= 32:
|
1626
|
+
moe_sum_reduce_torch_compile(
|
1627
|
+
intermediate_cache3.view(*intermediate_cache3.shape),
|
1628
|
+
out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
1629
|
+
routed_scaling_factor,
|
1630
|
+
)
|
1631
|
+
else:
|
1632
|
+
moe_sum_reduce_triton(
|
1633
|
+
intermediate_cache3.view(*intermediate_cache3.shape),
|
1634
|
+
out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
1635
|
+
routed_scaling_factor,
|
1636
|
+
)
|
1631
1637
|
else:
|
1632
1638
|
vllm_ops.moe_sum(
|
1633
1639
|
intermediate_cache3.view(*intermediate_cache3.shape),
|
@@ -1641,12 +1647,10 @@ def fused_moe(
|
|
1641
1647
|
hidden_states: torch.Tensor,
|
1642
1648
|
w1: torch.Tensor,
|
1643
1649
|
w2: torch.Tensor,
|
1644
|
-
topk_output:
|
1650
|
+
topk_output: StandardTopKOutput,
|
1651
|
+
moe_runner_config: MoeRunnerConfig = MoeRunnerConfig(),
|
1645
1652
|
b1: Optional[torch.Tensor] = None,
|
1646
1653
|
b2: Optional[torch.Tensor] = None,
|
1647
|
-
inplace: bool = False,
|
1648
|
-
activation: str = "silu",
|
1649
|
-
apply_router_weight_on_input: bool = False,
|
1650
1654
|
use_fp8_w8a8: bool = False,
|
1651
1655
|
use_int8_w8a8: bool = False,
|
1652
1656
|
use_int8_w8a16: bool = False,
|
@@ -1659,10 +1663,6 @@ def fused_moe(
|
|
1659
1663
|
a1_scale: Optional[torch.Tensor] = None,
|
1660
1664
|
a2_scale: Optional[torch.Tensor] = None,
|
1661
1665
|
block_shape: Optional[List[int]] = None,
|
1662
|
-
no_combine: bool = False,
|
1663
|
-
routed_scaling_factor: Optional[float] = None,
|
1664
|
-
activation_alpha: Optional[float] = None,
|
1665
|
-
swiglu_limit: Optional[float] = None,
|
1666
1666
|
) -> torch.Tensor:
|
1667
1667
|
"""
|
1668
1668
|
This function computes a Mixture of Experts (MoE) layer using two sets of
|
@@ -1672,11 +1672,10 @@ def fused_moe(
|
|
1672
1672
|
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
|
1673
1673
|
- w1 (torch.Tensor): The first set of expert weights.
|
1674
1674
|
- w2 (torch.Tensor): The second set of expert weights.
|
1675
|
-
- topk_output (
|
1675
|
+
- topk_output (StandardTopKOutput): The top-k output of the experts.
|
1676
|
+
- moe_runner_config (MoeRunnerConfig): The configuration for the MoE runner.
|
1676
1677
|
- b1 (Optional[torch.Tensor]): Optional bias for w1.
|
1677
1678
|
- b2 (Optional[torch.Tensor]): Optional bias for w2.
|
1678
|
-
- inplace (bool): If True, perform the operation in-place.
|
1679
|
-
Defaults to False.
|
1680
1679
|
- use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
|
1681
1680
|
products for w1 and w2. Defaults to False.
|
1682
1681
|
- use_int8_w8a8 (bool): If True, use int8 arithmetic to compute the inner
|
@@ -1696,9 +1695,9 @@ def fused_moe(
|
|
1696
1695
|
a2.
|
1697
1696
|
- block_shape: (Optional[List[int]]): Optional block size for block-wise
|
1698
1697
|
quantization.
|
1699
|
-
-
|
1698
|
+
- gemm1_alpha (Optional[float]): Optional gemm1_alpha for the activation
|
1700
1699
|
function.
|
1701
|
-
-
|
1700
|
+
- gemm1_limit (Optional[float]): Optional gemm1_limit for the swiglu activation
|
1702
1701
|
function.
|
1703
1702
|
|
1704
1703
|
Returns:
|
@@ -1710,11 +1709,9 @@ def fused_moe(
|
|
1710
1709
|
w1,
|
1711
1710
|
w2,
|
1712
1711
|
topk_output,
|
1712
|
+
moe_runner_config=moe_runner_config,
|
1713
1713
|
b1=b1,
|
1714
1714
|
b2=b2,
|
1715
|
-
inplace=inplace,
|
1716
|
-
activation=activation,
|
1717
|
-
apply_router_weight_on_input=apply_router_weight_on_input,
|
1718
1715
|
use_fp8_w8a8=use_fp8_w8a8,
|
1719
1716
|
use_int8_w8a8=use_int8_w8a8,
|
1720
1717
|
use_int8_w8a16=use_int8_w8a16,
|
@@ -1727,8 +1724,4 @@ def fused_moe(
|
|
1727
1724
|
a1_scale=a1_scale,
|
1728
1725
|
a2_scale=a2_scale,
|
1729
1726
|
block_shape=block_shape,
|
1730
|
-
no_combine=no_combine,
|
1731
|
-
routed_scaling_factor=routed_scaling_factor,
|
1732
|
-
activation_alpha=activation_alpha,
|
1733
|
-
swiglu_limit=swiglu_limit,
|
1734
1727
|
)
|