sglang 0.4.3.post1__py3-none-any.whl → 0.4.3.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. sglang/api.py +1 -1
  2. sglang/bench_offline_throughput.py +19 -0
  3. sglang/bench_one_batch.py +2 -2
  4. sglang/bench_serving.py +123 -79
  5. sglang/global_config.py +8 -3
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/lang/ir.py +1 -1
  8. sglang/srt/_custom_ops.py +83 -91
  9. sglang/srt/configs/load_config.py +4 -1
  10. sglang/srt/configs/model_config.py +48 -2
  11. sglang/srt/configs/qwen2_5_vl_config.py +5 -2
  12. sglang/srt/constrained/base_grammar_backend.py +117 -15
  13. sglang/srt/constrained/llguidance_backend.py +151 -0
  14. sglang/srt/constrained/outlines_backend.py +24 -33
  15. sglang/srt/constrained/xgrammar_backend.py +69 -38
  16. sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
  17. sglang/srt/distributed/parallel_state.py +48 -3
  18. sglang/srt/entrypoints/engine.py +67 -9
  19. sglang/srt/entrypoints/http_server.py +190 -41
  20. sglang/srt/entrypoints/verl_engine.py +147 -0
  21. sglang/srt/function_call_parser.py +0 -1
  22. sglang/srt/layers/activation.py +11 -0
  23. sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
  24. sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
  25. sglang/srt/layers/attention/flashinfer_backend.py +208 -295
  26. sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
  27. sglang/srt/layers/attention/torch_native_backend.py +1 -1
  28. sglang/srt/layers/attention/triton_backend.py +9 -6
  29. sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
  30. sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
  31. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
  32. sglang/srt/layers/attention/utils.py +39 -0
  33. sglang/srt/layers/attention/vision.py +60 -63
  34. sglang/srt/layers/dp_attention.py +142 -1
  35. sglang/srt/layers/layernorm.py +1 -1
  36. sglang/srt/layers/linear.py +3 -1
  37. sglang/srt/layers/logits_processor.py +281 -45
  38. sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
  39. sglang/srt/layers/moe/ep_moe/layer.py +140 -28
  40. sglang/srt/layers/moe/fused_moe_native.py +2 -0
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
  48. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
  51. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
  55. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
  61. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
  62. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
  63. sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
  64. sglang/srt/layers/moe/topk.py +13 -4
  65. sglang/srt/layers/quantization/__init__.py +111 -7
  66. sglang/srt/layers/quantization/blockwise_int8.py +409 -0
  67. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  68. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  69. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  70. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  71. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  72. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  73. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  74. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  75. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  76. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  77. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  78. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  79. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  80. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  81. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  82. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  83. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  84. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  85. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  86. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  87. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  88. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  89. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  90. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  91. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  92. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  93. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  94. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  95. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  96. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  97. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  98. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  99. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  100. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  101. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  102. sglang/srt/layers/quantization/fp8.py +69 -28
  103. sglang/srt/layers/quantization/fp8_utils.py +17 -1
  104. sglang/srt/layers/quantization/gptq.py +416 -0
  105. sglang/srt/layers/quantization/int8_kernel.py +327 -0
  106. sglang/srt/layers/quantization/int8_utils.py +73 -0
  107. sglang/srt/layers/quantization/modelopt_quant.py +18 -1
  108. sglang/srt/layers/radix_attention.py +1 -0
  109. sglang/srt/layers/rotary_embedding.py +0 -1
  110. sglang/srt/layers/sampler.py +76 -31
  111. sglang/srt/layers/vocab_parallel_embedding.py +14 -13
  112. sglang/srt/lora/lora.py +17 -1
  113. sglang/srt/lora/lora_config.py +5 -0
  114. sglang/srt/lora/lora_manager.py +1 -3
  115. sglang/srt/managers/cache_controller.py +193 -62
  116. sglang/srt/managers/configure_logging.py +2 -1
  117. sglang/srt/managers/data_parallel_controller.py +6 -2
  118. sglang/srt/managers/detokenizer_manager.py +124 -102
  119. sglang/srt/managers/image_processor.py +2 -1
  120. sglang/srt/managers/io_struct.py +143 -6
  121. sglang/srt/managers/schedule_batch.py +238 -197
  122. sglang/srt/managers/schedule_policy.py +29 -29
  123. sglang/srt/managers/scheduler.py +681 -259
  124. sglang/srt/managers/session_controller.py +6 -2
  125. sglang/srt/managers/tokenizer_manager.py +224 -68
  126. sglang/srt/managers/tp_worker.py +15 -4
  127. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  128. sglang/srt/mem_cache/chunk_cache.py +18 -11
  129. sglang/srt/mem_cache/hiradix_cache.py +394 -0
  130. sglang/srt/mem_cache/memory_pool.py +44 -18
  131. sglang/srt/mem_cache/radix_cache.py +58 -47
  132. sglang/srt/metrics/collector.py +94 -36
  133. sglang/srt/model_executor/cuda_graph_runner.py +55 -24
  134. sglang/srt/model_executor/forward_batch_info.py +49 -16
  135. sglang/srt/model_executor/model_runner.py +209 -28
  136. sglang/srt/model_loader/loader.py +3 -3
  137. sglang/srt/model_loader/weight_utils.py +36 -14
  138. sglang/srt/models/baichuan.py +31 -6
  139. sglang/srt/models/chatglm.py +39 -7
  140. sglang/srt/models/commandr.py +29 -5
  141. sglang/srt/models/dbrx.py +31 -5
  142. sglang/srt/models/deepseek.py +43 -6
  143. sglang/srt/models/deepseek_nextn.py +32 -19
  144. sglang/srt/models/deepseek_v2.py +265 -29
  145. sglang/srt/models/exaone.py +19 -9
  146. sglang/srt/models/gemma.py +22 -8
  147. sglang/srt/models/gemma2.py +25 -12
  148. sglang/srt/models/gemma2_reward.py +5 -1
  149. sglang/srt/models/gpt2.py +28 -13
  150. sglang/srt/models/gpt_bigcode.py +27 -5
  151. sglang/srt/models/granite.py +21 -9
  152. sglang/srt/models/grok.py +21 -4
  153. sglang/srt/models/internlm2.py +36 -6
  154. sglang/srt/models/internlm2_reward.py +5 -1
  155. sglang/srt/models/llama.py +26 -9
  156. sglang/srt/models/llama_classification.py +5 -1
  157. sglang/srt/models/llama_eagle.py +17 -4
  158. sglang/srt/models/llama_embedding.py +5 -1
  159. sglang/srt/models/llama_reward.py +7 -2
  160. sglang/srt/models/llava.py +19 -3
  161. sglang/srt/models/llavavid.py +10 -1
  162. sglang/srt/models/minicpm.py +26 -2
  163. sglang/srt/models/minicpm3.py +39 -3
  164. sglang/srt/models/minicpmv.py +45 -14
  165. sglang/srt/models/mixtral.py +20 -9
  166. sglang/srt/models/mixtral_quant.py +50 -8
  167. sglang/srt/models/mllama.py +57 -11
  168. sglang/srt/models/olmo.py +34 -6
  169. sglang/srt/models/olmo2.py +34 -13
  170. sglang/srt/models/olmoe.py +26 -4
  171. sglang/srt/models/phi3_small.py +29 -10
  172. sglang/srt/models/qwen.py +26 -3
  173. sglang/srt/models/qwen2.py +26 -4
  174. sglang/srt/models/qwen2_5_vl.py +46 -8
  175. sglang/srt/models/qwen2_eagle.py +17 -5
  176. sglang/srt/models/qwen2_moe.py +44 -6
  177. sglang/srt/models/qwen2_rm.py +78 -0
  178. sglang/srt/models/qwen2_vl.py +39 -8
  179. sglang/srt/models/stablelm.py +32 -5
  180. sglang/srt/models/torch_native_llama.py +5 -2
  181. sglang/srt/models/xverse.py +21 -9
  182. sglang/srt/models/xverse_moe.py +45 -7
  183. sglang/srt/models/yivl.py +2 -1
  184. sglang/srt/openai_api/adapter.py +109 -24
  185. sglang/srt/openai_api/protocol.py +17 -1
  186. sglang/srt/reasoning_parser.py +154 -0
  187. sglang/srt/sampling/penaltylib/__init__.py +4 -6
  188. sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
  189. sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
  190. sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
  191. sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
  192. sglang/srt/sampling/sampling_batch_info.py +79 -157
  193. sglang/srt/sampling/sampling_params.py +16 -13
  194. sglang/srt/server_args.py +136 -52
  195. sglang/srt/speculative/build_eagle_tree.py +2 -8
  196. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +0 -1
  197. sglang/srt/speculative/eagle_utils.py +92 -58
  198. sglang/srt/speculative/eagle_worker.py +186 -94
  199. sglang/srt/speculative/spec_info.py +1 -13
  200. sglang/srt/utils.py +43 -17
  201. sglang/srt/warmup.py +47 -0
  202. sglang/test/few_shot_gsm8k.py +4 -1
  203. sglang/test/runners.py +389 -126
  204. sglang/test/send_one.py +88 -0
  205. sglang/test/test_block_fp8_ep.py +361 -0
  206. sglang/test/test_programs.py +1 -1
  207. sglang/test/test_utils.py +138 -84
  208. sglang/utils.py +50 -60
  209. sglang/version.py +1 -1
  210. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/METADATA +21 -15
  211. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/RECORD +214 -166
  212. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/WHEEL +1 -1
  213. sglang/bench_latency.py +0 -1
  214. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
  215. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
  216. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
  217. sglang/test/srt/sampling/penaltylib/utils.py +0 -344
  218. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/LICENSE +0 -0
  219. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,416 @@
