sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. sglang/bench_one_batch.py +2 -1
  2. sglang/eval/loogle_eval.py +7 -0
  3. sglang/srt/_custom_ops.py +29 -1
  4. sglang/srt/configs/deepseekvl2.py +11 -2
  5. sglang/srt/configs/internvl.py +3 -0
  6. sglang/srt/configs/janus_pro.py +3 -0
  7. sglang/srt/configs/model_config.py +10 -8
  8. sglang/srt/configs/update_config.py +3 -1
  9. sglang/srt/conversation.py +2 -1
  10. sglang/srt/custom_op.py +5 -2
  11. sglang/srt/disaggregation/common/conn.py +34 -6
  12. sglang/srt/disaggregation/decode.py +9 -1
  13. sglang/srt/disaggregation/mini_lb.py +3 -2
  14. sglang/srt/disaggregation/mooncake/conn.py +93 -76
  15. sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
  16. sglang/srt/disaggregation/nixl/conn.py +17 -13
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
  18. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
  19. sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
  20. sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
  21. sglang/srt/distributed/parallel_state.py +103 -15
  22. sglang/srt/entrypoints/engine.py +31 -33
  23. sglang/srt/entrypoints/http_server.py +20 -32
  24. sglang/srt/entrypoints/openai/protocol.py +3 -3
  25. sglang/srt/entrypoints/openai/serving_chat.py +48 -6
  26. sglang/srt/eplb/expert_location_dispatch.py +1 -1
  27. sglang/srt/function_call/base_format_detector.py +74 -12
  28. sglang/srt/function_call/deepseekv3_detector.py +26 -11
  29. sglang/srt/function_call/ebnf_composer.py +95 -63
  30. sglang/srt/function_call/function_call_parser.py +4 -2
  31. sglang/srt/function_call/kimik2_detector.py +41 -16
  32. sglang/srt/function_call/llama32_detector.py +6 -3
  33. sglang/srt/function_call/mistral_detector.py +11 -3
  34. sglang/srt/function_call/pythonic_detector.py +16 -14
  35. sglang/srt/function_call/qwen25_detector.py +12 -3
  36. sglang/srt/function_call/qwen3_coder_detector.py +151 -0
  37. sglang/srt/hf_transformers_utils.py +0 -1
  38. sglang/srt/layers/activation.py +24 -3
  39. sglang/srt/layers/attention/base_attn_backend.py +3 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +3 -3
  41. sglang/srt/layers/attention/flashinfer_backend.py +40 -1
  42. sglang/srt/layers/communicator.py +12 -12
  43. sglang/srt/layers/dp_attention.py +72 -24
  44. sglang/srt/layers/linear.py +13 -102
  45. sglang/srt/layers/logits_processor.py +34 -24
  46. sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
  47. sglang/srt/layers/moe/ep_moe/layer.py +23 -402
  48. sglang/srt/layers/moe/fused_moe_native.py +7 -47
  49. sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
  50. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +54 -263
  57. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
  58. sglang/srt/layers/moe/topk.py +190 -23
  59. sglang/srt/layers/quantization/__init__.py +20 -134
  60. sglang/srt/layers/quantization/awq.py +578 -11
  61. sglang/srt/layers/quantization/awq_triton.py +339 -0
  62. sglang/srt/layers/quantization/base_config.py +85 -10
  63. sglang/srt/layers/quantization/blockwise_int8.py +17 -55
  64. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
  65. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +23 -79
  66. sglang/srt/layers/quantization/fp8.py +273 -62
  67. sglang/srt/layers/quantization/fp8_kernel.py +210 -46
  68. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  69. sglang/srt/layers/quantization/gptq.py +501 -143
  70. sglang/srt/layers/quantization/marlin_utils.py +790 -0
  71. sglang/srt/layers/quantization/modelopt_quant.py +34 -112
  72. sglang/srt/layers/quantization/moe_wna16.py +45 -49
  73. sglang/srt/layers/quantization/petit.py +252 -0
  74. sglang/srt/layers/quantization/petit_utils.py +104 -0
  75. sglang/srt/layers/quantization/qoq.py +7 -6
  76. sglang/srt/layers/quantization/scalar_type.py +352 -0
  77. sglang/srt/layers/quantization/unquant.py +422 -0
  78. sglang/srt/layers/quantization/utils.py +340 -9
  79. sglang/srt/layers/quantization/w4afp8.py +8 -4
  80. sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
  81. sglang/srt/layers/quantization/w8a8_int8.py +51 -115
  82. sglang/srt/layers/radix_attention.py +5 -3
  83. sglang/srt/layers/vocab_parallel_embedding.py +1 -41
  84. sglang/srt/lora/lora.py +0 -4
  85. sglang/srt/lora/lora_manager.py +162 -164
  86. sglang/srt/lora/lora_registry.py +124 -0
  87. sglang/srt/lora/mem_pool.py +83 -35
  88. sglang/srt/lora/utils.py +12 -5
  89. sglang/srt/managers/cache_controller.py +288 -0
  90. sglang/srt/managers/io_struct.py +60 -30
  91. sglang/srt/managers/mm_utils.py +7 -8
  92. sglang/srt/managers/schedule_batch.py +163 -113
  93. sglang/srt/managers/schedule_policy.py +68 -27
  94. sglang/srt/managers/scheduler.py +256 -86
  95. sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
  96. sglang/srt/managers/tokenizer_manager.py +38 -27
  97. sglang/srt/managers/tp_worker.py +16 -4
  98. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  99. sglang/srt/mem_cache/allocator.py +74 -23
  100. sglang/srt/mem_cache/base_prefix_cache.py +14 -2
  101. sglang/srt/mem_cache/chunk_cache.py +5 -2
  102. sglang/srt/mem_cache/hicache_storage.py +168 -0
  103. sglang/srt/mem_cache/hiradix_cache.py +194 -5
  104. sglang/srt/mem_cache/memory_pool.py +16 -1
  105. sglang/srt/mem_cache/memory_pool_host.py +44 -2
  106. sglang/srt/mem_cache/radix_cache.py +26 -0
  107. sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
  108. sglang/srt/metrics/collector.py +9 -0
  109. sglang/srt/model_executor/cuda_graph_runner.py +66 -31
  110. sglang/srt/model_executor/forward_batch_info.py +210 -25
  111. sglang/srt/model_executor/model_runner.py +147 -42
  112. sglang/srt/model_loader/loader.py +7 -1
  113. sglang/srt/model_loader/utils.py +4 -4
  114. sglang/srt/models/clip.py +1 -1
  115. sglang/srt/models/deepseek.py +9 -6
  116. sglang/srt/models/deepseek_janus_pro.py +1 -1
  117. sglang/srt/models/deepseek_v2.py +192 -173
  118. sglang/srt/models/deepseek_vl2.py +5 -5
  119. sglang/srt/models/gemma.py +48 -0
  120. sglang/srt/models/gemma2.py +52 -0
  121. sglang/srt/models/gemma3_causal.py +63 -0
  122. sglang/srt/models/gemma3_mm.py +1 -1
  123. sglang/srt/models/gemma3n_mm.py +2 -4
  124. sglang/srt/models/granitemoe.py +385 -0
  125. sglang/srt/models/grok.py +9 -3
  126. sglang/srt/models/hunyuan.py +63 -16
  127. sglang/srt/models/internvl.py +1 -1
  128. sglang/srt/models/kimi_vl.py +1 -1
  129. sglang/srt/models/llama.py +41 -0
  130. sglang/srt/models/llama4.py +11 -11
  131. sglang/srt/models/llava.py +2 -2
  132. sglang/srt/models/llavavid.py +1 -1
  133. sglang/srt/models/minicpm.py +0 -2
  134. sglang/srt/models/minicpmo.py +3 -7
  135. sglang/srt/models/minicpmv.py +1 -1
  136. sglang/srt/models/mistral.py +1 -1
  137. sglang/srt/models/mixtral.py +9 -2
  138. sglang/srt/models/mllama.py +3 -5
  139. sglang/srt/models/mllama4.py +13 -6
  140. sglang/srt/models/olmoe.py +8 -5
  141. sglang/srt/models/persimmon.py +330 -0
  142. sglang/srt/models/phi.py +321 -0
  143. sglang/srt/models/phi4mm.py +44 -4
  144. sglang/srt/models/phi4mm_audio.py +1260 -0
  145. sglang/srt/models/phi4mm_utils.py +1917 -0
  146. sglang/srt/models/phimoe.py +9 -3
  147. sglang/srt/models/qwen.py +37 -0
  148. sglang/srt/models/qwen2.py +41 -0
  149. sglang/srt/models/qwen2_5_vl.py +4 -4
  150. sglang/srt/models/qwen2_audio.py +1 -1
  151. sglang/srt/models/qwen2_moe.py +53 -9
  152. sglang/srt/models/qwen2_vl.py +4 -4
  153. sglang/srt/models/qwen3.py +65 -1
  154. sglang/srt/models/qwen3_moe.py +57 -24
  155. sglang/srt/models/vila.py +1 -1
  156. sglang/srt/multimodal/processors/base_processor.py +91 -97
  157. sglang/srt/multimodal/processors/clip.py +21 -19
  158. sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
  159. sglang/srt/multimodal/processors/gemma3.py +13 -17
  160. sglang/srt/multimodal/processors/gemma3n.py +19 -23
  161. sglang/srt/multimodal/processors/internvl.py +9 -10
  162. sglang/srt/multimodal/processors/janus_pro.py +12 -27
  163. sglang/srt/multimodal/processors/kimi_vl.py +12 -14
  164. sglang/srt/multimodal/processors/llava.py +4 -2
  165. sglang/srt/multimodal/processors/minicpm.py +35 -44
  166. sglang/srt/multimodal/processors/mlama.py +21 -18
  167. sglang/srt/multimodal/processors/mllama4.py +4 -5
  168. sglang/srt/multimodal/processors/phi4mm.py +63 -39
  169. sglang/srt/multimodal/processors/pixtral.py +14 -35
  170. sglang/srt/multimodal/processors/qwen_audio.py +65 -0
  171. sglang/srt/multimodal/processors/qwen_vl.py +16 -21
  172. sglang/srt/multimodal/processors/vila.py +14 -14
  173. sglang/srt/reasoning_parser.py +46 -4
  174. sglang/srt/sampling/sampling_batch_info.py +6 -5
  175. sglang/srt/sampling/sampling_params.py +8 -1
  176. sglang/srt/server_args.py +454 -270
  177. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
  178. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +46 -37
  179. sglang/srt/speculative/eagle_utils.py +51 -23
  180. sglang/srt/speculative/eagle_worker.py +59 -44
  181. sglang/srt/two_batch_overlap.py +10 -5
  182. sglang/srt/utils.py +44 -69
  183. sglang/test/runners.py +14 -3
  184. sglang/test/test_activation.py +50 -1
  185. sglang/test/test_block_fp8.py +8 -3
  186. sglang/test/test_block_fp8_ep.py +1 -1
  187. sglang/test/test_custom_ops.py +12 -7
  188. sglang/test/test_cutlass_w4a8_moe.py +1 -3
  189. sglang/test/test_fp4_moe.py +1 -3
  190. sglang/test/test_marlin_moe.py +286 -0
  191. sglang/test/test_marlin_utils.py +171 -0
  192. sglang/test/test_utils.py +35 -0
  193. sglang/version.py +1 -1
  194. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +10 -10
  195. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +198 -175
  196. sglang/srt/layers/quantization/quant_utils.py +0 -166
  197. sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
  198. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
  199. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
  200. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
