sglang 0.5.0rc2__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. sglang/bench_one_batch.py +0 -6
  2. sglang/bench_one_batch_server.py +7 -2
  3. sglang/bench_serving.py +3 -3
  4. sglang/eval/llama3_eval.py +0 -1
  5. sglang/srt/configs/model_config.py +24 -9
  6. sglang/srt/configs/update_config.py +40 -5
  7. sglang/srt/constrained/xgrammar_backend.py +23 -11
  8. sglang/srt/conversation.py +2 -15
  9. sglang/srt/disaggregation/ascend/conn.py +1 -3
  10. sglang/srt/disaggregation/base/conn.py +1 -0
  11. sglang/srt/disaggregation/decode.py +1 -1
  12. sglang/srt/disaggregation/launch_lb.py +7 -1
  13. sglang/srt/disaggregation/mini_lb.py +11 -5
  14. sglang/srt/disaggregation/mooncake/conn.py +141 -47
  15. sglang/srt/disaggregation/prefill.py +261 -5
  16. sglang/srt/disaggregation/utils.py +2 -1
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  18. sglang/srt/distributed/device_communicators/pynccl.py +68 -18
  19. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
  20. sglang/srt/distributed/naive_distributed.py +112 -0
  21. sglang/srt/distributed/parallel_state.py +90 -4
  22. sglang/srt/entrypoints/context.py +20 -1
  23. sglang/srt/entrypoints/engine.py +27 -2
  24. sglang/srt/entrypoints/http_server.py +12 -0
  25. sglang/srt/entrypoints/openai/protocol.py +2 -2
  26. sglang/srt/entrypoints/openai/serving_chat.py +22 -6
  27. sglang/srt/entrypoints/openai/serving_completions.py +9 -1
  28. sglang/srt/entrypoints/openai/serving_responses.py +2 -2
  29. sglang/srt/eplb/expert_distribution.py +2 -3
  30. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  31. sglang/srt/hf_transformers_utils.py +24 -0
  32. sglang/srt/host_shared_memory.py +83 -0
  33. sglang/srt/layers/attention/ascend_backend.py +132 -22
  34. sglang/srt/layers/attention/flashattention_backend.py +24 -17
  35. sglang/srt/layers/attention/flashinfer_backend.py +11 -3
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +226 -76
  37. sglang/srt/layers/attention/triton_backend.py +85 -46
  38. sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
  39. sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
  40. sglang/srt/layers/attention/trtllm_mha_backend.py +390 -30
  41. sglang/srt/layers/attention/trtllm_mla_backend.py +39 -16
  42. sglang/srt/layers/attention/utils.py +94 -15
  43. sglang/srt/layers/attention/vision.py +40 -13
  44. sglang/srt/layers/attention/vision_utils.py +65 -0
  45. sglang/srt/layers/communicator.py +51 -3
  46. sglang/srt/layers/dp_attention.py +23 -4
  47. sglang/srt/layers/elementwise.py +94 -0
  48. sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
  49. sglang/srt/layers/layernorm.py +8 -1
  50. sglang/srt/layers/linear.py +24 -0
  51. sglang/srt/layers/logits_processor.py +5 -1
  52. sglang/srt/layers/moe/__init__.py +31 -0
  53. sglang/srt/layers/moe/ep_moe/layer.py +37 -33
  54. sglang/srt/layers/moe/fused_moe_native.py +14 -25
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
  60. sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
  61. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
  62. sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
  63. sglang/srt/layers/moe/moe_runner/base.py +13 -0
  64. sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
  65. sglang/srt/layers/moe/router.py +15 -9
  66. sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
  67. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
  68. sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
  69. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  70. sglang/srt/layers/moe/topk.py +167 -83
  71. sglang/srt/layers/moe/utils.py +159 -18
  72. sglang/srt/layers/quantization/__init__.py +13 -14
  73. sglang/srt/layers/quantization/awq.py +7 -7
  74. sglang/srt/layers/quantization/base_config.py +2 -6
  75. sglang/srt/layers/quantization/blockwise_int8.py +4 -12
  76. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -28
  77. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
  78. sglang/srt/layers/quantization/fp8.py +127 -119
  79. sglang/srt/layers/quantization/fp8_kernel.py +195 -24
  80. sglang/srt/layers/quantization/fp8_utils.py +34 -9
  81. sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
  82. sglang/srt/layers/quantization/gptq.py +5 -4
  83. sglang/srt/layers/quantization/marlin_utils.py +11 -3
  84. sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
  85. sglang/srt/layers/quantization/modelopt_quant.py +165 -68
  86. sglang/srt/layers/quantization/moe_wna16.py +10 -15
  87. sglang/srt/layers/quantization/mxfp4.py +206 -37
  88. sglang/srt/layers/quantization/quark/quark.py +390 -0
  89. sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
  90. sglang/srt/layers/quantization/unquant.py +34 -70
  91. sglang/srt/layers/quantization/utils.py +25 -0
  92. sglang/srt/layers/quantization/w4afp8.py +7 -8
  93. sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
  94. sglang/srt/layers/quantization/w8a8_int8.py +5 -13
  95. sglang/srt/layers/radix_attention.py +6 -0
  96. sglang/srt/layers/rotary_embedding.py +1 -0
  97. sglang/srt/lora/lora_manager.py +21 -22
  98. sglang/srt/lora/lora_registry.py +3 -3
  99. sglang/srt/lora/mem_pool.py +26 -24
  100. sglang/srt/lora/utils.py +10 -12
  101. sglang/srt/managers/cache_controller.py +76 -18
  102. sglang/srt/managers/detokenizer_manager.py +10 -2
  103. sglang/srt/managers/io_struct.py +9 -0
  104. sglang/srt/managers/mm_utils.py +1 -1
  105. sglang/srt/managers/schedule_batch.py +4 -9
  106. sglang/srt/managers/scheduler.py +25 -16
  107. sglang/srt/managers/session_controller.py +1 -1
  108. sglang/srt/managers/template_manager.py +7 -5
  109. sglang/srt/managers/tokenizer_manager.py +60 -21
  110. sglang/srt/managers/tp_worker.py +1 -0
  111. sglang/srt/managers/utils.py +59 -1
  112. sglang/srt/mem_cache/allocator.py +7 -5
  113. sglang/srt/mem_cache/allocator_ascend.py +0 -11
  114. sglang/srt/mem_cache/hicache_storage.py +14 -4
  115. sglang/srt/mem_cache/memory_pool.py +3 -3
  116. sglang/srt/mem_cache/memory_pool_host.py +35 -2
  117. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
  118. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
  119. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
  120. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
  121. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
  122. sglang/srt/model_executor/cuda_graph_runner.py +25 -12
  123. sglang/srt/model_executor/forward_batch_info.py +4 -1
  124. sglang/srt/model_executor/model_runner.py +43 -32
  125. sglang/srt/model_executor/npu_graph_runner.py +94 -0
  126. sglang/srt/model_loader/loader.py +24 -6
  127. sglang/srt/models/dbrx.py +12 -6
  128. sglang/srt/models/deepseek.py +2 -1
  129. sglang/srt/models/deepseek_nextn.py +3 -1
  130. sglang/srt/models/deepseek_v2.py +224 -223
  131. sglang/srt/models/ernie4.py +2 -2
  132. sglang/srt/models/glm4_moe.py +25 -63
  133. sglang/srt/models/glm4v.py +52 -1
  134. sglang/srt/models/glm4v_moe.py +8 -11
  135. sglang/srt/models/gpt_oss.py +34 -74
  136. sglang/srt/models/granitemoe.py +0 -1
  137. sglang/srt/models/grok.py +376 -48
  138. sglang/srt/models/interns1.py +12 -47
  139. sglang/srt/models/internvl.py +6 -51
  140. sglang/srt/models/llama4.py +0 -2
  141. sglang/srt/models/minicpm3.py +0 -1
  142. sglang/srt/models/mixtral.py +0 -2
  143. sglang/srt/models/nemotron_nas.py +435 -0
  144. sglang/srt/models/olmoe.py +0 -1
  145. sglang/srt/models/phi4mm.py +3 -21
  146. sglang/srt/models/qwen2_5_vl.py +2 -0
  147. sglang/srt/models/qwen2_moe.py +3 -18
  148. sglang/srt/models/qwen3.py +2 -2
  149. sglang/srt/models/qwen3_classification.py +7 -1
  150. sglang/srt/models/qwen3_moe.py +9 -38
  151. sglang/srt/models/step3_vl.py +2 -1
  152. sglang/srt/models/xverse_moe.py +11 -5
  153. sglang/srt/multimodal/processors/base_processor.py +3 -3
  154. sglang/srt/multimodal/processors/internvl.py +7 -2
  155. sglang/srt/multimodal/processors/llava.py +11 -7
  156. sglang/srt/offloader.py +433 -0
  157. sglang/srt/operations.py +6 -1
  158. sglang/srt/reasoning_parser.py +4 -3
  159. sglang/srt/server_args.py +237 -104
  160. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
  161. sglang/srt/speculative/eagle_utils.py +36 -13
  162. sglang/srt/speculative/eagle_worker.py +56 -3
  163. sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
  164. sglang/srt/two_batch_overlap.py +16 -11
  165. sglang/srt/utils.py +68 -70
  166. sglang/test/runners.py +8 -5
  167. sglang/test/test_block_fp8.py +5 -6
  168. sglang/test/test_block_fp8_ep.py +13 -19
  169. sglang/test/test_cutlass_moe.py +4 -6
  170. sglang/test/test_cutlass_w4a8_moe.py +4 -3
  171. sglang/test/test_fp4_moe.py +4 -3
  172. sglang/test/test_utils.py +7 -0
  173. sglang/utils.py +0 -1
  174. sglang/version.py +1 -1
  175. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/METADATA +7 -7
  176. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/RECORD +179 -161
  177. sglang/srt/layers/quantization/fp4.py +0 -557
  178. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
  179. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
  180. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,390 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import fnmatch
