sglang 0.4.3.post2__py3-none-any.whl → 0.4.3.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (205) hide show
  1. sglang/api.py +1 -1
  2. sglang/bench_offline_throughput.py +19 -0
  3. sglang/bench_one_batch.py +2 -2
  4. sglang/bench_serving.py +123 -79
  5. sglang/global_config.py +8 -3
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/lang/ir.py +1 -1
  8. sglang/srt/_custom_ops.py +83 -91
  9. sglang/srt/configs/load_config.py +4 -1
  10. sglang/srt/configs/model_config.py +48 -2
  11. sglang/srt/configs/qwen2_5_vl_config.py +5 -2
  12. sglang/srt/constrained/base_grammar_backend.py +117 -15
  13. sglang/srt/constrained/llguidance_backend.py +151 -0
  14. sglang/srt/constrained/outlines_backend.py +24 -33
  15. sglang/srt/constrained/xgrammar_backend.py +69 -38
  16. sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
  17. sglang/srt/distributed/parallel_state.py +48 -3
  18. sglang/srt/entrypoints/engine.py +67 -9
  19. sglang/srt/entrypoints/http_server.py +190 -41
  20. sglang/srt/entrypoints/verl_engine.py +147 -0
  21. sglang/srt/function_call_parser.py +0 -1
  22. sglang/srt/layers/activation.py +11 -0
  23. sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
  24. sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
  25. sglang/srt/layers/attention/flashinfer_backend.py +220 -378
  26. sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
  27. sglang/srt/layers/attention/torch_native_backend.py +1 -1
  28. sglang/srt/layers/attention/triton_backend.py +9 -6
  29. sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
  30. sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
  31. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
  32. sglang/srt/layers/attention/utils.py +39 -0
  33. sglang/srt/layers/attention/vision.py +60 -63
  34. sglang/srt/layers/dp_attention.py +142 -1
  35. sglang/srt/layers/layernorm.py +1 -1
  36. sglang/srt/layers/linear.py +3 -1
  37. sglang/srt/layers/logits_processor.py +281 -45
  38. sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
  39. sglang/srt/layers/moe/ep_moe/layer.py +140 -28
  40. sglang/srt/layers/moe/fused_moe_native.py +2 -0
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
  48. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
  51. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
  55. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
  63. sglang/srt/layers/moe/topk.py +13 -4
  64. sglang/srt/layers/quantization/__init__.py +111 -7
  65. sglang/srt/layers/quantization/blockwise_int8.py +409 -0
  66. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  68. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  69. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  70. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  71. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  72. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  73. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  74. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  75. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  76. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  77. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  78. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  79. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  80. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  81. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  82. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  83. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  84. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  85. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  86. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  88. sglang/srt/layers/quantization/fp8.py +69 -28
  89. sglang/srt/layers/quantization/fp8_utils.py +17 -1
  90. sglang/srt/layers/quantization/gptq.py +416 -0
  91. sglang/srt/layers/quantization/int8_kernel.py +327 -0
  92. sglang/srt/layers/quantization/int8_utils.py +73 -0
  93. sglang/srt/layers/quantization/modelopt_quant.py +18 -1
  94. sglang/srt/layers/radix_attention.py +1 -0
  95. sglang/srt/layers/rotary_embedding.py +0 -1
  96. sglang/srt/layers/sampler.py +76 -31
  97. sglang/srt/layers/vocab_parallel_embedding.py +14 -13
  98. sglang/srt/lora/lora.py +17 -1
  99. sglang/srt/lora/lora_config.py +5 -0
  100. sglang/srt/lora/lora_manager.py +1 -3
  101. sglang/srt/managers/cache_controller.py +193 -62
  102. sglang/srt/managers/configure_logging.py +2 -1
  103. sglang/srt/managers/data_parallel_controller.py +6 -2
  104. sglang/srt/managers/detokenizer_manager.py +124 -102
  105. sglang/srt/managers/image_processor.py +2 -1
  106. sglang/srt/managers/io_struct.py +143 -6
  107. sglang/srt/managers/schedule_batch.py +237 -197
  108. sglang/srt/managers/schedule_policy.py +29 -29
  109. sglang/srt/managers/scheduler.py +681 -259
  110. sglang/srt/managers/session_controller.py +6 -2
  111. sglang/srt/managers/tokenizer_manager.py +224 -68
  112. sglang/srt/managers/tp_worker.py +15 -4
  113. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  114. sglang/srt/mem_cache/chunk_cache.py +18 -11
  115. sglang/srt/mem_cache/hiradix_cache.py +394 -0
  116. sglang/srt/mem_cache/memory_pool.py +44 -18
  117. sglang/srt/mem_cache/radix_cache.py +58 -47
  118. sglang/srt/metrics/collector.py +94 -36
  119. sglang/srt/model_executor/cuda_graph_runner.py +55 -24
  120. sglang/srt/model_executor/forward_batch_info.py +49 -16
  121. sglang/srt/model_executor/model_runner.py +208 -28
  122. sglang/srt/model_loader/loader.py +3 -3
  123. sglang/srt/model_loader/weight_utils.py +36 -14
  124. sglang/srt/models/baichuan.py +31 -6
  125. sglang/srt/models/chatglm.py +39 -7
  126. sglang/srt/models/commandr.py +29 -5
  127. sglang/srt/models/dbrx.py +31 -5
  128. sglang/srt/models/deepseek.py +43 -6
  129. sglang/srt/models/deepseek_nextn.py +32 -19
  130. sglang/srt/models/deepseek_v2.py +265 -32
  131. sglang/srt/models/exaone.py +19 -9
  132. sglang/srt/models/gemma.py +22 -8
  133. sglang/srt/models/gemma2.py +25 -12
  134. sglang/srt/models/gemma2_reward.py +5 -1
  135. sglang/srt/models/gpt2.py +28 -13
  136. sglang/srt/models/gpt_bigcode.py +27 -5
  137. sglang/srt/models/granite.py +21 -9
  138. sglang/srt/models/grok.py +21 -4
  139. sglang/srt/models/internlm2.py +36 -6
  140. sglang/srt/models/internlm2_reward.py +5 -1
  141. sglang/srt/models/llama.py +26 -9
  142. sglang/srt/models/llama_classification.py +5 -1
  143. sglang/srt/models/llama_eagle.py +17 -4
  144. sglang/srt/models/llama_embedding.py +5 -1
  145. sglang/srt/models/llama_reward.py +7 -2
  146. sglang/srt/models/llava.py +19 -3
  147. sglang/srt/models/llavavid.py +10 -1
  148. sglang/srt/models/minicpm.py +26 -2
  149. sglang/srt/models/minicpm3.py +39 -3
  150. sglang/srt/models/minicpmv.py +45 -14
  151. sglang/srt/models/mixtral.py +20 -9
  152. sglang/srt/models/mixtral_quant.py +50 -8
  153. sglang/srt/models/mllama.py +57 -11
  154. sglang/srt/models/olmo.py +34 -6
  155. sglang/srt/models/olmo2.py +34 -13
  156. sglang/srt/models/olmoe.py +26 -4
  157. sglang/srt/models/phi3_small.py +29 -10
  158. sglang/srt/models/qwen.py +26 -3
  159. sglang/srt/models/qwen2.py +26 -4
  160. sglang/srt/models/qwen2_5_vl.py +46 -8
  161. sglang/srt/models/qwen2_eagle.py +17 -5
  162. sglang/srt/models/qwen2_moe.py +44 -6
  163. sglang/srt/models/qwen2_rm.py +78 -0
  164. sglang/srt/models/qwen2_vl.py +39 -8
  165. sglang/srt/models/stablelm.py +32 -5
  166. sglang/srt/models/torch_native_llama.py +5 -2
  167. sglang/srt/models/xverse.py +21 -9
  168. sglang/srt/models/xverse_moe.py +45 -7
  169. sglang/srt/models/yivl.py +2 -1
  170. sglang/srt/openai_api/adapter.py +109 -24
  171. sglang/srt/openai_api/protocol.py +17 -1
  172. sglang/srt/reasoning_parser.py +154 -0
  173. sglang/srt/sampling/penaltylib/__init__.py +4 -6
  174. sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
  175. sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
  176. sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
  177. sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
  178. sglang/srt/sampling/sampling_batch_info.py +79 -157
  179. sglang/srt/sampling/sampling_params.py +16 -13
  180. sglang/srt/server_args.py +136 -52
  181. sglang/srt/speculative/build_eagle_tree.py +2 -8
  182. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +0 -1
  183. sglang/srt/speculative/eagle_utils.py +92 -58
  184. sglang/srt/speculative/eagle_worker.py +186 -94
  185. sglang/srt/speculative/spec_info.py +1 -13
  186. sglang/srt/utils.py +43 -17
  187. sglang/srt/warmup.py +47 -0
  188. sglang/test/few_shot_gsm8k.py +4 -1
  189. sglang/test/runners.py +389 -126
  190. sglang/test/send_one.py +88 -0
  191. sglang/test/test_block_fp8_ep.py +361 -0
  192. sglang/test/test_programs.py +1 -1
  193. sglang/test/test_utils.py +138 -84
  194. sglang/utils.py +50 -60
  195. sglang/version.py +1 -1
  196. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/METADATA +21 -15
  197. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/RECORD +200 -166
  198. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/WHEEL +1 -1
  199. sglang/bench_latency.py +0 -1
  200. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
  201. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
  202. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
  203. sglang/test/srt/sampling/penaltylib/utils.py +0 -344
  204. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/LICENSE +0 -0
  205. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/top_level.txt +0 -0
