sglang 0.4.9.post3__py3-none-any.whl → 0.4.9.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (128) hide show
  1. sglang/lang/chat_template.py +21 -0
  2. sglang/srt/_custom_ops.py +29 -1
  3. sglang/srt/configs/internvl.py +3 -0
  4. sglang/srt/configs/model_config.py +5 -1
  5. sglang/srt/constrained/base_grammar_backend.py +10 -2
  6. sglang/srt/constrained/xgrammar_backend.py +7 -5
  7. sglang/srt/conversation.py +17 -2
  8. sglang/srt/debug_utils/__init__.py +0 -0
  9. sglang/srt/debug_utils/dump_comparator.py +131 -0
  10. sglang/srt/debug_utils/dumper.py +108 -0
  11. sglang/srt/debug_utils/text_comparator.py +172 -0
  12. sglang/srt/disaggregation/common/conn.py +34 -6
  13. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +13 -1
  14. sglang/srt/disaggregation/mini_lb.py +3 -2
  15. sglang/srt/disaggregation/mooncake/conn.py +65 -20
  16. sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
  17. sglang/srt/disaggregation/nixl/conn.py +17 -13
  18. sglang/srt/disaggregation/prefill.py +13 -1
  19. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
  20. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
  21. sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
  22. sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
  23. sglang/srt/distributed/parallel_state.py +70 -15
  24. sglang/srt/entrypoints/engine.py +5 -9
  25. sglang/srt/entrypoints/http_server.py +20 -32
  26. sglang/srt/entrypoints/openai/protocol.py +3 -3
  27. sglang/srt/entrypoints/openai/serving_chat.py +148 -72
  28. sglang/srt/function_call/base_format_detector.py +74 -12
  29. sglang/srt/function_call/deepseekv3_detector.py +26 -11
  30. sglang/srt/function_call/ebnf_composer.py +105 -66
  31. sglang/srt/function_call/function_call_parser.py +6 -4
  32. sglang/srt/function_call/glm4_moe_detector.py +164 -0
  33. sglang/srt/function_call/kimik2_detector.py +41 -16
  34. sglang/srt/function_call/llama32_detector.py +6 -3
  35. sglang/srt/function_call/mistral_detector.py +11 -3
  36. sglang/srt/function_call/pythonic_detector.py +16 -14
  37. sglang/srt/function_call/qwen25_detector.py +12 -3
  38. sglang/srt/function_call/{qwen3_detector.py → qwen3_coder_detector.py} +11 -9
  39. sglang/srt/layers/activation.py +11 -3
  40. sglang/srt/layers/attention/base_attn_backend.py +3 -1
  41. sglang/srt/layers/attention/hybrid_attn_backend.py +100 -0
  42. sglang/srt/layers/attention/vision.py +56 -8
  43. sglang/srt/layers/communicator.py +12 -12
  44. sglang/srt/layers/dp_attention.py +72 -24
  45. sglang/srt/layers/layernorm.py +26 -1
  46. sglang/srt/layers/logits_processor.py +46 -25
  47. sglang/srt/layers/moe/ep_moe/layer.py +172 -206
  48. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=160,N=320,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  49. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
  50. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +25 -224
  51. sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -48
  52. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +11 -8
  53. sglang/srt/layers/moe/topk.py +88 -34
  54. sglang/srt/layers/multimodal.py +11 -8
  55. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -9
  56. sglang/srt/layers/quantization/fp8.py +25 -247
  57. sglang/srt/layers/quantization/fp8_kernel.py +78 -48
  58. sglang/srt/layers/quantization/modelopt_quant.py +33 -14
  59. sglang/srt/layers/quantization/unquant.py +24 -76
  60. sglang/srt/layers/quantization/utils.py +0 -9
  61. sglang/srt/layers/quantization/w4afp8.py +68 -17
  62. sglang/srt/layers/radix_attention.py +5 -3
  63. sglang/srt/lora/lora_manager.py +133 -169
  64. sglang/srt/lora/lora_registry.py +188 -0
  65. sglang/srt/lora/mem_pool.py +2 -2
  66. sglang/srt/managers/cache_controller.py +62 -13
  67. sglang/srt/managers/io_struct.py +19 -1
  68. sglang/srt/managers/mm_utils.py +154 -35
  69. sglang/srt/managers/multimodal_processor.py +3 -14
  70. sglang/srt/managers/schedule_batch.py +27 -11
  71. sglang/srt/managers/scheduler.py +48 -26
  72. sglang/srt/managers/tokenizer_manager.py +62 -28
  73. sglang/srt/managers/tp_worker.py +5 -4
  74. sglang/srt/mem_cache/allocator.py +67 -7
  75. sglang/srt/mem_cache/hicache_storage.py +17 -1
  76. sglang/srt/mem_cache/hiradix_cache.py +35 -18
  77. sglang/srt/mem_cache/memory_pool_host.py +3 -0
  78. sglang/srt/model_executor/cuda_graph_runner.py +61 -25
  79. sglang/srt/model_executor/forward_batch_info.py +201 -29
  80. sglang/srt/model_executor/model_runner.py +109 -37
  81. sglang/srt/models/deepseek_v2.py +63 -30
  82. sglang/srt/models/glm4_moe.py +1035 -0
  83. sglang/srt/models/glm4_moe_nextn.py +167 -0
  84. sglang/srt/models/interns1.py +328 -0
  85. sglang/srt/models/internvl.py +143 -47
  86. sglang/srt/models/llava.py +9 -5
  87. sglang/srt/models/minicpmo.py +4 -1
  88. sglang/srt/models/mllama4.py +10 -3
  89. sglang/srt/models/qwen2_moe.py +2 -6
  90. sglang/srt/models/qwen3_moe.py +6 -8
  91. sglang/srt/multimodal/processors/base_processor.py +20 -6
  92. sglang/srt/multimodal/processors/clip.py +2 -2
  93. sglang/srt/multimodal/processors/deepseek_vl_v2.py +2 -2
  94. sglang/srt/multimodal/processors/gemma3.py +2 -2
  95. sglang/srt/multimodal/processors/gemma3n.py +2 -2
  96. sglang/srt/multimodal/processors/internvl.py +21 -8
  97. sglang/srt/multimodal/processors/janus_pro.py +2 -2
  98. sglang/srt/multimodal/processors/kimi_vl.py +2 -2
  99. sglang/srt/multimodal/processors/llava.py +4 -4
  100. sglang/srt/multimodal/processors/minicpm.py +2 -3
  101. sglang/srt/multimodal/processors/mlama.py +2 -2
  102. sglang/srt/multimodal/processors/mllama4.py +18 -111
  103. sglang/srt/multimodal/processors/phi4mm.py +2 -2
  104. sglang/srt/multimodal/processors/pixtral.py +2 -2
  105. sglang/srt/multimodal/processors/qwen_audio.py +2 -2
  106. sglang/srt/multimodal/processors/qwen_vl.py +2 -2
  107. sglang/srt/multimodal/processors/vila.py +3 -1
  108. sglang/srt/reasoning_parser.py +48 -5
  109. sglang/srt/sampling/sampling_batch_info.py +6 -5
  110. sglang/srt/server_args.py +132 -60
  111. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
  112. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +37 -36
  113. sglang/srt/speculative/eagle_utils.py +51 -23
  114. sglang/srt/speculative/eagle_worker.py +59 -44
  115. sglang/srt/two_batch_overlap.py +9 -5
  116. sglang/srt/utils.py +113 -69
  117. sglang/srt/weight_sync/utils.py +119 -0
  118. sglang/test/runners.py +4 -0
  119. sglang/test/test_activation.py +50 -1
  120. sglang/test/test_utils.py +65 -5
  121. sglang/utils.py +19 -0
  122. sglang/version.py +1 -1
  123. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/METADATA +6 -6
  124. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/RECORD +127 -114
  125. sglang/srt/debug_utils.py +0 -74
  126. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/WHEEL +0 -0
  127. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/licenses/LICENSE +0 -0
  128. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 64,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 32,
