sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. sglang/bench_one_batch.py +2 -1
  2. sglang/eval/loogle_eval.py +7 -0
  3. sglang/srt/_custom_ops.py +29 -1
  4. sglang/srt/configs/deepseekvl2.py +11 -2
  5. sglang/srt/configs/internvl.py +3 -0
  6. sglang/srt/configs/janus_pro.py +3 -0
  7. sglang/srt/configs/model_config.py +10 -8
  8. sglang/srt/configs/update_config.py +3 -1
  9. sglang/srt/conversation.py +2 -1
  10. sglang/srt/custom_op.py +5 -2
  11. sglang/srt/disaggregation/common/conn.py +34 -6
  12. sglang/srt/disaggregation/decode.py +9 -1
  13. sglang/srt/disaggregation/mini_lb.py +3 -2
  14. sglang/srt/disaggregation/mooncake/conn.py +93 -76
  15. sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
  16. sglang/srt/disaggregation/nixl/conn.py +17 -13
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
  18. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
  19. sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
  20. sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
  21. sglang/srt/distributed/parallel_state.py +103 -15
  22. sglang/srt/entrypoints/engine.py +31 -33
  23. sglang/srt/entrypoints/http_server.py +20 -32
  24. sglang/srt/entrypoints/openai/protocol.py +3 -3
  25. sglang/srt/entrypoints/openai/serving_chat.py +48 -6
  26. sglang/srt/eplb/expert_location_dispatch.py +1 -1
  27. sglang/srt/function_call/base_format_detector.py +74 -12
  28. sglang/srt/function_call/deepseekv3_detector.py +26 -11
  29. sglang/srt/function_call/ebnf_composer.py +95 -63
  30. sglang/srt/function_call/function_call_parser.py +4 -2
  31. sglang/srt/function_call/kimik2_detector.py +41 -16
  32. sglang/srt/function_call/llama32_detector.py +6 -3
  33. sglang/srt/function_call/mistral_detector.py +11 -3
  34. sglang/srt/function_call/pythonic_detector.py +16 -14
  35. sglang/srt/function_call/qwen25_detector.py +12 -3
  36. sglang/srt/function_call/qwen3_coder_detector.py +151 -0
  37. sglang/srt/hf_transformers_utils.py +0 -1
  38. sglang/srt/layers/activation.py +24 -3
  39. sglang/srt/layers/attention/base_attn_backend.py +3 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +3 -3
  41. sglang/srt/layers/attention/flashinfer_backend.py +40 -1
  42. sglang/srt/layers/communicator.py +12 -12
  43. sglang/srt/layers/dp_attention.py +72 -24
  44. sglang/srt/layers/linear.py +13 -102
  45. sglang/srt/layers/logits_processor.py +34 -24
  46. sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
  47. sglang/srt/layers/moe/ep_moe/layer.py +23 -402
  48. sglang/srt/layers/moe/fused_moe_native.py +7 -47
  49. sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
  50. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +54 -263
  57. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
  58. sglang/srt/layers/moe/topk.py +190 -23
  59. sglang/srt/layers/quantization/__init__.py +20 -134
  60. sglang/srt/layers/quantization/awq.py +578 -11
  61. sglang/srt/layers/quantization/awq_triton.py +339 -0
  62. sglang/srt/layers/quantization/base_config.py +85 -10
  63. sglang/srt/layers/quantization/blockwise_int8.py +17 -55
  64. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
  65. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +23 -79
  66. sglang/srt/layers/quantization/fp8.py +273 -62
  67. sglang/srt/layers/quantization/fp8_kernel.py +210 -46
  68. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  69. sglang/srt/layers/quantization/gptq.py +501 -143
  70. sglang/srt/layers/quantization/marlin_utils.py +790 -0
  71. sglang/srt/layers/quantization/modelopt_quant.py +34 -112
  72. sglang/srt/layers/quantization/moe_wna16.py +45 -49
  73. sglang/srt/layers/quantization/petit.py +252 -0
  74. sglang/srt/layers/quantization/petit_utils.py +104 -0
  75. sglang/srt/layers/quantization/qoq.py +7 -6
  76. sglang/srt/layers/quantization/scalar_type.py +352 -0
  77. sglang/srt/layers/quantization/unquant.py +422 -0
  78. sglang/srt/layers/quantization/utils.py +340 -9
  79. sglang/srt/layers/quantization/w4afp8.py +8 -4
  80. sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
  81. sglang/srt/layers/quantization/w8a8_int8.py +51 -115
  82. sglang/srt/layers/radix_attention.py +5 -3
  83. sglang/srt/layers/vocab_parallel_embedding.py +1 -41
  84. sglang/srt/lora/lora.py +0 -4
  85. sglang/srt/lora/lora_manager.py +162 -164
  86. sglang/srt/lora/lora_registry.py +124 -0
  87. sglang/srt/lora/mem_pool.py +83 -35
  88. sglang/srt/lora/utils.py +12 -5
  89. sglang/srt/managers/cache_controller.py +288 -0
  90. sglang/srt/managers/io_struct.py +60 -30
  91. sglang/srt/managers/mm_utils.py +7 -8
  92. sglang/srt/managers/schedule_batch.py +163 -113
  93. sglang/srt/managers/schedule_policy.py +68 -27
  94. sglang/srt/managers/scheduler.py +256 -86
  95. sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
  96. sglang/srt/managers/tokenizer_manager.py +38 -27
  97. sglang/srt/managers/tp_worker.py +16 -4
  98. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  99. sglang/srt/mem_cache/allocator.py +74 -23
  100. sglang/srt/mem_cache/base_prefix_cache.py +14 -2
  101. sglang/srt/mem_cache/chunk_cache.py +5 -2
  102. sglang/srt/mem_cache/hicache_storage.py +168 -0
  103. sglang/srt/mem_cache/hiradix_cache.py +194 -5
  104. sglang/srt/mem_cache/memory_pool.py +16 -1
  105. sglang/srt/mem_cache/memory_pool_host.py +44 -2
  106. sglang/srt/mem_cache/radix_cache.py +26 -0
  107. sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
  108. sglang/srt/metrics/collector.py +9 -0
  109. sglang/srt/model_executor/cuda_graph_runner.py +66 -31
  110. sglang/srt/model_executor/forward_batch_info.py +210 -25
  111. sglang/srt/model_executor/model_runner.py +147 -42
  112. sglang/srt/model_loader/loader.py +7 -1
  113. sglang/srt/model_loader/utils.py +4 -4
  114. sglang/srt/models/clip.py +1 -1
  115. sglang/srt/models/deepseek.py +9 -6
  116. sglang/srt/models/deepseek_janus_pro.py +1 -1
  117. sglang/srt/models/deepseek_v2.py +192 -173
  118. sglang/srt/models/deepseek_vl2.py +5 -5
  119. sglang/srt/models/gemma.py +48 -0
  120. sglang/srt/models/gemma2.py +52 -0
  121. sglang/srt/models/gemma3_causal.py +63 -0
  122. sglang/srt/models/gemma3_mm.py +1 -1
  123. sglang/srt/models/gemma3n_mm.py +2 -4
  124. sglang/srt/models/granitemoe.py +385 -0
  125. sglang/srt/models/grok.py +9 -3
  126. sglang/srt/models/hunyuan.py +63 -16
  127. sglang/srt/models/internvl.py +1 -1
  128. sglang/srt/models/kimi_vl.py +1 -1
  129. sglang/srt/models/llama.py +41 -0
  130. sglang/srt/models/llama4.py +11 -11
  131. sglang/srt/models/llava.py +2 -2
  132. sglang/srt/models/llavavid.py +1 -1
  133. sglang/srt/models/minicpm.py +0 -2
  134. sglang/srt/models/minicpmo.py +3 -7
  135. sglang/srt/models/minicpmv.py +1 -1
  136. sglang/srt/models/mistral.py +1 -1
  137. sglang/srt/models/mixtral.py +9 -2
  138. sglang/srt/models/mllama.py +3 -5
  139. sglang/srt/models/mllama4.py +13 -6
  140. sglang/srt/models/olmoe.py +8 -5
  141. sglang/srt/models/persimmon.py +330 -0
  142. sglang/srt/models/phi.py +321 -0
  143. sglang/srt/models/phi4mm.py +44 -4
  144. sglang/srt/models/phi4mm_audio.py +1260 -0
  145. sglang/srt/models/phi4mm_utils.py +1917 -0
  146. sglang/srt/models/phimoe.py +9 -3
  147. sglang/srt/models/qwen.py +37 -0
  148. sglang/srt/models/qwen2.py +41 -0
  149. sglang/srt/models/qwen2_5_vl.py +4 -4
  150. sglang/srt/models/qwen2_audio.py +1 -1
  151. sglang/srt/models/qwen2_moe.py +53 -9
  152. sglang/srt/models/qwen2_vl.py +4 -4
  153. sglang/srt/models/qwen3.py +65 -1
  154. sglang/srt/models/qwen3_moe.py +57 -24
  155. sglang/srt/models/vila.py +1 -1
  156. sglang/srt/multimodal/processors/base_processor.py +91 -97
  157. sglang/srt/multimodal/processors/clip.py +21 -19
  158. sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
  159. sglang/srt/multimodal/processors/gemma3.py +13 -17
  160. sglang/srt/multimodal/processors/gemma3n.py +19 -23
  161. sglang/srt/multimodal/processors/internvl.py +9 -10
  162. sglang/srt/multimodal/processors/janus_pro.py +12 -27
  163. sglang/srt/multimodal/processors/kimi_vl.py +12 -14
  164. sglang/srt/multimodal/processors/llava.py +4 -2
  165. sglang/srt/multimodal/processors/minicpm.py +35 -44
  166. sglang/srt/multimodal/processors/mlama.py +21 -18
  167. sglang/srt/multimodal/processors/mllama4.py +4 -5
  168. sglang/srt/multimodal/processors/phi4mm.py +63 -39
  169. sglang/srt/multimodal/processors/pixtral.py +14 -35
  170. sglang/srt/multimodal/processors/qwen_audio.py +65 -0
  171. sglang/srt/multimodal/processors/qwen_vl.py +16 -21
  172. sglang/srt/multimodal/processors/vila.py +14 -14
  173. sglang/srt/reasoning_parser.py +46 -4
  174. sglang/srt/sampling/sampling_batch_info.py +6 -5
  175. sglang/srt/sampling/sampling_params.py +8 -1
  176. sglang/srt/server_args.py +454 -270
  177. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
  178. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +46 -37
  179. sglang/srt/speculative/eagle_utils.py +51 -23
  180. sglang/srt/speculative/eagle_worker.py +59 -44
  181. sglang/srt/two_batch_overlap.py +10 -5
  182. sglang/srt/utils.py +44 -69
  183. sglang/test/runners.py +14 -3
  184. sglang/test/test_activation.py +50 -1
  185. sglang/test/test_block_fp8.py +8 -3
  186. sglang/test/test_block_fp8_ep.py +1 -1
  187. sglang/test/test_custom_ops.py +12 -7
  188. sglang/test/test_cutlass_w4a8_moe.py +1 -3
  189. sglang/test/test_fp4_moe.py +1 -3
  190. sglang/test/test_marlin_moe.py +286 -0
  191. sglang/test/test_marlin_utils.py +171 -0
  192. sglang/test/test_utils.py +35 -0
  193. sglang/version.py +1 -1
  194. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +10 -10
  195. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +198 -175
  196. sglang/srt/layers/quantization/quant_utils.py +0 -166
  197. sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
  198. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
  199. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
  200. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,790 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/marlin_utils.py