@@ -5,7 +5,7 @@
5
5
  "BLOCK_SIZE_K": 256,
6
6
  "GROUP_SIZE_M": 1,
7
7
  "num_warps": 4,
8
- "num_stages": 0,
8
+ "num_stages": 2,
9
9
  "waves_per_eu": 4,
10
10
  "matrix_instr_nonkdim": 16,
11
11
  "kpack": 2
@@ -16,7 +16,7 @@
16
16
  "BLOCK_SIZE_K": 256,
17
17
  "GROUP_SIZE_M": 1,
18
18
  "num_warps": 4,
19
- "num_stages": 0,
19
+ "num_stages": 2,
20
20
  "waves_per_eu": 1,
21
21
  "matrix_instr_nonkdim": 16,
22
22
  "kpack": 1
@@ -27,7 +27,7 @@
27
27
  "BLOCK_SIZE_K": 256,
28
28
  "GROUP_SIZE_M": 1,
29
29
  "num_warps": 8,
30
- "num_stages": 0,
30
+ "num_stages": 2,
31
31
  "waves_per_eu": 2,
32
32
  "matrix_instr_nonkdim": 16,
33
33
  "kpack": 2
@@ -38,7 +38,7 @@
38
38
  "BLOCK_SIZE_K": 256,
39
39
  "GROUP_SIZE_M": 1,
