sglang 0.5.0rc1__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (203) hide show
  1. sglang/bench_one_batch.py +0 -7
  2. sglang/bench_one_batch_server.py +7 -2
  3. sglang/bench_serving.py +3 -3
  4. sglang/eval/llama3_eval.py +0 -1
  5. sglang/srt/configs/model_config.py +25 -9
  6. sglang/srt/configs/update_config.py +40 -5
  7. sglang/srt/constrained/xgrammar_backend.py +23 -11
  8. sglang/srt/conversation.py +2 -15
  9. sglang/srt/disaggregation/ascend/conn.py +1 -3
  10. sglang/srt/disaggregation/base/conn.py +1 -0
  11. sglang/srt/disaggregation/decode.py +1 -2
  12. sglang/srt/disaggregation/launch_lb.py +7 -1
  13. sglang/srt/disaggregation/mini_lb.py +11 -5
  14. sglang/srt/disaggregation/mooncake/conn.py +141 -47
  15. sglang/srt/disaggregation/prefill.py +261 -5
  16. sglang/srt/disaggregation/utils.py +2 -1
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  18. sglang/srt/distributed/device_communicators/pynccl.py +68 -18
  19. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
  20. sglang/srt/distributed/naive_distributed.py +112 -0
  21. sglang/srt/distributed/parallel_state.py +90 -4
  22. sglang/srt/entrypoints/context.py +20 -1
  23. sglang/srt/entrypoints/engine.py +29 -4
  24. sglang/srt/entrypoints/http_server.py +76 -0
  25. sglang/srt/entrypoints/openai/protocol.py +4 -2
  26. sglang/srt/entrypoints/openai/serving_chat.py +23 -6
  27. sglang/srt/entrypoints/openai/serving_completions.py +10 -1
  28. sglang/srt/entrypoints/openai/serving_responses.py +2 -2
  29. sglang/srt/eplb/expert_distribution.py +2 -3
  30. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  31. sglang/srt/hf_transformers_utils.py +24 -0
  32. sglang/srt/host_shared_memory.py +83 -0
  33. sglang/srt/layers/attention/ascend_backend.py +132 -22
  34. sglang/srt/layers/attention/flashattention_backend.py +24 -17
  35. sglang/srt/layers/attention/flashinfer_backend.py +14 -3
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +227 -76
  37. sglang/srt/layers/attention/triton_backend.py +109 -73
  38. sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
  39. sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
  40. sglang/srt/layers/attention/trtllm_mha_backend.py +398 -36
  41. sglang/srt/layers/attention/trtllm_mla_backend.py +49 -19
  42. sglang/srt/layers/attention/utils.py +94 -15
  43. sglang/srt/layers/attention/vision.py +40 -13
  44. sglang/srt/layers/attention/vision_utils.py +65 -0
  45. sglang/srt/layers/communicator.py +58 -10
  46. sglang/srt/layers/dp_attention.py +137 -27
  47. sglang/srt/layers/elementwise.py +94 -0
  48. sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
  49. sglang/srt/layers/layernorm.py +8 -1
  50. sglang/srt/layers/linear.py +24 -0
  51. sglang/srt/layers/logits_processor.py +16 -18
  52. sglang/srt/layers/moe/__init__.py +31 -0
  53. sglang/srt/layers/moe/ep_moe/layer.py +37 -33
  54. sglang/srt/layers/moe/fused_moe_native.py +14 -25
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  63. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  64. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  65. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  68. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  69. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  70. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
  71. sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
  72. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
  73. sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
  74. sglang/srt/layers/moe/moe_runner/base.py +13 -0
  75. sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
  76. sglang/srt/layers/moe/router.py +15 -9
  77. sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
  78. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
  79. sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
  80. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  81. sglang/srt/layers/moe/topk.py +167 -83
  82. sglang/srt/layers/moe/utils.py +159 -18
  83. sglang/srt/layers/multimodal.py +156 -40
  84. sglang/srt/layers/quantization/__init__.py +18 -46
  85. sglang/srt/layers/quantization/awq.py +22 -23
  86. sglang/srt/layers/quantization/base_config.py +2 -6
  87. sglang/srt/layers/quantization/blockwise_int8.py +4 -12
  88. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -29
  89. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
  90. sglang/srt/layers/quantization/fp8.py +127 -119
  91. sglang/srt/layers/quantization/fp8_kernel.py +195 -24
  92. sglang/srt/layers/quantization/fp8_utils.py +34 -9
  93. sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
  94. sglang/srt/layers/quantization/gptq.py +17 -21
  95. sglang/srt/layers/quantization/marlin_utils.py +26 -8
  96. sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
  97. sglang/srt/layers/quantization/modelopt_quant.py +217 -98
  98. sglang/srt/layers/quantization/moe_wna16.py +10 -15
  99. sglang/srt/layers/quantization/mxfp4.py +222 -39
  100. sglang/srt/layers/quantization/quark/quark.py +390 -0
  101. sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
  102. sglang/srt/layers/quantization/unquant.py +34 -70
  103. sglang/srt/layers/quantization/utils.py +77 -2
  104. sglang/srt/layers/quantization/w4afp8.py +7 -8
  105. sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
  106. sglang/srt/layers/quantization/w8a8_int8.py +5 -13
  107. sglang/srt/layers/radix_attention.py +6 -0
  108. sglang/srt/layers/rotary_embedding.py +1 -0
  109. sglang/srt/layers/sampler.py +5 -2
  110. sglang/srt/lora/layers.py +6 -2
  111. sglang/srt/lora/lora_manager.py +21 -22
  112. sglang/srt/lora/lora_registry.py +3 -3
  113. sglang/srt/lora/mem_pool.py +26 -24
  114. sglang/srt/lora/utils.py +10 -12
  115. sglang/srt/managers/cache_controller.py +80 -19
  116. sglang/srt/managers/detokenizer_manager.py +10 -2
  117. sglang/srt/managers/io_struct.py +23 -0
  118. sglang/srt/managers/mm_utils.py +1 -1
  119. sglang/srt/managers/schedule_batch.py +22 -48
  120. sglang/srt/managers/scheduler.py +28 -20
  121. sglang/srt/managers/session_controller.py +1 -1
  122. sglang/srt/managers/template_manager.py +7 -5
  123. sglang/srt/managers/tokenizer_manager.py +88 -39
  124. sglang/srt/managers/tp_worker.py +1 -0
  125. sglang/srt/managers/utils.py +59 -1
  126. sglang/srt/mem_cache/allocator.py +10 -157
  127. sglang/srt/mem_cache/allocator_ascend.py +147 -0
  128. sglang/srt/mem_cache/chunk_cache.py +1 -1
  129. sglang/srt/mem_cache/hicache_storage.py +14 -4
  130. sglang/srt/mem_cache/memory_pool.py +3 -3
  131. sglang/srt/mem_cache/memory_pool_host.py +35 -2
  132. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
  133. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
  134. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
  135. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
  136. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
  137. sglang/srt/model_executor/cuda_graph_runner.py +33 -33
  138. sglang/srt/model_executor/forward_batch_info.py +11 -10
  139. sglang/srt/model_executor/model_runner.py +93 -78
  140. sglang/srt/model_executor/npu_graph_runner.py +94 -0
  141. sglang/srt/model_loader/loader.py +24 -6
  142. sglang/srt/models/dbrx.py +12 -6
  143. sglang/srt/models/deepseek.py +2 -1
  144. sglang/srt/models/deepseek_nextn.py +5 -2
  145. sglang/srt/models/deepseek_v2.py +226 -223
  146. sglang/srt/models/ernie4.py +2 -2
  147. sglang/srt/models/glm4_moe.py +27 -65
  148. sglang/srt/models/glm4_moe_nextn.py +2 -1
  149. sglang/srt/models/glm4v.py +52 -1
  150. sglang/srt/models/glm4v_moe.py +8 -11
  151. sglang/srt/models/gpt_oss.py +41 -76
  152. sglang/srt/models/granitemoe.py +0 -1
  153. sglang/srt/models/grok.py +376 -48
  154. sglang/srt/models/interns1.py +12 -47
  155. sglang/srt/models/internvl.py +6 -51
  156. sglang/srt/models/llama.py +10 -2
  157. sglang/srt/models/llama4.py +18 -7
  158. sglang/srt/models/minicpm3.py +0 -1
  159. sglang/srt/models/mixtral.py +0 -2
  160. sglang/srt/models/nemotron_nas.py +435 -0
  161. sglang/srt/models/olmoe.py +0 -1
  162. sglang/srt/models/phi4mm.py +3 -21
  163. sglang/srt/models/qwen2.py +2 -2
  164. sglang/srt/models/qwen2_5_vl.py +2 -0
  165. sglang/srt/models/qwen2_moe.py +23 -23
  166. sglang/srt/models/qwen3.py +2 -2
  167. sglang/srt/models/qwen3_classification.py +84 -0
  168. sglang/srt/models/qwen3_moe.py +27 -43
  169. sglang/srt/models/step3_vl.py +8 -3
  170. sglang/srt/models/xverse_moe.py +11 -5
  171. sglang/srt/multimodal/processors/base_processor.py +3 -3
  172. sglang/srt/multimodal/processors/internvl.py +7 -2
  173. sglang/srt/multimodal/processors/llava.py +11 -7
  174. sglang/srt/offloader.py +433 -0
  175. sglang/srt/operations.py +22 -2
  176. sglang/srt/reasoning_parser.py +4 -3
  177. sglang/srt/sampling/sampling_batch_info.py +7 -4
  178. sglang/srt/server_args.py +264 -105
  179. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +8 -21
  180. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
  181. sglang/srt/speculative/eagle_utils.py +36 -13
  182. sglang/srt/speculative/eagle_worker.py +56 -3
  183. sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
  184. sglang/srt/two_batch_overlap.py +20 -19
  185. sglang/srt/utils.py +68 -70
  186. sglang/test/runners.py +8 -5
  187. sglang/test/test_block_fp8.py +5 -6
  188. sglang/test/test_block_fp8_ep.py +13 -19
  189. sglang/test/test_cutlass_moe.py +4 -6
  190. sglang/test/test_cutlass_w4a8_moe.py +4 -3
  191. sglang/test/test_fp4_moe.py +4 -3
  192. sglang/test/test_marlin_moe.py +1 -1
  193. sglang/test/test_marlin_utils.py +1 -1
  194. sglang/test/test_utils.py +7 -0
  195. sglang/utils.py +0 -1
  196. sglang/version.py +1 -1
  197. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/METADATA +11 -11
  198. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/RECORD +201 -171
  199. sglang/srt/layers/quantization/fp4.py +0 -557
  200. sglang/srt/layers/quantization/scalar_type.py +0 -352
  201. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
  202. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
  203. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
