sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (168) hide show
  1. sglang/bench_one_batch.py +2 -1
  2. sglang/eval/loogle_eval.py +7 -0
  3. sglang/srt/configs/deepseekvl2.py +11 -2
  4. sglang/srt/configs/internvl.py +3 -0
  5. sglang/srt/configs/janus_pro.py +3 -0
  6. sglang/srt/configs/model_config.py +9 -7
  7. sglang/srt/configs/update_config.py +3 -1
  8. sglang/srt/conversation.py +1 -0
  9. sglang/srt/custom_op.py +5 -2
  10. sglang/srt/disaggregation/decode.py +9 -1
  11. sglang/srt/disaggregation/mooncake/conn.py +44 -56
  12. sglang/srt/distributed/parallel_state.py +33 -0
  13. sglang/srt/entrypoints/engine.py +30 -26
  14. sglang/srt/entrypoints/openai/serving_chat.py +21 -2
  15. sglang/srt/eplb/expert_location_dispatch.py +1 -1
  16. sglang/srt/function_call/function_call_parser.py +2 -0
  17. sglang/srt/function_call/qwen3_detector.py +150 -0
  18. sglang/srt/hf_transformers_utils.py +0 -1
  19. sglang/srt/layers/activation.py +13 -0
  20. sglang/srt/layers/attention/flashattention_backend.py +3 -3
  21. sglang/srt/layers/attention/flashinfer_backend.py +40 -1
  22. sglang/srt/layers/linear.py +13 -102
  23. sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
  24. sglang/srt/layers/moe/ep_moe/layer.py +23 -402
  25. sglang/srt/layers/moe/fused_moe_native.py +7 -47
  26. sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
  27. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +35 -45
  33. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
  34. sglang/srt/layers/moe/topk.py +187 -12
  35. sglang/srt/layers/quantization/__init__.py +20 -134
  36. sglang/srt/layers/quantization/awq.py +578 -11
  37. sglang/srt/layers/quantization/awq_triton.py +339 -0
  38. sglang/srt/layers/quantization/base_config.py +85 -10
  39. sglang/srt/layers/quantization/blockwise_int8.py +17 -55
  40. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
  41. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +24 -73
  42. sglang/srt/layers/quantization/fp8.py +273 -62
  43. sglang/srt/layers/quantization/fp8_kernel.py +210 -46
  44. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  45. sglang/srt/layers/quantization/gptq.py +501 -143
  46. sglang/srt/layers/quantization/marlin_utils.py +790 -0
  47. sglang/srt/layers/quantization/modelopt_quant.py +26 -108
  48. sglang/srt/layers/quantization/moe_wna16.py +45 -49
  49. sglang/srt/layers/quantization/petit.py +252 -0
  50. sglang/srt/layers/quantization/petit_utils.py +104 -0
  51. sglang/srt/layers/quantization/qoq.py +7 -6
  52. sglang/srt/layers/quantization/scalar_type.py +352 -0
  53. sglang/srt/layers/quantization/unquant.py +422 -0
  54. sglang/srt/layers/quantization/utils.py +343 -3
  55. sglang/srt/layers/quantization/w4afp8.py +8 -4
  56. sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
  57. sglang/srt/layers/quantization/w8a8_int8.py +51 -115
  58. sglang/srt/layers/vocab_parallel_embedding.py +1 -41
  59. sglang/srt/lora/lora.py +0 -4
  60. sglang/srt/lora/lora_manager.py +87 -53
  61. sglang/srt/lora/mem_pool.py +81 -33
  62. sglang/srt/lora/utils.py +12 -5
  63. sglang/srt/managers/cache_controller.py +241 -0
  64. sglang/srt/managers/io_struct.py +41 -29
  65. sglang/srt/managers/mm_utils.py +7 -8
  66. sglang/srt/managers/schedule_batch.py +150 -110
  67. sglang/srt/managers/schedule_policy.py +68 -27
  68. sglang/srt/managers/scheduler.py +243 -61
  69. sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
  70. sglang/srt/managers/tokenizer_manager.py +11 -3
  71. sglang/srt/managers/tp_worker.py +14 -0
  72. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  73. sglang/srt/mem_cache/allocator.py +7 -16
  74. sglang/srt/mem_cache/base_prefix_cache.py +14 -2
  75. sglang/srt/mem_cache/chunk_cache.py +5 -2
  76. sglang/srt/mem_cache/hicache_storage.py +152 -0
  77. sglang/srt/mem_cache/hiradix_cache.py +179 -4
  78. sglang/srt/mem_cache/memory_pool.py +16 -1
  79. sglang/srt/mem_cache/memory_pool_host.py +41 -2
  80. sglang/srt/mem_cache/radix_cache.py +26 -0
  81. sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
  82. sglang/srt/metrics/collector.py +9 -0
  83. sglang/srt/model_executor/cuda_graph_runner.py +5 -6
  84. sglang/srt/model_executor/forward_batch_info.py +14 -1
  85. sglang/srt/model_executor/model_runner.py +109 -22
  86. sglang/srt/model_loader/loader.py +7 -1
  87. sglang/srt/model_loader/utils.py +4 -4
  88. sglang/srt/models/clip.py +1 -1
  89. sglang/srt/models/deepseek.py +9 -6
  90. sglang/srt/models/deepseek_janus_pro.py +1 -1
  91. sglang/srt/models/deepseek_v2.py +191 -171
  92. sglang/srt/models/deepseek_vl2.py +5 -5
  93. sglang/srt/models/gemma.py +48 -0
  94. sglang/srt/models/gemma2.py +52 -0
  95. sglang/srt/models/gemma3_causal.py +63 -0
  96. sglang/srt/models/gemma3_mm.py +1 -1
  97. sglang/srt/models/gemma3n_mm.py +2 -4
  98. sglang/srt/models/granitemoe.py +385 -0
  99. sglang/srt/models/grok.py +9 -3
  100. sglang/srt/models/hunyuan.py +63 -16
  101. sglang/srt/models/internvl.py +1 -1
  102. sglang/srt/models/kimi_vl.py +1 -1
  103. sglang/srt/models/llama.py +41 -0
  104. sglang/srt/models/llama4.py +11 -11
  105. sglang/srt/models/llava.py +2 -2
  106. sglang/srt/models/llavavid.py +1 -1
  107. sglang/srt/models/minicpm.py +0 -2
  108. sglang/srt/models/minicpmo.py +3 -7
  109. sglang/srt/models/minicpmv.py +1 -1
  110. sglang/srt/models/mistral.py +1 -1
  111. sglang/srt/models/mixtral.py +9 -2
  112. sglang/srt/models/mllama.py +3 -5
  113. sglang/srt/models/mllama4.py +3 -3
  114. sglang/srt/models/olmoe.py +8 -5
  115. sglang/srt/models/persimmon.py +330 -0
  116. sglang/srt/models/phi.py +321 -0
  117. sglang/srt/models/phi4mm.py +44 -4
  118. sglang/srt/models/phi4mm_audio.py +1260 -0
  119. sglang/srt/models/phi4mm_utils.py +1917 -0
  120. sglang/srt/models/phimoe.py +9 -3
  121. sglang/srt/models/qwen.py +37 -0
  122. sglang/srt/models/qwen2.py +41 -0
  123. sglang/srt/models/qwen2_5_vl.py +4 -4
  124. sglang/srt/models/qwen2_audio.py +1 -1
  125. sglang/srt/models/qwen2_moe.py +53 -5
  126. sglang/srt/models/qwen2_vl.py +4 -4
  127. sglang/srt/models/qwen3.py +65 -1
  128. sglang/srt/models/qwen3_moe.py +56 -18
  129. sglang/srt/models/vila.py +1 -1
  130. sglang/srt/multimodal/processors/base_processor.py +91 -97
  131. sglang/srt/multimodal/processors/clip.py +21 -19
  132. sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
  133. sglang/srt/multimodal/processors/gemma3.py +13 -17
  134. sglang/srt/multimodal/processors/gemma3n.py +19 -23
  135. sglang/srt/multimodal/processors/internvl.py +9 -10
  136. sglang/srt/multimodal/processors/janus_pro.py +12 -27
  137. sglang/srt/multimodal/processors/kimi_vl.py +12 -14
  138. sglang/srt/multimodal/processors/llava.py +4 -2
  139. sglang/srt/multimodal/processors/minicpm.py +35 -44
  140. sglang/srt/multimodal/processors/mlama.py +21 -18
  141. sglang/srt/multimodal/processors/mllama4.py +4 -5
  142. sglang/srt/multimodal/processors/phi4mm.py +63 -39
  143. sglang/srt/multimodal/processors/pixtral.py +14 -35
  144. sglang/srt/multimodal/processors/qwen_audio.py +65 -0
  145. sglang/srt/multimodal/processors/qwen_vl.py +16 -21
  146. sglang/srt/multimodal/processors/vila.py +14 -14
  147. sglang/srt/sampling/sampling_params.py +8 -1
  148. sglang/srt/server_args.py +393 -230
  149. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +9 -1
  150. sglang/srt/two_batch_overlap.py +1 -0
  151. sglang/srt/utils.py +27 -1
  152. sglang/test/runners.py +14 -3
  153. sglang/test/test_block_fp8.py +8 -3
  154. sglang/test/test_block_fp8_ep.py +1 -1
  155. sglang/test/test_custom_ops.py +12 -7
  156. sglang/test/test_cutlass_w4a8_moe.py +1 -3
  157. sglang/test/test_fp4_moe.py +1 -3
  158. sglang/test/test_marlin_moe.py +286 -0
  159. sglang/test/test_marlin_utils.py +171 -0
  160. sglang/test/test_utils.py +35 -0
  161. sglang/version.py +1 -1
  162. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/METADATA +8 -8
  163. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/RECORD +166 -146
  164. sglang/srt/layers/quantization/quant_utils.py +0 -166
  165. sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
  166. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/WHEEL +0 -0
  167. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/licenses/LICENSE +0 -0
  168. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/top_level.txt +0 -0
