sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. sglang/bench_one_batch.py +2 -1
  2. sglang/eval/loogle_eval.py +7 -0
  3. sglang/srt/_custom_ops.py +29 -1
  4. sglang/srt/configs/deepseekvl2.py +11 -2
  5. sglang/srt/configs/internvl.py +3 -0
  6. sglang/srt/configs/janus_pro.py +3 -0
  7. sglang/srt/configs/model_config.py +10 -8
  8. sglang/srt/configs/update_config.py +3 -1
  9. sglang/srt/conversation.py +2 -1
  10. sglang/srt/custom_op.py +5 -2
  11. sglang/srt/disaggregation/common/conn.py +34 -6
  12. sglang/srt/disaggregation/decode.py +9 -1
  13. sglang/srt/disaggregation/mini_lb.py +3 -2
  14. sglang/srt/disaggregation/mooncake/conn.py +93 -76
  15. sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
  16. sglang/srt/disaggregation/nixl/conn.py +17 -13
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
  18. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
  19. sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
  20. sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
  21. sglang/srt/distributed/parallel_state.py +103 -15
  22. sglang/srt/entrypoints/engine.py +31 -33
  23. sglang/srt/entrypoints/http_server.py +20 -32
  24. sglang/srt/entrypoints/openai/protocol.py +3 -3
  25. sglang/srt/entrypoints/openai/serving_chat.py +48 -6
  26. sglang/srt/eplb/expert_location_dispatch.py +1 -1
  27. sglang/srt/function_call/base_format_detector.py +74 -12
  28. sglang/srt/function_call/deepseekv3_detector.py +26 -11
  29. sglang/srt/function_call/ebnf_composer.py +95 -63
  30. sglang/srt/function_call/function_call_parser.py +4 -2
  31. sglang/srt/function_call/kimik2_detector.py +41 -16
  32. sglang/srt/function_call/llama32_detector.py +6 -3
  33. sglang/srt/function_call/mistral_detector.py +11 -3
  34. sglang/srt/function_call/pythonic_detector.py +16 -14
  35. sglang/srt/function_call/qwen25_detector.py +12 -3
  36. sglang/srt/function_call/qwen3_coder_detector.py +151 -0
  37. sglang/srt/hf_transformers_utils.py +0 -1
  38. sglang/srt/layers/activation.py +24 -3
  39. sglang/srt/layers/attention/base_attn_backend.py +3 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +3 -3
  41. sglang/srt/layers/attention/flashinfer_backend.py +40 -1
  42. sglang/srt/layers/communicator.py +12 -12
  43. sglang/srt/layers/dp_attention.py +72 -24
  44. sglang/srt/layers/linear.py +13 -102
  45. sglang/srt/layers/logits_processor.py +34 -24
  46. sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
  47. sglang/srt/layers/moe/ep_moe/layer.py +23 -402
  48. sglang/srt/layers/moe/fused_moe_native.py +7 -47
  49. sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
  50. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +54 -263
  57. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
  58. sglang/srt/layers/moe/topk.py +190 -23
  59. sglang/srt/layers/quantization/__init__.py +20 -134
  60. sglang/srt/layers/quantization/awq.py +578 -11
  61. sglang/srt/layers/quantization/awq_triton.py +339 -0
  62. sglang/srt/layers/quantization/base_config.py +85 -10
  63. sglang/srt/layers/quantization/blockwise_int8.py +17 -55
  64. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
  65. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +23 -79
  66. sglang/srt/layers/quantization/fp8.py +273 -62
  67. sglang/srt/layers/quantization/fp8_kernel.py +210 -46
  68. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  69. sglang/srt/layers/quantization/gptq.py +501 -143
  70. sglang/srt/layers/quantization/marlin_utils.py +790 -0
  71. sglang/srt/layers/quantization/modelopt_quant.py +34 -112
  72. sglang/srt/layers/quantization/moe_wna16.py +45 -49
  73. sglang/srt/layers/quantization/petit.py +252 -0
  74. sglang/srt/layers/quantization/petit_utils.py +104 -0
  75. sglang/srt/layers/quantization/qoq.py +7 -6
  76. sglang/srt/layers/quantization/scalar_type.py +352 -0
  77. sglang/srt/layers/quantization/unquant.py +422 -0
  78. sglang/srt/layers/quantization/utils.py +340 -9
  79. sglang/srt/layers/quantization/w4afp8.py +8 -4
  80. sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
  81. sglang/srt/layers/quantization/w8a8_int8.py +51 -115
  82. sglang/srt/layers/radix_attention.py +5 -3
  83. sglang/srt/layers/vocab_parallel_embedding.py +1 -41
  84. sglang/srt/lora/lora.py +0 -4
  85. sglang/srt/lora/lora_manager.py +162 -164
  86. sglang/srt/lora/lora_registry.py +124 -0
  87. sglang/srt/lora/mem_pool.py +83 -35
  88. sglang/srt/lora/utils.py +12 -5
  89. sglang/srt/managers/cache_controller.py +288 -0
  90. sglang/srt/managers/io_struct.py +60 -30
  91. sglang/srt/managers/mm_utils.py +7 -8
  92. sglang/srt/managers/schedule_batch.py +163 -113
  93. sglang/srt/managers/schedule_policy.py +68 -27
  94. sglang/srt/managers/scheduler.py +256 -86
  95. sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
  96. sglang/srt/managers/tokenizer_manager.py +38 -27
  97. sglang/srt/managers/tp_worker.py +16 -4
  98. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  99. sglang/srt/mem_cache/allocator.py +74 -23
  100. sglang/srt/mem_cache/base_prefix_cache.py +14 -2
  101. sglang/srt/mem_cache/chunk_cache.py +5 -2
  102. sglang/srt/mem_cache/hicache_storage.py +168 -0
  103. sglang/srt/mem_cache/hiradix_cache.py +194 -5
  104. sglang/srt/mem_cache/memory_pool.py +16 -1
  105. sglang/srt/mem_cache/memory_pool_host.py +44 -2
  106. sglang/srt/mem_cache/radix_cache.py +26 -0
  107. sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
  108. sglang/srt/metrics/collector.py +9 -0
  109. sglang/srt/model_executor/cuda_graph_runner.py +66 -31
  110. sglang/srt/model_executor/forward_batch_info.py +210 -25
  111. sglang/srt/model_executor/model_runner.py +147 -42
  112. sglang/srt/model_loader/loader.py +7 -1
  113. sglang/srt/model_loader/utils.py +4 -4
  114. sglang/srt/models/clip.py +1 -1
  115. sglang/srt/models/deepseek.py +9 -6
  116. sglang/srt/models/deepseek_janus_pro.py +1 -1
  117. sglang/srt/models/deepseek_v2.py +192 -173
  118. sglang/srt/models/deepseek_vl2.py +5 -5
  119. sglang/srt/models/gemma.py +48 -0
  120. sglang/srt/models/gemma2.py +52 -0
  121. sglang/srt/models/gemma3_causal.py +63 -0
  122. sglang/srt/models/gemma3_mm.py +1 -1
  123. sglang/srt/models/gemma3n_mm.py +2 -4
  124. sglang/srt/models/granitemoe.py +385 -0
  125. sglang/srt/models/grok.py +9 -3
  126. sglang/srt/models/hunyuan.py +63 -16
  127. sglang/srt/models/internvl.py +1 -1
  128. sglang/srt/models/kimi_vl.py +1 -1
  129. sglang/srt/models/llama.py +41 -0
  130. sglang/srt/models/llama4.py +11 -11
  131. sglang/srt/models/llava.py +2 -2
  132. sglang/srt/models/llavavid.py +1 -1
  133. sglang/srt/models/minicpm.py +0 -2
  134. sglang/srt/models/minicpmo.py +3 -7
  135. sglang/srt/models/minicpmv.py +1 -1
  136. sglang/srt/models/mistral.py +1 -1
  137. sglang/srt/models/mixtral.py +9 -2
  138. sglang/srt/models/mllama.py +3 -5
  139. sglang/srt/models/mllama4.py +13 -6
  140. sglang/srt/models/olmoe.py +8 -5
  141. sglang/srt/models/persimmon.py +330 -0
  142. sglang/srt/models/phi.py +321 -0
  143. sglang/srt/models/phi4mm.py +44 -4
  144. sglang/srt/models/phi4mm_audio.py +1260 -0
  145. sglang/srt/models/phi4mm_utils.py +1917 -0
  146. sglang/srt/models/phimoe.py +9 -3
  147. sglang/srt/models/qwen.py +37 -0
  148. sglang/srt/models/qwen2.py +41 -0
  149. sglang/srt/models/qwen2_5_vl.py +4 -4
  150. sglang/srt/models/qwen2_audio.py +1 -1
  151. sglang/srt/models/qwen2_moe.py +53 -9
  152. sglang/srt/models/qwen2_vl.py +4 -4
  153. sglang/srt/models/qwen3.py +65 -1
  154. sglang/srt/models/qwen3_moe.py +57 -24
  155. sglang/srt/models/vila.py +1 -1
  156. sglang/srt/multimodal/processors/base_processor.py +91 -97
  157. sglang/srt/multimodal/processors/clip.py +21 -19
  158. sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
  159. sglang/srt/multimodal/processors/gemma3.py +13 -17
  160. sglang/srt/multimodal/processors/gemma3n.py +19 -23
  161. sglang/srt/multimodal/processors/internvl.py +9 -10
  162. sglang/srt/multimodal/processors/janus_pro.py +12 -27
  163. sglang/srt/multimodal/processors/kimi_vl.py +12 -14
  164. sglang/srt/multimodal/processors/llava.py +4 -2
  165. sglang/srt/multimodal/processors/minicpm.py +35 -44
  166. sglang/srt/multimodal/processors/mlama.py +21 -18
  167. sglang/srt/multimodal/processors/mllama4.py +4 -5
  168. sglang/srt/multimodal/processors/phi4mm.py +63 -39
  169. sglang/srt/multimodal/processors/pixtral.py +14 -35
  170. sglang/srt/multimodal/processors/qwen_audio.py +65 -0
  171. sglang/srt/multimodal/processors/qwen_vl.py +16 -21
  172. sglang/srt/multimodal/processors/vila.py +14 -14
  173. sglang/srt/reasoning_parser.py +46 -4
  174. sglang/srt/sampling/sampling_batch_info.py +6 -5
  175. sglang/srt/sampling/sampling_params.py +8 -1
  176. sglang/srt/server_args.py +454 -270
  177. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
  178. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +46 -37
  179. sglang/srt/speculative/eagle_utils.py +51 -23
  180. sglang/srt/speculative/eagle_worker.py +59 -44
  181. sglang/srt/two_batch_overlap.py +10 -5
  182. sglang/srt/utils.py +44 -69
  183. sglang/test/runners.py +14 -3
  184. sglang/test/test_activation.py +50 -1
  185. sglang/test/test_block_fp8.py +8 -3
  186. sglang/test/test_block_fp8_ep.py +1 -1
  187. sglang/test/test_custom_ops.py +12 -7
  188. sglang/test/test_cutlass_w4a8_moe.py +1 -3
  189. sglang/test/test_fp4_moe.py +1 -3
  190. sglang/test/test_marlin_moe.py +286 -0
  191. sglang/test/test_marlin_utils.py +171 -0
  192. sglang/test/test_utils.py +35 -0
  193. sglang/version.py +1 -1
  194. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +10 -10
  195. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +198 -175
  196. sglang/srt/layers/quantization/quant_utils.py +0 -166
  197. sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
  198. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
  199. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
  200. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