4
+ import logging
5
+ from typing import Any, List, Optional, cast
6
+
7
+ import torch
8
+
9
+ from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
10
+ from sglang.srt.layers.quantization.base_config import ( # noqa: E501
11
+ LinearMethodBase,
12
+ QuantizationConfig,
13
+ QuantizeMethodBase,
14
+ )
15
+ from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
16
+ from sglang.srt.layers.quantization.quark.quark_moe import QuarkMoEMethod
17
+ from sglang.srt.layers.quantization.quark.schemes import QuarkScheme, QuarkW4A4MXFP4
18
+ from sglang.srt.layers.quantization.quark.utils import deep_compare, should_ignore_layer
19
+ from sglang.srt.layers.radix_attention import RadixAttention
20
+ from sglang.srt.utils import get_device_capability
21
+
22
+ __all__ = ["QuarkLinearMethod"]
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ class QuarkConfig(QuantizationConfig):
28
+
29
+ def __init__(
30
+ self,
31
+ quant_config: dict[str, Any],
32
+ kv_cache_group: Optional[list[str]] = None,
33
+ kv_cache_config: Optional[dict[str, Any]] = None,
34
+ pack_method: str = "reorder",
35
+ ):
36
+ super().__init__()
37
+ if kv_cache_group is None:
38
+ kv_cache_group = []
39
+ self.quant_config = quant_config
40
+ self.kv_cache_group = kv_cache_group
41
+ self.kv_cache_config = kv_cache_config
42
+ self.pack_method = pack_method
43
+
44
+ self.packed_modules_mapping = self.quant_config["packed_modules_mapping"]
45
+
46
+ def get_linear_method(self) -> "QuarkLinearMethod":
47
+ return QuarkLinearMethod(self)
48
+
49
+ @classmethod
50
+ def get_supported_act_dtypes(cls) -> list[torch.dtype]:
51
+ return [torch.float16, torch.bfloat16]
52
+
53
+ @classmethod
54
+ def get_min_capability(cls) -> int:
55
+ return 70
56
+
57
+ def get_name(self) -> str:
58
+ return "quark"
59
+
60
+ def get_quant_method(
61
+ self, layer: torch.nn.Module, prefix: str
62
+ ) -> Optional["QuantizeMethodBase"]:
63
+ # Check if the layer is skipped for quantization.
64
+ exclude_layers = cast(list[str], self.quant_config.get("exclude"))
65
+ if should_ignore_layer(
66
+ prefix, ignore=exclude_layers, fused_mapping=self.packed_modules_mapping
67
+ ):
68
+ return UnquantizedLinearMethod()
69
+
70
+ if isinstance(layer, LinearBase):
71
+ scheme = self.get_scheme(layer=layer, layer_name=prefix)
72
+ layer.scheme = scheme
73
+ return QuarkLinearMethod(self)
74
+
75
+ if isinstance(layer, RadixAttention):
76
+ return QuarkKVCacheMethod(self)
77
+
78
+ from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
79
+
80
+ if isinstance(layer, FusedMoE):
81
+ return QuarkMoEMethod.get_moe_method(self, module=layer, layer_name=prefix)
82
+
83
+ return None
84
+
85
+ @classmethod
86
+ def from_config(cls, config: dict[str, Any]) -> "QuarkConfig":
87
+ export_config = config.get("export")
88
+ if export_config is None:
89
+ raise ValueError(
90
+ "The export key should be included in "
91
+ "the configurations of Quark quantized model"
92
+ )
93
+
94
+ kv_cache_group = cast(list[str], export_config.get("kv_cache_group"))
95
+ pack_method = cast(str, export_config.get("pack_method"))
96
+
97
+ # In the export model of quark, the quantization configuration
98
+ # of kv_cache is stored in layer_quant_config. First, it is
99
+ # judged whether kv_cache_group exists, and then it is judged
100
+ # whether layer_quant_config has a quantization configuration
101
+ # that matches kv_cache.
102
+ if len(kv_cache_group) == 0:
103
+ kv_cache_config = None
104
+ else:
105
+ kv_cache_set = set(kv_cache_group)
106
+ layer_quant_config = cast(dict[str, Any], config.get("layer_quant_config"))
107
+ layer_quant_names = list(layer_quant_config.keys())
108
+ layer_quant_set = set(layer_quant_names)
109
+
110
+ if not kv_cache_set.issubset(layer_quant_set):
111
+ raise ValueError(
112
+ "The Quark quantized model has the "
113
+ "kv_cache_group parameter setting, "
114
+ "but no kv_cache quantization settings "
115
+ "were found in the quantization "
116
+ "configuration."
117
+ )
118
+
119
+ q_configs = [
120
+ cast(dict[str, Any], layer_quant_config.get(name))
121
+ for name in kv_cache_group
122
+ ]
123
+ if not all(deep_compare(q_config, q_configs[0]) for q_config in q_configs):
124
+ raise ValueError(
125
+ "The quantization method used for kv_cache should "
126
+ "be the same, but the quantization method for the "
127
+ "kv_cache layer in the config is different."
128
+ )
129
+ kv_cache_config = q_configs[0].get("output_tensors")
130
+ if kv_cache_config is None:
131
+ raise ValueError("The kv_cache quantization configuration is empty.")
132
+
133
+ # Since we have already set kv_cache quantization configurations,
134
+ # we will remove the quantization configuration for the
135
+ # output_tensors corresponding to the kv_cache layer.
136
+ for q_config in q_configs:
137
+ q_config["output_tensors"] = None
138
+
139
+ # In case q_proj output is also quantized, remove the configuration
140
+ # to keep qkv consistency.
141
+ q_proj_q_config = cast(dict[str, Any], layer_quant_config.get("*q_proj"))
142
+ if q_proj_q_config is not None:
143
+ q_proj_q_config["output_tensors"] = None
144
+
145
+ return cls(
146
+ quant_config=config,
147
+ kv_cache_group=kv_cache_group,
148
+ kv_cache_config=kv_cache_config,
149
+ pack_method=pack_method,
150
+ )
151
+
152
+ @classmethod
153
+ def get_config_filenames(cls) -> list[str]:
154
+ return []
155
+
156
+ def _check_scheme_supported(self, min_capability: int, error: bool = True) -> bool:
157
+ capability_tuple = get_device_capability()
158
+
159
+ if capability_tuple is not None:
160
+ assert 0 <= capability_tuple[1] < 10
161
+ capability = capability_tuple[0] * 10 + capability_tuple[1]
162
+
163
+ supported = capability >= min_capability
164
+ if error and not supported:
165
+ raise RuntimeError(
166
+ "Quantization scheme is not supported for ",
167
+ f"the current GPU. Min capability: {min_capability}. ",
168
+ f"Current capability: {capability}.",
169
+ )
170
+ return supported
171
+ else:
172
+ return False
173
+
174
+ def _is_mx_fp4(
175
+ self,
176
+ weight_quant: Optional[dict[str, Any]],
177
+ input_quant: Optional[dict[str, Any]],
178
+ ) -> bool:
179
+ # Confirm weights and input quantized.
180
+ if weight_quant is None or input_quant is None:
181
+ logger.debug(
182
+ "Quark model is not in MX-FP4 format: "
183
+ "weight_quant or input_quant not set"
184
+ )
185
+ return False
186
+
187
+ # Input and weight dtype needs to be fp4.
188
+ if weight_quant.get("dtype") != "fp4" or input_quant.get("dtype") != "fp4":
189
+ logger.debug("Quark model is not in MX-FP4 format: dtype not fp4")
190
+ return False
191
+
192
+ # Input and weight qscheme needs to be per group.
193
+ if (
194
+ weight_quant.get("qscheme") != "per_group"
195
+ or input_quant.get("qscheme") != "per_group"
196
+ ):
197
+ logger.debug("Quark model is not in MX-FP4 format: not per_group")
198
+ return False
199
+
200
+ # Input and weight group size needs to be 32.
201
+ if weight_quant.get("group_size") != 32 or input_quant.get("group_size") != 32:
202
+ logger.debug("Quark model is not in MX-FP4 format: not group_size=32")
203
+ return False
204
+
205
+ # Weights need to use static quantization.
206
+ if weight_quant.get("is_dynamic") is True:
207
+ logger.debug("Quark model is not in MX-FP4 format: not weight static")
208
+ return False
209
+
210
+ # Activations need to use dynamic quantization.
211
+ if input_quant.get("is_dynamic") is False:
212
+ logger.debug("Quark model is not in MX-FP4 format: not activation dynamic")
213
+ return False
214
+
215
+ # Activations and weight scales need to be in e8m0 format.
216
+ if (
217
+ weight_quant.get("scale_format") != "e8m0"
218
+ or input_quant.get("scale_format") != "e8m0"
219
+ ):
220
+ logger.debug("Quark model is not in MX-FP4 format: not scale_format e8m0")
221
+ return False
222
+
223
+ return True
224
+
225
+ def _find_matched_config(
226
+ self, layer_name: str, module: torch.nn.Module
227
+ ) -> dict[str, Any]:
228
+
229
+ proj_name = layer_name.split(".")[-1]
230
+ if proj_name in self.packed_modules_mapping:
231
+ shard_proj_names = self.packed_modules_mapping[proj_name]
232
+
233
+ # Convert fused_name --> [shard_names]
234
+ shard_names = [
235
+ layer_name.replace(proj_name, shard_proj_name)
236
+ for shard_proj_name in shard_proj_names
237
+ ]
238
+ shard_configs = [
239
+ self._find_matched_config(shard_name, module)
240
+ for shard_name in shard_names
241
+ ]
242
+ if not all(
243
+ deep_compare(q_config, shard_configs[0]) for q_config in shard_configs
244
+ ):
245
+ raise ValueError(
246
+ f"Found a different quantization configuration for "
247
+ f"{shard_proj_names} in {layer_name}. vLLM "
248
+ "requires all to use the same scheme."
249
+ )
250
+ return shard_configs[0]
251
+ else:
252
+ layer_quant_config = cast(
253
+ dict[str, Any], self.quant_config.get("layer_quant_config")
254
+ )
255
+ for name_pattern in layer_quant_config:
256
+ if fnmatch.fnmatch(layer_name, name_pattern):
257
+ return layer_quant_config[name_pattern]
258
+
259
+ layer_type = type(module).__name__
260
+ layer_type_quant_config = cast(
261
+ dict[str, Any], self.quant_config.get("layer_type_quant_config")
262
+ )
263
+ if layer_type in layer_type_quant_config:
264
+ return layer_type_quant_config[layer_type]
265
+
266
+ global_quant_config = cast(
267
+ dict[str, Any], self.quant_config.get("global_quant_config")
268
+ )
269
+ return global_quant_config
270
+
271
+ def _get_scheme_from_config(self, config: dict[str, Any]) -> "QuarkScheme":
272
+ if config.get("output_tensors") or config.get("bias"):
273
+ raise NotImplementedError(
274
+ "Currently, Quark models with output_tensors "
275
+ "and bias quantized are not supported"
276
+ )
277
+ weight_config = cast(dict[str, Any], config.get("weight"))
278
+ input_config = cast(dict[str, Any], config.get("input_tensors"))
279
+
280
+ if self._is_mx_fp4(weight_config, input_config):
281
+ return QuarkW4A4MXFP4(weight_config, input_config)
282
+
283
+ raise NotImplementedError(
284
+ "No quark compatible scheme was found. "
285
+ f"Weight config: {weight_config}, "
286
+ f"Input config: {input_config}"
287
+ )
288
+
289
+ def get_scheme(self, layer: torch.nn.Module, layer_name: str) -> "QuarkScheme":
290
+
291
+ layer_quant_config = self._find_matched_config(layer_name, layer)
292
+
293
+ # Find the quant_scheme
294
+ scheme = self._get_scheme_from_config(layer_quant_config)
295
+
296
+ # Raise error if device does not support the scheme
297
+ # (e.g. fp8 needs ada lovelace)
298
+ self._check_scheme_supported(scheme.get_min_capability())
299
+
300
+ return scheme
301
+
302
+ def get_scaled_act_names(self) -> List[str]:
303
+ return []
304
+
305
+
306
+ class QuarkLinearMethod(LinearMethodBase):
307
+
308
+ def __init__(self, quantization_config: QuarkConfig):
309
+ self.quantization_config = quantization_config
310
+
311
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
312
+ layer.scheme.process_weights_after_loading(layer)
313
+
314
+ def create_weights(
315
+ self,
316
+ layer: torch.nn.Module,
317
+ input_size_per_partition: int,
318
+ output_partition_sizes: list[int],
319
+ input_size: int,
320
+ output_size: int,
321
+ params_dtype: torch.dtype,
322
+ **extra_weight_attrs,
323
+ ):
324
+ """
325
+ Use the CompressedTensorsScheme associated with each layer to create
326
+ the necessary parameters for the layer. See LinearMethodBase for param
327
+ details
328
+ """
329
+ weight_loader = extra_weight_attrs.get("weight_loader")
330
+ layer.scheme.create_weights(
331
+ layer=layer,
332
+ input_size=input_size,
333
+ input_size_per_partition=input_size_per_partition,
334
+ output_partition_sizes=output_partition_sizes,
335
+ output_size=output_size,
336
+ params_dtype=params_dtype,
337
+ weight_loader=weight_loader,
338
+ )
339
+
340
+ def apply(
341
+ self,
342
+ layer: torch.nn.Module,
343
+ x: torch.Tensor,
344
+ bias: Optional[torch.Tensor] = None,
345
+ ):
346
+ """
347
+ Use the output of create_weights and the CompressedTensorsScheme
348
+ associated with the layer to apply the forward pass with the
349
+ layer input. See LinearMethodBase for param details
350
+
351
+ """
352
+ scheme = layer.scheme
353
+ if scheme is None:
354
+ raise ValueError("A scheme must be defined for each layer")
355
+ return scheme.apply_weights(layer, x, bias=bias)
356
+
357
+
358
+ class QuarkKVCacheMethod(BaseKVCacheMethod):
359
+ """
360
+ Supports loading kv-cache scaling factors from quark checkpoints.
361
+ """
362
+
363
+ def __init__(self, quant_config: QuarkConfig):
364
+ self.validate_kv_cache_config(quant_config.kv_cache_config)
365
+ super().__init__(quant_config)
366
+
367
+ @staticmethod
368
+ def validate_kv_cache_config(kv_cache_config: Optional[dict[str, Any]]):
369
+ """
370
+ Validator for the kv cache configuration. Useful for controlling the
371
+ kv cache quantization schemes, that are being supported in vLLM
372
+ :param kv_cache_config: the quark kv cache scheme
373
+ """
374
+ if kv_cache_config is None:
375
+ return
376
+
377
+ dtype = kv_cache_config.get("dtype")
378
+ if dtype != "fp8_e4m3":
379
+ raise NotImplementedError(
380
+ "Currently supported kv cache quantization is "
381
+ f"dtype=fp8_e4m3, however received {dtype}"
382
+ )
383
+
384
+ qscheme = kv_cache_config.get("qscheme")
385
+ if qscheme != "per_tensor":
386
+ raise NotImplementedError(
387
+ "Only support per-tensor scaling factor "
388
+ "for quark KV cache. "
389
+ f"Expected qscheme: per_tensor, found qscheme: {qscheme}"
390
+ )
@@ -0,0 +1,197 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ from typing import TYPE_CHECKING, Any, Callable, Optional
7
+
8
+ import torch
9
+ from aiter import ActivationType, QuantType, biased_grouped_topk
10
+ from aiter.fused_moe import fused_moe
11
+ from aiter.utility.fp4_utils import e8m0_shuffle
12
+
13
+ from sglang.srt.utils import get_bool_env_var, mxfp_supported, set_weight_attrs
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ __all__ = ["QuarkMoEMethod", "QuarkW4A4MXFp4MoEMethod"]
18
+
19
+ OCP_MX_BLOCK_SIZE = 32
20
+
21
+ if TYPE_CHECKING:
22
+ from sglang.srt.layers.moe.topk import TopKOutput
23
+
24
+
25
+ class QuarkMoEMethod:
26
+ def __new__(cls, *args, **kwargs):
27
+ from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase
28
+
29
+ if not hasattr(cls, "_initialized"):
30
+ original_init = cls.__init__
31
+ new_cls = type(
32
+ cls.__name__,
33
+ (FusedMoEMethodBase,),
34
+ {
35
+ "__init__": original_init,
36
+ **{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
37
+ },
38
+ )
39
+ obj = super(new_cls, new_cls).__new__(new_cls)
40
+ obj.__init__(*args, **kwargs)
41
+ return obj
42
+ return super().__new__(cls)
43
+
44
+ @staticmethod
45
+ def get_moe_method(
46
+ quant_config: "QuarkConfig", # type: ignore # noqa E501 # noqa F821
47
+ module: torch.nn.Module,
48
+ layer_name: str,
49
+ ) -> "QuarkMoEMethod":
50
+ layer_quant_config = quant_config._find_matched_config(layer_name, module)
51
+
52
+ if layer_quant_config.get("output_tensors") or layer_quant_config.get("bias"):
53
+ raise NotImplementedError(
54
+ "Currently, Quark models with "
55
+ "output_tensors and bias "
56
+ "quantized are not supported"
57
+ )
58
+ weight_config = layer_quant_config.get("weight")
59
+ input_config = layer_quant_config.get("input_tensors")
60
+
61
+ if quant_config._is_mx_fp4(weight_config, input_config):
62
+ return QuarkW4A4MXFp4MoEMethod(weight_config, input_config)
63
+ else:
64
+ raise RuntimeError("Unsupported FusedMoe scheme")
65
+
66
+
67
+ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
68
+
69
+ def __init__(self, weight_config: dict[str, Any], input_config: dict[str, Any]):
70
+ self.weight_quant = weight_config
71
+ self.input_quant = input_config
72
+
73
+ weight_qscheme = self.weight_quant.get("qscheme")
74
+ input_qscheme = self.input_quant.get("qscheme")
75
+ if not (weight_qscheme == "per_group" and input_qscheme == "per_group"):
76
+ raise ValueError(
77
+ "For MX(FP4) Fused MoE layers, only per-group scales "
78
+ "for weights and activations are supported. Found "
79
+ f"{weight_qscheme}, {input_qscheme}"
80
+ ) # noqa E501
81
+
82
+ self.static_input_scales = not self.input_quant.get("is_dynamic")
83
+ self.with_bias = False
84
+
85
+ def create_weights(
86
+ self,
87
+ layer: torch.nn.Module,
88
+ num_experts: int,
89
+ hidden_size: int,
90
+ intermediate_size_per_partition: int,
91
+ params_dtype: torch.dtype,
92
+ **extra_weight_attrs,
93
+ ):
94
+
95
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
96
+
97
+ # Add the quantization method used (per tensor/grouped/channel)
98
+ # to ensure the weight scales are loaded in properly
99
+ extra_weight_attrs.update(
100
+ {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
101
+ )
102
+
103
+ params_dtype = torch.uint8
104
+
105
+ # WEIGHTS
106
+ w13_weight = torch.nn.Parameter(
107
+ torch.empty(
108
+ num_experts,
109
+ 2 * intermediate_size_per_partition,
110
+ hidden_size // 2,
111
+ dtype=params_dtype,
112
+ ),
113
+ requires_grad=False,
114
+ )
115
+ layer.register_parameter("w13_weight", w13_weight)
116
+
117
+ set_weight_attrs(w13_weight, extra_weight_attrs)
118
+
119
+ w2_weight = torch.nn.Parameter(
120
+ torch.empty(
121
+ num_experts,
122
+ hidden_size,
123
+ intermediate_size_per_partition // 2,
124
+ dtype=params_dtype,
125
+ ),
126
+ requires_grad=False,
127
+ )
128
+ layer.register_parameter("w2_weight", w2_weight)
129
+
130
+ set_weight_attrs(w2_weight, extra_weight_attrs)
131
+
132
+ # WEIGHT_SCALES
133
+ w13_weight_scale = torch.nn.Parameter(
134
+ torch.ones(
135
+ num_experts,
136
+ 2 * intermediate_size_per_partition,
137
+ hidden_size // OCP_MX_BLOCK_SIZE,
138
+ dtype=params_dtype,
139
+ ),
140
+ requires_grad=False,
141
+ )
142
+ w2_weight_scale = torch.nn.Parameter(
143
+ torch.ones(
144
+ num_experts,
145
+ hidden_size,
146
+ intermediate_size_per_partition // OCP_MX_BLOCK_SIZE,
147
+ dtype=params_dtype,
148
+ ),
149
+ requires_grad=False,
150
+ )
151
+ set_weight_attrs(w2_weight_scale, extra_weight_attrs)
152
+ set_weight_attrs(w13_weight_scale, extra_weight_attrs)
153
+
154
+ layer.register_parameter("w13_weight_scale", w13_weight_scale)
155
+ layer.register_parameter("w2_weight_scale", w2_weight_scale)
156
+
157
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
158
+ float_dtype = torch.get_default_dtype()
159
+
160
+ # Pre-shuffle weight scales
161
+ s0, s1, _ = layer.w13_weight_scale.shape
162
+ w13_weight_scale = layer.w13_weight_scale.view(s0 * s1, -1)
163
+ w13_weight_scale = e8m0_shuffle(w13_weight_scale)
164
+ # layer.w13_weight_scale = torch.nn.Parameter(w13_weight_scale, requires_grad=False)
165
+ layer.w13_weight_scale.data = w13_weight_scale.view(s0, s1, -1)
166
+
167
+ s0, s1, _ = layer.w2_weight_scale.shape
168
+ w2_weight_scale = layer.w2_weight_scale.view(s0 * s1, -1)
169
+ w2_weight_scale = e8m0_shuffle(w2_weight_scale)
170
+ # layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale, requires_grad=False)
171
+ layer.w2_weight_scale.data = w2_weight_scale.view(s0, s1, -1)
172
+
173
+ def apply(
174
+ self,
175
+ layer: torch.nn.Module,
176
+ x: torch.Tensor,
177
+ topk_output: TopKOutput,
178
+ moe_runner_config: MoeRunnerConfig,
179
+ ) -> torch.Tensor:
180
+ topk_weights, topk_ids, _ = topk_output
181
+
182
+ return fused_moe(
183
+ x,
184
+ layer.w13_weight,
185
+ layer.w2_weight,
186
+ topk_weights,
187
+ topk_ids,
188
+ quant_type=QuantType.per_1x32,
189
+ w1_scale=layer.w13_weight_scale,
190
+ w2_scale=layer.w2_weight_scale,
191
+ activation=(
192
+ ActivationType.Silu
193
+ if moe_runner_config.activation == "silu"
194
+ else ActivationType.Gelu
195
+ ),
196
+ doweight_stage1=False,
197
+ )