@@ -1,19 +1,29 @@
1
1
  # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
2
2
 
3
+ from __future__ import annotations
4
+
5
+ import re
6
+ from copy import deepcopy
3
7
  from types import MappingProxyType
4
- from typing import List, Mapping, Tuple, Union
8
+ from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union
5
9
 
10
+ import numpy
6
11
  import torch
7
12
 
8
13
  from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
9
- from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_npu
14
+ from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types
15
+ from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip, is_npu
16
+
17
+ if TYPE_CHECKING:
18
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
10
19
 
11
20
  _is_cuda = is_cuda()
12
21
  _is_npu = is_npu()
13
22
  _is_cpu_amx_available = cpu_has_amx_support()
14
23
  _is_cpu = is_cpu()
24
+ _is_hip = is_hip()
15
25
 
16
- if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)):
26
+ if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip):
17
27
  from vllm._custom_ops import scaled_fp8_quant
18
28
 
19
29
 
@@ -143,3 +153,333 @@ def replace_parameter(
143
153
  if not isinstance(new, torch.nn.Parameter):
144
154
  new = torch.nn.Parameter(new, requires_grad=False)
145
155
  mod.register_parameter(name, torch.nn.Parameter(new, requires_grad=False))
156
+
157
+
158
+ # Match dynamic rules with module name (prefix) and override quantize
159
+ # config if module (prefix) matches a rule
160
+ def override_config(config: QuantizationConfig, prefix: str):
161
+ weight_bits = get_dynamic_override(config, prefix, "bits", config.weight_bits)
162
+ if isinstance(weight_bits, int):
163
+ config.weight_bits = weight_bits
164
+ group_size = get_dynamic_override(config, prefix, "group_size", config.group_size)
165
+ if isinstance(group_size, int):
166
+ config.group_size = group_size
167
+ desc_act = get_dynamic_override(config, prefix, "desc_act", config.desc_act)
168
+ if isinstance(desc_act, bool):
169
+ config.desc_act = desc_act
170
+
171
+ config.pack_factor = 32 // config.weight_bits # packed into int32
172
+ if config.get_name() == "gptq_marlin":
173
+ is_sym = get_dynamic_override(config, prefix, "sym", config.is_sym)
174
+ if isinstance(is_sym, bool):
175
+ config.is_sym = is_sym
176
+
177
+ if (config.weight_bits, config.is_sym) not in config.TYPE_MAP:
178
+ raise ValueError(
179
+ "Unsupported quantization config: "
180
+ f"bits={config.weight_bits}, sym={config.is_sym}"
181
+ )
182
+
183
+ config.quant_type = config.TYPE_MAP[(config.weight_bits, config.is_sym)]
184
+ elif config.get_name() == "gptq":
185
+ if config.weight_bits not in [2, 3, 4, 8]:
186
+ raise ValueError(
187
+ "Currently, only 2/3/4/8-bit weight quantization is "
188
+ f"supported for GPTQ, but got {config.weight_bits} bits."
189
+ )
190
+
191
+
192
+ def get_dynamic_override(
193
+ config: QuantizationConfig,
194
+ layer_name: str,
195
+ key: Optional[str] = None,
196
+ default_value: Union[int, bool, None] = None,
197
+ ) -> Union[Dict, int, bool, None]:
198
+ for pattern, pattern_dict in config.dynamic.items():
199
+ # Negative match: matched modules are excluded from quantized init
200
+ if pattern.startswith("-:"):
201
+ if re.match(pattern.removeprefix("-:"), layer_name):
202
+ return False
203
+ # Positive match: matched modules have quant properties overrides
204
+ # base quant config
205
+ elif re.match(pattern.removeprefix("+:"), layer_name):
206
+ if key is None:
207
+ return pattern_dict
208
+ else:
209
+ return pattern_dict.get(key, default_value)
210
+ return default_value
211
+
212
+
213
+ def get_linear_quant_method(
214
+ config: QuantizationConfig,
215
+ layer: torch.nn.Module,
216
+ prefix: str,
217
+ linear_method_cls: type,
218
+ ):
219
+ from sglang.srt.layers.linear import LinearBase
220
+ from sglang.srt.layers.quantization.unquant import (
221
+ UnquantizedEmbeddingMethod,
222
+ UnquantizedLinearMethod,
223
+ )
224
+ from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
225
+
226
+ cloned_config = deepcopy(config)
227
+ parallel_lm_head_quantized = (
228
+ isinstance(layer, ParallelLMHead) and cloned_config.lm_head_quantized
229
+ )
230
+
231
+ if isinstance(layer, LinearBase) or parallel_lm_head_quantized:
232
+ # False = skip module, None = no override, else = Positive match
233
+ if get_dynamic_override(cloned_config, layer_name=prefix) is False:
234
+ if parallel_lm_head_quantized:
235
+ return UnquantizedEmbeddingMethod()
236
+ return UnquantizedLinearMethod()
237
+
238
+ if prefix:
239
+ # Dynamic per module/layer rules may override base config
240
+ override_config(cloned_config, prefix=prefix)
241
+
242
+ return linear_method_cls(cloned_config)
243
+ return None
244
+
245
+
246
+ def get_pack_factor(num_bits):
247
+ assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}"
248
+ return 32 // num_bits
249
+
250
+
251
+ def permute_rows(
252
+ q_w: torch.Tensor,
253
+ w_ref: torch.Tensor,
254
+ group_size: int,
255
+ test_perm: Optional[torch.Tensor] = None,
256
+ ):
257
+ assert q_w.shape == w_ref.shape
258
+
259
+ orig_device = q_w.device
260
+ k_size, _ = q_w.shape
261
+
262
+ g_idx = torch.zeros((k_size,), dtype=torch.int32)
263
+ for i in range(k_size):
264
+ g_idx[i] = i // group_size
265
+
266
+ # Simulate act_order by doing a random permutation on K
267
+ rand_perm = test_perm if test_perm is not None else torch.randperm(k_size)
268
+
269
+ g_idx = g_idx[rand_perm].contiguous()
270
+ q_w = q_w[rand_perm, :].contiguous()
271
+ w_ref = w_ref[rand_perm, :].contiguous()
272
+
273
+ return (
274
+ w_ref.to(device=orig_device),
275
+ q_w.to(device=orig_device),
276
+ g_idx.to(device=orig_device),
277
+ rand_perm.to(device=orig_device),
278
+ )
279
+
280
+
281
+ def pack_cols(
282
+ q_w: torch.Tensor,
283
+ num_bits: int,
284
+ size_k: int,
285
+ size_n: int,
286
+ ):
287
+ assert q_w.shape == (size_k, size_n)
288
+
289
+ pack_factor = get_pack_factor(num_bits)
290
+ assert size_n % pack_factor == 0
291
+
292
+ orig_device = q_w.device
293
+
294
+ q_w = q_w.cpu().numpy().astype(numpy.uint32)
295
+
296
+ q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32)
297
+
298
+ for i in range(pack_factor):
299
+ q_res |= q_w[:, i::pack_factor] << num_bits * i
300
+
301
+ q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
302
+ q_res = q_res.contiguous()
303
+
304
+ return q_res
305
+
306
+
307
+ def unpack_cols(
308
+ packed_q_w: torch.Tensor,
309
+ num_bits: int,
310
+ size_k: int,
311
+ size_n: int,
312
+ ):
313
+ pack_factor = get_pack_factor(num_bits)
314
+ assert size_n % pack_factor == 0
315
+ assert packed_q_w.shape == (
316
+ size_k,
317
+ size_n // pack_factor,
318
+ ), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format(
319
+ packed_q_w.shape, size_k, size_n, pack_factor
320
+ )
321
+
322
+ orig_device = packed_q_w.device
323
+
324
+ packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32)
325
+ q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32)
326
+
327
+ mask = (1 << num_bits) - 1
328
+ for i in range(pack_factor):
329
+ vals = packed_q_w_cpu & mask
330
+ packed_q_w_cpu >>= num_bits
331
+ q_res[:, i::pack_factor] = vals
332
+
333
+ q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
334
+ q_res = q_res.contiguous()
335
+
336
+ return q_res
337
+
338
+
339
+ # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
340
+ def quantize_weights(
341
+ w: torch.Tensor,
342
+ quant_type: ScalarType,
343
+ group_size: Optional[int],
344
+ zero_points: bool = False,
345
+ ref_zero_points_after_scales: bool = False,
346
+ ):
347
+ assert (
348
+ quant_type.is_integer()
349
+ ), "Floating point quantization may work but has not been tested"
350
+ assert not zero_points or group_size is not None, (
351
+ "to have group zero points, group_size must be provided "
352
+ "(-1 group_size is channelwise)"
353
+ )
354
+
355
+ orig_device = w.device
356
+ orig_type = w.dtype
357
+ size_k, size_n = w.shape
358
+
359
+ assert w.is_floating_point(), "w must be float"
360
+
361
+ if group_size == -1:
362
+ group_size = size_k
363
+
364
+ # Reshape to [groupsize, -1]
365
+ if group_size is not None and group_size < size_k:
366
+ w = w.reshape((-1, group_size, size_n))
367
+ w = w.permute(1, 0, 2)
368
+ w = w.reshape((group_size, -1))
369
+
370
+ # Compute scale for each group
371
+ max_val = torch.max(w, 0, keepdim=True).values
372
+ min_val = torch.min(w, 0, keepdim=True).values
373
+
374
+ max_q_val = quant_type.max()
375
+ min_q_val = quant_type.min()
376
+
377
+ w_s = torch.Tensor([1.0]).to(w.device) # unscaled case
378
+ maybe_w_zp = None
379
+ if group_size is not None:
380
+ if zero_points:
381
+ assert not quant_type.is_signed() and quant_type.max() > 0
382
+ w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max()
383
+ maybe_w_zp = (
384
+ torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int()
385
+ )
386
+ else:
387
+ # If the bias is such that there are no possible negative/positive
388
+ # values, set the max value to inf to avoid divide by 0
389
+ w_s = torch.max(
390
+ abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)),
391
+ abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)),
392
+ )
393
+
394
+ # Quantize
395
+ w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0)
396
+ w_q = torch.clamp(w_q, min_q_val, max_q_val)
397
+
398
+ # Compute ref (dequantized)
399
+ # For some kernels (namely Machete) the zero-points are applied after the
400
+ # scales are applied, for this case computing the reference in similar way
401
+ # allows us to use tighter error tolerances in our unit tests.
402
+ if ref_zero_points_after_scales and maybe_w_zp is not None:
403
+ w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s
404
+ else:
405
+ w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s
406
+
407
+ if quant_type.has_bias():
408
+ w_q += quant_type.bias
409
+
410
+ # Restore original shapes
411
+ if group_size is not None and group_size < size_k:
412
+
413
+ def reshape_w(w):
414
+ w = w.reshape((group_size, -1, size_n))
415
+ w = w.permute(1, 0, 2)
416
+ w = w.reshape((size_k, size_n)).contiguous()
417
+ return w
418
+
419
+ w_q = reshape_w(w_q)
420
+ w_ref = reshape_w(w_ref)
421
+ w_s = w_s.reshape((-1, size_n)).contiguous()
422
+
423
+ if maybe_w_zp is not None:
424
+ maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous()
425
+ maybe_w_zp = maybe_w_zp.to(device=orig_device)
426
+
427
+ return (
428
+ w_ref.to(device=orig_device),
429
+ w_q.to(device=orig_device),
430
+ w_s if group_size is not None else None,
431
+ maybe_w_zp,
432
+ )
433
+
434
+
435
+ SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
436
+ SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
437
+
438
+
439
+ def gptq_quantize_weights(
440
+ w: torch.Tensor,
441
+ quant_type: ScalarType,
442
+ group_size: int,
443
+ act_order: bool,
444
+ test_perm: Optional[torch.Tensor] = None,
445
+ ):
446
+ size_k, _ = w.shape
447
+
448
+ assert w.is_floating_point(), "w must be float"
449
+ assert (
450
+ quant_type in SUPPORTED_GPTQ_QUANT_TYPES
451
+ ), f"Unsupported gptq type = {quant_type}"
452
+ assert group_size in SUPPORTED_GROUP_SIZES + [
453
+ size_k
454
+ ], f"Unsupported groupsize = {group_size}"
455
+
456
+ w_ref, w_q, w_s, _ = quantize_weights(w, quant_type, group_size)
457
+
458
+ # Apply act_order
459
+ g_idx = torch.empty(0, dtype=torch.int, device=w.device)
460
+ rand_perm = torch.empty(0, dtype=torch.int, device=w.device)
461
+ if act_order:
462
+ assert (
463
+ group_size < size_k
464
+ ), "For act_order, groupsize = {} must be less than size_k = {}".format(
465
+ group_size, size_k
466
+ )
467
+
468
+ w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size, test_perm)
469
+
470
+ return w_ref, w_q, w_s, g_idx, rand_perm
471
+
472
+
473
+ def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor):
474
+ orig_device = q_w.device
475
+
476
+ sort_indices = torch.argsort(g_idx).to(dtype=torch.int32) # Sort based on g_idx
477
+
478
+ g_idx = g_idx[sort_indices].contiguous()
479
+ q_w = q_w[sort_indices, :].contiguous()
480
+
481
+ return (
482
+ q_w.to(device=orig_device),
483
+ g_idx.to(device=orig_device),
484
+ sort_indices.to(device=orig_device),
485
+ )
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  from typing import Any, Dict, List, Optional
3
5
 
