sglang 0.4.4.post4__py3-none-any.whl → 0.4.5.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (134) hide show
  1. sglang/bench_one_batch.py +21 -0
  2. sglang/bench_serving.py +10 -4
  3. sglang/lang/chat_template.py +24 -0
  4. sglang/srt/configs/model_config.py +40 -4
  5. sglang/srt/constrained/base_grammar_backend.py +26 -5
  6. sglang/srt/constrained/llguidance_backend.py +1 -0
  7. sglang/srt/constrained/outlines_backend.py +1 -0
  8. sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
  9. sglang/srt/constrained/xgrammar_backend.py +1 -0
  10. sglang/srt/conversation.py +29 -4
  11. sglang/srt/disaggregation/base/__init__.py +8 -0
  12. sglang/srt/disaggregation/base/conn.py +113 -0
  13. sglang/srt/disaggregation/decode.py +18 -5
  14. sglang/srt/disaggregation/mini_lb.py +53 -122
  15. sglang/srt/disaggregation/mooncake/__init__.py +6 -0
  16. sglang/srt/disaggregation/mooncake/conn.py +615 -0
  17. sglang/srt/disaggregation/mooncake/transfer_engine.py +108 -0
  18. sglang/srt/disaggregation/prefill.py +43 -19
  19. sglang/srt/disaggregation/utils.py +31 -0
  20. sglang/srt/entrypoints/EngineBase.py +53 -0
  21. sglang/srt/entrypoints/engine.py +36 -8
  22. sglang/srt/entrypoints/http_server.py +37 -8
  23. sglang/srt/entrypoints/http_server_engine.py +142 -0
  24. sglang/srt/entrypoints/verl_engine.py +37 -10
  25. sglang/srt/hf_transformers_utils.py +4 -0
  26. sglang/srt/layers/attention/flashattention_backend.py +609 -202
  27. sglang/srt/layers/attention/flashinfer_backend.py +13 -7
  28. sglang/srt/layers/attention/vision.py +1 -1
  29. sglang/srt/layers/dp_attention.py +2 -4
  30. sglang/srt/layers/elementwise.py +15 -2
  31. sglang/srt/layers/linear.py +1 -0
  32. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
  33. sglang/srt/layers/moe/fused_moe_native.py +5 -0
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  36. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  37. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H200.json +146 -0
  38. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  39. sglang/srt/layers/moe/fused_moe_triton/configs/E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  40. sglang/srt/layers/moe/fused_moe_triton/configs/E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  43. sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  47. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +51 -24
  48. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  49. sglang/srt/layers/moe/router.py +7 -1
  50. sglang/srt/layers/moe/topk.py +37 -16
  51. sglang/srt/layers/quantization/__init__.py +13 -5
  52. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  53. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +4 -0
  54. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +68 -45
  55. sglang/srt/layers/quantization/fp8.py +28 -14
  56. sglang/srt/layers/quantization/fp8_kernel.py +130 -4
  57. sglang/srt/layers/quantization/fp8_utils.py +34 -6
  58. sglang/srt/layers/quantization/kv_cache.py +43 -52
  59. sglang/srt/layers/quantization/modelopt_quant.py +271 -4
  60. sglang/srt/layers/quantization/moe_wna16.py +2 -0
  61. sglang/srt/layers/quantization/w8a8_fp8.py +154 -4
  62. sglang/srt/layers/quantization/w8a8_int8.py +3 -0
  63. sglang/srt/layers/radix_attention.py +14 -0
  64. sglang/srt/layers/rotary_embedding.py +75 -1
  65. sglang/srt/managers/io_struct.py +254 -97
  66. sglang/srt/managers/mm_utils.py +3 -2
  67. sglang/srt/managers/multimodal_processors/base_processor.py +114 -77
  68. sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
  69. sglang/srt/managers/multimodal_processors/mllama4.py +146 -0
  70. sglang/srt/managers/schedule_batch.py +62 -21
  71. sglang/srt/managers/scheduler.py +71 -14
  72. sglang/srt/managers/tokenizer_manager.py +17 -3
  73. sglang/srt/managers/tp_worker.py +1 -0
  74. sglang/srt/mem_cache/memory_pool.py +14 -1
  75. sglang/srt/metrics/collector.py +9 -0
  76. sglang/srt/model_executor/cuda_graph_runner.py +7 -4
  77. sglang/srt/model_executor/forward_batch_info.py +234 -15
  78. sglang/srt/model_executor/model_runner.py +49 -9
  79. sglang/srt/model_loader/loader.py +31 -4
  80. sglang/srt/model_loader/weight_utils.py +4 -2
  81. sglang/srt/models/baichuan.py +2 -0
  82. sglang/srt/models/chatglm.py +1 -0
  83. sglang/srt/models/commandr.py +1 -0
  84. sglang/srt/models/dbrx.py +1 -0
  85. sglang/srt/models/deepseek.py +1 -0
  86. sglang/srt/models/deepseek_v2.py +248 -61
  87. sglang/srt/models/exaone.py +1 -0
  88. sglang/srt/models/gemma.py +1 -0
  89. sglang/srt/models/gemma2.py +1 -0
  90. sglang/srt/models/gemma3_causal.py +1 -0
  91. sglang/srt/models/gpt2.py +1 -0
  92. sglang/srt/models/gpt_bigcode.py +1 -0
  93. sglang/srt/models/granite.py +1 -0
  94. sglang/srt/models/grok.py +1 -0
  95. sglang/srt/models/internlm2.py +1 -0
  96. sglang/srt/models/llama.py +13 -4
  97. sglang/srt/models/llama4.py +487 -0
  98. sglang/srt/models/minicpm.py +1 -0
  99. sglang/srt/models/minicpm3.py +2 -0
  100. sglang/srt/models/mixtral.py +1 -0
  101. sglang/srt/models/mixtral_quant.py +1 -0
  102. sglang/srt/models/mllama.py +51 -8
  103. sglang/srt/models/mllama4.py +227 -0
  104. sglang/srt/models/olmo.py +1 -0
  105. sglang/srt/models/olmo2.py +1 -0
  106. sglang/srt/models/olmoe.py +1 -0
  107. sglang/srt/models/phi3_small.py +1 -0
  108. sglang/srt/models/qwen.py +1 -0
  109. sglang/srt/models/qwen2.py +1 -0
  110. sglang/srt/models/qwen2_5_vl.py +35 -70
  111. sglang/srt/models/qwen2_moe.py +1 -0
  112. sglang/srt/models/qwen2_vl.py +27 -25
  113. sglang/srt/models/stablelm.py +1 -0
  114. sglang/srt/models/xverse.py +1 -0
  115. sglang/srt/models/xverse_moe.py +1 -0
  116. sglang/srt/openai_api/adapter.py +4 -1
  117. sglang/srt/patch_torch.py +11 -0
  118. sglang/srt/server_args.py +34 -0
  119. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
  120. sglang/srt/speculative/eagle_utils.py +1 -11
  121. sglang/srt/speculative/eagle_worker.py +6 -2
  122. sglang/srt/utils.py +120 -9
  123. sglang/test/attention/test_flashattn_backend.py +259 -221
  124. sglang/test/attention/test_flashattn_mla_backend.py +285 -0
  125. sglang/test/attention/test_prefix_chunk_info.py +224 -0
  126. sglang/test/test_block_fp8.py +57 -0
  127. sglang/test/test_utils.py +19 -8
  128. sglang/version.py +1 -1
  129. {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/METADATA +14 -4
  130. {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/RECORD +133 -109
  131. sglang/srt/disaggregation/conn.py +0 -81
  132. {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/WHEEL +0 -0
  133. {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/licenses/LICENSE +0 -0
  134. {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 64,
5
+ "BLOCK_SIZE_K": 64,
6
+ "GROUP_SIZE_M": 64,
7
+ "num_warps": 4,
8
+ "num_stages": 3
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 64,
13
+ "BLOCK_SIZE_K": 64,
14
+ "GROUP_SIZE_M": 64,
15
+ "num_warps": 4,
16
+ "num_stages": 4
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 64,
21
+ "BLOCK_SIZE_K": 64,
22
+ "GROUP_SIZE_M": 64,
23
+ "num_warps": 4,
24
+ "num_stages": 4
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 64,
29
+ "BLOCK_SIZE_K": 64,
30
+ "GROUP_SIZE_M": 16,
31
+ "num_warps": 4,
32
+ "num_stages": 3
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 64,
37
+ "BLOCK_SIZE_K": 64,
38
+ "GROUP_SIZE_M": 16,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 64,
45
+ "BLOCK_SIZE_K": 64,
46
+ "GROUP_SIZE_M": 1,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 64,
53
+ "BLOCK_SIZE_K": 64,
54
+ "GROUP_SIZE_M": 16,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 64,
61
+ "BLOCK_SIZE_K": 64,
62
+ "GROUP_SIZE_M": 1,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 64,
70
+ "GROUP_SIZE_M": 64,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 64,
78
+ "GROUP_SIZE_M": 16,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 64,
86
+ "GROUP_SIZE_M": 32,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 64,
94
+ "GROUP_SIZE_M": 64,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 32,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 64,
102
+ "GROUP_SIZE_M": 64,
103
+ "num_warps": 4,
104
+ "num_stages": 3
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 32,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 64,
110
+ "GROUP_SIZE_M": 64,
111
+ "num_warps": 4,
112
+ "num_stages": 2
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 64,
116
+ "BLOCK_SIZE_N": 128,
117
+ "BLOCK_SIZE_K": 64,
118
+ "GROUP_SIZE_M": 64,
119
+ "num_warps": 4,
120
+ "num_stages": 3
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 64,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 64,
126
+ "GROUP_SIZE_M": 64,
127
+ "num_warps": 4,
128
+ "num_stages": 4
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 128,
132
+ "BLOCK_SIZE_N": 128,
133
+ "BLOCK_SIZE_K": 64,
134
+ "GROUP_SIZE_M": 64,
135
+ "num_warps": 4,
136
+ "num_stages": 3
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 128,
140
+ "BLOCK_SIZE_N": 64,
141
+ "BLOCK_SIZE_K": 64,
142
+ "GROUP_SIZE_M": 16,
143
+ "num_warps": 4,
144
+ "num_stages": 2
145
+ }
146
+ }
@@ -342,6 +342,7 @@ def fused_moe_kernel(
342
342
  use_fp8_w8a8: tl.constexpr,
343
343
  use_int8_w8a8: tl.constexpr,
344
344
  use_int8_w8a16: tl.constexpr,
345
+ per_channel_quant: tl.constexpr,
345
346
  even_Ks: tl.constexpr,
346
347
  ):
347
348
  """
@@ -416,20 +417,7 @@ def fused_moe_kernel(
416
417
  )
417
418
  b_scale = tl.load(b_scale_ptrs)
418
419
 
419
- if use_fp8_w8a8:
420
- # block-wise
421
- if group_k > 0 and group_n > 0:
422
- a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
423
- offs_bsn = offs_bn // group_n
424
- b_scale_ptrs = (
425
- b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
426
- )
427
- # tensor-wise
428
- else:
429
- a_scale = tl.load(a_scale_ptr)
430
- b_scale = tl.load(b_scale_ptr + off_experts)
431
-
432
- if use_int8_w8a8:
420
+ if use_fp8_w8a8 or use_int8_w8a8:
433
421
  # block-wise
434
422
  if group_k > 0 and group_n > 0:
435
423
  a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
@@ -438,8 +426,7 @@ def fused_moe_kernel(
438
426
  b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
439
427
  )
440
428
  # channel-wise
441
- else:
442
- # Load per-column scale for weights
429
+ elif per_channel_quant:
443
430
  b_scale_ptrs = (
444
431
  b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
445
432
  )
@@ -447,6 +434,10 @@ def fused_moe_kernel(
447
434
  # Load per-token scale for activations
448
435
  a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
449
436
  a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:, None]
437
+ # tensor-wise
438
+ else:
439
+ a_scale = tl.load(a_scale_ptr)
440
+ b_scale = tl.load(b_scale_ptr + off_experts)
450
441
 
451
442
  # -----------------------------------------------------------
452
443
  # Iterate to compute a block of the C matrix.
@@ -711,12 +702,12 @@ def moe_align_block_size(
711
702
  num_tokens_post_pad,
712
703
  )
713
704
  else:
714
- token_cnts_buffer = torch.zeros(
705
+ token_cnts_buffer = torch.empty(
715
706
  (num_experts + 1) * num_experts,
716
707
  dtype=torch.int32,
717
708
  device=topk_ids.device,
718
709
  )
719
- cumsum_buffer = torch.zeros(
710
+ cumsum_buffer = torch.empty(
720
711
  num_experts + 1, dtype=torch.int32, device=topk_ids.device
721
712
  )
722
713
 
@@ -753,6 +744,7 @@ def invoke_fused_moe_kernel(
753
744
  use_int8_w8a8: bool,
754
745
  use_int8_w8a16: bool,
755
746
  use_int4_w4a16: bool,
747
+ per_channel_quant: bool,
756
748
  block_shape: Optional[List[int]] = None,
757
749
  no_combine: bool = False,
758
750
  ) -> None:
@@ -765,6 +757,8 @@ def invoke_fused_moe_kernel(
765
757
  from sglang.srt.layers.quantization.fp8_kernel import (
766
758
  sglang_per_token_group_quant_fp8,
767
759
  )
760
+ else:
761
+ from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
768
762
 
769
763
  assert topk_weights.stride(1) == 1
770
764
  assert sorted_token_ids.stride(0) == 1
@@ -775,10 +769,15 @@ def invoke_fused_moe_kernel(
775
769
  if block_shape is None:
776
770
  # activation tensor-wise fp8 quantization, dynamic or static
777
771
  padded_size = padding_size
772
+ # activations apply per-token quantization when weights apply per-channel quantization by default
778
773
  if _is_cuda:
779
- A, A_scale = sgl_scaled_fp8_quant(A, A_scale)
774
+ A, A_scale = sgl_scaled_fp8_quant(
775
+ A, A_scale, use_per_token_if_dynamic=per_channel_quant
776
+ )
780
777
  else:
781
- A, A_scale = vllm_ops.scaled_fp8_quant(A, A_scale)
778
+ A, A_scale = vllm_ops.scaled_fp8_quant(
779
+ A, A_scale, use_per_token_if_dynamic=per_channel_quant
780
+ )
782
781
  else:
783
782
  # activation block-wise fp8 quantization
784
783
  assert len(block_shape) == 2
@@ -794,6 +793,9 @@ def invoke_fused_moe_kernel(
794
793
  assert B_scale is not None
795
794
  if block_shape is None:
796
795
  # activation channel-wise int8 quantization
796
+ assert (
797
+ per_channel_quant
798
+ ), "int8 quantization only supports channel-wise quantization except for block-wise quantization"
797
799
  A, A_scale = per_token_quant_int8(A)
798
800
  else:
799
801
  # activation block-wise int8 quantization
@@ -902,6 +904,7 @@ def invoke_fused_moe_kernel(
902
904
  use_fp8_w8a8=use_fp8_w8a8,
903
905
  use_int8_w8a8=use_int8_w8a8,
904
906
  use_int8_w8a16=use_int8_w8a16,
907
+ per_channel_quant=per_channel_quant,
905
908
  even_Ks=even_Ks,
906
909
  **config,
907
910
  )
@@ -953,7 +956,7 @@ def get_moe_configs(
953
956
  logger.warning(
954
957
  (
955
958
  "Using default MoE config. Performance might be sub-optimal! "
956
- "Config file not found at %s"
959
+ "Config file not found at %s, you can tune the config with https://github.com/sgl-project/sglang/blob/main/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py."
957
960
  ),
958
961
  config_file_path,
959
962
  )
@@ -1079,10 +1082,12 @@ def inplace_fused_experts(
1079
1082
  topk_weights: torch.Tensor,
1080
1083
  topk_ids: torch.Tensor,
1081
1084
  activation: str = "silu",
1085
+ apply_router_weight_on_input: bool = False,
1082
1086
  use_fp8_w8a8: bool = False,
1083
1087
  use_int8_w8a8: bool = False,
1084
1088
  use_int8_w8a16: bool = False,
1085
1089
  use_int4_w4a16: bool = False,
1090
+ per_channel_quant: bool = False,
1086
1091
  w1_scale: Optional[torch.Tensor] = None,
1087
1092
  w2_scale: Optional[torch.Tensor] = None,
1088
1093
  w1_zp: Optional[torch.Tensor] = None,
@@ -1099,10 +1104,12 @@ def inplace_fused_experts(
1099
1104
  topk_ids,
1100
1105
  True,
1101
1106
  activation,
1107
+ apply_router_weight_on_input,
1102
1108
  use_fp8_w8a8,
1103
1109
  use_int8_w8a8,
1104
1110
  use_int8_w8a16,
1105
1111
  use_int4_w4a16,
1112
+ per_channel_quant,
1106
1113
  w1_scale,
1107
1114
  w2_scale,
1108
1115
  w1_zp,
@@ -1120,10 +1127,12 @@ def inplace_fused_experts_fake(
1120
1127
  topk_weights: torch.Tensor,
1121
1128
  topk_ids: torch.Tensor,
1122
1129
  activation: str = "silu",
1130
+ apply_router_weight_on_input: bool = False,
1123
1131
  use_fp8_w8a8: bool = False,
1124
1132
  use_int8_w8a8: bool = False,
1125
1133
  use_int8_w8a16: bool = False,
1126
1134
  use_int4_w4a16: bool = False,
1135
+ per_channel_quant: bool = False,
1127
1136
  w1_scale: Optional[torch.Tensor] = None,
1128
1137
  w2_scale: Optional[torch.Tensor] = None,
1129
1138
  w1_zp: Optional[torch.Tensor] = None,
@@ -1150,10 +1159,12 @@ def outplace_fused_experts(
1150
1159
  topk_weights: torch.Tensor,
1151
1160
  topk_ids: torch.Tensor,
1152
1161
  activation: str = "silu",
1162
+ apply_router_weight_on_input: bool = False,
1153
1163
  use_fp8_w8a8: bool = False,
1154
1164
  use_int8_w8a8: bool = False,
1155
1165
  use_int8_w8a16: bool = False,
1156
1166
  use_int4_w4a16: bool = False,
1167
+ per_channel_quant: bool = False,
1157
1168
  w1_scale: Optional[torch.Tensor] = None,
1158
1169
  w2_scale: Optional[torch.Tensor] = None,
1159
1170
  w1_zp: Optional[torch.Tensor] = None,
@@ -1171,10 +1182,12 @@ def outplace_fused_experts(
1171
1182
  topk_ids,
1172
1183
  False,
1173
1184
  activation,
1185
+ apply_router_weight_on_input,
1174
1186
  use_fp8_w8a8,
1175
1187
  use_int8_w8a8,
1176
1188
  use_int8_w8a16,
1177
1189
  use_int4_w4a16,
1190
+ per_channel_quant,
1178
1191
  w1_scale,
1179
1192
  w2_scale,
1180
1193
  w1_zp,
@@ -1193,10 +1206,12 @@ def outplace_fused_experts_fake(
1193
1206
  topk_weights: torch.Tensor,
1194
1207
  topk_ids: torch.Tensor,
1195
1208
  activation: str = "silu",
1209
+ apply_router_weight_on_input: bool = False,
1196
1210
  use_fp8_w8a8: bool = False,
1197
1211
  use_int8_w8a8: bool = False,
1198
1212
  use_int8_w8a16: bool = False,
1199
1213
  use_int4_w4a16: bool = False,
1214
+ per_channel_quant: bool = False,
1200
1215
  w1_scale: Optional[torch.Tensor] = None,
1201
1216
  w2_scale: Optional[torch.Tensor] = None,
1202
1217
  w1_zp: Optional[torch.Tensor] = None,
@@ -1225,10 +1240,12 @@ def fused_experts(
1225
1240
  topk_ids: torch.Tensor,
1226
1241
  inplace: bool = False,
1227
1242
  activation: str = "silu",
1243
+ apply_router_weight_on_input: bool = False,
1228
1244
  use_fp8_w8a8: bool = False,
1229
1245
  use_int8_w8a8: bool = False,
1230
1246
  use_int8_w8a16: bool = False,
1231
1247
  use_int4_w4a16: bool = False,
1248
+ per_channel_quant: bool = False,
1232
1249
  w1_scale: Optional[torch.Tensor] = None,
1233
1250
  w2_scale: Optional[torch.Tensor] = None,
1234
1251
  w1_zp: Optional[torch.Tensor] = None,
@@ -1247,10 +1264,12 @@ def fused_experts(
1247
1264
  topk_weights,
1248
1265
  topk_ids,
1249
1266
  activation,
1267
+ apply_router_weight_on_input,
1250
1268
  use_fp8_w8a8,
1251
1269
  use_int8_w8a8,
1252
1270
  use_int8_w8a16,
1253
1271
  use_int4_w4a16,
1272
+ per_channel_quant,
1254
1273
  w1_scale,
1255
1274
  w2_scale,
1256
1275
  w1_zp,
@@ -1268,10 +1287,12 @@ def fused_experts(
1268
1287
  topk_weights,
1269
1288
  topk_ids,
1270
1289
  activation,
1290
+ apply_router_weight_on_input,
1271
1291
  use_fp8_w8a8,
1272
1292
  use_int8_w8a8,
1273
1293
  use_int8_w8a16,
1274
1294
  use_int4_w4a16,
1295
+ per_channel_quant,
1275
1296
  w1_scale,
1276
1297
  w2_scale,
1277
1298
  w1_zp,
@@ -1291,10 +1312,12 @@ def fused_experts_impl(
1291
1312
  topk_ids: torch.Tensor,
1292
1313
  inplace: bool = False,
1293
1314
  activation: str = "silu",
1315
+ apply_router_weight_on_input: bool = False,
1294
1316
  use_fp8_w8a8: bool = False,
1295
1317
  use_int8_w8a8: bool = False,
1296
1318
  use_int8_w8a16: bool = False,
1297
1319
  use_int4_w4a16: bool = False,
1320
+ per_channel_quant: bool = False,
1298
1321
  w1_scale: Optional[torch.Tensor] = None,
1299
1322
  w2_scale: Optional[torch.Tensor] = None,
1300
1323
  w1_zp: Optional[torch.Tensor] = None,
@@ -1423,7 +1446,7 @@ def fused_experts_impl(
1423
1446
  sorted_token_ids,
1424
1447
  expert_ids,
1425
1448
  num_tokens_post_padded,
1426
- False,
1449
+ apply_router_weight_on_input,
1427
1450
  topk_ids.shape[1],
1428
1451
  config,
1429
1452
  compute_type=compute_type,
@@ -1431,6 +1454,7 @@ def fused_experts_impl(
1431
1454
  use_int8_w8a8=use_int8_w8a8,
1432
1455
  use_int8_w8a16=use_int8_w8a16,
1433
1456
  use_int4_w4a16=use_int4_w4a16,
1457
+ per_channel_quant=per_channel_quant,
1434
1458
  block_shape=block_shape,
1435
1459
  )
1436
1460
  if activation == "silu":
@@ -1456,7 +1480,7 @@ def fused_experts_impl(
1456
1480
  (
1457
1481
  intermediate_cache3
1458
1482
  if not no_combine and topk_ids.shape[1] != 1
1459
- else out_hidden_states[begin_chunk_idx:end_chunk_idx]
1483
+ else out_hidden_states[begin_chunk_idx:end_chunk_idx].unsqueeze(0)
1460
1484
  ),
1461
1485
  a2_scale,
1462
1486
  w2_scale,
@@ -1466,7 +1490,7 @@ def fused_experts_impl(
1466
1490
  sorted_token_ids,
1467
1491
  expert_ids,
1468
1492
  num_tokens_post_padded,
1469
- True,
1493
+ not apply_router_weight_on_input,
1470
1494
  1,
1471
1495
  config,
1472
1496
  compute_type=compute_type,
@@ -1474,6 +1498,7 @@ def fused_experts_impl(
1474
1498
  use_int8_w8a8=use_int8_w8a8,
1475
1499
  use_int8_w8a16=use_int8_w8a16,
1476
1500
  use_int4_w4a16=use_int4_w4a16,
1501
+ per_channel_quant=per_channel_quant,
1477
1502
  block_shape=block_shape,
1478
1503
  )
1479
1504
 
@@ -1520,6 +1545,7 @@ def fused_moe(
1520
1545
  use_int8_w8a8: bool = False,
1521
1546
  use_int8_w8a16: bool = False,
1522
1547
  use_int4_w4a16: bool = False,
1548
+ per_channel_quant: bool = False,
1523
1549
  w1_scale: Optional[torch.Tensor] = None,
1524
1550
  w2_scale: Optional[torch.Tensor] = None,
1525
1551
  w1_zp: Optional[torch.Tensor] = None,
@@ -1596,6 +1622,7 @@ def fused_moe(
1596
1622
  use_int8_w8a8=use_int8_w8a8,
1597
1623
  use_int8_w8a16=use_int8_w8a16,
1598
1624
  use_int4_w4a16=use_int4_w4a16,
1625
+ per_channel_quant=per_channel_quant,
1599
1626
  w1_scale=w1_scale,
1600
1627
  w2_scale=w2_scale,
1601
1628
  w1_zp=w1_zp,
@@ -128,6 +128,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
128
128
  custom_routing_function: Optional[Callable] = None,
129
129
  correction_bias: Optional[torch.Tensor] = None,
130
130
  activation: str = "silu",
131
+ apply_router_weight_on_input: bool = False,
131
132
  inplace: bool = True,
132
133
  no_combine: bool = False,
133
134
  ) -> torch.Tensor:
@@ -143,6 +144,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
143
144
  custom_routing_function=custom_routing_function,
144
145
  correction_bias=correction_bias,
145
146
  activation=activation,
147
+ apply_router_weight_on_input=apply_router_weight_on_input,
146
148
  inplace=inplace,
147
149
  no_combine=no_combine,
148
150
  )
@@ -160,6 +162,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
160
162
  custom_routing_function: Optional[Callable] = None,
161
163
  correction_bias: Optional[torch.Tensor] = None,
162
164
  activation: str = "silu",
165
+ apply_router_weight_on_input: bool = False,
163
166
  inplace: bool = True,
164
167
  no_combine: bool = False,
165
168
  ) -> torch.Tensor:
@@ -200,6 +203,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
200
203
  topk_ids=topk_ids,
201
204
  inplace=inplace and not no_combine,
202
205
  activation=activation,
206
+ apply_router_weight_on_input=apply_router_weight_on_input,
203
207
  no_combine=no_combine,
204
208
  )
205
209
 
@@ -276,6 +280,7 @@ class FusedMoE(torch.nn.Module):
276
280
  custom_routing_function: Optional[Callable] = None,
277
281
  correction_bias: Optional[torch.Tensor] = None,
278
282
  activation: str = "silu",
283
+ apply_router_weight_on_input: bool = False,
279
284
  use_presharded_weights: bool = False,
280
285
  inplace: bool = True,
281
286
  no_combine: bool = False,
@@ -302,6 +307,7 @@ class FusedMoE(torch.nn.Module):
302
307
  self.custom_routing_function = custom_routing_function
303
308
  self.correction_bias = correction_bias
304
309
  self.activation = activation
310
+ self.apply_router_weight_on_input = apply_router_weight_on_input
305
311
  self.use_presharded_weights = use_presharded_weights
306
312
  self.inplace = inplace
307
313
  self.no_combine = no_combine
@@ -630,6 +636,7 @@ class FusedMoE(torch.nn.Module):
630
636
  custom_routing_function=self.custom_routing_function,
631
637
  correction_bias=self.correction_bias,
632
638
  activation=self.activation,
639
+ apply_router_weight_on_input=self.apply_router_weight_on_input,
633
640
  )
634
641
 
635
642
  if self.reduce_results and self.tp_size > 1:
@@ -5,6 +5,9 @@ import triton
5
5
  import triton.language as tl
6
6
 
7
7
  from sglang.srt.layers.moe.topk import fused_topk
8
+ from sglang.srt.utils import is_hip
9
+
10
+ _is_hip = is_hip()
8
11
 
9
12
 
10
13
  @triton.jit
@@ -116,10 +119,13 @@ def fused_moe_router_impl(
116
119
  topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device)
117
120
 
118
121
  grid = lambda meta: (bs,)
122
+
123
+ min_num_warps = 16 if _is_hip else 32
124
+
119
125
  config = {
120
126
  "BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
121
127
  "num_warps": max(
122
- min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), 32), 4
128
+ min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), min_num_warps), 4
123
129
  ),
124
130
  }
125
131
 
@@ -12,6 +12,7 @@
12
12
  # limitations under the License.
13
13
  # ==============================================================================
14
14
 
15
+ import math
15
16
  import os
16
17
  from typing import Callable, Optional
17
18
 
@@ -25,6 +26,8 @@ from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip
25
26
  _is_cuda = is_cuda()
26
27
  _is_hip = is_hip()
27
28
 
29
+ if _is_cuda:
30
+ from sgl_kernel import moe_fused_gate
28
31
 
29
32
  expert_distribution_recorder = ExpertDistributionRecorder()
30
33
 
@@ -209,6 +212,10 @@ def biased_grouped_topk_impl(
209
212
  return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
210
213
 
211
214
 
215
+ def is_power_of_two(n):
216
+ return n > 0 and math.log2(n).is_integer()
217
+
218
+
212
219
  def biased_grouped_topk(
213
220
  hidden_states: torch.Tensor,
214
221
  gating_output: torch.Tensor,
@@ -220,23 +227,37 @@ def biased_grouped_topk(
220
227
  compiled: bool = True,
221
228
  n_share_experts_fusion: int = 0,
222
229
  ):
223
- biased_grouped_topk_fn = (
224
- torch.compile(
225
- biased_grouped_topk_impl, dynamic=True, backend=get_compiler_backend()
230
+ # TODO: moe_fused_gate kernel is not supported for n_share_experts_fusion > 0 now.
231
+ if (
232
+ _is_cuda
233
+ and n_share_experts_fusion == 0
234
+ and is_power_of_two(correction_bias.shape[0])
235
+ ):
236
+ return moe_fused_gate(
237
+ gating_output,
238
+ correction_bias,
239
+ num_expert_group,
240
+ topk_group,
241
+ topk,
242
+ )
243
+ else:
244
+ biased_grouped_topk_fn = (
245
+ torch.compile(
246
+ biased_grouped_topk_impl, dynamic=True, backend=get_compiler_backend()
247
+ )
248
+ if compiled
249
+ else biased_grouped_topk_impl
250
+ )
251
+ return biased_grouped_topk_fn(
252
+ hidden_states,
253
+ gating_output,
254
+ correction_bias,
255
+ topk,
256
+ renormalize,
257
+ num_expert_group,
258
+ topk_group,
259
+ n_share_experts_fusion=n_share_experts_fusion,
226
260
  )
227
- if compiled
228
- else biased_grouped_topk_impl
229
- )
230
- return biased_grouped_topk_fn(
231
- hidden_states,
232
- gating_output,
233
- correction_bias,
234
- topk,
235
- renormalize,
236
- num_expert_group,
237
- topk_group,
238
- n_share_experts_fusion=n_share_experts_fusion,
239
- )
240
261
 
241
262
 
242
263
  def select_experts(
@@ -59,20 +59,20 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import
59
59
  )
60
60
  from sglang.srt.layers.quantization.fp8 import Fp8Config
61
61
  from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
62
- from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config
62
+ from sglang.srt.layers.quantization.modelopt_quant import (
63
+ ModelOptFp4Config,
64
+ ModelOptFp8Config,
65
+ )
63
66
  from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
64
67
  from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
65
68
  from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
66
- from sglang.srt.layers.vocab_parallel_embedding import (
67
- ParallelLMHead,
68
- UnquantizedEmbeddingMethod,
69
- )
70
69
 
71
70
  # Base quantization methods that don't depend on vllm
72
71
  BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
73
72
  "fp8": Fp8Config,
74
73
  "blockwise_int8": BlockInt8Config,
75
74
  "modelopt": ModelOptFp8Config,
75
+ "modelopt_fp4": ModelOptFp4Config,
76
76
  "w8a8_int8": W8A8Int8Config,
77
77
  "w8a8_fp8": W8A8Fp8Config,
78
78
  "moe_wna16": MoeWNA16Config,
@@ -176,6 +176,13 @@ def get_linear_quant_method(
176
176
  prefix: str,
177
177
  linear_method_cls: type,
178
178
  ):
179
+ # Move import here to avoid circular import. This is only used in monkey patching
180
+ # of vllm's QuantizationConfig.
181
+ from sglang.srt.layers.vocab_parallel_embedding import (
182
+ ParallelLMHead,
183
+ UnquantizedEmbeddingMethod,
184
+ )
185
+
179
186
  cloned_config = deepcopy(config)
180
187
  parallel_lm_head_quantized = (
181
188
  isinstance(layer, ParallelLMHead) and cloned_config.lm_head_quantized
@@ -280,6 +287,7 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
280
287
  custom_routing_function: Optional[Callable] = None,
281
288
  correction_bias: Optional[torch.Tensor] = None,
282
289
  activation: str = "silu",
290
+ apply_router_weight_on_input: bool = False,
283
291
  inplace: bool = True,
284
292
  no_combine: bool = False,
285
293
  ):
@@ -370,6 +370,7 @@ class BlockInt8MoEMethod:
370
370
  custom_routing_function: Optional[Callable] = None,
371
371
  correction_bias: Optional[torch.Tensor] = None,
372
372
  activation: str = "silu",
373
+ apply_router_weight_on_input: bool = False,
373
374
  inplace: bool = True,
374
375
  no_combine: bool = False,
375
376
  ) -> torch.Tensor:
@@ -398,6 +399,7 @@ class BlockInt8MoEMethod:
398
399
  topk_ids=topk_ids,
399
400
  inplace=inplace,
400
401
  activation=activation,
402
+ apply_router_weight_on_input=apply_router_weight_on_input,
401
403
  use_int8_w8a8=True,
402
404
  w1_scale=(layer.w13_weight_scale_inv),
403
405
  w2_scale=(layer.w2_weight_scale_inv),