sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -1
- sglang/eval/loogle_eval.py +7 -0
- sglang/srt/_custom_ops.py +29 -1
- sglang/srt/configs/deepseekvl2.py +11 -2
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +10 -8
- sglang/srt/configs/update_config.py +3 -1
- sglang/srt/conversation.py +2 -1
- sglang/srt/custom_op.py +5 -2
- sglang/srt/disaggregation/common/conn.py +34 -6
- sglang/srt/disaggregation/decode.py +9 -1
- sglang/srt/disaggregation/mini_lb.py +3 -2
- sglang/srt/disaggregation/mooncake/conn.py +93 -76
- sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
- sglang/srt/disaggregation/nixl/conn.py +17 -13
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
- sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
- sglang/srt/distributed/parallel_state.py +103 -15
- sglang/srt/entrypoints/engine.py +31 -33
- sglang/srt/entrypoints/http_server.py +20 -32
- sglang/srt/entrypoints/openai/protocol.py +3 -3
- sglang/srt/entrypoints/openai/serving_chat.py +48 -6
- sglang/srt/eplb/expert_location_dispatch.py +1 -1
- sglang/srt/function_call/base_format_detector.py +74 -12
- sglang/srt/function_call/deepseekv3_detector.py +26 -11
- sglang/srt/function_call/ebnf_composer.py +95 -63
- sglang/srt/function_call/function_call_parser.py +4 -2
- sglang/srt/function_call/kimik2_detector.py +41 -16
- sglang/srt/function_call/llama32_detector.py +6 -3
- sglang/srt/function_call/mistral_detector.py +11 -3
- sglang/srt/function_call/pythonic_detector.py +16 -14
- sglang/srt/function_call/qwen25_detector.py +12 -3
- sglang/srt/function_call/qwen3_coder_detector.py +151 -0
- sglang/srt/hf_transformers_utils.py +0 -1
- sglang/srt/layers/activation.py +24 -3
- sglang/srt/layers/attention/base_attn_backend.py +3 -1
- sglang/srt/layers/attention/flashattention_backend.py +3 -3
- sglang/srt/layers/attention/flashinfer_backend.py +40 -1
- sglang/srt/layers/communicator.py +12 -12
- sglang/srt/layers/dp_attention.py +72 -24
- sglang/srt/layers/linear.py +13 -102
- sglang/srt/layers/logits_processor.py +34 -24
- sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
- sglang/srt/layers/moe/ep_moe/layer.py +23 -402
- sglang/srt/layers/moe/fused_moe_native.py +7 -47
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +54 -263
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
- sglang/srt/layers/moe/topk.py +190 -23
- sglang/srt/layers/quantization/__init__.py +20 -134
- sglang/srt/layers/quantization/awq.py +578 -11
- sglang/srt/layers/quantization/awq_triton.py +339 -0
- sglang/srt/layers/quantization/base_config.py +85 -10
- sglang/srt/layers/quantization/blockwise_int8.py +17 -55
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +23 -79
- sglang/srt/layers/quantization/fp8.py +273 -62
- sglang/srt/layers/quantization/fp8_kernel.py +210 -46
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gptq.py +501 -143
- sglang/srt/layers/quantization/marlin_utils.py +790 -0
- sglang/srt/layers/quantization/modelopt_quant.py +34 -112
- sglang/srt/layers/quantization/moe_wna16.py +45 -49
- sglang/srt/layers/quantization/petit.py +252 -0
- sglang/srt/layers/quantization/petit_utils.py +104 -0
- sglang/srt/layers/quantization/qoq.py +7 -6
- sglang/srt/layers/quantization/scalar_type.py +352 -0
- sglang/srt/layers/quantization/unquant.py +422 -0
- sglang/srt/layers/quantization/utils.py +340 -9
- sglang/srt/layers/quantization/w4afp8.py +8 -4
- sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
- sglang/srt/layers/quantization/w8a8_int8.py +51 -115
- sglang/srt/layers/radix_attention.py +5 -3
- sglang/srt/layers/vocab_parallel_embedding.py +1 -41
- sglang/srt/lora/lora.py +0 -4
- sglang/srt/lora/lora_manager.py +162 -164
- sglang/srt/lora/lora_registry.py +124 -0
- sglang/srt/lora/mem_pool.py +83 -35
- sglang/srt/lora/utils.py +12 -5
- sglang/srt/managers/cache_controller.py +288 -0
- sglang/srt/managers/io_struct.py +60 -30
- sglang/srt/managers/mm_utils.py +7 -8
- sglang/srt/managers/schedule_batch.py +163 -113
- sglang/srt/managers/schedule_policy.py +68 -27
- sglang/srt/managers/scheduler.py +256 -86
- sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
- sglang/srt/managers/tokenizer_manager.py +38 -27
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/allocator.py +74 -23
- sglang/srt/mem_cache/base_prefix_cache.py +14 -2
- sglang/srt/mem_cache/chunk_cache.py +5 -2
- sglang/srt/mem_cache/hicache_storage.py +168 -0
- sglang/srt/mem_cache/hiradix_cache.py +194 -5
- sglang/srt/mem_cache/memory_pool.py +16 -1
- sglang/srt/mem_cache/memory_pool_host.py +44 -2
- sglang/srt/mem_cache/radix_cache.py +26 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +66 -31
- sglang/srt/model_executor/forward_batch_info.py +210 -25
- sglang/srt/model_executor/model_runner.py +147 -42
- sglang/srt/model_loader/loader.py +7 -1
- sglang/srt/model_loader/utils.py +4 -4
- sglang/srt/models/clip.py +1 -1
- sglang/srt/models/deepseek.py +9 -6
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_v2.py +192 -173
- sglang/srt/models/deepseek_vl2.py +5 -5
- sglang/srt/models/gemma.py +48 -0
- sglang/srt/models/gemma2.py +52 -0
- sglang/srt/models/gemma3_causal.py +63 -0
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -4
- sglang/srt/models/granitemoe.py +385 -0
- sglang/srt/models/grok.py +9 -3
- sglang/srt/models/hunyuan.py +63 -16
- sglang/srt/models/internvl.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -1
- sglang/srt/models/llama.py +41 -0
- sglang/srt/models/llama4.py +11 -11
- sglang/srt/models/llava.py +2 -2
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpm.py +0 -2
- sglang/srt/models/minicpmo.py +3 -7
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mixtral.py +9 -2
- sglang/srt/models/mllama.py +3 -5
- sglang/srt/models/mllama4.py +13 -6
- sglang/srt/models/olmoe.py +8 -5
- sglang/srt/models/persimmon.py +330 -0
- sglang/srt/models/phi.py +321 -0
- sglang/srt/models/phi4mm.py +44 -4
- sglang/srt/models/phi4mm_audio.py +1260 -0
- sglang/srt/models/phi4mm_utils.py +1917 -0
- sglang/srt/models/phimoe.py +9 -3
- sglang/srt/models/qwen.py +37 -0
- sglang/srt/models/qwen2.py +41 -0
- sglang/srt/models/qwen2_5_vl.py +4 -4
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +53 -9
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/qwen3.py +65 -1
- sglang/srt/models/qwen3_moe.py +57 -24
- sglang/srt/models/vila.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +91 -97
- sglang/srt/multimodal/processors/clip.py +21 -19
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
- sglang/srt/multimodal/processors/gemma3.py +13 -17
- sglang/srt/multimodal/processors/gemma3n.py +19 -23
- sglang/srt/multimodal/processors/internvl.py +9 -10
- sglang/srt/multimodal/processors/janus_pro.py +12 -27
- sglang/srt/multimodal/processors/kimi_vl.py +12 -14
- sglang/srt/multimodal/processors/llava.py +4 -2
- sglang/srt/multimodal/processors/minicpm.py +35 -44
- sglang/srt/multimodal/processors/mlama.py +21 -18
- sglang/srt/multimodal/processors/mllama4.py +4 -5
- sglang/srt/multimodal/processors/phi4mm.py +63 -39
- sglang/srt/multimodal/processors/pixtral.py +14 -35
- sglang/srt/multimodal/processors/qwen_audio.py +65 -0
- sglang/srt/multimodal/processors/qwen_vl.py +16 -21
- sglang/srt/multimodal/processors/vila.py +14 -14
- sglang/srt/reasoning_parser.py +46 -4
- sglang/srt/sampling/sampling_batch_info.py +6 -5
- sglang/srt/sampling/sampling_params.py +8 -1
- sglang/srt/server_args.py +454 -270
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +46 -37
- sglang/srt/speculative/eagle_utils.py +51 -23
- sglang/srt/speculative/eagle_worker.py +59 -44
- sglang/srt/two_batch_overlap.py +10 -5
- sglang/srt/utils.py +44 -69
- sglang/test/runners.py +14 -3
- sglang/test/test_activation.py +50 -1
- sglang/test/test_block_fp8.py +8 -3
- sglang/test/test_block_fp8_ep.py +1 -1
- sglang/test/test_custom_ops.py +12 -7
- sglang/test/test_cutlass_w4a8_moe.py +1 -3
- sglang/test/test_fp4_moe.py +1 -3
- sglang/test/test_marlin_moe.py +286 -0
- sglang/test/test_marlin_utils.py +171 -0
- sglang/test/test_utils.py +35 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +10 -10
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +198 -175
- sglang/srt/layers/quantization/quant_utils.py +0 -166
- sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 16,
|
4
|
+
"BLOCK_SIZE_N": 128,
|
5
|
+
"BLOCK_SIZE_K": 128,
|
6
|
+
"GROUP_SIZE_M": 1,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 4
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 16,
|
12
|
+
"BLOCK_SIZE_N": 128,
|
13
|
+
"BLOCK_SIZE_K": 128,
|
14
|
+
"GROUP_SIZE_M": 64,
|
15
|
+
"num_warps": 4,
|
16
|
+
"num_stages": 4
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 16,
|
20
|
+
"BLOCK_SIZE_N": 128,
|
21
|
+
"BLOCK_SIZE_K": 128,
|
22
|
+
"GROUP_SIZE_M": 1,
|
23
|
+
"num_warps": 4,
|
24
|
+
"num_stages": 4
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 16,
|
28
|
+
"BLOCK_SIZE_N": 128,
|
29
|
+
"BLOCK_SIZE_K": 128,
|
30
|
+
"GROUP_SIZE_M": 1,
|
31
|
+
"num_warps": 4,
|
32
|
+
"num_stages": 4
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 16,
|
36
|
+
"BLOCK_SIZE_N": 128,
|
37
|
+
"BLOCK_SIZE_K": 128,
|
38
|
+
"GROUP_SIZE_M": 1,
|
39
|
+
"num_warps": 4,
|
40
|
+
"num_stages": 4
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 16,
|
44
|
+
"BLOCK_SIZE_N": 128,
|
45
|
+
"BLOCK_SIZE_K": 128,
|
46
|
+
"GROUP_SIZE_M": 64,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 4
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 16,
|
52
|
+
"BLOCK_SIZE_N": 128,
|
53
|
+
"BLOCK_SIZE_K": 128,
|
54
|
+
"GROUP_SIZE_M": 1,
|
55
|
+
"num_warps": 4,
|
56
|
+
"num_stages": 3
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 16,
|
60
|
+
"BLOCK_SIZE_N": 128,
|
61
|
+
"BLOCK_SIZE_K": 128,
|
62
|
+
"GROUP_SIZE_M": 32,
|
63
|
+
"num_warps": 4,
|
64
|
+
"num_stages": 3
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 16,
|
68
|
+
"BLOCK_SIZE_N": 128,
|
69
|
+
"BLOCK_SIZE_K": 128,
|
70
|
+
"GROUP_SIZE_M": 32,
|
71
|
+
"num_warps": 4,
|
72
|
+
"num_stages": 3
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 16,
|
76
|
+
"BLOCK_SIZE_N": 128,
|
77
|
+
"BLOCK_SIZE_K": 64,
|
78
|
+
"GROUP_SIZE_M": 1,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 3
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 16,
|
84
|
+
"BLOCK_SIZE_N": 128,
|
85
|
+
"BLOCK_SIZE_K": 64,
|
86
|
+
"GROUP_SIZE_M": 16,
|
87
|
+
"num_warps": 4,
|
88
|
+
"num_stages": 3
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 16,
|
92
|
+
"BLOCK_SIZE_N": 256,
|
93
|
+
"BLOCK_SIZE_K": 64,
|
94
|
+
"GROUP_SIZE_M": 32,
|
95
|
+
"num_warps": 4,
|
96
|
+
"num_stages": 3
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 16,
|
100
|
+
"BLOCK_SIZE_N": 256,
|
101
|
+
"BLOCK_SIZE_K": 64,
|
102
|
+
"GROUP_SIZE_M": 64,
|
103
|
+
"num_warps": 4,
|
104
|
+
"num_stages": 3
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 32,
|
108
|
+
"BLOCK_SIZE_N": 128,
|
109
|
+
"BLOCK_SIZE_K": 128,
|
110
|
+
"GROUP_SIZE_M": 64,
|
111
|
+
"num_warps": 4,
|
112
|
+
"num_stages": 3
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 32,
|
116
|
+
"BLOCK_SIZE_N": 128,
|
117
|
+
"BLOCK_SIZE_K": 128,
|
118
|
+
"GROUP_SIZE_M": 16,
|
119
|
+
"num_warps": 4,
|
120
|
+
"num_stages": 3
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 64,
|
124
|
+
"BLOCK_SIZE_N": 256,
|
125
|
+
"BLOCK_SIZE_K": 128,
|
126
|
+
"GROUP_SIZE_M": 32,
|
127
|
+
"num_warps": 8,
|
128
|
+
"num_stages": 4
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 64,
|
132
|
+
"BLOCK_SIZE_N": 128,
|
133
|
+
"BLOCK_SIZE_K": 128,
|
134
|
+
"GROUP_SIZE_M": 16,
|
135
|
+
"num_warps": 4,
|
136
|
+
"num_stages": 4
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 64,
|
140
|
+
"BLOCK_SIZE_N": 128,
|
141
|
+
"BLOCK_SIZE_K": 128,
|
142
|
+
"GROUP_SIZE_M": 32,
|
143
|
+
"num_warps": 4,
|
144
|
+
"num_stages": 4
|
145
|
+
}
|
146
|
+
}
|
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 16,
|
4
|
+
"BLOCK_SIZE_N": 128,
|
5
|
+
"BLOCK_SIZE_K": 128,
|
6
|
+
"GROUP_SIZE_M": 32,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 4
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 16,
|
12
|
+
"BLOCK_SIZE_N": 128,
|
13
|
+
"BLOCK_SIZE_K": 128,
|
14
|
+
"GROUP_SIZE_M": 32,
|
15
|
+
"num_warps": 4,
|
16
|
+
"num_stages": 4
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 16,
|
20
|
+
"BLOCK_SIZE_N": 128,
|
21
|
+
"BLOCK_SIZE_K": 128,
|
22
|
+
"GROUP_SIZE_M": 1,
|
23
|
+
"num_warps": 4,
|
24
|
+
"num_stages": 4
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 16,
|
28
|
+
"BLOCK_SIZE_N": 128,
|
29
|
+
"BLOCK_SIZE_K": 128,
|
30
|
+
"GROUP_SIZE_M": 1,
|
31
|
+
"num_warps": 4,
|
32
|
+
"num_stages": 4
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 16,
|
36
|
+
"BLOCK_SIZE_N": 128,
|
37
|
+
"BLOCK_SIZE_K": 128,
|
38
|
+
"GROUP_SIZE_M": 64,
|
39
|
+
"num_warps": 4,
|
40
|
+
"num_stages": 3
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 16,
|
44
|
+
"BLOCK_SIZE_N": 128,
|
45
|
+
"BLOCK_SIZE_K": 128,
|
46
|
+
"GROUP_SIZE_M": 64,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 3
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 16,
|
52
|
+
"BLOCK_SIZE_N": 128,
|
53
|
+
"BLOCK_SIZE_K": 128,
|
54
|
+
"GROUP_SIZE_M": 64,
|
55
|
+
"num_warps": 4,
|
56
|
+
"num_stages": 3
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 16,
|
60
|
+
"BLOCK_SIZE_N": 128,
|
61
|
+
"BLOCK_SIZE_K": 128,
|
62
|
+
"GROUP_SIZE_M": 64,
|
63
|
+
"num_warps": 4,
|
64
|
+
"num_stages": 3
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 16,
|
68
|
+
"BLOCK_SIZE_N": 256,
|
69
|
+
"BLOCK_SIZE_K": 128,
|
70
|
+
"GROUP_SIZE_M": 16,
|
71
|
+
"num_warps": 4,
|
72
|
+
"num_stages": 2
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 16,
|
76
|
+
"BLOCK_SIZE_N": 128,
|
77
|
+
"BLOCK_SIZE_K": 128,
|
78
|
+
"GROUP_SIZE_M": 16,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 3
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 16,
|
84
|
+
"BLOCK_SIZE_N": 128,
|
85
|
+
"BLOCK_SIZE_K": 128,
|
86
|
+
"GROUP_SIZE_M": 16,
|
87
|
+
"num_warps": 4,
|
88
|
+
"num_stages": 3
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 16,
|
92
|
+
"BLOCK_SIZE_N": 128,
|
93
|
+
"BLOCK_SIZE_K": 128,
|
94
|
+
"GROUP_SIZE_M": 16,
|
95
|
+
"num_warps": 4,
|
96
|
+
"num_stages": 3
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 16,
|
100
|
+
"BLOCK_SIZE_N": 128,
|
101
|
+
"BLOCK_SIZE_K": 128,
|
102
|
+
"GROUP_SIZE_M": 64,
|
103
|
+
"num_warps": 4,
|
104
|
+
"num_stages": 3
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 32,
|
108
|
+
"BLOCK_SIZE_N": 128,
|
109
|
+
"BLOCK_SIZE_K": 128,
|
110
|
+
"GROUP_SIZE_M": 32,
|
111
|
+
"num_warps": 4,
|
112
|
+
"num_stages": 3
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 64,
|
116
|
+
"BLOCK_SIZE_N": 128,
|
117
|
+
"BLOCK_SIZE_K": 128,
|
118
|
+
"GROUP_SIZE_M": 32,
|
119
|
+
"num_warps": 4,
|
120
|
+
"num_stages": 3
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 64,
|
124
|
+
"BLOCK_SIZE_N": 128,
|
125
|
+
"BLOCK_SIZE_K": 128,
|
126
|
+
"GROUP_SIZE_M": 32,
|
127
|
+
"num_warps": 4,
|
128
|
+
"num_stages": 3
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 64,
|
132
|
+
"BLOCK_SIZE_N": 128,
|
133
|
+
"BLOCK_SIZE_K": 128,
|
134
|
+
"GROUP_SIZE_M": 16,
|
135
|
+
"num_warps": 4,
|
136
|
+
"num_stages": 3
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 64,
|
140
|
+
"BLOCK_SIZE_N": 128,
|
141
|
+
"BLOCK_SIZE_K": 128,
|
142
|
+
"GROUP_SIZE_M": 32,
|
143
|
+
"num_warps": 4,
|
144
|
+
"num_stages": 3
|
145
|
+
}
|
146
|
+
}
|
@@ -6,13 +6,13 @@ import functools
|
|
6
6
|
import json
|
7
7
|
import logging
|
8
8
|
import os
|
9
|
-
from typing import Any,
|
9
|
+
from typing import Any, Dict, List, Optional, Tuple
|
10
10
|
|
11
11
|
import torch
|
12
12
|
import triton
|
13
13
|
import triton.language as tl
|
14
14
|
|
15
|
-
from sglang.srt.layers.moe.topk import
|
15
|
+
from sglang.srt.layers.moe.topk import TopKOutput
|
16
16
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
17
17
|
per_token_group_quant_fp8,
|
18
18
|
scaled_fp8_quant,
|
@@ -39,14 +39,21 @@ _is_hip = is_hip()
|
|
39
39
|
_is_cuda = is_cuda()
|
40
40
|
_is_cpu_amx_available = cpu_has_amx_support()
|
41
41
|
_is_cpu = is_cpu()
|
42
|
+
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
42
43
|
|
43
44
|
if _is_cuda:
|
44
45
|
from sgl_kernel import gelu_and_mul, silu_and_mul
|
45
46
|
elif _is_cpu and _is_cpu_amx_available:
|
46
47
|
pass
|
47
|
-
|
48
|
-
from vllm import _custom_ops as vllm_ops
|
49
|
-
|
48
|
+
elif _is_hip:
|
49
|
+
from vllm import _custom_ops as vllm_ops # gelu_and_mul, silu_and_mul
|
50
|
+
|
51
|
+
if _use_aiter:
|
52
|
+
try:
|
53
|
+
from aiter import moe_sum
|
54
|
+
except ImportError:
|
55
|
+
raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")
|
56
|
+
|
50
57
|
|
51
58
|
if _is_cuda or _is_hip:
|
52
59
|
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
|
@@ -54,9 +61,6 @@ if _is_cuda or _is_hip:
|
|
54
61
|
|
55
62
|
logger = logging.getLogger(__name__)
|
56
63
|
padding_size = 128 if bool(int(os.getenv("SGLANG_MOE_PADDING", "0"))) else 0
|
57
|
-
enable_moe_align_block_size_triton = bool(
|
58
|
-
int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
|
59
|
-
)
|
60
64
|
|
61
65
|
|
62
66
|
@triton.jit
|
@@ -524,190 +528,6 @@ def fused_moe_kernel(
|
|
524
528
|
tl.store(c_ptrs, accumulator, mask=c_mask)
|
525
529
|
|
526
530
|
|
527
|
-
@triton.jit
|
528
|
-
def moe_align_block_size_stage1(
|
529
|
-
topk_ids_ptr,
|
530
|
-
tokens_cnts_ptr,
|
531
|
-
num_experts: tl.constexpr,
|
532
|
-
numel: tl.constexpr,
|
533
|
-
tokens_per_thread: tl.constexpr,
|
534
|
-
):
|
535
|
-
pid = tl.program_id(0)
|
536
|
-
|
537
|
-
start_idx = pid * tokens_per_thread
|
538
|
-
|
539
|
-
off_c = (pid + 1) * num_experts
|
540
|
-
|
541
|
-
for i in range(tokens_per_thread):
|
542
|
-
if start_idx + i < numel:
|
543
|
-
idx = tl.load(topk_ids_ptr + start_idx + i)
|
544
|
-
token_cnt = tl.load(tokens_cnts_ptr + off_c + idx)
|
545
|
-
tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1)
|
546
|
-
|
547
|
-
|
548
|
-
@triton.jit
|
549
|
-
def moe_align_block_size_stage2(
|
550
|
-
tokens_cnts_ptr,
|
551
|
-
num_experts: tl.constexpr,
|
552
|
-
):
|
553
|
-
pid = tl.program_id(0)
|
554
|
-
|
555
|
-
last_cnt = 0
|
556
|
-
for i in range(1, num_experts + 1):
|
557
|
-
token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid)
|
558
|
-
last_cnt = last_cnt + token_cnt
|
559
|
-
tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt)
|
560
|
-
|
561
|
-
|
562
|
-
@triton.jit
|
563
|
-
def moe_align_block_size_stage3(
|
564
|
-
total_tokens_post_pad_ptr,
|
565
|
-
tokens_cnts_ptr,
|
566
|
-
cumsum_ptr,
|
567
|
-
num_experts: tl.constexpr,
|
568
|
-
block_size: tl.constexpr,
|
569
|
-
):
|
570
|
-
last_cumsum = 0
|
571
|
-
off_cnt = num_experts * num_experts
|
572
|
-
for i in range(1, num_experts + 1):
|
573
|
-
token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1)
|
574
|
-
last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size
|
575
|
-
tl.store(cumsum_ptr + i, last_cumsum)
|
576
|
-
tl.store(total_tokens_post_pad_ptr, last_cumsum)
|
577
|
-
|
578
|
-
|
579
|
-
@triton.jit
|
580
|
-
def moe_align_block_size_stage4(
|
581
|
-
topk_ids_ptr,
|
582
|
-
sorted_token_ids_ptr,
|
583
|
-
expert_ids_ptr,
|
584
|
-
tokens_cnts_ptr,
|
585
|
-
cumsum_ptr,
|
586
|
-
num_experts: tl.constexpr,
|
587
|
-
block_size: tl.constexpr,
|
588
|
-
numel: tl.constexpr,
|
589
|
-
tokens_per_thread: tl.constexpr,
|
590
|
-
):
|
591
|
-
pid = tl.program_id(0)
|
592
|
-
start_idx = tl.load(cumsum_ptr + pid)
|
593
|
-
end_idx = tl.load(cumsum_ptr + pid + 1)
|
594
|
-
|
595
|
-
for i in range(start_idx, end_idx, block_size):
|
596
|
-
tl.store(expert_ids_ptr + i // block_size, pid)
|
597
|
-
|
598
|
-
start_idx = pid * tokens_per_thread
|
599
|
-
off_t = pid * num_experts
|
600
|
-
|
601
|
-
for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread, numel)):
|
602
|
-
expert_id = tl.load(topk_ids_ptr + i)
|
603
|
-
token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id)
|
604
|
-
rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id)
|
605
|
-
tl.store(sorted_token_ids_ptr + rank_post_pad, i)
|
606
|
-
tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1)
|
607
|
-
|
608
|
-
|
609
|
-
def moe_align_block_size_triton(
|
610
|
-
topk_ids: torch.Tensor,
|
611
|
-
num_experts: int,
|
612
|
-
block_size: int,
|
613
|
-
sorted_token_ids: torch.Tensor,
|
614
|
-
expert_ids: torch.Tensor,
|
615
|
-
num_tokens_post_pad: torch.Tensor,
|
616
|
-
) -> None:
|
617
|
-
numel = topk_ids.numel()
|
618
|
-
grid = (num_experts,)
|
619
|
-
tokens_cnts = torch.zeros(
|
620
|
-
(num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device
|
621
|
-
)
|
622
|
-
cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device=topk_ids.device)
|
623
|
-
tokens_per_thread = ceil_div(numel, num_experts)
|
624
|
-
|
625
|
-
moe_align_block_size_stage1[grid](
|
626
|
-
topk_ids,
|
627
|
-
tokens_cnts,
|
628
|
-
num_experts,
|
629
|
-
numel,
|
630
|
-
tokens_per_thread,
|
631
|
-
)
|
632
|
-
moe_align_block_size_stage2[grid](
|
633
|
-
tokens_cnts,
|
634
|
-
num_experts,
|
635
|
-
)
|
636
|
-
moe_align_block_size_stage3[(1,)](
|
637
|
-
num_tokens_post_pad,
|
638
|
-
tokens_cnts,
|
639
|
-
cumsum,
|
640
|
-
num_experts,
|
641
|
-
block_size,
|
642
|
-
)
|
643
|
-
moe_align_block_size_stage4[grid](
|
644
|
-
topk_ids,
|
645
|
-
sorted_token_ids,
|
646
|
-
expert_ids,
|
647
|
-
tokens_cnts,
|
648
|
-
cumsum,
|
649
|
-
num_experts,
|
650
|
-
block_size,
|
651
|
-
numel,
|
652
|
-
tokens_per_thread,
|
653
|
-
)
|
654
|
-
|
655
|
-
|
656
|
-
@triton.jit
|
657
|
-
def init_sorted_ids_and_cumsum_buffer_kernel(
|
658
|
-
sorted_ids_ptr,
|
659
|
-
cumsum_buffer_ptr,
|
660
|
-
max_num_tokens_padded,
|
661
|
-
topk_ids_numel,
|
662
|
-
num_experts: tl.constexpr,
|
663
|
-
BLOCK_SIZE: tl.constexpr,
|
664
|
-
ALIGNED_NUM_EXPERTS_P1: tl.constexpr,
|
665
|
-
):
|
666
|
-
pid = tl.program_id(0)
|
667
|
-
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
668
|
-
|
669
|
-
sorted_ids_blocks = tl.cdiv(max_num_tokens_padded, BLOCK_SIZE)
|
670
|
-
|
671
|
-
if pid < sorted_ids_blocks:
|
672
|
-
mask = offsets < max_num_tokens_padded
|
673
|
-
tl.store(
|
674
|
-
sorted_ids_ptr + offsets,
|
675
|
-
tl.full((BLOCK_SIZE,), topk_ids_numel, dtype=tl.int32),
|
676
|
-
mask=mask,
|
677
|
-
)
|
678
|
-
elif pid == sorted_ids_blocks:
|
679
|
-
offset_e = tl.arange(0, ALIGNED_NUM_EXPERTS_P1)
|
680
|
-
mask_e = offset_e < num_experts + 1
|
681
|
-
tl.store(
|
682
|
-
cumsum_buffer_ptr + offset_e,
|
683
|
-
tl.zeros((ALIGNED_NUM_EXPERTS_P1,), dtype=tl.int32),
|
684
|
-
mask=mask_e,
|
685
|
-
)
|
686
|
-
|
687
|
-
|
688
|
-
def init_sorted_ids_and_cumsum_buffer(
|
689
|
-
max_num_tokens_padded: int, topk_ids_numel: int, num_experts: int, device="cuda"
|
690
|
-
):
|
691
|
-
sorted_ids = torch.empty((max_num_tokens_padded,), dtype=torch.int32, device=device)
|
692
|
-
cumsum_buffer = torch.empty((num_experts + 1,), dtype=torch.int32, device=device)
|
693
|
-
|
694
|
-
BLOCK_SIZE = 1024
|
695
|
-
sorted_ids_blocks = triton.cdiv(max_num_tokens_padded, BLOCK_SIZE)
|
696
|
-
grid = (sorted_ids_blocks + 1,)
|
697
|
-
|
698
|
-
init_sorted_ids_and_cumsum_buffer_kernel[grid](
|
699
|
-
sorted_ids,
|
700
|
-
cumsum_buffer,
|
701
|
-
max_num_tokens_padded,
|
702
|
-
topk_ids_numel,
|
703
|
-
num_experts,
|
704
|
-
BLOCK_SIZE,
|
705
|
-
next_power_of_2(num_experts + 1),
|
706
|
-
)
|
707
|
-
|
708
|
-
return sorted_ids, cumsum_buffer
|
709
|
-
|
710
|
-
|
711
531
|
def moe_align_block_size(
|
712
532
|
topk_ids: torch.Tensor, block_size: int, num_experts: int
|
713
533
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
@@ -752,42 +572,37 @@ def moe_align_block_size(
|
|
752
572
|
sorted_ids = torch.empty(
|
753
573
|
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
|
754
574
|
)
|
755
|
-
sorted_ids.fill_(topk_ids.numel())
|
756
|
-
|
757
575
|
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
|
758
576
|
expert_ids = torch.empty(
|
759
577
|
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
760
578
|
)
|
761
579
|
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
|
762
|
-
if enable_moe_align_block_size_triton:
|
763
|
-
moe_align_block_size_triton(
|
764
|
-
topk_ids,
|
765
|
-
num_experts,
|
766
|
-
block_size,
|
767
|
-
sorted_ids,
|
768
|
-
expert_ids,
|
769
|
-
num_tokens_post_pad,
|
770
|
-
)
|
771
|
-
else:
|
772
|
-
cumsum_buffer = torch.empty(
|
773
|
-
(num_experts + 1,), dtype=torch.int32, device=topk_ids.device
|
774
|
-
)
|
775
|
-
token_cnts_buffer = torch.empty(
|
776
|
-
(num_experts + 1) * num_experts,
|
777
|
-
dtype=torch.int32,
|
778
|
-
device=topk_ids.device,
|
779
|
-
)
|
780
580
|
|
781
|
-
|
782
|
-
|
783
|
-
|
784
|
-
|
785
|
-
|
786
|
-
|
787
|
-
|
788
|
-
|
789
|
-
|
790
|
-
|
581
|
+
cumsum_buffer = torch.empty(
|
582
|
+
(num_experts + 1,), dtype=torch.int32, device=topk_ids.device
|
583
|
+
)
|
584
|
+
token_cnts_buffer = torch.empty(
|
585
|
+
(num_experts + 1) * num_experts,
|
586
|
+
dtype=torch.int32,
|
587
|
+
device=topk_ids.device,
|
588
|
+
)
|
589
|
+
|
590
|
+
# Threshold based on benchmark results
|
591
|
+
fuse_sorted_ids_padding = sorted_ids.shape[0] <= 4096
|
592
|
+
if not fuse_sorted_ids_padding:
|
593
|
+
sorted_ids.fill_(topk_ids.numel())
|
594
|
+
|
595
|
+
sgl_moe_align_block_size(
|
596
|
+
topk_ids,
|
597
|
+
num_experts,
|
598
|
+
block_size,
|
599
|
+
sorted_ids,
|
600
|
+
expert_ids,
|
601
|
+
num_tokens_post_pad,
|
602
|
+
token_cnts_buffer,
|
603
|
+
cumsum_buffer,
|
604
|
+
fuse_sorted_ids_padding,
|
605
|
+
)
|
791
606
|
return sorted_ids, expert_ids, num_tokens_post_pad
|
792
607
|
|
793
608
|
|
@@ -1328,8 +1143,7 @@ def fused_experts(
|
|
1328
1143
|
hidden_states: torch.Tensor,
|
1329
1144
|
w1: torch.Tensor,
|
1330
1145
|
w2: torch.Tensor,
|
1331
|
-
|
1332
|
-
topk_ids: torch.Tensor,
|
1146
|
+
topk_output: TopKOutput,
|
1333
1147
|
inplace: bool = False,
|
1334
1148
|
activation: str = "silu",
|
1335
1149
|
apply_router_weight_on_input: bool = False,
|
@@ -1348,7 +1162,7 @@ def fused_experts(
|
|
1348
1162
|
no_combine: bool = False,
|
1349
1163
|
routed_scaling_factor: Optional[float] = None,
|
1350
1164
|
):
|
1351
|
-
|
1165
|
+
topk_weights, topk_ids, _ = topk_output
|
1352
1166
|
if inplace:
|
1353
1167
|
assert not no_combine, "no combine + inplace makes no sense"
|
1354
1168
|
torch.ops.sglang.inplace_fused_experts(
|
@@ -1517,11 +1331,7 @@ def fused_experts_impl(
|
|
1517
1331
|
routed_scaling_factor: Optional[float] = None,
|
1518
1332
|
):
|
1519
1333
|
padded_size = padding_size
|
1520
|
-
if (
|
1521
|
-
not (use_fp8_w8a8 or use_int8_w8a8)
|
1522
|
-
or block_shape is not None
|
1523
|
-
or (_is_hip and get_bool_env_var("SGLANG_USE_AITER"))
|
1524
|
-
):
|
1334
|
+
if not (use_fp8_w8a8 or use_int8_w8a8) or block_shape is not None or _use_aiter:
|
1525
1335
|
padded_size = 0
|
1526
1336
|
|
1527
1337
|
# Check constraints.
|
@@ -1719,6 +1529,17 @@ def fused_experts_impl(
|
|
1719
1529
|
out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
1720
1530
|
routed_scaling_factor,
|
1721
1531
|
)
|
1532
|
+
elif _is_hip:
|
1533
|
+
if _use_aiter:
|
1534
|
+
moe_sum(
|
1535
|
+
intermediate_cache3.view(*intermediate_cache3.shape),
|
1536
|
+
out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
1537
|
+
)
|
1538
|
+
else:
|
1539
|
+
vllm_ops.moe_sum(
|
1540
|
+
intermediate_cache3.view(*intermediate_cache3.shape),
|
1541
|
+
out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
1542
|
+
)
|
1722
1543
|
else:
|
1723
1544
|
vllm_ops.moe_sum(
|
1724
1545
|
intermediate_cache3.view(*intermediate_cache3.shape),
|
@@ -1732,17 +1553,10 @@ def fused_moe(
|
|
1732
1553
|
hidden_states: torch.Tensor,
|
1733
1554
|
w1: torch.Tensor,
|
1734
1555
|
w2: torch.Tensor,
|
1735
|
-
|
1736
|
-
topk: int,
|
1737
|
-
renormalize: bool,
|
1556
|
+
topk_output: TopKOutput,
|
1738
1557
|
inplace: bool = False,
|
1739
1558
|
activation: str = "silu",
|
1740
1559
|
apply_router_weight_on_input: bool = False,
|
1741
|
-
use_grouped_topk: bool = False,
|
1742
|
-
num_expert_group: Optional[int] = None,
|
1743
|
-
num_fused_shared_experts: int = 0,
|
1744
|
-
topk_group: Optional[int] = None,
|
1745
|
-
custom_routing_function: Optional[Callable] = None,
|
1746
1560
|
use_fp8_w8a8: bool = False,
|
1747
1561
|
use_int8_w8a8: bool = False,
|
1748
1562
|
use_int8_w8a16: bool = False,
|
@@ -1766,16 +1580,9 @@ def fused_moe(
|
|
1766
1580
|
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
|
1767
1581
|
- w1 (torch.Tensor): The first set of expert weights.
|
1768
1582
|
- w2 (torch.Tensor): The second set of expert weights.
|
1769
|
-
-
|
1770
|
-
(before softmax).
|
1771
|
-
- topk (int): The number of top-k experts to select.
|
1772
|
-
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
|
1583
|
+
- topk_output (TopKOutput): The top-k output of the experts.
|
1773
1584
|
- inplace (bool): If True, perform the operation in-place.
|
1774
1585
|
Defaults to False.
|
1775
|
-
- num_expert_group: Optional[int]: additional parameter for grouped_topk
|
1776
|
-
- topk_group: Optional[int]: additional parameter for grouped_topk
|
1777
|
-
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
|
1778
|
-
note: Deepseek V2/V3/R1 series models use grouped_topk
|
1779
1586
|
- use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
|
1780
1587
|
products for w1 and w2. Defaults to False.
|
1781
1588
|
- use_int8_w8a8 (bool): If True, use int8 arithmetic to compute the inner
|
@@ -1799,28 +1606,12 @@ def fused_moe(
|
|
1799
1606
|
Returns:
|
1800
1607
|
- torch.Tensor: The output tensor after applying the MoE layer.
|
1801
1608
|
"""
|
1802
|
-
# Check constraints.
|
1803
|
-
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
|
1804
|
-
|
1805
|
-
topk_weights, topk_ids = select_experts(
|
1806
|
-
hidden_states=hidden_states,
|
1807
|
-
router_logits=gating_output,
|
1808
|
-
use_grouped_topk=use_grouped_topk,
|
1809
|
-
top_k=topk,
|
1810
|
-
renormalize=renormalize,
|
1811
|
-
topk_group=topk_group,
|
1812
|
-
num_expert_group=num_expert_group,
|
1813
|
-
num_fused_shared_experts=num_fused_shared_experts,
|
1814
|
-
custom_routing_function=custom_routing_function,
|
1815
|
-
routed_scaling_factor=routed_scaling_factor,
|
1816
|
-
)
|
1817
1609
|
|
1818
1610
|
return fused_experts(
|
1819
1611
|
hidden_states,
|
1820
1612
|
w1,
|
1821
1613
|
w2,
|
1822
|
-
|
1823
|
-
topk_ids,
|
1614
|
+
topk_output,
|
1824
1615
|
inplace=inplace,
|
1825
1616
|
activation=activation,
|
1826
1617
|
apply_router_weight_on_input=apply_router_weight_on_input,
|