sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. sglang/bench_one_batch.py +2 -1
  2. sglang/eval/loogle_eval.py +7 -0
  3. sglang/srt/_custom_ops.py +29 -1
  4. sglang/srt/configs/deepseekvl2.py +11 -2
  5. sglang/srt/configs/internvl.py +3 -0
  6. sglang/srt/configs/janus_pro.py +3 -0
  7. sglang/srt/configs/model_config.py +10 -8
  8. sglang/srt/configs/update_config.py +3 -1
  9. sglang/srt/conversation.py +2 -1
  10. sglang/srt/custom_op.py +5 -2
  11. sglang/srt/disaggregation/common/conn.py +34 -6
  12. sglang/srt/disaggregation/decode.py +9 -1
  13. sglang/srt/disaggregation/mini_lb.py +3 -2
  14. sglang/srt/disaggregation/mooncake/conn.py +93 -76
  15. sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
  16. sglang/srt/disaggregation/nixl/conn.py +17 -13
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
  18. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
  19. sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
  20. sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
  21. sglang/srt/distributed/parallel_state.py +103 -15
  22. sglang/srt/entrypoints/engine.py +31 -33
  23. sglang/srt/entrypoints/http_server.py +20 -32
  24. sglang/srt/entrypoints/openai/protocol.py +3 -3
  25. sglang/srt/entrypoints/openai/serving_chat.py +48 -6
  26. sglang/srt/eplb/expert_location_dispatch.py +1 -1
  27. sglang/srt/function_call/base_format_detector.py +74 -12
  28. sglang/srt/function_call/deepseekv3_detector.py +26 -11
  29. sglang/srt/function_call/ebnf_composer.py +95 -63
  30. sglang/srt/function_call/function_call_parser.py +4 -2
  31. sglang/srt/function_call/kimik2_detector.py +41 -16
  32. sglang/srt/function_call/llama32_detector.py +6 -3
  33. sglang/srt/function_call/mistral_detector.py +11 -3
  34. sglang/srt/function_call/pythonic_detector.py +16 -14
  35. sglang/srt/function_call/qwen25_detector.py +12 -3
  36. sglang/srt/function_call/qwen3_coder_detector.py +151 -0
  37. sglang/srt/hf_transformers_utils.py +0 -1
  38. sglang/srt/layers/activation.py +24 -3
  39. sglang/srt/layers/attention/base_attn_backend.py +3 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +3 -3
  41. sglang/srt/layers/attention/flashinfer_backend.py +40 -1
  42. sglang/srt/layers/communicator.py +12 -12
  43. sglang/srt/layers/dp_attention.py +72 -24
  44. sglang/srt/layers/linear.py +13 -102
  45. sglang/srt/layers/logits_processor.py +34 -24
  46. sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
  47. sglang/srt/layers/moe/ep_moe/layer.py +23 -402
  48. sglang/srt/layers/moe/fused_moe_native.py +7 -47
  49. sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
  50. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +54 -263
  57. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
  58. sglang/srt/layers/moe/topk.py +190 -23
  59. sglang/srt/layers/quantization/__init__.py +20 -134
  60. sglang/srt/layers/quantization/awq.py +578 -11
  61. sglang/srt/layers/quantization/awq_triton.py +339 -0
  62. sglang/srt/layers/quantization/base_config.py +85 -10
  63. sglang/srt/layers/quantization/blockwise_int8.py +17 -55
  64. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
  65. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +23 -79
  66. sglang/srt/layers/quantization/fp8.py +273 -62
  67. sglang/srt/layers/quantization/fp8_kernel.py +210 -46
  68. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  69. sglang/srt/layers/quantization/gptq.py +501 -143
  70. sglang/srt/layers/quantization/marlin_utils.py +790 -0
  71. sglang/srt/layers/quantization/modelopt_quant.py +34 -112
  72. sglang/srt/layers/quantization/moe_wna16.py +45 -49
  73. sglang/srt/layers/quantization/petit.py +252 -0
  74. sglang/srt/layers/quantization/petit_utils.py +104 -0
  75. sglang/srt/layers/quantization/qoq.py +7 -6
  76. sglang/srt/layers/quantization/scalar_type.py +352 -0
  77. sglang/srt/layers/quantization/unquant.py +422 -0
  78. sglang/srt/layers/quantization/utils.py +340 -9
  79. sglang/srt/layers/quantization/w4afp8.py +8 -4
  80. sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
  81. sglang/srt/layers/quantization/w8a8_int8.py +51 -115
  82. sglang/srt/layers/radix_attention.py +5 -3
  83. sglang/srt/layers/vocab_parallel_embedding.py +1 -41
  84. sglang/srt/lora/lora.py +0 -4
  85. sglang/srt/lora/lora_manager.py +162 -164
  86. sglang/srt/lora/lora_registry.py +124 -0
  87. sglang/srt/lora/mem_pool.py +83 -35
  88. sglang/srt/lora/utils.py +12 -5
  89. sglang/srt/managers/cache_controller.py +288 -0
  90. sglang/srt/managers/io_struct.py +60 -30
  91. sglang/srt/managers/mm_utils.py +7 -8
  92. sglang/srt/managers/schedule_batch.py +163 -113
  93. sglang/srt/managers/schedule_policy.py +68 -27
  94. sglang/srt/managers/scheduler.py +256 -86
  95. sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
  96. sglang/srt/managers/tokenizer_manager.py +38 -27
  97. sglang/srt/managers/tp_worker.py +16 -4
  98. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  99. sglang/srt/mem_cache/allocator.py +74 -23
  100. sglang/srt/mem_cache/base_prefix_cache.py +14 -2
  101. sglang/srt/mem_cache/chunk_cache.py +5 -2
  102. sglang/srt/mem_cache/hicache_storage.py +168 -0
  103. sglang/srt/mem_cache/hiradix_cache.py +194 -5
  104. sglang/srt/mem_cache/memory_pool.py +16 -1
  105. sglang/srt/mem_cache/memory_pool_host.py +44 -2
  106. sglang/srt/mem_cache/radix_cache.py +26 -0
  107. sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
  108. sglang/srt/metrics/collector.py +9 -0
  109. sglang/srt/model_executor/cuda_graph_runner.py +66 -31
  110. sglang/srt/model_executor/forward_batch_info.py +210 -25
  111. sglang/srt/model_executor/model_runner.py +147 -42
  112. sglang/srt/model_loader/loader.py +7 -1
  113. sglang/srt/model_loader/utils.py +4 -4
  114. sglang/srt/models/clip.py +1 -1
  115. sglang/srt/models/deepseek.py +9 -6
  116. sglang/srt/models/deepseek_janus_pro.py +1 -1
  117. sglang/srt/models/deepseek_v2.py +192 -173
  118. sglang/srt/models/deepseek_vl2.py +5 -5
  119. sglang/srt/models/gemma.py +48 -0
  120. sglang/srt/models/gemma2.py +52 -0
  121. sglang/srt/models/gemma3_causal.py +63 -0
  122. sglang/srt/models/gemma3_mm.py +1 -1
  123. sglang/srt/models/gemma3n_mm.py +2 -4
  124. sglang/srt/models/granitemoe.py +385 -0
  125. sglang/srt/models/grok.py +9 -3
  126. sglang/srt/models/hunyuan.py +63 -16
  127. sglang/srt/models/internvl.py +1 -1
  128. sglang/srt/models/kimi_vl.py +1 -1
  129. sglang/srt/models/llama.py +41 -0
  130. sglang/srt/models/llama4.py +11 -11
  131. sglang/srt/models/llava.py +2 -2
  132. sglang/srt/models/llavavid.py +1 -1
  133. sglang/srt/models/minicpm.py +0 -2
  134. sglang/srt/models/minicpmo.py +3 -7
  135. sglang/srt/models/minicpmv.py +1 -1
  136. sglang/srt/models/mistral.py +1 -1
  137. sglang/srt/models/mixtral.py +9 -2
  138. sglang/srt/models/mllama.py +3 -5
  139. sglang/srt/models/mllama4.py +13 -6
  140. sglang/srt/models/olmoe.py +8 -5
  141. sglang/srt/models/persimmon.py +330 -0
  142. sglang/srt/models/phi.py +321 -0
  143. sglang/srt/models/phi4mm.py +44 -4
  144. sglang/srt/models/phi4mm_audio.py +1260 -0
  145. sglang/srt/models/phi4mm_utils.py +1917 -0
  146. sglang/srt/models/phimoe.py +9 -3
  147. sglang/srt/models/qwen.py +37 -0
  148. sglang/srt/models/qwen2.py +41 -0
  149. sglang/srt/models/qwen2_5_vl.py +4 -4
  150. sglang/srt/models/qwen2_audio.py +1 -1
  151. sglang/srt/models/qwen2_moe.py +53 -9
  152. sglang/srt/models/qwen2_vl.py +4 -4
  153. sglang/srt/models/qwen3.py +65 -1
  154. sglang/srt/models/qwen3_moe.py +57 -24
  155. sglang/srt/models/vila.py +1 -1
  156. sglang/srt/multimodal/processors/base_processor.py +91 -97
  157. sglang/srt/multimodal/processors/clip.py +21 -19
  158. sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
  159. sglang/srt/multimodal/processors/gemma3.py +13 -17
  160. sglang/srt/multimodal/processors/gemma3n.py +19 -23
  161. sglang/srt/multimodal/processors/internvl.py +9 -10
  162. sglang/srt/multimodal/processors/janus_pro.py +12 -27
  163. sglang/srt/multimodal/processors/kimi_vl.py +12 -14
  164. sglang/srt/multimodal/processors/llava.py +4 -2
  165. sglang/srt/multimodal/processors/minicpm.py +35 -44
  166. sglang/srt/multimodal/processors/mlama.py +21 -18
  167. sglang/srt/multimodal/processors/mllama4.py +4 -5
  168. sglang/srt/multimodal/processors/phi4mm.py +63 -39
  169. sglang/srt/multimodal/processors/pixtral.py +14 -35
  170. sglang/srt/multimodal/processors/qwen_audio.py +65 -0
  171. sglang/srt/multimodal/processors/qwen_vl.py +16 -21
  172. sglang/srt/multimodal/processors/vila.py +14 -14
  173. sglang/srt/reasoning_parser.py +46 -4
  174. sglang/srt/sampling/sampling_batch_info.py +6 -5
  175. sglang/srt/sampling/sampling_params.py +8 -1
  176. sglang/srt/server_args.py +454 -270
  177. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
  178. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +46 -37
  179. sglang/srt/speculative/eagle_utils.py +51 -23
  180. sglang/srt/speculative/eagle_worker.py +59 -44
  181. sglang/srt/two_batch_overlap.py +10 -5
  182. sglang/srt/utils.py +44 -69
  183. sglang/test/runners.py +14 -3
  184. sglang/test/test_activation.py +50 -1
  185. sglang/test/test_block_fp8.py +8 -3
  186. sglang/test/test_block_fp8_ep.py +1 -1
  187. sglang/test/test_custom_ops.py +12 -7
  188. sglang/test/test_cutlass_w4a8_moe.py +1 -3
  189. sglang/test/test_fp4_moe.py +1 -3
  190. sglang/test/test_marlin_moe.py +286 -0
  191. sglang/test/test_marlin_utils.py +171 -0
  192. sglang/test/test_utils.py +35 -0
  193. sglang/version.py +1 -1
  194. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +10 -10
  195. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +198 -175
  196. sglang/srt/layers/quantization/quant_utils.py +0 -166
  197. sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
  198. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
  199. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
  200. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