@@ -1,21 +1,65 @@
1
1
  # SPDX-License-Identifier: Apache-2.0
2
+ from __future__ import annotations
3
+
2
4
  import logging
3
- from typing import Any, Dict, List, Optional
5
+ import warnings
6
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
4
7
 
5
8
  import torch
6
9
 
7
- from sglang.srt.layers.linear import (
8
- LinearBase,
10
+ from sglang.srt.layers.linear import LinearBase, set_weight_attrs
11
+ from sglang.srt.layers.parameter import GroupQuantScaleParameter, PackedvLLMParameter
12
+ from sglang.srt.layers.quantization.base_config import (
13
+ FusedMoEMethodBase,
9
14
  LinearMethodBase,
10
- UnquantizedLinearMethod,
15
+ QuantizationConfig,
16
+ QuantizeMethodBase,
11
17
  )
12
- from sglang.srt.layers.parameter import GroupQuantScaleParameter, PackedvLLMParameter
13
- from sglang.srt.layers.quantization.base_config import QuantizationConfig
14
- from sglang.srt.utils import is_cuda
18
+ from sglang.srt.layers.quantization.marlin_utils import (
19
+ apply_awq_marlin_linear,
20
+ awq_to_marlin_zero_points,
21
+ check_marlin_supported,
22
+ check_marlin_supports_layer,
23
+ check_moe_marlin_supports_layer,
24
+ marlin_make_empty_g_idx,
25
+ marlin_make_workspace,
26
+ marlin_moe_permute_scales,
27
+ marlin_permute_scales,
28
+ moe_awq_to_marlin_zero_points,
29
+ verify_marlin_supported,
30
+ verify_marlin_supports_shape,
31
+ )
32
+ from sglang.srt.layers.quantization.scalar_type import scalar_types
33
+ from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
34
+ from sglang.srt.layers.quantization.utils import replace_parameter
35
+
36
+ if TYPE_CHECKING:
37
+ from sglang.srt.layers.moe.topk import TopKOutput
38
+
39
+ try:
40
+ from vllm import _custom_ops as ops
41
+
42
+ warnings.warn(
43
+ f"Using kernels directly from vllm. This might lead to performance degradation or "
44
+ f"missing functionalities as certain kernels may not be optimized. "
45
+ )
46
+ except ImportError:
47
+ ops = None
48
+
49
+ from sglang.srt.utils import is_cuda, is_hip
15
50
 
16
51
  _is_cuda = is_cuda()
52
+ _is_hip = is_hip()
17
53
  if _is_cuda:
18
- from sgl_kernel import awq_dequantize
54
+ from sgl_kernel import awq_dequantize, fused_marlin_moe
55
+ elif _is_hip:
56
+ from sglang.srt.layers.quantization.awq_triton import (
57
+ awq_dequantize_triton as awq_dequantize,
58
+ )
59
+
60
+ warnings.warn(f"HIP does not support fused_marlin_moe currently.")
61
+ else:
62
+ warnings.warn(f"Only CUDA and HIP support AWQ currently.")
19
63
 
20
64
  logger = logging.getLogger(__name__)
21
65
 
@@ -81,7 +125,7 @@ class AWQConfig(QuantizationConfig):
81
125
  ]
