sglang 0.4.3.post1__py3-none-any.whl → 0.4.3.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. sglang/api.py +1 -1
  2. sglang/bench_offline_throughput.py +19 -0
  3. sglang/bench_one_batch.py +2 -2
  4. sglang/bench_serving.py +123 -79
  5. sglang/global_config.py +8 -3
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/lang/ir.py +1 -1
  8. sglang/srt/_custom_ops.py +83 -91
  9. sglang/srt/configs/load_config.py +4 -1
  10. sglang/srt/configs/model_config.py +48 -2
  11. sglang/srt/configs/qwen2_5_vl_config.py +5 -2
  12. sglang/srt/constrained/base_grammar_backend.py +117 -15
  13. sglang/srt/constrained/llguidance_backend.py +151 -0
  14. sglang/srt/constrained/outlines_backend.py +24 -33
  15. sglang/srt/constrained/xgrammar_backend.py +69 -38
  16. sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
  17. sglang/srt/distributed/parallel_state.py +48 -3
  18. sglang/srt/entrypoints/engine.py +67 -9
  19. sglang/srt/entrypoints/http_server.py +190 -41
  20. sglang/srt/entrypoints/verl_engine.py +147 -0
  21. sglang/srt/function_call_parser.py +0 -1
  22. sglang/srt/layers/activation.py +11 -0
  23. sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
  24. sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
  25. sglang/srt/layers/attention/flashinfer_backend.py +208 -295
  26. sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
  27. sglang/srt/layers/attention/torch_native_backend.py +1 -1
  28. sglang/srt/layers/attention/triton_backend.py +9 -6
  29. sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
  30. sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
  31. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
  32. sglang/srt/layers/attention/utils.py +39 -0
  33. sglang/srt/layers/attention/vision.py +60 -63
  34. sglang/srt/layers/dp_attention.py +142 -1
  35. sglang/srt/layers/layernorm.py +1 -1
  36. sglang/srt/layers/linear.py +3 -1
  37. sglang/srt/layers/logits_processor.py +281 -45
  38. sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
  39. sglang/srt/layers/moe/ep_moe/layer.py +140 -28
  40. sglang/srt/layers/moe/fused_moe_native.py +2 -0
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
  48. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
  51. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
  55. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
  61. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
  62. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
  63. sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
  64. sglang/srt/layers/moe/topk.py +13 -4
  65. sglang/srt/layers/quantization/__init__.py +111 -7
  66. sglang/srt/layers/quantization/blockwise_int8.py +409 -0
  67. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  68. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  69. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  70. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  71. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  72. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  73. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  74. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  75. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  76. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  77. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  78. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  79. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  80. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  81. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  82. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  83. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  84. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  85. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  86. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  87. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  88. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  89. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  90. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  91. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  92. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  93. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  94. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  95. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  96. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  97. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  98. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  99. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  100. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  101. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  102. sglang/srt/layers/quantization/fp8.py +69 -28
  103. sglang/srt/layers/quantization/fp8_utils.py +17 -1
  104. sglang/srt/layers/quantization/gptq.py +416 -0
  105. sglang/srt/layers/quantization/int8_kernel.py +327 -0
  106. sglang/srt/layers/quantization/int8_utils.py +73 -0
  107. sglang/srt/layers/quantization/modelopt_quant.py +18 -1
  108. sglang/srt/layers/radix_attention.py +1 -0
  109. sglang/srt/layers/rotary_embedding.py +0 -1
  110. sglang/srt/layers/sampler.py +76 -31
  111. sglang/srt/layers/vocab_parallel_embedding.py +14 -13
  112. sglang/srt/lora/lora.py +17 -1
  113. sglang/srt/lora/lora_config.py +5 -0
  114. sglang/srt/lora/lora_manager.py +1 -3
  115. sglang/srt/managers/cache_controller.py +193 -62
  116. sglang/srt/managers/configure_logging.py +2 -1
  117. sglang/srt/managers/data_parallel_controller.py +6 -2
  118. sglang/srt/managers/detokenizer_manager.py +124 -102
  119. sglang/srt/managers/image_processor.py +2 -1
  120. sglang/srt/managers/io_struct.py +143 -6
  121. sglang/srt/managers/schedule_batch.py +238 -197
  122. sglang/srt/managers/schedule_policy.py +29 -29
  123. sglang/srt/managers/scheduler.py +681 -259
  124. sglang/srt/managers/session_controller.py +6 -2
  125. sglang/srt/managers/tokenizer_manager.py +224 -68
  126. sglang/srt/managers/tp_worker.py +15 -4
  127. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  128. sglang/srt/mem_cache/chunk_cache.py +18 -11
  129. sglang/srt/mem_cache/hiradix_cache.py +394 -0
  130. sglang/srt/mem_cache/memory_pool.py +44 -18
  131. sglang/srt/mem_cache/radix_cache.py +58 -47
  132. sglang/srt/metrics/collector.py +94 -36
  133. sglang/srt/model_executor/cuda_graph_runner.py +55 -24
  134. sglang/srt/model_executor/forward_batch_info.py +49 -16
  135. sglang/srt/model_executor/model_runner.py +209 -28
  136. sglang/srt/model_loader/loader.py +3 -3
  137. sglang/srt/model_loader/weight_utils.py +36 -14
  138. sglang/srt/models/baichuan.py +31 -6
  139. sglang/srt/models/chatglm.py +39 -7
  140. sglang/srt/models/commandr.py +29 -5
  141. sglang/srt/models/dbrx.py +31 -5
  142. sglang/srt/models/deepseek.py +43 -6
  143. sglang/srt/models/deepseek_nextn.py +32 -19
  144. sglang/srt/models/deepseek_v2.py +265 -29
  145. sglang/srt/models/exaone.py +19 -9
  146. sglang/srt/models/gemma.py +22 -8
  147. sglang/srt/models/gemma2.py +25 -12
  148. sglang/srt/models/gemma2_reward.py +5 -1
  149. sglang/srt/models/gpt2.py +28 -13
  150. sglang/srt/models/gpt_bigcode.py +27 -5
  151. sglang/srt/models/granite.py +21 -9
  152. sglang/srt/models/grok.py +21 -4
  153. sglang/srt/models/internlm2.py +36 -6
  154. sglang/srt/models/internlm2_reward.py +5 -1
  155. sglang/srt/models/llama.py +26 -9
  156. sglang/srt/models/llama_classification.py +5 -1
  157. sglang/srt/models/llama_eagle.py +17 -4
  158. sglang/srt/models/llama_embedding.py +5 -1
  159. sglang/srt/models/llama_reward.py +7 -2
  160. sglang/srt/models/llava.py +19 -3
  161. sglang/srt/models/llavavid.py +10 -1
  162. sglang/srt/models/minicpm.py +26 -2
  163. sglang/srt/models/minicpm3.py +39 -3
  164. sglang/srt/models/minicpmv.py +45 -14
  165. sglang/srt/models/mixtral.py +20 -9
  166. sglang/srt/models/mixtral_quant.py +50 -8
  167. sglang/srt/models/mllama.py +57 -11
  168. sglang/srt/models/olmo.py +34 -6
  169. sglang/srt/models/olmo2.py +34 -13
  170. sglang/srt/models/olmoe.py +26 -4
  171. sglang/srt/models/phi3_small.py +29 -10
  172. sglang/srt/models/qwen.py +26 -3
  173. sglang/srt/models/qwen2.py +26 -4
  174. sglang/srt/models/qwen2_5_vl.py +46 -8
  175. sglang/srt/models/qwen2_eagle.py +17 -5
  176. sglang/srt/models/qwen2_moe.py +44 -6
  177. sglang/srt/models/qwen2_rm.py +78 -0
  178. sglang/srt/models/qwen2_vl.py +39 -8
  179. sglang/srt/models/stablelm.py +32 -5
  180. sglang/srt/models/torch_native_llama.py +5 -2
  181. sglang/srt/models/xverse.py +21 -9
  182. sglang/srt/models/xverse_moe.py +45 -7
  183. sglang/srt/models/yivl.py +2 -1
  184. sglang/srt/openai_api/adapter.py +109 -24
  185. sglang/srt/openai_api/protocol.py +17 -1
  186. sglang/srt/reasoning_parser.py +154 -0
  187. sglang/srt/sampling/penaltylib/__init__.py +4 -6
  188. sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
  189. sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
  190. sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
  191. sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
  192. sglang/srt/sampling/sampling_batch_info.py +79 -157
  193. sglang/srt/sampling/sampling_params.py +16 -13
  194. sglang/srt/server_args.py +136 -52
  195. sglang/srt/speculative/build_eagle_tree.py +2 -8
  196. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +0 -1
  197. sglang/srt/speculative/eagle_utils.py +92 -58
  198. sglang/srt/speculative/eagle_worker.py +186 -94
  199. sglang/srt/speculative/spec_info.py +1 -13
  200. sglang/srt/utils.py +43 -17
  201. sglang/srt/warmup.py +47 -0
  202. sglang/test/few_shot_gsm8k.py +4 -1
  203. sglang/test/runners.py +389 -126
  204. sglang/test/send_one.py +88 -0
  205. sglang/test/test_block_fp8_ep.py +361 -0
  206. sglang/test/test_programs.py +1 -1
  207. sglang/test/test_utils.py +138 -84
  208. sglang/utils.py +50 -60
  209. sglang/version.py +1 -1
  210. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/METADATA +21 -15
  211. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/RECORD +214 -166
  212. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/WHEEL +1 -1
  213. sglang/bench_latency.py +0 -1
  214. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
  215. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
  216. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
  217. sglang/test/srt/sampling/penaltylib/utils.py +0 -344
  218. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/LICENSE +0 -0
  219. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,7 @@
