sglang 0.4.5.post3__py3-none-any.whl → 0.4.6.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. sglang/bench_one_batch.py +19 -3
  2. sglang/bench_serving.py +8 -9
  3. sglang/compile_deep_gemm.py +45 -4
  4. sglang/srt/code_completion_parser.py +1 -1
  5. sglang/srt/configs/deepseekvl2.py +1 -1
  6. sglang/srt/configs/model_config.py +9 -3
  7. sglang/srt/constrained/llguidance_backend.py +78 -61
  8. sglang/srt/conversation.py +34 -1
  9. sglang/srt/disaggregation/decode.py +67 -13
  10. sglang/srt/disaggregation/fake/__init__.py +1 -0
  11. sglang/srt/disaggregation/fake/conn.py +88 -0
  12. sglang/srt/disaggregation/mini_lb.py +45 -8
  13. sglang/srt/disaggregation/mooncake/conn.py +198 -31
  14. sglang/srt/disaggregation/prefill.py +36 -12
  15. sglang/srt/disaggregation/utils.py +16 -2
  16. sglang/srt/entrypoints/engine.py +9 -0
  17. sglang/srt/entrypoints/http_server.py +35 -4
  18. sglang/srt/function_call_parser.py +77 -5
  19. sglang/srt/layers/attention/base_attn_backend.py +3 -0
  20. sglang/srt/layers/attention/cutlass_mla_backend.py +278 -0
  21. sglang/srt/layers/attention/flashattention_backend.py +28 -10
  22. sglang/srt/layers/attention/flashmla_backend.py +8 -11
  23. sglang/srt/layers/attention/utils.py +1 -1
  24. sglang/srt/layers/attention/vision.py +2 -0
  25. sglang/srt/layers/layernorm.py +38 -16
  26. sglang/srt/layers/logits_processor.py +2 -2
  27. sglang/srt/layers/moe/fused_moe_native.py +2 -4
  28. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H20.json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H200.json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H20.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H200.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  36. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  37. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  38. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H200.json +146 -0
  39. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=96,device_name=NVIDIA_H20.json +146 -0
  40. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +41 -41
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +20 -17
  43. sglang/srt/layers/moe/fused_moe_triton/layer.py +15 -17
  44. sglang/srt/layers/pooler.py +6 -0
  45. sglang/srt/layers/quantization/awq.py +5 -1
  46. sglang/srt/layers/quantization/deep_gemm.py +17 -10
  47. sglang/srt/layers/quantization/fp8.py +20 -22
  48. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  49. sglang/srt/layers/quantization/int8_kernel.py +32 -1
  50. sglang/srt/layers/radix_attention.py +13 -3
  51. sglang/srt/layers/rotary_embedding.py +170 -126
  52. sglang/srt/managers/data_parallel_controller.py +10 -3
  53. sglang/srt/managers/io_struct.py +7 -0
  54. sglang/srt/managers/mm_utils.py +85 -28
  55. sglang/srt/managers/multimodal_processors/base_processor.py +14 -1
  56. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +9 -2
  57. sglang/srt/managers/multimodal_processors/gemma3.py +2 -5
  58. sglang/srt/managers/multimodal_processors/janus_pro.py +2 -2
  59. sglang/srt/managers/multimodal_processors/minicpm.py +4 -3
  60. sglang/srt/managers/multimodal_processors/qwen_vl.py +38 -13
  61. sglang/srt/managers/schedule_batch.py +38 -12
  62. sglang/srt/managers/scheduler.py +41 -28
  63. sglang/srt/managers/scheduler_output_processor_mixin.py +25 -9
  64. sglang/srt/managers/tokenizer_manager.py +5 -1
  65. sglang/srt/managers/tp_worker.py +3 -3
  66. sglang/srt/managers/tp_worker_overlap_thread.py +9 -4
  67. sglang/srt/mem_cache/memory_pool.py +87 -0
  68. sglang/srt/model_executor/cuda_graph_runner.py +4 -3
  69. sglang/srt/model_executor/forward_batch_info.py +51 -95
  70. sglang/srt/model_executor/model_runner.py +19 -25
  71. sglang/srt/models/deepseek.py +12 -2
  72. sglang/srt/models/deepseek_nextn.py +101 -6
  73. sglang/srt/models/deepseek_v2.py +144 -70
  74. sglang/srt/models/deepseek_vl2.py +9 -4
  75. sglang/srt/models/gemma3_causal.py +1 -1
  76. sglang/srt/models/llama4.py +0 -1
  77. sglang/srt/models/minicpmo.py +5 -1
  78. sglang/srt/models/mllama4.py +2 -2
  79. sglang/srt/models/qwen2_5_vl.py +3 -6
  80. sglang/srt/models/qwen2_vl.py +3 -7
  81. sglang/srt/models/roberta.py +178 -0
  82. sglang/srt/openai_api/adapter.py +50 -11
  83. sglang/srt/openai_api/protocol.py +2 -0
  84. sglang/srt/reasoning_parser.py +25 -1
  85. sglang/srt/server_args.py +31 -24
  86. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  87. sglang/srt/torch_memory_saver_adapter.py +10 -1
  88. sglang/srt/utils.py +5 -1
  89. sglang/test/runners.py +6 -13
  90. sglang/test/send_one.py +84 -28
  91. sglang/test/test_utils.py +74 -18
  92. sglang/version.py +1 -1
  93. {sglang-0.4.5.post3.dist-info → sglang-0.4.6.post1.dist-info}/METADATA +5 -6
  94. {sglang-0.4.5.post3.dist-info → sglang-0.4.6.post1.dist-info}/RECORD +97 -80
  95. {sglang-0.4.5.post3.dist-info → sglang-0.4.6.post1.dist-info}/WHEEL +1 -1
  96. {sglang-0.4.5.post3.dist-info → sglang-0.4.6.post1.dist-info}/licenses/LICENSE +0 -0
  97. {sglang-0.4.5.post3.dist-info → sglang-0.4.6.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 64,
5
+ "BLOCK_SIZE_K": 64,
6
+ "GROUP_SIZE_M": 16,
7
+ "num_warps": 4,
8
+ "num_stages": 4
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 64,
13
+ "BLOCK_SIZE_K": 64,
14
+ "GROUP_SIZE_M": 1,
15
+ "num_warps": 4,
16
+ "num_stages": 4
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 128,
21
+ "BLOCK_SIZE_K": 64,
22
+ "GROUP_SIZE_M": 1,
23
+ "num_warps": 4,
24
+ "num_stages": 4
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 64,
29
+ "BLOCK_SIZE_K": 64,
30
+ "GROUP_SIZE_M": 64,
31
+ "num_warps": 4,
32
+ "num_stages": 4
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 64,
37
+ "BLOCK_SIZE_K": 64,
38
+ "GROUP_SIZE_M": 1,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 64,
45
+ "BLOCK_SIZE_K": 64,
46
+ "GROUP_SIZE_M": 1,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 128,
53
+ "BLOCK_SIZE_K": 64,
54
+ "GROUP_SIZE_M": 1,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 64,
61
+ "BLOCK_SIZE_K": 64,
62
+ "GROUP_SIZE_M": 1,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 64,
69
+ "BLOCK_SIZE_K": 64,
70
+ "GROUP_SIZE_M": 1,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 64,
77
+ "BLOCK_SIZE_K": 64,
78
+ "GROUP_SIZE_M": 1,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 64,
85
+ "BLOCK_SIZE_K": 64,
86
+ "GROUP_SIZE_M": 1,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 64,
94
+ "GROUP_SIZE_M": 1,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 64,
100
+ "BLOCK_SIZE_N": 64,
101
+ "BLOCK_SIZE_K": 64,
102
+ "GROUP_SIZE_M": 1,
103
+ "num_warps": 4,
104
+ "num_stages": 2
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 64,
108
+ "BLOCK_SIZE_N": 64,
109
+ "BLOCK_SIZE_K": 64,
110
+ "GROUP_SIZE_M": 1,
111
+ "num_warps": 4,
112
+ "num_stages": 2
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 64,
116
+ "BLOCK_SIZE_N": 64,
117
+ "BLOCK_SIZE_K": 64,
118
+ "GROUP_SIZE_M": 1,
119
+ "num_warps": 4,
120
+ "num_stages": 2
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 64,
124
+ "BLOCK_SIZE_N": 64,
125
+ "BLOCK_SIZE_K": 64,
126
+ "GROUP_SIZE_M": 1,
127
+ "num_warps": 4,
128
+ "num_stages": 2
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 64,
132
+ "BLOCK_SIZE_N": 64,
133
+ "BLOCK_SIZE_K": 64,
134
+ "GROUP_SIZE_M": 1,
135
+ "num_warps": 4,
136
+ "num_stages": 2
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 64,
140
+ "BLOCK_SIZE_N": 64,
141
+ "BLOCK_SIZE_K": 64,
142
+ "GROUP_SIZE_M": 1,
143
+ "num_warps": 4,
144
+ "num_stages": 2
145
+ }
146
+ }
@@ -1,102 +1,102 @@
1
1
  {
2
2
  "1": {
3
- "BLOCK_SIZE_M": 64,
3
+ "BLOCK_SIZE_M": 16,
4
4
  "BLOCK_SIZE_N": 64,
5
5
  "BLOCK_SIZE_K": 128,
6
- "GROUP_SIZE_M": 16,
6
+ "GROUP_SIZE_M": 1,
7
7
  "num_warps": 4,
8
8
  "num_stages": 4
9
9
  },
10
10
  "2": {
11
- "BLOCK_SIZE_M": 64,
12
- "BLOCK_SIZE_N": 32,
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 128,
13
13
  "BLOCK_SIZE_K": 128,
14
- "GROUP_SIZE_M": 1,
14
+ "GROUP_SIZE_M": 16,
15
15
  "num_warps": 4,
16
- "num_stages": 3
16
+ "num_stages": 4
17
17
  },
18
18
  "4": {
19
- "BLOCK_SIZE_M": 64,
20
- "BLOCK_SIZE_N": 64,
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 128,
21
21
  "BLOCK_SIZE_K": 128,
22
- "GROUP_SIZE_M": 1,
22
+ "GROUP_SIZE_M": 16,
23
23
  "num_warps": 4,
24
24
  "num_stages": 4
25
25
  },
26
26
  "8": {
27
- "BLOCK_SIZE_M": 64,
27
+ "BLOCK_SIZE_M": 16,
28
28
  "BLOCK_SIZE_N": 128,
29
29
  "BLOCK_SIZE_K": 128,
30
30
  "GROUP_SIZE_M": 32,
31
31
  "num_warps": 4,
32
- "num_stages": 3
32
+ "num_stages": 4
33
33
  },
34
34
  "16": {
35
- "BLOCK_SIZE_M": 64,
35
+ "BLOCK_SIZE_M": 16,
36
36
  "BLOCK_SIZE_N": 128,
37
37
  "BLOCK_SIZE_K": 128,
38
- "GROUP_SIZE_M": 16,
38
+ "GROUP_SIZE_M": 1,
39
39
  "num_warps": 4,
40
40
  "num_stages": 3
41
41
  },
42
42
  "24": {
43
- "BLOCK_SIZE_M": 64,
43
+ "BLOCK_SIZE_M": 16,
44
44
  "BLOCK_SIZE_N": 128,
45
45
  "BLOCK_SIZE_K": 128,
46
- "GROUP_SIZE_M": 16,
46
+ "GROUP_SIZE_M": 1,
47
47
  "num_warps": 4,
48
- "num_stages": 3
48
+ "num_stages": 4
49
49
  },
50
50
  "32": {
51
- "BLOCK_SIZE_M": 64,
51
+ "BLOCK_SIZE_M": 16,
52
52
  "BLOCK_SIZE_N": 128,
53
53
  "BLOCK_SIZE_K": 128,
54
- "GROUP_SIZE_M": 32,
54
+ "GROUP_SIZE_M": 16,
55
55
  "num_warps": 4,
56
- "num_stages": 3
56
+ "num_stages": 5
57
57
  },
58
58
  "48": {
59
- "BLOCK_SIZE_M": 64,
59
+ "BLOCK_SIZE_M": 16,
60
60
  "BLOCK_SIZE_N": 128,
61
61
  "BLOCK_SIZE_K": 128,
62
- "GROUP_SIZE_M": 32,
62
+ "GROUP_SIZE_M": 64,
63
63
  "num_warps": 4,
64
- "num_stages": 3
64
+ "num_stages": 4
65
65
  },
66
66
  "64": {
67
- "BLOCK_SIZE_M": 64,
67
+ "BLOCK_SIZE_M": 16,
68
68
  "BLOCK_SIZE_N": 128,
69
69
  "BLOCK_SIZE_K": 128,
70
- "GROUP_SIZE_M": 64,
70
+ "GROUP_SIZE_M": 32,
71
71
  "num_warps": 4,
72
72
  "num_stages": 3
73
73
  },
74
74
  "96": {
75
- "BLOCK_SIZE_M": 64,
75
+ "BLOCK_SIZE_M": 16,
76
76
  "BLOCK_SIZE_N": 128,
77
77
  "BLOCK_SIZE_K": 128,
78
- "GROUP_SIZE_M": 64,
78
+ "GROUP_SIZE_M": 32,
79
79
  "num_warps": 4,
80
80
  "num_stages": 3
81
81
  },
82
82
  "128": {
83
- "BLOCK_SIZE_M": 64,
83
+ "BLOCK_SIZE_M": 16,
84
84
  "BLOCK_SIZE_N": 128,
85
85
  "BLOCK_SIZE_K": 128,
86
- "GROUP_SIZE_M": 16,
86
+ "GROUP_SIZE_M": 64,
87
87
  "num_warps": 4,
88
88
  "num_stages": 3
89
89
  },
90
90
  "256": {
91
- "BLOCK_SIZE_M": 64,
91
+ "BLOCK_SIZE_M": 16,
92
92
  "BLOCK_SIZE_N": 128,
93
93
  "BLOCK_SIZE_K": 128,
94
- "GROUP_SIZE_M": 16,
94
+ "GROUP_SIZE_M": 64,
95
95
  "num_warps": 4,
96
96
  "num_stages": 3
97
97
  },
98
98
  "512": {
99
- "BLOCK_SIZE_M": 64,
99
+ "BLOCK_SIZE_M": 16,
100
100
  "BLOCK_SIZE_N": 128,
101
101
  "BLOCK_SIZE_K": 128,
102
102
  "GROUP_SIZE_M": 16,
@@ -107,9 +107,9 @@
107
107
  "BLOCK_SIZE_M": 64,
108
108
  "BLOCK_SIZE_N": 128,
109
109
  "BLOCK_SIZE_K": 128,
110
- "GROUP_SIZE_M": 32,
110
+ "GROUP_SIZE_M": 16,
111
111
  "num_warps": 4,
112
- "num_stages": 3
112
+ "num_stages": 4
113
113
  },
114
114
  "1536": {
115
115
  "BLOCK_SIZE_M": 64,
@@ -117,21 +117,21 @@
117
117
  "BLOCK_SIZE_K": 128,
118
118
  "GROUP_SIZE_M": 32,
119
119
  "num_warps": 4,
120
- "num_stages": 3
120
+ "num_stages": 4
121
121
  },
122
122
  "2048": {
123
123
  "BLOCK_SIZE_M": 64,
124
124
  "BLOCK_SIZE_N": 128,
125
125
  "BLOCK_SIZE_K": 128,
126
- "GROUP_SIZE_M": 16,
126
+ "GROUP_SIZE_M": 32,
127
127
  "num_warps": 4,
128
- "num_stages": 3
128
+ "num_stages": 4
129
129
  },
130
130
  "3072": {
131
- "BLOCK_SIZE_M": 128,
132
- "BLOCK_SIZE_N": 64,
131
+ "BLOCK_SIZE_M": 64,
132
+ "BLOCK_SIZE_N": 128,
133
133
  "BLOCK_SIZE_K": 128,
134
- "GROUP_SIZE_M": 32,
134
+ "GROUP_SIZE_M": 16,
135
135
  "num_warps": 4,
136
136
  "num_stages": 3
137
137
  },
@@ -139,8 +139,8 @@
139
139
  "BLOCK_SIZE_M": 64,
140
140
  "BLOCK_SIZE_N": 128,
141
141
  "BLOCK_SIZE_K": 128,
142
- "GROUP_SIZE_M": 64,
142
+ "GROUP_SIZE_M": 16,
143
143
  "num_warps": 4,
144
- "num_stages": 3
144
+ "num_stages": 4
145
145
  }