82
126
 
83
127
  @classmethod
84
- def from_config(cls, config: Dict[str, Any]) -> "AWQConfig":
128
+ def from_config(cls, config: Dict[str, Any]) -> AWQConfig:
85
129
  weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
86
130
  group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
87
131
  zero_point = cls.get_from_keys(config, ["zero_point"])
@@ -92,7 +136,8 @@ class AWQConfig(QuantizationConfig):
92
136
 
93
137
  def get_quant_method(
94
138
  self, layer: torch.nn.Module, prefix: str
95
- ) -> Optional["LinearMethodBase"]:
139
+ ) -> Optional[LinearMethodBase]:
140
+ from sglang.srt.layers.linear import LinearBase
96
141
 
97
142
  if isinstance(layer, LinearBase):
98
143
  if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
@@ -101,6 +146,176 @@ class AWQConfig(QuantizationConfig):
101
146
  return None
102
147
 
103
148
 
149
+ class AWQMarlinConfig(QuantizationConfig):
150
+ """Config class for AWQ Marlin"""
151
+
152
+ # num_bits -> type
153
+ TYPE_MAP = {
154
+ 4: scalar_types.uint4,
155
+ 8: scalar_types.uint8,
156
+ }
157
+
158
+ def __init__(
159
+ self,
160
+ weight_bits: int,
161
+ group_size: int,
162
+ zero_point: bool,
163
+ lm_head_quantized: bool,
164
+ modules_to_not_convert: Optional[list[str]],
165
+ full_config: dict[str, Any],
166
+ ) -> None:
167
+ super().__init__()
168
+ self.pack_factor = 32 // weight_bits # packed into int32
169
+ self.group_size = group_size
170
+ self.zero_point = zero_point
171
+ self.lm_head_quantized = lm_head_quantized
172
+ self.weight_bits = weight_bits
173
+ self.modules_to_not_convert = modules_to_not_convert or []
174
+ self.full_config = full_config
175
+
176
+ if self.weight_bits not in self.TYPE_MAP:
177
+ raise ValueError(
178
+ f"Unsupported num_bits = {self.weight_bits}. "
179
+ f"Supported num_bits = {self.TYPE_MAP.keys()}"
180
+ )
181
+
182
+ self.quant_type = self.TYPE_MAP[self.weight_bits]
183
+
184
+ verify_marlin_supported(
185
+ self.quant_type, group_size=self.group_size, has_zp=self.zero_point
186
+ )
187
+
188
+ def __repr__(self) -> str:
189
+ return (
190
+ f"AWQMarlinConfig(quant_type={self.quant_type}, "
191
+ f"group_size={self.group_size}, "
192
+ f"zero_point={self.zero_point}, "
193
+ f"lm_head_quantized={self.lm_head_quantized}, "
194
+ f"modules_to_not_convert={self.modules_to_not_convert})"
195
+ )
196
+
197
+ def get_scaled_act_names(self) -> List[str]:
198
+ return []
199
+
200
+ @classmethod
201
+ def get_name(cls) -> str:
202
+ return "awq_marlin"
203
+
204
+ @classmethod
205
+ def get_supported_act_dtypes(cls) -> list[torch.dtype]:
206
+ return [torch.half, torch.bfloat16]
207
+
208
+ @classmethod
209
+ def get_min_capability(cls) -> int:
210
+ return 80
211
+
212
+ @classmethod
213
+ def get_config_filenames(cls) -> list[str]:
214
+ return ["quantize_config.json"]
215
+
216
+ @classmethod
217
+ def from_config(cls, config: dict[str, Any]) -> AWQMarlinConfig:
218
+ weight_bits = cls.get_from_keys(config, ["bits"])
219
+ group_size = cls.get_from_keys(config, ["group_size"])
220
+ zero_point = cls.get_from_keys(config, ["zero_point"])
221
+ lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False)
222
+ modules_to_not_convert = cls.get_from_keys_or(
223
+ config, ["modules_to_not_convert"], None
224
+ )
225
+ return cls(
226
+ weight_bits,
227
+ group_size,
228
+ zero_point,
229
+ lm_head_quantized,
230
+ modules_to_not_convert,
231
+ config,
232
+ )
233
+
234
+ @classmethod
235
+ def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
236
+ can_convert = cls.is_awq_marlin_compatible(hf_quant_cfg)
237
+ is_valid_user_quant = (
238
+ user_quant is None or user_quant == "marlin" or user_quant == "awq_marlin"
239
+ )
240
+
241
+ if can_convert and is_valid_user_quant:
242
+ msg = (
243
+ "The model is convertible to {} during runtime."
244
+ " Using {} kernel.".format(cls.get_name(), cls.get_name())
245
+ )
246
+ logger.info(msg)
247
+ return cls.get_name()
248
+
249
+ if can_convert and user_quant == "awq":
250
+ logger.info(
251
+ "Detected that the model can run with awq_marlin"
252
+ ", however you specified quantization=awq explicitly,"
253
+ " so forcing awq. Use quantization=awq_marlin for"
254
+ " faster inference"
255
+ )
256
+ return None
257
+
258
+ def get_quant_method(
259
+ self, layer: torch.nn.Module, prefix: str
260
+ ) -> Optional[QuantizeMethodBase]:
261
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
262
+ from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
263
+
264
+ if isinstance(layer, LinearBase) or (
265
+ isinstance(layer, ParallelLMHead) and self.lm_head_quantized
266
+ ):
267
+ if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
268
+ return UnquantizedLinearMethod()
269
+ # Check if the layer is supported by AWQMarlin.
270
+ if not check_marlin_supports_layer(layer, self.group_size):
271
+ logger.warning_once(
272
+ "Layer '%s' is not supported by AWQMarlin. Falling back to unoptimized AWQ kernels.", # noqa: E501
273
+ prefix,
274
+ )
275
+ return AWQConfig.from_config(self.full_config).get_quant_method(
276
+ layer, prefix
277
+ )
278
+ return AWQMarlinLinearMethod(self)
279
+ elif isinstance(layer, FusedMoE):
280
+ from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
281
+
282
+ if not check_moe_marlin_supports_layer(layer, self.group_size):
283
+ logger.warning_once(
284
+ f"Layer '{prefix}' is not supported by AWQMoeMarlin. "
285
+ "Falling back to Moe WNA16 kernels."
286
+ )
287
+ return MoeWNA16Config.from_config(self.full_config).get_quant_method(
288
+ layer, prefix
289
+ )
290
+ return AWQMoEMethod(self)
291
+ return None
292
+
293
+ @classmethod
294
+ def is_awq_marlin_compatible(cls, quant_config: dict[str, Any]):
295
+ # Extract data from quant config.
296
+ quant_method = quant_config.get("quant_method", "").lower()
297
+ num_bits = quant_config.get("bits")
298
+ group_size = quant_config.get("group_size")
299
+ zero_point = quant_config.get("zero_point")
300
+
301
+ if not _is_cuda:
302
+ return False
303
+
304
+ if quant_method != "awq":
305
+ return False
306
+
307
+ # If we cannot find the info needed in the config, cannot convert.
308
+ if num_bits is None or group_size is None or zero_point is None:
309
+ return False
310
+
311
+ if num_bits not in cls.TYPE_MAP:
312
+ return False
313
+
314
+ return check_marlin_supported(
315
+ quant_type=cls.TYPE_MAP[num_bits], group_size=group_size, has_zp=zero_point
316
+ )
317
+
318
+
104
319
  class AWQLinearMethod(LinearMethodBase):
