sglang 0.4.3.post2__py3-none-any.whl → 0.4.3.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (205) hide show
  1. sglang/api.py +1 -1
  2. sglang/bench_offline_throughput.py +19 -0
  3. sglang/bench_one_batch.py +2 -2
  4. sglang/bench_serving.py +123 -79
  5. sglang/global_config.py +8 -3
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/lang/ir.py +1 -1
  8. sglang/srt/_custom_ops.py +83 -91
  9. sglang/srt/configs/load_config.py +4 -1
  10. sglang/srt/configs/model_config.py +48 -2
  11. sglang/srt/configs/qwen2_5_vl_config.py +5 -2
  12. sglang/srt/constrained/base_grammar_backend.py +117 -15
  13. sglang/srt/constrained/llguidance_backend.py +151 -0
  14. sglang/srt/constrained/outlines_backend.py +24 -33
  15. sglang/srt/constrained/xgrammar_backend.py +69 -38
  16. sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
  17. sglang/srt/distributed/parallel_state.py +48 -3
  18. sglang/srt/entrypoints/engine.py +67 -9
  19. sglang/srt/entrypoints/http_server.py +190 -41
  20. sglang/srt/entrypoints/verl_engine.py +147 -0
  21. sglang/srt/function_call_parser.py +0 -1
  22. sglang/srt/layers/activation.py +11 -0
  23. sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
  24. sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
  25. sglang/srt/layers/attention/flashinfer_backend.py +220 -378
  26. sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
  27. sglang/srt/layers/attention/torch_native_backend.py +1 -1
  28. sglang/srt/layers/attention/triton_backend.py +9 -6
  29. sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
  30. sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
  31. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
  32. sglang/srt/layers/attention/utils.py +39 -0
  33. sglang/srt/layers/attention/vision.py +60 -63
  34. sglang/srt/layers/dp_attention.py +142 -1
  35. sglang/srt/layers/layernorm.py +1 -1
  36. sglang/srt/layers/linear.py +3 -1
  37. sglang/srt/layers/logits_processor.py +281 -45
  38. sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
  39. sglang/srt/layers/moe/ep_moe/layer.py +140 -28
  40. sglang/srt/layers/moe/fused_moe_native.py +2 -0
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
  48. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
  51. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
  55. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
  63. sglang/srt/layers/moe/topk.py +13 -4
  64. sglang/srt/layers/quantization/__init__.py +111 -7
  65. sglang/srt/layers/quantization/blockwise_int8.py +409 -0
  66. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  68. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  69. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  70. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  71. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  72. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  73. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  74. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  75. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  76. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  77. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  78. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  79. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  80. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  81. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  82. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  83. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  84. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  85. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  86. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  88. sglang/srt/layers/quantization/fp8.py +69 -28
  89. sglang/srt/layers/quantization/fp8_utils.py +17 -1
  90. sglang/srt/layers/quantization/gptq.py +416 -0
  91. sglang/srt/layers/quantization/int8_kernel.py +327 -0
  92. sglang/srt/layers/quantization/int8_utils.py +73 -0
  93. sglang/srt/layers/quantization/modelopt_quant.py +18 -1
  94. sglang/srt/layers/radix_attention.py +1 -0
  95. sglang/srt/layers/rotary_embedding.py +0 -1
  96. sglang/srt/layers/sampler.py +76 -31
  97. sglang/srt/layers/vocab_parallel_embedding.py +14 -13
  98. sglang/srt/lora/lora.py +17 -1
  99. sglang/srt/lora/lora_config.py +5 -0
  100. sglang/srt/lora/lora_manager.py +1 -3
  101. sglang/srt/managers/cache_controller.py +193 -62
  102. sglang/srt/managers/configure_logging.py +2 -1
  103. sglang/srt/managers/data_parallel_controller.py +6 -2
  104. sglang/srt/managers/detokenizer_manager.py +124 -102
  105. sglang/srt/managers/image_processor.py +2 -1
  106. sglang/srt/managers/io_struct.py +143 -6
  107. sglang/srt/managers/schedule_batch.py +237 -197
  108. sglang/srt/managers/schedule_policy.py +29 -29
  109. sglang/srt/managers/scheduler.py +681 -259
  110. sglang/srt/managers/session_controller.py +6 -2
  111. sglang/srt/managers/tokenizer_manager.py +224 -68
  112. sglang/srt/managers/tp_worker.py +15 -4
  113. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  114. sglang/srt/mem_cache/chunk_cache.py +18 -11
  115. sglang/srt/mem_cache/hiradix_cache.py +394 -0
  116. sglang/srt/mem_cache/memory_pool.py +44 -18
  117. sglang/srt/mem_cache/radix_cache.py +58 -47
  118. sglang/srt/metrics/collector.py +94 -36
  119. sglang/srt/model_executor/cuda_graph_runner.py +55 -24
  120. sglang/srt/model_executor/forward_batch_info.py +49 -16
  121. sglang/srt/model_executor/model_runner.py +208 -28
  122. sglang/srt/model_loader/loader.py +3 -3
  123. sglang/srt/model_loader/weight_utils.py +36 -14
  124. sglang/srt/models/baichuan.py +31 -6
  125. sglang/srt/models/chatglm.py +39 -7
  126. sglang/srt/models/commandr.py +29 -5
  127. sglang/srt/models/dbrx.py +31 -5
  128. sglang/srt/models/deepseek.py +43 -6
  129. sglang/srt/models/deepseek_nextn.py +32 -19
  130. sglang/srt/models/deepseek_v2.py +265 -32
  131. sglang/srt/models/exaone.py +19 -9
  132. sglang/srt/models/gemma.py +22 -8
  133. sglang/srt/models/gemma2.py +25 -12
  134. sglang/srt/models/gemma2_reward.py +5 -1
  135. sglang/srt/models/gpt2.py +28 -13
  136. sglang/srt/models/gpt_bigcode.py +27 -5
  137. sglang/srt/models/granite.py +21 -9
  138. sglang/srt/models/grok.py +21 -4
  139. sglang/srt/models/internlm2.py +36 -6
  140. sglang/srt/models/internlm2_reward.py +5 -1
  141. sglang/srt/models/llama.py +26 -9
  142. sglang/srt/models/llama_classification.py +5 -1
  143. sglang/srt/models/llama_eagle.py +17 -4
  144. sglang/srt/models/llama_embedding.py +5 -1
  145. sglang/srt/models/llama_reward.py +7 -2
  146. sglang/srt/models/llava.py +19 -3
  147. sglang/srt/models/llavavid.py +10 -1
  148. sglang/srt/models/minicpm.py +26 -2
  149. sglang/srt/models/minicpm3.py +39 -3
  150. sglang/srt/models/minicpmv.py +45 -14
  151. sglang/srt/models/mixtral.py +20 -9
  152. sglang/srt/models/mixtral_quant.py +50 -8
  153. sglang/srt/models/mllama.py +57 -11
  154. sglang/srt/models/olmo.py +34 -6
  155. sglang/srt/models/olmo2.py +34 -13
  156. sglang/srt/models/olmoe.py +26 -4
  157. sglang/srt/models/phi3_small.py +29 -10
  158. sglang/srt/models/qwen.py +26 -3
  159. sglang/srt/models/qwen2.py +26 -4
  160. sglang/srt/models/qwen2_5_vl.py +46 -8
  161. sglang/srt/models/qwen2_eagle.py +17 -5
  162. sglang/srt/models/qwen2_moe.py +44 -6
  163. sglang/srt/models/qwen2_rm.py +78 -0
  164. sglang/srt/models/qwen2_vl.py +39 -8
  165. sglang/srt/models/stablelm.py +32 -5
  166. sglang/srt/models/torch_native_llama.py +5 -2
  167. sglang/srt/models/xverse.py +21 -9
  168. sglang/srt/models/xverse_moe.py +45 -7
  169. sglang/srt/models/yivl.py +2 -1
  170. sglang/srt/openai_api/adapter.py +109 -24
  171. sglang/srt/openai_api/protocol.py +17 -1
  172. sglang/srt/reasoning_parser.py +154 -0
  173. sglang/srt/sampling/penaltylib/__init__.py +4 -6
  174. sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
  175. sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
  176. sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
  177. sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
  178. sglang/srt/sampling/sampling_batch_info.py +79 -157
  179. sglang/srt/sampling/sampling_params.py +16 -13
  180. sglang/srt/server_args.py +136 -52
  181. sglang/srt/speculative/build_eagle_tree.py +2 -8
  182. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +0 -1
  183. sglang/srt/speculative/eagle_utils.py +92 -58
  184. sglang/srt/speculative/eagle_worker.py +186 -94
  185. sglang/srt/speculative/spec_info.py +1 -13
  186. sglang/srt/utils.py +43 -17
  187. sglang/srt/warmup.py +47 -0
  188. sglang/test/few_shot_gsm8k.py +4 -1
  189. sglang/test/runners.py +389 -126
  190. sglang/test/send_one.py +88 -0
  191. sglang/test/test_block_fp8_ep.py +361 -0
  192. sglang/test/test_programs.py +1 -1
  193. sglang/test/test_utils.py +138 -84
  194. sglang/utils.py +50 -60
  195. sglang/version.py +1 -1
  196. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/METADATA +21 -15
  197. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/RECORD +200 -166
  198. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/WHEEL +1 -1
  199. sglang/bench_latency.py +0 -1
  200. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
  201. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
  202. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
  203. sglang/test/srt/sampling/penaltylib/utils.py +0 -344
  204. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/LICENSE +0 -0
  205. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/top_level.txt +0 -0