@@ -1,46 +1,62 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
4
+ from dataclasses import dataclass
2
5
  from fractions import Fraction
3
- from typing import Any, Callable, Dict, List, Optional, Union
6
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
4
7
 
5
8
  import torch
6
9
 
7
- from sglang.srt.layers.linear import LinearBase, set_weight_attrs
10
+ from sglang.srt.layers.parameter import (
11
+ BasevLLMParameter,
12
+ ChannelQuantScaleParameter,
13
+ GroupQuantScaleParameter,
14
+ PackedColumnParameter,
15
+ PackedvLLMParameter,
16
+ RowvLLMParameter,
17
+ permute_param_layout_,
18
+ )
8
19
  from sglang.srt.layers.quantization.base_config import (
20
+ FusedMoEMethodBase,
21
+ LinearMethodBase,
9
22
  QuantizationConfig,
10
23
  QuantizeMethodBase,
11
24
  )
12
- from sglang.srt.layers.quantization.utils import replace_parameter
13
- from sglang.srt.utils import is_cuda
25
+ from sglang.srt.layers.quantization.marlin_utils import (
26
+ apply_gptq_marlin_linear,
27
+ check_marlin_supported,
28
+ check_marlin_supports_shape,
29
+ marlin_is_k_full,
30
+ marlin_make_empty_g_idx,
31
+ marlin_make_workspace,
32
+ marlin_moe_permute_scales,
33
+ marlin_permute_scales,
34
+ marlin_repeat_scales_on_all_ranks,
35
+ marlin_sort_g_idx,
36
+ marlin_zero_points,
37
+ verify_marlin_supported,
38
+ )
39
+ from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types
40
+ from sglang.srt.layers.quantization.utils import (
41
+ get_linear_quant_method,
42
+ replace_parameter,
43
+ unpack_cols,
44
+ )
14
45
 