@@ -1,16 +1,28 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1
2
  # SPDX-License-Identifier: Apache-2.0
2
- # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/mxfp4.py
3
16
 
4
17
  from __future__ import annotations
5
18
 
6
- import importlib.util
7
19
  import logging
8
20
  from typing import TYPE_CHECKING, List, Optional
9
21
 
10
22
  import torch
11
- import triton.language as tl
12
23
  from torch.nn.parameter import Parameter
13
24
 
25
+ from sglang.srt.layers.moe.utils import get_moe_runner_backend
14
26
  from sglang.srt.layers.quantization.base_config import (
15
27
  FusedMoEMethodBase,
16
28
  QuantizationConfig,
@@ -27,6 +39,7 @@ from sglang.srt.utils import (
27
39
  is_hip,
28
40
  is_triton_kernels_available,
29
41
  log_info_on_rank0,
42
+ mxfp_supported,
30
43
  next_power_of_2,
31
44
  round_up,
32
45
  set_weight_attrs,
@@ -47,9 +60,17 @@ if is_flashinfer_available():
47
60
  logger = logging.getLogger(__name__)
48
61
 
49
62
  if TYPE_CHECKING:
63
+ from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
50
64
  from sglang.srt.layers.moe.topk import TopKOutput
51
65
 
52
- OCP_MX_BLOCK_SIZE = 32
66
+ _is_hip = is_hip()
67
+
68
+ if _is_hip:
69
+ # import aiter
70
+ from aiter import ActivationType, QuantType, dtypes
71
+ from aiter.fused_moe import fused_moe
72
+ from aiter.ops.triton.quant import dynamic_mxfp4_quant
73
+ from aiter.utility.fp4_utils import e8m0_shuffle
53
74
 
54
75
 
55
76
  def _swizzle_mxfp4(quant_tensor, scale, num_warps):
@@ -150,13 +171,34 @@ except AttributeError as error:
150
171
 
151
172
  class Mxfp4Config(QuantizationConfig):
152
173
 
153
- def __init__(self, ignored_layers: Optional[list[str]] = None):
174
+ def __init__(
175
+ self,
176
+ ignored_layers: Optional[list[str]] = None,
177
+ is_checkpoint_mxfp4_serialized: bool = False,
178
+ ):
154
179
  super().__init__()
180
+ self.is_checkpoint_mxfp4_serialized = is_checkpoint_mxfp4_serialized
155
181
  self.ignored_layers = ignored_layers
156
182
 
157
183
  @classmethod
158
184
  def from_config(cls, config):
159
- return cls()
185
+
186
+ quant_method = cls.get_from_keys(config, ["quant_method"])
187
+ is_checkpoint_mxfp4_serialized = "mxfp4" in quant_method
188
+
189
+ if _is_hip:
190
+ if mxfp_supported():
191
+ return cls(
192
+ is_checkpoint_mxfp4_serialized=is_checkpoint_mxfp4_serialized
193
+ )
194
+ else:
195
+
196
+ platform = torch.cuda.get_device_properties(0).gcnArchName
197
+ raise ValueError(
198
+ f"Current platform {platform} not support mxfp4 computation"
199
+ )
200
+
201
+ return cls(is_checkpoint_mxfp4_serialized=is_checkpoint_mxfp4_serialized)
160
202
 
161
203
  @classmethod
162
204
  def get_min_capability(cls) -> int:
@@ -174,6 +216,9 @@ class Mxfp4Config(QuantizationConfig):
174
216
  def get_config_filenames(cls) -> list[str]:
175
217
  return []
176
218
 
219
+ def is_static_cfg(self):
220
+ return self.is_checkpoint_mxfp4_serialized
221
+
177
222
  def get_quant_method(
178
223
  self, layer: torch.nn.Module, prefix: str
179
224
  ) -> Optional["QuantizeMethodBase"]:
@@ -189,10 +234,16 @@ class Mxfp4Config(QuantizationConfig):
189
234
  fused_mapping=self.packed_modules_mapping,
190
235
  ):
191
236
  return UnquantizedLinearMethod()
237
+ elif _is_hip:
238
+ return UnquantizedLinearMethod()
192
239
  elif isinstance(layer, FusedMoE):
193
- return Mxfp4MoEMethod(prefix)
240
+ if self.is_checkpoint_mxfp4_serialized:
241
+ return Mxfp4MoEMethod(prefix=prefix)
242
+ else:
243
+ return Mxfp4DynamicQuantMoEMethod()
194
244
  else:
195
- raise NotImplementedError("Mxfp4 attention layer is not implemented")
245
+ if self.is_checkpoint_mxfp4_serialized:
246
+ raise NotImplementedError("Mxfp4 attention layer is not implemented")
196
247
  return None
197
248
 
198
249
  def get_scaled_act_names(self) -> List[str]:
@@ -205,14 +256,16 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
205
256
  self,
206
257
  prefix: str,
207
258
  ):
