sglang 0.5.4__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. sglang/bench_one_batch.py +149 -34
  2. sglang/bench_serving.py +73 -14
  3. sglang/compile_deep_gemm.py +13 -7
  4. sglang/launch_server.py +2 -0
  5. sglang/srt/batch_invariant_ops/__init__.py +2 -0
  6. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +221 -4
  7. sglang/srt/checkpoint_engine/__init__.py +9 -0
  8. sglang/srt/checkpoint_engine/update.py +317 -0
  9. sglang/srt/compilation/backend.py +1 -1
  10. sglang/srt/configs/__init__.py +2 -0
  11. sglang/srt/configs/deepseek_ocr.py +542 -10
  12. sglang/srt/configs/deepseekvl2.py +95 -194
  13. sglang/srt/configs/kimi_linear.py +160 -0
  14. sglang/srt/configs/mamba_utils.py +66 -0
  15. sglang/srt/configs/model_config.py +30 -7
  16. sglang/srt/constants.py +7 -0
  17. sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
  18. sglang/srt/disaggregation/decode.py +34 -6
  19. sglang/srt/disaggregation/nixl/conn.py +2 -2
  20. sglang/srt/disaggregation/prefill.py +25 -3
  21. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
  22. sglang/srt/distributed/parallel_state.py +9 -12
  23. sglang/srt/entrypoints/engine.py +31 -20
  24. sglang/srt/entrypoints/grpc_server.py +0 -1
  25. sglang/srt/entrypoints/http_server.py +94 -94
  26. sglang/srt/entrypoints/openai/protocol.py +7 -1
  27. sglang/srt/entrypoints/openai/serving_chat.py +42 -0
  28. sglang/srt/entrypoints/openai/serving_completions.py +10 -0
  29. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  30. sglang/srt/environ.py +23 -2
  31. sglang/srt/eplb/expert_distribution.py +64 -1
  32. sglang/srt/eplb/expert_location.py +106 -36
  33. sglang/srt/function_call/function_call_parser.py +2 -0
  34. sglang/srt/function_call/minimax_m2.py +367 -0
  35. sglang/srt/grpc/compile_proto.py +3 -0
  36. sglang/srt/layers/activation.py +6 -0
  37. sglang/srt/layers/attention/ascend_backend.py +233 -5
  38. sglang/srt/layers/attention/attention_registry.py +3 -0
  39. sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
  40. sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
  41. sglang/srt/layers/attention/fla/kda.py +1359 -0
  42. sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
  43. sglang/srt/layers/attention/flashattention_backend.py +19 -8
  44. sglang/srt/layers/attention/flashinfer_backend.py +10 -1
  45. sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -11
  46. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  47. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
  48. sglang/srt/layers/attention/mamba/mamba.py +20 -11
  49. sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
  50. sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
  51. sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
  52. sglang/srt/layers/attention/nsa/transform_index.py +1 -1
  53. sglang/srt/layers/attention/nsa_backend.py +157 -23
  54. sglang/srt/layers/attention/triton_backend.py +4 -1
  55. sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
  56. sglang/srt/layers/attention/trtllm_mla_backend.py +11 -15
  57. sglang/srt/layers/attention/utils.py +78 -0
  58. sglang/srt/layers/communicator.py +24 -1
  59. sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
  60. sglang/srt/layers/layernorm.py +35 -6
  61. sglang/srt/layers/logits_processor.py +9 -20
  62. sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
  63. sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
  64. sglang/srt/layers/moe/ep_moe/layer.py +78 -289
  65. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
  67. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
  68. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
  69. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
  70. sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
  71. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
  72. sglang/srt/layers/moe/moe_runner/deep_gemm.py +340 -55
  73. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  74. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  75. sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
  76. sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
  77. sglang/srt/layers/moe/token_dispatcher/deepep.py +25 -18
  78. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  79. sglang/srt/layers/moe/topk.py +35 -10
  80. sglang/srt/layers/moe/utils.py +3 -4
  81. sglang/srt/layers/pooler.py +21 -2
  82. sglang/srt/layers/quantization/__init__.py +13 -84
  83. sglang/srt/layers/quantization/auto_round.py +394 -0
  84. sglang/srt/layers/quantization/awq.py +0 -3
  85. sglang/srt/layers/quantization/base_config.py +7 -0
  86. sglang/srt/layers/quantization/fp8.py +68 -63
  87. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  88. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  89. sglang/srt/layers/quantization/gguf.py +566 -0
  90. sglang/srt/layers/quantization/modelopt_quant.py +168 -11
  91. sglang/srt/layers/quantization/mxfp4.py +30 -38
  92. sglang/srt/layers/quantization/unquant.py +23 -45
  93. sglang/srt/layers/quantization/w4afp8.py +38 -2
  94. sglang/srt/layers/radix_attention.py +5 -2
  95. sglang/srt/layers/rotary_embedding.py +130 -46
  96. sglang/srt/layers/sampler.py +12 -1
  97. sglang/srt/lora/lora_registry.py +9 -0
  98. sglang/srt/managers/async_mm_data_processor.py +122 -0
  99. sglang/srt/managers/data_parallel_controller.py +30 -3
  100. sglang/srt/managers/detokenizer_manager.py +3 -0
  101. sglang/srt/managers/io_struct.py +29 -4
  102. sglang/srt/managers/multi_tokenizer_mixin.py +22 -1
  103. sglang/srt/managers/schedule_batch.py +74 -15
  104. sglang/srt/managers/scheduler.py +185 -144
  105. sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
  106. sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
  107. sglang/srt/managers/scheduler_pp_mixin.py +7 -2
  108. sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
  109. sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
  110. sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
  111. sglang/srt/managers/session_controller.py +6 -5
  112. sglang/srt/managers/tokenizer_manager.py +165 -78
  113. sglang/srt/managers/tp_worker.py +24 -1
  114. sglang/srt/mem_cache/base_prefix_cache.py +23 -4
  115. sglang/srt/mem_cache/common.py +1 -0
  116. sglang/srt/mem_cache/hicache_storage.py +7 -1
  117. sglang/srt/mem_cache/memory_pool.py +253 -57
  118. sglang/srt/mem_cache/memory_pool_host.py +12 -5
  119. sglang/srt/mem_cache/radix_cache.py +4 -0
  120. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  121. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
  122. sglang/srt/metrics/collector.py +46 -3
  123. sglang/srt/model_executor/cuda_graph_runner.py +15 -3
  124. sglang/srt/model_executor/forward_batch_info.py +55 -14
  125. sglang/srt/model_executor/model_runner.py +77 -170
  126. sglang/srt/model_executor/npu_graph_runner.py +7 -3
  127. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
  128. sglang/srt/model_loader/weight_utils.py +1 -1
  129. sglang/srt/models/bailing_moe.py +9 -2
  130. sglang/srt/models/deepseek_nextn.py +11 -2
  131. sglang/srt/models/deepseek_v2.py +296 -78
  132. sglang/srt/models/glm4.py +391 -77
  133. sglang/srt/models/glm4_moe.py +322 -354
  134. sglang/srt/models/glm4_moe_nextn.py +4 -14
  135. sglang/srt/models/glm4v.py +196 -55
  136. sglang/srt/models/glm4v_moe.py +29 -197
  137. sglang/srt/models/gpt_oss.py +1 -10
  138. sglang/srt/models/kimi_linear.py +678 -0
  139. sglang/srt/models/llama4.py +1 -1
  140. sglang/srt/models/llama_eagle3.py +11 -1
  141. sglang/srt/models/longcat_flash.py +2 -2
  142. sglang/srt/models/minimax_m2.py +922 -0
  143. sglang/srt/models/nvila.py +355 -0
  144. sglang/srt/models/nvila_lite.py +184 -0
  145. sglang/srt/models/qwen2.py +23 -2
  146. sglang/srt/models/qwen2_moe.py +30 -15
  147. sglang/srt/models/qwen3.py +35 -5
  148. sglang/srt/models/qwen3_moe.py +18 -12
  149. sglang/srt/models/qwen3_next.py +7 -0
  150. sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
  151. sglang/srt/multimodal/processors/base_processor.py +1 -0
  152. sglang/srt/multimodal/processors/glm4v.py +1 -1
  153. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  154. sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
  155. sglang/srt/multiplex/multiplexing_mixin.py +209 -0
  156. sglang/srt/multiplex/pdmux_context.py +164 -0
  157. sglang/srt/parser/conversation.py +7 -1
  158. sglang/srt/parser/reasoning_parser.py +28 -1
  159. sglang/srt/sampling/custom_logit_processor.py +67 -1
  160. sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
  161. sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
  162. sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
  163. sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
  164. sglang/srt/server_args.py +459 -199
  165. sglang/srt/single_batch_overlap.py +2 -4
  166. sglang/srt/speculative/draft_utils.py +16 -0
  167. sglang/srt/speculative/eagle_info.py +42 -36
  168. sglang/srt/speculative/eagle_info_v2.py +68 -25
  169. sglang/srt/speculative/eagle_utils.py +261 -16
  170. sglang/srt/speculative/eagle_worker.py +11 -3
  171. sglang/srt/speculative/eagle_worker_v2.py +15 -9
  172. sglang/srt/speculative/spec_info.py +305 -31
  173. sglang/srt/speculative/spec_utils.py +44 -8
  174. sglang/srt/tracing/trace.py +121 -12
  175. sglang/srt/utils/common.py +142 -74
  176. sglang/srt/utils/hf_transformers_utils.py +38 -12
  177. sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
  178. sglang/test/kits/radix_cache_server_kit.py +50 -0
  179. sglang/test/runners.py +31 -7
  180. sglang/test/simple_eval_common.py +5 -3
  181. sglang/test/simple_eval_humaneval.py +1 -0
  182. sglang/test/simple_eval_math.py +1 -0
  183. sglang/test/simple_eval_mmlu.py +1 -0
  184. sglang/test/simple_eval_mmmu_vlm.py +1 -0
  185. sglang/test/test_deterministic.py +235 -12
  186. sglang/test/test_deterministic_utils.py +2 -1
  187. sglang/test/test_utils.py +7 -1
  188. sglang/version.py +1 -1
  189. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +15 -28
  190. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +194 -175
  191. sglang/srt/models/vila.py +0 -306
  192. /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
  193. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
  194. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
  195. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,394 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import logging