40
40
  "num_warps": 8,
41
- "num_stages": 0,
41
+ "num_stages": 2,
42
42
  "waves_per_eu": 1,
43
43
  "matrix_instr_nonkdim": 16,
44
44
  "kpack": 2
@@ -49,7 +49,7 @@
49
49
  "BLOCK_SIZE_K": 256,
50
50
  "GROUP_SIZE_M": 1,
51
51
  "num_warps": 4,
52
- "num_stages": 0,
52
+ "num_stages": 2,
53
53
  "waves_per_eu": 2,
54
54
  "matrix_instr_nonkdim": 16,
55
55
  "kpack": 2
@@ -60,7 +60,7 @@
60
60
  "BLOCK_SIZE_K": 256,
61
61
  "GROUP_SIZE_M": 1,
62
62
  "num_warps": 4,
63
- "num_stages": 0,
63
+ "num_stages": 2,
64
64
  "waves_per_eu": 1,
65
65
  "matrix_instr_nonkdim": 16,
66
66
  "kpack": 1
@@ -79,7 +79,7 @@
79
79
  "BLOCK_SIZE_K": 256,
80
80
  "GROUP_SIZE_M": 1,
81
81
  "num_warps": 4,
82
- "num_stages": 0,
82
+ "num_stages": 2,
83
83
  "waves_per_eu": 2,
84
84
  "matrix_instr_nonkdim": 16,
85
85
  "kpack": 2
@@ -90,7 +90,7 @@
90
90
  "BLOCK_SIZE_K": 128,
91
91
  "GROUP_SIZE_M": 1,
92
92
  "num_warps": 8,
93
- "num_stages": 0,
93
+ "num_stages": 2,
94
94
  "waves_per_eu": 4,
95
95
  "matrix_instr_nonkdim": 16,
96
96
  "kpack": 2
@@ -101,7 +101,7 @@
101
101
  "BLOCK_SIZE_K": 128,
102
102
  "GROUP_SIZE_M": 1,
103
103
  "num_warps": 8,
104
- "num_stages": 0,
104
+ "num_stages": 2,
105
105
  "waves_per_eu": 2,
106
106
  "matrix_instr_nonkdim": 16,
107
107
  "kpack": 2
@@ -112,7 +112,7 @@
112
112
  "BLOCK_SIZE_K": 128,
113
113
  "GROUP_SIZE_M": 1,
114
114
  "num_warps": 8,
115
- "num_stages": 0,
115
+ "num_stages": 2,
116
116
  "waves_per_eu": 2,
117
117
  "matrix_instr_nonkdim": 16,
118
118
  "kpack": 2
@@ -123,7 +123,7 @@
123
123
  "BLOCK_SIZE_K": 64,
124
124
  "GROUP_SIZE_M": 1,
125
125
  "num_warps": 8,
126
- "num_stages": 0,
126
+ "num_stages": 2,
127
127
  "waves_per_eu": 2,
128
128
  "matrix_instr_nonkdim": 16,
129
129
  "kpack": 1
@@ -134,7 +134,7 @@
134
134
  "BLOCK_SIZE_K": 64,
135
135
  "GROUP_SIZE_M": 1,
136
136
  "num_warps": 8,
137
- "num_stages": 0,
137
+ "num_stages": 2,
138
138
  "waves_per_eu": 1,
139
139
  "matrix_instr_nonkdim": 16,
140
140
  "kpack": 1
@@ -145,7 +145,7 @@
145
145
  "BLOCK_SIZE_K": 64,
146
146
  "GROUP_SIZE_M": 1,
147
147
  "num_warps": 8,
148
- "num_stages": 0,
148
+ "num_stages": 2,
149
149
  "waves_per_eu": 0,
150
150
  "matrix_instr_nonkdim": 16,
151
151
  "kpack": 1
@@ -156,7 +156,7 @@
156
156
  "BLOCK_SIZE_K": 64,
157
157
  "GROUP_SIZE_M": 1,
158
158
  "num_warps": 8,
159
- "num_stages": 0,
159
+ "num_stages": 2,
160
160
  "waves_per_eu": 1,
