sglang 0.5.1.post2__py3-none-any.whl → 0.5.2rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +79 -53
  3. sglang/bench_serving.py +186 -14
  4. sglang/profiler.py +0 -1
  5. sglang/srt/configs/__init__.py +2 -0
  6. sglang/srt/configs/longcat_flash.py +104 -0
  7. sglang/srt/configs/model_config.py +12 -0
  8. sglang/srt/connector/__init__.py +1 -1
  9. sglang/srt/connector/base_connector.py +1 -2
  10. sglang/srt/connector/redis.py +2 -2
  11. sglang/srt/connector/serde/__init__.py +1 -1
  12. sglang/srt/connector/serde/safe_serde.py +4 -3
  13. sglang/srt/conversation.py +38 -5
  14. sglang/srt/disaggregation/ascend/conn.py +75 -0
  15. sglang/srt/disaggregation/launch_lb.py +0 -13
  16. sglang/srt/disaggregation/mini_lb.py +33 -8
  17. sglang/srt/disaggregation/prefill.py +1 -1
  18. sglang/srt/distributed/parallel_state.py +24 -14
  19. sglang/srt/entrypoints/engine.py +19 -12
  20. sglang/srt/entrypoints/http_server.py +174 -34
  21. sglang/srt/entrypoints/openai/protocol.py +87 -24
  22. sglang/srt/entrypoints/openai/serving_chat.py +50 -9
  23. sglang/srt/entrypoints/openai/serving_completions.py +15 -0
  24. sglang/srt/eplb/eplb_manager.py +26 -2
  25. sglang/srt/eplb/expert_distribution.py +29 -2
  26. sglang/srt/function_call/deepseekv31_detector.py +222 -0
  27. sglang/srt/function_call/function_call_parser.py +2 -0
  28. sglang/srt/function_call/gpt_oss_detector.py +144 -256
  29. sglang/srt/harmony_parser.py +588 -0
  30. sglang/srt/hf_transformers_utils.py +26 -7
  31. sglang/srt/layers/activation.py +12 -0
  32. sglang/srt/layers/attention/ascend_backend.py +374 -136
  33. sglang/srt/layers/attention/flashattention_backend.py +241 -7
  34. sglang/srt/layers/attention/flashinfer_backend.py +5 -2
  35. sglang/srt/layers/attention/flashinfer_mla_backend.py +5 -2
  36. sglang/srt/layers/attention/hybrid_attn_backend.py +53 -21
  37. sglang/srt/layers/attention/trtllm_mla_backend.py +25 -10
  38. sglang/srt/layers/communicator.py +1 -2
  39. sglang/srt/layers/layernorm.py +28 -3
  40. sglang/srt/layers/linear.py +3 -2
  41. sglang/srt/layers/logits_processor.py +1 -1
  42. sglang/srt/layers/moe/cutlass_moe.py +0 -8
  43. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  44. sglang/srt/layers/moe/ep_moe/layer.py +13 -13
  45. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  46. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  47. sglang/srt/layers/moe/topk.py +35 -12
  48. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +133 -235
  49. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
  50. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +5 -23
  51. sglang/srt/layers/quantization/fp8.py +2 -1
  52. sglang/srt/layers/quantization/fp8_kernel.py +2 -2
  53. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  54. sglang/srt/layers/quantization/modelopt_quant.py +7 -0
  55. sglang/srt/layers/quantization/mxfp4.py +25 -27
  56. sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
  57. sglang/srt/layers/quantization/utils.py +13 -0
  58. sglang/srt/layers/quantization/w8a8_int8.py +7 -3
  59. sglang/srt/layers/rotary_embedding.py +28 -1
  60. sglang/srt/layers/sampler.py +29 -5
  61. sglang/srt/layers/utils.py +0 -14
  62. sglang/srt/managers/cache_controller.py +237 -204
  63. sglang/srt/managers/detokenizer_manager.py +48 -2
  64. sglang/srt/managers/io_struct.py +57 -0
  65. sglang/srt/managers/mm_utils.py +5 -1
  66. sglang/srt/managers/multi_tokenizer_mixin.py +591 -0
  67. sglang/srt/managers/scheduler.py +94 -9
  68. sglang/srt/managers/scheduler_output_processor_mixin.py +20 -18
  69. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  70. sglang/srt/managers/tokenizer_manager.py +122 -42
  71. sglang/srt/mem_cache/chunk_cache.py +1 -1
  72. sglang/srt/mem_cache/hicache_storage.py +51 -23
  73. sglang/srt/mem_cache/hiradix_cache.py +87 -71
  74. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  75. sglang/srt/mem_cache/memory_pool.py +77 -14
  76. sglang/srt/mem_cache/memory_pool_host.py +4 -5
  77. sglang/srt/mem_cache/radix_cache.py +6 -4
  78. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  79. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +38 -20
  80. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +87 -82
  81. sglang/srt/mem_cache/swa_radix_cache.py +1 -1
  82. sglang/srt/model_executor/model_runner.py +6 -5
  83. sglang/srt/model_loader/loader.py +15 -24
  84. sglang/srt/model_loader/utils.py +12 -0
  85. sglang/srt/models/deepseek_v2.py +38 -13
  86. sglang/srt/models/gpt_oss.py +2 -15
  87. sglang/srt/models/llama_eagle3.py +4 -0
  88. sglang/srt/models/longcat_flash.py +1015 -0
  89. sglang/srt/models/longcat_flash_nextn.py +691 -0
  90. sglang/srt/models/qwen2.py +26 -3
  91. sglang/srt/models/qwen2_5_vl.py +66 -41
  92. sglang/srt/models/qwen2_moe.py +22 -2
  93. sglang/srt/models/transformers.py +1 -1
  94. sglang/srt/multimodal/processors/base_processor.py +4 -2
  95. sglang/srt/reasoning_parser.py +56 -300
  96. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  97. sglang/srt/server_args.py +122 -56
  98. sglang/srt/speculative/eagle_worker.py +28 -8
  99. sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
  100. sglang/srt/utils.py +73 -5
  101. sglang/test/attention/test_trtllm_mla_backend.py +12 -3
  102. sglang/version.py +1 -1
  103. {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/METADATA +7 -6
  104. {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/RECORD +107 -99
  105. {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/WHEEL +0 -0
  106. {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/licenses/LICENSE +0 -0
  107. {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/top_level.txt +0 -0
@@ -1362,3 +1362,77 @@ def moe_ep_deepgemm_preprocess(
1362
1362
  gateup_input,
1363
1363
  gateup_input_scale,
1364
1364
  )
1365
+
1366
+
1367
+ @triton.jit
1368
+ def compute_identity_kernel(
1369
+ top_k,
1370
+ hidden_states_ptr,
1371
+ expert_scales_ptr,
1372
+ num_tokens,
1373
+ output_ptr,
1374
+ hidden_dim,
1375
+ scales_stride,
1376
+ BLOCK_SIZE: tl.constexpr,
1377
+ ):
1378
+ pid = tl.program_id(0)
1379
+
1380
+ batch_id = pid // (hidden_dim // BLOCK_SIZE)
1381
+ dim_offset = pid % (hidden_dim // BLOCK_SIZE) * BLOCK_SIZE
1382
+
1383
+ if batch_id >= num_tokens or dim_offset >= hidden_dim:
1384
+ return
1385
+
1386
+ h = tl.load(
1387
+ hidden_states_ptr
1388
+ + batch_id * hidden_dim
1389
+ + dim_offset
1390
+ + tl.arange(0, BLOCK_SIZE),
1391
+ mask=(dim_offset + tl.arange(0, BLOCK_SIZE)) < hidden_dim,
1392
+ )
1393
+
1394
+ result = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
1395
+ for i in range(top_k):
1396
+ scale = tl.load(expert_scales_ptr + batch_id * scales_stride + i)
1397
+ result += h * scale
1398
+
1399
+ tl.store(
1400
+ output_ptr + batch_id * hidden_dim + dim_offset + tl.arange(0, BLOCK_SIZE),
1401
+ result,
1402
+ mask=(dim_offset + tl.arange(0, BLOCK_SIZE)) < hidden_dim,
1403
+ )
1404
+
1405
+
1406
+ def zero_experts_compute_triton(
1407
+ expert_indices, expert_scales, num_experts, zero_expert_type, hidden_states
1408
+ ):
1409
+ N = expert_indices.numel()
1410
+ top_k = expert_indices.size(-1)
1411
+ grid = lambda meta: (triton.cdiv(N, meta["BLOCK_SIZE"]),)
1412
+
1413
+ if zero_expert_type == "identity":
1414
+ zero_expert_mask = expert_indices < num_experts
1415
+ zero_expert_scales = expert_scales.clone()
1416
+ zero_expert_scales[zero_expert_mask] = 0.0
1417
+
1418
+ normal_expert_mask = expert_indices >= num_experts
1419
+ expert_indices[normal_expert_mask] = 0
1420
+ expert_scales[normal_expert_mask] = 0.0
1421
+
1422
+ output = torch.zeros_like(hidden_states).to(hidden_states.device)
1423
+ hidden_dim = hidden_states.size(-1)
1424
+ num_tokens = hidden_states.size(0)
1425
+
1426
+ grid = lambda meta: (num_tokens * (hidden_dim // meta["BLOCK_SIZE"]),)
1427
+ compute_identity_kernel[grid](
1428
+ top_k,
1429
+ hidden_states,
1430
+ zero_expert_scales,
1431
+ num_tokens,
1432
+ output,
1433
+ hidden_dim,
1434
+ zero_expert_scales.stride(0),
1435
+ BLOCK_SIZE=256,
1436
+ )
1437
+
1438
+ return output
@@ -248,7 +248,6 @@ class EPMoE(FusedMoE):
248
248
  gateup_output,
249
249
  masked_m,
250
250
  expected_m,
251
- recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None,
252
251
  )
253
252
  del gateup_input
254
253
  del gateup_input_fp8
@@ -304,7 +303,6 @@ class EPMoE(FusedMoE):
304
303
  down_output,
305
304
  masked_m,
306
305
  expected_m,
307
- recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None,
308
306
  )
309
307
  del down_input
310
308
  del down_input_fp8
@@ -667,7 +665,6 @@ class DeepEPMoE(EPMoE):
667
665
  gateup_output,
668
666
  masked_m,
669
667
  expected_m,
670
- recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None,
671
668
  )
672
669
  dispose_tensor(hidden_states_fp8[0])
673
670
 
@@ -708,9 +705,7 @@ class DeepEPMoE(EPMoE):
708
705
  (
709
706
  down_input_scale
710
707
  if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
711
- else deep_gemm_wrapper.get_col_major_tma_aligned_tensor(
712
- down_input_scale
713
- )
708
+ else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(down_input_scale)
714
709
  ),
715
710
  )
716
711
  down_output = torch.empty(
@@ -722,7 +717,6 @@ class DeepEPMoE(EPMoE):
722
717
  down_output,
723
718
  masked_m,
724
719
  expected_m,
725
- recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None,
726
720
  )
727
721
 
728
722
  return down_output
@@ -752,19 +746,25 @@ class DeepEPMoE(EPMoE):
752
746
  hidden_states = torch_npu.npu_grouped_matmul(
753
747
  x=[hidden_states],
754
748
  weight=[self.w13_weight],
755
- scale=[self.w13_weight_scale.to(output_dtype)],
756
- per_token_scale=[pertoken_scale],
757
749
  split_item=2,
758
750
  group_list_type=group_list_type,
759
751
  group_type=0,
760
752
  group_list=seg_indptr,
761
- output_dtype=output_dtype,
753
+ output_dtype=torch.int32,
762
754
  )[0]
763
755
 
764
756
  # act_fn: swiglu
765
- hidden_states = torch_npu.npu_swiglu(hidden_states)
766
-
767
- hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(hidden_states)
757
+ hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
758
+ x=hidden_states,
759
+ weight_scale=self.w13_weight_scale.to(torch.float32),
760
+ activation_scale=pertoken_scale,
761
+ bias=None,
762
+ quant_scale=None,
763
+ quant_offset=None,
764
+ group_index=seg_indptr,
765
+ activate_left=True,
766
+ quant_mode=1,
767
+ )
768
768
 
769
769
  # gmm2: down_proj
770
770
  hidden_states = torch_npu.npu_grouped_matmul(
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 64,
5
+ "BLOCK_SIZE_K": 64,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 5
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 128,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 1,
15
+ "num_warps": 4,
16
+ "num_stages": 4
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 128,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 1,
23
+ "num_warps": 4,
24
+ "num_stages": 5
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 128,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 1,
31
+ "num_warps": 4,
32
+ "num_stages": 4
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 1,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 128,
45
+ "BLOCK_SIZE_K": 128,
46
+ "GROUP_SIZE_M": 1,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 128,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 1,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 256,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 1,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 256,
69
+ "BLOCK_SIZE_K": 64,
70
+ "GROUP_SIZE_M": 1,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 256,
77
+ "BLOCK_SIZE_K": 64,
78
+ "GROUP_SIZE_M": 1,
79
+ "num_warps": 4,
80
+ "num_stages": 4
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 256,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 1,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 256,
93
+ "BLOCK_SIZE_K": 64,
94
+ "GROUP_SIZE_M": 1,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 64,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 1,
103
+ "num_warps": 8,
104
+ "num_stages": 3
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 64,
108
+ "BLOCK_SIZE_N": 256,
109
+ "BLOCK_SIZE_K": 64,
110
+ "GROUP_SIZE_M": 16,
111
+ "num_warps": 8,
112
+ "num_stages": 5
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 128,
116
+ "BLOCK_SIZE_N": 256,
117
+ "BLOCK_SIZE_K": 64,
118
+ "GROUP_SIZE_M": 16,
119
+ "num_warps": 4,
120
+ "num_stages": 4
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 128,
124
+ "BLOCK_SIZE_N": 256,
125
+ "BLOCK_SIZE_K": 64,
126
+ "GROUP_SIZE_M": 1,
127
+ "num_warps": 4,
128
+ "num_stages": 4
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 256,
132
+ "BLOCK_SIZE_N": 256,
133
+ "BLOCK_SIZE_K": 64,
134
+ "GROUP_SIZE_M": 16,
135
+ "num_warps": 8,
136
+ "num_stages": 5
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 256,
140
+ "BLOCK_SIZE_N": 256,
141
+ "BLOCK_SIZE_K": 64,
142
+ "GROUP_SIZE_M": 16,
143
+ "num_warps": 8,
144
+ "num_stages": 5
145
+ }
146
+ }
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 64,
5
+ "BLOCK_SIZE_K": 64,
6
+ "GROUP_SIZE_M": 64,
7
+ "num_warps": 4,
8
+ "num_stages": 4
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 64,
13
+ "BLOCK_SIZE_K": 64,
14
+ "GROUP_SIZE_M": 64,
15
+ "num_warps": 4,
16
+ "num_stages": 4
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 64,
21
+ "BLOCK_SIZE_K": 64,
22
+ "GROUP_SIZE_M": 64,
23
+ "num_warps": 4,
24
+ "num_stages": 4
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 64,
29
+ "BLOCK_SIZE_K": 64,
30
+ "GROUP_SIZE_M": 1,
31
+ "num_warps": 4,
32
+ "num_stages": 3
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 64,
38
+ "GROUP_SIZE_M": 1,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 128,
45
+ "BLOCK_SIZE_K": 64,
46
+ "GROUP_SIZE_M": 1,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 64,
53
+ "BLOCK_SIZE_K": 64,
54
+ "GROUP_SIZE_M": 1,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 64,
61
+ "BLOCK_SIZE_K": 64,
62
+ "GROUP_SIZE_M": 1,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 64,
69
+ "BLOCK_SIZE_K": 64,
70
+ "GROUP_SIZE_M": 1,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 64,
78
+ "GROUP_SIZE_M": 1,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 64,
86
+ "GROUP_SIZE_M": 1,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 64,
94
+ "GROUP_SIZE_M": 1,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 32,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 64,
102
+ "GROUP_SIZE_M": 1,
103
+ "num_warps": 4,
104
+ "num_stages": 3
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 32,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 64,
110
+ "GROUP_SIZE_M": 1,
111
+ "num_warps": 4,
112
+ "num_stages": 2
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 64,
116
+ "BLOCK_SIZE_N": 128,
117
+ "BLOCK_SIZE_K": 64,
118
+ "GROUP_SIZE_M": 32,
119
+ "num_warps": 4,
120
+ "num_stages": 2
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 32,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 64,
126
+ "GROUP_SIZE_M": 1,
127
+ "num_warps": 4,
128
+ "num_stages": 2
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 128,
132
+ "BLOCK_SIZE_N": 64,
133
+ "BLOCK_SIZE_K": 64,
134
+ "GROUP_SIZE_M": 1,
135
+ "num_warps": 4,
136
+ "num_stages": 3
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 64,
140
+ "BLOCK_SIZE_N": 128,
141
+ "BLOCK_SIZE_K": 64,
142
+ "GROUP_SIZE_M": 64,
143
+ "num_warps": 4,
144
+ "num_stages": 2
145
+ }
146
+ }
@@ -304,12 +304,12 @@ class TopK(CustomOp):
304
304
  global_num_experts = router_logits.shape[-1]