1
1
  # Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py
2
- from typing import Callable, Dict, Optional, Type
2
+ import re
3
+ from copy import deepcopy
4
+ from typing import Callable, Dict, Optional, Type, Union
3
5
 
4
6
  import torch
5
7
  from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
@@ -16,15 +18,15 @@ from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfi
16
18
  from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
17
19
  from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
18
20
  from vllm.model_executor.layers.quantization.gguf import GGUFConfig
19
- from vllm.model_executor.layers.quantization.gptq import GPTQConfig
20
- from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinConfig
21
21
  from vllm.model_executor.layers.quantization.gptq_marlin_24 import GPTQMarlin24Config
22
22
  from vllm.model_executor.layers.quantization.marlin import MarlinConfig
23
23
  from vllm.model_executor.layers.quantization.qqq import QQQConfig
24
24
  from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
25
25
 
26
26
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
27
+ from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
27
28
  from sglang.srt.layers.quantization.fp8 import Fp8Config
29
+ from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
28
30
  from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config
29
31
  from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
30
32
 
@@ -34,6 +36,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
34
36
  "deepspeedfp": DeepSpeedFPConfig,
35
37
  "tpu_int8": Int8TpuConfig,
36
38
  "fp8": Fp8Config,
39
+ "blockwise_int8": BlockInt8Config,
37
40
  "fbgemm_fp8": FBGEMMFp8Config,