161
161
  "matrix_instr_nonkdim": 16,
162
162
  "kpack": 1
@@ -167,7 +167,7 @@
167
167
  "BLOCK_SIZE_K": 64,
168
168
  "GROUP_SIZE_M": 1,
169
169
  "num_warps": 4,
170
- "num_stages": 0,
170
+ "num_stages": 2,
171
171
  "waves_per_eu": 2,
172
172
  "matrix_instr_nonkdim": 16,
173
173
  "kpack": 2
@@ -5,7 +5,7 @@
5
5
  "BLOCK_SIZE_K": 256,
6
6
  "GROUP_SIZE_M": 1,
7
7
  "num_warps": 4,
8
- "num_stages": 0,
8
+ "num_stages": 2,
9
9
  "waves_per_eu": 4,
10
10
  "matrix_instr_nonkdim": 16,
11
11
  "kpack": 2
@@ -16,7 +16,7 @@
16
16
  "BLOCK_SIZE_K": 256,
17
17
  "GROUP_SIZE_M": 1,
18
18
  "num_warps": 4,
19
- "num_stages": 0,
19
+ "num_stages": 2,
20
20
  "waves_per_eu": 1,
21
21
  "matrix_instr_nonkdim": 16,
22
22
  "kpack": 1
@@ -27,7 +27,7 @@
27
27
  "BLOCK_SIZE_K": 256,
28
28
  "GROUP_SIZE_M": 1,
29
29
  "num_warps": 8,
30
- "num_stages": 0,
30
+ "num_stages": 2,
31
31
  "waves_per_eu": 2,
32
32
  "matrix_instr_nonkdim": 16,
33
33
  "kpack": 2
@@ -38,7 +38,7 @@
38
38
  "BLOCK_SIZE_K": 256,
39
39
  "GROUP_SIZE_M": 1,
40
40
  "num_warps": 8,
41
- "num_stages": 0,
41
+ "num_stages": 2,
42
42
  "waves_per_eu": 1,
43
43
  "matrix_instr_nonkdim": 16,
44
44
  "kpack": 2
@@ -49,7 +49,7 @@
49
49
  "BLOCK_SIZE_K": 256,
50
50
  "GROUP_SIZE_M": 1,
51
51
  "num_warps": 4,
52
- "num_stages": 0,
52
+ "num_stages": 2,
53
53
  "waves_per_eu": 2,
54
54
  "matrix_instr_nonkdim": 16,
55
55
  "kpack": 2
@@ -60,7 +60,7 @@
60
60
  "BLOCK_SIZE_K": 256,
61
61
  "GROUP_SIZE_M": 1,
62
62
  "num_warps": 4,
63
- "num_stages": 0,
63
+ "num_stages": 2,
64
64
  "waves_per_eu": 1,
65
65
  "matrix_instr_nonkdim": 16,
66
66
  "kpack": 1
@@ -79,7 +79,7 @@
79
79
  "BLOCK_SIZE_K": 256,
80
80
  "GROUP_SIZE_M": 1,
81
81
  "num_warps": 4,
82
- "num_stages": 0,
82
+ "num_stages": 2,
83
83
  "waves_per_eu": 2,
84
84
  "matrix_instr_nonkdim": 16,
85
85
  "kpack": 2
@@ -90,7 +90,7 @@
90
90
  "BLOCK_SIZE_K": 128,
91
91
  "GROUP_SIZE_M": 1,
92
92
  "num_warps": 8,
93
- "num_stages": 0,
93
+ "num_stages": 2,
94
94
  "waves_per_eu": 4,
95
95
  "matrix_instr_nonkdim": 16,
96
96
  "kpack": 2
@@ -101,7 +101,7 @@
101
101
  "BLOCK_SIZE_K": 128,
102
102
  "GROUP_SIZE_M": 1,
103
103
  "num_warps": 8,
104
- "num_stages": 0,
104
+ "num_stages": 2,
105
105
  "waves_per_eu": 2,
106
106
  "matrix_instr_nonkdim": 16,
107
107
  "kpack": 2
@@ -112,7 +112,7 @@
112
112
  "BLOCK_SIZE_K": 128,
113
113
  "GROUP_SIZE_M": 1,
114
114
  "num_warps": 8,
115
- "num_stages": 0,
115
+ "num_stages": 2,
116
116
  "waves_per_eu": 2,
117
117
  "matrix_instr_nonkdim": 16,
118
118
  "kpack": 2
@@ -123,7 +123,7 @@
123
123
  "BLOCK_SIZE_K": 64,
124
124
  "GROUP_SIZE_M": 1,
125
125
  "num_warps": 8,
126
- "num_stages": 0,
126
+ "num_stages": 2,
127
127
  "waves_per_eu": 2,
128
128
  "matrix_instr_nonkdim": 16,
129
129
  "kpack": 1
@@ -134,7 +134,7 @@
134
134
  "BLOCK_SIZE_K": 64,
135
135
  "GROUP_SIZE_M": 1,
136
136
  "num_warps": 8,