105
320
  """Linear method for AWQ.
106
321
 
@@ -195,10 +410,362 @@ class AWQLinearMethod(LinearMethodBase):
195
410
  pack_factor = self.quant_config.pack_factor
196
411
  out_shape = x.shape[:-1] + (qweight.shape[-1] * pack_factor,)
197
412
  reshaped_x = x.reshape(-1, x.shape[-1])
198
-
199
413
  out = awq_dequantize(qweight, scales, qzeros)
200
414
  out = torch.matmul(reshaped_x, out)
201
415
 
202
416
  if bias is not None:
203
417
  out.add_(bias)
204
418
  return out.reshape(out_shape)
419
+
420
+
421
+ class AWQMarlinLinearMethod(LinearMethodBase):
422
+ """Linear method for AWQ Marlin.
423
+
424
+ Args:
425
+ quant_config: The AWQ Marlin quantization config.
426
+ """
427
+
428
+ def __init__(self, quant_config: AWQMarlinConfig) -> None:
429
+ self.quant_config = quant_config
430
+
431
+ def create_weights(
432
+ self,
433
+ layer: torch.nn.Module,
434
+ input_size_per_partition: int,
435
+ output_partition_sizes: list[int],
436
+ input_size: int,
437
+ output_size: int,
438
+ params_dtype: torch.dtype,
439
+ **extra_weight_attrs,
440
+ ) -> None:
441
+ del output_size
442
+ output_size_per_partition = sum(output_partition_sizes)
443
+ weight_loader = extra_weight_attrs.get("weight_loader")
444
+
445
+ # Normalize group_size
446
+ if self.quant_config.group_size != -1:
447
+ group_size = self.quant_config.group_size
448
+ else:
449
+ group_size = input_size
450
+
451
+ verify_marlin_supports_shape(
452
+ output_size_per_partition=output_size_per_partition,
453
+ input_size_per_partition=input_size_per_partition,
454
+ input_size=input_size,
455
+ group_size=group_size,
456
+ )
457
+
458
+ qweight = PackedvLLMParameter(
459
+ data=torch.empty(
460
+ input_size_per_partition,
461
+ output_size_per_partition // self.quant_config.pack_factor,
462
+ dtype=torch.int32,
463
+ ),
464
+ input_dim=0,
465
+ output_dim=1,
466
+ packed_dim=1,
467
+ packed_factor=self.quant_config.pack_factor,
468
+ weight_loader=weight_loader,
469
+ )
470
+
471
+ num_groups = input_size_per_partition // group_size
472
+
473
+ qzeros = PackedvLLMParameter(
474
+ data=torch.empty(
475
+ num_groups,
476
+ output_size_per_partition // self.quant_config.pack_factor,
477
+ dtype=torch.int32,
478
+ ),
479
+ input_dim=0,
480
+ output_dim=1,
481
+ packed_dim=1,
482
+ packed_factor=self.quant_config.pack_factor,
483
+ weight_loader=weight_loader,
484
+ )
485
+
486
+ scales = GroupQuantScaleParameter(
487
+ data=torch.empty(
488
+ num_groups,
489
+ output_size_per_partition,
490
+ dtype=params_dtype,
491
+ ),
492
+ input_dim=0,
493
+ output_dim=1,
494
+ weight_loader=weight_loader,
495
+ )
496
+
497
+ layer.register_parameter("qweight", qweight)
498
+ layer.register_parameter("qzeros", qzeros)
499
+ layer.register_parameter("scales", scales)
500
+
501
+ layer.input_size_per_partition = input_size_per_partition
502
+ layer.output_size_per_partition = output_size_per_partition
503
+ layer.num_groups = num_groups
504
+
505
+ # TODO: Update this docs
506
+ # Checkpoints are serialized in AutoAWQ format, which is different from the
507
+ # marlin format. This function is called after the weights are loaded.
508
+ # Here, we handle the repacking
509
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
510
+ device = layer.qweight.device
511
+ layer.qweight = torch.nn.Parameter(layer.qweight.data, requires_grad=False)
512
+ layer.qzeros = torch.nn.Parameter(layer.qzeros.data, requires_grad=False)
513
+ layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False)
514
+
515
+ # Allocate marlin workspace
516
+ layer.workspace = marlin_make_workspace(device)
517
+
518
+ # Repack weights from AWQ format to marlin format.
519
+ marlin_qweight = ops.awq_marlin_repack(
520
+ layer.qweight,
521
+ size_k=layer.input_size_per_partition,
522
+ size_n=layer.output_size_per_partition,
523
+ num_bits=self.quant_config.quant_type.size_bits,
524
+ )
525
+ replace_parameter(layer, "qweight", marlin_qweight)
526
+
527
+ # Permute scales from AWQ format to marlin format.
528
+ marlin_scales = marlin_permute_scales(
529
+ layer.scales,
530
+ size_k=layer.input_size_per_partition,
531
+ size_n=layer.output_size_per_partition,
532
+ group_size=self.quant_config.group_size,
533
+ )
534
+ replace_parameter(layer, "scales", marlin_scales)
535
+
536
+ # Permute zero-points from AWQ format to marlin format.
537
+ marlin_zp = awq_to_marlin_zero_points(
538
+ layer.qzeros,
539
+ size_k=layer.num_groups,
540
+ size_n=layer.output_size_per_partition,
541
+ num_bits=self.quant_config.quant_type.size_bits,
542
+ )
543
+ replace_parameter(layer, "qzeros", marlin_zp)
544
+
545
+ # Not-used
546
+ layer.g_idx = marlin_make_empty_g_idx(device)
547
+ layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
548
+
549
+ def apply(
550
+ self,
551
+ layer: torch.nn.Module,
552
+ x: torch.Tensor,
553
+ bias: Optional[torch.Tensor] = None,
554
+ ) -> torch.Tensor:
555
+ return apply_awq_marlin_linear(
556
+ input=x,
557
+ weight=layer.qweight,
558
+ weight_scale=layer.scales,
559
+ weight_zp=layer.qzeros,
560
+ g_idx=layer.g_idx,
561
+ g_idx_sort_indices=layer.g_idx_sort_indices,
562
+ workspace=layer.workspace,
563
+ quant_type=self.quant_config.quant_type,
564
+ output_size_per_partition=layer.output_size_per_partition,
565
+ input_size_per_partition=layer.input_size_per_partition,
566
+ bias=bias,
567
+ )
568
+
569
+
570
+ class AWQMoEMethod(FusedMoEMethodBase):
571
+
572
+ def __init__(self, quant_config: AWQMarlinConfig):
573
+ self.quant_config = quant_config
574
+ if self.quant_config.weight_bits != 4:
575
+ raise ValueError("AWQMoEMethod only supports 4bit now.")
576
+ self.quant_type = scalar_types.uint4
577
+
578
+ def create_weights(
579
+ self,
580
+ layer: torch.nn.Module,
581
+ num_experts: int,
582
+ hidden_size: int,
583
+ intermediate_size_per_partition: int,
584
+ params_dtype: torch.dtype,
585
+ **extra_weight_attrs,
586
+ ):
587
+ # Delay the import to avoid circular dependency
588
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
589
+
590
+ extra_weight_attrs.update(
591
+ {
592
+ "is_transposed": True,
593
+ "quant_method": FusedMoeWeightScaleSupported.GROUP.value,
594
+ }
595
+ )
596
+
597
+ w13_qweight = torch.nn.Parameter(
598
+ torch.empty(
599
+ num_experts,
600
+ hidden_size,
601
+ 2 * intermediate_size_per_partition // self.quant_config.pack_factor,
602
+ dtype=torch.int32,
603
+ ),
604
+ requires_grad=False,
605
+ )
606
+ layer.register_parameter("w13_qweight", w13_qweight)
607
+ set_weight_attrs(w13_qweight, extra_weight_attrs)
608
+
609
+ w2_qweight = torch.nn.Parameter(
610
+ torch.empty(
611
+ num_experts,
612
+ intermediate_size_per_partition,
613
+ hidden_size // self.quant_config.pack_factor,
614
+ dtype=torch.int32,
615
+ ),
616
+ requires_grad=False,
617
+ )
618
+ layer.register_parameter("w2_qweight", w2_qweight)
619
+ set_weight_attrs(w2_qweight, extra_weight_attrs)
620
+
621
+ num_groups_w13 = hidden_size // self.quant_config.group_size
622
+ num_groups_w2 = intermediate_size_per_partition // self.quant_config.group_size
623
+
624
+ # WEIGHT_SCALES
625
+ # Allocate 2 scales for w1 and w3 respectively.
626
+ w13_scales = torch.nn.Parameter(
627
+ torch.empty(
628
+ num_experts,
629
+ num_groups_w13,
630
+ intermediate_size_per_partition * 2,
631
+ dtype=params_dtype,
632
+ ),
633
+ requires_grad=False,
634
+ )
635
+ layer.register_parameter("w13_scales", w13_scales)
636
+ set_weight_attrs(w13_scales, extra_weight_attrs)
637
+
638
+ w2_scales = torch.nn.Parameter(
639
+ torch.empty(num_experts, num_groups_w2, hidden_size, dtype=params_dtype),
640
+ requires_grad=False,
641
+ )
642
+ layer.register_parameter("w2_scales", w2_scales)
643
+ set_weight_attrs(w2_scales, extra_weight_attrs)
644
+
645
+ # WEIGHT_ZERO_POINT
646
+ # Allocate 2 zero points for w1 and w3 respectively.
647
+ w13_qzeros = torch.nn.Parameter(
648
+ torch.empty(
649
+ num_experts,
650
+ num_groups_w13,
651
+ 2 * intermediate_size_per_partition // self.quant_config.pack_factor,
652
+ dtype=torch.int32,
653
+ ),
654
+ requires_grad=False,
655
+ )
656
+ layer.register_parameter("w13_qzeros", w13_qzeros)
657
+ set_weight_attrs(w13_qzeros, extra_weight_attrs)
658
+
659
+ w2_qzeros = torch.nn.Parameter(
660
+ torch.empty(
661
+ num_experts,
662
+ num_groups_w2,
663
+ hidden_size // self.quant_config.pack_factor,
664
+ dtype=torch.int32,
665
+ ),
666
+ requires_grad=False,
667
+ )
668
+ layer.register_parameter("w2_qzeros", w2_qzeros)
669
+ set_weight_attrs(w2_qzeros, extra_weight_attrs)
670
+
671
+ device = layer.w13_qweight.device
672
+ layer.workspace = marlin_make_workspace(device, 4)
673
+
674
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
675
+ num_experts = layer.w13_qweight.shape[0]
676
+ device = layer.w13_qweight.device
677
+
678
+ layer.w13_g_idx_sort_indices = torch.nn.Parameter(
679
+ torch.empty((num_experts, 0), dtype=torch.int32, device=device),
680
+ requires_grad=False,
681
+ )
682
+ layer.w2_g_idx_sort_indices = torch.nn.Parameter(
683
+ torch.empty((num_experts, 0), dtype=torch.int32, device=device),
684
+ requires_grad=False,
685
+ )
686
+
687
+ marlin_w13_qweight = ops.awq_marlin_moe_repack(
688
+ layer.w13_qweight,
689
+ layer.w13_g_idx_sort_indices,
690
+ size_k=layer.w13_qweight.shape[1],
691
+ size_n=layer.w13_qweight.shape[2] * self.quant_config.pack_factor,
692
+ num_bits=self.quant_config.weight_bits,
693
+ )
694
+ replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
695
+
696
+ marlin_w2_qweight = ops.awq_marlin_moe_repack(
697
+ layer.w2_qweight,
698
+ layer.w2_g_idx_sort_indices,
699
+ size_k=layer.w2_qweight.shape[1],
700
+ size_n=layer.w2_qweight.shape[2] * self.quant_config.pack_factor,
701
+ num_bits=self.quant_config.weight_bits,
702
+ )
703
+ replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
704
+
705
+ # hidden_size->intermediate_size
706
+ marlin_w13_scales = marlin_moe_permute_scales(
707
+ s=layer.w13_scales,
708
+ size_k=layer.intermediate_size_per_partition,
709
+ size_n=layer.w13_scales.shape[2],
710
+ group_size=self.quant_config.group_size,
711
+ )
712
+
713
+ replace_parameter(layer, "w13_scales", marlin_w13_scales)
714
+
715
+ marlin_w2_scales = marlin_moe_permute_scales(
716
+ s=layer.w2_scales,
717
+ size_k=layer.intermediate_size_per_partition,
718
+ size_n=layer.w2_scales.shape[2],
719
+ group_size=self.quant_config.group_size,
720
+ )
721
+ replace_parameter(layer, "w2_scales", marlin_w2_scales)
722
+
723
+ marlin_w13_zp = moe_awq_to_marlin_zero_points(
724
+ layer.w13_qzeros,
725
+ size_k=layer.w13_qzeros.shape[1],
726
+ size_n=layer.w13_qzeros.shape[2] * self.quant_config.pack_factor,
727
+ num_bits=self.quant_config.weight_bits,
728
+ )
729
+ replace_parameter(layer, "w13_qzeros", marlin_w13_zp)
730
+
731
+ marlin_w2_zp = moe_awq_to_marlin_zero_points(
732
+ layer.w2_qzeros,
733
+ size_k=layer.w2_qzeros.shape[1],
734
+ size_n=layer.w2_qzeros.shape[2] * self.quant_config.pack_factor,
735
+ num_bits=self.quant_config.weight_bits,
736
+ )
737
+ replace_parameter(layer, "w2_qzeros", marlin_w2_zp)
738
+
739
+ def apply(
740
+ self,
741
+ layer: torch.nn.Module,
742
+ x: torch.Tensor,
743
+ topk_output: TopKOutput,
744
+ *,
745
+ activation: str = "silu",
746
+ **kwargs,
747
+ ) -> torch.Tensor:
748
+
749
+ assert activation == "silu", "Only SiLU activation is supported."
750
+
751
+ # The input must currently be float16
752
+ orig_dtype = x.dtype
753
+ x = x.half()
754
+
755
+ topk_weights, topk_ids, router_logits = topk_output
756
+
757
+ return fused_marlin_moe(
758
+ x,
759
+ layer.w13_qweight,
760
+ layer.w2_qweight,
761
+ layer.w13_scales,
762
+ layer.w2_scales,
763
+ router_logits,
764
+ topk_weights,
765
+ topk_ids,
766
+ sort_indices1=layer.w13_g_idx_sort_indices,
767
+ sort_indices2=layer.w2_g_idx_sort_indices,
768
+ w1_zeros=layer.w13_qzeros,
769
+ w2_zeros=layer.w2_qzeros,
770
+ num_bits=self.quant_config.weight_bits,
771
+ ).to(orig_dtype)