1
+ import logging
2
+ from fractions import Fraction
3
+ from typing import Any, Dict, List, Optional, Union
4
+
5
+ import torch
6
+ from vllm.scalar_type import scalar_types
7
+
8
+ from sglang.srt.layers.linear import LinearBase
9
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
10
+ from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class GPTQConfig(QuantizationConfig):
16
+ """Config class for GPTQ.
17
+
18
+ Reference: https://arxiv.org/abs/2210.17323
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ weight_bits: int,
24
+ group_size: int,
25
+ desc_act: bool,
26
+ lm_head_quantized: bool,
27
+ dynamic: Dict[str, Dict[str, Union[int, bool]]],
28
+ ) -> None:
29
+ # GPTQModel use `dynamic` config property to allow per module
30
+ # quantization config so each module can be individually optimized.
31
+ # Format is Dict[str, Dict] where key is a regex string that can
32
+ # perform both positive ("+:" prefixed) or negative ("-:" prefixed)
33
+ # matching of a module.
34
+ # Default to positive match, override base quant config mode, if no
35
+ # prefix is used. Value is in dict format of field key and override
36
+ # value.
37
+ # Negative matching will skip quantization init for this module
38
+ # entirely:
39
+ # non-quantized inference. More details and quantization examples can be
40
+ # found at: https://github.com/ModelCloud/GPTQModel
41
+ # Example:
42
+ # # last 1/2 of the layers 10-21 has 8bit vs 4bit for 0-9
43
+ # # last 1/4 of the layers 16-21 has 8bit and group_size 64
44
+ # dynamic = {
45
+ # #`.*\.` matches the layers_node prefix
46
+ # # positive match layer 10-15
47
+ # r"+:.*\.(?:1[0-5])\..*": {"bits": 8,},
48
+ # # positive match layer 16-21
49
+ # r"+:.*\.(?:1[6-9]|20|21)\..*": {"bits": 8, "group_size": 64,},
50
+ # r"-:.*\.moe\..*": {}, # negative match (skip) all `moe` layers
51
+ # }
52
+ super().__init__()
53
+ self.dynamic = dynamic
54
+
55
+ self.weight_bits = weight_bits
56
+ self.group_size = group_size
57
+ self.desc_act = desc_act
58
+ self.lm_head_quantized = lm_head_quantized
59
+ self.pack_factor = Fraction(32, self.weight_bits)
60
+ if self.weight_bits not in [2, 3, 4, 8]:
61
+ raise ValueError(
62
+ "Currently, only 2/3/4/8-bit weight quantization is "
63
+ f"supported for GPTQ, but got {self.weight_bits} bits."
64
+ )
65
+
66
+ def __repr__(self) -> str:
67
+ return (
68
+ f"GPTQConfig(weight_bits={self.weight_bits}, "
69
+ f"group_size={self.group_size}, "
70
+ f"desc_act={self.desc_act}),"
71
+ f"lm_head_quantized={self.lm_head_quantized}), "
72
+ f"dynamic={self.dynamic}"
73
+ )
74
+
75
+ def get_scaled_act_names(self) -> List[str]:
76
+ """Returns the activation function names that should be post-scaled.
77
+
78
+ For now, this is only used by AWQ.
79
+ """
80
+ raise NotImplementedError
81
+
82
+ @classmethod
83
+ def get_name(cls) -> str:
84
+ return "gptq"
85
+
86
+ @classmethod
87
+ def get_supported_act_dtypes(cls) -> List[torch.dtype]:
88
+ return [torch.half]
89
+
90
+ @classmethod
91
+ # Need to figure it out
92
+ def get_min_capability(cls) -> int:
93
+ return 60
94
+
95
+ @classmethod
96
+ def get_config_filenames(cls) -> List[str]:
97
+ return ["quantize_config.json"]
98
+
99
+ @classmethod
100
+ def from_config(cls, config: Dict[str, Any]) -> "GPTQConfig":
101
+ dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
102
+ dynamic = {} if dynamic is None else dynamic
103
+
104
+ weight_bits = cls.get_from_keys(config, ["bits"])
105
+ group_size = cls.get_from_keys(config, ["group_size"])
106
+ desc_act = cls.get_from_keys(config, ["desc_act"])
107
+ lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False)
108
+ return cls(weight_bits, group_size, desc_act, lm_head_quantized, dynamic)
109
+
110
+ def get_quant_method(
111
+ self, layer: torch.nn.Module, prefix: str
112
+ ) -> Optional["GPTQLinearMethod"]:
113
+ from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
114
+
115
+ from sglang.srt.layers.quantization import get_linear_quant_method
116
+
117
+ return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod)
118
+
119
+
120
+ class GPTQMarlinConfig(QuantizationConfig):
121
+ """Config class for GPTQ Marlin"""
122
+
123
+ # (num_bits, is_sym) -> quant_type
124
+ TYPE_MAP = {
125
+ (4, True): scalar_types.uint4b8,
126
+ (8, True): scalar_types.uint8b128,
127
+ }
128
+
129
+ def __init__(
130
+ self,
131
+ weight_bits: int,
132
+ group_size: int,
133
+ desc_act: bool,
134
+ is_sym: bool,
135
+ lm_head_quantized: bool,
136
+ dynamic: Dict[str, Dict[str, Union[int, bool]]],
137
+ full_config: Dict[str, Any],
138
+ ) -> None:
139
+ super().__init__()
140
+ if desc_act and group_size == -1:
141
+ # In this case, act_order == True is the same as act_order == False
142
+ # (since we have only one group per output channel)
143
+ desc_act = False
144
+
145
+ # GPTQModel use `dynamic` config property to allow per module
146
+ # quantization config so each module can be individually optimized.
147
+ # Format is Dict[str, Dict] where key is a regex string that can
148
+ # perform both positive ("+:" prefixed) or negative ("-:" prefixed)
149
+ # matching of a module.
150
+ # Default to positive match, override base quant config mode, if no
151
+ # prefix is used. Value is in dict format of field key and override
152
+ # value.
153
+ # Negative matching will skip quantization init for this module
154
+ # entirely:
155
+ # non-quantized inference. More details and quantization examples can be
156
+ # found at: https://github.com/ModelCloud/GPTQModel
157
+ # Example:
158
+ # # last 1/2 of the layers 10-21 has 8bit vs 4bit for 0-9
159
+ # # last 1/4 of the layers 16-21 has 8bit and group_size 64
160
+ # dynamic = {
161
+ # #`.*\.` matches the layers_node prefix
162
+ # # positive match layer 10-15
163
+ # r"+:.*\.(?:1[0-5])\..*": {"bits": 8,},
164
+ # # positive match layer 16-21
165
+ # r"+:.*\.(?:1[6-9]|20|21)\..*": {"bits": 8, "group_size": 64,},
166
+ # r"-:.*\.moe\..*": {}, # negative match (skip) all `moe` layers
167
+ # }
168
+ self.dynamic = dynamic
169
+
170
+ self.weight_bits = weight_bits
171
+ self.is_sym = is_sym
172
+
173
+ self.pack_factor = 32 // weight_bits # packed into int32
174
+ self.group_size = group_size
175
+ self.desc_act = desc_act
176
+ self.lm_head_quantized = lm_head_quantized
177
+ self.full_config = full_config
178
+
179
+ if (weight_bits, is_sym) not in self.TYPE_MAP:
180
+ raise ValueError(
181
+ "Unsupported quantization config: " f"bits={weight_bits}, sym={is_sym}"
182
+ )
183
+
184
+ self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)]
185
+
186
+ def __repr__(self) -> str:
187
+ return (
188
+ f"GPTQMarlinConfig(quant_type={self.quant_type}, "
189
+ f"group_size={self.group_size}, "
190
+ f"desc_act={self.desc_act}, "
191
+ f"lm_head_quantized={self.lm_head_quantized}), "
192
+ f"dynamic={self.dynamic}"
193
+ )
194
+
195
+ def get_scaled_act_names(self) -> List[str]:
196
+ """Returns the activation function names that should be post-scaled.
197
+
198
+ For now, this is only used by AWQ.
199
+ """
200
+ raise NotImplementedError
201
+
202
+ @classmethod
203
+ def get_name(cls) -> str:
204
+ return "gptq_marlin"
205
+
206
+ @classmethod
207
+ def get_supported_act_dtypes(cls) -> List[torch.dtype]:
208
+ return [torch.half, torch.bfloat16]
209
+
210
+ @classmethod
211
+ def get_min_capability(cls) -> int:
212
+ return 80
213
+
214
+ @classmethod
215
+ def get_config_filenames(cls) -> List[str]:
216
+ return ["quantize_config.json"]
217
+
218
+ @classmethod
219
+ def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig":
220
+ dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
221
+ dynamic = {} if dynamic is None else dynamic
222
+
223
+ weight_bits = cls.get_from_keys(config, ["bits"])
224
+ group_size = cls.get_from_keys(config, ["group_size"])
225
+ desc_act = cls.get_from_keys(config, ["desc_act"])
226
+ is_sym = cls.get_from_keys(config, ["sym"])
227
+ lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False)
228
+ return cls(
229
+ weight_bits,
230
+ group_size,
231
+ desc_act,
232
+ is_sym,
233
+ lm_head_quantized,
234
+ dynamic,
235
+ config,
236
+ )
237
+
238
+ @classmethod
239
+ def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
240
+ can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg)
241
+
242
+ is_valid_user_quant = (
243
+ user_quant is None or user_quant == "marlin" or user_quant == "gptq_marlin"
244
+ )
245
+
246
+ if can_convert and is_valid_user_quant:
247
+ msg = (
248
+ "The model is convertible to {} during runtime."
249
+ " Using {} kernel.".format(cls.get_name(), cls.get_name())
250
+ )
251
+ logger.info(msg)
252
+ return cls.get_name()
253
+
254
+ if can_convert and user_quant == "gptq":
255
+ logger.info(
256
+ "Detected that the model can run with gptq_marlin"
257
+ ", however you specified quantization=gptq explicitly,"
258
+ " so forcing gptq. Use quantization=gptq_marlin for"
259
+ " faster inference"
260
+ )
261
+ return None
262
+
263
+ def get_quant_method(
264
+ self, layer: torch.nn.Module, prefix: str
265
+ ) -> Optional["QuantizeMethodBase"]:
266
+ from vllm.model_executor.layers.quantization.gptq_marlin import (
267
+ GPTQMarlinLinearMethod,
268
+ GPTQMarlinMoEMethod,
269
+ )
270
+
271
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
272
+ from sglang.srt.layers.quantization import get_linear_quant_method
273
+
274
+ if isinstance(layer, FusedMoE):
275
+ return GPTQMarlinMoEMethod(self)
276
+ # TODO: re-enable after SGLang syncs with vllm >= 0.7.3
277
+ # if layer.num_experts > 32:
278
+ # # For MoEs with many experts the moe_wna16 kernel is faster
279
+ # return MoeWNA16Config.from_config(self.full_config).get_quant_method(
280
+ # layer, prefix
281
+ # )
282
+ # else:
283
+ # return GPTQMarlinMoEMethod(self)
284
+ return get_linear_quant_method(self, layer, prefix, GPTQMarlinLinearMethod)
285
+
286
+ @classmethod
287
+ def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]):
288
+ quant_method = quant_config.get("quant_method", "").lower()
289
+ num_bits = quant_config.get("bits")
290
+ group_size = quant_config.get("group_size")
291
+ sym = quant_config.get("sym")
292
+ desc_act = quant_config.get("desc_act")
293
+
294
+ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
295
+ check_marlin_supported,
296
+ )
297
+ from vllm.platforms import current_platform
298
+
299
+ if not current_platform.is_cuda():
300
+ return False
301
+
302
+ if quant_method != "gptq":
303
+ return False
304
+
305
+ # Marlin conversion is only valid if required properties are found
306
+ if num_bits is None or group_size is None or sym is None or desc_act is None:
307
+ return False
308
+
309
+ if (num_bits, sym) not in cls.TYPE_MAP:
310
+ return False
311
+
312
+ return check_marlin_supported(
313
+ quant_type=cls.TYPE_MAP[(num_bits, sym)], group_size=group_size
314
+ )
315
+
316
+
317
+ class MarlinConfig(QuantizationConfig):
318
+ """Config class for Marlin.
319
+
320
+ Reference: https://github.com/IST-DASLab/marlin/tree/master
321
+ """
322
+
323
+ def __init__(
324
+ self,
325
+ group_size: int,
326
+ lm_head_quantized: bool,
327
+ ) -> None:
328
+ # Group size for the quantization.
329
+ self.group_size = group_size
330
+ self.lm_head_quantized = lm_head_quantized
331
+ if self.group_size != 128 and self.group_size != -1:
332
+ raise ValueError(
333
+ "Currently, only group size 128 and -1 (channelwise) "
334
+ "is supported for Marlin, but got group_size of "
335
+ f"{self.group_size}"
336
+ )
337
+
338
+ # 4 Bits packed into 32 bit datatype.
339
+ self.pack_factor = 32 // 4
340
+
341
+ # Tile size used by marlin kernels.
342
+ self.tile_size = 16
343
+
344
+ # Min out_features dim
345
+ self.min_n_threads = 64
346
+
347
+ # Min in_features dim
348
+ self.min_k_threads = 128
349
+
350
+ # Max parallel problems to solve at once (improves large
351
+ # batch performance)
352
+ self.max_parallel = 16
353
+
354
+ # Permutation length used by the marlin kernels.
355
+ self.perm_len = 1024
356
+
357
+ def __repr__(self) -> str:
358
+ return (
359
+ f"MarlinConfig(group_size={self.group_size}, "
360
+ f"lm_head_quantized={self.lm_head_quantized})"
361
+ )
362
+
363
+ @classmethod
364
+ def get_name(cls) -> str:
365
+ return "marlin"
366
+
367
+ @classmethod
368
+ def get_supported_act_dtypes(cls) -> List[torch.dtype]:
369
+ return [torch.half]
370
+
371
+ @classmethod
372
+ # Need to figure it out
373
+ def get_min_capability(cls) -> int:
374
+ return 80
375
+
376
+ @classmethod
377
+ def get_config_filenames(cls) -> List[str]:
378
+ return ["quantize_config.json"]
379
+
380
+ @classmethod
381
+ def from_config(cls, config: Dict[str, Any]) -> "MarlinConfig":
382
+ group_size = cls.get_from_keys(config, ["group_size"])
383
+ lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False)
384
+ return cls(group_size, lm_head_quantized)
385
+
386
+ @classmethod
387
+ def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
388
+ # compat: autogptq >=0.8.0 use checkpoint_format: str
389
+ # compat: autogptq <=0.7.1 is_marlin_format: bool
390
+ is_marlin_format = hf_quant_cfg.get(
391
+ "checkpoint_format"
392
+ ) == "marlin" or hf_quant_cfg.get("is_marlin_format", False)
393
+
394
+ is_valid_user_quant = (
395
+ user_quant is None or user_quant == "gptq" or user_quant == "marlin"
396
+ )
397
+
398
+ if is_marlin_format and is_valid_user_quant:
399
+ msg = "The model is serialized in {} format. Using {} kernel.".format(
400
+ cls.get_name(), cls.get_name()
401
+ )
402
+ logger.info(msg)
403
+ return cls.get_name()
404
+
405
+ return None
406
+
407
+ def get_quant_method(
408
+ self, layer: torch.nn.Module, prefix: str
409
+ ) -> Optional["MarlinLinearMethod"]:
410
+ from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod
411
+
412
+ if isinstance(layer, LinearBase) or (
413
+ isinstance(layer, ParallelLMHead) and self.lm_head_quantized
414
+ ):
415
+ return MarlinLinearMethod(self)
416
+ return None