146
146
  }
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 64,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 4
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 128,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 32,
15
+ "num_warps": 4,
16
+ "num_stages": 3
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 128,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 1,
23
+ "num_warps": 4,
24
+ "num_stages": 4
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 128,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 64,
31
+ "num_warps": 4,
32
+ "num_stages": 3
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 256,
37
+ "BLOCK_SIZE_K": 64,
38
+ "GROUP_SIZE_M": 64,
39
+ "num_warps": 4,
40
+ "num_stages": 4
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 128,
45
+ "BLOCK_SIZE_K": 128,
46
+ "GROUP_SIZE_M": 16,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 128,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 32,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 128,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 16,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 128,
70
+ "GROUP_SIZE_M": 16,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 1,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 16,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 32,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 16,
100
+ "BLOCK_SIZE_N": 256,
101
+ "BLOCK_SIZE_K": 64,
102
+ "GROUP_SIZE_M": 64,
103
+ "num_warps": 4,
104
+ "num_stages": 3
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 64,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 16,
111
+ "num_warps": 4,
112
+ "num_stages": 3
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 64,
116
+ "BLOCK_SIZE_N": 128,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 32,
119
+ "num_warps": 4,
120
+ "num_stages": 3
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 64,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 64,
127
+ "num_warps": 4,
128
+ "num_stages": 3
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 64,
132
+ "BLOCK_SIZE_N": 128,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 16,
135
+ "num_warps": 4,
136
+ "num_stages": 3
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 64,
140
+ "BLOCK_SIZE_N": 128,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 16,
143
+ "num_warps": 4,
144
+ "num_stages": 3
145
+ }
146
+ }
@@ -13,7 +13,16 @@ import triton
13
13
  import triton.language as tl