208
- from sglang.srt.managers.schedule_batch import global_server_args_dict
209
-
210
259
  super().__init__()
211
260
 
261
+ self.prefix = prefix
212
262
  self.topk_indices_dtype = None
213
- self.use_triton_kernels = global_server_args_dict["enable_triton_kernel_moe"]
263
+ self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
214
264
  self.with_bias = False
215
- self.use_flashinfer = global_server_args_dict["enable_flashinfer_mxfp4_moe"]
265
+ self.use_flashinfer = get_moe_runner_backend().is_flashinfer_mxfp4()
266
+ self.flashinfer_mxfp4_moe_precision = global_server_args_dict[
267
+ "flashinfer_mxfp4_moe_precision"
268
+ ]
216
269
 
217
270
  self.triton_kernel_moe_forward = None
218
271
  self.triton_kernel_moe_with_bias_forward = None
@@ -256,6 +309,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
256
309
  intermediate_size_per_partition_after_pad = round_up(
257
310
  intermediate_size, 64
258
311
  )
312
+ elif has_triton_kernels:
313
+ # TODO: this is a hack to make
314
+ # intermediate_size_per_partition_after_pad the same as the
315
+ # per_rank_intermediate_size during weight loading
316
+ intermediate_size_per_partition_after_pad = round_up(
317
+ intermediate_size, mxfp4_block
318
+ )
259
319
 
