sglang 0.4.5__py3-none-any.whl → 0.4.5.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (121) hide show
  1. sglang/bench_one_batch.py +21 -0
  2. sglang/bench_serving.py +10 -4
  3. sglang/srt/configs/model_config.py +37 -5
  4. sglang/srt/constrained/base_grammar_backend.py +26 -5
  5. sglang/srt/constrained/llguidance_backend.py +1 -0
  6. sglang/srt/constrained/outlines_backend.py +1 -0
  7. sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
  8. sglang/srt/constrained/xgrammar_backend.py +1 -0
  9. sglang/srt/disaggregation/base/__init__.py +8 -0
  10. sglang/srt/disaggregation/base/conn.py +113 -0
  11. sglang/srt/disaggregation/decode.py +18 -5
  12. sglang/srt/disaggregation/mini_lb.py +53 -122
  13. sglang/srt/disaggregation/mooncake/__init__.py +6 -0
  14. sglang/srt/disaggregation/mooncake/conn.py +615 -0
  15. sglang/srt/disaggregation/mooncake/transfer_engine.py +108 -0
  16. sglang/srt/disaggregation/prefill.py +43 -19
  17. sglang/srt/disaggregation/utils.py +31 -0
  18. sglang/srt/entrypoints/EngineBase.py +53 -0
  19. sglang/srt/entrypoints/engine.py +36 -8
  20. sglang/srt/entrypoints/http_server.py +37 -8
  21. sglang/srt/entrypoints/http_server_engine.py +142 -0
  22. sglang/srt/entrypoints/verl_engine.py +37 -10
  23. sglang/srt/hf_transformers_utils.py +4 -0
  24. sglang/srt/layers/attention/flashattention_backend.py +330 -200
  25. sglang/srt/layers/attention/flashinfer_backend.py +13 -7
  26. sglang/srt/layers/attention/vision.py +1 -1
  27. sglang/srt/layers/dp_attention.py +2 -4
  28. sglang/srt/layers/elementwise.py +15 -2
  29. sglang/srt/layers/linear.py +1 -0
  30. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
  31. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  36. sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  37. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +38 -21
  38. sglang/srt/layers/moe/router.py +7 -1
  39. sglang/srt/layers/moe/topk.py +37 -16
  40. sglang/srt/layers/quantization/__init__.py +12 -5
  41. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +4 -0
  42. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +68 -45
  43. sglang/srt/layers/quantization/fp8.py +25 -13
  44. sglang/srt/layers/quantization/fp8_kernel.py +130 -4
  45. sglang/srt/layers/quantization/fp8_utils.py +34 -6
  46. sglang/srt/layers/quantization/kv_cache.py +43 -52
  47. sglang/srt/layers/quantization/modelopt_quant.py +271 -4
  48. sglang/srt/layers/quantization/w8a8_fp8.py +154 -4
  49. sglang/srt/layers/quantization/w8a8_int8.py +1 -0
  50. sglang/srt/layers/radix_attention.py +13 -1
  51. sglang/srt/layers/rotary_embedding.py +12 -1
  52. sglang/srt/managers/io_struct.py +254 -97
  53. sglang/srt/managers/mm_utils.py +3 -2
  54. sglang/srt/managers/multimodal_processors/base_processor.py +114 -77
  55. sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
  56. sglang/srt/managers/multimodal_processors/mllama4.py +21 -36
  57. sglang/srt/managers/schedule_batch.py +62 -21
  58. sglang/srt/managers/scheduler.py +71 -14
  59. sglang/srt/managers/tokenizer_manager.py +17 -3
  60. sglang/srt/managers/tp_worker.py +1 -0
  61. sglang/srt/mem_cache/memory_pool.py +14 -1
  62. sglang/srt/metrics/collector.py +9 -0
  63. sglang/srt/model_executor/cuda_graph_runner.py +7 -4
  64. sglang/srt/model_executor/forward_batch_info.py +234 -15
  65. sglang/srt/model_executor/model_runner.py +48 -9
  66. sglang/srt/model_loader/loader.py +31 -4
  67. sglang/srt/model_loader/weight_utils.py +4 -2
  68. sglang/srt/models/baichuan.py +2 -0
  69. sglang/srt/models/chatglm.py +1 -0
  70. sglang/srt/models/commandr.py +1 -0
  71. sglang/srt/models/dbrx.py +1 -0
  72. sglang/srt/models/deepseek.py +1 -0
  73. sglang/srt/models/deepseek_v2.py +248 -61
  74. sglang/srt/models/exaone.py +1 -0
  75. sglang/srt/models/gemma.py +1 -0
  76. sglang/srt/models/gemma2.py +1 -0
  77. sglang/srt/models/gemma3_causal.py +1 -0
  78. sglang/srt/models/gpt2.py +1 -0
  79. sglang/srt/models/gpt_bigcode.py +1 -0
  80. sglang/srt/models/granite.py +1 -0
  81. sglang/srt/models/grok.py +1 -0
  82. sglang/srt/models/internlm2.py +1 -0
  83. sglang/srt/models/llama.py +1 -0
  84. sglang/srt/models/llama4.py +101 -34
  85. sglang/srt/models/minicpm.py +1 -0
  86. sglang/srt/models/minicpm3.py +2 -0
  87. sglang/srt/models/mixtral.py +1 -0
  88. sglang/srt/models/mixtral_quant.py +1 -0
  89. sglang/srt/models/mllama.py +51 -8
  90. sglang/srt/models/mllama4.py +102 -29
  91. sglang/srt/models/olmo.py +1 -0
  92. sglang/srt/models/olmo2.py +1 -0
  93. sglang/srt/models/olmoe.py +1 -0
  94. sglang/srt/models/phi3_small.py +1 -0
  95. sglang/srt/models/qwen.py +1 -0
  96. sglang/srt/models/qwen2.py +1 -0
  97. sglang/srt/models/qwen2_5_vl.py +35 -70
  98. sglang/srt/models/qwen2_moe.py +1 -0
  99. sglang/srt/models/qwen2_vl.py +27 -25
  100. sglang/srt/models/stablelm.py +1 -0
  101. sglang/srt/models/xverse.py +1 -0
  102. sglang/srt/models/xverse_moe.py +1 -0
  103. sglang/srt/openai_api/adapter.py +4 -1
  104. sglang/srt/patch_torch.py +11 -0
  105. sglang/srt/server_args.py +34 -0
  106. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
  107. sglang/srt/speculative/eagle_utils.py +1 -11
  108. sglang/srt/speculative/eagle_worker.py +6 -2
  109. sglang/srt/utils.py +120 -9
  110. sglang/test/attention/test_flashattn_backend.py +259 -221
  111. sglang/test/attention/test_flashattn_mla_backend.py +285 -0
  112. sglang/test/attention/test_prefix_chunk_info.py +224 -0
  113. sglang/test/test_block_fp8.py +57 -0
  114. sglang/test/test_utils.py +19 -8
  115. sglang/version.py +1 -1
  116. {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/METADATA +14 -4
  117. {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/RECORD +120 -106
  118. sglang/srt/disaggregation/conn.py +0 -81
  119. {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/WHEEL +0 -0
  120. {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/licenses/LICENSE +0 -0
  121. {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 64,
5
+ "BLOCK_SIZE_K": 64,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 3
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 64,
13
+ "BLOCK_SIZE_K": 64,
14
+ "GROUP_SIZE_M": 64,
15
+ "num_warps": 4,
16
+ "num_stages": 4
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 64,
21
+ "BLOCK_SIZE_K": 64,
22
+ "GROUP_SIZE_M": 64,
23
+ "num_warps": 4,
24
+ "num_stages": 4
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 64,
29
+ "BLOCK_SIZE_K": 64,
30
+ "GROUP_SIZE_M": 32,
31
+ "num_warps": 4,
32
+ "num_stages": 3
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 64,
38
+ "GROUP_SIZE_M": 64,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 64,
45
+ "BLOCK_SIZE_K": 64,
46
+ "GROUP_SIZE_M": 1,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 64,
53
+ "BLOCK_SIZE_K": 64,
54
+ "GROUP_SIZE_M": 1,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 64,
61
+ "BLOCK_SIZE_K": 64,
62
+ "GROUP_SIZE_M": 1,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 64,
70
+ "GROUP_SIZE_M": 32,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 64,
78
+ "GROUP_SIZE_M": 16,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 64,
86
+ "GROUP_SIZE_M": 32,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 64,
94
+ "GROUP_SIZE_M": 64,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 32,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 64,
102
+ "GROUP_SIZE_M": 64,
103
+ "num_warps": 4,
104
+ "num_stages": 3
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 32,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 64,
110
+ "GROUP_SIZE_M": 64,
111
+ "num_warps": 4,
112
+ "num_stages": 2
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 64,
116
+ "BLOCK_SIZE_N": 128,
117
+ "BLOCK_SIZE_K": 64,
118
+ "GROUP_SIZE_M": 64,
119
+ "num_warps": 4,
120
+ "num_stages": 3
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 32,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 64,
126
+ "GROUP_SIZE_M": 64,
127
+ "num_warps": 4,
128
+ "num_stages": 2
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 128,
132
+ "BLOCK_SIZE_N": 128,
133
+ "BLOCK_SIZE_K": 64,
134
+ "GROUP_SIZE_M": 1,
135
+ "num_warps": 4,
136
+ "num_stages": 3
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 128,
140
+ "BLOCK_SIZE_N": 128,
141
+ "BLOCK_SIZE_K": 64,
142
+ "GROUP_SIZE_M": 64,
143
+ "num_warps": 4,
144
+ "num_stages": 2
145
+ }
146
+ }
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 64,
5
+ "BLOCK_SIZE_K": 64,
6
+ "GROUP_SIZE_M": 64,
7
+ "num_warps": 4,
8
+ "num_stages": 3
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 64,
13
+ "BLOCK_SIZE_K": 64,
14
+ "GROUP_SIZE_M": 64,
15
+ "num_warps": 4,
16
+ "num_stages": 4
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 64,
21
+ "BLOCK_SIZE_K": 64,
22
+ "GROUP_SIZE_M": 64,
23
+ "num_warps": 4,
24
+ "num_stages": 4
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 64,
29
+ "BLOCK_SIZE_K": 64,
30
+ "GROUP_SIZE_M": 16,
31
+ "num_warps": 4,
32
+ "num_stages": 3
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 64,
37
+ "BLOCK_SIZE_K": 64,
38
+ "GROUP_SIZE_M": 16,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 64,
45
+ "BLOCK_SIZE_K": 64,
46
+ "GROUP_SIZE_M": 1,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 64,
53
+ "BLOCK_SIZE_K": 64,
54
+ "GROUP_SIZE_M": 16,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 64,
61
+ "BLOCK_SIZE_K": 64,
62
+ "GROUP_SIZE_M": 1,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 64,
70
+ "GROUP_SIZE_M": 64,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 64,
78
+ "GROUP_SIZE_M": 16,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 64,
86
+ "GROUP_SIZE_M": 32,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 64,
94
+ "GROUP_SIZE_M": 64,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 32,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 64,
102
+ "GROUP_SIZE_M": 64,
103
+ "num_warps": 4,
104
+ "num_stages": 3
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 32,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 64,
110
+ "GROUP_SIZE_M": 64,
111
+ "num_warps": 4,
112
+ "num_stages": 2
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 64,
116
+ "BLOCK_SIZE_N": 128,
117
+ "BLOCK_SIZE_K": 64,
118
+ "GROUP_SIZE_M": 64,
119
+ "num_warps": 4,
120
+ "num_stages": 3
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 64,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 64,
126
+ "GROUP_SIZE_M": 64,
127
+ "num_warps": 4,
128
+ "num_stages": 4
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 128,
132
+ "BLOCK_SIZE_N": 128,
133
+ "BLOCK_SIZE_K": 64,
134
+ "GROUP_SIZE_M": 64,
135
+ "num_warps": 4,
136
+ "num_stages": 3
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 128,
140
+ "BLOCK_SIZE_N": 64,
141
+ "BLOCK_SIZE_K": 64,
142
+ "GROUP_SIZE_M": 16,
143
+ "num_warps": 4,
144
+ "num_stages": 2
145
+ }
146
+ }
@@ -342,6 +342,7 @@ def fused_moe_kernel(
342
342
  use_fp8_w8a8: tl.constexpr,
343
343
  use_int8_w8a8: tl.constexpr,
344
344
  use_int8_w8a16: tl.constexpr,
345
+ per_channel_quant: tl.constexpr,
345
346
  even_Ks: tl.constexpr,
346
347
  ):
347
348
  """
@@ -416,20 +417,7 @@ def fused_moe_kernel(
416
417
  )
417
418
  b_scale = tl.load(b_scale_ptrs)
418
419
 
419
- if use_fp8_w8a8:
420
- # block-wise
421
- if group_k > 0 and group_n > 0:
422
- a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
423
- offs_bsn = offs_bn // group_n
424
- b_scale_ptrs = (
425
- b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
426
- )
427
- # tensor-wise
428
- else:
429
- a_scale = tl.load(a_scale_ptr)
430
- b_scale = tl.load(b_scale_ptr + off_experts)
431
-
432
- if use_int8_w8a8:
420
+ if use_fp8_w8a8 or use_int8_w8a8:
433
421
  # block-wise
434
422
  if group_k > 0 and group_n > 0:
435
423
  a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
@@ -438,8 +426,7 @@ def fused_moe_kernel(
438
426
  b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
439
427
  )
440
428
  # channel-wise
441
- else:
442
- # Load per-column scale for weights
429
+ elif per_channel_quant:
443
430
  b_scale_ptrs = (
444
431
  b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
445
432
  )
@@ -447,6 +434,10 @@ def fused_moe_kernel(
447
434
  # Load per-token scale for activations
448
435
  a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
449
436
  a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:, None]
437
+ # tensor-wise
438
+ else:
439
+ a_scale = tl.load(a_scale_ptr)
440
+ b_scale = tl.load(b_scale_ptr + off_experts)
450
441
 
451
442
  # -----------------------------------------------------------
452
443
  # Iterate to compute a block of the C matrix.
@@ -711,12 +702,12 @@ def moe_align_block_size(
711
702
  num_tokens_post_pad,
712
703
  )
713
704
  else:
714
- token_cnts_buffer = torch.zeros(
705
+ token_cnts_buffer = torch.empty(
715
706
  (num_experts + 1) * num_experts,
716
707
  dtype=torch.int32,
717
708
  device=topk_ids.device,
718
709
  )
719
- cumsum_buffer = torch.zeros(
710
+ cumsum_buffer = torch.empty(
720
711
  num_experts + 1, dtype=torch.int32, device=topk_ids.device
721
712
  )
722
713
 
@@ -753,6 +744,7 @@ def invoke_fused_moe_kernel(
753
744
  use_int8_w8a8: bool,
754
745
  use_int8_w8a16: bool,
755
746
  use_int4_w4a16: bool,
747
+ per_channel_quant: bool,
756
748
  block_shape: Optional[List[int]] = None,
757
749
  no_combine: bool = False,
758
750
  ) -> None:
@@ -765,6 +757,8 @@ def invoke_fused_moe_kernel(
765
757
  from sglang.srt.layers.quantization.fp8_kernel import (
766
758
  sglang_per_token_group_quant_fp8,
767
759
  )
760
+ else:
761
+ from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
768
762
 
769
763
  assert topk_weights.stride(1) == 1
770
764
  assert sorted_token_ids.stride(0) == 1
@@ -775,10 +769,15 @@ def invoke_fused_moe_kernel(
775
769
  if block_shape is None:
776
770
  # activation tensor-wise fp8 quantization, dynamic or static
777
771
  padded_size = padding_size
772
+ # activations apply per-token quantization when weights apply per-channel quantization by default
778
773
  if _is_cuda:
779
- A, A_scale = sgl_scaled_fp8_quant(A, A_scale)
774
+ A, A_scale = sgl_scaled_fp8_quant(
775
+ A, A_scale, use_per_token_if_dynamic=per_channel_quant
776
+ )
780
777
  else:
781
- A, A_scale = vllm_ops.scaled_fp8_quant(A, A_scale)
778
+ A, A_scale = vllm_ops.scaled_fp8_quant(
779
+ A, A_scale, use_per_token_if_dynamic=per_channel_quant
780
+ )
782
781
  else:
783
782
  # activation block-wise fp8 quantization
784
783
  assert len(block_shape) == 2
@@ -794,6 +793,9 @@ def invoke_fused_moe_kernel(
794
793
  assert B_scale is not None
795
794
  if block_shape is None:
796
795
  # activation channel-wise int8 quantization
796
+ assert (
797
+ per_channel_quant
798
+ ), "int8 quantization only supports channel-wise quantization except for block-wise quantization"
797
799
  A, A_scale = per_token_quant_int8(A)
798
800
  else:
799
801
  # activation block-wise int8 quantization
@@ -902,6 +904,7 @@ def invoke_fused_moe_kernel(
902
904
  use_fp8_w8a8=use_fp8_w8a8,
903
905
  use_int8_w8a8=use_int8_w8a8,
904
906
  use_int8_w8a16=use_int8_w8a16,
907
+ per_channel_quant=per_channel_quant,
905
908
  even_Ks=even_Ks,
906
909
  **config,
907
910
  )
@@ -953,7 +956,7 @@ def get_moe_configs(
953
956
  logger.warning(
954
957
  (
955
958
  "Using default MoE config. Performance might be sub-optimal! "
956
- "Config file not found at %s"
959
+ "Config file not found at %s, you can tune the config with https://github.com/sgl-project/sglang/blob/main/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py."
957
960
  ),
958
961
  config_file_path,
959
962
  )
@@ -1084,6 +1087,7 @@ def inplace_fused_experts(
1084
1087
  use_int8_w8a8: bool = False,
1085
1088
  use_int8_w8a16: bool = False,
1086
1089
  use_int4_w4a16: bool = False,
1090
+ per_channel_quant: bool = False,
1087
1091
  w1_scale: Optional[torch.Tensor] = None,
1088
1092
  w2_scale: Optional[torch.Tensor] = None,
1089
1093
  w1_zp: Optional[torch.Tensor] = None,
@@ -1105,6 +1109,7 @@ def inplace_fused_experts(
1105
1109
  use_int8_w8a8,
1106
1110
  use_int8_w8a16,
1107
1111
  use_int4_w4a16,
1112
+ per_channel_quant,
1108
1113
  w1_scale,
1109
1114
  w2_scale,
1110
1115
  w1_zp,
@@ -1127,6 +1132,7 @@ def inplace_fused_experts_fake(
1127
1132
  use_int8_w8a8: bool = False,
1128
1133
  use_int8_w8a16: bool = False,
1129
1134
  use_int4_w4a16: bool = False,
1135
+ per_channel_quant: bool = False,
1130
1136
  w1_scale: Optional[torch.Tensor] = None,
1131
1137
  w2_scale: Optional[torch.Tensor] = None,
1132
1138
  w1_zp: Optional[torch.Tensor] = None,
@@ -1158,6 +1164,7 @@ def outplace_fused_experts(
1158
1164
  use_int8_w8a8: bool = False,
1159
1165
  use_int8_w8a16: bool = False,
1160
1166
  use_int4_w4a16: bool = False,
1167
+ per_channel_quant: bool = False,
1161
1168
  w1_scale: Optional[torch.Tensor] = None,
1162
1169
  w2_scale: Optional[torch.Tensor] = None,
1163
1170
  w1_zp: Optional[torch.Tensor] = None,
@@ -1180,6 +1187,7 @@ def outplace_fused_experts(
1180
1187
  use_int8_w8a8,
1181
1188
  use_int8_w8a16,
1182
1189
  use_int4_w4a16,
1190
+ per_channel_quant,
1183
1191
  w1_scale,
1184
1192
  w2_scale,
1185
1193
  w1_zp,
@@ -1203,6 +1211,7 @@ def outplace_fused_experts_fake(
1203
1211
  use_int8_w8a8: bool = False,
1204
1212
  use_int8_w8a16: bool = False,
1205
1213
  use_int4_w4a16: bool = False,
1214
+ per_channel_quant: bool = False,
1206
1215
  w1_scale: Optional[torch.Tensor] = None,
1207
1216
  w2_scale: Optional[torch.Tensor] = None,
1208
1217
  w1_zp: Optional[torch.Tensor] = None,
@@ -1236,6 +1245,7 @@ def fused_experts(
1236
1245
  use_int8_w8a8: bool = False,
1237
1246
  use_int8_w8a16: bool = False,
1238
1247
  use_int4_w4a16: bool = False,
1248
+ per_channel_quant: bool = False,
1239
1249
  w1_scale: Optional[torch.Tensor] = None,
1240
1250
  w2_scale: Optional[torch.Tensor] = None,
1241
1251
  w1_zp: Optional[torch.Tensor] = None,
@@ -1259,6 +1269,7 @@ def fused_experts(
1259
1269
  use_int8_w8a8,
1260
1270
  use_int8_w8a16,
1261
1271
  use_int4_w4a16,
1272
+ per_channel_quant,
1262
1273
  w1_scale,
1263
1274
  w2_scale,
1264
1275
  w1_zp,
@@ -1281,6 +1292,7 @@ def fused_experts(
1281
1292
  use_int8_w8a8,
1282
1293
  use_int8_w8a16,
1283
1294
  use_int4_w4a16,
1295
+ per_channel_quant,
1284
1296
  w1_scale,
1285
1297
  w2_scale,
1286
1298
  w1_zp,
@@ -1305,6 +1317,7 @@ def fused_experts_impl(
1305
1317
  use_int8_w8a8: bool = False,
1306
1318
  use_int8_w8a16: bool = False,
1307
1319
  use_int4_w4a16: bool = False,
1320
+ per_channel_quant: bool = False,
1308
1321
  w1_scale: Optional[torch.Tensor] = None,
1309
1322
  w2_scale: Optional[torch.Tensor] = None,
1310
1323
  w1_zp: Optional[torch.Tensor] = None,
@@ -1441,6 +1454,7 @@ def fused_experts_impl(
1441
1454
  use_int8_w8a8=use_int8_w8a8,
1442
1455
  use_int8_w8a16=use_int8_w8a16,
1443
1456
  use_int4_w4a16=use_int4_w4a16,
1457
+ per_channel_quant=per_channel_quant,
1444
1458
  block_shape=block_shape,
1445
1459
  )
1446
1460
  if activation == "silu":
@@ -1484,6 +1498,7 @@ def fused_experts_impl(
1484
1498
  use_int8_w8a8=use_int8_w8a8,
1485
1499
  use_int8_w8a16=use_int8_w8a16,
1486
1500
  use_int4_w4a16=use_int4_w4a16,
1501
+ per_channel_quant=per_channel_quant,
1487
1502
  block_shape=block_shape,
1488
1503
  )
1489
1504
 
@@ -1530,6 +1545,7 @@ def fused_moe(
1530
1545
  use_int8_w8a8: bool = False,
1531
1546
  use_int8_w8a16: bool = False,
1532
1547
  use_int4_w4a16: bool = False,
1548
+ per_channel_quant: bool = False,
1533
1549
  w1_scale: Optional[torch.Tensor] = None,
1534
1550
  w2_scale: Optional[torch.Tensor] = None,
1535
1551
  w1_zp: Optional[torch.Tensor] = None,
@@ -1606,6 +1622,7 @@ def fused_moe(
1606
1622
  use_int8_w8a8=use_int8_w8a8,
1607
1623
  use_int8_w8a16=use_int8_w8a16,
1608
1624
  use_int4_w4a16=use_int4_w4a16,
1625
+ per_channel_quant=per_channel_quant,
1609
1626
  w1_scale=w1_scale,
1610
1627
  w2_scale=w2_scale,
1611
1628
  w1_zp=w1_zp,
@@ -5,6 +5,9 @@ import triton
5
5
  import triton.language as tl
6
6
 
7
7
  from sglang.srt.layers.moe.topk import fused_topk
8
+ from sglang.srt.utils import is_hip
9
+
10
+ _is_hip = is_hip()
8
11
 
9
12
 
10
13
  @triton.jit
@@ -116,10 +119,13 @@ def fused_moe_router_impl(
116
119
  topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device)
117
120
 
118
121
  grid = lambda meta: (bs,)
122
+
123
+ min_num_warps = 16 if _is_hip else 32
124
+
119
125
  config = {
120
126
  "BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
121
127
  "num_warps": max(
122
- min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), 32), 4
128
+ min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), min_num_warps), 4
123
129
  ),
124
130
  }
125
131
 
@@ -12,6 +12,7 @@
12
12
  # limitations under the License.
13
13
  # ==============================================================================
14
14
 
15
+ import math
15
16
  import os
16
17
  from typing import Callable, Optional
17
18
 
@@ -25,6 +26,8 @@ from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip
25
26
  _is_cuda = is_cuda()
26
27
  _is_hip = is_hip()
27
28
 
29
+ if _is_cuda:
30
+ from sgl_kernel import moe_fused_gate
28
31
 
29
32
  expert_distribution_recorder = ExpertDistributionRecorder()
30
33
 
@@ -209,6 +212,10 @@ def biased_grouped_topk_impl(
209
212
  return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
210
213
 
211
214
 
215
+ def is_power_of_two(n):
216
+ return n > 0 and math.log2(n).is_integer()
217
+
218
+
212
219
  def biased_grouped_topk(
213
220
  hidden_states: torch.Tensor,
214
221
  gating_output: torch.Tensor,
@@ -220,23 +227,37 @@ def biased_grouped_topk(
220
227
  compiled: bool = True,
221
228
  n_share_experts_fusion: int = 0,
222
229
  ):
223
- biased_grouped_topk_fn = (
224
- torch.compile(
225
- biased_grouped_topk_impl, dynamic=True, backend=get_compiler_backend()
230
+ # TODO: moe_fused_gate kernel is not supported for n_share_experts_fusion > 0 now.
231
+ if (
232
+ _is_cuda
233
+ and n_share_experts_fusion == 0
234
+ and is_power_of_two(correction_bias.shape[0])
235
+ ):
236
+ return moe_fused_gate(
237
+ gating_output,
238
+ correction_bias,
239
+ num_expert_group,
240
+ topk_group,
241
+ topk,
242
+ )
243
+ else:
244
+ biased_grouped_topk_fn = (
245
+ torch.compile(
246
+ biased_grouped_topk_impl, dynamic=True, backend=get_compiler_backend()
247
+ )
248
+ if compiled
249
+ else biased_grouped_topk_impl
250
+ )
251
+ return biased_grouped_topk_fn(
252
+ hidden_states,
253
+ gating_output,
254
+ correction_bias,
255
+ topk,
256
+ renormalize,
257
+ num_expert_group,
258
+ topk_group,
259
+ n_share_experts_fusion=n_share_experts_fusion,
226
260
  )
227
- if compiled
228
- else biased_grouped_topk_impl
229
- )
230
- return biased_grouped_topk_fn(
231
- hidden_states,
232
- gating_output,
233
- correction_bias,
234
- topk,
235
- renormalize,
236
- num_expert_group,
237
- topk_group,
238
- n_share_experts_fusion=n_share_experts_fusion,
239
- )
240
261
 
241
262
 
242
263
  def select_experts(
@@ -59,20 +59,20 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import
59
59
  )
60
60
  from sglang.srt.layers.quantization.fp8 import Fp8Config
61
61
  from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
62
- from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config
62
+ from sglang.srt.layers.quantization.modelopt_quant import (
63
+ ModelOptFp4Config,
64
+ ModelOptFp8Config,
65
+ )
63
66
  from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
64
67
  from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
65
68
  from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
66
- from sglang.srt.layers.vocab_parallel_embedding import (
67
- ParallelLMHead,
68
- UnquantizedEmbeddingMethod,
69
- )
70
69
 
71
70
  # Base quantization methods that don't depend on vllm
72
71
  BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
73
72
  "fp8": Fp8Config,
74
73
  "blockwise_int8": BlockInt8Config,
75
74
  "modelopt": ModelOptFp8Config,
75
+ "modelopt_fp4": ModelOptFp4Config,
76
76
  "w8a8_int8": W8A8Int8Config,
77
77
  "w8a8_fp8": W8A8Fp8Config,
78
78
  "moe_wna16": MoeWNA16Config,
@@ -176,6 +176,13 @@ def get_linear_quant_method(
176
176
  prefix: str,
177
177
  linear_method_cls: type,
178
178
  ):
179
+ # Move import here to avoid circular import. This is only used in monkey patching
180
+ # of vllm's QuantizationConfig.
181
+ from sglang.srt.layers.vocab_parallel_embedding import (
182
+ ParallelLMHead,
183
+ UnquantizedEmbeddingMethod,
184
+ )
185
+
179
186
  cloned_config = deepcopy(config)
180
187
  parallel_lm_head_quantized = (
181
188
  isinstance(layer, ParallelLMHead) and cloned_config.lm_head_quantized
@@ -77,6 +77,7 @@ class CompressedTensorsConfig(QuantizationConfig):
77
77
  sparsity_ignore_list: List[str],
78
78
  kv_cache_scheme: Optional[Dict[str, Any]] = None,
79
79
  config: Optional[Dict[str, Any]] = None,
80
+ packed_modules_mapping: Dict[str, List[str]] = {},
80
81
  ):
81
82
  super().__init__()
82
83
  self.ignore = ignore
@@ -87,6 +88,7 @@ class CompressedTensorsConfig(QuantizationConfig):
87
88
  self.sparsity_scheme_map = sparsity_scheme_map
88
89
  self.sparsity_ignore_list = sparsity_ignore_list
89
90
  self.config = config
91
+ self.packed_modules_mapping = packed_modules_mapping
90
92
 
91
93
  def get_linear_method(self) -> "CompressedTensorsLinearMethod":
92
94
  return CompressedTensorsLinearMethod(self)
@@ -136,6 +138,7 @@ class CompressedTensorsConfig(QuantizationConfig):
136
138
  sparsity_scheme_map, sparsity_ignore_list = cls._parse_sparsity_config(
137
139
  config=config
138
140
  )
141
+ packed_modules_mapping = config.get("packed_modules_mapping", {})
139
142
 
140
143
  return cls(
141
144
  target_scheme_map=target_scheme_map,
@@ -144,6 +147,7 @@ class CompressedTensorsConfig(QuantizationConfig):
144
147
  sparsity_scheme_map=sparsity_scheme_map,
145
148
  sparsity_ignore_list=sparsity_ignore_list,
146
149
  config=config,
150
+ packed_modules_mapping=packed_modules_mapping,
147
151
  )
148
152
 
149
153
  @classmethod