sglang 0.4.5__py3-none-any.whl → 0.4.5.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. sglang/__init__.py +2 -4
  2. sglang/bench_one_batch.py +23 -2
  3. sglang/bench_serving.py +6 -4
  4. sglang/lang/backend/anthropic.py +0 -4
  5. sglang/lang/backend/base_backend.py +1 -1
  6. sglang/lang/backend/openai.py +1 -1
  7. sglang/lang/backend/vertexai.py +0 -1
  8. sglang/lang/compiler.py +1 -7
  9. sglang/lang/tracer.py +3 -7
  10. sglang/srt/_custom_ops.py +0 -2
  11. sglang/srt/configs/model_config.py +37 -5
  12. sglang/srt/constrained/base_grammar_backend.py +26 -5
  13. sglang/srt/constrained/llguidance_backend.py +1 -0
  14. sglang/srt/constrained/outlines_backend.py +1 -0
  15. sglang/srt/constrained/outlines_jump_forward.py +14 -1
  16. sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
  17. sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
  18. sglang/srt/constrained/xgrammar_backend.py +27 -4
  19. sglang/srt/custom_op.py +0 -62
  20. sglang/srt/disaggregation/base/__init__.py +8 -0
  21. sglang/srt/disaggregation/base/conn.py +113 -0
  22. sglang/srt/disaggregation/decode.py +80 -11
  23. sglang/srt/disaggregation/mini_lb.py +58 -123
  24. sglang/srt/disaggregation/mooncake/__init__.py +6 -0
  25. sglang/srt/disaggregation/mooncake/conn.py +585 -0
  26. sglang/srt/disaggregation/mooncake/transfer_engine.py +77 -0
  27. sglang/srt/disaggregation/prefill.py +82 -22
  28. sglang/srt/disaggregation/utils.py +46 -0
  29. sglang/srt/entrypoints/EngineBase.py +53 -0
  30. sglang/srt/entrypoints/engine.py +36 -8
  31. sglang/srt/entrypoints/http_server.py +37 -8
  32. sglang/srt/entrypoints/http_server_engine.py +142 -0
  33. sglang/srt/entrypoints/verl_engine.py +42 -13
  34. sglang/srt/hf_transformers_utils.py +4 -0
  35. sglang/srt/layers/activation.py +6 -8
  36. sglang/srt/layers/attention/flashattention_backend.py +430 -257
  37. sglang/srt/layers/attention/flashinfer_backend.py +18 -9
  38. sglang/srt/layers/attention/torch_native_backend.py +6 -1
  39. sglang/srt/layers/attention/triton_backend.py +6 -0
  40. sglang/srt/layers/attention/triton_ops/extend_attention.py +13 -2
  41. sglang/srt/layers/attention/vision.py +1 -1
  42. sglang/srt/layers/dp_attention.py +2 -4
  43. sglang/srt/layers/elementwise.py +15 -2
  44. sglang/srt/layers/layernorm.py +1 -1
  45. sglang/srt/layers/linear.py +18 -3
  46. sglang/srt/layers/moe/ep_moe/layer.py +15 -29
  47. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
  48. sglang/srt/layers/moe/fused_moe_native.py +4 -0
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +46 -34
  56. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  57. sglang/srt/layers/moe/router.py +7 -1
  58. sglang/srt/layers/moe/topk.py +63 -45
  59. sglang/srt/layers/parameter.py +0 -2
  60. sglang/srt/layers/quantization/__init__.py +13 -5
  61. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  62. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +12 -2
  63. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -77
  64. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
  65. sglang/srt/layers/quantization/fp8.py +131 -136
  66. sglang/srt/layers/quantization/fp8_kernel.py +328 -46
  67. sglang/srt/layers/quantization/fp8_utils.py +206 -253
  68. sglang/srt/layers/quantization/kv_cache.py +43 -52
  69. sglang/srt/layers/quantization/modelopt_quant.py +271 -4
  70. sglang/srt/layers/quantization/moe_wna16.py +2 -0
  71. sglang/srt/layers/quantization/utils.py +5 -11
  72. sglang/srt/layers/quantization/w8a8_fp8.py +156 -4
  73. sglang/srt/layers/quantization/w8a8_int8.py +8 -7
  74. sglang/srt/layers/radix_attention.py +28 -1
  75. sglang/srt/layers/rotary_embedding.py +15 -3
  76. sglang/srt/layers/sampler.py +5 -10
  77. sglang/srt/lora/backend/base_backend.py +18 -2
  78. sglang/srt/lora/backend/flashinfer_backend.py +1 -1
  79. sglang/srt/lora/backend/triton_backend.py +1 -1
  80. sglang/srt/lora/layers.py +1 -1
  81. sglang/srt/lora/lora.py +1 -1
  82. sglang/srt/lora/lora_manager.py +1 -1
  83. sglang/srt/managers/detokenizer_manager.py +0 -1
  84. sglang/srt/managers/io_struct.py +255 -97
  85. sglang/srt/managers/mm_utils.py +7 -5
  86. sglang/srt/managers/multimodal_processor.py +0 -2
  87. sglang/srt/managers/multimodal_processors/base_processor.py +117 -79
  88. sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
  89. sglang/srt/managers/multimodal_processors/mllama4.py +21 -36
  90. sglang/srt/managers/schedule_batch.py +64 -25
  91. sglang/srt/managers/scheduler.py +80 -82
  92. sglang/srt/managers/tokenizer_manager.py +18 -3
  93. sglang/srt/managers/tp_worker.py +1 -0
  94. sglang/srt/mem_cache/hiradix_cache.py +5 -1
  95. sglang/srt/mem_cache/memory_pool.py +21 -3
  96. sglang/srt/metrics/collector.py +9 -0
  97. sglang/srt/model_executor/cuda_graph_runner.py +9 -6
  98. sglang/srt/model_executor/forward_batch_info.py +234 -15
  99. sglang/srt/model_executor/model_runner.py +67 -35
  100. sglang/srt/model_loader/loader.py +31 -4
  101. sglang/srt/model_loader/weight_utils.py +4 -2
  102. sglang/srt/models/baichuan.py +2 -0
  103. sglang/srt/models/bert.py +398 -0
  104. sglang/srt/models/chatglm.py +1 -0
  105. sglang/srt/models/commandr.py +1 -0
  106. sglang/srt/models/dbrx.py +1 -0
  107. sglang/srt/models/deepseek.py +2 -1
  108. sglang/srt/models/deepseek_nextn.py +74 -70
  109. sglang/srt/models/deepseek_v2.py +494 -366
  110. sglang/srt/models/exaone.py +1 -0
  111. sglang/srt/models/gemma.py +1 -0
  112. sglang/srt/models/gemma2.py +1 -0
  113. sglang/srt/models/gemma3_causal.py +1 -0
  114. sglang/srt/models/gpt2.py +1 -0
  115. sglang/srt/models/gpt_bigcode.py +1 -0
  116. sglang/srt/models/granite.py +1 -0
  117. sglang/srt/models/grok.py +1 -0
  118. sglang/srt/models/internlm2.py +1 -0
  119. sglang/srt/models/llama.py +6 -5
  120. sglang/srt/models/llama4.py +101 -34
  121. sglang/srt/models/minicpm.py +1 -0
  122. sglang/srt/models/minicpm3.py +30 -200
  123. sglang/srt/models/mixtral.py +1 -0
  124. sglang/srt/models/mixtral_quant.py +1 -0
  125. sglang/srt/models/mllama.py +51 -8
  126. sglang/srt/models/mllama4.py +102 -29
  127. sglang/srt/models/olmo.py +1 -0
  128. sglang/srt/models/olmo2.py +1 -0
  129. sglang/srt/models/olmoe.py +1 -0
  130. sglang/srt/models/phi3_small.py +1 -0
  131. sglang/srt/models/qwen.py +1 -0
  132. sglang/srt/models/qwen2.py +5 -1
  133. sglang/srt/models/qwen2_5_vl.py +35 -70
  134. sglang/srt/models/qwen2_moe.py +15 -13
  135. sglang/srt/models/qwen2_vl.py +27 -25
  136. sglang/srt/models/qwen3.py +335 -0
  137. sglang/srt/models/qwen3_moe.py +423 -0
  138. sglang/srt/models/stablelm.py +1 -0
  139. sglang/srt/models/xverse.py +1 -0
  140. sglang/srt/models/xverse_moe.py +1 -0
  141. sglang/srt/openai_api/adapter.py +4 -1
  142. sglang/srt/patch_torch.py +11 -0
  143. sglang/srt/reasoning_parser.py +0 -1
  144. sglang/srt/sampling/sampling_batch_info.py +2 -3
  145. sglang/srt/server_args.py +55 -19
  146. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
  147. sglang/srt/speculative/eagle_utils.py +1 -11
  148. sglang/srt/speculative/eagle_worker.py +10 -9
  149. sglang/srt/utils.py +136 -10
  150. sglang/test/attention/test_flashattn_backend.py +259 -221
  151. sglang/test/attention/test_flashattn_mla_backend.py +285 -0
  152. sglang/test/attention/test_prefix_chunk_info.py +224 -0
  153. sglang/test/runners.py +5 -1
  154. sglang/test/test_block_fp8.py +224 -0
  155. sglang/test/test_custom_ops.py +1 -1
  156. sglang/test/test_utils.py +19 -8
  157. sglang/version.py +1 -1
  158. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/METADATA +15 -5
  159. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/RECORD +162 -147
  160. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/WHEEL +1 -1
  161. sglang/lang/__init__.py +0 -0
  162. sglang/srt/disaggregation/conn.py +0 -81
  163. sglang/srt/lora/backend/__init__.py +0 -25
  164. sglang/srt/server.py +0 -18
  165. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/licenses/LICENSE +0 -0
  166. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 64,