137
- "num_stages": 0,
137
+ "num_stages": 2,
138
138
  "waves_per_eu": 1,
139
139
  "matrix_instr_nonkdim": 16,
140
140
  "kpack": 1
@@ -145,7 +145,7 @@
145
145
  "BLOCK_SIZE_K": 64,
146
146
  "GROUP_SIZE_M": 1,
147
147
  "num_warps": 8,
148
- "num_stages": 0,
148
+ "num_stages": 2,
149
149
  "waves_per_eu": 0,
150
150
  "matrix_instr_nonkdim": 16,
151
151
  "kpack": 1
@@ -156,7 +156,7 @@
156
156
  "BLOCK_SIZE_K": 64,
157
157
  "GROUP_SIZE_M": 1,
158
158
  "num_warps": 8,
159
- "num_stages": 0,
159
+ "num_stages": 2,
160
160
  "waves_per_eu": 1,
161
161
  "matrix_instr_nonkdim": 16,
162
162
  "kpack": 1
@@ -167,7 +167,7 @@
167
167
  "BLOCK_SIZE_K": 64,
168
168
  "GROUP_SIZE_M": 1,
169
169
  "num_warps": 4,
170
- "num_stages": 0,
170
+ "num_stages": 2,
171
171
  "waves_per_eu": 2,
172
172
  "matrix_instr_nonkdim": 16,
173
173
  "kpack": 2
@@ -15,9 +15,16 @@ from vllm import _custom_ops as ops
15
15
 
16
16
  from sglang.srt.layers.moe.topk import select_experts
17
17
  from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
18
- from sglang.srt.utils import direct_register_custom_op, get_device_name, is_hip
18
+ from sglang.srt.layers.quantization.int8_kernel import per_token_group_quant_int8
19
+ from sglang.srt.utils import (
20
+ direct_register_custom_op,
21
+ get_bool_env_var,
22
+ get_device_name,
23
+ is_cuda_available,
24
+ is_hip,
25
+ )
19
26
 
20
- is_hip_flag = is_hip()
27
+ is_hip_ = is_hip()
21
28
 
22
29
 
23
30
  logger = logging.getLogger(__name__)
@@ -86,6 +93,7 @@ def fused_moe_kernel(
86
93
  top_k: tl.constexpr,
87
94
  compute_type: tl.constexpr,
88
95
  use_fp8_w8a8: tl.constexpr,
96
+ use_int8_w8a8: tl.constexpr,
89
97
  use_int8_w8a16: tl.constexpr,
90
98
  even_Ks: tl.constexpr,
91
99
  ):
@@ -159,7 +167,7 @@ def fused_moe_kernel(
159
167
  )
160
168
  b_scale = tl.load(b_scale_ptrs)
161
169
 
162
- if use_fp8_w8a8:
170
+ if use_fp8_w8a8 or use_int8_w8a8:
163
171
  if group_k > 0 and group_n > 0:
164
172
  a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
165
173
  offs_bsn = offs_bn // group_n
@@ -198,7 +206,7 @@ def fused_moe_kernel(
198
206
  # We accumulate along the K dimension.
199
207
  if use_int8_w8a16:
200
208
  accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
201
- elif use_fp8_w8a8:
209
+ elif use_fp8_w8a8 or use_int8_w8a8:
202
210
  if group_k > 0 and group_n > 0:
203
211
  k_start = k * BLOCK_SIZE_K
204
212
  offs_ks = k_start // group_k
@@ -221,7 +229,7 @@ def fused_moe_kernel(
221
229
  accumulator = accumulator * moe_weight[:, None]
222
230
  if use_int8_w8a16:
223
231
  accumulator = (accumulator * b_scale).to(compute_type)
224
- elif use_fp8_w8a8:
232
+ elif use_fp8_w8a8 or use_int8_w8a8:
225
233
  if group_k > 0 and group_n > 0:
226
234
  accumulator = accumulator.to(compute_type)
227
235
  else:
@@ -477,8 +485,10 @@ def invoke_fused_moe_kernel(
477
485
  config: Dict[str, Any],
478
486
  compute_type: tl.dtype,
479
487
  use_fp8_w8a8: bool,
488
+ use_int8_w8a8: bool,
480
489
  use_int8_w8a16: bool,
481
490
  block_shape: Optional[List[int]] = None,
491
+ no_combine: bool = False,
482
492
  ) -> None:
483
493
  assert topk_weights.stride(1) == 1
484
494
  assert sorted_token_ids.stride(0) == 1
@@ -499,6 +509,18 @@ def invoke_fused_moe_kernel(
499
509
  assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
500
510
  assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
501
511
  assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
512
+ elif use_int8_w8a8:
513
+ assert B_scale is not None
514
+ if block_shape is None:
515
+ padded_size = padding_size
516
+ A, A_scale = ops.scaled_int8_quant(A, A_scale)
517
+ else:
518
+ assert len(block_shape) == 2
519
+ block_n, block_k = block_shape[0], block_shape[1]
520
+ A, A_scale = per_token_group_quant_int8(A, block_k)
521
+ assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
522
+ assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
523
+ assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
502
524
  elif use_int8_w8a16:
503
525
  assert B_scale is not None
504
526
  else:
@@ -548,6 +570,7 @@ def invoke_fused_moe_kernel(
548
570
  top_k=top_k,
549
571
  compute_type=compute_type,
550
572
  use_fp8_w8a8=use_fp8_w8a8,
573
+ use_int8_w8a8=use_int8_w8a8,
551
574
  use_int8_w8a16=use_int8_w8a16,
552
575
  even_Ks=even_Ks,
553
576
  **config,
@@ -625,7 +648,7 @@ def get_default_config(
625
648
  "BLOCK_SIZE_K": 128,
626
649
  "GROUP_SIZE_M": 32,
627
650
  "num_warps": 8,
628
- "num_stages": 2 if is_hip_flag else 4,
651
+ "num_stages": 2 if is_hip_ else 4,
629
652
  }
630
653
  if M <= E:
631
654
  config = {
@@ -634,7 +657,7 @@ def get_default_config(
634
657
  "BLOCK_SIZE_K": 128,
635
658
  "GROUP_SIZE_M": 1,
636
659
  "num_warps": 4,
637
- "num_stages": 2 if is_hip_flag else 4,
660
+ "num_stages": 2 if is_hip_ else 4,
638
661
  }
639
662
  else:
640
663
  # Block-wise quant: BLOCK_SIZE_K must be divisable by block_shape[1]
@@ -644,7 +667,7 @@ def get_default_config(
644
667
  "BLOCK_SIZE_K": block_shape[1],
645
668
  "GROUP_SIZE_M": 32,
646
669
  "num_warps": 4,
647
- "num_stages": 2 if is_hip_flag else 3,
670
+ "num_stages": 2 if is_hip_ else 3,
648
671
  }
649
672
  else:
650
673
  config = {
@@ -701,9 +724,12 @@ def get_config_dtype_str(
701
724
  dtype: torch.dtype,
702
725
  use_int8_w8a16: Optional[bool] = False,
703
726
  use_fp8_w8a8: Optional[bool] = False,
727
+ use_int8_w8a8: Optional[bool] = False,
704
728
  ):
705
729
  if use_fp8_w8a8:
706
730
  return "fp8_w8a8"
731
+ elif use_int8_w8a8:
732
+ return "int8_w8a8"
707
733
  elif use_int8_w8a16:
708
734
  return "int8_w8a16"
709
735
  elif dtype == torch.float:
@@ -721,6 +747,7 @@ def inplace_fused_experts(
721
747
  topk_ids: torch.Tensor,
722
748
  activation: str = "silu",
723
749
  use_fp8_w8a8: bool = False,
750
+ use_int8_w8a8: bool = False,
724
751
  use_int8_w8a16: bool = False,
725
752
  w1_scale: Optional[torch.Tensor] = None,
726
753
  w2_scale: Optional[torch.Tensor] = None,
@@ -737,6 +764,7 @@ def inplace_fused_experts(
737
764
  True,
738
765
  activation,
739
766
  use_fp8_w8a8,
767
+ use_int8_w8a8,
740
768
  use_int8_w8a16,
741
769
  w1_scale,
742
770
  w2_scale,
@@ -754,6 +782,7 @@ def inplace_fused_experts_fake(
754
782
  topk_ids: torch.Tensor,
755
783
  activation: str = "silu",
756
784
  use_fp8_w8a8: bool = False,
785
+ use_int8_w8a8: bool = False,
757
786
  use_int8_w8a16: bool = False,
758
787
  w1_scale: Optional[torch.Tensor] = None,
759
788
  w2_scale: Optional[torch.Tensor] = None,
@@ -780,12 +809,14 @@ def outplace_fused_experts(
780
809
  topk_ids: torch.Tensor,
781
810
  activation: str = "silu",
782
811
  use_fp8_w8a8: bool = False,
812
+ use_int8_w8a8: bool = False,
783
813
  use_int8_w8a16: bool = False,
784
814
  w1_scale: Optional[torch.Tensor] = None,
785
815
  w2_scale: Optional[torch.Tensor] = None,
786
816
  a1_scale: Optional[torch.Tensor] = None,
787
817
  a2_scale: Optional[torch.Tensor] = None,
788
818
  block_shape: Optional[List[int]] = None,
819
+ no_combine: bool = False,
789
820
  ) -> torch.Tensor:
790
821
  return fused_experts_impl(
791
822
  hidden_states,
@@ -796,12 +827,14 @@ def outplace_fused_experts(
796
827
  False,
797
828
  activation,
798
829
  use_fp8_w8a8,
830
+ use_int8_w8a8,
799
831
  use_int8_w8a16,
800
832
  w1_scale,
801
833
  w2_scale,
802
834
  a1_scale,
803
835
  a2_scale,
804
836
  block_shape,
837
+ no_combine=no_combine,
805
838
  )
806
839
 
807
840
 
@@ -813,12 +846,14 @@ def outplace_fused_experts_fake(
813
846
  topk_ids: torch.Tensor,
814
847
  activation: str = "silu",
815
848
  use_fp8_w8a8: bool = False,
849
+ use_int8_w8a8: bool = False,
816
850
  use_int8_w8a16: bool = False,
817
851
  w1_scale: Optional[torch.Tensor] = None,
818
852
  w2_scale: Optional[torch.Tensor] = None,
819
853
  a1_scale: Optional[torch.Tensor] = None,
820
854
  a2_scale: Optional[torch.Tensor] = None,
821
855
  block_shape: Optional[List[int]] = None,
856
+ no_combine: bool = False,
822
857
  ) -> torch.Tensor:
823
858
  return torch.empty_like(hidden_states)
824
859
 
@@ -840,14 +875,17 @@ def fused_experts(
840
875
  inplace: bool = False,
841
876
  activation: str = "silu",
842
877
  use_fp8_w8a8: bool = False,
878
+ use_int8_w8a8: bool = False,
843
879
  use_int8_w8a16: bool = False,
844
880
  w1_scale: Optional[torch.Tensor] = None,
845
881
  w2_scale: Optional[torch.Tensor] = None,
846
882
  a1_scale: Optional[torch.Tensor] = None,
847
883
  a2_scale: Optional[torch.Tensor] = None,
848
884
  block_shape: Optional[List[int]] = None,
885
+ no_combine: bool = False,
849
886
  ):
850
887
  if inplace:
888
+ assert not no_combine, "no combine + inplace makes no sense"
851
889
  torch.ops.sglang.inplace_fused_experts(
852
890
  hidden_states,
853
891
  w1,
@@ -856,6 +894,7 @@ def fused_experts(
856
894
  topk_ids,
857
895
  activation,
858
896
  use_fp8_w8a8,
897
+ use_int8_w8a8,
859
898
  use_int8_w8a16,
860
899
  w1_scale,
861
900
  w2_scale,
@@ -873,12 +912,14 @@ def fused_experts(
873
912
  topk_ids,
874
913
  activation,
875
914
  use_fp8_w8a8,
915
+ use_int8_w8a8,
876
916
  use_int8_w8a16,
877
917
  w1_scale,
878
918
  w2_scale,
879
919
  a1_scale,
880
920
  a2_scale,
881
921
  block_shape,
922
+ no_combine=no_combine,
882
923
  )
883
924
 
884
925
 
@@ -891,15 +932,21 @@ def fused_experts_impl(
891
932
  inplace: bool = False,
892
933
  activation: str = "silu",
893
934
  use_fp8_w8a8: bool = False,
935
+ use_int8_w8a8: bool = False,
894
936
  use_int8_w8a16: bool = False,
895
937
  w1_scale: Optional[torch.Tensor] = None,
896
938
  w2_scale: Optional[torch.Tensor] = None,
897
939
  a1_scale: Optional[torch.Tensor] = None,
898
940
  a2_scale: Optional[torch.Tensor] = None,
899
941
  block_shape: Optional[List[int]] = None,
942
+ no_combine: bool = False,
900
943
  ):
901
944
  padded_size = padding_size
902
- if not use_fp8_w8a8 or block_shape is not None:
945
+ if (
946
+ not (use_fp8_w8a8 or use_int8_w8a8)
947
+ or block_shape is not None
948
+ or (is_hip_ and get_bool_env_var("CK_MOE"))
949
+ ):
903
950
  padded_size = 0
904
951
 
905
952
  # Check constraints.
@@ -918,6 +965,7 @@ def fused_experts_impl(
918
965
  M = min(num_tokens, CHUNK_SIZE)
919
966
  config_dtype = get_config_dtype_str(
920
967
  use_fp8_w8a8=use_fp8_w8a8,
968
+ use_int8_w8a8=use_int8_w8a8,
921
969
  use_int8_w8a16=use_int8_w8a16,
922
970
  dtype=hidden_states.dtype,
923
971
  )
@@ -933,25 +981,33 @@ def fused_experts_impl(
933
981
 
934
982
  config = get_config_func(M)
935
983
 
936
- intermediate_cache1 = torch.empty(
937
- (M, topk_ids.shape[1], N),
984
+ cache = torch.empty(
985
+ M * topk_ids.shape[1] * max(N, w2.shape[1]),
938
986
  device=hidden_states.device,
939
987
  dtype=hidden_states.dtype,
940
988
  )
989
+ intermediate_cache1 = cache[: M * topk_ids.shape[1] * N].view(
990
+ (M, topk_ids.shape[1], N),
991
+ )
941
992
  intermediate_cache2 = torch.empty(
942
993
  (M * topk_ids.shape[1], N // 2),
943
994
  device=hidden_states.device,
944
995
  dtype=hidden_states.dtype,
945
996
  )
946
- intermediate_cache3 = torch.empty(
997
+ intermediate_cache3 = cache[: M * topk_ids.shape[1] * w2.shape[1]].view(
947
998
  (M, topk_ids.shape[1], w2.shape[1]),
948
- device=hidden_states.device,
949
- dtype=hidden_states.dtype,
950
999
  )
951
1000
 
952
1001
  compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
953
1002
 
954
- if inplace:
1003
+ if no_combine:
1004
+ assert not inplace
1005
+ out_hidden_states = torch.empty(
1006
+ (num_tokens, topk_ids.shape[1], w2.shape[1]),
1007
+ device=hidden_states.device,
1008
+ dtype=hidden_states.dtype,
1009
+ )
1010
+ elif inplace:
955
1011
  out_hidden_states = hidden_states
956
1012
  else:
957
1013
  out_hidden_states = torch.empty_like(hidden_states)
@@ -1000,6 +1056,7 @@ def fused_experts_impl(
1000
1056
  config,
1001
1057
  compute_type=compute_type,
1002
1058
  use_fp8_w8a8=use_fp8_w8a8,
1059
+ use_int8_w8a8=use_int8_w8a8,
1003
1060
  use_int8_w8a16=use_int8_w8a16,
1004
1061
  block_shape=block_shape,
1005
1062
  )
@@ -1020,7 +1077,11 @@ def fused_experts_impl(
1020
1077
  invoke_fused_moe_kernel(
1021
1078
  intermediate_cache2,
1022
1079
  w2,
1023
- intermediate_cache3,
1080
+ (
1081
+ intermediate_cache3
1082
+ if not no_combine and topk_ids.shape[1] != 1
1083
+ else out_hidden_states[begin_chunk_idx:end_chunk_idx]
1084
+ ),
1024
1085
  a2_scale,
1025
1086
  w2_scale,
1026
1087
  curr_topk_weights,
@@ -1033,20 +1094,21 @@ def fused_experts_impl(
1033
1094
  config,
1034
1095
  compute_type=compute_type,
1035
1096
  use_fp8_w8a8=use_fp8_w8a8,
1097
+ use_int8_w8a8=use_int8_w8a8,
1036
1098
  use_int8_w8a16=use_int8_w8a16,
1037
1099
  block_shape=block_shape,
1038
1100
  )
1039
1101
 
1040
- if is_hip_flag:
1102
+ if no_combine:
1103
+ pass
1104
+ elif is_hip_:
1041
1105
  ops.moe_sum(
1042
1106
  intermediate_cache3.view(*intermediate_cache3.shape),
1043
1107
  out_hidden_states[begin_chunk_idx:end_chunk_idx],
1044
1108
  )
1045
1109
  else:
1046
1110
  if topk_ids.shape[1] == 1:
1047
- out_hidden_states[begin_chunk_idx:end_chunk_idx].copy_(
1048
- intermediate_cache3[:, 0]
1049
- )
1111
+ pass # we write directly into out_hidden_states
1050
1112
  elif topk_ids.shape[1] == 2:
1051
1113
  torch.add(
1052
1114
  intermediate_cache3[:, 0],
@@ -1077,12 +1139,14 @@ def fused_moe(
1077
1139
  topk_group: Optional[int] = None,
1078
1140
  custom_routing_function: Optional[Callable] = None,
1079
1141
  use_fp8_w8a8: bool = False,
1142
+ use_int8_w8a8: bool = False,
1080
1143
  use_int8_w8a16: bool = False,
1081
1144
  w1_scale: Optional[torch.Tensor] = None,
1082
1145
  w2_scale: Optional[torch.Tensor] = None,
1083
1146
  a1_scale: Optional[torch.Tensor] = None,
1084
1147
  a2_scale: Optional[torch.Tensor] = None,
1085
1148
  block_shape: Optional[List[int]] = None,
1149
+ no_combine: bool = False,
1086
1150
  ) -> torch.Tensor:
1087
1151
  """
1088
1152
  This function computes a Mixture of Experts (MoE) layer using two sets of
@@ -1104,6 +1168,8 @@ def fused_moe(
1104
1168
  note: Deepseek V2/V3/R1 series models use grouped_topk
1105
1169
  - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
1106
1170
  products for w1 and w2. Defaults to False.
1171
+ - use_int8_w8a8 (bool): If True, use int8 arithmetic to compute the inner
1172
+ products for w1 and w2. Defaults to False.
1107
1173
  - use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner
1108
1174
  products for w1 and w2. Defaults to False.
1109
1175
  - w1_scale (Optional[torch.Tensor]): Optional scale to be used for
@@ -1143,10 +1209,12 @@ def fused_moe(
1143
1209
  inplace=inplace,
1144
1210
  activation=activation,
1145
1211
  use_fp8_w8a8=use_fp8_w8a8,
1212
+ use_int8_w8a8=use_int8_w8a8,
1146
1213
  use_int8_w8a16=use_int8_w8a16,
1147
1214
  w1_scale=w1_scale,
1148
1215
  w2_scale=w2_scale,
1149
1216
  a1_scale=a1_scale,
1150
1217
  a2_scale=a2_scale,
1151
1218
  block_shape=block_shape,
1219
+ no_combine=no_combine,
1152
1220
  )