305
305
 
306
306
  # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
307
- if global_num_experts == 256 and self.topk_config.renormalize is False:
307
+ if global_num_experts == 256:
308
308
 
309
309
  routed_scaling_factor = self.topk_config.routed_scaling_factor or 1
310
310
  router_logits = router_logits.to(torch.float32)
311
311
 
312
- return torch_npu.npu_moe_gating_top_k(
312
+ topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
313
313
  router_logits,
314
314
  k=self.topk_config.top_k,
315
315
  bias=self.topk_config.correction_bias.to(torch.float32),
@@ -321,6 +321,16 @@ class TopK(CustomOp):
321
321
  routed_scaling_factor=routed_scaling_factor,
322
322
  eps=float(1e-20),
323
323
  )
324
+
325
+ if self.topk_config.renormalize:
326
+ topk_weights_sum = (
327
+ topk_weights.sum(dim=-1, keepdim=True)
328
+ if self.topk_config.num_fused_shared_experts == 0
329
+ else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
330
+ )
331
+ topk_weights = topk_weights / topk_weights_sum
332
+
333
+ return StandardTopKOutput(topk_weights, topk_ids, _)
324
334
  else:
325
335
  self.topk_config.torch_native = True
326
336
  return select_experts(
@@ -347,17 +357,28 @@ def fused_topk_torch_native(
347
357
  gating_output: torch.Tensor,
348
358
  topk: int,
349
359
  renormalize: bool,
360
+ correction_bias: torch.Tensor = None,
350
361
  ):
351
- assert (
352
- hidden_states.shape[0] == gating_output.shape[0]
353
- ), f"Number of tokens mismatch, {hidden_states.shape=} vs {gating_output.shape=}"
354
- M, _ = hidden_states.shape
355
- topk_weights = torch.empty(
356
- M, topk, dtype=torch.float32, device=hidden_states.device
357
- )
358
- topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
359
- topk_weights = F.softmax(gating_output.float(), dim=-1)
360
- topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1)
362
+ if correction_bias is not None:
363
+ n_routed_experts = gating_output.shape[-1]
364
+ scores = gating_output.softmax(dim=-1)
365
+ scores_for_choice = scores.view(
366
+ -1, n_routed_experts
367
+ ) + correction_bias.unsqueeze(0)
368
+ topk_ids = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=False)[1]
369
+ topk_weights = scores.gather(1, topk_ids)
370
+ else:
371
+ assert (
372
+ hidden_states.shape[0] == gating_output.shape[0]
373
+ ), f"Number of tokens mismatch, {hidden_states.shape=} vs {gating_output.shape=}"
374
+ M, _ = hidden_states.shape
375
+ topk_weights = torch.empty(
376
+ M, topk, dtype=torch.float32, device=hidden_states.device
377
+ )
378
+ topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
379
+ topk_weights = F.softmax(gating_output.float(), dim=-1)
380
+ topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1)
381
+
361
382
  if renormalize:
362
383
  topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
363
384
  return topk_weights, topk_ids
@@ -370,6 +391,7 @@ def fused_topk_cpu(
370
391
  renormalize: bool,
371
392
  num_token_non_padded: Optional[torch.Tensor] = None,
372
393
  expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
394
+ correction_bias: torch.Tensor = None,
373
395
  ):
374
396
  topk_weights, topk_ids = torch.ops.sgl_kernel.topk_softmax_cpu(
375
397
  hidden_states=hidden_states,
@@ -815,6 +837,7 @@ def select_experts(
815
837
  gating_output=router_logits,
816
838
  topk=top_k,
817
839
  renormalize=renormalize,
840
+ correction_bias=correction_bias,
818
841
  )
819
842
  elif custom_routing_function is None:
820
843
  assert not apply_routed_scaling_factor_on_output, "Not implemented"