38
41
  "marlin": MarlinConfig,
39
42
  "modelopt": ModelOptFp8Config,
@@ -59,19 +62,119 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
59
62
  return QUANTIZATION_METHODS[quantization]
60
63
 
61
64
 
65
+ # Match dynamic rules with module name (prefix) and override quantize
66
+ # config if module (prefix) matches a rule
67
+ def override_config(config: QuantizationConfig, prefix: str):
68
+ weight_bits = get_dynamic_override(config, prefix, "bits", config.weight_bits)
69
+ if isinstance(weight_bits, int):
70
+ config.weight_bits = weight_bits
71
+ group_size = get_dynamic_override(config, prefix, "group_size", config.group_size)
72
+ if isinstance(group_size, int):
73
+ config.group_size = group_size
74
+ desc_act = get_dynamic_override(config, prefix, "desc_act", config.desc_act)
75
+ if isinstance(desc_act, bool):
76
+ config.desc_act = desc_act
77
+
78
+ config.pack_factor = 32 // config.weight_bits # packed into int32
79
+ if config.get_name() == "gptq_marlin":
80
+ is_sym = get_dynamic_override(config, prefix, "sym", config.is_sym)
81
+ if isinstance(is_sym, bool):
82
+ config.is_sym = is_sym
83
+
84
+ if (config.weight_bits, config.is_sym) not in config.TYPE_MAP:
85
+ raise ValueError(
86
+ "Unsupported quantization config: "
87
+ f"bits={config.weight_bits}, sym={config.is_sym}"
88
+ )
89
+
90
+ config.quant_type = config.TYPE_MAP[(config.weight_bits, config.is_sym)]
91
+ elif config.get_name() == "gptq":
92
+ if config.weight_bits not in [2, 3, 4, 8]:
93
+ raise ValueError(
94
+ "Currently, only 2/3/4/8-bit weight quantization is "
95
+ f"supported for GPTQ, but got {config.weight_bits} bits."
96
+ )
97
+
98
+
99
+ def get_dynamic_override(
100
+ config: QuantizationConfig,
101
+ layer_name: str,
102
+ key: Optional[str] = None,
103
+ default_value: Union[int, bool, None] = None,
104
+ ) -> Union[Dict, int, bool, None]:
105
+ for pattern, pattern_dict in config.dynamic.items():
106
+ # Negative match: matched modules are excluded from quantized init
107
+ if pattern.startswith("-:"):
108
+ if re.match(pattern.removeprefix("-:"), layer_name):
109
+ return False
110
+ # Positive match: matched modules have quant properties overrides
111
+ # base quant config
112
+ elif re.match(pattern.removeprefix("+:"), layer_name):
113
+ if key is None:
114
+ return pattern_dict
115
+ else:
116
+ return pattern_dict.get(key, default_value)
117
+ return default_value
118
+
119
+
120
+ def get_linear_quant_method(
121
+ config: QuantizationConfig,
122
+ layer: torch.nn.Module,
123
+ prefix: str,
124
+ linear_method_cls: type,
125
+ ):
126
+
127
+ from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
128
+ from sglang.srt.layers.vocab_parallel_embedding import (
129
+ ParallelLMHead,
130
+ UnquantizedEmbeddingMethod,
131
+ )
132
+
133
+ cloned_config = deepcopy(config)
134
+ parallel_lm_head_quantized = (
135
+ isinstance(layer, ParallelLMHead) and cloned_config.lm_head_quantized
136
+ )
137
+
138
+ if isinstance(layer, LinearBase) or parallel_lm_head_quantized:
139
+ # False = skip module, None = no override, else = Positive match
140
+ if (
141
+ get_dynamic_override( # noqa: E712
142
+ cloned_config, layer_name=prefix # noqa: E712
143
+ )
144
+ == False
145
+ ): # noqa: E712
146
+ if parallel_lm_head_quantized:
147
+ return UnquantizedEmbeddingMethod()
148
+ return UnquantizedLinearMethod()
149
+
150
+ if prefix:
151
+ # Dynamic per module/layer rules may override base config
152
+ override_config(cloned_config, prefix=prefix)
153
+
154
+ return linear_method_cls(cloned_config)
155
+ return None
156
+
157
+
62
158
  def gptq_get_quant_method(self, layer, prefix):