260
320
  self.intermediate_size = intermediate_size_per_partition_after_pad
261
321
 
@@ -332,8 +392,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
332
392
  if self.use_flashinfer:
333
393
  log_info_on_rank0(
334
394
  logger,
335
- "Shuffling MoE weights for FlashInfer MXFP4 moe kernel, it might take a while...",
395
+ f"Shuffling MoE weights for FlashInfer MXFP4 moe kernel (layer: {self.prefix}), it might take a while...",
336
396
  )
397
+ # TODO: these values are hardcoded for now, we need to get them from the model
337
398
  layer.gemm1_alpha = Parameter(
338
399
  torch.tensor([1.702] * self.num_experts, dtype=torch.float32).cuda(),
339
400
  requires_grad=False,
@@ -559,24 +620,40 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
559
620
  layer: torch.nn.Module,
560
621
  x: torch.Tensor,
561
622
  topk_output: TopKOutput,
562
- *,
563
- activation: str = "silu",
564
- apply_router_weight_on_input: bool = False,
565
- inplace: bool = True,
566
- no_combine: bool = False,
567
- routed_scaling_factor: Optional[float] = None,
568
- activation_alpha: Optional[float] = None,
569
- swiglu_limit: Optional[float] = None,
623
+ moe_runner_config: MoeRunnerConfig,
570
624
  ) -> torch.Tensor:
