sglang 0.4.3.post2__py3-none-any.whl → 0.4.3.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (205) hide show
  1. sglang/api.py +1 -1
  2. sglang/bench_offline_throughput.py +19 -0
  3. sglang/bench_one_batch.py +2 -2
  4. sglang/bench_serving.py +123 -79
  5. sglang/global_config.py +8 -3
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/lang/ir.py +1 -1
  8. sglang/srt/_custom_ops.py +83 -91
  9. sglang/srt/configs/load_config.py +4 -1
  10. sglang/srt/configs/model_config.py +48 -2
  11. sglang/srt/configs/qwen2_5_vl_config.py +5 -2
  12. sglang/srt/constrained/base_grammar_backend.py +117 -15
  13. sglang/srt/constrained/llguidance_backend.py +151 -0
  14. sglang/srt/constrained/outlines_backend.py +24 -33
  15. sglang/srt/constrained/xgrammar_backend.py +69 -38
  16. sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
  17. sglang/srt/distributed/parallel_state.py +48 -3
  18. sglang/srt/entrypoints/engine.py +67 -9
  19. sglang/srt/entrypoints/http_server.py +190 -41
  20. sglang/srt/entrypoints/verl_engine.py +147 -0
  21. sglang/srt/function_call_parser.py +0 -1
  22. sglang/srt/layers/activation.py +11 -0
  23. sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
  24. sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
  25. sglang/srt/layers/attention/flashinfer_backend.py +302 -414
  26. sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
  27. sglang/srt/layers/attention/torch_native_backend.py +1 -1
  28. sglang/srt/layers/attention/triton_backend.py +13 -8
  29. sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
  30. sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
  31. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
  32. sglang/srt/layers/attention/utils.py +39 -0
  33. sglang/srt/layers/attention/vision.py +60 -63
  34. sglang/srt/layers/dp_attention.py +142 -1
  35. sglang/srt/layers/layernorm.py +1 -1
  36. sglang/srt/layers/linear.py +3 -1
  37. sglang/srt/layers/logits_processor.py +281 -45
  38. sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
  39. sglang/srt/layers/moe/ep_moe/layer.py +140 -28
  40. sglang/srt/layers/moe/fused_moe_native.py +2 -0
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
  48. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
  51. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
  55. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
  63. sglang/srt/layers/moe/topk.py +13 -4
  64. sglang/srt/layers/quantization/__init__.py +111 -7
  65. sglang/srt/layers/quantization/blockwise_int8.py +409 -0
  66. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  68. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  69. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  70. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  71. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  72. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  73. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  74. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  75. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  76. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  77. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  78. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  79. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  80. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  81. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  82. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  83. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  84. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  85. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  86. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  88. sglang/srt/layers/quantization/fp8.py +69 -28
  89. sglang/srt/layers/quantization/fp8_utils.py +17 -1
  90. sglang/srt/layers/quantization/gptq.py +416 -0
  91. sglang/srt/layers/quantization/int8_kernel.py +327 -0
  92. sglang/srt/layers/quantization/int8_utils.py +73 -0
  93. sglang/srt/layers/quantization/modelopt_quant.py +18 -1
  94. sglang/srt/layers/radix_attention.py +1 -0
  95. sglang/srt/layers/rotary_embedding.py +0 -1
  96. sglang/srt/layers/sampler.py +76 -31
  97. sglang/srt/layers/vocab_parallel_embedding.py +14 -13
  98. sglang/srt/lora/lora.py +17 -1
  99. sglang/srt/lora/lora_config.py +5 -0
  100. sglang/srt/lora/lora_manager.py +1 -3
  101. sglang/srt/managers/cache_controller.py +193 -62
  102. sglang/srt/managers/configure_logging.py +2 -1
  103. sglang/srt/managers/data_parallel_controller.py +6 -2
  104. sglang/srt/managers/detokenizer_manager.py +124 -102
  105. sglang/srt/managers/image_processor.py +2 -1
  106. sglang/srt/managers/io_struct.py +144 -6
  107. sglang/srt/managers/schedule_batch.py +237 -197
  108. sglang/srt/managers/schedule_policy.py +29 -29
  109. sglang/srt/managers/scheduler.py +773 -334
  110. sglang/srt/managers/session_controller.py +6 -2
  111. sglang/srt/managers/tokenizer_manager.py +225 -68
  112. sglang/srt/managers/tp_worker.py +15 -4
  113. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  114. sglang/srt/mem_cache/chunk_cache.py +18 -11
  115. sglang/srt/mem_cache/hiradix_cache.py +394 -0
  116. sglang/srt/mem_cache/memory_pool.py +68 -37
  117. sglang/srt/mem_cache/radix_cache.py +58 -47
  118. sglang/srt/metrics/collector.py +102 -36
  119. sglang/srt/model_executor/cuda_graph_runner.py +56 -31
  120. sglang/srt/model_executor/forward_batch_info.py +49 -16
  121. sglang/srt/model_executor/model_runner.py +280 -81
  122. sglang/srt/model_loader/loader.py +3 -3
  123. sglang/srt/model_loader/weight_utils.py +36 -14
  124. sglang/srt/models/baichuan.py +31 -6
  125. sglang/srt/models/chatglm.py +39 -7
  126. sglang/srt/models/commandr.py +29 -5
  127. sglang/srt/models/dbrx.py +31 -5
  128. sglang/srt/models/deepseek.py +43 -6
  129. sglang/srt/models/deepseek_nextn.py +32 -19
  130. sglang/srt/models/deepseek_v2.py +265 -32
  131. sglang/srt/models/exaone.py +19 -9
  132. sglang/srt/models/gemma.py +22 -8
  133. sglang/srt/models/gemma2.py +25 -12
  134. sglang/srt/models/gemma2_reward.py +5 -1
  135. sglang/srt/models/gpt2.py +28 -13
  136. sglang/srt/models/gpt_bigcode.py +27 -5
  137. sglang/srt/models/granite.py +21 -9
  138. sglang/srt/models/grok.py +21 -4
  139. sglang/srt/models/internlm2.py +36 -6
  140. sglang/srt/models/internlm2_reward.py +5 -1
  141. sglang/srt/models/llama.py +26 -9
  142. sglang/srt/models/llama_classification.py +5 -1
  143. sglang/srt/models/llama_eagle.py +17 -4
  144. sglang/srt/models/llama_embedding.py +5 -1
  145. sglang/srt/models/llama_reward.py +7 -2
  146. sglang/srt/models/llava.py +19 -3
  147. sglang/srt/models/llavavid.py +10 -1
  148. sglang/srt/models/minicpm.py +26 -2
  149. sglang/srt/models/minicpm3.py +39 -3
  150. sglang/srt/models/minicpmv.py +45 -14
  151. sglang/srt/models/mixtral.py +20 -9
  152. sglang/srt/models/mixtral_quant.py +50 -8
  153. sglang/srt/models/mllama.py +57 -11
  154. sglang/srt/models/olmo.py +34 -6
  155. sglang/srt/models/olmo2.py +34 -13
  156. sglang/srt/models/olmoe.py +26 -4
  157. sglang/srt/models/phi3_small.py +29 -10
  158. sglang/srt/models/qwen.py +26 -3
  159. sglang/srt/models/qwen2.py +26 -4
  160. sglang/srt/models/qwen2_5_vl.py +46 -8
  161. sglang/srt/models/qwen2_eagle.py +17 -5
  162. sglang/srt/models/qwen2_moe.py +44 -6
  163. sglang/srt/models/qwen2_rm.py +78 -0
  164. sglang/srt/models/qwen2_vl.py +39 -8
  165. sglang/srt/models/stablelm.py +32 -5
  166. sglang/srt/models/torch_native_llama.py +5 -2
  167. sglang/srt/models/xverse.py +21 -9
  168. sglang/srt/models/xverse_moe.py +45 -7
  169. sglang/srt/models/yivl.py +2 -1
  170. sglang/srt/openai_api/adapter.py +109 -24
  171. sglang/srt/openai_api/protocol.py +17 -1
  172. sglang/srt/reasoning_parser.py +154 -0
  173. sglang/srt/sampling/penaltylib/__init__.py +4 -6
  174. sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
  175. sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
  176. sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
  177. sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
  178. sglang/srt/sampling/sampling_batch_info.py +79 -157
  179. sglang/srt/sampling/sampling_params.py +16 -13
  180. sglang/srt/server_args.py +135 -60
  181. sglang/srt/speculative/build_eagle_tree.py +8 -9
  182. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -12
  183. sglang/srt/speculative/eagle_utils.py +92 -57
  184. sglang/srt/speculative/eagle_worker.py +238 -111
  185. sglang/srt/speculative/spec_info.py +1 -13
  186. sglang/srt/utils.py +43 -17
  187. sglang/srt/warmup.py +47 -0
  188. sglang/test/few_shot_gsm8k.py +4 -1
  189. sglang/test/runners.py +389 -126
  190. sglang/test/send_one.py +88 -0
  191. sglang/test/test_block_fp8_ep.py +361 -0
  192. sglang/test/test_programs.py +1 -1
  193. sglang/test/test_utils.py +138 -84
  194. sglang/utils.py +50 -60
  195. sglang/version.py +1 -1
  196. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/METADATA +22 -15
  197. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/RECORD +200 -166
  198. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/WHEEL +1 -1
  199. sglang/bench_latency.py +0 -1
  200. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
  201. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
  202. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
  203. sglang/test/srt/sampling/penaltylib/utils.py +0 -344
  204. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/LICENSE +0 -0
  205. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,409 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