159
+ from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
63
160
  from vllm.model_executor.layers.quantization.gptq_marlin import (
64
161
  GPTQMarlinLinearMethod,
65
162
  GPTQMarlinMoEMethod,
66
163
  )
67
164
 
68
- from sglang.srt.layers.linear import LinearBase
69
165
  from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
70
166
 
71
- if isinstance(layer, LinearBase):
72
- return GPTQMarlinLinearMethod(self)
73
- elif isinstance(layer, FusedMoE):
167
+ if isinstance(layer, FusedMoE):
74
168
  return GPTQMarlinMoEMethod(self)
169
+
170
+ if isinstance(self, GPTQConfig):
171
+ return get_linear_quant_method(
172
+ self, layer, prefix=prefix, linear_method_cls=GPTQLinearMethod
173
+ )
174
+ elif isinstance(self, GPTQMarlinConfig):
175
+ return get_linear_quant_method(
176
+ self, layer, prefix=prefix, linear_method_cls=GPTQMarlinLinearMethod
177
+ )
75
178
  return None
76
179
 
77
180
 
@@ -153,6 +256,7 @@ def apply_monkey_patches():
153
256
  from vllm.model_executor.layers.quantization.awq_marlin import AWQMoEMethod
154
257
 
155
258
  setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
259
+ setattr(GPTQConfig, "get_quant_method", gptq_get_quant_method)
156
260
  setattr(AWQMarlinConfig, "get_quant_method", awq_get_quant_method)
