sglang 0.4.9.post3__py3-none-any.whl → 0.4.9.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/lang/chat_template.py +21 -0
- sglang/srt/_custom_ops.py +29 -1
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/model_config.py +5 -1
- sglang/srt/constrained/base_grammar_backend.py +10 -2
- sglang/srt/constrained/xgrammar_backend.py +7 -5
- sglang/srt/conversation.py +17 -2
- sglang/srt/debug_utils/__init__.py +0 -0
- sglang/srt/debug_utils/dump_comparator.py +131 -0
- sglang/srt/debug_utils/dumper.py +108 -0
- sglang/srt/debug_utils/text_comparator.py +172 -0
- sglang/srt/disaggregation/common/conn.py +34 -6
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +13 -1
- sglang/srt/disaggregation/mini_lb.py +3 -2
- sglang/srt/disaggregation/mooncake/conn.py +65 -20
- sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
- sglang/srt/disaggregation/nixl/conn.py +17 -13
- sglang/srt/disaggregation/prefill.py +13 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
- sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
- sglang/srt/distributed/parallel_state.py +70 -15
- sglang/srt/entrypoints/engine.py +5 -9
- sglang/srt/entrypoints/http_server.py +20 -32
- sglang/srt/entrypoints/openai/protocol.py +3 -3
- sglang/srt/entrypoints/openai/serving_chat.py +148 -72
- sglang/srt/function_call/base_format_detector.py +74 -12
- sglang/srt/function_call/deepseekv3_detector.py +26 -11
- sglang/srt/function_call/ebnf_composer.py +105 -66
- sglang/srt/function_call/function_call_parser.py +6 -4
- sglang/srt/function_call/glm4_moe_detector.py +164 -0
- sglang/srt/function_call/kimik2_detector.py +41 -16
- sglang/srt/function_call/llama32_detector.py +6 -3
- sglang/srt/function_call/mistral_detector.py +11 -3
- sglang/srt/function_call/pythonic_detector.py +16 -14
- sglang/srt/function_call/qwen25_detector.py +12 -3
- sglang/srt/function_call/{qwen3_detector.py → qwen3_coder_detector.py} +11 -9
- sglang/srt/layers/activation.py +11 -3
- sglang/srt/layers/attention/base_attn_backend.py +3 -1
- sglang/srt/layers/attention/hybrid_attn_backend.py +100 -0
- sglang/srt/layers/attention/vision.py +56 -8
- sglang/srt/layers/communicator.py +12 -12
- sglang/srt/layers/dp_attention.py +72 -24
- sglang/srt/layers/layernorm.py +26 -1
- sglang/srt/layers/logits_processor.py +46 -25
- sglang/srt/layers/moe/ep_moe/layer.py +172 -206
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=160,N=320,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +25 -224
- sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -48
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +11 -8
- sglang/srt/layers/moe/topk.py +88 -34
- sglang/srt/layers/multimodal.py +11 -8
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -9
- sglang/srt/layers/quantization/fp8.py +25 -247
- sglang/srt/layers/quantization/fp8_kernel.py +78 -48
- sglang/srt/layers/quantization/modelopt_quant.py +33 -14
- sglang/srt/layers/quantization/unquant.py +24 -76
- sglang/srt/layers/quantization/utils.py +0 -9
- sglang/srt/layers/quantization/w4afp8.py +68 -17
- sglang/srt/layers/radix_attention.py +5 -3
- sglang/srt/lora/lora_manager.py +133 -169
- sglang/srt/lora/lora_registry.py +188 -0
- sglang/srt/lora/mem_pool.py +2 -2
- sglang/srt/managers/cache_controller.py +62 -13
- sglang/srt/managers/io_struct.py +19 -1
- sglang/srt/managers/mm_utils.py +154 -35
- sglang/srt/managers/multimodal_processor.py +3 -14
- sglang/srt/managers/schedule_batch.py +27 -11
- sglang/srt/managers/scheduler.py +48 -26
- sglang/srt/managers/tokenizer_manager.py +62 -28
- sglang/srt/managers/tp_worker.py +5 -4
- sglang/srt/mem_cache/allocator.py +67 -7
- sglang/srt/mem_cache/hicache_storage.py +17 -1
- sglang/srt/mem_cache/hiradix_cache.py +35 -18
- sglang/srt/mem_cache/memory_pool_host.py +3 -0
- sglang/srt/model_executor/cuda_graph_runner.py +61 -25
- sglang/srt/model_executor/forward_batch_info.py +201 -29
- sglang/srt/model_executor/model_runner.py +109 -37
- sglang/srt/models/deepseek_v2.py +63 -30
- sglang/srt/models/glm4_moe.py +1035 -0
- sglang/srt/models/glm4_moe_nextn.py +167 -0
- sglang/srt/models/interns1.py +328 -0
- sglang/srt/models/internvl.py +143 -47
- sglang/srt/models/llava.py +9 -5
- sglang/srt/models/minicpmo.py +4 -1
- sglang/srt/models/mllama4.py +10 -3
- sglang/srt/models/qwen2_moe.py +2 -6
- sglang/srt/models/qwen3_moe.py +6 -8
- sglang/srt/multimodal/processors/base_processor.py +20 -6
- sglang/srt/multimodal/processors/clip.py +2 -2
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +2 -2
- sglang/srt/multimodal/processors/gemma3.py +2 -2
- sglang/srt/multimodal/processors/gemma3n.py +2 -2
- sglang/srt/multimodal/processors/internvl.py +21 -8
- sglang/srt/multimodal/processors/janus_pro.py +2 -2
- sglang/srt/multimodal/processors/kimi_vl.py +2 -2
- sglang/srt/multimodal/processors/llava.py +4 -4
- sglang/srt/multimodal/processors/minicpm.py +2 -3
- sglang/srt/multimodal/processors/mlama.py +2 -2
- sglang/srt/multimodal/processors/mllama4.py +18 -111
- sglang/srt/multimodal/processors/phi4mm.py +2 -2
- sglang/srt/multimodal/processors/pixtral.py +2 -2
- sglang/srt/multimodal/processors/qwen_audio.py +2 -2
- sglang/srt/multimodal/processors/qwen_vl.py +2 -2
- sglang/srt/multimodal/processors/vila.py +3 -1
- sglang/srt/reasoning_parser.py +48 -5
- sglang/srt/sampling/sampling_batch_info.py +6 -5
- sglang/srt/server_args.py +132 -60
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +37 -36
- sglang/srt/speculative/eagle_utils.py +51 -23
- sglang/srt/speculative/eagle_worker.py +59 -44
- sglang/srt/two_batch_overlap.py +9 -5
- sglang/srt/utils.py +113 -69
- sglang/srt/weight_sync/utils.py +119 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_activation.py +50 -1
- sglang/test/test_utils.py +65 -5
- sglang/utils.py +19 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/METADATA +6 -6
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/RECORD +127 -114
- sglang/srt/debug_utils.py +0 -74
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 16,
|
4
|
+
"BLOCK_SIZE_N": 64,
|
5
|
+
"BLOCK_SIZE_K": 128,
|
6
|
+
"GROUP_SIZE_M": 32,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 4
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 16,
|
12
|
+
"BLOCK_SIZE_N": 128,
|
13
|
+
"BLOCK_SIZE_K": 128,
|
14
|
+
"GROUP_SIZE_M": 64,
|
15
|
+
"num_warps": 4,
|
16
|
+
"num_stages": 3
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 16,
|
20
|
+
"BLOCK_SIZE_N": 128,
|
21
|
+
"BLOCK_SIZE_K": 256,
|
22
|
+
"GROUP_SIZE_M": 16,
|
23
|
+
"num_warps": 4,
|
24
|
+
"num_stages": 2
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 16,
|
28
|
+
"BLOCK_SIZE_N": 64,
|
29
|
+
"BLOCK_SIZE_K": 128,
|
30
|
+
"GROUP_SIZE_M": 16,
|
31
|
+
"num_warps": 4,
|
32
|
+
"num_stages": 3
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 16,
|
36
|
+
"BLOCK_SIZE_N": 128,
|
37
|
+
"BLOCK_SIZE_K": 128,
|
38
|
+
"GROUP_SIZE_M": 16,
|
39
|
+
"num_warps": 4,
|
40
|
+
"num_stages": 2
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 16,
|
44
|
+
"BLOCK_SIZE_N": 64,
|
45
|
+
"BLOCK_SIZE_K": 256,
|
46
|
+
"GROUP_SIZE_M": 1,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 2
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 16,
|
52
|
+
"BLOCK_SIZE_N": 64,
|
53
|
+
"BLOCK_SIZE_K": 256,
|
54
|
+
"GROUP_SIZE_M": 1,
|
55
|
+
"num_warps": 4,
|
56
|
+
"num_stages": 3
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 16,
|
60
|
+
"BLOCK_SIZE_N": 64,
|
61
|
+
"BLOCK_SIZE_K": 256,
|
62
|
+
"GROUP_SIZE_M": 1,
|
63
|
+
"num_warps": 4,
|
64
|
+
"num_stages": 3
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 16,
|
68
|
+
"BLOCK_SIZE_N": 128,
|
69
|
+
"BLOCK_SIZE_K": 256,
|
70
|
+
"GROUP_SIZE_M": 1,
|
71
|
+
"num_warps": 4,
|
72
|
+
"num_stages": 2
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 16,
|
76
|
+
"BLOCK_SIZE_N": 128,
|
77
|
+
"BLOCK_SIZE_K": 256,
|
78
|
+
"GROUP_SIZE_M": 1,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 3
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 16,
|
84
|
+
"BLOCK_SIZE_N": 128,
|
85
|
+
"BLOCK_SIZE_K": 256,
|
86
|
+
"GROUP_SIZE_M": 16,
|
87
|
+
"num_warps": 4,
|
88
|
+
"num_stages": 2
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 32,
|
92
|
+
"BLOCK_SIZE_N": 128,
|
93
|
+
"BLOCK_SIZE_K": 256,
|
94
|
+
"GROUP_SIZE_M": 16,
|
95
|
+
"num_warps": 4,
|
96
|
+
"num_stages": 2
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 32,
|
100
|
+
"BLOCK_SIZE_N": 128,
|
101
|
+
"BLOCK_SIZE_K": 256,
|
102
|
+
"GROUP_SIZE_M": 16,
|
103
|
+
"num_warps": 4,
|
104
|
+
"num_stages": 2
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 64,
|
108
|
+
"BLOCK_SIZE_N": 128,
|
109
|
+
"BLOCK_SIZE_K": 128,
|
110
|
+
"GROUP_SIZE_M": 1,
|
111
|
+
"num_warps": 4,
|
112
|
+
"num_stages": 4
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 32,
|
116
|
+
"BLOCK_SIZE_N": 128,
|
117
|
+
"BLOCK_SIZE_K": 128,
|
118
|
+
"GROUP_SIZE_M": 1,
|
119
|
+
"num_warps": 4,
|
120
|
+
"num_stages": 3
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 64,
|
124
|
+
"BLOCK_SIZE_N": 128,
|
125
|
+
"BLOCK_SIZE_K": 128,
|
126
|
+
"GROUP_SIZE_M": 1,
|
127
|
+
"num_warps": 4,
|
128
|
+
"num_stages": 3
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 64,
|
132
|
+
"BLOCK_SIZE_N": 128,
|
133
|
+
"BLOCK_SIZE_K": 128,
|
134
|
+
"GROUP_SIZE_M": 1,
|
135
|
+
"num_warps": 4,
|
136
|
+
"num_stages": 3
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 64,
|
140
|
+
"BLOCK_SIZE_N": 128,
|
141
|
+
"BLOCK_SIZE_K": 128,
|
142
|
+
"GROUP_SIZE_M": 1,
|
143
|
+
"num_warps": 4,
|
144
|
+
"num_stages": 3
|
145
|
+
}
|
146
|
+
}
|
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 16,
|
4
|
+
"BLOCK_SIZE_N": 128,
|
5
|
+
"BLOCK_SIZE_K": 64,
|
6
|
+
"GROUP_SIZE_M": 1,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 4
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 16,
|
12
|
+
"BLOCK_SIZE_N": 64,
|
13
|
+
"BLOCK_SIZE_K": 64,
|
14
|
+
"GROUP_SIZE_M": 1,
|
15
|
+
"num_warps": 4,
|
16
|
+
"num_stages": 4
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 16,
|
20
|
+
"BLOCK_SIZE_N": 64,
|
21
|
+
"BLOCK_SIZE_K": 64,
|
22
|
+
"GROUP_SIZE_M": 1,
|
23
|
+
"num_warps": 4,
|
24
|
+
"num_stages": 4
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 16,
|
28
|
+
"BLOCK_SIZE_N": 64,
|
29
|
+
"BLOCK_SIZE_K": 64,
|
30
|
+
"GROUP_SIZE_M": 1,
|
31
|
+
"num_warps": 4,
|
32
|
+
"num_stages": 3
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 16,
|
36
|
+
"BLOCK_SIZE_N": 128,
|
37
|
+
"BLOCK_SIZE_K": 64,
|
38
|
+
"GROUP_SIZE_M": 32,
|
39
|
+
"num_warps": 4,
|
40
|
+
"num_stages": 3
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 16,
|
44
|
+
"BLOCK_SIZE_N": 128,
|
45
|
+
"BLOCK_SIZE_K": 64,
|
46
|
+
"GROUP_SIZE_M": 64,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 4
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 16,
|
52
|
+
"BLOCK_SIZE_N": 64,
|
53
|
+
"BLOCK_SIZE_K": 64,
|
54
|
+
"GROUP_SIZE_M": 32,
|
55
|
+
"num_warps": 4,
|
56
|
+
"num_stages": 3
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 16,
|
60
|
+
"BLOCK_SIZE_N": 128,
|
61
|
+
"BLOCK_SIZE_K": 64,
|
62
|
+
"GROUP_SIZE_M": 32,
|
63
|
+
"num_warps": 4,
|
64
|
+
"num_stages": 3
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 16,
|
68
|
+
"BLOCK_SIZE_N": 128,
|
69
|
+
"BLOCK_SIZE_K": 64,
|
70
|
+
"GROUP_SIZE_M": 1,
|
71
|
+
"num_warps": 4,
|
72
|
+
"num_stages": 3
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 16,
|
76
|
+
"BLOCK_SIZE_N": 128,
|
77
|
+
"BLOCK_SIZE_K": 64,
|
78
|
+
"GROUP_SIZE_M": 32,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 3
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 16,
|
84
|
+
"BLOCK_SIZE_N": 128,
|
85
|
+
"BLOCK_SIZE_K": 64,
|
86
|
+
"GROUP_SIZE_M": 1,
|
87
|
+
"num_warps": 4,
|
88
|
+
"num_stages": 4
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 16,
|
92
|
+
"BLOCK_SIZE_N": 128,
|
93
|
+
"BLOCK_SIZE_K": 64,
|
94
|
+
"GROUP_SIZE_M": 64,
|
95
|
+
"num_warps": 4,
|
96
|
+
"num_stages": 3
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 32,
|
100
|
+
"BLOCK_SIZE_N": 128,
|
101
|
+
"BLOCK_SIZE_K": 64,
|
102
|
+
"GROUP_SIZE_M": 64,
|
103
|
+
"num_warps": 4,
|
104
|
+
"num_stages": 3
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 64,
|
108
|
+
"BLOCK_SIZE_N": 128,
|
109
|
+
"BLOCK_SIZE_K": 64,
|
110
|
+
"GROUP_SIZE_M": 1,
|
111
|
+
"num_warps": 4,
|
112
|
+
"num_stages": 3
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 64,
|
116
|
+
"BLOCK_SIZE_N": 64,
|
117
|
+
"BLOCK_SIZE_K": 64,
|
118
|
+
"GROUP_SIZE_M": 1,
|
119
|
+
"num_warps": 4,
|
120
|
+
"num_stages": 3
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 64,
|
124
|
+
"BLOCK_SIZE_N": 128,
|
125
|
+
"BLOCK_SIZE_K": 64,
|
126
|
+
"GROUP_SIZE_M": 16,
|
127
|
+
"num_warps": 4,
|
128
|
+
"num_stages": 3
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 64,
|
132
|
+
"BLOCK_SIZE_N": 128,
|
133
|
+
"BLOCK_SIZE_K": 64,
|
134
|
+
"GROUP_SIZE_M": 32,
|
135
|
+
"num_warps": 4,
|
136
|
+
"num_stages": 3
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 64,
|
140
|
+
"BLOCK_SIZE_N": 64,
|
141
|
+
"BLOCK_SIZE_K": 64,
|
142
|
+
"GROUP_SIZE_M": 16,
|
143
|
+
"num_warps": 4,
|
144
|
+
"num_stages": 2
|
145
|
+
}
|
146
|
+
}
|
@@ -53,9 +53,7 @@ elif _is_hip:
|
|
53
53
|
from aiter import moe_sum
|
54
54
|
except ImportError:
|
55
55
|
raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")
|
56
|
-
|
57
|
-
from vllm import _custom_ops as vllm_ops
|
58
|
-
from vllm._custom_ops import scaled_fp8_quant
|
56
|
+
|
59
57
|
|
60
58
|
if _is_cuda or _is_hip:
|
61
59
|
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
|
@@ -63,9 +61,6 @@ if _is_cuda or _is_hip:
|
|
63
61
|
|
64
62
|
logger = logging.getLogger(__name__)
|
65
63
|
padding_size = 128 if bool(int(os.getenv("SGLANG_MOE_PADDING", "0"))) else 0
|
66
|
-
enable_moe_align_block_size_triton = bool(
|
67
|
-
int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
|
68
|
-
)
|
69
64
|
|
70
65
|
|
71
66
|
@triton.jit
|
@@ -533,190 +528,6 @@ def fused_moe_kernel(
|
|
533
528
|
tl.store(c_ptrs, accumulator, mask=c_mask)
|
534
529
|
|
535
530
|
|
536
|
-
@triton.jit
|
537
|
-
def moe_align_block_size_stage1(
|
538
|
-
topk_ids_ptr,
|
539
|
-
tokens_cnts_ptr,
|
540
|
-
num_experts: tl.constexpr,
|
541
|
-
numel: tl.constexpr,
|
542
|
-
tokens_per_thread: tl.constexpr,
|
543
|
-
):
|
544
|
-
pid = tl.program_id(0)
|
545
|
-
|
546
|
-
start_idx = pid * tokens_per_thread
|
547
|
-
|
548
|
-
off_c = (pid + 1) * num_experts
|
549
|
-
|
550
|
-
for i in range(tokens_per_thread):
|
551
|
-
if start_idx + i < numel:
|
552
|
-
idx = tl.load(topk_ids_ptr + start_idx + i)
|
553
|
-
token_cnt = tl.load(tokens_cnts_ptr + off_c + idx)
|
554
|
-
tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1)
|
555
|
-
|
556
|
-
|
557
|
-
@triton.jit
|
558
|
-
def moe_align_block_size_stage2(
|
559
|
-
tokens_cnts_ptr,
|
560
|
-
num_experts: tl.constexpr,
|
561
|
-
):
|
562
|
-
pid = tl.program_id(0)
|
563
|
-
|
564
|
-
last_cnt = 0
|
565
|
-
for i in range(1, num_experts + 1):
|
566
|
-
token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid)
|
567
|
-
last_cnt = last_cnt + token_cnt
|
568
|
-
tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt)
|
569
|
-
|
570
|
-
|
571
|
-
@triton.jit
|
572
|
-
def moe_align_block_size_stage3(
|
573
|
-
total_tokens_post_pad_ptr,
|
574
|
-
tokens_cnts_ptr,
|
575
|
-
cumsum_ptr,
|
576
|
-
num_experts: tl.constexpr,
|
577
|
-
block_size: tl.constexpr,
|
578
|
-
):
|
579
|
-
last_cumsum = 0
|
580
|
-
off_cnt = num_experts * num_experts
|
581
|
-
for i in range(1, num_experts + 1):
|
582
|
-
token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1)
|
583
|
-
last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size
|
584
|
-
tl.store(cumsum_ptr + i, last_cumsum)
|
585
|
-
tl.store(total_tokens_post_pad_ptr, last_cumsum)
|
586
|
-
|
587
|
-
|
588
|
-
@triton.jit
|
589
|
-
def moe_align_block_size_stage4(
|
590
|
-
topk_ids_ptr,
|
591
|
-
sorted_token_ids_ptr,
|
592
|
-
expert_ids_ptr,
|
593
|
-
tokens_cnts_ptr,
|
594
|
-
cumsum_ptr,
|
595
|
-
num_experts: tl.constexpr,
|
596
|
-
block_size: tl.constexpr,
|
597
|
-
numel: tl.constexpr,
|
598
|
-
tokens_per_thread: tl.constexpr,
|
599
|
-
):
|
600
|
-
pid = tl.program_id(0)
|
601
|
-
start_idx = tl.load(cumsum_ptr + pid)
|
602
|
-
end_idx = tl.load(cumsum_ptr + pid + 1)
|
603
|
-
|
604
|
-
for i in range(start_idx, end_idx, block_size):
|
605
|
-
tl.store(expert_ids_ptr + i // block_size, pid)
|
606
|
-
|
607
|
-
start_idx = pid * tokens_per_thread
|
608
|
-
off_t = pid * num_experts
|
609
|
-
|
610
|
-
for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread, numel)):
|
611
|
-
expert_id = tl.load(topk_ids_ptr + i)
|
612
|
-
token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id)
|
613
|
-
rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id)
|
614
|
-
tl.store(sorted_token_ids_ptr + rank_post_pad, i)
|
615
|
-
tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1)
|
616
|
-
|
617
|
-
|
618
|
-
def moe_align_block_size_triton(
|
619
|
-
topk_ids: torch.Tensor,
|
620
|
-
num_experts: int,
|
621
|
-
block_size: int,
|
622
|
-
sorted_token_ids: torch.Tensor,
|
623
|
-
expert_ids: torch.Tensor,
|
624
|
-
num_tokens_post_pad: torch.Tensor,
|
625
|
-
) -> None:
|
626
|
-
numel = topk_ids.numel()
|
627
|
-
grid = (num_experts,)
|
628
|
-
tokens_cnts = torch.zeros(
|
629
|
-
(num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device
|
630
|
-
)
|
631
|
-
cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device=topk_ids.device)
|
632
|
-
tokens_per_thread = ceil_div(numel, num_experts)
|
633
|
-
|
634
|
-
moe_align_block_size_stage1[grid](
|
635
|
-
topk_ids,
|
636
|
-
tokens_cnts,
|
637
|
-
num_experts,
|
638
|
-
numel,
|
639
|
-
tokens_per_thread,
|
640
|
-
)
|
641
|
-
moe_align_block_size_stage2[grid](
|
642
|
-
tokens_cnts,
|
643
|
-
num_experts,
|
644
|
-
)
|
645
|
-
moe_align_block_size_stage3[(1,)](
|
646
|
-
num_tokens_post_pad,
|
647
|
-
tokens_cnts,
|
648
|
-
cumsum,
|
649
|
-
num_experts,
|
650
|
-
block_size,
|
651
|
-
)
|
652
|
-
moe_align_block_size_stage4[grid](
|
653
|
-
topk_ids,
|
654
|
-
sorted_token_ids,
|
655
|
-
expert_ids,
|
656
|
-
tokens_cnts,
|
657
|
-
cumsum,
|
658
|
-
num_experts,
|
659
|
-
block_size,
|
660
|
-
numel,
|
661
|
-
tokens_per_thread,
|
662
|
-
)
|
663
|
-
|
664
|
-
|
665
|
-
@triton.jit
|
666
|
-
def init_sorted_ids_and_cumsum_buffer_kernel(
|
667
|
-
sorted_ids_ptr,
|
668
|
-
cumsum_buffer_ptr,
|
669
|
-
max_num_tokens_padded,
|
670
|
-
topk_ids_numel,
|
671
|
-
num_experts: tl.constexpr,
|
672
|
-
BLOCK_SIZE: tl.constexpr,
|
673
|
-
ALIGNED_NUM_EXPERTS_P1: tl.constexpr,
|
674
|
-
):
|
675
|
-
pid = tl.program_id(0)
|
676
|
-
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
677
|
-
|
678
|
-
sorted_ids_blocks = tl.cdiv(max_num_tokens_padded, BLOCK_SIZE)
|
679
|
-
|
680
|
-
if pid < sorted_ids_blocks:
|
681
|
-
mask = offsets < max_num_tokens_padded
|
682
|
-
tl.store(
|
683
|
-
sorted_ids_ptr + offsets,
|
684
|
-
tl.full((BLOCK_SIZE,), topk_ids_numel, dtype=tl.int32),
|
685
|
-
mask=mask,
|
686
|
-
)
|
687
|
-
elif pid == sorted_ids_blocks:
|
688
|
-
offset_e = tl.arange(0, ALIGNED_NUM_EXPERTS_P1)
|
689
|
-
mask_e = offset_e < num_experts + 1
|
690
|
-
tl.store(
|
691
|
-
cumsum_buffer_ptr + offset_e,
|
692
|
-
tl.zeros((ALIGNED_NUM_EXPERTS_P1,), dtype=tl.int32),
|
693
|
-
mask=mask_e,
|
694
|
-
)
|
695
|
-
|
696
|
-
|
697
|
-
def init_sorted_ids_and_cumsum_buffer(
|
698
|
-
max_num_tokens_padded: int, topk_ids_numel: int, num_experts: int, device="cuda"
|
699
|
-
):
|
700
|
-
sorted_ids = torch.empty((max_num_tokens_padded,), dtype=torch.int32, device=device)
|
701
|
-
cumsum_buffer = torch.empty((num_experts + 1,), dtype=torch.int32, device=device)
|
702
|
-
|
703
|
-
BLOCK_SIZE = 1024
|
704
|
-
sorted_ids_blocks = triton.cdiv(max_num_tokens_padded, BLOCK_SIZE)
|
705
|
-
grid = (sorted_ids_blocks + 1,)
|
706
|
-
|
707
|
-
init_sorted_ids_and_cumsum_buffer_kernel[grid](
|
708
|
-
sorted_ids,
|
709
|
-
cumsum_buffer,
|
710
|
-
max_num_tokens_padded,
|
711
|
-
topk_ids_numel,
|
712
|
-
num_experts,
|
713
|
-
BLOCK_SIZE,
|
714
|
-
next_power_of_2(num_experts + 1),
|
715
|
-
)
|
716
|
-
|
717
|
-
return sorted_ids, cumsum_buffer
|
718
|
-
|
719
|
-
|
720
531
|
def moe_align_block_size(
|
721
532
|
topk_ids: torch.Tensor, block_size: int, num_experts: int
|
722
533
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
@@ -766,42 +577,32 @@ def moe_align_block_size(
|
|
766
577
|
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
767
578
|
)
|
768
579
|
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
|
769
|
-
if enable_moe_align_block_size_triton:
|
770
|
-
sorted_ids.fill_(topk_ids.numel())
|
771
|
-
moe_align_block_size_triton(
|
772
|
-
topk_ids,
|
773
|
-
num_experts,
|
774
|
-
block_size,
|
775
|
-
sorted_ids,
|
776
|
-
expert_ids,
|
777
|
-
num_tokens_post_pad,
|
778
|
-
)
|
779
|
-
else:
|
780
|
-
cumsum_buffer = torch.empty(
|
781
|
-
(num_experts + 1,), dtype=torch.int32, device=topk_ids.device
|
782
|
-
)
|
783
|
-
token_cnts_buffer = torch.empty(
|
784
|
-
(num_experts + 1) * num_experts,
|
785
|
-
dtype=torch.int32,
|
786
|
-
device=topk_ids.device,
|
787
|
-
)
|
788
580
|
|
789
|
-
|
790
|
-
|
791
|
-
|
792
|
-
|
581
|
+
cumsum_buffer = torch.empty(
|
582
|
+
(num_experts + 1,), dtype=torch.int32, device=topk_ids.device
|
583
|
+
)
|
584
|
+
token_cnts_buffer = torch.empty(
|
585
|
+
(num_experts + 1) * num_experts,
|
586
|
+
dtype=torch.int32,
|
587
|
+
device=topk_ids.device,
|
588
|
+
)
|
793
589
|
|
794
|
-
|
795
|
-
|
796
|
-
|
797
|
-
|
798
|
-
|
799
|
-
|
800
|
-
|
801
|
-
|
802
|
-
|
803
|
-
|
804
|
-
|
590
|
+
# Threshold based on benchmark results
|
591
|
+
fuse_sorted_ids_padding = sorted_ids.shape[0] <= 4096
|
592
|
+
if not fuse_sorted_ids_padding:
|
593
|
+
sorted_ids.fill_(topk_ids.numel())
|
594
|
+
|
595
|
+
sgl_moe_align_block_size(
|
596
|
+
topk_ids,
|
597
|
+
num_experts,
|
598
|
+
block_size,
|
599
|
+
sorted_ids,
|
600
|
+
expert_ids,
|
601
|
+
num_tokens_post_pad,
|
602
|
+
token_cnts_buffer,
|
603
|
+
cumsum_buffer,
|
604
|
+
fuse_sorted_ids_padding,
|
605
|
+
)
|
805
606
|
return sorted_ids, expert_ids, num_tokens_post_pad
|
806
607
|
|
807
608
|
|