7
+ "num_warps": 4,
8
+ "num_stages": 4
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 128,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 64,
15
+ "num_warps": 4,
16
+ "num_stages": 3
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 128,
21
+ "BLOCK_SIZE_K": 256,
22
+ "GROUP_SIZE_M": 16,
23
+ "num_warps": 4,
24
+ "num_stages": 2
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 64,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 16,
31
+ "num_warps": 4,
32
+ "num_stages": 3
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 16,
39
+ "num_warps": 4,
40
+ "num_stages": 2
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 64,
45
+ "BLOCK_SIZE_K": 256,
46
+ "GROUP_SIZE_M": 1,
47
+ "num_warps": 4,
48
+ "num_stages": 2
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 64,
53
+ "BLOCK_SIZE_K": 256,
54
+ "GROUP_SIZE_M": 1,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 64,
61
+ "BLOCK_SIZE_K": 256,
62
+ "GROUP_SIZE_M": 1,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 256,
70
+ "GROUP_SIZE_M": 1,
71
+ "num_warps": 4,
72
+ "num_stages": 2
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 256,
78
+ "GROUP_SIZE_M": 1,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 256,
86
+ "GROUP_SIZE_M": 16,
87
+ "num_warps": 4,
88
+ "num_stages": 2
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 32,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 256,
94
+ "GROUP_SIZE_M": 16,
95
+ "num_warps": 4,
96
+ "num_stages": 2
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 32,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 256,
102
+ "GROUP_SIZE_M": 16,
103
+ "num_warps": 4,
104
+ "num_stages": 2
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 64,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 1,
111
+ "num_warps": 4,
112
+ "num_stages": 4
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 32,
116
+ "BLOCK_SIZE_N": 128,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 1,
119
+ "num_warps": 4,
120
+ "num_stages": 3
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 64,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 1,
127
+ "num_warps": 4,
128
+ "num_stages": 3
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 64,
132
+ "BLOCK_SIZE_N": 128,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 1,
135
+ "num_warps": 4,
136
+ "num_stages": 3
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 64,
140
+ "BLOCK_SIZE_N": 128,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 1,
143
+ "num_warps": 4,
144
+ "num_stages": 3
145
+ }
146
+ }
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 128,
5
+ "BLOCK_SIZE_K": 64,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 4
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 64,
13
+ "BLOCK_SIZE_K": 64,
14
+ "GROUP_SIZE_M": 1,
15
+ "num_warps": 4,
16
+ "num_stages": 4
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 64,
21
+ "BLOCK_SIZE_K": 64,
22
+ "GROUP_SIZE_M": 1,
23
+ "num_warps": 4,
24
+ "num_stages": 4
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 64,
29
+ "BLOCK_SIZE_K": 64,
30
+ "GROUP_SIZE_M": 1,
31
+ "num_warps": 4,
32
+ "num_stages": 3
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 64,
38
+ "GROUP_SIZE_M": 32,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 128,
45
+ "BLOCK_SIZE_K": 64,
46
+ "GROUP_SIZE_M": 64,
47
+ "num_warps": 4,
48
+ "num_stages": 4
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 64,
53
+ "BLOCK_SIZE_K": 64,
54
+ "GROUP_SIZE_M": 32,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 128,
61
+ "BLOCK_SIZE_K": 64,
62
+ "GROUP_SIZE_M": 32,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 64,
70
+ "GROUP_SIZE_M": 1,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 64,
78
+ "GROUP_SIZE_M": 32,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 64,
86
+ "GROUP_SIZE_M": 1,
87
+ "num_warps": 4,
88
+ "num_stages": 4
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 64,
94
+ "GROUP_SIZE_M": 64,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 32,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 64,
102
+ "GROUP_SIZE_M": 64,
103
+ "num_warps": 4,
104
+ "num_stages": 3
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 64,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 64,
110
+ "GROUP_SIZE_M": 1,
111
+ "num_warps": 4,
112
+ "num_stages": 3
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 64,
116
+ "BLOCK_SIZE_N": 64,
117
+ "BLOCK_SIZE_K": 64,
118
+ "GROUP_SIZE_M": 1,
119
+ "num_warps": 4,
120
+ "num_stages": 3
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 64,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 64,
126
+ "GROUP_SIZE_M": 16,
127
+ "num_warps": 4,
128
+ "num_stages": 3
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 64,
132
+ "BLOCK_SIZE_N": 128,
133
+ "BLOCK_SIZE_K": 64,
134
+ "GROUP_SIZE_M": 32,
135
+ "num_warps": 4,
136
+ "num_stages": 3
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 64,
140
+ "BLOCK_SIZE_N": 64,
141
+ "BLOCK_SIZE_K": 64,
142
+ "GROUP_SIZE_M": 16,
143
+ "num_warps": 4,
144
+ "num_stages": 2
145
+ }
146
+ }
@@ -53,9 +53,7 @@ elif _is_hip:
53
53
  from aiter import moe_sum