5
+ "BLOCK_SIZE_K": 64,
6
+ "GROUP_SIZE_M": 64,
7
+ "num_warps": 4,
8
+ "num_stages": 3
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 64,
13
+ "BLOCK_SIZE_K": 64,
14
+ "GROUP_SIZE_M": 64,
15
+ "num_warps": 4,
16
+ "num_stages": 4
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 64,
21
+ "BLOCK_SIZE_K": 64,
22
+ "GROUP_SIZE_M": 64,
23
+ "num_warps": 4,
24
+ "num_stages": 4
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 64,
29
+ "BLOCK_SIZE_K": 64,
30
+ "GROUP_SIZE_M": 16,
31
+ "num_warps": 4,
32
+ "num_stages": 3
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 64,
37
+ "BLOCK_SIZE_K": 64,
38
+ "GROUP_SIZE_M": 16,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 64,
45
+ "BLOCK_SIZE_K": 64,
46
+ "GROUP_SIZE_M": 1,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 64,
53
+ "BLOCK_SIZE_K": 64,
54
+ "GROUP_SIZE_M": 16,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 64,
61
+ "BLOCK_SIZE_K": 64,
62
+ "GROUP_SIZE_M": 1,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 64,
70
+ "GROUP_SIZE_M": 64,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 64,
78
+ "GROUP_SIZE_M": 16,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 64,
86
+ "GROUP_SIZE_M": 32,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 64,
94
+ "GROUP_SIZE_M": 64,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 32,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 64,
102
+ "GROUP_SIZE_M": 64,
103
+ "num_warps": 4,
104
+ "num_stages": 3
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 32,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 64,
110
+ "GROUP_SIZE_M": 64,
111
+ "num_warps": 4,
112
+ "num_stages": 2
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 64,
116
+ "BLOCK_SIZE_N": 128,
117
+ "BLOCK_SIZE_K": 64,
118
+ "GROUP_SIZE_M": 64,
119
+ "num_warps": 4,
120
+ "num_stages": 3
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 64,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 64,
126
+ "GROUP_SIZE_M": 64,
127
+ "num_warps": 4,
128
+ "num_stages": 4
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 128,
132
+ "BLOCK_SIZE_N": 128,
133
+ "BLOCK_SIZE_K": 64,
134
+ "GROUP_SIZE_M": 64,
135
+ "num_warps": 4,
136
+ "num_stages": 3
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 128,
140
+ "BLOCK_SIZE_N": 64,
141
+ "BLOCK_SIZE_K": 64,
142
+ "GROUP_SIZE_M": 16,
143
+ "num_warps": 4,
144
+ "num_stages": 2
145
+ }
146
+ }
@@ -13,6 +13,7 @@ import triton
13
13
  import triton.language as tl