15
- _is_cuda = is_cuda()
46
+ if TYPE_CHECKING:
47
+ from sglang.srt.layers.moe.topk import TopKOutput
16
48
 
17
49
  try:
18
50
  from vllm import _custom_ops as ops
19
- from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
20
- from vllm.model_executor.layers.quantization.gptq_marlin import (
21
- FusedMoE,
22
- FusedMoEMethodBase,
23
- FusedMoeWeightScaleSupported,
24
- GPTQMarlinLinearMethod,
25
- marlin_moe_permute_scales,
26
- )
27
- from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod
28
- from vllm.model_executor.layers.quantization.utils.marlin_utils import (
29
- check_marlin_supported,
30
- )
31
- from vllm.scalar_type import scalar_types
32
-
33
- VLLM_AVAILABLE = True
34
51
  except ImportError:
35
- VLLM_AVAILABLE = False
52
+ ops = None
36
53
 
37
- GPTQLinearMethod = MarlinLinearMethod = Any
54
+ from sglang.srt.utils import is_cuda
38
55
 
39
- FusedMoEMethodBase = QuantizeMethodBase
56
+ _is_cuda = is_cuda()
40
57
 
41
- class scalar_types:
42
- uint4b8 = "uint4b8"
43
- uint8b128 = "uint8b128"
58
+ if _is_cuda:
59
+ from sgl_kernel import fused_marlin_moe
44
60
 
45
61
 
46
62
  logger = logging.getLogger(__name__)
@@ -54,6 +70,38 @@ def check_marlin_format(hf_quant_cfg: Dict[str, Any]) -> bool:
54
70
  )
55
71
 
56
72
 