@@ -5,12 +7,13 @@ import torch
5
7
  from torch.nn import Module
6
8
  from torch.nn.parameter import Parameter
7
9
 
8
- from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
9
10
  from sglang.srt.layers.quantization.base_config import (
11
+ FusedMoEMethodBase,
10
12
  QuantizationConfig,
11
13
  QuantizeMethodBase,
12
14
  )
13
15
  from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod
16
+ from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
14
17
  from sglang.srt.layers.quantization.utils import is_layer_skipped
15
18
  from sglang.srt.utils import set_weight_attrs
16
19
 
@@ -62,7 +65,7 @@ class W4AFp8Config(QuantizationConfig):
62
65
  return []
63
66
 
64
67
  @classmethod
65
- def from_config(cls, config: Dict[str, Any]) -> "W4AFp8Config":
68
+ def from_config(cls, config: Dict[str, Any]) -> W4AFp8Config:
66
69
  quant_method = cls.get_from_keys(config, ["quant_method"])
67
70
  is_checkpoint_fp8_serialized = "fp8" in quant_method
68
71
  is_checkpoint_w4afp8_serialized = "w4afp8" in quant_method
@@ -79,7 +82,8 @@ class W4AFp8Config(QuantizationConfig):
79
82
 
80
83
  def get_quant_method(
81
84
  self, layer: torch.nn.Module, prefix: str
82
- ) -> Optional["QuantizeMethodBase"]:
85
+ ) -> Optional[QuantizeMethodBase]:
86
+ from sglang.srt.layers.linear import LinearBase
83
87
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
84
88
 