@@ -1,20 +1,21 @@
1
1
  # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
2
2
 
3
+ from __future__ import annotations
4
+
5
+ import re
6
+ from copy import deepcopy
3
7
  from types import MappingProxyType
4
- from typing import List, Mapping, Tuple, Union
8
+ from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union
5
9
 
10
+ import numpy
6
11
  import torch
7
12
 
8
13
  from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
9
- from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_npu
10
-
11
- _is_cuda = is_cuda()
12
- _is_npu = is_npu()
13
- _is_cpu_amx_available = cpu_has_amx_support()
14
- _is_cpu = is_cpu()
14
+ from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types
15
+ from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip, is_npu
15
16
 
16
- if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)):
17
- from vllm._custom_ops import scaled_fp8_quant
17
+ if TYPE_CHECKING:
18
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
18
19
 
19
20
 
20
21
  def is_layer_skipped(
@@ -143,3 +144,333 @@ def replace_parameter(
143
144
  if not isinstance(new, torch.nn.Parameter):
144
145
  new = torch.nn.Parameter(new, requires_grad=False)
145
146
  mod.register_parameter(name, torch.nn.Parameter(new, requires_grad=False))
147
+
148
+
149
+ # Match dynamic rules with module name (prefix) and override quantize
150
+ # config if module (prefix) matches a rule
151
+ def override_config(config: QuantizationConfig, prefix: str):
152
+ weight_bits = get_dynamic_override(config, prefix, "bits", config.weight_bits)
153
+ if isinstance(weight_bits, int):
154
+ config.weight_bits = weight_bits
155
+ group_size = get_dynamic_override(config, prefix, "group_size", config.group_size)
156
+ if isinstance(group_size, int):
157
+ config.group_size = group_size
158
+ desc_act = get_dynamic_override(config, prefix, "desc_act", config.desc_act)
159
+ if isinstance(desc_act, bool):
160
+ config.desc_act = desc_act
161
+
162
+ config.pack_factor = 32 // config.weight_bits # packed into int32
163
+ if config.get_name() == "gptq_marlin":
164
+ is_sym = get_dynamic_override(config, prefix, "sym", config.is_sym)
165
+ if isinstance(is_sym, bool):
166
+ config.is_sym = is_sym
167
+
168
+ if (config.weight_bits, config.is_sym) not in config.TYPE_MAP:
169
+ raise ValueError(
170
+ "Unsupported quantization config: "
171
+ f"bits={config.weight_bits}, sym={config.is_sym}"
172
+ )
173
+
174
+ config.quant_type = config.TYPE_MAP[(config.weight_bits, config.is_sym)]
175
+ elif config.get_name() == "gptq":
176
+ if config.weight_bits not in [2, 3, 4, 8]:
177
+ raise ValueError(
178
+ "Currently, only 2/3/4/8-bit weight quantization is "
179
+ f"supported for GPTQ, but got {config.weight_bits} bits."
180
+ )
181
+
182
+
183
+ def get_dynamic_override(
184
+ config: QuantizationConfig,
185
+ layer_name: str,
186
+ key: Optional[str] = None,
187
+ default_value: Union[int, bool, None] = None,
188
+ ) -> Union[Dict, int, bool, None]:
189
+ for pattern, pattern_dict in config.dynamic.items():
190
+ # Negative match: matched modules are excluded from quantized init
191
+ if pattern.startswith("-:"):
192
+ if re.match(pattern.removeprefix("-:"), layer_name):
193
+ return False
194
+ # Positive match: matched modules have quant properties overrides
195
+ # base quant config
196
+ elif re.match(pattern.removeprefix("+:"), layer_name):
197
+ if key is None:
198
+ return pattern_dict
199
+ else:
200
+ return pattern_dict.get(key, default_value)
201
+ return default_value
202
+
203
+
204
+ def get_linear_quant_method(
205
+ config: QuantizationConfig,
206
+ layer: torch.nn.Module,
207
+ prefix: str,
208
+ linear_method_cls: type,
209
+ ):
210
+ from sglang.srt.layers.linear import LinearBase
211
+ from sglang.srt.layers.quantization.unquant import (
212
+ UnquantizedEmbeddingMethod,
213
+ UnquantizedLinearMethod,
214
+ )
215
+ from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
216
+
217
+ cloned_config = deepcopy(config)
218
+ parallel_lm_head_quantized = (
219
+ isinstance(layer, ParallelLMHead) and cloned_config.lm_head_quantized
220
+ )
221
+
222
+ if isinstance(layer, LinearBase) or parallel_lm_head_quantized:
223
+ # False = skip module, None = no override, else = Positive match
224
+ if get_dynamic_override(cloned_config, layer_name=prefix) is False:
225
+ if parallel_lm_head_quantized:
226
+ return UnquantizedEmbeddingMethod()
227
+ return UnquantizedLinearMethod()
228
+
229
+ if prefix:
230
+ # Dynamic per module/layer rules may override base config
231
+ override_config(cloned_config, prefix=prefix)
232
+
233
+ return linear_method_cls(cloned_config)
234
+ return None
235
+
236
+
237
+ def get_pack_factor(num_bits):
238
+ assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}"
239
+ return 32 // num_bits
240
+
241
+
242
+ def permute_rows(
243
+ q_w: torch.Tensor,
244
+ w_ref: torch.Tensor,
245
+ group_size: int,
246
+ test_perm: Optional[torch.Tensor] = None,
247
+ ):
248
+ assert q_w.shape == w_ref.shape
249
+
250
+ orig_device = q_w.device
251
+ k_size, _ = q_w.shape
252
+
253
+ g_idx = torch.zeros((k_size,), dtype=torch.int32)
254
+ for i in range(k_size):
255
+ g_idx[i] = i // group_size
256
+
257
+ # Simulate act_order by doing a random permutation on K
258
+ rand_perm = test_perm if test_perm is not None else torch.randperm(k_size)
259
+
260
+ g_idx = g_idx[rand_perm].contiguous()
261
+ q_w = q_w[rand_perm, :].contiguous()
262
+ w_ref = w_ref[rand_perm, :].contiguous()
263
+
264
+ return (
265
+ w_ref.to(device=orig_device),
266
+ q_w.to(device=orig_device),
267
+ g_idx.to(device=orig_device),
268
+ rand_perm.to(device=orig_device),
269
+ )
270
+
271
+
272
+ def pack_cols(
273
+ q_w: torch.Tensor,
274
+ num_bits: int,
275
+ size_k: int,
276
+ size_n: int,
277
+ ):
278
+ assert q_w.shape == (size_k, size_n)
279
+
280
+ pack_factor = get_pack_factor(num_bits)
281
+ assert size_n % pack_factor == 0
282
+
283
+ orig_device = q_w.device
284
+
285
+ q_w = q_w.cpu().numpy().astype(numpy.uint32)
286
+
287
+ q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32)
288
+
289
+ for i in range(pack_factor):
290
+ q_res |= q_w[:, i::pack_factor] << num_bits * i
291
+
292
+ q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
293
+ q_res = q_res.contiguous()
294
+
295
+ return q_res
296
+
297
+
298
+ def unpack_cols(
299
+ packed_q_w: torch.Tensor,
300
+ num_bits: int,
301
+ size_k: int,
302
+ size_n: int,
303
+ ):
304
+ pack_factor = get_pack_factor(num_bits)
305
+ assert size_n % pack_factor == 0
306
+ assert packed_q_w.shape == (
307
+ size_k,
308
+ size_n // pack_factor,
309
+ ), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format(
310
+ packed_q_w.shape, size_k, size_n, pack_factor
311
+ )
312
+
313
+ orig_device = packed_q_w.device
314
+
315
+ packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32)
316
+ q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32)
317
+
318
+ mask = (1 << num_bits) - 1
319
+ for i in range(pack_factor):
320
+ vals = packed_q_w_cpu & mask
321
+ packed_q_w_cpu >>= num_bits
322
+ q_res[:, i::pack_factor] = vals
323
+
324
+ q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
325
+ q_res = q_res.contiguous()
326
+
327
+ return q_res
328
+
329
+
330
+ # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
331
+ def quantize_weights(
332
+ w: torch.Tensor,
333
+ quant_type: ScalarType,
334
+ group_size: Optional[int],
335
+ zero_points: bool = False,
336
+ ref_zero_points_after_scales: bool = False,
337
+ ):
338
+ assert (
339
+ quant_type.is_integer()
340
+ ), "Floating point quantization may work but has not been tested"
341
+ assert not zero_points or group_size is not None, (
342
+ "to have group zero points, group_size must be provided "
343
+ "(-1 group_size is channelwise)"
344
+ )
345
+
346
+ orig_device = w.device
347
+ orig_type = w.dtype
348
+ size_k, size_n = w.shape
349
+
350
+ assert w.is_floating_point(), "w must be float"
351
+
352
+ if group_size == -1:
353
+ group_size = size_k
354
+
355
+ # Reshape to [groupsize, -1]
356
+ if group_size is not None and group_size < size_k:
357
+ w = w.reshape((-1, group_size, size_n))
358
+ w = w.permute(1, 0, 2)
359
+ w = w.reshape((group_size, -1))
360
+
361
+ # Compute scale for each group
362
+ max_val = torch.max(w, 0, keepdim=True).values
363
+ min_val = torch.min(w, 0, keepdim=True).values
364
+
365
+ max_q_val = quant_type.max()
366
+ min_q_val = quant_type.min()
367
+
368
+ w_s = torch.Tensor([1.0]).to(w.device) # unscaled case
369
+ maybe_w_zp = None
370
+ if group_size is not None:
371
+ if zero_points:
372
+ assert not quant_type.is_signed() and quant_type.max() > 0
373
+ w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max()
374
+ maybe_w_zp = (
375
+ torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int()
376
+ )
377
+ else:
378
+ # If the bias is such that there are no possible negative/positive
379
+ # values, set the max value to inf to avoid divide by 0
380
+ w_s = torch.max(
381
+ abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)),
382
+ abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)),
383
+ )
384
+
385
+ # Quantize
386
+ w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0)
387
+ w_q = torch.clamp(w_q, min_q_val, max_q_val)
388
+
389
+ # Compute ref (dequantized)
390
+ # For some kernels (namely Machete) the zero-points are applied after the
391
+ # scales are applied, for this case computing the reference in similar way
392
+ # allows us to use tighter error tolerances in our unit tests.
393
+ if ref_zero_points_after_scales and maybe_w_zp is not None:
394
+ w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s
395
+ else:
396
+ w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s
397
+
398
+ if quant_type.has_bias():
399
+ w_q += quant_type.bias
400
+
401
+ # Restore original shapes
402
+ if group_size is not None and group_size < size_k:
403
+
404
+ def reshape_w(w):
405
+ w = w.reshape((group_size, -1, size_n))
406
+ w = w.permute(1, 0, 2)
407
+ w = w.reshape((size_k, size_n)).contiguous()
408
+ return w
409
+
410
+ w_q = reshape_w(w_q)
411
+ w_ref = reshape_w(w_ref)
412
+ w_s = w_s.reshape((-1, size_n)).contiguous()
413
+
414
+ if maybe_w_zp is not None:
415
+ maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous()
416
+ maybe_w_zp = maybe_w_zp.to(device=orig_device)
417
+
418
+ return (
419
+ w_ref.to(device=orig_device),
420
+ w_q.to(device=orig_device),
421
+ w_s if group_size is not None else None,
422
+ maybe_w_zp,
423
+ )
424
+
425
+
426
+ SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
427
+ SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
428
+
429
+
430
+ def gptq_quantize_weights(
431
+ w: torch.Tensor,
432
+ quant_type: ScalarType,
433
+ group_size: int,
434
+ act_order: bool,
435
+ test_perm: Optional[torch.Tensor] = None,
436
+ ):
437
+ size_k, _ = w.shape
438
+
439
+ assert w.is_floating_point(), "w must be float"
440
+ assert (
441
+ quant_type in SUPPORTED_GPTQ_QUANT_TYPES
442
+ ), f"Unsupported gptq type = {quant_type}"
443
+ assert group_size in SUPPORTED_GROUP_SIZES + [
444
+ size_k
445
+ ], f"Unsupported groupsize = {group_size}"
446
+
447
+ w_ref, w_q, w_s, _ = quantize_weights(w, quant_type, group_size)
448
+
449
+ # Apply act_order
450
+ g_idx = torch.empty(0, dtype=torch.int, device=w.device)
451
+ rand_perm = torch.empty(0, dtype=torch.int, device=w.device)
452
+ if act_order:
453
+ assert (
454
+ group_size < size_k
455
+ ), "For act_order, groupsize = {} must be less than size_k = {}".format(
456
+ group_size, size_k
457
+ )
458
+
459
+ w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size, test_perm)
460
+
461
+ return w_ref, w_q, w_s, g_idx, rand_perm
462
+
463
+
464
+ def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor):
465
+ orig_device = q_w.device
466
+
467
+ sort_indices = torch.argsort(g_idx).to(dtype=torch.int32) # Sort based on g_idx
468
+
469
+ g_idx = g_idx[sort_indices].contiguous()
470
+ q_w = q_w[sort_indices, :].contiguous()
471
+
472
+ return (
473
+ q_w.to(device=orig_device),
474
+ g_idx.to(device=orig_device),
475
+ sort_indices.to(device=orig_device),
476
+ )
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  from typing import Any, Dict, List, Optional
3
5
 