73
+ def gptq_marlin_moe_repack(
74
+ b_q_weight: torch.Tensor,
75
+ perm: torch.Tensor,
76
+ size_k: int,
77
+ size_n: int,
78
+ num_bits: int,
79
+ ) -> torch.Tensor:
80
+ num_experts = b_q_weight.shape[0]
81
+ assert size_k % 16 == 0
82
+ output = torch.empty(
83
+ (num_experts, size_k // 16, size_n * (num_bits // 2)),
84
+ device=b_q_weight.device,
85
+ dtype=b_q_weight.dtype,
86
+ )
87
+ for e in range(num_experts):
88
+ output[e] = torch.ops.sgl_kernel.gptq_marlin_repack(
89
+ b_q_weight[e], perm[e], size_k, size_n, num_bits
90
+ )
91
+ return output
92
+
93
+
94
+ @dataclass
95
+ class MarlinLinearLayerConfig:
96
+ full_weight_shape: tuple[int, int] # [in, out]
97
+ partition_weight_shape: tuple[int, int]
98
+ weight_type: ScalarType
99
+ act_type: torch.dtype
100
+ group_size: int
101
+ zero_points: bool
102
+ has_g_idx: bool
103
+
104
+
57
105
  class GPTQConfig(QuantizationConfig):
58
106
  """Config class for GPTQ.
59
107
 
@@ -139,7 +187,7 @@ class GPTQConfig(QuantizationConfig):
139
187
  return ["quantize_config.json"]
140
188
 
141
189
  @classmethod
142
- def from_config(cls, config: Dict[str, Any]) -> "GPTQConfig":
190
+ def from_config(cls, config: Dict[str, Any]) -> GPTQConfig:
143
191
  dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
144
192
  dynamic = {} if dynamic is None else dynamic
145
193
 
@@ -151,11 +199,16 @@ class GPTQConfig(QuantizationConfig):
151
199
 
152
200
  def get_quant_method(
153
201
  self, layer: torch.nn.Module, prefix: str
154
- ) -> Optional[GPTQLinearMethod]:
202
+ ) -> Optional[LinearMethodBase]:
155
203
  # Delay the import to avoid circular dependency
156
- from sglang.srt.layers.quantization import get_linear_quant_method
204
+ from sglang.srt.layers.linear import LinearBase
205
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
157
206
 
158
- return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod)
207
+ if isinstance(layer, LinearBase):
208
+ return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod)
209
+ elif isinstance(layer, FusedMoE):
210
+ raise TypeError("GPTQ Method does not support MoE, please use gptq_marlin")
211
+ return None
159
212
 
160
213
 
161
214
  class GPTQMarlinConfig(QuantizationConfig):
@@ -258,7 +311,7 @@ class GPTQMarlinConfig(QuantizationConfig):
258
311
  return ["quantize_config.json"]
259
312
 
260
313
  @classmethod
261
- def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig":
314
+ def from_config(cls, config: Dict[str, Any]) -> GPTQMarlinConfig:
262
315
  dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
263
316
  dynamic = {} if dynamic is None else dynamic
264
317
 
@@ -309,18 +362,9 @@ class GPTQMarlinConfig(QuantizationConfig):
309
362
  ) -> Optional[QuantizeMethodBase]:
310
363
  # Delay the import to avoid circular dependency
311
364
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
312
- from sglang.srt.layers.quantization import get_linear_quant_method
313
365
 
314
366
  if isinstance(layer, FusedMoE):
315
367
  return GPTQMarlinMoEMethod(self)
316
- # TODO: re-enable after SGLang syncs with vllm >= 0.7.3
317
- # if layer.num_experts > 32:
318
- # # For MoEs with many experts the moe_wna16 kernel is faster
319
- # return MoeWNA16Config.from_config(self.full_config).get_quant_method(
320
- # layer, prefix
321
- # )
322
- # else:
323
- # return GPTQMarlinMoEMethod(self)
324
368
  return get_linear_quant_method(self, layer, prefix, GPTQMarlinLinearMethod)
325
369
 
326
370
  @classmethod
@@ -344,112 +388,439 @@ class GPTQMarlinConfig(QuantizationConfig):
344
388
  if (num_bits, sym) not in cls.TYPE_MAP:
345
389
  return False
346
390
 
347
- assert (
348
- VLLM_AVAILABLE
349
- ), "vllm is not installed, to use gptq_marlin, please install vllm"
350
-
351
391
  return check_marlin_supported(
352
392
  quant_type=cls.TYPE_MAP[(num_bits, sym)], group_size=group_size
353
393
  )
354
394
 
355
395
 
356
- class MarlinConfig(QuantizationConfig):
357
- """Config class for Marlin.
396
+ class GPTQLinearMethod(LinearMethodBase):
397
+ """Linear method for GPTQ.
358
398
 
359
- Reference: https://github.com/IST-DASLab/marlin/tree/master
399
+ Args:
400
+ quant_config: The GPTQ quantization config.
360
401
  """
361
402
 
362
- def __init__(
403
+ def __init__(self, quant_config: GPTQConfig):
404
+ self.quant_config = quant_config
405
+
406
+ def create_weights(
363
407
  self,
364
- group_size: int,
365
- lm_head_quantized: bool,
366
- ) -> None:
367
- # Group size for the quantization.
368
- self.group_size = group_size
369
- self.lm_head_quantized = lm_head_quantized
370
- if self.group_size != 128 and self.group_size != -1:
408
+ layer: torch.nn.Module,
409
+ input_size_per_partition: int,
410
+ output_partition_sizes: list[int],
411
+ input_size: int,
412
+ output_size: int,
413
+ params_dtype: torch.dtype,
414
+ **extra_weight_attrs,
415
+ ):
416
+ del output_size # Unused.
417
+ weight_loader = extra_weight_attrs.get("weight_loader")
418
+ if input_size_per_partition % self.quant_config.group_size != 0:
371
419
  raise ValueError(
372
- "Currently, only group size 128 and -1 (channelwise) "
373
- "is supported for Marlin, but got group_size of "
374
- f"{self.group_size}"
420
+ "The input size is not aligned with the quantized "
421
+ "weight shape. This can be caused by too large "
422
+ "tensor parallel size."
423
+ )
424
+ output_size_per_partition = sum(output_partition_sizes)
425
+ if output_size_per_partition % self.quant_config.pack_factor.numerator != 0:
426
+ raise ValueError(
427
+ "The output size is not aligned with the quantized "
428
+ "weight shape. This can be caused by too large "
429
+ "tensor parallel size."
430
+ )
431
+
432
+ if self.quant_config.group_size != -1:
433
+ group_size = self.quant_config.group_size
434
+ else:
435
+ group_size = input_size
436
+
437
+ self.use_shuffle = True
438
+ scale_and_zero_size = input_size // group_size
439
+ scale_and_zero_input_dim = None
440
+ if (
441
+ input_size != input_size_per_partition
442
+ and self.quant_config.group_size != -1
443
+ ):
444
+ if self.quant_config.desc_act:
445
+ self.use_shuffle = False
446
+ else:
447
+ # we need to partition qzeros and scales for exllama kernel
448
+ scale_and_zero_size = input_size_per_partition // group_size
449
+ scale_and_zero_input_dim = 0
450
+
451
+ qweight = PackedvLLMParameter(
452
+ data=torch.empty(
453
+ input_size_per_partition // self.quant_config.pack_factor,
454
+ output_size_per_partition,
455
+ dtype=torch.int32,
456
+ ),
457
+ input_dim=0,
458
+ output_dim=1,
459
+ packed_dim=0,
460
+ packed_factor=self.quant_config.pack_factor,
461
+ weight_loader=weight_loader,
462
+ )
463
+
464
+ g_idx = RowvLLMParameter(
465
+ data=torch.tensor(
466
+ [
467
+ i // self.quant_config.group_size
468
+ for i in range(input_size_per_partition)
469
+ ],
470
+ dtype=torch.int32,
471
+ ),
472
+ input_dim=0,
473
+ weight_loader=weight_loader,
474
+ )
475
+ qzeros_args = {
476
+ "data": torch.empty(
477
+ scale_and_zero_size,
478
+ output_size_per_partition // self.quant_config.pack_factor,
479
+ dtype=torch.int32,
480
+ ),
481
+ "weight_loader": weight_loader,
482
+ }
483
+ weight_scale_args = {
484
+ "data": torch.empty(
485
+ scale_and_zero_size,
486
+ output_size_per_partition,
487
+ dtype=params_dtype,
488
+ ),
489
+ "weight_loader": weight_loader,
490
+ }
491
+ if scale_and_zero_input_dim is None:
492
+ scales = ChannelQuantScaleParameter(output_dim=1, **weight_scale_args)
493
+ qzeros = PackedColumnParameter(
494
+ output_dim=1,
495
+ packed_dim=1,
496
+ packed_factor=self.quant_config.pack_factor,
497
+ **qzeros_args,
375
498
  )
376
499
 
377
- # 4 Bits packed into 32 bit datatype.
378
- self.pack_factor = 32 // 4
500
+ else:
501
+ scales = GroupQuantScaleParameter(
502
+ output_dim=1, input_dim=0, **weight_scale_args
503
+ )
504
+ qzeros = PackedvLLMParameter(
505
+ input_dim=0,
506
+ output_dim=1,
507
+ packed_dim=1,
508
+ packed_factor=self.quant_config.pack_factor,
509
+ **qzeros_args,
510
+ )
379
511
 
380
- # Tile size used by marlin kernels.
381
- self.tile_size = 16
512
+ layer.register_parameter("qweight", qweight)
513
+ layer.register_parameter("g_idx", g_idx)
514
+ layer.register_parameter("qzeros", qzeros)
515
+ layer.register_parameter("scales", scales)
382
516
 
383
- # Min out_features dim
384
- self.min_n_threads = 64
517
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
518
+ # for torch.compile
519
+ layer.qzeros = torch.nn.Parameter(layer.qzeros.data, requires_grad=False)
520
+ layer.qweight = torch.nn.Parameter(layer.qweight.data, requires_grad=False)
521
+ layer.g_idx = torch.nn.Parameter(layer.g_idx.data, requires_grad=False)
522
+ layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False)
523
+
524
+ # exllama needs to shuffle the weight after the weight is loaded
525
+ # here we do the shuffle on first forward pass
526
+ if self.use_shuffle:
527
+ if self.quant_config.desc_act:
528
+ layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int)
529
+ else:
530
+ layer.g_idx.data = torch.empty(
531
+ (0,), dtype=torch.int, device=layer.g_idx.device
532
+ )
533
+ ops.gptq_shuffle(layer.qweight, layer.g_idx, self.quant_config.weight_bits)
385
534
 