625
+
626
+ from sglang.srt.layers.moe.topk import TopKOutputChecker
627
+
571
628
  if self.use_flashinfer:
572
- # Based on profiling results, we need to quantize x to mxfp8 here to achieve better performance
573
- x_quant, x_scale = mxfp8_quantize(
574
- x, False, alignment=self.hidden_size
575
- ) # to mxfp8
576
- x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1)
629
+ # When bf16 mode is enabled, we don't need to quantize the input,
630
+ # TRT-LLM automatically handles quantization in the kernel implementation and pipelines it with GEMM operations,
631
+ # which can theoretically improve performance
632
+ if self.flashinfer_mxfp4_moe_precision == "bf16":
633
+ assert x.dtype == torch.bfloat16
634
+ x_quant = x
635
+ x_scale = None
636
+
637
+ # May be fused later if this code branch is frequently needed
638
+ origin_hidden_states_dim = x_quant.shape[-1]
639
+ if self.hidden_size != origin_hidden_states_dim:
640
+ x_quant = torch.nn.functional.pad(
641
+ x_quant,
642
+ (0, self.hidden_size - origin_hidden_states_dim),
643
+ mode="constant",
644
+ value=0.0,
645
+ )
646
+ elif self.flashinfer_mxfp4_moe_precision == "default":
647
+ x_quant, x_scale = mxfp8_quantize(x, False, alignment=self.hidden_size)
648
+ x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1)
649
+ else:
650
+ raise NotImplementedError
651
+
577
652
  assert x_quant.shape[-1] == self.hidden_size
653
+ assert TopKOutputChecker.format_is_bypassed(topk_output)
578
654
 
579
- top_k, router_logits = topk_output
655
+ top_k = topk_output.topk_config.top_k
656
+ router_logits = topk_output.router_logits
580
657
 
581
658
  trtllm_gen_output = trtllm_fp4_block_scale_moe(
582
659
  router_logits.to(torch.bfloat16),
@@ -597,8 +674,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
597
674
  None, # output2_scale_scalar
598
675
  layer.num_experts,
599
676
  top_k,
600
- None, # n_group
601
- None, # topk_group
677
+ None, # n_group # TODO: support n_group
678
+ None, # topk_group # TODO: support topk_group
602
679
  self.intermediate_size, # padded to multiple of 256
603
680
  layer.moe_ep_rank * layer.num_local_experts, # local_expert_offset
604
681
  layer.num_local_experts, # local num experts
@@ -623,9 +700,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
623
700
  b1=layer.w13_weight_bias,
624
701
  b2=layer.w2_weight_bias,
625
702
  topk_output=topk_output,
626
- activation=activation,
627
- activation_alpha=activation_alpha,
628
- swiglu_limit=swiglu_limit,
703
+ moe_runner_config=moe_runner_config,
629
704
  )
630
705
  else:
631
706
  return self.triton_kernel_moe_forward(
@@ -633,6 +708,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
633
708
  w1=layer.w13_weight,
634
709
  w2=layer.w2_weight,
635
710
  topk_output=topk_output,
711
+ moe_runner_config=moe_runner_config,
636
712
  )
637
713
  else:
638
714
  from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
@@ -642,13 +718,120 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
642
718
  w1=layer.w13_weight,
643
719
  w2=layer.w2_weight,
644
720
  topk_output=topk_output,
721
+ moe_runner_config=moe_runner_config,
645
722
  b1=layer.w13_weight_bias,
646
723
  b2=layer.w2_weight_bias,
647
- inplace=inplace,
648
- activation=activation,
649
- apply_router_weight_on_input=apply_router_weight_on_input,
650
- no_combine=no_combine,
651
- routed_scaling_factor=routed_scaling_factor,
652
- activation_alpha=activation_alpha,
653
- swiglu_limit=swiglu_limit,
654
724
  )