85
89
  if isinstance(layer, LinearBase):
@@ -94,7 +98,7 @@ class W4AFp8Config(QuantizationConfig):
94
98
  return []
95
99
 
96
100
 
97
- class W4AFp8MoEMethod:
101
+ class W4AFp8MoEMethod(FusedMoEMethodBase):
98
102
 
99
103
  def __init__(self, quant_config: W4AFp8Config):
100
104
  self.quant_config = quant_config
@@ -1,11 +1,14 @@
1
- from typing import Any, Callable, Dict, List, Optional
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional
2
4
 
3
5
  import torch
4
6
  from torch.nn.parameter import Parameter
5
7
 
6
- from sglang.srt.layers.linear import LinearMethodBase
7
8
  from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter
8
9
  from sglang.srt.layers.quantization.base_config import (
10
+ FusedMoEMethodBase,
11
+ LinearMethodBase,
9
12
  QuantizationConfig,
10
13
  QuantizeMethodBase,
11
14
  )
@@ -22,6 +25,9 @@ from sglang.srt.layers.quantization.fp8_utils import (
22
25
  )
23
26
  from sglang.srt.utils import set_weight_attrs
24
27
 
28
+ if TYPE_CHECKING:
29
+ from sglang.srt.layers.moe.topk import TopKOutput
30
+
25
31
  _is_fp8_fnuz = is_fp8_fnuz()