157
261
  setattr(AWQMoEMethod, "apply", awq_moe_method_apply)
158
262
 
@@ -0,0 +1,409 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
2
+
3
+ import logging
4
+ from typing import Any, Callable, Dict, List, Optional
5
+
6
+ import torch
7
+ from torch.nn import Module
8
+ from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
9
+
10
+ from sglang.srt.distributed import get_tensor_model_parallel_world_size
11
+ from sglang.srt.layers.linear import (
12
+ LinearBase,
13
+ LinearMethodBase,
14
+ UnquantizedLinearMethod,
15
+ )
16
+ from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
17
+ from sglang.srt.layers.quantization.base_config import (
18
+ QuantizationConfig,
19
+ QuantizeMethodBase,
20
+ )
21
+ from sglang.srt.layers.quantization.fp8_utils import BlockQuantScaleParameter
22
+ from sglang.srt.layers.quantization.int8_utils import apply_w8a8_block_int8_linear
23
+ from sglang.srt.utils import set_weight_attrs
24
+
25
+ ACTIVATION_SCHEMES = ["static", "dynamic"]
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ class BlockInt8Config(QuantizationConfig):
31
+ """Config class for INT8."""
32
+
33
+ def __init__(
34
+ self,
35
+ is_checkpoint_int8_serialized: bool = False,
36
+ activation_scheme: str = "dynamic",
37
+ ignored_layers: Optional[List[str]] = None,
38
+ weight_block_size: List[int] = None,
39
+ ) -> None:
40
+ self.is_checkpoint_int8_serialized = is_checkpoint_int8_serialized
41
+ if is_checkpoint_int8_serialized:
42
+ logger.warning(
43
+ "Detected int8 checkpoint. Please note that the "
44
+ "format is experimental and subject to change."
45
+ )
46
+ if activation_scheme not in ACTIVATION_SCHEMES:
47
+ raise ValueError(f"Unsupported activation scheme {activation_scheme}")
48
+ self.activation_scheme = activation_scheme
49
+ self.ignored_layers = ignored_layers or []
50
+ if weight_block_size is not None:
51
+ if not is_checkpoint_int8_serialized:
52
+ raise ValueError(
53
+ f"The block-wise quantization only supports int8-serialized checkpoint for now."
54
+ )
55
+ if len(weight_block_size) != 2:
56
+ raise ValueError(
57
+ f"The quantization block size of weight must have 2 dimensions, but got {len(weight_block_size)} dimensions."
58
+ )
59
+ if activation_scheme != "dynamic":
60
+ raise ValueError(
61
+ f"The block-wise quantization only supports dynamic activation scheme for now, but got {activation_scheme} activation scheme."
62
+ )
63
+ self.weight_block_size = weight_block_size
64
+
65
+ @classmethod
66
+ def get_name(cls) -> str:
67
+ return "blockwise_int8"
68
+
69
+ @classmethod
70
+ def get_supported_act_dtypes(cls) -> List[torch.dtype]:
71
+ return [torch.bfloat16, torch.half]
72
+
73
+ @classmethod
74
+ def get_min_capability(cls) -> int:
75
+ return 80
76
+
77
+ @classmethod
78
+ def get_config_filenames(cls) -> List[str]:
79
+ return []
80
+
81
+ @classmethod
82
+ def from_config(cls, config: Dict[str, Any]) -> "BlockInt8Config":
83
+ quant_method = cls.get_from_keys(config, ["quant_method"])
84
+ is_checkpoint_int8_serialized = "int8" in quant_method
85
+ activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
86
+ ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
87
+ weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None)
88
+ return cls(
89
+ is_checkpoint_int8_serialized=is_checkpoint_int8_serialized,
90
+ activation_scheme=activation_scheme,
91
+ ignored_layers=ignored_layers,
92
+ weight_block_size=weight_block_size,
93
+ )
94
+
95
+ def get_quant_method(
96
+ self, layer: torch.nn.Module, prefix: str
97
+ ) -> Optional["QuantizeMethodBase"]:
98
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
99
+
100
+ if isinstance(layer, LinearBase):
101
+ if is_layer_skipped(prefix, self.ignored_layers):
102
+ return UnquantizedLinearMethod()
103
+ return BlockInt8LinearMethod(self)
104
+ elif isinstance(layer, FusedMoE):
105
+ return BlockInt8MoEMethod(self)
106
+ return None
107
+
108
+ def get_scaled_act_names(self) -> List[str]:
109
+ return []
110
+
111
+
112
+ class BlockInt8LinearMethod(LinearMethodBase):
113
+ """Linear method for INT8.
114
+ Supports loading INT8 checkpoints with static weight scale and
115
+ dynamic activation scale.
116
+
117
+ Limitations:
118
+ Only support block-wise int8 quantization and int8 checkpoint
119
+
120
+ Args:
121
+ quant_config: The quantization config.
122
+ """
123
+
124
+ def __init__(self, quant_config: BlockInt8Config):
125
+ self.quant_config = quant_config
126
+ assert self.quant_config.weight_block_size is not None
127
+ assert self.quant_config.is_checkpoint_int8_serialized
128
+
129
+ def create_weights(
130
+ self,
131
+ layer: torch.nn.Module,
132
+ input_size_per_partition: int,
133
+ output_partition_sizes: List[int],
134
+ input_size: int,
135
+ output_size: int,
136
+ params_dtype: torch.dtype,
137
+ **extra_weight_attrs,
138
+ ):
139
+ output_size_per_partition = sum(output_partition_sizes)
140
+ weight_loader = extra_weight_attrs.get("weight_loader")
141
+
142
+ tp_size = get_tensor_model_parallel_world_size()
143
+
144
+ block_n, block_k = (
145
+ self.quant_config.weight_block_size[0],
146
+ self.quant_config.weight_block_size[1],
147
+ )
148
+ # Required by row parallel
149
+ if tp_size > 1 and input_size // input_size_per_partition == tp_size:
150
+ if input_size_per_partition % block_k != 0:
151
+ raise ValueError(
152
+ f"Weight input_size_per_partition = "
153
+ f"{input_size_per_partition} is not divisible by "
154
+ f"weight quantization block_k = {block_k}."
155
+ )
156
+ # Required by collum parallel or enabling merged weights
157
+ if (tp_size > 1 and output_size // output_size_per_partition == tp_size) or len(
158
+ output_partition_sizes
159
+ ) > 1:
160
+ for output_partition_size in output_partition_sizes:
161
+ if output_partition_size % block_n != 0:
162
+ raise ValueError(
163
+ f"Weight output_partition_size = "
164
+ f"{output_partition_size} is not divisible by "
165
+ f"weight quantization block_n = {block_n}."
166
+ )
167
+
168
+ layer.logical_widths = output_partition_sizes
169
+
170
+ layer.input_size_per_partition = input_size_per_partition
171
+ layer.output_size_per_partition = output_size_per_partition
172
+ layer.orig_dtype = params_dtype
173
+
174
+ # WEIGHT
175
+ weight_dtype = (
176
+ torch.int8
177
+ if self.quant_config.is_checkpoint_int8_serialized
178
+ else params_dtype
179
+ )
180
+
181
+ weight = ModelWeightParameter(
182
+ data=torch.empty(
183
+ output_size_per_partition, input_size_per_partition, dtype=weight_dtype
184
+ ),
185
+ input_dim=1,
186
+ output_dim=0,
187
+ weight_loader=weight_loader,
188
+ )
189
+ layer.register_parameter("weight", weight)
190
+
191
+ # WEIGHT SCALE
192
+
193
+ scale = BlockQuantScaleParameter(
194
+ data=torch.empty(
195
+ (output_size_per_partition + block_n - 1) // block_n,
196
+ (input_size_per_partition + block_k - 1) // block_k,
197
+ dtype=torch.float32,
198
+ ),
199
+ input_dim=1,
200
+ output_dim=0,
201
+ weight_loader=weight_loader,
202
+ )
203
+ scale[:] = torch.finfo(torch.float32).min
204
+ layer.register_parameter("weight_scale_inv", scale)
205
+
206
+ # INPUT ACTIVATION SCALE
207
+ assert self.quant_config.activation_scheme == "dynamic"
208
+ layer.register_parameter("input_scale", None)
209
+
210
+ def process_weights_after_loading(self, layer: Module) -> None:
211
+ # Block quant doesn't need to process weights after loading
212
+ # Use torch Parameter to avoid cuda graph capturing issue
213
+ layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
214
+ layer.weight_scale_inv = torch.nn.Parameter(
215
+ layer.weight_scale_inv.data, requires_grad=False
216
+ )
217
+
218
+ def apply(
219
+ self,
220
+ layer: torch.nn.Module,
221
+ x: torch.Tensor,
222
+ bias: Optional[torch.Tensor] = None,
223
+ ) -> torch.Tensor:
224
+ return apply_w8a8_block_int8_linear(
225
+ input=x,
226
+ weight=layer.weight,
227
+ block_size=self.quant_config.weight_block_size,
228
+ weight_scale=layer.weight_scale_inv,
229
+ input_scale=None,
230
+ bias=bias,
231
+ )
232
+
233
+
234
+ class BlockInt8MoEMethod:
235
+ """MoE method for INT8.
236
+ Supports loading INT8 checkpoints with static weight scale and
237
+ dynamic activation scale.
238
+
239
+ Limitations:
240
+ Only support block-wise int8 quantization and int8 checkpoint
241
+
242
+ Args:
243
+ quant_config: The quantization config.
244
+ """
245
+
246
+ def __new__(cls, *args, **kwargs):
247
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
248
+
249
+ if not hasattr(cls, "_initialized"):
250
+ original_init = cls.__init__
251
+ new_cls = type(
252
+ cls.__name__,
253
+ (FusedMoEMethodBase,),
254
+ {
255
+ "__init__": original_init,
256
+ **{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
257
+ },
258
+ )
259
+ obj = super(new_cls, new_cls).__new__(new_cls)
260
+ obj.__init__(*args, **kwargs)
261
+ return obj
262
+ return super().__new__(cls)
263
+
264
+ def __init__(self, quant_config):
265
+ self.quant_config = quant_config
266
+ assert self.quant_config.weight_block_size is not None
267
+ assert self.quant_config.is_checkpoint_int8_serialized
268
+
269
+ def create_weights(
270
+ self,
271
+ layer: Module,
272
+ num_experts: int,
273
+ hidden_size: int,
274
+ intermediate_size: int,
275
+ params_dtype: torch.dtype,
276
+ **extra_weight_attrs,
277
+ ):
278
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
279
+
280
+ if self.quant_config.is_checkpoint_int8_serialized:
281
+ params_dtype = torch.int8
282
+ tp_size = get_tensor_model_parallel_world_size()
283
+
284
+ block_n, block_k = (
285
+ self.quant_config.weight_block_size[0],
286
+ self.quant_config.weight_block_size[1],
287
+ )
288
+ # NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
289
+ # Required by collum parallel or enabling merged weights
290
+ if intermediate_size % block_n != 0:
291
+ raise ValueError(
292
+ f"The output_size of gate's and up's weight = "
293
+ f"{intermediate_size} is not divisible by "
294
+ f"weight quantization block_n = {block_n}."
295
+ )
296
+ if tp_size > 1:
297
+ # Required by row parallel
298
+ if intermediate_size % block_k != 0:
299
+ raise ValueError(
300
+ f"The input_size of down's weight = "
301
+ f"{intermediate_size} is not divisible by "
302
+ f"weight quantization block_k = {block_k}."
303
+ )
304
+
305
+ # WEIGHTS
306
+ w13_weight = torch.nn.Parameter(
307
+ torch.empty(
308
+ num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype
309
+ ),
310
+ requires_grad=False,
311
+ )
312
+ layer.register_parameter("w13_weight", w13_weight)
313
+ set_weight_attrs(w13_weight, extra_weight_attrs)
314
+
315
+ w2_weight = torch.nn.Parameter(
316
+ torch.empty(
317
+ num_experts, hidden_size, intermediate_size, dtype=params_dtype
318
+ ),
319
+ requires_grad=False,
320
+ )
321
+ layer.register_parameter("w2_weight", w2_weight)
322
+ set_weight_attrs(w2_weight, extra_weight_attrs)
323
+
324
+ # WEIGHT_SCALES
325
+ w13_weight_scale = torch.nn.Parameter(
326
+ torch.ones(
327
+ num_experts,
328
+ 2 * ((intermediate_size + block_n - 1) // block_n),
329
+ (hidden_size + block_k - 1) // block_k,
330
+ dtype=torch.float32,
331
+ ),
332
+ requires_grad=False,
333
+ )
334
+ w2_weight_scale = torch.nn.Parameter(
335
+ torch.ones(
336
+ num_experts,
337
+ (hidden_size + block_n - 1) // block_n,
338
+ (intermediate_size + block_k - 1) // block_k,
339
+ dtype=torch.float32,
340
+ ),
341
+ requires_grad=False,
342
+ )
343
+ layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
344
+ layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
345
+
346
+ extra_weight_attrs.update(
347
+ {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
348
+ )
349
+ set_weight_attrs(w13_weight_scale, extra_weight_attrs)
350
+ set_weight_attrs(w2_weight_scale, extra_weight_attrs)
351
+
352
+ # INPUT_SCALES
353
+ assert self.quant_config.activation_scheme == "dynamic"
354
+ layer.w13_input_scale = None
355
+ layer.w2_input_scale = None
356
+
357
+ def process_weights_after_loading(self, layer: Module) -> None:
358
+ # Block quant doesn't need to process weights after loading
359
+ return
360
+
361
+ def apply(
362
+ self,
363
+ layer: torch.nn.Module,
364
+ x: torch.Tensor,
365
+ router_logits: torch.Tensor,
366
+ top_k: int,
367
+ renormalize: bool,
368
+ use_grouped_topk: bool,
369
+ topk_group: Optional[int] = None,
370
+ num_expert_group: Optional[int] = None,
371
+ custom_routing_function: Optional[Callable] = None,
372
+ correction_bias: Optional[torch.Tensor] = None,
373
+ activation: str = "silu",
374
+ inplace: bool = True,
375
+ no_combine: bool = False,
376
+ ) -> torch.Tensor:
377
+ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
378
+ from sglang.srt.layers.moe.topk import select_experts
379
+
380
+ # Expert selection
381
+ topk_weights, topk_ids = select_experts(
382
+ hidden_states=x,
383
+ router_logits=router_logits,
384
+ use_grouped_topk=use_grouped_topk,
385
+ top_k=top_k,
386
+ renormalize=renormalize,
387
+ topk_group=topk_group,
388
+ num_expert_group=num_expert_group,
389
+ custom_routing_function=custom_routing_function,
390
+ correction_bias=correction_bias,
391
+ )
392
+
393
+ # Expert fusion with INT8 quantization
394
+ return fused_experts(
395
+ x,
396
+ layer.w13_weight,
397
+ layer.w2_weight,
398
+ topk_weights=topk_weights,
399
+ topk_ids=topk_ids,
400
+ inplace=inplace,
401
+ activation=activation,
402
+ use_int8_w8a8=True,
403
+ w1_scale=(layer.w13_weight_scale_inv),
404
+ w2_scale=(layer.w2_weight_scale_inv),
405
+ a1_scale=layer.w13_input_scale,
406
+ a2_scale=layer.w2_input_scale,
407
+ block_shape=self.quant_config.weight_block_size,
408
+ no_combine=no_combine,
409
+ )