54
54
  except ImportError:
55
55
  raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")
56
- else:
57
- from vllm import _custom_ops as vllm_ops
58
- from vllm._custom_ops import scaled_fp8_quant
56
+
59
57
 
60
58
  if _is_cuda or _is_hip:
61
59
  from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
@@ -63,9 +61,6 @@ if _is_cuda or _is_hip:
63
61
 
64
62
  logger = logging.getLogger(__name__)
65
63
  padding_size = 128 if bool(int(os.getenv("SGLANG_MOE_PADDING", "0"))) else 0
66
- enable_moe_align_block_size_triton = bool(
67
- int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
68
- )
69
64
 
70
65
 
71
66
  @triton.jit
@@ -533,190 +528,6 @@ def fused_moe_kernel(
533
528
  tl.store(c_ptrs, accumulator, mask=c_mask)
534
529
 
535
530
 
536
- @triton.jit
537
- def moe_align_block_size_stage1(
538
- topk_ids_ptr,
539
- tokens_cnts_ptr,
540
- num_experts: tl.constexpr,
541
- numel: tl.constexpr,
542
- tokens_per_thread: tl.constexpr,
543
- ):
544
- pid = tl.program_id(0)
545
-
546
- start_idx = pid * tokens_per_thread
547
-
548
- off_c = (pid + 1) * num_experts
549
-
550
- for i in range(tokens_per_thread):
551
- if start_idx + i < numel:
552
- idx = tl.load(topk_ids_ptr + start_idx + i)
553
- token_cnt = tl.load(tokens_cnts_ptr + off_c + idx)
554
- tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1)
555
-
556
-
557
- @triton.jit
558
- def moe_align_block_size_stage2(
559
- tokens_cnts_ptr,
560
- num_experts: tl.constexpr,
561
- ):
562
- pid = tl.program_id(0)
563
-
564
- last_cnt = 0
565
- for i in range(1, num_experts + 1):
566
- token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid)
567
- last_cnt = last_cnt + token_cnt
568
- tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt)
569
-
570
-
571
- @triton.jit
572
- def moe_align_block_size_stage3(
573
- total_tokens_post_pad_ptr,
574
- tokens_cnts_ptr,
575
- cumsum_ptr,
576
- num_experts: tl.constexpr,
577
- block_size: tl.constexpr,
578
- ):
579
- last_cumsum = 0
580
- off_cnt = num_experts * num_experts
581
- for i in range(1, num_experts + 1):
582
- token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1)
583
- last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size
584
- tl.store(cumsum_ptr + i, last_cumsum)
585
- tl.store(total_tokens_post_pad_ptr, last_cumsum)
586
-
587
-
588
- @triton.jit
589
- def moe_align_block_size_stage4(
590
- topk_ids_ptr,
591
- sorted_token_ids_ptr,
592
- expert_ids_ptr,
593
- tokens_cnts_ptr,
594
- cumsum_ptr,
595
- num_experts: tl.constexpr,
596
- block_size: tl.constexpr,
597
- numel: tl.constexpr,
598
- tokens_per_thread: tl.constexpr,
599
- ):
600
- pid = tl.program_id(0)
601
- start_idx = tl.load(cumsum_ptr + pid)
602
- end_idx = tl.load(cumsum_ptr + pid + 1)
603
-
604
- for i in range(start_idx, end_idx, block_size):
605
- tl.store(expert_ids_ptr + i // block_size, pid)
606
-
607
- start_idx = pid * tokens_per_thread
608
- off_t = pid * num_experts
609
-
610
- for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread, numel)):
611
- expert_id = tl.load(topk_ids_ptr + i)
612
- token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id)
613
- rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id)
614
- tl.store(sorted_token_ids_ptr + rank_post_pad, i)
615
- tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1)
616
-
617
-
618
- def moe_align_block_size_triton(
619
- topk_ids: torch.Tensor,
620
- num_experts: int,
621
- block_size: int,
622
- sorted_token_ids: torch.Tensor,
623
- expert_ids: torch.Tensor,
624
- num_tokens_post_pad: torch.Tensor,
625
- ) -> None:
626
- numel = topk_ids.numel()
627
- grid = (num_experts,)
628
- tokens_cnts = torch.zeros(
629
- (num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device
630
- )
631
- cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device=topk_ids.device)
632
- tokens_per_thread = ceil_div(numel, num_experts)
633
-
634
- moe_align_block_size_stage1[grid](
635
- topk_ids,
636
- tokens_cnts,
637
- num_experts,
638
- numel,
639
- tokens_per_thread,
640
- )
641
- moe_align_block_size_stage2[grid](
642
- tokens_cnts,
643
- num_experts,
644
- )
645
- moe_align_block_size_stage3[(1,)](
646
- num_tokens_post_pad,
647
- tokens_cnts,
648
- cumsum,
649
- num_experts,
650
- block_size,
651
- )
652
- moe_align_block_size_stage4[grid](
653
- topk_ids,
654
- sorted_token_ids,
655
- expert_ids,
656
- tokens_cnts,
657
- cumsum,
658
- num_experts,
659
- block_size,
660
- numel,
661
- tokens_per_thread,
662
- )
663
-
664
-
665
- @triton.jit
666
- def init_sorted_ids_and_cumsum_buffer_kernel(
667
- sorted_ids_ptr,
668
- cumsum_buffer_ptr,
669
- max_num_tokens_padded,
670
- topk_ids_numel,
671
- num_experts: tl.constexpr,
672
- BLOCK_SIZE: tl.constexpr,
673
- ALIGNED_NUM_EXPERTS_P1: tl.constexpr,
674
- ):
675
- pid = tl.program_id(0)
676
- offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
677
-
678
- sorted_ids_blocks = tl.cdiv(max_num_tokens_padded, BLOCK_SIZE)
679
-
680
- if pid < sorted_ids_blocks:
681
- mask = offsets < max_num_tokens_padded
682
- tl.store(
683
- sorted_ids_ptr + offsets,
684
- tl.full((BLOCK_SIZE,), topk_ids_numel, dtype=tl.int32),
685
- mask=mask,
686
- )
687
- elif pid == sorted_ids_blocks:
688
- offset_e = tl.arange(0, ALIGNED_NUM_EXPERTS_P1)
689
- mask_e = offset_e < num_experts + 1
690
- tl.store(
691
- cumsum_buffer_ptr + offset_e,
692
- tl.zeros((ALIGNED_NUM_EXPERTS_P1,), dtype=tl.int32),
693
- mask=mask_e,
694
- )
695
-
696
-
697
- def init_sorted_ids_and_cumsum_buffer(
698
- max_num_tokens_padded: int, topk_ids_numel: int, num_experts: int, device="cuda"
699
- ):
700
- sorted_ids = torch.empty((max_num_tokens_padded,), dtype=torch.int32, device=device)
701
- cumsum_buffer = torch.empty((num_experts + 1,), dtype=torch.int32, device=device)
702
-
703
- BLOCK_SIZE = 1024
704
- sorted_ids_blocks = triton.cdiv(max_num_tokens_padded, BLOCK_SIZE)
705
- grid = (sorted_ids_blocks + 1,)
706
-
707
- init_sorted_ids_and_cumsum_buffer_kernel[grid](
708
- sorted_ids,
709
- cumsum_buffer,
710
- max_num_tokens_padded,
711
- topk_ids_numel,
712
- num_experts,
713
- BLOCK_SIZE,
714
- next_power_of_2(num_experts + 1),
715
- )
716
-
717
- return sorted_ids, cumsum_buffer
718
-
719
-
720
531
  def moe_align_block_size(
721
532
  topk_ids: torch.Tensor, block_size: int, num_experts: int
722
533
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
@@ -766,42 +577,32 @@ def moe_align_block_size(
766
577
  (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
767
578
  )
768
579
  num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
769
- if enable_moe_align_block_size_triton:
770
- sorted_ids.fill_(topk_ids.numel())
771
- moe_align_block_size_triton(
772
- topk_ids,
773
- num_experts,
774
- block_size,
775
- sorted_ids,
776
- expert_ids,
777
- num_tokens_post_pad,
778
- )
779
- else:
780
- cumsum_buffer = torch.empty(
781
- (num_experts + 1,), dtype=torch.int32, device=topk_ids.device
782
- )
783
- token_cnts_buffer = torch.empty(
784
- (num_experts + 1) * num_experts,
785
- dtype=torch.int32,
786
- device=topk_ids.device,
787
- )
788
580
 
789
- # Threshold based on benchmark results
790
- fuse_sorted_ids_padding = sorted_ids.shape[0] <= 4096
791
- if not fuse_sorted_ids_padding:
792
- sorted_ids.fill_(topk_ids.numel())
581
+ cumsum_buffer = torch.empty(
582
+ (num_experts + 1,), dtype=torch.int32, device=topk_ids.device
583
+ )
584
+ token_cnts_buffer = torch.empty(
585
+ (num_experts + 1) * num_experts,
586
+ dtype=torch.int32,
587
+ device=topk_ids.device,
588
+ )
793
589
 
794
- sgl_moe_align_block_size(
795
- topk_ids,
796
- num_experts,
797
- block_size,
798
- sorted_ids,
799
- expert_ids,
800
- num_tokens_post_pad,
801
- token_cnts_buffer,
802
- cumsum_buffer,
803
- fuse_sorted_ids_padding,
804
- )
590
+ # Threshold based on benchmark results
591
+ fuse_sorted_ids_padding = sorted_ids.shape[0] <= 4096
592
+ if not fuse_sorted_ids_padding:
593
+ sorted_ids.fill_(topk_ids.numel())
594
+
595
+ sgl_moe_align_block_size(
596
+ topk_ids,
597
+ num_experts,
598
+ block_size,
599
+ sorted_ids,
600
+ expert_ids,
601
+ num_tokens_post_pad,
602
+ token_cnts_buffer,
603
+ cumsum_buffer,
604
+ fuse_sorted_ids_padding,
605
+ )
805
606
  return sorted_ids, expert_ids, num_tokens_post_pad
806
607
 
807
608