26
32
 
27
33
 
@@ -64,7 +70,7 @@ class W8A8Fp8Config(QuantizationConfig):
64
70
  return []
65
71
 
66
72
  @classmethod
67
- def from_config(cls, config: Dict[str, Any]) -> "W8A8Fp8Config":
73
+ def from_config(cls, config: Dict[str, Any]) -> W8A8Fp8Config:
68
74
  quant_method = cls.get_from_keys(config, ["quant_method"])
69
75
  is_checkpoint_fp8_serialized = (
70
76
  "compressed-tensors" in quant_method or "w8a8_fp8" in quant_method
@@ -75,7 +81,7 @@ class W8A8Fp8Config(QuantizationConfig):
75
81
  self,
76
82
  layer: torch.nn.Module,
77
83
  prefix: str,
78
- ) -> Optional["QuantizeMethodBase"]:
84
+ ) -> Optional[QuantizeMethodBase]:
79
85
  from sglang.srt.layers.linear import LinearBase
80
86
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
81
87
 
@@ -183,7 +189,7 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
183
189
  )
184
190
 
185
191
 
186
- class W8A8FP8MoEMethod:
192
+ class W8A8FP8MoEMethod(FusedMoEMethodBase):
187
193
  """MoE method for FP8.
188
194
  Supports loading FP8 checkpoints with static weight scale and
189
195
  dynamic/static activation scale.
@@ -194,25 +200,7 @@ class W8A8FP8MoEMethod:
194
200
  quant_config: The quantization config.
195
201
  """