3
+
4
+ from __future__ import annotations
5
+
6
+ import logging
7
+ from typing import TYPE_CHECKING, Any, Optional
8
+
9
+ import numpy
10
+ import torch
11
+
12
+ from sglang.srt.layers.parameter import (
13
+ BasevLLMParameter,
14
+ ChannelQuantScaleParameter,
15
+ GroupQuantScaleParameter,
16
+ PackedvLLMParameter,
17
+ )
18
+ from sglang.srt.layers.quantization.base_config import (
19
+ LinearMethodBase,
20
+ QuantizationConfig,
21
+ )
22
+ from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types
23
+ from sglang.srt.layers.quantization.utils import pack_cols, unpack_cols
24
+ from sglang.srt.utils import get_device_capability
25
+
26
+ if TYPE_CHECKING:
27
+ from sglang.srt.layers.linear import LinearBase
28
+
29
+ try:
30
+ from vllm import _custom_ops as ops
31
+ except ImportError:
32
+ ops = None
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+ GPTQ_MARLIN_TILE = 16
37
+ GPTQ_MARLIN_MIN_THREAD_N = 64
38
+ GPTQ_MARLIN_MIN_THREAD_K = 128
39
+ GPTQ_MARLIN_MAX_PARALLEL = 16
40
+
41
+ MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
42
+
43
+ # In case there is a performance issue with Marlin, the variable below can be
44
+ # changed to False, which allows Marlin to perform global reductions in fp16
45
+ # precision (instead of fp32), and therefore, save on some memory movements.
46
+ USE_FP32_REDUCE_DEFAULT = True
47
+
48
+
49
+ # For binary size and compile time, we don't support the same types for with and
50
+ # without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
51
+ # TODO: we may want to move this into the C++ so its closer to the actual impl
52
+ def query_marlin_supported_quant_types(
53
+ has_zp: Optional[bool] = None,
54
+ include_fp_type: bool = True,
55
+ device_capability: Optional[int] = None,
56
+ ):
57
+ if device_capability is None:
58
+ major, minor = get_device_capability()
59
+ capability = major * 10 + minor
60
+ device_capability = -1 if capability is None else capability
61
+
62
+ if device_capability < 80:
63
+ return []
64
+
65
+ # - has_zp is True: return quant_types that has zero points
66
+ # - has_zp is False: return quant_types that has not zero points
67
+ # - has_zp is None: both
68
+ if has_zp is None:
69
+ types0 = query_marlin_supported_quant_types(
70
+ False, include_fp_type, device_capability
71
+ )
72
+ types1 = query_marlin_supported_quant_types(
73
+ True, include_fp_type, device_capability
74
+ )
75
+ return types0 + types1
76
+
77
+ if has_zp:
78
+ # AWQ style, unsigned + runtime zero-point
79
+ return [scalar_types.uint4]
80
+ else:
81
+ # GPTQ style, unsigned + symmetric bias
82
+ res = [scalar_types.uint4b8, scalar_types.uint8b128]
83
+ if include_fp_type:
84
+ res += [scalar_types.float8_e4m3fn, scalar_types.float4_e2m1f]
85
+ return res
86
+
87
+
88
+ def _check_marlin_supported(
89
+ quant_type: ScalarType,
90
+ group_size: Optional[int],
91
+ has_zp: bool,
92
+ device_capability: Optional[int] = None,
93
+ ) -> tuple[bool, Optional[str]]:
94
+
95
+ if device_capability is None:
96
+ major, minor = get_device_capability()
97
+ capability = major * 10 + minor
98
+ device_capability = -1 if capability is None else capability
99
+
100
+ supported_types = query_marlin_supported_quant_types(
101
+ has_zp, True, device_capability
102
+ )
103
+
104
+ if quant_type not in supported_types:
105
+ return (
106
+ False,
107
+ f"Marlin does not support weight_bits = {quant_type}. "
108
+ f"Only types = {supported_types} "
109
+ f"are supported (for group_size = {group_size}, "
110
+ f"device_capability = {device_capability}, zp = {has_zp}).",
111
+ )
112
+ if group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES:
113
+ return (
114
+ False,
115
+ f"Marlin does not support group_size = {group_size}. "
116
+ f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} "
117
+ "are supported.",
118
+ )
119
+
120
+ return True, None
121
+
122
+
123
+ def check_marlin_supported(
124
+ quant_type: ScalarType,
125
+ group_size: int,
126
+ has_zp: bool = False,
127
+ device_capability: Optional[int] = None,
128
+ ) -> bool:
129
+ cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, device_capability)
130
+ return cond
131
+
132
+
133
+ def verify_marlin_supported(
134
+ quant_type: ScalarType, group_size: int, has_zp: bool = False
135
+ ) -> None:
136
+ cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp)
137
+ if not cond:
138
+ assert err_msg is not None
139
+ raise ValueError(err_msg)
140
+
141
+
142
+ def verify_marlin_supports_shape(
143
+ output_size_per_partition: int,
144
+ input_size_per_partition: int,
145
+ input_size: int,
146
+ group_size: int,
147
+ ) -> None:
148
+
149
+ # Validate output_size_per_partition
150
+ if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0:
151
+ raise ValueError(
152
+ f"Weight output_size_per_partition = "
153
+ f"{output_size_per_partition} is not divisible by "
154
+ f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. "
155
+ "Consider reducing tensor_parallel_size or running "
156
+ "with --quantization gptq."
157
+ )
158
+
159
+ # Validate input_size_per_partition
160
+ if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0:
161
+ raise ValueError(
162
+ f"Weight input_size_per_partition = "
163
+ f"{input_size_per_partition} is not divisible "
164
+ f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. "
165
+ "Consider reducing tensor_parallel_size or running "
166
+ "with --quantization gptq."
167
+ )
168
+
169
+ if group_size < input_size and input_size_per_partition % group_size != 0:
170
+ raise ValueError(
171
+ f"Weight input_size_per_partition = {input_size_per_partition}"
172
+ f" is not divisible by group_size = {group_size}. "
173
+ "Consider reducing tensor_parallel_size or running "
174
+ "with --quantization gptq."
175
+ )
176
+
177
+
178
+ def check_marlin_supports_shape(
179
+ output_size_per_partition: int,
180
+ input_size_per_partition: int,
181
+ input_size: int,
182
+ group_size: int,
183
+ ) -> tuple[bool, Optional[str]]:
184
+ try:
185
+ verify_marlin_supports_shape(
186
+ output_size_per_partition, input_size_per_partition, input_size, group_size
187
+ )
188
+ except ValueError as e:
189
+ return False, e.__str__()
190
+ return True, None
191
+
192
+
193
+ def check_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool:
194
+ output_size_per_partition = (
195
+ getattr(layer, "output_size_per_partition", None) or layer.output_size
196
+ )
197
+ input_size_per_partition = (
198
+ getattr(layer, "input_size_per_partition", None) or layer.input_size
199
+ )
200
+
201
+ return check_marlin_supports_shape(
202
+ output_size_per_partition=output_size_per_partition,
203
+ input_size_per_partition=input_size_per_partition,
204
+ input_size=layer.input_size,
205
+ group_size=group_size,
206
+ )[0]
207
+
208
+
209
+ def check_moe_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool:
210
+ hidden_size = layer.hidden_size
211
+ intermediate_size_per_partition = layer.intermediate_size_per_partition
212
+ # apply_router_weight_on_input is not supported for moe marlin
213
+ supports_router_weight = not layer.apply_router_weight_on_input
214
+ # moe marlin requires the activation to be silu
215
+ supports_activation = layer.activation == "silu"
216
+
217
+ # gate-up: (n, k) = (intermediate_size_per_partition * 2, hidden_size)
218
+ # down: (n, k) = (hidden_size, intermediate_size_per_partition)
219
+ # moe marlin requires n % 128 == 0 and k % 64 == 0
220
+ supports_shape = (
221
+ hidden_size % 128 == 0
222
+ and intermediate_size_per_partition % max(64, group_size) == 0
223
+ )
224
+ supports_group_size = group_size in [-1, 32, 64, 128]
225
+ return (
226
+ supports_shape
227
+ and supports_group_size
228
+ and supports_router_weight
229
+ and supports_activation
230
+ )
231
+
232
+
233
+ def marlin_make_workspace(
234
+ device: torch.device, max_blocks_per_sm: int = 1
235
+ ) -> torch.Tensor:
236
+ # In the new marlin kernel, we use the num of threadblocks as workspace
237
+ # size. The num of threadblocks is is sms_count * max_blocks_per_sm.
238
+ sms = torch.cuda.get_device_properties(device).multi_processor_count
239
+ return torch.zeros(
240
+ sms * max_blocks_per_sm, dtype=torch.int, device=device, requires_grad=False
241
+ )
242
+
243
+
244
+ def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool:
245
+ return (not act_order) or (act_order and not is_row_parallel)
246
+
247
+
248
+ def marlin_repeat_scales_on_all_ranks(
249
+ act_order: bool, group_size: int, is_row_parallel: bool
250
+ ) -> bool:
251
+ # Need to repeat scales on every rank if act_ordering or
252
+ # channelwise and RowParallelLinear
253
+ is_channelwise = group_size == -1
254
+ return act_order or (is_channelwise and is_row_parallel)
255
+
256
+
257
+ def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor:
258
+ return torch.nn.Parameter(
259
+ torch.empty(0, dtype=torch.int, device=device), requires_grad=False
260
+ )
261
+
262
+
263
+ def marlin_make_empty_zp(device: torch.device) -> torch.Tensor:
264
+ return torch.nn.Parameter(
265
+ torch.empty(0, dtype=torch.int, device=device), requires_grad=False
266
+ )
267
+
268
+
269
+ def marlin_sort_g_idx(g_idx: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
270
+ g_idx_sort_indices = torch.argsort(g_idx).to(torch.int)
271
+ return g_idx[g_idx_sort_indices], g_idx_sort_indices
272
+
273
+
274
+ def get_scale_perms():
275
+ scale_perm: list[int] = []
276
+ for i in range(8):
277
+ scale_perm.extend([i + 8 * j for j in range(8)])
278
+ scale_perm_single: list[int] = []
279
+ for i in range(4):
280
+ scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
281
+ return scale_perm, scale_perm_single
282
+
283
+
284
+ def marlin_permute_scales(
285
+ s: torch.Tensor, size_k: int, size_n: int, group_size: int
286
+ ) -> torch.Tensor:
287
+
288
+ scale_perm, scale_perm_single = get_scale_perms()
289
+ if group_size < size_k and group_size != -1:
290
+ s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
291
+ else:
292
+ s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
293
+ s = s.reshape((-1, size_n)).contiguous()
294
+
295
+ return s
296
+
297
+
298
+ def marlin_moe_permute_scales(
299
+ s: torch.Tensor,
300
+ size_k: int,
301
+ size_n: int,
302
+ group_size: int,
303
+ ):
304
+ num_experts = s.shape[0]
305
+ output = torch.empty(
306
+ (num_experts, s.shape[1], s.shape[2]),
307
+ device=s.device,
308
+ dtype=s.dtype,
309
+ )
310
+
311
+ for e in range(num_experts):
312
+ output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size)
313
+ return output
314
+
315
+
316
+ def marlin_zero_points(
317
+ zp: torch.Tensor, size_k: int, size_n: int, num_bits: int
318
+ ) -> torch.Tensor:
319
+ # Permute zero-points in a similar way to scales, but do not use the
320
+ # "single" permutation, since zero-points are applied on every MMA
321
+ scale_perm, _ = get_scale_perms()
322
+ zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm]
323
+
324
+ # Interleave column dim (for the dequantize code) and pack it to int32
325
+ if num_bits == 4:
326
+ interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
327
+ elif num_bits == 8:
328
+ interleave = numpy.array([0, 2, 1, 3])
329
+ else:
330
+ raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
331
+
332
+ zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel()
333
+ zp = zp.reshape((-1, size_n)).contiguous()
334
+ zp = pack_cols(zp, num_bits, size_k, size_n)
335
+
336
+ return zp
337
+
338
+
339
+ def awq_to_marlin_zero_points(
340
+ q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
341
+ ) -> torch.Tensor:
342
+ # AWQ zero-points are quantized and packed on the column dim.
343
+ # In addition, the values are permuted based on dequantizer.
344
+ # Here we undo both of these, and then apply marlin permutation
345
+ # and pack it back.
346
+ q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n)
347
+
348
+ # Undo interleaving (use argsort(..) to get inverse perm)
349
+ if num_bits == 4:
350
+ undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7]))
351
+ elif num_bits == 8:
352
+ undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3]))
353
+ else:
354
+ raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
355
+
356
+ q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel()
357
+ q_zp = q_zp.reshape((-1, size_n)).contiguous()
358
+
359
+ marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits)
360
+ return marlin_zp
361
+
362
+
363
+ def moe_awq_to_marlin_zero_points(
364
+ q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
365
+ ):
366
+ num_experts = q_zp_packed.shape[0]
367
+ output = torch.empty(
368
+ (num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]),
369
+ device=q_zp_packed.device,
370
+ dtype=q_zp_packed.dtype,
371
+ )
372
+ for e in range(num_experts):
373
+ output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, num_bits)
374
+ return output
375
+
376
+
377
+ def maybe_warn_marlin_atomic_add(device, dtype):
378
+ if torch.compiler.is_dynamo_compiling():
379
+ return
380
+ device_capability = torch.cuda.get_device_capability(device)
381
+ if device_capability[0] < 9 and dtype == torch.bfloat16:
382
+ logger.info_once(
383
+ "You are running Marlin kernel with bf16 on GPUs before SM90. "
384
+ "You can consider change to fp16 to achieve better performance "
385
+ "if possible."
386
+ )
387
+
388
+
389
+ def maybe_warn_marlin_atomic_add_env():
390
+ if torch.compiler.is_dynamo_compiling():
391
+ return
392
+ # TODO(yiyun): Need to add sglang's MARLIN_USE_ATOMIC_ADD: bool = False
393
+ if True:
394
+ return
395
+ # if envs.VLLM_MARLIN_USE_ATOMIC_ADD:
396
+ # return
397
+ logger.info_once(
398
+ "Marlin kernel can achieve better performance for small size_n "
399
+ "with experimental use_atomic_add feature. "
400
+ "You can consider set environment variable "
401
+ "VLLM_MARLIN_USE_ATOMIC_ADD to 1 if possible."
402
+ )
403
+
404
+
405
+ def should_use_atomic_add_reduce(
406
+ m: int, n: int, k: int, device: torch.device, dtype: torch.dtype
407
+ ) -> bool:
408
+
409
+ # the performance of atomicAdd is better than global reduce
410
+ # only when m*n is small and k is large
411
+ if n >= 2048 or k < 2048 or device.type != "cuda":
412
+ return False
413
+
414
+ # disable atomicAdd reduce by default,
415
+ # one can enable it with VLLM_MARLIN_USE_ATOMIC_ADD=1
416
+ # TODO: Need to add sglang's MARLIN_USE_ATOMIC_ADD: bool = False
417
+ if not True:
418
+ maybe_warn_marlin_atomic_add_env()
419
+ return False
420
+
421
+ # sm8x doesn't support atomicAdd + bfloat16 natively
422
+ device_capability = torch.cuda.get_device_capability(device)
423
+ if device_capability[0] < 9 and dtype == torch.bfloat16:
424
+ maybe_warn_marlin_atomic_add(device, dtype)
425
+ return False
426
+
427
+ return True
428
+
429
+
430
+ def apply_gptq_marlin_linear(
431
+ input: torch.Tensor,
432
+ weight: torch.Tensor,
433
+ weight_scale: torch.Tensor,
434
+ weight_zp: torch.Tensor,
435
+ g_idx: torch.Tensor,
436
+ g_idx_sort_indices: torch.Tensor,
437
+ workspace: torch.Tensor,
438
+ wtype: ScalarType,
439
+ output_size_per_partition: int,
440
+ input_size_per_partition: int,
441
+ is_k_full: bool,
442
+ bias: Optional[torch.Tensor] = None,
443
+ use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
444
+ ) -> torch.Tensor:
445
+ reshaped_x = input.reshape(-1, input.shape[-1])
446
+ out_shape = input.shape[:-1] + (output_size_per_partition,)
447
+
448
+ use_atomic_add = should_use_atomic_add_reduce(
449
+ m=reshaped_x.size(0),
450
+ n=output_size_per_partition,
451
+ k=reshaped_x.size(1),
452
+ device=input.device,
453
+ dtype=input.dtype,
454
+ )
455
+
456
+ output = ops.gptq_marlin_gemm(
457
+ reshaped_x,
458
+ None,
459
+ weight,
460
+ weight_scale,
461
+ None,
462
+ weight_zp,
463
+ g_idx,
464
+ g_idx_sort_indices,
465
+ workspace,
466
+ wtype,
467
+ size_m=reshaped_x.shape[0],
468
+ size_n=output_size_per_partition,
469
+ size_k=input_size_per_partition,
470
+ is_k_full=is_k_full,
471
+ use_atomic_add=use_atomic_add,
472
+ use_fp32_reduce=use_fp32_reduce,
473
+ is_zp_float=False,
474
+ )
475
+
476
+ if bias is not None:
477
+ output.add_(bias) # In-place add
478
+
479
+ return output.reshape(out_shape)
480
+
481
+
482
+ def apply_awq_marlin_linear(
483
+ input: torch.Tensor,
484
+ weight: torch.Tensor,
485
+ weight_scale: torch.Tensor,
486
+ weight_zp: torch.Tensor,
487
+ g_idx: torch.Tensor,
488
+ g_idx_sort_indices: torch.Tensor,
489
+ workspace: torch.Tensor,
490
+ quant_type: ScalarType,
491
+ output_size_per_partition: int,
492
+ input_size_per_partition: int,
493
+ bias: Optional[torch.Tensor] = None,
494
+ use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
495
+ ) -> torch.Tensor:
496
+ reshaped_x = input.reshape(-1, input.shape[-1])
497
+ out_shape = input.shape[:-1] + (output_size_per_partition,)
498
+
499
+ use_atomic_add = should_use_atomic_add_reduce(
500
+ m=reshaped_x.size(0),
501
+ n=output_size_per_partition,
502
+ k=reshaped_x.size(1),
503
+ device=input.device,
504
+ dtype=input.dtype,
505
+ )
506
+
507
+ output = ops.gptq_marlin_gemm(
508
+ reshaped_x,
509
+ None,
510
+ weight,
511
+ weight_scale,
512
+ None,
513
+ weight_zp,
514
+ g_idx,
515
+ g_idx_sort_indices,
516
+ workspace,
517
+ quant_type,
518
+ size_m=reshaped_x.shape[0],
519
+ size_n=output_size_per_partition,
520
+ size_k=input_size_per_partition,
521
+ use_atomic_add=use_atomic_add,
522
+ use_fp32_reduce=use_fp32_reduce,
523
+ is_zp_float=False,
524
+ )
525
+
526
+ if bias is not None:
527
+ output.add_(bias) # In-place add
528
+
529
+ return output.reshape(out_shape)
530
+
531
+
532
+ class MarlinConfig(QuantizationConfig):
533
+ """Config class for Marlin.
534
+
535
+ Reference: https://github.com/IST-DASLab/marlin/tree/master
536
+ """
537
+
538
+ def __init__(
539
+ self,
540
+ group_size: int,
541
+ lm_head_quantized: bool,
542
+ ) -> None:
543
+ super().__init__()
544
+
545
+ # Group size for the quantization.
546
+ self.group_size = group_size
547
+ self.lm_head_quantized = lm_head_quantized
548
+ if self.group_size != 128 and self.group_size != -1:
549
+ raise ValueError(
550
+ "Currently, only group size 128 and -1 (channelwise) "
551
+ "is supported for Marlin, but got group_size of "
552
+ f"{self.group_size}"
553
+ )
554
+
555
+ # 4 Bits packed into 32 bit datatype.
556
+ self.pack_factor = 32 // 4
557
+
558
+ # Tile size used by marlin kernels.
559
+ self.tile_size = 16
560
+
561
+ # Min out_features dim
562
+ self.min_n_threads = 64
563
+
564
+ # Min in_features dim
565
+ self.min_k_threads = 128
566
+
567
+ # Max parallel problems to solve at once (improves large
568
+ # batch performance)
569
+ self.max_parallel = 16
570
+
571
+ # Permutation length used by the marlin kernels.
572
+ self.perm_len = 1024
573
+
574
+ def __repr__(self) -> str:
575
+ return (
576
+ f"MarlinConfig(group_size={self.group_size}, "
577
+ f"lm_head_quantized={self.lm_head_quantized})"
578
+ )
579
+
580
+ @classmethod
581
+ def get_name(cls) -> str:
582
+ return "marlin"
583
+
584
+ @classmethod
585
+ def get_supported_act_dtypes(cls) -> list[torch.dtype]:
586
+ return [torch.half]
587
+
588
+ @classmethod
589
+ # Need to figure it out
590
+ def get_min_capability(cls) -> int:
591
+ return 80
592
+
593
+ @classmethod
594
+ def get_config_filenames(cls) -> list[str]:
595
+ return ["quantize_config.json"]
596
+
597
+ @classmethod
598
+ def from_config(cls, config: dict[str, Any]) -> "MarlinConfig":
599
+ group_size = cls.get_from_keys(config, ["group_size"])
600
+ lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False)
601
+ return cls(group_size, lm_head_quantized)
602
+
603
+ @classmethod
604
+ def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
605
+ # compat: autogptq >=0.8.0 use checkpoint_format: str
606
+ # compat: autogptq <=0.7.1 is_marlin_format: bool
607
+ is_marlin_format = hf_quant_cfg.get(
608
+ "checkpoint_format"
609
+ ) == "marlin" or hf_quant_cfg.get("is_marlin_format", False)
610
+
611
+ is_valid_user_quant = (
612
+ user_quant is None or user_quant == "gptq" or user_quant == "marlin"
613
+ )
614
+
615
+ if is_marlin_format and is_valid_user_quant:
616
+ msg = "The model is serialized in {} format. Using {} kernel.".format(
617
+ cls.get_name(), cls.get_name()
618
+ )
619
+ logger.info(msg)
620
+ return cls.get_name()
621
+
622
+ return None
623
+
624
+ def get_quant_method(
625
+ self, layer: torch.nn.Module, prefix: str
626
+ ) -> Optional[MarlinLinearMethod]:
627
+ from sglang.srt.layers.linear import LinearBase
628
+ from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
629
+
630
+ if isinstance(layer, LinearBase) or (
631
+ isinstance(layer, ParallelLMHead) and self.lm_head_quantized
632
+ ):
633
+ return MarlinLinearMethod(self)
634
+ return None
635
+
636
+
637
+ class MarlinLinearMethod(LinearMethodBase):
638
+ """Linear method for Marlin.
639
+
640
+ Args:
641
+ quant_config: The Marlin quantization config.
642
+ """
643
+
644
+ def __init__(self, quant_config: MarlinConfig):
645
+ self.quant_config = quant_config
646
+
647
+ def create_weights(
648
+ self,
649
+ layer: torch.nn.Module,
650
+ input_size_per_partition: int,
651
+ output_partition_sizes: list[int],
652
+ input_size: int,
653
+ output_size: int,
654
+ params_dtype: torch.dtype,
655
+ **extra_weight_attrs,
656
+ ):
657
+ del output_size # Unused.
658
+ weight_loader = extra_weight_attrs["weight_loader"]
659
+
660
+ if params_dtype != torch.float16:
661
+ raise ValueError(
662
+ f"The params dtype must be float16, but got {params_dtype}"
663
+ )
664
+
665
+ # Validate output_size_per_partition
666
+ output_size_per_partition = sum(output_partition_sizes)
667
+ if output_size_per_partition % self.quant_config.min_n_threads != 0:
668
+ raise ValueError(
669
+ f"Weight output_size_per_partition = "
670
+ f"{output_size_per_partition} is not divisible by "
671
+ f"min_n_threads = {self.quant_config.min_n_threads}."
672
+ )
673
+ if output_size_per_partition % self.quant_config.pack_factor != 0:
674
+ raise ValueError(
675
+ f"Weight output_size_per_partition = "
676
+ f"{output_size_per_partition} is not divisible by "
677
+ f"pack_factor = {self.quant_config.pack_factor}."
678
+ )
679
+
680
+ # Validate input_size_per_partition
681
+ if input_size_per_partition % self.quant_config.min_k_threads != 0:
682
+ raise ValueError(
683
+ f"Weight input_size_per_partition = "
684
+ f"{input_size_per_partition} is not divisible by "
685
+ f"min_k_threads = {self.quant_config.min_k_threads}."
686
+ )
687
+ if (
688
+ self.quant_config.group_size != -1
689
+ and input_size_per_partition % self.quant_config.group_size != 0
690
+ ):
691
+ raise ValueError(
692
+ f"Weight input_size_per_partition = "
693
+ f"{input_size_per_partition} is not divisible by "
694
+ f"group_size = {self.quant_config.group_size}."
695
+ )
696
+
697
+ # Check that we have at least 4 tiles horizontally in the shard
698
+ num_tiles_per_perm = self.quant_config.perm_len // (
699
+ self.quant_config.tile_size**2
700
+ )
701
+ if output_size_per_partition % num_tiles_per_perm != 0:
702
+ raise ValueError("Each permutation group must reside on the same gpu")
703
+
704
+ # Quantized 4Bit weights packed into Int32.
705
+ qweight = PackedvLLMParameter(
706
+ data=torch.empty(
707
+ input_size_per_partition // self.quant_config.tile_size,
708
+ output_size_per_partition
709
+ * self.quant_config.tile_size
710
+ // self.quant_config.pack_factor,
711
+ device="cuda",
712
+ dtype=torch.int32,
713
+ ),
714
+ input_dim=0,
715
+ output_dim=1,
716
+ packed_dim=1,
717
+ packed_factor=self.quant_config.pack_factor,
718
+ marlin_tile_size=self.quant_config.tile_size,
719
+ weight_loader=weight_loader,
720
+ )
721
+
722
+ # Determine if channelwise or not
723
+ input_groups = (
724
+ 1
725
+ if self.quant_config.group_size == -1
726
+ else input_size_per_partition // self.quant_config.group_size
727
+ )
728
+
729
+ weight_scale_args = {
730
+ "data": torch.empty(
731
+ input_groups,
732
+ output_size_per_partition,
733
+ device="cuda",
734
+ dtype=params_dtype,
735
+ ),
736
+ "weight_loader": weight_loader,
737
+ }
738
+ if input_groups == 1:
739
+ scales = ChannelQuantScaleParameter(output_dim=1, **weight_scale_args)
740
+ else:
741
+ scales = GroupQuantScaleParameter(
742
+ output_dim=1, input_dim=0, **weight_scale_args
743
+ )
744
+
745
+ # Allocate workspace (Used for internal locking mechanism)
746
+ max_workspace_size = (
747
+ output_size_per_partition // self.quant_config.min_n_threads
748
+ ) * self.quant_config.max_parallel
749
+
750
+ workspace = BasevLLMParameter(
751
+ data=torch.zeros(max_workspace_size, device="cuda", dtype=torch.int),
752
+ weight_loader=weight_loader,
753
+ )
754
+
755
+ layer.register_parameter("B", qweight)
756
+ layer.register_parameter("s", scales)
757
+ layer.register_parameter("workspace", workspace)
758
+
759
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
760
+ # required by torch.compile
761
+ layer.B = torch.nn.Parameter(layer.B.data, requires_grad=False)
762
+ layer.s = torch.nn.Parameter(layer.s.data, requires_grad=False)
763
+ layer.workspace = torch.nn.Parameter(layer.workspace.data, requires_grad=False)
764
+
765
+ def apply(
766
+ self,
767
+ layer: torch.nn.Module,
768
+ x: torch.Tensor,
769
+ bias: Optional[torch.Tensor] = None,
770
+ ) -> torch.Tensor:
771
+ qweight = layer.B
772
+ scales = layer.s
773
+ workspace = layer.workspace
774
+
775
+ x_2d = x.view(-1, x.shape[-1])
776
+
777
+ size_m = x_2d.shape[0]
778
+ size_k = x_2d.shape[1]
779
+ size_n = scales.shape[1]
780
+
781
+ output_2d = ops.marlin_gemm(
782
+ x_2d, qweight, scales, workspace, size_m, size_n, size_k
783
+ )
784
+
785
+ output = output_2d.view(x.shape[:-1] + (output_2d.shape[1],))
786
+
787
+ if bias is not None:
788
+ output.add_(bias) # In-place add
789
+
790
+ return output