@@ -5,12 +7,13 @@ import torch
5
7
  from torch.nn import Module
6
8
  from torch.nn.parameter import Parameter
7
9
 
8
- from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
9
10
  from sglang.srt.layers.quantization.base_config import (
11
+ FusedMoEMethodBase,
10
12
  QuantizationConfig,
11
13
  QuantizeMethodBase,
12
14
  )
13
15
  from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod
16
+ from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
14
17
  from sglang.srt.layers.quantization.utils import is_layer_skipped
15
18
  from sglang.srt.utils import set_weight_attrs
16
19
 
@@ -62,7 +65,7 @@ class W4AFp8Config(QuantizationConfig):
62
65
  return []
63
66
 
64
67
  @classmethod
65
- def from_config(cls, config: Dict[str, Any]) -> "W4AFp8Config":
68
+ def from_config(cls, config: Dict[str, Any]) -> W4AFp8Config:
66
69
  quant_method = cls.get_from_keys(config, ["quant_method"])
67
70
  is_checkpoint_fp8_serialized = "fp8" in quant_method
68
71
  is_checkpoint_w4afp8_serialized = "w4afp8" in quant_method
@@ -79,7 +82,8 @@ class W4AFp8Config(QuantizationConfig):
79
82
 
80
83
  def get_quant_method(
81
84
  self, layer: torch.nn.Module, prefix: str
82
- ) -> Optional["QuantizeMethodBase"]:
85
+ ) -> Optional[QuantizeMethodBase]:
86
+ from sglang.srt.layers.linear import LinearBase
83
87
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
84
88
 