386
- # Min in_features dim
387
- self.min_k_threads = 128
535
+ def apply(
536
+ self,
537
+ layer: torch.nn.Module,
538
+ x: torch.Tensor,
539
+ bias: Optional[torch.Tensor] = None,
540
+ ) -> torch.Tensor:
541
+ out_shape = x.shape[:-1] + (layer.qweight.shape[-1],)
542
+ reshaped_x = x.reshape(-1, x.shape[-1])
543
+
544
+ output = ops.gptq_gemm(
545
+ reshaped_x,
546
+ layer.qweight,
547
+ layer.qzeros,
548
+ layer.scales,
549
+ layer.g_idx,
550
+ self.use_shuffle,
551
+ self.quant_config.weight_bits,
552
+ )
553
+ if bias is not None:
554
+ output.add_(bias)
555
+ return output.reshape(out_shape)
388
556
 
389
- # Max parallel problems to solve at once (improves large
390
- # batch performance)
391
- self.max_parallel = 16
392
557
 
393
- # Permutation length used by the marlin kernels.
394
- self.perm_len = 1024
558
+ class GPTQMarlinLinearMethod(LinearMethodBase):
559
+ """Linear method for GPTQ Marlin.
395
560
 
396
- def __repr__(self) -> str:
397
- return (
398
- f"MarlinConfig(group_size={self.group_size}, "
399
- f"lm_head_quantized={self.lm_head_quantized})"
561
+ Args:
562
+ quant_config: The GPTQ Marlin quantization config.
563
+ """
564
+
565
+ _kernel_backends_being_used: set[str] = set()
566
+
567
+ def __init__(self, quant_config: GPTQMarlinConfig) -> None:
568
+ self.quant_config = quant_config
569
+
570
+ # Verify supported on platform.
571
+ verify_marlin_supported(
572
+ quant_type=self.quant_config.quant_type,
573
+ group_size=self.quant_config.group_size,
400
574
  )
401
575
 
402
- @classmethod
403
- def get_name(cls) -> str:
404
- return "marlin"
576
+ def create_weights(
577
+ self,
578
+ layer: torch.nn.Module,
579
+ input_size_per_partition: int,
580
+ output_partition_sizes: list[int],
581
+ input_size: int,
582
+ output_size: int,
583
+ params_dtype: torch.dtype,
584
+ **extra_weight_attrs,
585
+ ) -> None:
586
+ output_size_per_partition = sum(output_partition_sizes)
587
+ is_row_parallel = input_size != input_size_per_partition
588
+ weight_loader = extra_weight_attrs.get("weight_loader")
589
+
590
+ self.kernel_config = MarlinLinearLayerConfig(
591
+ full_weight_shape=(input_size, output_size),
592
+ partition_weight_shape=(
593
+ input_size_per_partition,
594
+ output_size_per_partition,
595
+ ),
596
+ weight_type=self.quant_config.quant_type,
597
+ act_type=params_dtype,
598
+ group_size=self.quant_config.group_size,
599
+ zero_points=False,
600
+ has_g_idx=self.quant_config.desc_act,
601
+ )
602
+ # Normalize group_size
603
+ if self.quant_config.group_size != -1:
604
+ group_size = self.quant_config.group_size
605
+ else:
606
+ group_size = input_size
405
607
 
406
- @classmethod
407
- def get_supported_act_dtypes(cls) -> List[torch.dtype]:
408
- return [torch.half]
608
+ # Determine sharding
609
+ if marlin_repeat_scales_on_all_ranks(
610
+ self.quant_config.desc_act, self.quant_config.group_size, is_row_parallel
611
+ ):
612
+ # By setting scale_dim == None, weight_loader will
613
+ # repeat the scales on each GPU in TP>1 case.
614
+ scales_and_zp_input_dim = None
615
+ scales_and_zp_size = input_size // group_size
616
+ else:
617
+ # By setting scale_dim == 0, weight_loader will
618
+ # shard the scales in TP>1 case.
619
+ scales_and_zp_input_dim = 0
620
+ scales_and_zp_size = input_size_per_partition // group_size
621
+
622
+ # Quantized weights
623
+ qweight = PackedvLLMParameter(
624
+ data=torch.empty(
625
+ input_size_per_partition // self.quant_config.pack_factor,
626
+ output_size_per_partition,
627
+ dtype=torch.int32,
628
+ ),
629
+ input_dim=0,
630
+ output_dim=1,
631
+ packed_dim=0,
632
+ packed_factor=self.quant_config.pack_factor,
633
+ weight_loader=weight_loader,
634
+ )
409
635
 
410
- @classmethod
411
- # Need to figure it out
412
- def get_min_capability(cls) -> int:
413
- return 80
636
+ # Activation order
637
+ g_idx = RowvLLMParameter(
638
+ data=torch.empty(
639
+ input_size_per_partition,
640
+ dtype=torch.int32,
641
+ ),
642
+ input_dim=0,
643
+ weight_loader=weight_loader,
644
+ )
414
645
 
415
- @classmethod
416
- def get_config_filenames(cls) -> List[str]:
417
- return ["quantize_config.json"]
646
+ qzeros_args = {
647
+ "data": torch.empty(
648
+ scales_and_zp_size,
649
+ output_size_per_partition // self.quant_config.pack_factor,
650
+ dtype=torch.int32,
651
+ ),
652
+ "weight_loader": weight_loader,
653
+ }
654
+ weight_scale_args = {
655
+ "data": torch.empty(
656
+ scales_and_zp_size,
657
+ output_size_per_partition,
658
+ dtype=params_dtype,
659
+ ),
660
+ "weight_loader": weight_loader,
661
+ }
662
+
663
+ if scales_and_zp_input_dim is None:
664
+ scales = ChannelQuantScaleParameter(output_dim=1, **weight_scale_args)
665
+ qzeros = PackedColumnParameter(
666
+ output_dim=1,
667
+ packed_dim=1,
668
+ packed_factor=self.quant_config.pack_factor,
669
+ **qzeros_args,
670
+ )
418
671
 
419
- @classmethod
420
- def from_config(cls, config: Dict[str, Any]) -> "MarlinConfig":
421
- group_size = cls.get_from_keys(config, ["group_size"])
422
- lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False)
423
- return cls(group_size, lm_head_quantized)
672
+ else:
673
+ scales = GroupQuantScaleParameter(
674
+ output_dim=1, input_dim=0, **weight_scale_args
675
+ )
676
+ qzeros = PackedvLLMParameter(
677
+ input_dim=0,
678
+ output_dim=1,
679
+ packed_dim=1,
680
+ packed_factor=self.quant_config.pack_factor,
681
+ **qzeros_args,
682
+ )
424
683
 
425
- @classmethod
426
- def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
427
- is_marlin_format = check_marlin_format(hf_quant_cfg)
684
+ layer.register_parameter("qweight", qweight)
685
+ layer.register_parameter("g_idx", g_idx)
686
+ layer.register_parameter("scales", scales)
687
+ layer.register_parameter("qzeros", qzeros)
428
688
 
429
- is_valid_user_quant = (
430
- user_quant is None or user_quant == "gptq" or user_quant == "marlin"
689
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
690
+ device = getattr(layer, "qweight").device
691
+ c = self.kernel_config
692
+
693
+ check_marlin_supports_shape(
694
+ c.partition_weight_shape[1], # out_features
695
+ c.partition_weight_shape[0], # in_features
696
+ c.full_weight_shape[0], # in_features
697
+ c.group_size,
431
698
  )
432
699
 
433
- if is_marlin_format and is_valid_user_quant:
434
- msg = "The model is serialized in {} format. Using {} kernel.".format(
435
- cls.get_name(), cls.get_name()
700
+ row_parallel = c.partition_weight_shape[0] != c.full_weight_shape[0]
701
+ self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel)
702
+
703
+ # Allocate marlin workspace.
704
+ self.workspace = marlin_make_workspace(device)
705
+
706
+ # Default names since marlin requires empty parameters for these,
707
+ # TODO: remove this requirement from marlin (allow optional tensors)
708
+ self.w_q_name = "qweight"
709
+ self.w_s_name = "scales"
710
+ self.w_zp_name = "qzeros"
711
+ self.w_gidx_name = "g_idx"
712
+
713
+ def _transform_param(
714
+ layer: torch.nn.Module, name: Optional[str], fn: Callable
715
+ ) -> None:
716
+ if name is not None and getattr(layer, name, None) is not None:
717
+
718
+ old_param = getattr(layer, name)
719
+ new_param = fn(old_param)
720
+ # replace the parameter with torch.nn.Parameter for TorchDynamo
721
+ # compatibility
722
+ replace_parameter(
723
+ layer, name, torch.nn.Parameter(new_param.data, requires_grad=False)
724
+ )
725
+
726
+ def transform_w_q(x):
727
+ assert isinstance(x, BasevLLMParameter)
728
+ permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
729
+ x.data = torch.ops.sgl_kernel.gptq_marlin_repack(
730
+ x.data.contiguous(),
731
+ perm=layer.g_idx_sort_indices,
732
+ size_k=c.partition_weight_shape[0],
733
+ size_n=c.partition_weight_shape[1],
734
+ num_bits=c.weight_type.size_bits,
436
735
  )
437
- logger.info(msg)
438
- return cls.get_name()
736
+ return x
737
+
738
+ def transform_w_s(x):
739
+ assert isinstance(x, BasevLLMParameter)
740
+ permute_param_layout_(x, input_dim=0, output_dim=1)
741
+ x.data = marlin_permute_scales(
742
+ x.data.contiguous(),
743
+ size_k=c.partition_weight_shape[0],
744
+ size_n=c.partition_weight_shape[1],
745
+ group_size=c.group_size,
746
+ )
747
+ return x
439
748
 
440
- return None
749
+ if c.has_g_idx:
750
+ g_idx, g_idx_sort_indices = marlin_sort_g_idx(
751
+ getattr(layer, self.w_gidx_name)
752
+ )
753
+ _transform_param(layer, self.w_gidx_name, lambda _: g_idx)
754
+ layer.g_idx_sort_indices = g_idx_sort_indices
755
+ else:
756
+ setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device))
757
+ layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
441
758
 
442
- def get_quant_method(
443
- self, layer: torch.nn.Module, prefix: str
444
- ) -> Optional[MarlinLinearMethod]:
445
- # Delay the import to avoid circular dependency
446
- from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
759
+ if c.zero_points:
760
+ grouped_k = (
761
+ c.partition_weight_shape[0] // c.group_size if c.group_size != -1 else 1
762
+ )
763
+ _transform_param(
764
+ layer,
765
+ self.w_zp_name,
766
+ lambda x: marlin_zero_points(
767
+ unpack_cols(
768
+ x.t(),
769
+ c.weight_type.size_bits,
770
+ grouped_k,
771
+ c.partition_weight_shape[1],
772
+ ),
773
+ size_k=grouped_k,
774
+ size_n=c.partition_weight_shape[1],
775
+ num_bits=c.weight_type.size_bits,
776
+ ),
777
+ )
778
+ else:
779
+ setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device))
780
+ _transform_param(layer, self.w_q_name, transform_w_q)
781
+ _transform_param(layer, self.w_s_name, transform_w_s)
447
782
 