4
+ import re
5
+ from fractions import Fraction
6
+ from typing import Any, Optional, Union
7
+
8
+ import torch
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ from sglang.srt.layers.quantization.utils import get_scalar_types
13
+
14
+ ScalarType, scalar_types = get_scalar_types()
15
+
16
+
17
+ from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
18
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
19
+
20
+
21
+ class AutoRoundConfig(QuantizationConfig):
22
+ """Config class for AutoRound.
23
+ Reference: https://arxiv.org/pdf/2309.05516
24
+ """
25
+
26
+ SUPPORTED_BITS = {2, 3, 4, 8}
27
+ SUPPORTED_DTYPES = {"int"}
28
+ SUPPORTED_FORMATS = {"auto_round:auto_gptq", "auto_round:auto_awq"}
29
+ SUPPORTED_BACKENDS = {"auto", "gptq", "gptq:marlin", "awq", "awq:marlin", "marlin"}
30
+
31
+ def __init__(
32
+ self,
33
+ weight_bits: int,
34
+ group_size: int,
35
+ sym: bool = True,
36
+ packing_format: str = "auto_round:auto_gptq",
37
+ block_name_to_quantize: Optional[Union[str, list[str]]] = None,
38
+ extra_config: Optional[dict[str, Any]] = None,
39
+ data_type: str = "int",
40
+ backend: str = "auto",
41
+ ) -> None:
42
+ super().__init__()
43
+ if weight_bits not in self.SUPPORTED_BITS:
44
+ raise ValueError(
45
+ f"Unsupported weight_bits: {weight_bits}, "
46
+ f"currently only support {self.SUPPORTED_BITS}"
47
+ )
48
+ if data_type not in self.SUPPORTED_DTYPES:
49
+ raise ValueError(
50
+ f"Unsupported data_type: {data_type},"
51
+ f" currently only support {self.SUPPORTED_DTYPES}"
52
+ )
53
+ if packing_format not in self.SUPPORTED_FORMATS:
54
+ raise ValueError(
55
+ f"Unsupported packing_format: {packing_format}, "
56
+ f"currently only support {self.SUPPORTED_FORMATS}"
57
+ )
58
+ if backend not in self.SUPPORTED_BACKENDS:
59
+ raise ValueError(
60
+ f"Unsupported backend: {backend}, "
61
+ f"currently only support {self.SUPPORTED_BACKENDS}"
62
+ )
63
+
64
+ self.weight_bits = weight_bits
65
+ self.group_size = group_size
66
+ self.sym = sym
67
+ self.packing_format = packing_format
68
+ self.block_name_to_quantize = (
69
+ block_name_to_quantize.split(",")
70
+ if isinstance(block_name_to_quantize, str)
71
+ else block_name_to_quantize
72
+ )
73
+ self.extra_config = extra_config
74
+ self.data_type = data_type
75
+ self.backend = backend
76
+ self.pack_factor = Fraction(32, weight_bits)
77
+
78
+ def __repr__(self) -> str:
79
+ return (
80
+ f"AutoRoundConfig(weight_bits={self.weight_bits}, "
81
+ f"group_size={self.group_size}, sym={self.sym})"
82
+ )
83
+
84
+ @classmethod
85
+ def get_name(cls):
86
+ return "auto-round"
87
+
88
+ @classmethod
89
+ def get_supported_act_dtypes(cls) -> list[torch.dtype]:
90
+ return [torch.half, torch.bfloat16]
91
+
92
+ @classmethod
93
+ def get_min_capability(cls) -> int:
94
+ return 60
95
+
96
+ @classmethod
97
+ def get_config_filenames(cls) -> list[str]:
98
+ return ["quantization_config.json"]
99
+
100
+ @classmethod
101
+ def from_config(cls, config: dict[str, Any]) -> "AutoRoundConfig":
102
+ return cls(
103
+ weight_bits=cls.get_from_keys(config, ["bits"]),
104
+ group_size=cls.get_from_keys(config, ["group_size"]),
105
+ sym=cls.get_from_keys(config, ["sym"]),
106
+ packing_format=cls.get_from_keys_or(
107
+ config,
108
+ ["packing_format"],
109
+ "auto_round:auto_gptq",
110
+ ),
111
+ block_name_to_quantize=cls.get_from_keys_or(
112
+ config, ["block_name_to_quantize", "to_quant_block_names"], None
113
+ ),
114
+ extra_config=cls.get_from_keys_or(config, ["extra_config"], None),
115
+ data_type=cls.get_from_keys_or(config, ["data_type"], "int"),
116
+ backend=cls.get_from_keys_or(
117
+ config, ["backend", "vllm_backend", "sglang_backend"], "auto"
118
+ ),
119
+ )
120
+
121
+ def get_scaled_act_names(self) -> list[str]:
122
+ """Returns the activation function names that should be post-scaled.
123
+
124
+ For now, this is only used by AWQ.
125
+ """
126
+ raise NotImplementedError
127
+
128
+ def get_layer_config(self, layer, layer_name: str):
129
+ from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
130
+
131
+ def get_config(name: str, quantized: bool = True):
132
+ if not self.extra_config:
133
+ return (
134
+ self.weight_bits if quantized else 16,
135
+ self.group_size if quantized else -1,
136
+ self.sym if quantized else True,
137
+ )
138
+
139
+ # Exact match first
140
+ if name in self.extra_config:
141
+ cfg = self.extra_config[name]
142
+ return (
143
+ cfg.get("bits", self.weight_bits if quantized else 16),
144
+ cfg.get("group_size", self.group_size if quantized else -1),
145
+ cfg.get("sym", self.sym if quantized else True),
146
+ )
147
+
148
+ REGEX_SPECIAL_CHARS = set(r"*+?^$()[]{}|\\")
149
+ for pattern, cfg in self.extra_config.items():
150
+ if not isinstance(pattern, str) or not any(
151
+ c in REGEX_SPECIAL_CHARS for c in pattern
152
+ ):
153
+ continue
154
+
155
+ try:
156
+ if re.fullmatch(pattern, name):
157
+ return (
158
+ cfg.get("bits", self.weight_bits if quantized else 16),
159
+ cfg.get("group_size", self.group_size if quantized else -1),
160
+ cfg.get("sym", self.sym if quantized else True),
161
+ )
162
+ except re.error:
163
+ # Invalid regex, ignore.
164
+ continue
165
+
166
+ return (
167
+ self.weight_bits if quantized else 16,
168
+ self.group_size if quantized else -1,
169
+ self.sym if quantized else True,
170
+ )
171
+
172
+ # 1. Exact match from config
173
+ if self.extra_config and layer_name in self.extra_config:
174
+ return get_config(layer_name)
175
+
176
+ # 2. Determine whether layer should be quantized
177
+ quantized = not isinstance(layer, ParallelLMHead)
178
+ if self.block_name_to_quantize:
179
+ quantized = any(
180
+ layer_name.startswith(name) for name in self.block_name_to_quantize
181
+ )
182
+
183
+ # 3. Handle fused MoE
184
+ if self.extra_config and "fusedmoe" in layer.__class__.__name__.lower():
185
+ moe_configs = [
186
+ get_config(name, quantized)
187
+ for name in self.extra_config
188
+ if name.startswith(layer_name)
189
+ ]
190
+ if moe_configs:
191
+ if len(set(moe_configs)) == 1:
192
+ return moe_configs[0]
193
+ raise ValueError(
194
+ f"Fused MoE layer '{layer_name}' requires "
195
+ f"consistent quant config for all sub-layers"
196
+ )
197
+
198
+ # 4. Handle fused QKV or other patterns
199
+ if self.extra_config:
200
+ for fusion_key, sub_keys in self.packed_modules_mapping.items():
201
+ if fusion_key in layer_name and layer_name.count(fusion_key) == 1:
202
+ sub_names = [
203
+ layer_name.replace(fusion_key, sub_key) for sub_key in sub_keys
204
+ ]
205
+ sub_configs = [get_config(name, quantized) for name in sub_names]
206
+ if len(set(sub_configs)) == 1:
207
+ return sub_configs[0]
208
+ raise ValueError(
209
+ f"Fused module '{layer_name}' requires "
210
+ f"consistent quant config for {sub_names}"
211
+ )
212
+
213
+ # 5. Fallback or try a regular expression match
214
+ return get_config(layer_name, quantized)
215
+
216
+ def check_quantized(self, weight_bits: int) -> bool:
217
+ return weight_bits < 16
218
+
219
+ def apply_awq_quant_layer(self, layer, prefix: str, backend: str = "auto"):
220
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
221
+ from sglang.srt.layers.quantization.marlin_utils import (
222
+ check_marlin_supported,
223
+ check_moe_marlin_supports_layer,
224
+ )
225
+ from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
226
+
227
+ weight_bits, group_size, sym = self.get_layer_config(layer, prefix)
228
+ if not self.check_quantized(weight_bits):
229
+ if isinstance(layer, (LinearBase, ParallelLMHead)):
230
+ return UnquantizedLinearMethod()
231
+ else:
232
+ return None
233
+ logger.debug(
234
+ "[%s] Type: %s, Bits: %s, Group Size: %s, Sym: %s",
235
+ prefix,
236
+ layer.__class__.__name__,
237
+ weight_bits,
238
+ group_size,
239
+ sym,
240
+ )
241
+ if backend == "auto" or "marlin" in backend:
242
+ AWQ_TYPE_MAP = {
243
+ 4: scalar_types.uint4,
244
+ 8: scalar_types.uint8,
245
+ }
246
+ use_marlin = (weight_bits in AWQ_TYPE_MAP) and check_marlin_supported(
247
+ AWQ_TYPE_MAP[weight_bits], group_size, not sym
248
+ )
249
+ if isinstance(layer, FusedMoE):
250
+ use_marlin = use_marlin and check_moe_marlin_supports_layer(
251
+ layer, group_size
252
+ )
253
+
254
+ else:
255
+ use_marlin = False
256
+ if use_marlin:
257
+ from sglang.srt.layers.quantization.awq import (
258
+ AWQMarlinConfig,
259
+ AWQMarlinLinearMethod,
260
+ AWQMoEMethod,
261
+ )
262
+
263
+ quant_args_marlin = AWQMarlinConfig(
264
+ weight_bits=weight_bits,
265
+ group_size=group_size,
266
+ zero_point=not sym,
267
+ lm_head_quantized=False,
268
+ full_config={},
269
+ modules_to_not_convert=[],
270
+ )
271
+ else:
272
+ from sglang.srt.layers.quantization.awq import AWQConfig, AWQLinearMethod
273
+
274
+ quant_args = AWQConfig(
275
+ weight_bits=weight_bits,
276
+ group_size=group_size,
277
+ zero_point=not sym,
278
+ )
279
+
280
+ if isinstance(layer, FusedMoE):
281
+ if use_marlin:
282
+ return AWQMoEMethod(quant_args_marlin)
283
+ from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
284
+
285
+ config = {
286
+ "quant_method": "awq",
287
+ "bits": weight_bits,
288
+ "group_size": group_size,
289
+ "zero_point": not sym,
290
+ "lm_head": False,
291
+ }
292
+ return MoeWNA16Config.from_config(config).get_quant_method(layer, prefix)
293
+
294
+ if isinstance(layer, (LinearBase, ParallelLMHead)):
295
+ if use_marlin:
296
+ return AWQMarlinLinearMethod(quant_args_marlin)
297
+ else:
298
+ return AWQLinearMethod(quant_args)
299
+ return None
300
+
301
+ def apply_gptq_quant_layer(self, layer, prefix: str, backend: str = "auto"):
302
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
303
+ from sglang.srt.layers.quantization.marlin_utils import (
304
+ check_marlin_supported,
305
+ check_moe_marlin_supports_layer,
306
+ )
307
+ from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
308
+
309
+ weight_bits, group_size, sym = self.get_layer_config(layer, prefix)
310
+ if not self.check_quantized(weight_bits):
311
+ if isinstance(layer, (LinearBase, ParallelLMHead)):
312
+ return UnquantizedLinearMethod()
313
+ else:
314
+ return None
315
+
316
+ logger.debug(
317
+ "[%s] Type: %s, Bits: %s, Group Size: %s, Sym: %s",
318
+ prefix,
319
+ layer.__class__.__name__,
320
+ weight_bits,
321
+ group_size,
322
+ sym,
323
+ )
324
+ if backend == "auto" or "marlin" in backend:
325
+ GPTQ_TYPE_MAP = {
326
+ (4, True): scalar_types.uint4b8,
327
+ (8, True): scalar_types.uint8b128,
328
+ }
329
+ use_marlin = (weight_bits, sym) in GPTQ_TYPE_MAP and check_marlin_supported(
330
+ GPTQ_TYPE_MAP[(weight_bits, sym)], group_size, has_zp=not sym
331
+ )
332
+ if isinstance(layer, FusedMoE):
333
+ use_marlin = use_marlin and check_moe_marlin_supports_layer(
334
+ layer, group_size
335
+ )
336
+ else:
337
+ use_marlin = False
338
+ if use_marlin:
339
+ from sglang.srt.layers.quantization.gptq import (
340
+ GPTQMarlinConfig,
341
+ GPTQMarlinLinearMethod,
342
+ GPTQMarlinMoEMethod,
343
+ )
344
+
345
+ quant_args_marlin = GPTQMarlinConfig(
346
+ weight_bits=weight_bits,
347
+ group_size=group_size,
348
+ is_sym=sym,
349
+ lm_head_quantized=False,
350
+ desc_act=False,
351
+ dynamic={},
352
+ full_config={},
353
+ )
354
+ else:
355
+ from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQLinearMethod
356
+
357
+ quant_args = GPTQConfig(
358
+ weight_bits=weight_bits,
359
+ group_size=group_size,
360
+ lm_head_quantized=False,
361
+ desc_act=False,
362
+ dynamic={},
363
+ )
364
+
365
+ if isinstance(layer, FusedMoE):
366
+ if use_marlin:
367
+ from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
368
+
369
+ config = {
370
+ "quant_method": "gptq",
371
+ "bits": weight_bits,
372
+ "group_size": group_size,
373
+ "sym": sym,
374
+ "lm_head": False,
375
+ }
376
+ return MoeWNA16Config.from_config(config).get_quant_method(
377
+ layer, prefix
378
+ )
379
+ return GPTQMarlinMoEMethod(quant_args_marlin)
380
+
381
+ if isinstance(layer, (LinearBase, ParallelLMHead)):
382
+ if use_marlin:
383
+ return GPTQMarlinLinearMethod(quant_args_marlin)
384
+ else:
385
+ return GPTQLinearMethod(quant_args)
386
+
387
+ return None
388
+
389
+ def get_quant_method(self, layer: torch.nn.Module, prefix: str):
390
+ # TODO enable CPU quant method later
391
+ if "gptq" in self.packing_format or "gptq" in self.backend:
392
+ return self.apply_gptq_quant_layer(layer, prefix)
393
+ if "awq" in self.packing_format or "awq" in self.backend:
394
+ return self.apply_awq_quant_layer(layer, prefix)
@@ -840,12 +840,9 @@ class AWQMoEMethod(FusedMoEMethodBase):
840
840
  self.moe_runner_config.activation == "silu"