85
89
  if isinstance(layer, LinearBase):
@@ -94,7 +98,7 @@ class W4AFp8Config(QuantizationConfig):
94
98
  return []
95
99
 
96
100
 
97
- class W4AFp8MoEMethod:
101
+ class W4AFp8MoEMethod(FusedMoEMethodBase):
98
102
 
99
103
  def __init__(self, quant_config: W4AFp8Config):
100
104
  self.quant_config = quant_config
@@ -1,11 +1,14 @@
1
- from typing import Any, Callable, Dict, List, Optional
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional
2
4
 
3
5
  import torch
4
6
  from torch.nn.parameter import Parameter
5
7
 
6
- from sglang.srt.layers.linear import LinearMethodBase
7
8
  from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter
8
9
  from sglang.srt.layers.quantization.base_config import (
10
+ FusedMoEMethodBase,
11
+ LinearMethodBase,
9
12
  QuantizationConfig,
10
13
  QuantizeMethodBase,
11
14
  )
@@ -22,6 +25,9 @@ from sglang.srt.layers.quantization.fp8_utils import (
22
25
  )
23
26
  from sglang.srt.utils import set_weight_attrs
24
27
 
28
+ if TYPE_CHECKING:
29
+ from sglang.srt.layers.moe.topk import TopKOutput
30
+
25
31
  _is_fp8_fnuz = is_fp8_fnuz()
26
32
 
27
33
 
@@ -64,7 +70,7 @@ class W8A8Fp8Config(QuantizationConfig):
64
70
  return []
65
71
 
66
72
  @classmethod
67
- def from_config(cls, config: Dict[str, Any]) -> "W8A8Fp8Config":
73
+ def from_config(cls, config: Dict[str, Any]) -> W8A8Fp8Config:
68
74
  quant_method = cls.get_from_keys(config, ["quant_method"])
69
75
  is_checkpoint_fp8_serialized = (
70
76
  "compressed-tensors" in quant_method or "w8a8_fp8" in quant_method
@@ -75,7 +81,7 @@ class W8A8Fp8Config(QuantizationConfig):
75
81
  self,
76
82
  layer: torch.nn.Module,
77
83
  prefix: str,
78
- ) -> Optional["QuantizeMethodBase"]:
84
+ ) -> Optional[QuantizeMethodBase]:
79
85
  from sglang.srt.layers.linear import LinearBase
80
86
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
81
87
 
@@ -183,7 +189,7 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
183
189
  )
184
190
 
185
191
 
186
- class W8A8FP8MoEMethod:
192
+ class W8A8FP8MoEMethod(FusedMoEMethodBase):
187
193
  """MoE method for FP8.
188
194
  Supports loading FP8 checkpoints with static weight scale and
189
195
  dynamic/static activation scale.
@@ -194,25 +200,7 @@ class W8A8FP8MoEMethod:
194
200
  quant_config: The quantization config.
195
201
  """