448
- if isinstance(layer, LinearBase) or (
449
- isinstance(layer, ParallelLMHead) and self.lm_head_quantized
450
- ):
451
- return MarlinLinearMethod(self)
452
- return None
783
+ def apply(
784
+ self,
785
+ layer: torch.nn.Module,
786
+ x: torch.Tensor,
787
+ bias: Optional[torch.Tensor] = None,
788
+ ) -> torch.Tensor:
789
+ c = self.kernel_config
790
+
791
+ def _get_weight_params(
792
+ layer: torch.nn.Module,
793
+ ) -> tuple[
794
+ torch.Tensor, # w_q
795
+ torch.Tensor, # w_s
796
+ Optional[torch.Tensor], # w_zp,
797
+ Optional[torch.Tensor], # w_gidx
798
+ ]:
799
+ return (
800
+ getattr(layer, self.w_q_name),
801
+ getattr(layer, self.w_s_name),
802
+ getattr(layer, self.w_zp_name or "", None),
803
+ getattr(layer, self.w_gidx_name or "", None),
804
+ )
805
+
806
+ w_q, w_s, w_zp, w_gidx = _get_weight_params(layer)
807
+
808
+ # `process_weights_after_loading` will ensure w_zp and w_gidx are not
809
+ # None for marlin
810
+ return apply_gptq_marlin_linear(
811
+ input=x,
812
+ weight=w_q,
813
+ weight_scale=w_s,
814
+ weight_zp=w_zp, # type: ignore
815
+ g_idx=w_gidx, # type: ignore
816
+ g_idx_sort_indices=layer.g_idx_sort_indices,
817
+ workspace=self.workspace,
818
+ wtype=c.weight_type,
819
+ input_size_per_partition=c.partition_weight_shape[0],
820
+ output_size_per_partition=c.partition_weight_shape[1],
821
+ is_k_full=self.is_k_full,
822
+ bias=bias,
823
+ )
453
824
 
454
825
 
455
826
  class GPTQMarlinMoEMethod(FusedMoEMethodBase):
@@ -467,6 +838,10 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
467
838
  params_dtype: torch.dtype,
468
839
  **extra_weight_attrs,
469
840
  ):
841
+ # Delay the import to avoid circular dependency
842
+ from sglang.srt.layers.linear import set_weight_attrs
843
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
844
+
470
845
  intermediate_size = extra_weight_attrs.pop("intermediate_size")
471
846
 
472
847
  self.is_k_full = (not self.quant_config.desc_act) or (
@@ -644,20 +1019,20 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
644
1019
  requires_grad=False,
645
1020
  )
646
1021
  # Repack weights
647
- marlin_w13_qweight = ops.gptq_marlin_moe_repack(
1022
+ marlin_w13_qweight = gptq_marlin_moe_repack(
648
1023
  layer.w13_qweight,
649
1024
  layer.w13_g_idx_sort_indices,
650
1025
  layer.w13_qweight.shape[1] * self.quant_config.pack_factor,
651
1026
  layer.w13_qweight.shape[2],
652
- self.quant_config.quant_type.size_bits,
1027
+ self.quant_config.weight_bits,
653
1028
  )
654
1029
  replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
655
- marlin_w2_qweight = ops.gptq_marlin_moe_repack(
1030
+ marlin_w2_qweight = gptq_marlin_moe_repack(
656
1031
  layer.w2_qweight,
657
1032
  layer.w2_g_idx_sort_indices,
658
1033
  layer.w2_qweight.shape[1] * self.quant_config.pack_factor,
659
1034
  layer.w2_qweight.shape[2],
660
- self.quant_config.quant_type.size_bits,
1035
+ self.quant_config.weight_bits,
661
1036
  )
662
1037
  replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
663
1038
  # Repack scales
@@ -685,39 +1060,22 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
685
1060
  self,
686
1061
  layer: torch.nn.Module,
687
1062
  x: torch.Tensor,
688
- router_logits: torch.Tensor,
689
- top_k: int,
690
- renormalize: bool,
691
- use_grouped_topk: bool = False,
692
- topk_group: Optional[int] = None,
693
- num_expert_group: Optional[int] = None,
694
- global_num_experts: int = -1,
695
- expert_map: Optional[torch.Tensor] = None,
696
- custom_routing_function: Optional[Callable] = None,
697
- scoring_func: str = "softmax",
698
- e_score_correction_bias: Optional[torch.Tensor] = None,
1063
+ topk_output: TopKOutput,
1064
+ *,
699
1065
  activation: str = "silu",
1066
+ **kwargs,
700
1067
  ) -> torch.Tensor:
1068
+ # Delay the import to avoid circular dependency
1069
+
701
1070
  assert activation == "silu", "Only SiLU activation is supported."
702
1071
 
703
1072
  # The input must currently be float16
704
1073
  orig_dtype = x.dtype
705
1074
  x = x.half()
706
1075
 
707
- topk_weights, topk_ids = FusedMoE.select_experts(
708
- hidden_states=x,
709
- router_logits=router_logits,
710
- use_grouped_topk=use_grouped_topk,
711
- top_k=top_k,
712
- renormalize=renormalize,
713
- topk_group=topk_group,
714
- num_expert_group=num_expert_group,
715
- custom_routing_function=custom_routing_function,
716
- scoring_func=scoring_func,
717
- e_score_correction_bias=e_score_correction_bias,
718
- )
1076
+ topk_weights, topk_ids, router_logits = topk_output
719
1077
 
720
- return torch.ops.vllm.fused_marlin_moe(
1078
+ return fused_marlin_moe(
721
1079
  x,
722
1080
  layer.w13_qweight,
723
1081
  layer.w2_qweight,
@@ -730,6 +1088,6 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
730
1088
  g_idx2=layer.w2_g_idx,
731
1089
  sort_indices1=layer.w13_g_idx_sort_indices,
732
1090
  sort_indices2=layer.w2_g_idx_sort_indices,
733
- quant_type_id=self.quant_config.quant_type.id,
1091
+ num_bits=self.quant_config.weight_bits,
734
1092
  is_k_full=self.is_k_full,
735
1093
  ).to(orig_dtype)