@@ -29,6 +29,9 @@ import logging
29
29
 
30
30
  is_hip_ = is_hip()
31
31
 
32
+ if is_hip_:
33
+ from aiter import ck_moe
34
+
32
35
  logger = logging.getLogger(__name__)
33
36
 
34
37
 
@@ -125,6 +128,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
125
128
  custom_routing_function: Optional[Callable] = None,
126
129
  correction_bias: Optional[torch.Tensor] = None,
127
130
  activation: str = "silu",
131
+ inplace: bool = True,
132
+ no_combine: bool = False,
128
133
  ) -> torch.Tensor:
129
134
  return self.forward(
130
135
  x=x,
@@ -138,6 +143,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
138
143
  custom_routing_function=custom_routing_function,
139
144
  correction_bias=correction_bias,
140
145
  activation=activation,
146
+ inplace=inplace,
147
+ no_combine=no_combine,
141
148
  )
142
149
 
143
150
  def forward_cuda(
@@ -153,6 +160,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
153
160
  custom_routing_function: Optional[Callable] = None,
154
161
  correction_bias: Optional[torch.Tensor] = None,
155
162
  activation: str = "silu",
163
+ inplace: bool = True,
164
+ no_combine: bool = False,
156
165
  ) -> torch.Tensor:
157
166
  topk_weights, topk_ids = select_experts(
158
167
  hidden_states=x,
@@ -167,17 +176,20 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
167
176
  )