725
+
726
+
727
+ class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase):
728
+ def create_weights(
729
+ self,
730
+ layer: torch.nn.Module,
731
+ num_experts: int,
732
+ hidden_size: int,
733
+ intermediate_size_per_partition: int,
734
+ params_dtype: torch.dtype,
735
+ **extra_weight_attrs,
736
+ ):
737
+
738
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
739
+
740
+ w13_weight = torch.nn.Parameter(
741
+ torch.empty(
742
+ num_experts,
743
+ 2 * intermediate_size_per_partition,
744
+ hidden_size,
745
+ dtype=params_dtype,
746
+ ),
747
+ requires_grad=False,
748
+ )
749
+ w2_weight = torch.nn.Parameter(
750
+ torch.empty(
751
+ num_experts,
752
+ hidden_size,
753
+ intermediate_size_per_partition,
754
+ dtype=params_dtype,
755
+ ),
756
+ requires_grad=False,
757
+ )
758
+
759
+ layer.register_parameter("w13_weight", w13_weight)
760
+ set_weight_attrs(w13_weight, extra_weight_attrs)
761
+
762
+ layer.register_parameter("w2_weight", w2_weight)
763
+ set_weight_attrs(w2_weight, extra_weight_attrs)
764
+
765
+ # Allocate 2 scales for w1 and w3 respectively.
766
+ # They will be combined to a single scale after weight loading.
767
+ w13_weight_scale = torch.nn.Parameter(
768
+ torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
769
+ )
770
+ w2_weight_scale = torch.nn.Parameter(
771
+ torch.ones(num_experts, dtype=torch.float32), requires_grad=False
772
+ )
773
+ layer.register_parameter("w13_weight_scale", w13_weight_scale)
774
+ layer.register_parameter("w2_weight_scale", w2_weight_scale)
775
+
776
+ # Add the quantization method used (per tensor/grouped/channel)
777
+ # to ensure the weight scales are loaded in properly
778
+ extra_weight_attrs.update(
779
+ {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
780
+ )
781
+
782
+ layer.w13_input_scale = None
783
+ layer.w2_input_scale = None
784
+
785
+ def mxfp4_quantize(self, w):
786
+ w_shape = w.shape
787
+ w_need_reshape = True if w.dim() != 2 else False
788
+
789
+ if w_need_reshape:
790
+ w_last_dim_size = w_shape[-1]
791
+ w = w.view(-1, w_last_dim_size)
792
+
793
+ w, mx_scales = dynamic_mxfp4_quant(w)
794
+
795
+ if w_need_reshape:
796
+ w_new_shape = w_shape[:-1] + (w.shape[-1],)
797
+ w = w.view(w_new_shape)
798
+
799
+ mx_scales = e8m0_shuffle(mx_scales)
800
+
801
+ return w, mx_scales
802
+
803
+ def process_weights_after_loading(self, layer: Module) -> None:
804
+ w13, w13_mx_scales = self.mxfp4_quantize(layer.w13_weight.data)
805
+ w2, w2_mx_scales = self.mxfp4_quantize(layer.w2_weight.data)
806
+
807
+ layer.w13_weight = torch.nn.Parameter(w13, requires_grad=False)
808
+ layer.w13_weight_scale = torch.nn.Parameter(w13_mx_scales, requires_grad=False)
809
+
810
+ layer.w2_weight = torch.nn.Parameter(w2, requires_grad=False)
811
+ layer.w2_weight_scale = torch.nn.Parameter(w2_mx_scales, requires_grad=False)
812
+
813
+ def apply(
814
+ self,
815
+ layer: torch.nn.Module,
816
+ x: torch.Tensor,
817
+ topk_output: TopKOutput,
818
+ moe_runner_config: MoeRunnerConfig,
819
+ ) -> torch.Tensor:
820
+ topk_weights, topk_ids, _ = topk_output
821
+
822
+ return fused_moe(
823
+ x,
824
+ layer.w13_weight,
825
+ layer.w2_weight,
826
+ topk_weights,
827
+ topk_ids,
828
+ quant_type=QuantType.per_1x32,
829
+ w1_scale=layer.w13_weight_scale,
830
+ w2_scale=layer.w2_weight_scale,
831
+ activation=(
832
+ ActivationType.Silu
833
+ if moe_runner_config.activation == "silu"
834
+ else ActivationType.Gelu
835
+ ),
836
+ doweight_stage1=False,
837
+ )