196
202
 
197
- def __new__(cls, *args, **kwargs):
198
- from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
199
-
200
- if not hasattr(cls, "_initialized"):
201
- original_init = cls.__init__
202
- new_cls = type(
203
- cls.__name__,
204
- (FusedMoEMethodBase,),
205
- {
206
- "__init__": original_init,
207
- **{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
208
- },
209
- )
210
- obj = super(new_cls, new_cls).__new__(new_cls)
211
- obj.__init__(*args, **kwargs)
212
- return obj
213
- return super().__new__(cls)
214
-
215
- def __init__(self, quant_config):
203
+ def __init__(self, quant_config: W8A8Fp8Config):
216
204
  self.quant_config = quant_config
217
205
 
218
206
  def create_weights(
@@ -281,45 +269,23 @@ class W8A8FP8MoEMethod:
281
269
  self,
282
270
  layer: torch.nn.Module,
283
271
  x: torch.Tensor,
284
- router_logits: torch.Tensor,
285
- top_k: int,
286
- renormalize: bool,
287
- use_grouped_topk: bool,
288
- topk_group: Optional[int] = None,
289
- num_expert_group: Optional[int] = None,
290
- num_fused_shared_experts: int = 0,
291
- custom_routing_function: Optional[Callable] = None,
292
- correction_bias: Optional[torch.Tensor] = None,
272
+ topk_output: TopKOutput,
273
+ *,
293
274
  activation: str = "silu",
275
+ apply_router_weight_on_input: bool = False,
294
276
  inplace: bool = True,
295
277
  no_combine: bool = False,
296
278
  routed_scaling_factor: Optional[float] = None,
297
279
  ) -> torch.Tensor:
298
280
  from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
299
- from sglang.srt.layers.moe.topk import select_experts
300
-
301
- # Expert selection
302
- topk_weights, topk_ids = select_experts(
303
- hidden_states=x,
304
- router_logits=router_logits,
305
- use_grouped_topk=use_grouped_topk,
306
- top_k=top_k,
307
- renormalize=renormalize,
308
- topk_group=topk_group,
309
- num_expert_group=num_expert_group,
310
- num_fused_shared_experts=num_fused_shared_experts,
311
- custom_routing_function=custom_routing_function,
312
- correction_bias=correction_bias,
313
- routed_scaling_factor=routed_scaling_factor,
314
- )
315
281
 
316
282
  return fused_experts(
317
283
  x,
318
284
  layer.w13_weight,
319
285
  layer.w2_weight,
320
- topk_weights=topk_weights,
321
- topk_ids=topk_ids,
286
+ topk_output=topk_output,
322
287
  inplace=inplace,
288
+ apply_router_weight_on_input=apply_router_weight_on_input,
323
289
  activation=activation,
324
290
  use_fp8_w8a8=True,
325
291
  per_channel_quant=True,