14
14
 
15
15
  from sglang.srt.layers.moe.topk import select_experts
16
- from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
16
+ from sglang.srt.layers.quantization.fp8_kernel import (
17
+ per_token_group_quant_fp8,
18
+ scaled_fp8_quant,
19
+ sglang_per_token_group_quant_fp8,
20
+ )
21
+ from sglang.srt.layers.quantization.int8_kernel import (
22
+ per_token_group_quant_int8,
23
+ per_token_quant_int8,
24
+ sglang_per_token_group_quant_int8,
25
+ )
17
26
  from sglang.srt.utils import (
18
27
  direct_register_custom_op,
19
28
  get_bool_env_var,
@@ -36,7 +45,7 @@ if _is_cuda or _is_hip:
36
45
 
37
46
 
38
47
  logger = logging.getLogger(__name__)
39
- padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
48
+ padding_size = 128 if bool(int(os.getenv("SGLANG_MOE_PADDING", "0"))) else 0
40
49
  enable_moe_align_block_size_triton = bool(
41
50
  int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
42
51
  )
@@ -746,18 +755,6 @@ def invoke_fused_moe_kernel(
746
755
  block_shape: Optional[List[int]] = None,
747
756
  no_combine: bool = False,
748
757
  ) -> None:
749
- from sglang.srt.layers.quantization.int8_kernel import (
750
- per_token_group_quant_int8,
751
- per_token_quant_int8,
752
- )
753
-
754
- if _is_cuda:
755
- from sglang.srt.layers.quantization.fp8_kernel import (
756
- sglang_per_token_group_quant_fp8,
757
- )
758
- else:
759
- from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
760
-
761
758
  assert topk_weights.stride(1) == 1
762
759
  assert sorted_token_ids.stride(0) == 1
763
760
 
@@ -794,7 +791,10 @@ def invoke_fused_moe_kernel(
794
791
  # activation block-wise int8 quantization
795
792
  assert len(block_shape) == 2
796
793
  block_n, block_k = block_shape[0], block_shape[1]
797
- A, A_scale = per_token_group_quant_int8(A, block_k)
794
+ if _is_cuda:
795
+ A, A_scale = sglang_per_token_group_quant_int8(A, block_k)
796
+ else:
797
+ A, A_scale = per_token_group_quant_int8(A, block_k)
798
798
  assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
799
799
  assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
800
800
  assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
@@ -940,7 +940,10 @@ def get_moe_configs(
940
940
  )
941
941
  if os.path.exists(config_file_path):
942
942
  with open(config_file_path) as f:
943
- logger.info("Using configuration from %s for MoE layer.", config_file_path)
943
+ logger.info(
944
+ "Using configuration from %s for MoE layer. Please note that due to the large number of configs under fused_moe_triton/configs potentially not being tuned with the corresponding Triton version in your current environment, using the current configs may result in performance degradation. To achieve best performance, you can consider re-tuning the Triton fused MOE kernel in your current environment. For the tuning method, please refer to: https://github.com/sgl-project/sglang/blob/main/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py. ",
945
+ config_file_path,
946
+ )
944
947
  # If a configuration has been found, return it
945
948
  return {int(key): val for key, val in json.load(f).items()}
946
949
 
@@ -1324,7 +1327,7 @@ def fused_experts_impl(
1324
1327
  if (
1325
1328
  not (use_fp8_w8a8 or use_int8_w8a8)
1326
1329
  or block_shape is not None
1327
- or (_is_hip and get_bool_env_var("CK_MOE"))
1330
+ or (_is_hip and get_bool_env_var("SGLANG_AITER_MOE"))
1328
1331
  ):
1329
1332
  padded_size = 0
1330
1333
 
@@ -18,7 +18,7 @@ from sglang.srt.layers.quantization.base_config import (
18
18
  QuantizationConfig,
19
19
  QuantizeMethodBase,
20
20
  )
21
- from sglang.srt.utils import get_bool_env_var, is_hip, permute_weight, set_weight_attrs
21
+ from sglang.srt.utils import get_bool_env_var, is_hip, set_weight_attrs
22
22
 
23
23
  if torch.cuda.is_available():
24
24
  from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
@@ -30,7 +30,9 @@ import logging
30
30
  _is_hip = is_hip()
31
31
 
32
32
  if _is_hip:
33
- from aiter import ck_moe
33
+ from aiter import ActivationType
34
+ from aiter.fused_moe_bf16_asm import ck_moe_2stages
35
+ from aiter.ops.shuffle import shuffle_weight
34
36
 
35
37
  logger = logging.getLogger(__name__)
36
38
 
@@ -102,14 +104,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
102
104
  set_weight_attrs(w2_weight, extra_weight_attrs)
103
105
 
104
106
  def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
105
- if _is_hip and get_bool_env_var("CK_MOE"):
107
+ if _is_hip and get_bool_env_var("SGLANG_AITER_MOE"):
106
108
  layer.w13_weight = torch.nn.Parameter(
107
- permute_weight(layer.w13_weight.data),
109
+ shuffle_weight(layer.w13_weight.data, (16, 16)),
108
110
  requires_grad=False,
109
111
  )
110
112
  torch.cuda.empty_cache()
111
113
  layer.w2_weight = torch.nn.Parameter(
112
- permute_weight(layer.w2_weight.data),
114
+ shuffle_weight(layer.w2_weight.data, (16, 16)),
113
115
  requires_grad=False,
114
116
  )
115
117
  torch.cuda.empty_cache()
@@ -182,21 +184,17 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
182
184
  routed_scaling_factor=routed_scaling_factor,
183
185
  )
184
186
 
185
- if _is_hip and get_bool_env_var("CK_MOE"):
187
+ if _is_hip and get_bool_env_var("SGLANG_AITER_MOE"):
186
188
  assert not no_combine, "unsupported"
187
- return ck_moe(
189
+ return ck_moe_2stages(
188
190
  x,
189
191
  layer.w13_weight,
190
192
  layer.w2_weight,
191
193
  topk_weights,
192
194
  topk_ids,
193
- None,
194
- None,
195
- None,
196
- None,
197
- 32,
198
- None,
199
- activation,
195
+ activation=(
196
+ ActivationType.Silu if activation == "silu" else ActivationType.Gelu
197
+ ),
200
198
  )
201
199
  else:
202
200
  return fused_experts(
@@ -527,7 +525,7 @@ class FusedMoE(torch.nn.Module):
527
525
  # Case input scale: input_scale loading is only supported for fp8
528
526
  if "input_scale" in weight_name:
529
527
  # INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust input_scale for e4m3fnuz (AMD)
530
- if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
528
+ if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
531
529
  loaded_weight = loaded_weight * 2.0
532
530
 
533
531
  # this is needed for compressed-tensors only
@@ -569,7 +567,7 @@ class FusedMoE(torch.nn.Module):
569
567
  quant_method = getattr(param, "quant_method", None)
570
568
  if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value:
571
569
  # INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust INT4 column-wise scaling number to e4m3fnuz (AMD)
572
- if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
570
+ if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
573
571
  loaded_weight = loaded_weight * 0.5
574
572
 
575
573
  self._load_per_channel_weight_scale(
@@ -592,7 +590,7 @@ class FusedMoE(torch.nn.Module):
592
590
  )
593
591
  elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
594
592
  # INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust FP8 per-tensor scaling number for e4m3fnuz (AMD)
595
- if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
593
+ if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
596
594
  loaded_weight = loaded_weight * 2.0
597
595
 
598
596
  self._load_per_tensor_weight_scale(
@@ -12,6 +12,7 @@ from sglang.srt.model_executor.model_runner import ForwardBatch
12
12
 
13
13
  class PoolingType(IntEnum):
14
14
  LAST = 0
15
+ CLS = 1
15
16
 
16
17
 
17
18
  @dataclass
@@ -41,6 +42,11 @@ class Pooler(nn.Module):
41
42
  if self.pooling_type == PoolingType.LAST:
42
43
  last_token_indices = torch.cumsum(forward_batch.extend_seq_lens, dim=0) - 1
43
44
  pooled_data = hidden_states[last_token_indices]
45
+ elif self.pooling_type == PoolingType.CLS:
46
+ prompt_lens = forward_batch.extend_seq_lens
47
+ first_token_flat_indices = torch.zeros_like(prompt_lens)
48
+ first_token_flat_indices[1:] += torch.cumsum(prompt_lens, dim=0)[:-1]
49
+ pooled_data = hidden_states[first_token_flat_indices]
44
50
  else:
45
51
  raise ValueError(f"Invalid pooling type: {self.pooling_type}")
46
52
 
@@ -3,7 +3,6 @@ import logging
3
3
  from typing import Any, Dict, List, Optional
4
4
 
5
5
  import torch
6
- from sgl_kernel import awq_dequantize
7
6
 
8
7
  from sglang.srt.layers.linear import (
9
8
  LinearBase,
@@ -12,6 +11,11 @@ from sglang.srt.layers.linear import (
12
11
  )
13
12
  from sglang.srt.layers.parameter import GroupQuantScaleParameter, PackedvLLMParameter
14
13
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
14
+ from sglang.srt.utils import is_cuda
15
+
16
+ _is_cuda = is_cuda()
17
+ if _is_cuda:
18
+ from sgl_kernel import awq_dequantize
15
19
 
16
20
  logger = logging.getLogger(__name__)
17
21