196
202
 
197
- def __new__(cls, *args, **kwargs):
198
- from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
199
-
200
- if not hasattr(cls, "_initialized"):
201
- original_init = cls.__init__
202
- new_cls = type(
203
- cls.__name__,
204
- (FusedMoEMethodBase,),
205
- {
206
- "__init__": original_init,
207
- **{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
208
- },
209
- )
210
- obj = super(new_cls, new_cls).__new__(new_cls)
211
- obj.__init__(*args, **kwargs)
212
- return obj
213
- return super().__new__(cls)
214
-
215
- def __init__(self, quant_config):
203
+ def __init__(self, quant_config: W8A8Fp8Config):
216
204
  self.quant_config = quant_config
217
205
 
218
206
  def create_weights(
@@ -281,45 +269,23 @@ class W8A8FP8MoEMethod:
281
269
  self,
282
270
  layer: torch.nn.Module,
283
271
  x: torch.Tensor,
284
- router_logits: torch.Tensor,
285
- top_k: int,
286
- renormalize: bool,
287
- use_grouped_topk: bool,
288
- topk_group: Optional[int] = None,
289
- num_expert_group: Optional[int] = None,
290
- num_fused_shared_experts: int = 0,
291
- custom_routing_function: Optional[Callable] = None,
292
- correction_bias: Optional[torch.Tensor] = None,
272
+ topk_output: TopKOutput,
273
+ *,
293
274
  activation: str = "silu",
275
+ apply_router_weight_on_input: bool = False,
294
276
  inplace: bool = True,
295
277
  no_combine: bool = False,
296
278
  routed_scaling_factor: Optional[float] = None,
297
279
  ) -> torch.Tensor:
298
280
  from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
299
- from sglang.srt.layers.moe.topk import select_experts
300
-
301
- # Expert selection
302
- topk_weights, topk_ids = select_experts(
303
- hidden_states=x,
304
- router_logits=router_logits,
305
- use_grouped_topk=use_grouped_topk,
306
- top_k=top_k,
307
- renormalize=renormalize,
308
- topk_group=topk_group,
309
- num_expert_group=num_expert_group,
310
- num_fused_shared_experts=num_fused_shared_experts,
311
- custom_routing_function=custom_routing_function,
312
- correction_bias=correction_bias,
313
- routed_scaling_factor=routed_scaling_factor,
314
- )
315
281
 
316
282
  return fused_experts(
317
283
  x,
318
284
  layer.w13_weight,
319
285
  layer.w2_weight,
320
- topk_weights=topk_weights,
321
- topk_ids=topk_ids,
286
+ topk_output=topk_output,
322
287
  inplace=inplace,
288
+ apply_router_weight_on_input=apply_router_weight_on_input,
323
289
  activation=activation,
324
290
  use_fp8_w8a8=True,
325
291
  per_channel_quant=True,