168
177
 
169
178
  if is_hip_ and get_bool_env_var("CK_MOE"):
170
- import ater
171
- from ater.fused_moe import fused_experts_ck
172
-
173
- assert activation == "silu", f"{activation=} is not supported."
174
-
175
- return fused_experts_ck(
176
- hidden_states=x,
177
- w1=layer.w13_weight,
178
- w2=layer.w2_weight,
179
- topk_weights=topk_weights,
180
- topk_ids=topk_ids,
179
+ assert not no_combine, "unsupported"
180
+ return ck_moe(
181
+ x,
182
+ layer.w13_weight,
183
+ layer.w2_weight,
184
+ topk_weights,
185
+ topk_ids,
186
+ None,
187
+ None,
188
+ None,
189
+ None,
190
+ 32,
191
+ None,
192
+ activation,
181
193
  )
182
194
  else:
183
195
  return fused_experts(
@@ -186,8 +198,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
186
198
  w2=layer.w2_weight,
187
199
  topk_weights=topk_weights,
188
200
  topk_ids=topk_ids,
189
- inplace=True,
201
+ inplace=inplace and not no_combine,
190
202
  activation=activation,
203
+ no_combine=no_combine,
191
204
  )
192
205
 
193
206
  def forward_cpu(
@@ -202,6 +215,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
202
215
  num_expert_group: Optional[int] = None,
203
216
  custom_routing_function: Optional[Callable] = None,
204
217
  correction_bias: Optional[torch.Tensor] = None,
218
+ inplace: bool = True,
205
219
  ) -> torch.Tensor:
206
220
  return moe_forward_native(
207
221
  layer,
@@ -241,6 +255,7 @@ class FusedMoE(torch.nn.Module):
241
255
  reduce_results: Whether to all all_reduce on the output of the layer
242
256
  renomalize: Whether to renormalize the logits in the fused_moe kernel
243
257
  quant_config: Quantization configure.
258
+ inplace: suggestion to compute inplace (modify input activation).
244
259
  """
245
260
 
246
261
  def __init__(
@@ -262,6 +277,8 @@ class FusedMoE(torch.nn.Module):
262
277
  correction_bias: Optional[torch.Tensor] = None,
263
278
  activation: str = "silu",
264
279
  use_presharded_weights: bool = False,
280
+ inplace: bool = True,
281
+ no_combine: bool = False,
265
282
  ):
266
283
  super().__init__()
267
284
 
@@ -285,6 +302,9 @@ class FusedMoE(torch.nn.Module):
285
302
  self.custom_routing_function = custom_routing_function
286
303
  self.correction_bias = correction_bias
287
304
  self.activation = activation
305
+ self.use_presharded_weights = use_presharded_weights
306
+ self.inplace = inplace
307
+ self.no_combine = no_combine
288
308
 
289
309
  if quant_config is None:
290
310
  self.quant_method: Optional[QuantizeMethodBase] = (
@@ -304,7 +324,6 @@ class FusedMoE(torch.nn.Module):
304
324
  params_dtype=params_dtype,
305
325
  weight_loader=self.weight_loader,
306
326
  )
307
- self.use_presharded_weights = use_presharded_weights
308
327
 
309
328
  def _load_per_tensor_weight_scale(
310
329
  self,
@@ -598,6 +617,8 @@ class FusedMoE(torch.nn.Module):
598
617
  custom_routing_function=self.custom_routing_function,
599
618
  correction_bias=self.correction_bias,
600
619
  activation=self.activation,
620
+ inplace=self.inplace,
621
+ no_combine=self.no_combine,
601
622
  )
602
623
 
603
624
  if self.reduce_results and self.tp_size > 1:
@@ -75,7 +75,6 @@ def fused_topk(
75
75
  return topk_weights, topk_ids
76
76
 
77
77
 
78
- # This is used by the Deepseek V2/V3/R1 series models
79
78
  @torch.compile(dynamic=True, backend=get_compiler_backend())
80
79
  def grouped_topk(
81
80
  hidden_states: torch.Tensor,
@@ -84,10 +83,17 @@ def grouped_topk(
84
83
  renormalize: bool,
85
84
  num_expert_group: int = 0,
86
85
  topk_group: int = 0,
86
+ scoring_func: str = "softmax",
87
87
  ):
88
88
  assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
89
89
 
90
- scores = torch.softmax(gating_output, dim=-1)
90
+ if scoring_func == "softmax":
91
+ scores = torch.softmax(gating_output, dim=-1)
92
+ elif scoring_func == "sigmoid":
93
+ scores = gating_output.sigmoid()
94
+ else:
95
+ raise ValueError(f"Scoring function '{scoring_func}' is not supported.")
96
+
91
97
  num_token = scores.shape[0]
92
98
  group_scores = (
93
99
  scores.view(num_token, num_expert_group, -1).max(dim=-1).values
@@ -111,6 +117,7 @@ def grouped_topk(
111
117
  return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
112
118
 
113
119
 
120
+ # DeepSeek V2/V3/R1 uses biased_grouped_top
114
121
  @torch.compile(dynamic=True, backend=get_compiler_backend())
115
122
  def biased_grouped_topk(
116
123
  hidden_states: torch.Tensor,
@@ -141,7 +148,9 @@ def biased_grouped_topk(
141
148
  .expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
142
149
  .reshape(num_token, -1)
143
150
  ) # [n, e]
144
- tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) # [n, e]
151
+ tmp_scores = scores_for_choice.masked_fill(
152
+ ~score_mask.bool(), float("-inf")
153
+ ) # [n, e]
145
154
  _, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
146
155
  topk_weights = scores.gather(1, topk_ids)
147
156
 
@@ -163,7 +172,7 @@ def select_experts(
163
172
  correction_bias: Optional[torch.Tensor] = None,
164
173
  torch_native: bool = False,
165
174
  ):
166
- # DeekSeekv2 uses grouped_top_k
175
+ # DeepSeek V2/V3/R1 uses biased_grouped_top
167
176
  if use_grouped_topk:
168
177
  assert topk_group is not None
169
178
  assert num_expert_group is not None
@@ -1,5 +1,7 @@
1
1
  # Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py
2
- from typing import Callable, Dict, Optional, Type
2
+ import re
3
+ from copy import deepcopy
4
+ from typing import Callable, Dict, Optional, Type, Union
3
5
 
4
6
  import torch
5
7
  from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
@@ -16,15 +18,15 @@ from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfi
16
18
  from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
17
19
  from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
18
20
  from vllm.model_executor.layers.quantization.gguf import GGUFConfig
19
- from vllm.model_executor.layers.quantization.gptq import GPTQConfig
20
- from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinConfig
21
21
  from vllm.model_executor.layers.quantization.gptq_marlin_24 import GPTQMarlin24Config
22
22
  from vllm.model_executor.layers.quantization.marlin import MarlinConfig
23
23
  from vllm.model_executor.layers.quantization.qqq import QQQConfig
24
24
  from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
25
25
 
26
26
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
27
+ from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
27
28
  from sglang.srt.layers.quantization.fp8 import Fp8Config
29
+ from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
28
30
  from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config
29
31
  from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
30
32
 
@@ -34,6 +36,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
34
36
  "deepspeedfp": DeepSpeedFPConfig,
35
37
  "tpu_int8": Int8TpuConfig,
36
38
  "fp8": Fp8Config,
39
+ "blockwise_int8": BlockInt8Config,
37
40
  "fbgemm_fp8": FBGEMMFp8Config,
38
41
  "marlin": MarlinConfig,
39
42
  "modelopt": ModelOptFp8Config,
@@ -59,19 +62,119 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
59
62
  return QUANTIZATION_METHODS[quantization]
60
63
 
61
64
 
65
+ # Match dynamic rules with module name (prefix) and override quantize
66
+ # config if module (prefix) matches a rule
67
+ def override_config(config: QuantizationConfig, prefix: str):
68
+ weight_bits = get_dynamic_override(config, prefix, "bits", config.weight_bits)
69
+ if isinstance(weight_bits, int):
70
+ config.weight_bits = weight_bits
71
+ group_size = get_dynamic_override(config, prefix, "group_size", config.group_size)
72
+ if isinstance(group_size, int):
73
+ config.group_size = group_size
74
+ desc_act = get_dynamic_override(config, prefix, "desc_act", config.desc_act)
75
+ if isinstance(desc_act, bool):
76
+ config.desc_act = desc_act
77
+
78
+ config.pack_factor = 32 // config.weight_bits # packed into int32
79
+ if config.get_name() == "gptq_marlin":
80
+ is_sym = get_dynamic_override(config, prefix, "sym", config.is_sym)
81
+ if isinstance(is_sym, bool):
82
+ config.is_sym = is_sym
83
+
84
+ if (config.weight_bits, config.is_sym) not in config.TYPE_MAP:
85
+ raise ValueError(
86
+ "Unsupported quantization config: "
87
+ f"bits={config.weight_bits}, sym={config.is_sym}"
88
+ )
89
+
90
+ config.quant_type = config.TYPE_MAP[(config.weight_bits, config.is_sym)]
91
+ elif config.get_name() == "gptq":
92
+ if config.weight_bits not in [2, 3, 4, 8]:
93
+ raise ValueError(
94
+ "Currently, only 2/3/4/8-bit weight quantization is "
95
+ f"supported for GPTQ, but got {config.weight_bits} bits."
96
+ )
97
+
98
+
99
+ def get_dynamic_override(
100
+ config: QuantizationConfig,
101
+ layer_name: str,
102
+ key: Optional[str] = None,
103
+ default_value: Union[int, bool, None] = None,
104
+ ) -> Union[Dict, int, bool, None]:
105
+ for pattern, pattern_dict in config.dynamic.items():
106
+ # Negative match: matched modules are excluded from quantized init
107
+ if pattern.startswith("-:"):
108
+ if re.match(pattern.removeprefix("-:"), layer_name):
109
+ return False
110
+ # Positive match: matched modules have quant properties overrides
111
+ # base quant config
112
+ elif re.match(pattern.removeprefix("+:"), layer_name):
113
+ if key is None:
114
+ return pattern_dict
115
+ else:
116
+ return pattern_dict.get(key, default_value)
117
+ return default_value
118
+
119
+
120
+ def get_linear_quant_method(
121
+ config: QuantizationConfig,
122
+ layer: torch.nn.Module,
123
+ prefix: str,
124
+ linear_method_cls: type,
125
+ ):
126
+
127
+ from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
128
+ from sglang.srt.layers.vocab_parallel_embedding import (
129
+ ParallelLMHead,
130
+ UnquantizedEmbeddingMethod,
131
+ )
132
+
133
+ cloned_config = deepcopy(config)
134
+ parallel_lm_head_quantized = (
135
+ isinstance(layer, ParallelLMHead) and cloned_config.lm_head_quantized
136
+ )
137
+
138
+ if isinstance(layer, LinearBase) or parallel_lm_head_quantized:
139
+ # False = skip module, None = no override, else = Positive match
140
+ if (
141
+ get_dynamic_override( # noqa: E712
142
+ cloned_config, layer_name=prefix # noqa: E712
143
+ )
144
+ == False
145
+ ): # noqa: E712
146
+ if parallel_lm_head_quantized:
147
+ return UnquantizedEmbeddingMethod()
148
+ return UnquantizedLinearMethod()
149
+
150
+ if prefix:
151
+ # Dynamic per module/layer rules may override base config
152
+ override_config(cloned_config, prefix=prefix)
153
+
154
+ return linear_method_cls(cloned_config)
155
+ return None
156
+
157
+
62
158
  def gptq_get_quant_method(self, layer, prefix):
159
+ from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
63
160
  from vllm.model_executor.layers.quantization.gptq_marlin import (
64
161
  GPTQMarlinLinearMethod,
65
162
  GPTQMarlinMoEMethod,
66
163
  )
67
164
 
68
- from sglang.srt.layers.linear import LinearBase
69
165
  from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
70
166
 
71
- if isinstance(layer, LinearBase):
72
- return GPTQMarlinLinearMethod(self)
73
- elif isinstance(layer, FusedMoE):
167
+ if isinstance(layer, FusedMoE):
74
168
  return GPTQMarlinMoEMethod(self)
169
+
170
+ if isinstance(self, GPTQConfig):
171
+ return get_linear_quant_method(
172
+ self, layer, prefix=prefix, linear_method_cls=GPTQLinearMethod
173
+ )
174
+ elif isinstance(self, GPTQMarlinConfig):
175
+ return get_linear_quant_method(
176
+ self, layer, prefix=prefix, linear_method_cls=GPTQMarlinLinearMethod
177
+ )
75
178
  return None
76
179
 
77
180
 
@@ -153,6 +256,7 @@ def apply_monkey_patches():
153
256
  from vllm.model_executor.layers.quantization.awq_marlin import AWQMoEMethod
154
257
 
155
258
  setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
259
+ setattr(GPTQConfig, "get_quant_method", gptq_get_quant_method)
156
260
  setattr(AWQMarlinConfig, "get_quant_method", awq_get_quant_method)
157
261
  setattr(AWQMoEMethod, "apply", awq_moe_method_apply)
158
262