841
841
  ), "Only SiLU activation is supported."
842
842
 
843
- # The input must currently be float16
844
843
  x = dispatch_output.hidden_states
845
844
  topk_output = dispatch_output.topk_output
846
-
847
845
  orig_dtype = x.dtype
848
- x = x.half()
849
846
 
850
847
  topk_weights, topk_ids, router_logits = topk_output
851
848
 
@@ -179,6 +179,13 @@ class QuantizationConfig(ABC):
179
179
  elif "NVFP4" in quant_algo or "FP4" in quant_algo:
180
180
  return "modelopt_fp4"
181
181
 
182
+ # The hf_quant_config may be a parsed quant config, so we need to check the
183
+ # quant_method.
184
+ if hf_quant_config.get("quant_method", "") == "modelopt_fp8":
185
+ return "modelopt_fp8"
186
+ elif hf_quant_config.get("quant_method", "") == "modelopt_fp4":
187
+ return "modelopt_fp4"
188
+
182
189
  return None
183
190
 
184
191
  @staticmethod
@@ -33,6 +33,7 @@ from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
33
33
  from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
34
34
  from sglang.srt.layers.moe.moe_runner.deep_gemm import DeepGemmMoeQuantInfo
35
35
  from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
36
+ from sglang.srt.layers.moe.utils import get_moe_runner_backend
36
37
  from sglang.srt.layers.parameter import (
37
38
  BlockQuantScaleParameter,
38
39
  ModelWeightParameter,
@@ -525,12 +526,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
525
526
  self.quant_config = quant_config
526
527
  self.block_quant = self.quant_config.weight_block_size is not None
527
528
  self.cutlass_fp8_supported = cutlass_fp8_supported()
528
- self.use_cutlass_fused_experts_fp8 = (
529
- get_bool_env_var("SGLANG_CUTLASS_MOE")
530
- and self.cutlass_fp8_supported
531
- and self.block_quant
532
- and (is_sm100_supported() or is_sm90_supported())
533
- )
534
529
 
535
530
  def create_weights(
536
531
  self,
@@ -638,58 +633,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
638
633
  layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
639
634
  layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
640
635
  assert self.quant_config.activation_scheme == "dynamic"
641
- if self.use_cutlass_fused_experts_fp8:
642
- self.ab_strides1 = torch.full(
643
- (num_experts,),
644
- hidden_size,
645
- device=w13_weight.device,
646
- dtype=torch.int64,
647
- )
648
- self.c_strides1 = torch.full(
649
- (num_experts,),
650
- 2 * intermediate_size_per_partition,
651
- device=w13_weight.device,
652
- dtype=torch.int64,
653
- )
654
- self.ab_strides2 = torch.full(
655
- (num_experts,),
656
- intermediate_size_per_partition,
657
- device=w2_weight.device,
658
- dtype=torch.int64,
659
- )
660
- self.c_strides2 = torch.full(
661
- (num_experts,),
662
- hidden_size,
663
- device=w2_weight.device,
664
- dtype=torch.int64,
665
- )
666
- self.workspace = torch.empty(
667
- 90000, device=w13_weight.device, dtype=torch.uint8
668
- )
669
- self.a_ptr = torch.empty(
670
- num_experts, device=w13_weight.device, dtype=torch.int64
671
- )
672
- self.b_ptr = torch.empty(
673
- num_experts, device=w13_weight.device, dtype=torch.int64
674
- )
675
- self.out_ptr = torch.empty(
676
- num_experts, device=w13_weight.device, dtype=torch.int64
677
- )
678
- self.a_scales_ptr = torch.empty(
679
- num_experts, device=w13_weight.device, dtype=torch.int64
680
- )
681
- self.b_scales_ptr = torch.empty(
682
- num_experts, device=w13_weight.device, dtype=torch.int64
683
- )
684
- self.expert_offsets = torch.empty(
685
- num_experts + 1, device=w13_weight.device, dtype=torch.int32
686
- )
687
- self.problem_sizes1 = torch.empty(
688
- num_experts, 3, device=w13_weight.device, dtype=torch.int32
689
- )
690
- self.problem_sizes2 = torch.empty(
691
- num_experts, 3, device=w13_weight.device, dtype=torch.int32
692
- )
636
+ if self._should_use_cutlass_fused_experts():
637
+ self._ensure_cutlass_buffers_initialized(layer)
693
638
 
694
639
  else:
695
640
  # Allocate 2 scales for w1 and w3 respectively.
@@ -1039,13 +984,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
1039
984
  from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
1040
985
 
1041
986
  x = dispatch_output.hidden_states
1042
- topk_output = dispatch_output.topk_output
1043
987
  moe_runner_config = self.moe_runner_config
1044
988
 
1045
989
  if use_intel_amx_backend(layer):
1046
990
  from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
1047
991
 
1048
- topk_weights, topk_ids, _ = topk_output
992
+ topk_weights, topk_ids, _ = dispatch_output.topk_output
1049
993
  x, topk_weights = apply_topk_weights_cpu(
1050
994
  moe_runner_config.apply_router_weight_on_input, topk_weights, x
1051
995
  )
@@ -1072,17 +1016,17 @@ class Fp8MoEMethod(FusedMoEMethodBase):
1072
1016
  ret = self.maybe_apply_hip_fused_experts(
1073
1017
  layer,
1074
1018
  x,
1075
- topk_output,
1019
+ dispatch_output.topk_output,
1076
1020
  moe_runner_config.activation,
1077
1021
  moe_runner_config.no_combine,
1078
1022
  )
1079
1023
  if ret is not None:
1080
1024
  return StandardCombineInput(hidden_states=ret)
1081
1025
 
1082
- if self.use_cutlass_fused_experts_fp8:
1026
+ if self._should_use_cutlass_fused_experts():
1083
1027
  from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8
1084
1028
 
1085
- topk_weights, topk_ids, _ = topk_output
1029
+ topk_weights, topk_ids, _ = dispatch_output.topk_output
1086
1030
  output = cutlass_fused_experts_fp8(
1087
1031
  x,
1088
1032
  layer.w13_weight.transpose(1, 2),
@@ -1171,6 +1115,67 @@ class Fp8MoEMethod(FusedMoEMethodBase):
1171
1115
 
1172
1116
  return self.runner.run(dispatch_output, quant_info)
1173
1117
 
1118
+ def _should_use_cutlass_fused_experts(self) -> bool:
1119
+ """Decide whether to use Cutlass FP8 fused-experts path based on moe runner backend,
1120
+ with env var override via `SGLANG_CUTLASS_MOE`.
1121
+ """
1122
+ backend = get_moe_runner_backend()
1123
+ env_force = get_bool_env_var("SGLANG_CUTLASS_MOE")
1124
+ # TODO: remove env var in the future, it should be handled by moe runner backend
1125
+ if env_force:
1126
+ return True
1127
+ return (
1128
+ backend.is_flashinfer_cutlass()
1129
+ and self.cutlass_fp8_supported
1130
+ and self.block_quant
1131
+ and (is_sm100_supported() or is_sm90_supported())
1132
+ )
1133
+
1134
+ def _ensure_cutlass_buffers_initialized(self, layer: Module) -> None:
1135
+ if getattr(self, "_cutlass_buffers_ready", False):
1136
+ return
1137
+
1138
+ device = layer.w13_weight.device
1139
+ num_experts = layer.w13_weight.shape[0]
1140
+ hidden_size = layer.w2_weight.shape[1]
1141
+ intermediate_size_per_partition = layer.intermediate_size_per_partition
1142
+
1143
+ self.ab_strides1 = torch.full(
1144
+ (num_experts,), hidden_size, device=device, dtype=torch.int64
1145
+ )
1146
+ self.c_strides1 = torch.full(
1147
+ (num_experts,),
1148
+ 2 * intermediate_size_per_partition,
1149
+ device=device,
1150
+ dtype=torch.int64,
1151
+ )
1152
+ self.ab_strides2 = torch.full(
1153
+ (num_experts,),
1154
+ intermediate_size_per_partition,
1155
+ device=device,
1156
+ dtype=torch.int64,
1157
+ )
1158
+ self.c_strides2 = torch.full(
1159
+ (num_experts,), hidden_size, device=device, dtype=torch.int64
1160
+ )
1161
+ self.workspace = torch.empty(90000, device=device, dtype=torch.uint8)
1162
+ self.a_ptr = torch.empty(num_experts, device=device, dtype=torch.int64)
1163
+ self.b_ptr = torch.empty(num_experts, device=device, dtype=torch.int64)
1164
+ self.out_ptr = torch.empty(num_experts, device=device, dtype=torch.int64)
1165
+ self.a_scales_ptr = torch.empty(num_experts, device=device, dtype=torch.int64)
1166
+ self.b_scales_ptr = torch.empty(num_experts, device=device, dtype=torch.int64)
1167
+ self.expert_offsets = torch.empty(
1168
+ num_experts + 1, device=device, dtype=torch.int32
1169
+ )
1170
+ self.problem_sizes1 = torch.empty(
1171
+ num_experts, 3, device=device, dtype=torch.int32
1172
+ )
1173
+ self.problem_sizes2 = torch.empty(
1174
+ num_experts, 3, device=device, dtype=torch.int32
1175
+ )
1176
+
1177
+ self._cutlass_buffers_ready = True
1178
+
1174
1179
  def apply_with_router_logits(
1175
1180
  self,
1176
1181
  layer: torch.nn.Module,
@@ -459,7 +459,7 @@ def create_per_token_group_quant_fp8_output_scale(
459
459
  x_shape[:-2] + (x_shape[-1] // group_size, aligned_size),
460
460
  device=device,
461
461
  dtype=torch.float32,
462
- ).permute(-1, -2)[: x_shape[-2], :]
462
+ ).transpose(-1, -2)[: x_shape[-2], :]
463
463
  else:
464
464
  return torch.empty(
465
465
  (x_shape[-1] // group_size,) + x_shape[:-1],
@@ -5,7 +5,7 @@ import torch
5
5
  from sglang.srt.layers import deep_gemm_wrapper
6
6
  from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
7
7
  from sglang.srt.layers.quantization.mxfp4_tensor import MXFP4QuantizeUtil
8
- from sglang.srt.utils import ceil_div, is_sm100_supported, offloader
8
+ from sglang.srt.utils import ceil_div, is_blackwell_supported, offloader
9
9
 
10
10
  try:
11
11
  from vllm import _custom_ops as ops
@@ -129,7 +129,7 @@ def cutlass_block_fp8_supported() -> bool:
129
129
  CUTLASS_BLOCK_FP8_SUPPORTED = cutlass_block_fp8_supported()
130
130
  ENABLE_FLASHINFER_GEMM = (
131
131
  get_bool_env_var("SGLANG_ENABLE_FLASHINFER_GEMM")
132
- and is_sm100_supported()
132
+ and is_blackwell_supported()
133
133
  and is_flashinfer_available()
134
134
  )
135
135
  if ENABLE_FLASHINFER_GEMM: