sglang 0.5.1.post2__py3-none-any.whl → 0.5.2rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/bench_one_batch_server.py +79 -53
- sglang/bench_serving.py +186 -14
- sglang/profiler.py +0 -1
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +12 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/conversation.py +38 -5
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/launch_lb.py +0 -13
- sglang/srt/disaggregation/mini_lb.py +33 -8
- sglang/srt/disaggregation/prefill.py +1 -1
- sglang/srt/distributed/parallel_state.py +24 -14
- sglang/srt/entrypoints/engine.py +19 -12
- sglang/srt/entrypoints/http_server.py +174 -34
- sglang/srt/entrypoints/openai/protocol.py +87 -24
- sglang/srt/entrypoints/openai/serving_chat.py +50 -9
- sglang/srt/entrypoints/openai/serving_completions.py +15 -0
- sglang/srt/eplb/eplb_manager.py +26 -2
- sglang/srt/eplb/expert_distribution.py +29 -2
- sglang/srt/function_call/deepseekv31_detector.py +222 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/gpt_oss_detector.py +144 -256
- sglang/srt/harmony_parser.py +588 -0
- sglang/srt/hf_transformers_utils.py +26 -7
- sglang/srt/layers/activation.py +12 -0
- sglang/srt/layers/attention/ascend_backend.py +374 -136
- sglang/srt/layers/attention/flashattention_backend.py +241 -7
- sglang/srt/layers/attention/flashinfer_backend.py +5 -2
- sglang/srt/layers/attention/flashinfer_mla_backend.py +5 -2
- sglang/srt/layers/attention/hybrid_attn_backend.py +53 -21
- sglang/srt/layers/attention/trtllm_mla_backend.py +25 -10
- sglang/srt/layers/communicator.py +1 -2
- sglang/srt/layers/layernorm.py +28 -3
- sglang/srt/layers/linear.py +3 -2
- sglang/srt/layers/logits_processor.py +1 -1
- sglang/srt/layers/moe/cutlass_moe.py +0 -8
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +13 -13
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/topk.py +35 -12
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +133 -235
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +5 -23
- sglang/srt/layers/quantization/fp8.py +2 -1
- sglang/srt/layers/quantization/fp8_kernel.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/modelopt_quant.py +7 -0
- sglang/srt/layers/quantization/mxfp4.py +25 -27
- sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w8a8_int8.py +7 -3
- sglang/srt/layers/rotary_embedding.py +28 -1
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/layers/utils.py +0 -14
- sglang/srt/managers/cache_controller.py +237 -204
- sglang/srt/managers/detokenizer_manager.py +48 -2
- sglang/srt/managers/io_struct.py +57 -0
- sglang/srt/managers/mm_utils.py +5 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +591 -0
- sglang/srt/managers/scheduler.py +94 -9
- sglang/srt/managers/scheduler_output_processor_mixin.py +20 -18
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/tokenizer_manager.py +122 -42
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +51 -23
- sglang/srt/mem_cache/hiradix_cache.py +87 -71
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +77 -14
- sglang/srt/mem_cache/memory_pool_host.py +4 -5
- sglang/srt/mem_cache/radix_cache.py +6 -4
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +38 -20
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +87 -82
- sglang/srt/mem_cache/swa_radix_cache.py +1 -1
- sglang/srt/model_executor/model_runner.py +6 -5
- sglang/srt/model_loader/loader.py +15 -24
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/models/deepseek_v2.py +38 -13
- sglang/srt/models/gpt_oss.py +2 -15
- sglang/srt/models/llama_eagle3.py +4 -0
- sglang/srt/models/longcat_flash.py +1015 -0
- sglang/srt/models/longcat_flash_nextn.py +691 -0
- sglang/srt/models/qwen2.py +26 -3
- sglang/srt/models/qwen2_5_vl.py +66 -41
- sglang/srt/models/qwen2_moe.py +22 -2
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/reasoning_parser.py +56 -300
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/server_args.py +122 -56
- sglang/srt/speculative/eagle_worker.py +28 -8
- sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
- sglang/srt/utils.py +73 -5
- sglang/test/attention/test_trtllm_mla_backend.py +12 -3
- sglang/version.py +1 -1
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/METADATA +7 -6
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/RECORD +107 -99
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/top_level.txt +0 -0
@@ -1362,3 +1362,77 @@ def moe_ep_deepgemm_preprocess(
|
|
1362
1362
|
gateup_input,
|
1363
1363
|
gateup_input_scale,
|
1364
1364
|
)
|
1365
|
+
|
1366
|
+
|
1367
|
+
@triton.jit
|
1368
|
+
def compute_identity_kernel(
|
1369
|
+
top_k,
|
1370
|
+
hidden_states_ptr,
|
1371
|
+
expert_scales_ptr,
|
1372
|
+
num_tokens,
|
1373
|
+
output_ptr,
|
1374
|
+
hidden_dim,
|
1375
|
+
scales_stride,
|
1376
|
+
BLOCK_SIZE: tl.constexpr,
|
1377
|
+
):
|
1378
|
+
pid = tl.program_id(0)
|
1379
|
+
|
1380
|
+
batch_id = pid // (hidden_dim // BLOCK_SIZE)
|
1381
|
+
dim_offset = pid % (hidden_dim // BLOCK_SIZE) * BLOCK_SIZE
|
1382
|
+
|
1383
|
+
if batch_id >= num_tokens or dim_offset >= hidden_dim:
|
1384
|
+
return
|
1385
|
+
|
1386
|
+
h = tl.load(
|
1387
|
+
hidden_states_ptr
|
1388
|
+
+ batch_id * hidden_dim
|
1389
|
+
+ dim_offset
|
1390
|
+
+ tl.arange(0, BLOCK_SIZE),
|
1391
|
+
mask=(dim_offset + tl.arange(0, BLOCK_SIZE)) < hidden_dim,
|
1392
|
+
)
|
1393
|
+
|
1394
|
+
result = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
|
1395
|
+
for i in range(top_k):
|
1396
|
+
scale = tl.load(expert_scales_ptr + batch_id * scales_stride + i)
|
1397
|
+
result += h * scale
|
1398
|
+
|
1399
|
+
tl.store(
|
1400
|
+
output_ptr + batch_id * hidden_dim + dim_offset + tl.arange(0, BLOCK_SIZE),
|
1401
|
+
result,
|
1402
|
+
mask=(dim_offset + tl.arange(0, BLOCK_SIZE)) < hidden_dim,
|
1403
|
+
)
|
1404
|
+
|
1405
|
+
|
1406
|
+
def zero_experts_compute_triton(
|
1407
|
+
expert_indices, expert_scales, num_experts, zero_expert_type, hidden_states
|
1408
|
+
):
|
1409
|
+
N = expert_indices.numel()
|
1410
|
+
top_k = expert_indices.size(-1)
|
1411
|
+
grid = lambda meta: (triton.cdiv(N, meta["BLOCK_SIZE"]),)
|
1412
|
+
|
1413
|
+
if zero_expert_type == "identity":
|
1414
|
+
zero_expert_mask = expert_indices < num_experts
|
1415
|
+
zero_expert_scales = expert_scales.clone()
|
1416
|
+
zero_expert_scales[zero_expert_mask] = 0.0
|
1417
|
+
|
1418
|
+
normal_expert_mask = expert_indices >= num_experts
|
1419
|
+
expert_indices[normal_expert_mask] = 0
|
1420
|
+
expert_scales[normal_expert_mask] = 0.0
|
1421
|
+
|
1422
|
+
output = torch.zeros_like(hidden_states).to(hidden_states.device)
|
1423
|
+
hidden_dim = hidden_states.size(-1)
|
1424
|
+
num_tokens = hidden_states.size(0)
|
1425
|
+
|
1426
|
+
grid = lambda meta: (num_tokens * (hidden_dim // meta["BLOCK_SIZE"]),)
|
1427
|
+
compute_identity_kernel[grid](
|
1428
|
+
top_k,
|
1429
|
+
hidden_states,
|
1430
|
+
zero_expert_scales,
|
1431
|
+
num_tokens,
|
1432
|
+
output,
|
1433
|
+
hidden_dim,
|
1434
|
+
zero_expert_scales.stride(0),
|
1435
|
+
BLOCK_SIZE=256,
|
1436
|
+
)
|
1437
|
+
|
1438
|
+
return output
|
@@ -248,7 +248,6 @@ class EPMoE(FusedMoE):
|
|
248
248
|
gateup_output,
|
249
249
|
masked_m,
|
250
250
|
expected_m,
|
251
|
-
recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None,
|
252
251
|
)
|
253
252
|
del gateup_input
|
254
253
|
del gateup_input_fp8
|
@@ -304,7 +303,6 @@ class EPMoE(FusedMoE):
|
|
304
303
|
down_output,
|
305
304
|
masked_m,
|
306
305
|
expected_m,
|
307
|
-
recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None,
|
308
306
|
)
|
309
307
|
del down_input
|
310
308
|
del down_input_fp8
|
@@ -667,7 +665,6 @@ class DeepEPMoE(EPMoE):
|
|
667
665
|
gateup_output,
|
668
666
|
masked_m,
|
669
667
|
expected_m,
|
670
|
-
recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None,
|
671
668
|
)
|
672
669
|
dispose_tensor(hidden_states_fp8[0])
|
673
670
|
|
@@ -708,9 +705,7 @@ class DeepEPMoE(EPMoE):
|
|
708
705
|
(
|
709
706
|
down_input_scale
|
710
707
|
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
|
711
|
-
else deep_gemm_wrapper.
|
712
|
-
down_input_scale
|
713
|
-
)
|
708
|
+
else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(down_input_scale)
|
714
709
|
),
|
715
710
|
)
|
716
711
|
down_output = torch.empty(
|
@@ -722,7 +717,6 @@ class DeepEPMoE(EPMoE):
|
|
722
717
|
down_output,
|
723
718
|
masked_m,
|
724
719
|
expected_m,
|
725
|
-
recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None,
|
726
720
|
)
|
727
721
|
|
728
722
|
return down_output
|
@@ -752,19 +746,25 @@ class DeepEPMoE(EPMoE):
|
|
752
746
|
hidden_states = torch_npu.npu_grouped_matmul(
|
753
747
|
x=[hidden_states],
|
754
748
|
weight=[self.w13_weight],
|
755
|
-
scale=[self.w13_weight_scale.to(output_dtype)],
|
756
|
-
per_token_scale=[pertoken_scale],
|
757
749
|
split_item=2,
|
758
750
|
group_list_type=group_list_type,
|
759
751
|
group_type=0,
|
760
752
|
group_list=seg_indptr,
|
761
|
-
output_dtype=
|
753
|
+
output_dtype=torch.int32,
|
762
754
|
)[0]
|
763
755
|
|
764
756
|
# act_fn: swiglu
|
765
|
-
hidden_states = torch_npu.
|
766
|
-
|
767
|
-
|
757
|
+
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
|
758
|
+
x=hidden_states,
|
759
|
+
weight_scale=self.w13_weight_scale.to(torch.float32),
|
760
|
+
activation_scale=pertoken_scale,
|
761
|
+
bias=None,
|
762
|
+
quant_scale=None,
|
763
|
+
quant_offset=None,
|
764
|
+
group_index=seg_indptr,
|
765
|
+
activate_left=True,
|
766
|
+
quant_mode=1,
|
767
|
+
)
|
768
768
|
|
769
769
|
# gmm2: down_proj
|
770
770
|
hidden_states = torch_npu.npu_grouped_matmul(
|
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 16,
|
4
|
+
"BLOCK_SIZE_N": 64,
|
5
|
+
"BLOCK_SIZE_K": 64,
|
6
|
+
"GROUP_SIZE_M": 1,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 5
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 16,
|
12
|
+
"BLOCK_SIZE_N": 128,
|
13
|
+
"BLOCK_SIZE_K": 128,
|
14
|
+
"GROUP_SIZE_M": 1,
|
15
|
+
"num_warps": 4,
|
16
|
+
"num_stages": 4
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 16,
|
20
|
+
"BLOCK_SIZE_N": 128,
|
21
|
+
"BLOCK_SIZE_K": 128,
|
22
|
+
"GROUP_SIZE_M": 1,
|
23
|
+
"num_warps": 4,
|
24
|
+
"num_stages": 5
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 16,
|
28
|
+
"BLOCK_SIZE_N": 128,
|
29
|
+
"BLOCK_SIZE_K": 128,
|
30
|
+
"GROUP_SIZE_M": 1,
|
31
|
+
"num_warps": 4,
|
32
|
+
"num_stages": 4
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 16,
|
36
|
+
"BLOCK_SIZE_N": 128,
|
37
|
+
"BLOCK_SIZE_K": 128,
|
38
|
+
"GROUP_SIZE_M": 1,
|
39
|
+
"num_warps": 4,
|
40
|
+
"num_stages": 3
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 16,
|
44
|
+
"BLOCK_SIZE_N": 128,
|
45
|
+
"BLOCK_SIZE_K": 128,
|
46
|
+
"GROUP_SIZE_M": 1,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 3
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 16,
|
52
|
+
"BLOCK_SIZE_N": 128,
|
53
|
+
"BLOCK_SIZE_K": 128,
|
54
|
+
"GROUP_SIZE_M": 1,
|
55
|
+
"num_warps": 4,
|
56
|
+
"num_stages": 3
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 16,
|
60
|
+
"BLOCK_SIZE_N": 256,
|
61
|
+
"BLOCK_SIZE_K": 128,
|
62
|
+
"GROUP_SIZE_M": 1,
|
63
|
+
"num_warps": 4,
|
64
|
+
"num_stages": 3
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 16,
|
68
|
+
"BLOCK_SIZE_N": 256,
|
69
|
+
"BLOCK_SIZE_K": 64,
|
70
|
+
"GROUP_SIZE_M": 1,
|
71
|
+
"num_warps": 4,
|
72
|
+
"num_stages": 3
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 16,
|
76
|
+
"BLOCK_SIZE_N": 256,
|
77
|
+
"BLOCK_SIZE_K": 64,
|
78
|
+
"GROUP_SIZE_M": 1,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 4
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 16,
|
84
|
+
"BLOCK_SIZE_N": 256,
|
85
|
+
"BLOCK_SIZE_K": 128,
|
86
|
+
"GROUP_SIZE_M": 1,
|
87
|
+
"num_warps": 4,
|
88
|
+
"num_stages": 3
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 16,
|
92
|
+
"BLOCK_SIZE_N": 256,
|
93
|
+
"BLOCK_SIZE_K": 64,
|
94
|
+
"GROUP_SIZE_M": 1,
|
95
|
+
"num_warps": 4,
|
96
|
+
"num_stages": 3
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 64,
|
100
|
+
"BLOCK_SIZE_N": 128,
|
101
|
+
"BLOCK_SIZE_K": 128,
|
102
|
+
"GROUP_SIZE_M": 1,
|
103
|
+
"num_warps": 8,
|
104
|
+
"num_stages": 3
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 64,
|
108
|
+
"BLOCK_SIZE_N": 256,
|
109
|
+
"BLOCK_SIZE_K": 64,
|
110
|
+
"GROUP_SIZE_M": 16,
|
111
|
+
"num_warps": 8,
|
112
|
+
"num_stages": 5
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 128,
|
116
|
+
"BLOCK_SIZE_N": 256,
|
117
|
+
"BLOCK_SIZE_K": 64,
|
118
|
+
"GROUP_SIZE_M": 16,
|
119
|
+
"num_warps": 4,
|
120
|
+
"num_stages": 4
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 128,
|
124
|
+
"BLOCK_SIZE_N": 256,
|
125
|
+
"BLOCK_SIZE_K": 64,
|
126
|
+
"GROUP_SIZE_M": 1,
|
127
|
+
"num_warps": 4,
|
128
|
+
"num_stages": 4
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 256,
|
132
|
+
"BLOCK_SIZE_N": 256,
|
133
|
+
"BLOCK_SIZE_K": 64,
|
134
|
+
"GROUP_SIZE_M": 16,
|
135
|
+
"num_warps": 8,
|
136
|
+
"num_stages": 5
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 256,
|
140
|
+
"BLOCK_SIZE_N": 256,
|
141
|
+
"BLOCK_SIZE_K": 64,
|
142
|
+
"GROUP_SIZE_M": 16,
|
143
|
+
"num_warps": 8,
|
144
|
+
"num_stages": 5
|
145
|
+
}
|
146
|
+
}
|
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 16,
|
4
|
+
"BLOCK_SIZE_N": 64,
|
5
|
+
"BLOCK_SIZE_K": 64,
|
6
|
+
"GROUP_SIZE_M": 64,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 4
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 16,
|
12
|
+
"BLOCK_SIZE_N": 64,
|
13
|
+
"BLOCK_SIZE_K": 64,
|
14
|
+
"GROUP_SIZE_M": 64,
|
15
|
+
"num_warps": 4,
|
16
|
+
"num_stages": 4
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 16,
|
20
|
+
"BLOCK_SIZE_N": 64,
|
21
|
+
"BLOCK_SIZE_K": 64,
|
22
|
+
"GROUP_SIZE_M": 64,
|
23
|
+
"num_warps": 4,
|
24
|
+
"num_stages": 4
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 16,
|
28
|
+
"BLOCK_SIZE_N": 64,
|
29
|
+
"BLOCK_SIZE_K": 64,
|
30
|
+
"GROUP_SIZE_M": 1,
|
31
|
+
"num_warps": 4,
|
32
|
+
"num_stages": 3
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 16,
|
36
|
+
"BLOCK_SIZE_N": 128,
|
37
|
+
"BLOCK_SIZE_K": 64,
|
38
|
+
"GROUP_SIZE_M": 1,
|
39
|
+
"num_warps": 4,
|
40
|
+
"num_stages": 3
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 16,
|
44
|
+
"BLOCK_SIZE_N": 128,
|
45
|
+
"BLOCK_SIZE_K": 64,
|
46
|
+
"GROUP_SIZE_M": 1,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 3
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 16,
|
52
|
+
"BLOCK_SIZE_N": 64,
|
53
|
+
"BLOCK_SIZE_K": 64,
|
54
|
+
"GROUP_SIZE_M": 1,
|
55
|
+
"num_warps": 4,
|
56
|
+
"num_stages": 3
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 16,
|
60
|
+
"BLOCK_SIZE_N": 64,
|
61
|
+
"BLOCK_SIZE_K": 64,
|
62
|
+
"GROUP_SIZE_M": 1,
|
63
|
+
"num_warps": 4,
|
64
|
+
"num_stages": 3
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 16,
|
68
|
+
"BLOCK_SIZE_N": 64,
|
69
|
+
"BLOCK_SIZE_K": 64,
|
70
|
+
"GROUP_SIZE_M": 1,
|
71
|
+
"num_warps": 4,
|
72
|
+
"num_stages": 3
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 16,
|
76
|
+
"BLOCK_SIZE_N": 128,
|
77
|
+
"BLOCK_SIZE_K": 64,
|
78
|
+
"GROUP_SIZE_M": 1,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 3
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 16,
|
84
|
+
"BLOCK_SIZE_N": 128,
|
85
|
+
"BLOCK_SIZE_K": 64,
|
86
|
+
"GROUP_SIZE_M": 1,
|
87
|
+
"num_warps": 4,
|
88
|
+
"num_stages": 3
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 16,
|
92
|
+
"BLOCK_SIZE_N": 128,
|
93
|
+
"BLOCK_SIZE_K": 64,
|
94
|
+
"GROUP_SIZE_M": 1,
|
95
|
+
"num_warps": 4,
|
96
|
+
"num_stages": 3
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 32,
|
100
|
+
"BLOCK_SIZE_N": 128,
|
101
|
+
"BLOCK_SIZE_K": 64,
|
102
|
+
"GROUP_SIZE_M": 1,
|
103
|
+
"num_warps": 4,
|
104
|
+
"num_stages": 3
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 32,
|
108
|
+
"BLOCK_SIZE_N": 128,
|
109
|
+
"BLOCK_SIZE_K": 64,
|
110
|
+
"GROUP_SIZE_M": 1,
|
111
|
+
"num_warps": 4,
|
112
|
+
"num_stages": 2
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 64,
|
116
|
+
"BLOCK_SIZE_N": 128,
|
117
|
+
"BLOCK_SIZE_K": 64,
|
118
|
+
"GROUP_SIZE_M": 32,
|
119
|
+
"num_warps": 4,
|
120
|
+
"num_stages": 2
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 32,
|
124
|
+
"BLOCK_SIZE_N": 128,
|
125
|
+
"BLOCK_SIZE_K": 64,
|
126
|
+
"GROUP_SIZE_M": 1,
|
127
|
+
"num_warps": 4,
|
128
|
+
"num_stages": 2
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 128,
|
132
|
+
"BLOCK_SIZE_N": 64,
|
133
|
+
"BLOCK_SIZE_K": 64,
|
134
|
+
"GROUP_SIZE_M": 1,
|
135
|
+
"num_warps": 4,
|
136
|
+
"num_stages": 3
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 64,
|
140
|
+
"BLOCK_SIZE_N": 128,
|
141
|
+
"BLOCK_SIZE_K": 64,
|
142
|
+
"GROUP_SIZE_M": 64,
|
143
|
+
"num_warps": 4,
|
144
|
+
"num_stages": 2
|
145
|
+
}
|
146
|
+
}
|
sglang/srt/layers/moe/topk.py
CHANGED
@@ -304,12 +304,12 @@ class TopK(CustomOp):
|
|
304
304
|
global_num_experts = router_logits.shape[-1]
|
305
305
|
|
306
306
|
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
307
|
-
if global_num_experts == 256
|
307
|
+
if global_num_experts == 256:
|
308
308
|
|
309
309
|
routed_scaling_factor = self.topk_config.routed_scaling_factor or 1
|
310
310
|
router_logits = router_logits.to(torch.float32)
|
311
311
|
|
312
|
-
|
312
|
+
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
313
313
|
router_logits,
|
314
314
|
k=self.topk_config.top_k,
|
315
315
|
bias=self.topk_config.correction_bias.to(torch.float32),
|
@@ -321,6 +321,16 @@ class TopK(CustomOp):
|
|
321
321
|
routed_scaling_factor=routed_scaling_factor,
|
322
322
|
eps=float(1e-20),
|
323
323
|
)
|
324
|
+
|
325
|
+
if self.topk_config.renormalize:
|
326
|
+
topk_weights_sum = (
|
327
|
+
topk_weights.sum(dim=-1, keepdim=True)
|
328
|
+
if self.topk_config.num_fused_shared_experts == 0
|
329
|
+
else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
|
330
|
+
)
|
331
|
+
topk_weights = topk_weights / topk_weights_sum
|
332
|
+
|
333
|
+
return StandardTopKOutput(topk_weights, topk_ids, _)
|
324
334
|
else:
|
325
335
|
self.topk_config.torch_native = True
|
326
336
|
return select_experts(
|
@@ -347,17 +357,28 @@ def fused_topk_torch_native(
|
|
347
357
|
gating_output: torch.Tensor,
|
348
358
|
topk: int,
|
349
359
|
renormalize: bool,
|
360
|
+
correction_bias: torch.Tensor = None,
|
350
361
|
):
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
362
|
+
if correction_bias is not None:
|
363
|
+
n_routed_experts = gating_output.shape[-1]
|
364
|
+
scores = gating_output.softmax(dim=-1)
|
365
|
+
scores_for_choice = scores.view(
|
366
|
+
-1, n_routed_experts
|
367
|
+
) + correction_bias.unsqueeze(0)
|
368
|
+
topk_ids = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=False)[1]
|
369
|
+
topk_weights = scores.gather(1, topk_ids)
|
370
|
+
else:
|
371
|
+
assert (
|
372
|
+
hidden_states.shape[0] == gating_output.shape[0]
|
373
|
+
), f"Number of tokens mismatch, {hidden_states.shape=} vs {gating_output.shape=}"
|
374
|
+
M, _ = hidden_states.shape
|
375
|
+
topk_weights = torch.empty(
|
376
|
+
M, topk, dtype=torch.float32, device=hidden_states.device
|
377
|
+
)
|
378
|
+
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
|
379
|
+
topk_weights = F.softmax(gating_output.float(), dim=-1)
|
380
|
+
topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1)
|
381
|
+
|
361
382
|
if renormalize:
|
362
383
|
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
363
384
|
return topk_weights, topk_ids
|
@@ -370,6 +391,7 @@ def fused_topk_cpu(
|
|
370
391
|
renormalize: bool,
|
371
392
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
372
393
|
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
394
|
+
correction_bias: torch.Tensor = None,
|
373
395
|
):
|
374
396
|
topk_weights, topk_ids = torch.ops.sgl_kernel.topk_softmax_cpu(
|
375
397
|
hidden_states=hidden_states,
|
@@ -815,6 +837,7 @@ def select_experts(
|
|
815
837
|
gating_output=router_logits,
|
816
838
|
topk=top_k,
|
817
839
|
renormalize=renormalize,
|
840
|
+
correction_bias=correction_bias,
|
818
841
|
)
|
819
842
|
elif custom_routing_function is None:
|
820
843
|
assert not apply_routed_scaling_factor_on_output, "Not implemented"
|