14
14
 
15
15
  from sglang.srt.layers.moe.topk import select_experts
16
+ from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
16
17
  from sglang.srt.utils import (
17
18
  direct_register_custom_op,
18
19
  get_bool_env_var,
@@ -22,28 +23,25 @@ from sglang.srt.utils import (
22
23
  )
23
24
 
24
25
  _is_hip = is_hip()
25
-
26
-
27
- logger = logging.getLogger(__name__)
28
- padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
29
-
30
- enable_moe_align_block_size_triton = bool(
31
- int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
32
- )
33
-
34
26
  _is_cuda = is_cuda()
35
27
 
36
28
  if _is_cuda:
37
29
  from sgl_kernel import gelu_and_mul, silu_and_mul
38
-
39
- from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
40
30
  else:
41
31
  from vllm import _custom_ops as vllm_ops
32
+ from vllm._custom_ops import scaled_fp8_quant
42
33
 
43
34
  if _is_cuda or _is_hip:
44
35
  from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
45
36
 
46
37
 
38
+ logger = logging.getLogger(__name__)
39
+ padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
40
+ enable_moe_align_block_size_triton = bool(
41
+ int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
42
+ )
43
+
44
+
47
45
  @triton.jit
48
46
  def write_zeros_to_output(
49
47
  c_ptr,
@@ -342,6 +340,7 @@ def fused_moe_kernel(
342
340
  use_fp8_w8a8: tl.constexpr,
343
341
  use_int8_w8a8: tl.constexpr,
344
342
  use_int8_w8a16: tl.constexpr,
343
+ per_channel_quant: tl.constexpr,
345
344
  even_Ks: tl.constexpr,
346
345
  ):
347
346
  """
@@ -416,20 +415,7 @@ def fused_moe_kernel(
416
415
  )
417
416
  b_scale = tl.load(b_scale_ptrs)
418
417
 
419
- if use_fp8_w8a8:
420
- # block-wise
421
- if group_k > 0 and group_n > 0:
422
- a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
423
- offs_bsn = offs_bn // group_n
424
- b_scale_ptrs = (
425
- b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
426
- )
427
- # tensor-wise
428
- else:
429
- a_scale = tl.load(a_scale_ptr)
430
- b_scale = tl.load(b_scale_ptr + off_experts)
431
-
432
- if use_int8_w8a8:
418
+ if use_fp8_w8a8 or use_int8_w8a8:
433
419
  # block-wise
434
420
  if group_k > 0 and group_n > 0:
435
421
  a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
@@ -438,8 +424,7 @@ def fused_moe_kernel(
438
424
  b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
439
425
  )
440
426
  # channel-wise
441
- else:
442
- # Load per-column scale for weights
427
+ elif per_channel_quant:
443
428
  b_scale_ptrs = (
444
429
  b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
445
430
  )
@@ -447,6 +432,10 @@ def fused_moe_kernel(
447
432
  # Load per-token scale for activations
448
433
  a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
449
434
  a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:, None]
435
+ # tensor-wise
436
+ else:
437
+ a_scale = tl.load(a_scale_ptr)
438
+ b_scale = tl.load(b_scale_ptr + off_experts)
450
439
 
451
440
  # -----------------------------------------------------------
452
441
  # Iterate to compute a block of the C matrix.
@@ -711,12 +700,12 @@ def moe_align_block_size(
711
700
  num_tokens_post_pad,
712
701
  )
713
702
  else:
714
- token_cnts_buffer = torch.zeros(
703
+ token_cnts_buffer = torch.empty(
715
704
  (num_experts + 1) * num_experts,
716
705
  dtype=torch.int32,
717
706
  device=topk_ids.device,
718
707
  )
719
- cumsum_buffer = torch.zeros(
708
+ cumsum_buffer = torch.empty(
720
709
  num_experts + 1, dtype=torch.int32, device=topk_ids.device
721
710
  )
722
711
 
@@ -753,6 +742,7 @@ def invoke_fused_moe_kernel(
753
742
  use_int8_w8a8: bool,
754
743
  use_int8_w8a16: bool,
755
744
  use_int4_w4a16: bool,
745
+ per_channel_quant: bool,
756
746
  block_shape: Optional[List[int]] = None,
757
747
  no_combine: bool = False,
758
748
  ) -> None:
@@ -765,6 +755,8 @@ def invoke_fused_moe_kernel(
765
755
  from sglang.srt.layers.quantization.fp8_kernel import (
766
756
  sglang_per_token_group_quant_fp8,
767
757
  )
758
+ else:
759
+ from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
768
760
 
769
761
  assert topk_weights.stride(1) == 1
770
762
  assert sorted_token_ids.stride(0) == 1
@@ -775,10 +767,10 @@ def invoke_fused_moe_kernel(
775
767
  if block_shape is None:
776
768
  # activation tensor-wise fp8 quantization, dynamic or static
777
769
  padded_size = padding_size
778
- if _is_cuda:
779
- A, A_scale = sgl_scaled_fp8_quant(A, A_scale)
780
- else:
781
- A, A_scale = vllm_ops.scaled_fp8_quant(A, A_scale)
770
+ # activations apply per-token quantization when weights apply per-channel quantization by default
771
+ A, A_scale = scaled_fp8_quant(
772
+ A, A_scale, use_per_token_if_dynamic=per_channel_quant
773
+ )
782
774
  else:
783
775
  # activation block-wise fp8 quantization
784
776
  assert len(block_shape) == 2
@@ -794,6 +786,9 @@ def invoke_fused_moe_kernel(
794
786
  assert B_scale is not None
795
787
  if block_shape is None:
796
788
  # activation channel-wise int8 quantization
789
+ assert (
790
+ per_channel_quant
791
+ ), "int8 quantization only supports channel-wise quantization except for block-wise quantization"
797
792
  A, A_scale = per_token_quant_int8(A)
798
793
  else:
799
794
  # activation block-wise int8 quantization
@@ -902,6 +897,7 @@ def invoke_fused_moe_kernel(
902
897
  use_fp8_w8a8=use_fp8_w8a8,
903
898
  use_int8_w8a8=use_int8_w8a8,
904
899
  use_int8_w8a16=use_int8_w8a16,
900
+ per_channel_quant=per_channel_quant,
905
901
  even_Ks=even_Ks,
906
902
  **config,
907
903
  )
@@ -953,7 +949,7 @@ def get_moe_configs(
953
949
  logger.warning(
954
950
  (
955
951
  "Using default MoE config. Performance might be sub-optimal! "
956
- "Config file not found at %s"
952
+ "Config file not found at %s, you can tune the config with https://github.com/sgl-project/sglang/blob/main/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py."
957
953
  ),
958
954
  config_file_path,
959
955
  )
@@ -1084,6 +1080,7 @@ def inplace_fused_experts(
1084
1080
  use_int8_w8a8: bool = False,
1085
1081
  use_int8_w8a16: bool = False,
1086
1082
  use_int4_w4a16: bool = False,
1083
+ per_channel_quant: bool = False,
1087
1084
  w1_scale: Optional[torch.Tensor] = None,
1088
1085
  w2_scale: Optional[torch.Tensor] = None,
1089
1086
  w1_zp: Optional[torch.Tensor] = None,
@@ -1105,6 +1102,7 @@ def inplace_fused_experts(
1105
1102
  use_int8_w8a8,
1106
1103
  use_int8_w8a16,
1107
1104
  use_int4_w4a16,
1105
+ per_channel_quant,
1108
1106
  w1_scale,
1109
1107
  w2_scale,
1110
1108
  w1_zp,
@@ -1127,6 +1125,7 @@ def inplace_fused_experts_fake(
1127
1125
  use_int8_w8a8: bool = False,
1128
1126
  use_int8_w8a16: bool = False,
1129
1127
  use_int4_w4a16: bool = False,
1128
+ per_channel_quant: bool = False,
1130
1129
  w1_scale: Optional[torch.Tensor] = None,
1131
1130
  w2_scale: Optional[torch.Tensor] = None,
1132
1131
  w1_zp: Optional[torch.Tensor] = None,
@@ -1158,6 +1157,7 @@ def outplace_fused_experts(
1158
1157
  use_int8_w8a8: bool = False,
1159
1158
  use_int8_w8a16: bool = False,
1160
1159
  use_int4_w4a16: bool = False,
1160
+ per_channel_quant: bool = False,
1161
1161
  w1_scale: Optional[torch.Tensor] = None,
1162
1162
  w2_scale: Optional[torch.Tensor] = None,
1163
1163
  w1_zp: Optional[torch.Tensor] = None,
@@ -1180,6 +1180,7 @@ def outplace_fused_experts(
1180
1180
  use_int8_w8a8,
1181
1181
  use_int8_w8a16,
1182
1182
  use_int4_w4a16,
1183
+ per_channel_quant,
1183
1184
  w1_scale,
1184
1185
  w2_scale,
1185
1186
  w1_zp,
@@ -1203,6 +1204,7 @@ def outplace_fused_experts_fake(
1203
1204
  use_int8_w8a8: bool = False,
1204
1205
  use_int8_w8a16: bool = False,
1205
1206
  use_int4_w4a16: bool = False,
1207
+ per_channel_quant: bool = False,
1206
1208
  w1_scale: Optional[torch.Tensor] = None,
1207
1209
  w2_scale: Optional[torch.Tensor] = None,
1208
1210
  w1_zp: Optional[torch.Tensor] = None,
@@ -1236,6 +1238,7 @@ def fused_experts(
1236
1238
  use_int8_w8a8: bool = False,
1237
1239
  use_int8_w8a16: bool = False,
1238
1240
  use_int4_w4a16: bool = False,
1241
+ per_channel_quant: bool = False,
1239
1242
  w1_scale: Optional[torch.Tensor] = None,
1240
1243
  w2_scale: Optional[torch.Tensor] = None,
1241
1244
  w1_zp: Optional[torch.Tensor] = None,
@@ -1259,6 +1262,7 @@ def fused_experts(
1259
1262
  use_int8_w8a8,
1260
1263
  use_int8_w8a16,
1261
1264
  use_int4_w4a16,
1265
+ per_channel_quant,
1262
1266
  w1_scale,
1263
1267
  w2_scale,
1264
1268
  w1_zp,
@@ -1281,6 +1285,7 @@ def fused_experts(
1281
1285
  use_int8_w8a8,
1282
1286
  use_int8_w8a16,
1283
1287
  use_int4_w4a16,
1288
+ per_channel_quant,
1284
1289
  w1_scale,
1285
1290
  w2_scale,
1286
1291
  w1_zp,
@@ -1305,6 +1310,7 @@ def fused_experts_impl(
1305
1310
  use_int8_w8a8: bool = False,
1306
1311
  use_int8_w8a16: bool = False,
1307
1312
  use_int4_w4a16: bool = False,
1313
+ per_channel_quant: bool = False,
1308
1314
  w1_scale: Optional[torch.Tensor] = None,
1309
1315
  w2_scale: Optional[torch.Tensor] = None,
1310
1316
  w1_zp: Optional[torch.Tensor] = None,
@@ -1441,6 +1447,7 @@ def fused_experts_impl(
1441
1447
  use_int8_w8a8=use_int8_w8a8,
1442
1448
  use_int8_w8a16=use_int8_w8a16,
1443
1449
  use_int4_w4a16=use_int4_w4a16,
1450
+ per_channel_quant=per_channel_quant,
1444
1451
  block_shape=block_shape,
1445
1452
  )
1446
1453
  if activation == "silu":
@@ -1484,6 +1491,7 @@ def fused_experts_impl(
1484
1491
  use_int8_w8a8=use_int8_w8a8,
1485
1492
  use_int8_w8a16=use_int8_w8a16,
1486
1493
  use_int4_w4a16=use_int4_w4a16,
1494
+ per_channel_quant=per_channel_quant,
1487
1495
  block_shape=block_shape,
1488
1496
  )
1489
1497
 
@@ -1530,6 +1538,7 @@ def fused_moe(
1530
1538
  use_int8_w8a8: bool = False,
1531
1539
  use_int8_w8a16: bool = False,
1532
1540
  use_int4_w4a16: bool = False,
1541
+ per_channel_quant: bool = False,
1533
1542
  w1_scale: Optional[torch.Tensor] = None,
1534
1543
  w2_scale: Optional[torch.Tensor] = None,
1535
1544
  w1_zp: Optional[torch.Tensor] = None,
@@ -1538,6 +1547,7 @@ def fused_moe(
1538
1547
  a2_scale: Optional[torch.Tensor] = None,
1539
1548
  block_shape: Optional[List[int]] = None,
1540
1549
  no_combine: bool = False,
1550
+ routed_scaling_factor: Optional[float] = None,
1541
1551
  ) -> torch.Tensor:
1542
1552
  """
1543
1553
  This function computes a Mixture of Experts (MoE) layer using two sets of
@@ -1592,6 +1602,7 @@ def fused_moe(
1592
1602
  topk_group=topk_group,
1593
1603
  num_expert_group=num_expert_group,
1594
1604
  custom_routing_function=custom_routing_function,
1605
+ routed_scaling_factor=routed_scaling_factor,
1595
1606
  )
1596
1607
 
1597
1608
  return fused_experts(
@@ -1606,6 +1617,7 @@ def fused_moe(
1606
1617
  use_int8_w8a8=use_int8_w8a8,
1607
1618
  use_int8_w8a16=use_int8_w8a16,
1608
1619
  use_int4_w4a16=use_int4_w4a16,
1620
+ per_channel_quant=per_channel_quant,
1609
1621
  w1_scale=w1_scale,
1610
1622
  w2_scale=w2_scale,
1611
1623
  w1_zp=w1_zp,
@@ -131,6 +131,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
131
131
  apply_router_weight_on_input: bool = False,
132
132
  inplace: bool = True,
133
133
  no_combine: bool = False,
134
+ routed_scaling_factor: Optional[float] = None,
134
135
  ) -> torch.Tensor:
135
136
  return self.forward(
136
137
  x=x,
@@ -147,6 +148,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
147
148
  apply_router_weight_on_input=apply_router_weight_on_input,
148
149
  inplace=inplace,
149
150
  no_combine=no_combine,
151
+ routed_scaling_factor=routed_scaling_factor,
150
152
  )
151
153
 
152
154
  def forward_cuda(
@@ -165,6 +167,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
165
167
  apply_router_weight_on_input: bool = False,
166
168
  inplace: bool = True,
167
169
  no_combine: bool = False,
170
+ routed_scaling_factor: Optional[float] = None,
168
171
  ) -> torch.Tensor:
169
172
  topk_weights, topk_ids = select_experts(
170
173
  hidden_states=x,
@@ -176,6 +179,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
176
179
  num_expert_group=num_expert_group,
177
180
  custom_routing_function=custom_routing_function,
178
181
  correction_bias=correction_bias,
182
+ routed_scaling_factor=routed_scaling_factor,
179
183
  )
180
184
 
181
185
  if _is_hip and get_bool_env_var("CK_MOE"):
@@ -284,6 +288,7 @@ class FusedMoE(torch.nn.Module):
284
288
  use_presharded_weights: bool = False,
285
289
  inplace: bool = True,
286
290
  no_combine: bool = False,
291
+ routed_scaling_factor: Optional[float] = None,
287
292
  ):
288
293
  super().__init__()
289
294
 
@@ -293,6 +298,7 @@ class FusedMoE(torch.nn.Module):
293
298
  self.tp_size = (
294
299
  tp_size if tp_size is not None else get_tensor_model_parallel_world_size()
295
300
  )
301
+ self.routed_scaling_factor = routed_scaling_factor
296
302
  self.top_k = top_k
297
303
  self.num_experts = num_experts
298
304
  assert intermediate_size % self.tp_size == 0
@@ -637,6 +643,7 @@ class FusedMoE(torch.nn.Module):
637
643
  correction_bias=self.correction_bias,
638
644
  activation=self.activation,
639
645
  apply_router_weight_on_input=self.apply_router_weight_on_input,
646
+ routed_scaling_factor=self.routed_scaling_factor,
640
647
  )
641
648
 
642
649
  if self.reduce_results and self.tp_size > 1:
@@ -5,6 +5,9 @@ import triton
5
5
  import triton.language as tl
6
6
 
7
7
  from sglang.srt.layers.moe.topk import fused_topk
8
+ from sglang.srt.utils import is_hip
9
+
10
+ _is_hip = is_hip()
8
11
 
9
12
 
10
13
  @triton.jit
@@ -116,10 +119,13 @@ def fused_moe_router_impl(
116
119
  topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device)
117
120
 
118
121
  grid = lambda meta: (bs,)
122
+
123
+ min_num_warps = 16 if _is_hip else 32
124
+
119
125
  config = {
120
126
  "BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
121
127
  "num_warps": max(
122
- min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), 32), 4
128
+ min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), min_num_warps), 4
123
129
  ),
124
130
  }
125
131