2
+
3
+ import logging
4
+ from typing import Any, Callable, Dict, List, Optional
5
+
6
+ import torch
7
+ from torch.nn import Module
8
+ from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
9
+
10
+ from sglang.srt.distributed import get_tensor_model_parallel_world_size
11
+ from sglang.srt.layers.linear import (
12
+ LinearBase,
13
+ LinearMethodBase,
14
+ UnquantizedLinearMethod,
15
+ )
16
+ from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
17
+ from sglang.srt.layers.quantization.base_config import (
18
+ QuantizationConfig,
19
+ QuantizeMethodBase,
20
+ )
21
+ from sglang.srt.layers.quantization.fp8_utils import BlockQuantScaleParameter
22
+ from sglang.srt.layers.quantization.int8_utils import apply_w8a8_block_int8_linear
23
+ from sglang.srt.utils import set_weight_attrs
24
+
25
+ ACTIVATION_SCHEMES = ["static", "dynamic"]
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ class BlockInt8Config(QuantizationConfig):
31
+ """Config class for INT8."""
32
+
33
+ def __init__(
34
+ self,
35
+ is_checkpoint_int8_serialized: bool = False,
36
+ activation_scheme: str = "dynamic",
37
+ ignored_layers: Optional[List[str]] = None,
38
+ weight_block_size: List[int] = None,
39
+ ) -> None:
40
+ self.is_checkpoint_int8_serialized = is_checkpoint_int8_serialized
41
+ if is_checkpoint_int8_serialized:
42
+ logger.warning(
43
+ "Detected int8 checkpoint. Please note that the "
44
+ "format is experimental and subject to change."
45
+ )
46
+ if activation_scheme not in ACTIVATION_SCHEMES:
47
+ raise ValueError(f"Unsupported activation scheme {activation_scheme}")
48
+ self.activation_scheme = activation_scheme
49
+ self.ignored_layers = ignored_layers or []
50
+ if weight_block_size is not None:
51
+ if not is_checkpoint_int8_serialized:
52
+ raise ValueError(
53
+ f"The block-wise quantization only supports int8-serialized checkpoint for now."
54
+ )
55
+ if len(weight_block_size) != 2:
56
+ raise ValueError(
57
+ f"The quantization block size of weight must have 2 dimensions, but got {len(weight_block_size)} dimensions."
58
+ )
59
+ if activation_scheme != "dynamic":
60
+ raise ValueError(
61
+ f"The block-wise quantization only supports dynamic activation scheme for now, but got {activation_scheme} activation scheme."
62
+ )
63
+ self.weight_block_size = weight_block_size
64
+
65
+ @classmethod
66
+ def get_name(cls) -> str:
67
+ return "blockwise_int8"
68
+
69
+ @classmethod
70
+ def get_supported_act_dtypes(cls) -> List[torch.dtype]:
71
+ return [torch.bfloat16, torch.half]
72
+
73
+ @classmethod
74
+ def get_min_capability(cls) -> int:
75
+ return 80
76
+
77
+ @classmethod
78
+ def get_config_filenames(cls) -> List[str]:
79
+ return []
80
+
81
+ @classmethod
82
+ def from_config(cls, config: Dict[str, Any]) -> "BlockInt8Config":
83
+ quant_method = cls.get_from_keys(config, ["quant_method"])
84
+ is_checkpoint_int8_serialized = "int8" in quant_method
85
+ activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
86
+ ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
87
+ weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None)
88
+ return cls(
89
+ is_checkpoint_int8_serialized=is_checkpoint_int8_serialized,
90
+ activation_scheme=activation_scheme,
91
+ ignored_layers=ignored_layers,
92
+ weight_block_size=weight_block_size,
93
+ )
94
+
95
+ def get_quant_method(
96
+ self, layer: torch.nn.Module, prefix: str
97
+ ) -> Optional["QuantizeMethodBase"]:
98
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
99
+
100
+ if isinstance(layer, LinearBase):
101
+ if is_layer_skipped(prefix, self.ignored_layers):
102
+ return UnquantizedLinearMethod()
103
+ return BlockInt8LinearMethod(self)
104
+ elif isinstance(layer, FusedMoE):
105
+ return BlockInt8MoEMethod(self)
106
+ return None
107
+
108
+ def get_scaled_act_names(self) -> List[str]:
109
+ return []
110
+
111
+
112
+ class BlockInt8LinearMethod(LinearMethodBase):
113
+ """Linear method for INT8.
114
+ Supports loading INT8 checkpoints with static weight scale and
115
+ dynamic activation scale.
116
+
117
+ Limitations:
118
+ Only support block-wise int8 quantization and int8 checkpoint
119
+
120
+ Args:
121
+ quant_config: The quantization config.
122
+ """
123
+
124
+ def __init__(self, quant_config: BlockInt8Config):
125
+ self.quant_config = quant_config
126
+ assert self.quant_config.weight_block_size is not None
127
+ assert self.quant_config.is_checkpoint_int8_serialized
128
+
129
+ def create_weights(
130
+ self,
131
+ layer: torch.nn.Module,
132
+ input_size_per_partition: int,
133
+ output_partition_sizes: List[int],
134
+ input_size: int,
135
+ output_size: int,
136
+ params_dtype: torch.dtype,
137
+ **extra_weight_attrs,
138
+ ):
139
+ output_size_per_partition = sum(output_partition_sizes)
140
+ weight_loader = extra_weight_attrs.get("weight_loader")
141
+
142
+ tp_size = get_tensor_model_parallel_world_size()
143
+
144
+ block_n, block_k = (
145
+ self.quant_config.weight_block_size[0],
146
+ self.quant_config.weight_block_size[1],
147
+ )
148
+ # Required by row parallel
149
+ if tp_size > 1 and input_size // input_size_per_partition == tp_size:
150
+ if input_size_per_partition % block_k != 0:
151
+ raise ValueError(
152
+ f"Weight input_size_per_partition = "
153
+ f"{input_size_per_partition} is not divisible by "
154
+ f"weight quantization block_k = {block_k}."
155
+ )
156
+ # Required by collum parallel or enabling merged weights
157
+ if (tp_size > 1 and output_size // output_size_per_partition == tp_size) or len(
158
+ output_partition_sizes
159
+ ) > 1:
160
+ for output_partition_size in output_partition_sizes:
161
+ if output_partition_size % block_n != 0:
162
+ raise ValueError(
163
+ f"Weight output_partition_size = "
164
+ f"{output_partition_size} is not divisible by "
165
+ f"weight quantization block_n = {block_n}."
166
+ )
167
+
168
+ layer.logical_widths = output_partition_sizes
169
+
170
+ layer.input_size_per_partition = input_size_per_partition
171
+ layer.output_size_per_partition = output_size_per_partition
172
+ layer.orig_dtype = params_dtype
173
+
174
+ # WEIGHT
175
+ weight_dtype = (
176
+ torch.int8
177
+ if self.quant_config.is_checkpoint_int8_serialized
178
+ else params_dtype
179
+ )
180
+
181
+ weight = ModelWeightParameter(
182
+ data=torch.empty(
183
+ output_size_per_partition, input_size_per_partition, dtype=weight_dtype
184
+ ),
185
+ input_dim=1,
186
+ output_dim=0,
187
+ weight_loader=weight_loader,
188
+ )
189
+ layer.register_parameter("weight", weight)
190
+
191
+ # WEIGHT SCALE
192
+
193
+ scale = BlockQuantScaleParameter(
194
+ data=torch.empty(
195
+ (output_size_per_partition + block_n - 1) // block_n,
196
+ (input_size_per_partition + block_k - 1) // block_k,
197
+ dtype=torch.float32,
198
+ ),
199
+ input_dim=1,
200
+ output_dim=0,
201
+ weight_loader=weight_loader,
202
+ )
203
+ scale[:] = torch.finfo(torch.float32).min
204
+ layer.register_parameter("weight_scale_inv", scale)
205
+
206
+ # INPUT ACTIVATION SCALE
207
+ assert self.quant_config.activation_scheme == "dynamic"
208
+ layer.register_parameter("input_scale", None)
209
+
210
+ def process_weights_after_loading(self, layer: Module) -> None:
211
+ # Block quant doesn't need to process weights after loading
212
+ # Use torch Parameter to avoid cuda graph capturing issue
213
+ layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
214
+ layer.weight_scale_inv = torch.nn.Parameter(
215
+ layer.weight_scale_inv.data, requires_grad=False
216
+ )
217
+
218
+ def apply(
219
+ self,
220
+ layer: torch.nn.Module,
221
+ x: torch.Tensor,
222
+ bias: Optional[torch.Tensor] = None,
223
+ ) -> torch.Tensor:
224
+ return apply_w8a8_block_int8_linear(
225
+ input=x,
226
+ weight=layer.weight,
227
+ block_size=self.quant_config.weight_block_size,
228
+ weight_scale=layer.weight_scale_inv,
229
+ input_scale=None,
230
+ bias=bias,
231
+ )
232
+
233
+
234
+ class BlockInt8MoEMethod:
235
+ """MoE method for INT8.
236
+ Supports loading INT8 checkpoints with static weight scale and
237
+ dynamic activation scale.
238
+
239
+ Limitations:
240
+ Only support block-wise int8 quantization and int8 checkpoint
241
+
242
+ Args:
243
+ quant_config: The quantization config.
244
+ """
245
+
246
+ def __new__(cls, *args, **kwargs):
247
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
248
+
249
+ if not hasattr(cls, "_initialized"):
250
+ original_init = cls.__init__
251
+ new_cls = type(
252
+ cls.__name__,
253
+ (FusedMoEMethodBase,),
254
+ {
255
+ "__init__": original_init,
256
+ **{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
257
+ },
258
+ )
259
+ obj = super(new_cls, new_cls).__new__(new_cls)
260
+ obj.__init__(*args, **kwargs)
261
+ return obj
262
+ return super().__new__(cls)
263
+
264
+ def __init__(self, quant_config):
265
+ self.quant_config = quant_config
266
+ assert self.quant_config.weight_block_size is not None
267
+ assert self.quant_config.is_checkpoint_int8_serialized
268
+
269
+ def create_weights(
270
+ self,
271
+ layer: Module,
272
+ num_experts: int,
273
+ hidden_size: int,
274
+ intermediate_size: int,
275
+ params_dtype: torch.dtype,
276
+ **extra_weight_attrs,
277
+ ):
278
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
279
+
280
+ if self.quant_config.is_checkpoint_int8_serialized:
281
+ params_dtype = torch.int8
282
+ tp_size = get_tensor_model_parallel_world_size()
283
+
284
+ block_n, block_k = (
285
+ self.quant_config.weight_block_size[0],
286
+ self.quant_config.weight_block_size[1],
287
+ )
288
+ # NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
289
+ # Required by collum parallel or enabling merged weights
290
+ if intermediate_size % block_n != 0:
291
+ raise ValueError(
292
+ f"The output_size of gate's and up's weight = "
293
+ f"{intermediate_size} is not divisible by "
294
+ f"weight quantization block_n = {block_n}."
295
+ )
296
+ if tp_size > 1:
297
+ # Required by row parallel
298
+ if intermediate_size % block_k != 0:
299
+ raise ValueError(
300
+ f"The input_size of down's weight = "
301
+ f"{intermediate_size} is not divisible by "
302
+ f"weight quantization block_k = {block_k}."
303
+ )
304
+
305
+ # WEIGHTS
306
+ w13_weight = torch.nn.Parameter(
307
+ torch.empty(
308
+ num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype
309
+ ),
310
+ requires_grad=False,
311
+ )
312
+ layer.register_parameter("w13_weight", w13_weight)
313
+ set_weight_attrs(w13_weight, extra_weight_attrs)
314
+
315
+ w2_weight = torch.nn.Parameter(
316
+ torch.empty(
317
+ num_experts, hidden_size, intermediate_size, dtype=params_dtype
318
+ ),
319
+ requires_grad=False,
320
+ )
321
+ layer.register_parameter("w2_weight", w2_weight)
322
+ set_weight_attrs(w2_weight, extra_weight_attrs)
323
+
324
+ # WEIGHT_SCALES
325
+ w13_weight_scale = torch.nn.Parameter(
326
+ torch.ones(
327
+ num_experts,
328
+ 2 * ((intermediate_size + block_n - 1) // block_n),
329
+ (hidden_size + block_k - 1) // block_k,
330
+ dtype=torch.float32,
331
+ ),
332
+ requires_grad=False,
333
+ )
334
+ w2_weight_scale = torch.nn.Parameter(
335
+ torch.ones(
336
+ num_experts,
337
+ (hidden_size + block_n - 1) // block_n,
338
+ (intermediate_size + block_k - 1) // block_k,
339
+ dtype=torch.float32,
340
+ ),
341
+ requires_grad=False,
342
+ )
343
+ layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
344
+ layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
345
+
346
+ extra_weight_attrs.update(
347
+ {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
348
+ )
349
+ set_weight_attrs(w13_weight_scale, extra_weight_attrs)
350
+ set_weight_attrs(w2_weight_scale, extra_weight_attrs)
351
+
352
+ # INPUT_SCALES
353
+ assert self.quant_config.activation_scheme == "dynamic"
354
+ layer.w13_input_scale = None
355
+ layer.w2_input_scale = None
356
+
357
+ def process_weights_after_loading(self, layer: Module) -> None:
358
+ # Block quant doesn't need to process weights after loading
359
+ return
360
+
361
+ def apply(
362
+ self,
363
+ layer: torch.nn.Module,
364
+ x: torch.Tensor,
365
+ router_logits: torch.Tensor,
366
+ top_k: int,
367
+ renormalize: bool,
368
+ use_grouped_topk: bool,
369
+ topk_group: Optional[int] = None,
370
+ num_expert_group: Optional[int] = None,
371
+ custom_routing_function: Optional[Callable] = None,
372
+ correction_bias: Optional[torch.Tensor] = None,
373
+ activation: str = "silu",
374
+ inplace: bool = True,
375
+ no_combine: bool = False,
376
+ ) -> torch.Tensor:
377
+ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
378
+ from sglang.srt.layers.moe.topk import select_experts
379
+
380
+ # Expert selection
381
+ topk_weights, topk_ids = select_experts(
382
+ hidden_states=x,
383
+ router_logits=router_logits,
384
+ use_grouped_topk=use_grouped_topk,
385
+ top_k=top_k,
386
+ renormalize=renormalize,
387
+ topk_group=topk_group,
388
+ num_expert_group=num_expert_group,
389
+ custom_routing_function=custom_routing_function,
390
+ correction_bias=correction_bias,
391
+ )
392
+
393
+ # Expert fusion with INT8 quantization
394
+ return fused_experts(
395
+ x,
396
+ layer.w13_weight,
397
+ layer.w2_weight,
398
+ topk_weights=topk_weights,
399
+ topk_ids=topk_ids,
400
+ inplace=inplace,
401
+ activation=activation,
402
+ use_int8_w8a8=True,
403
+ w1_scale=(layer.w13_weight_scale_inv),
404
+ w2_scale=(layer.w2_weight_scale_inv),
405
+ a1_scale=layer.w13_input_scale,
406
+ a2_scale=layer.w2_input_scale,
407
+ block_shape=self.quant_config.weight_block_size,
408
+ no_combine=no_combine,
409
+ )
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 64,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 64,
7
+ "num_warps": 4,
8
+ "num_stages": 5
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 64,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 64,
15
+ "num_warps": 4,
16
+ "num_stages": 3
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 32,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 64,
23
+ "num_warps": 4,
24
+ "num_stages": 5
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 32,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 16,
31
+ "num_warps": 4,
32
+ "num_stages": 5
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 32,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 32,
39
+ "num_warps": 4,
40
+ "num_stages": 5
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 32,
45
+ "BLOCK_SIZE_K": 128,
46
+ "GROUP_SIZE_M": 64,
47
+ "num_warps": 4,
48
+ "num_stages": 5
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 32,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 1,
55
+ "num_warps": 4,
56
+ "num_stages": 4
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 64,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 1,
63
+ "num_warps": 4,
64
+ "num_stages": 4
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 64,
69
+ "BLOCK_SIZE_K": 128,
70
+ "GROUP_SIZE_M": 64,
71
+ "num_warps": 4,
72
+ "num_stages": 4
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 64,
76
+ "BLOCK_SIZE_N": 64,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 32,
79
+ "num_warps": 8,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 64,
84
+ "BLOCK_SIZE_N": 32,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 1,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 64,
92
+ "BLOCK_SIZE_N": 64,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 64,
95
+ "num_warps": 8,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 128,
100
+ "BLOCK_SIZE_N": 64,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 64,
103
+ "num_warps": 8,
104
+ "num_stages": 5
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 32,
108
+ "BLOCK_SIZE_N": 64,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 32,
111
+ "num_warps": 4,
112
+ "num_stages": 3
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 128,
116
+ "BLOCK_SIZE_N": 64,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 64,
119
+ "num_warps": 4,
120
+ "num_stages": 3
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 64,
124
+ "BLOCK_SIZE_N": 256,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 1,
127
+ "num_warps": 8,
128
+ "num_stages": 4
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 128,
132
+ "BLOCK_SIZE_N": 32,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 16,
135
+ "num_warps": 4,
136
+ "num_stages": 4
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 64,
140
+ "BLOCK_SIZE_N": 64,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 1,
143
+ "num_warps": 4,
144
+ "num_stages": 3
145
+ }
146
+ }
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 32,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 64,
7
+ "num_warps": 4,
8
+ "num_stages": 5
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 32,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 64,
15
+ "num_warps": 4,
16
+ "num_stages": 5
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 32,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 64,
23
+ "num_warps": 4,
24
+ "num_stages": 5
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 32,
28
+ "BLOCK_SIZE_N": 32,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 16,
31
+ "num_warps": 8,
32
+ "num_stages": 5
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 32,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 16,
39
+ "num_warps": 4,
40
+ "num_stages": 5
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 32,
45
+ "BLOCK_SIZE_K": 128,
46
+ "GROUP_SIZE_M": 32,
47
+ "num_warps": 4,
48
+ "num_stages": 5
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 32,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 16,
55
+ "num_warps": 4,
56
+ "num_stages": 5
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 64,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 1,
63
+ "num_warps": 4,
64
+ "num_stages": 5
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 64,
69
+ "BLOCK_SIZE_K": 128,
70
+ "GROUP_SIZE_M": 32,
71
+ "num_warps": 4,
72
+ "num_stages": 5
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 32,
76
+ "BLOCK_SIZE_N": 64,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 1,
79
+ "num_warps": 4,
80
+ "num_stages": 4
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 32,
84
+ "BLOCK_SIZE_N": 64,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 16,
87
+ "num_warps": 4,
88
+ "num_stages": 4
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 32,
92
+ "BLOCK_SIZE_N": 64,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 1,
95
+ "num_warps": 4,
96
+ "num_stages": 5
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 64,
100
+ "BLOCK_SIZE_N": 64,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 1,
103
+ "num_warps": 4,
104
+ "num_stages": 5
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 64,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 1,
111
+ "num_warps": 4,
112
+ "num_stages": 3
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 128,
116
+ "BLOCK_SIZE_N": 64,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 1,
119
+ "num_warps": 4,
120
+ "num_stages": 3
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 64,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 64,
127
+ "num_warps": 4,
128
+ "num_stages": 3
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 64,
132
+ "BLOCK_SIZE_N": 128,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 64,
135
+ "num_warps": 4,
136
+ "num_stages": 3
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 128,
140
+ "BLOCK_SIZE_N": 64,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 1,
143
+ "num_warps": 4,
144
+ "num_stages": 3
145
+ }
146
+ }