sglang 0.5.0rc2__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. sglang/bench_one_batch.py +0 -6
  2. sglang/bench_one_batch_server.py +7 -2
  3. sglang/bench_serving.py +3 -3
  4. sglang/eval/llama3_eval.py +0 -1
  5. sglang/srt/configs/model_config.py +24 -9
  6. sglang/srt/configs/update_config.py +40 -5
  7. sglang/srt/constrained/xgrammar_backend.py +23 -11
  8. sglang/srt/conversation.py +2 -15
  9. sglang/srt/disaggregation/ascend/conn.py +1 -3
  10. sglang/srt/disaggregation/base/conn.py +1 -0
  11. sglang/srt/disaggregation/decode.py +1 -1
  12. sglang/srt/disaggregation/launch_lb.py +7 -1
  13. sglang/srt/disaggregation/mini_lb.py +11 -5
  14. sglang/srt/disaggregation/mooncake/conn.py +141 -47
  15. sglang/srt/disaggregation/prefill.py +261 -5
  16. sglang/srt/disaggregation/utils.py +2 -1
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  18. sglang/srt/distributed/device_communicators/pynccl.py +68 -18
  19. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
  20. sglang/srt/distributed/naive_distributed.py +112 -0
  21. sglang/srt/distributed/parallel_state.py +90 -4
  22. sglang/srt/entrypoints/context.py +20 -1
  23. sglang/srt/entrypoints/engine.py +27 -2
  24. sglang/srt/entrypoints/http_server.py +12 -0
  25. sglang/srt/entrypoints/openai/protocol.py +2 -2
  26. sglang/srt/entrypoints/openai/serving_chat.py +22 -6
  27. sglang/srt/entrypoints/openai/serving_completions.py +9 -1
  28. sglang/srt/entrypoints/openai/serving_responses.py +2 -2
  29. sglang/srt/eplb/expert_distribution.py +2 -3
  30. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  31. sglang/srt/hf_transformers_utils.py +24 -0
  32. sglang/srt/host_shared_memory.py +83 -0
  33. sglang/srt/layers/attention/ascend_backend.py +132 -22
  34. sglang/srt/layers/attention/flashattention_backend.py +24 -17
  35. sglang/srt/layers/attention/flashinfer_backend.py +11 -3
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +226 -76
  37. sglang/srt/layers/attention/triton_backend.py +85 -46
  38. sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
  39. sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
  40. sglang/srt/layers/attention/trtllm_mha_backend.py +390 -30
  41. sglang/srt/layers/attention/trtllm_mla_backend.py +39 -16
  42. sglang/srt/layers/attention/utils.py +94 -15
  43. sglang/srt/layers/attention/vision.py +40 -13
  44. sglang/srt/layers/attention/vision_utils.py +65 -0
  45. sglang/srt/layers/communicator.py +51 -3
  46. sglang/srt/layers/dp_attention.py +23 -4
  47. sglang/srt/layers/elementwise.py +94 -0
  48. sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
  49. sglang/srt/layers/layernorm.py +8 -1
  50. sglang/srt/layers/linear.py +24 -0
  51. sglang/srt/layers/logits_processor.py +5 -1
  52. sglang/srt/layers/moe/__init__.py +31 -0
  53. sglang/srt/layers/moe/ep_moe/layer.py +37 -33
  54. sglang/srt/layers/moe/fused_moe_native.py +14 -25
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
  60. sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
  61. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
  62. sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
  63. sglang/srt/layers/moe/moe_runner/base.py +13 -0
  64. sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
  65. sglang/srt/layers/moe/router.py +15 -9
  66. sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
  67. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
  68. sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
  69. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  70. sglang/srt/layers/moe/topk.py +167 -83
  71. sglang/srt/layers/moe/utils.py +159 -18
  72. sglang/srt/layers/quantization/__init__.py +13 -14
  73. sglang/srt/layers/quantization/awq.py +7 -7
  74. sglang/srt/layers/quantization/base_config.py +2 -6
  75. sglang/srt/layers/quantization/blockwise_int8.py +4 -12
  76. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -28
  77. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
  78. sglang/srt/layers/quantization/fp8.py +127 -119
  79. sglang/srt/layers/quantization/fp8_kernel.py +195 -24
  80. sglang/srt/layers/quantization/fp8_utils.py +34 -9
  81. sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
  82. sglang/srt/layers/quantization/gptq.py +5 -4
  83. sglang/srt/layers/quantization/marlin_utils.py +11 -3
  84. sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
  85. sglang/srt/layers/quantization/modelopt_quant.py +165 -68
  86. sglang/srt/layers/quantization/moe_wna16.py +10 -15
  87. sglang/srt/layers/quantization/mxfp4.py +206 -37
  88. sglang/srt/layers/quantization/quark/quark.py +390 -0
  89. sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
  90. sglang/srt/layers/quantization/unquant.py +34 -70
  91. sglang/srt/layers/quantization/utils.py +25 -0
  92. sglang/srt/layers/quantization/w4afp8.py +7 -8
  93. sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
  94. sglang/srt/layers/quantization/w8a8_int8.py +5 -13
  95. sglang/srt/layers/radix_attention.py +6 -0
  96. sglang/srt/layers/rotary_embedding.py +1 -0
  97. sglang/srt/lora/lora_manager.py +21 -22
  98. sglang/srt/lora/lora_registry.py +3 -3
  99. sglang/srt/lora/mem_pool.py +26 -24
  100. sglang/srt/lora/utils.py +10 -12
  101. sglang/srt/managers/cache_controller.py +76 -18
  102. sglang/srt/managers/detokenizer_manager.py +10 -2
  103. sglang/srt/managers/io_struct.py +9 -0
  104. sglang/srt/managers/mm_utils.py +1 -1
  105. sglang/srt/managers/schedule_batch.py +4 -9
  106. sglang/srt/managers/scheduler.py +25 -16
  107. sglang/srt/managers/session_controller.py +1 -1
  108. sglang/srt/managers/template_manager.py +7 -5
  109. sglang/srt/managers/tokenizer_manager.py +60 -21
  110. sglang/srt/managers/tp_worker.py +1 -0
  111. sglang/srt/managers/utils.py +59 -1
  112. sglang/srt/mem_cache/allocator.py +7 -5
  113. sglang/srt/mem_cache/allocator_ascend.py +0 -11
  114. sglang/srt/mem_cache/hicache_storage.py +14 -4
  115. sglang/srt/mem_cache/memory_pool.py +3 -3
  116. sglang/srt/mem_cache/memory_pool_host.py +35 -2
  117. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
  118. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
  119. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
  120. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
  121. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
  122. sglang/srt/model_executor/cuda_graph_runner.py +25 -12
  123. sglang/srt/model_executor/forward_batch_info.py +4 -1
  124. sglang/srt/model_executor/model_runner.py +43 -32
  125. sglang/srt/model_executor/npu_graph_runner.py +94 -0
  126. sglang/srt/model_loader/loader.py +24 -6
  127. sglang/srt/models/dbrx.py +12 -6
  128. sglang/srt/models/deepseek.py +2 -1
  129. sglang/srt/models/deepseek_nextn.py +3 -1
  130. sglang/srt/models/deepseek_v2.py +224 -223
  131. sglang/srt/models/ernie4.py +2 -2
  132. sglang/srt/models/glm4_moe.py +25 -63
  133. sglang/srt/models/glm4v.py +52 -1
  134. sglang/srt/models/glm4v_moe.py +8 -11
  135. sglang/srt/models/gpt_oss.py +34 -74
  136. sglang/srt/models/granitemoe.py +0 -1
  137. sglang/srt/models/grok.py +376 -48
  138. sglang/srt/models/interns1.py +12 -47
  139. sglang/srt/models/internvl.py +6 -51
  140. sglang/srt/models/llama4.py +0 -2
  141. sglang/srt/models/minicpm3.py +0 -1
  142. sglang/srt/models/mixtral.py +0 -2
  143. sglang/srt/models/nemotron_nas.py +435 -0
  144. sglang/srt/models/olmoe.py +0 -1
  145. sglang/srt/models/phi4mm.py +3 -21
  146. sglang/srt/models/qwen2_5_vl.py +2 -0
  147. sglang/srt/models/qwen2_moe.py +3 -18
  148. sglang/srt/models/qwen3.py +2 -2
  149. sglang/srt/models/qwen3_classification.py +7 -1
  150. sglang/srt/models/qwen3_moe.py +9 -38
  151. sglang/srt/models/step3_vl.py +2 -1
  152. sglang/srt/models/xverse_moe.py +11 -5
  153. sglang/srt/multimodal/processors/base_processor.py +3 -3
  154. sglang/srt/multimodal/processors/internvl.py +7 -2
  155. sglang/srt/multimodal/processors/llava.py +11 -7
  156. sglang/srt/offloader.py +433 -0
  157. sglang/srt/operations.py +6 -1
  158. sglang/srt/reasoning_parser.py +4 -3
  159. sglang/srt/server_args.py +237 -104
  160. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
  161. sglang/srt/speculative/eagle_utils.py +36 -13
  162. sglang/srt/speculative/eagle_worker.py +56 -3
  163. sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
  164. sglang/srt/two_batch_overlap.py +16 -11
  165. sglang/srt/utils.py +68 -70
  166. sglang/test/runners.py +8 -5
  167. sglang/test/test_block_fp8.py +5 -6
  168. sglang/test/test_block_fp8_ep.py +13 -19
  169. sglang/test/test_cutlass_moe.py +4 -6
  170. sglang/test/test_cutlass_w4a8_moe.py +4 -3
  171. sglang/test/test_fp4_moe.py +4 -3
  172. sglang/test/test_utils.py +7 -0
  173. sglang/utils.py +0 -1
  174. sglang/version.py +1 -1
  175. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/METADATA +7 -7
  176. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/RECORD +179 -161
  177. sglang/srt/layers/quantization/fp4.py +0 -557
  178. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
  179. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
  180. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
@@ -16,14 +16,13 @@
16
16
 
17
17
  from __future__ import annotations
18
18
 
19
- import importlib.util
20
19
  import logging
21
20
  from typing import TYPE_CHECKING, List, Optional
22
21
 
23
22
  import torch
24
- import triton.language as tl
25
23
  from torch.nn.parameter import Parameter
26
24
 
25
+ from sglang.srt.layers.moe.utils import get_moe_runner_backend
27
26
  from sglang.srt.layers.quantization.base_config import (
28
27
  FusedMoEMethodBase,
29
28
  QuantizationConfig,
@@ -40,6 +39,7 @@ from sglang.srt.utils import (
40
39
  is_hip,
41
40
  is_triton_kernels_available,
42
41
  log_info_on_rank0,
42
+ mxfp_supported,
43
43
  next_power_of_2,
44
44
  round_up,
45
45
  set_weight_attrs,
@@ -60,9 +60,17 @@ if is_flashinfer_available():
60
60
  logger = logging.getLogger(__name__)
61
61
 
62
62
  if TYPE_CHECKING:
63
+ from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
63
64
  from sglang.srt.layers.moe.topk import TopKOutput
64
65
 
65
- OCP_MX_BLOCK_SIZE = 32
66
+ _is_hip = is_hip()
67
+
68
+ if _is_hip:
69
+ # import aiter
70
+ from aiter import ActivationType, QuantType, dtypes
71
+ from aiter.fused_moe import fused_moe
72
+ from aiter.ops.triton.quant import dynamic_mxfp4_quant
73
+ from aiter.utility.fp4_utils import e8m0_shuffle
66
74
 
67
75
 
68
76
  def _swizzle_mxfp4(quant_tensor, scale, num_warps):
@@ -163,13 +171,34 @@ except AttributeError as error:
163
171
 
164
172
  class Mxfp4Config(QuantizationConfig):
165
173
 
166
- def __init__(self, ignored_layers: Optional[list[str]] = None):
174
+ def __init__(
175
+ self,
176
+ ignored_layers: Optional[list[str]] = None,
177
+ is_checkpoint_mxfp4_serialized: bool = False,
178
+ ):
167
179
  super().__init__()
180
+ self.is_checkpoint_mxfp4_serialized = is_checkpoint_mxfp4_serialized
168
181
  self.ignored_layers = ignored_layers
169
182
 
170
183
  @classmethod
171
184
  def from_config(cls, config):
172
- return cls()
185
+
186
+ quant_method = cls.get_from_keys(config, ["quant_method"])
187
+ is_checkpoint_mxfp4_serialized = "mxfp4" in quant_method
188
+
189
+ if _is_hip:
190
+ if mxfp_supported():
191
+ return cls(
192
+ is_checkpoint_mxfp4_serialized=is_checkpoint_mxfp4_serialized
193
+ )
194
+ else:
195
+
196
+ platform = torch.cuda.get_device_properties(0).gcnArchName
197
+ raise ValueError(
198
+ f"Current platform {platform} not support mxfp4 computation"
199
+ )
200
+
201
+ return cls(is_checkpoint_mxfp4_serialized=is_checkpoint_mxfp4_serialized)
173
202
 
174
203
  @classmethod
175
204
  def get_min_capability(cls) -> int:
@@ -187,6 +216,9 @@ class Mxfp4Config(QuantizationConfig):
187
216
  def get_config_filenames(cls) -> list[str]:
188
217
  return []
189
218
 
219
+ def is_static_cfg(self):
220
+ return self.is_checkpoint_mxfp4_serialized
221
+
190
222
  def get_quant_method(
191
223
  self, layer: torch.nn.Module, prefix: str
192
224
  ) -> Optional["QuantizeMethodBase"]:
@@ -202,10 +234,16 @@ class Mxfp4Config(QuantizationConfig):
202
234
  fused_mapping=self.packed_modules_mapping,
203
235
  ):
204
236
  return UnquantizedLinearMethod()
237
+ elif _is_hip:
238
+ return UnquantizedLinearMethod()
205
239
  elif isinstance(layer, FusedMoE):
206
- return Mxfp4MoEMethod(prefix)
240
+ if self.is_checkpoint_mxfp4_serialized:
241
+ return Mxfp4MoEMethod(prefix=prefix)
242
+ else:
243
+ return Mxfp4DynamicQuantMoEMethod()
207
244
  else:
208
- raise NotImplementedError("Mxfp4 attention layer is not implemented")
245
+ if self.is_checkpoint_mxfp4_serialized:
246
+ raise NotImplementedError("Mxfp4 attention layer is not implemented")
209
247
  return None
210
248
 
211
249
  def get_scaled_act_names(self) -> List[str]:
@@ -218,15 +256,16 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
218
256
  self,
219
257
  prefix: str,
220
258
  ):
221
- from sglang.srt.managers.schedule_batch import global_server_args_dict
222
-
223
259
  super().__init__()
224
260
 
225
261
  self.prefix = prefix
226
262
  self.topk_indices_dtype = None
227
- self.use_triton_kernels = global_server_args_dict["enable_triton_kernel_moe"]
263
+ self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
228
264
  self.with_bias = False
229
- self.use_flashinfer = global_server_args_dict["enable_flashinfer_mxfp4_moe"]
265
+ self.use_flashinfer = get_moe_runner_backend().is_flashinfer_mxfp4()
266
+ self.flashinfer_mxfp4_moe_precision = global_server_args_dict[
267
+ "flashinfer_mxfp4_moe_precision"
268
+ ]
230
269
 
231
270
  self.triton_kernel_moe_forward = None
232
271
  self.triton_kernel_moe_with_bias_forward = None
@@ -270,6 +309,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
270
309
  intermediate_size_per_partition_after_pad = round_up(
271
310
  intermediate_size, 64
272
311
  )
312
+ elif has_triton_kernels:
313
+ # TODO: this is a hack to make
314
+ # intermediate_size_per_partition_after_pad the same as the
315
+ # per_rank_intermediate_size during weight loading
316
+ intermediate_size_per_partition_after_pad = round_up(
317
+ intermediate_size, mxfp4_block
318
+ )
273
319
 
274
320
  self.intermediate_size = intermediate_size_per_partition_after_pad
275
321
 
@@ -348,6 +394,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
348
394
  logger,
349
395
  f"Shuffling MoE weights for FlashInfer MXFP4 moe kernel (layer: {self.prefix}), it might take a while...",
350
396
  )
397
+ # TODO: these values are hardcoded for now, we need to get them from the model
351
398
  layer.gemm1_alpha = Parameter(
352
399
  torch.tensor([1.702] * self.num_experts, dtype=torch.float32).cuda(),
353
400
  requires_grad=False,
@@ -573,24 +620,40 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
573
620
  layer: torch.nn.Module,
574
621
  x: torch.Tensor,
575
622
  topk_output: TopKOutput,
576
- *,
577
- activation: str = "silu",
578
- apply_router_weight_on_input: bool = False,
579
- inplace: bool = True,
580
- no_combine: bool = False,
581
- routed_scaling_factor: Optional[float] = None,
582
- activation_alpha: Optional[float] = None,
583
- swiglu_limit: Optional[float] = None,
623
+ moe_runner_config: MoeRunnerConfig,
584
624
  ) -> torch.Tensor:
625
+
626
+ from sglang.srt.layers.moe.topk import TopKOutputChecker
627
+
585
628
  if self.use_flashinfer:
586
- # Based on profiling results, we need to quantize x to mxfp8 here to achieve better performance
587
- x_quant, x_scale = mxfp8_quantize(
588
- x, False, alignment=self.hidden_size
589
- ) # to mxfp8
590
- x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1)
629
+ # When bf16 mode is enabled, we don't need to quantize the input,
630
+ # TRT-LLM automatically handles quantization in the kernel implementation and pipelines it with GEMM operations,
631
+ # which can theoretically improve performance
632
+ if self.flashinfer_mxfp4_moe_precision == "bf16":
633
+ assert x.dtype == torch.bfloat16
634
+ x_quant = x
635
+ x_scale = None
636
+
637
+ # May be fused later if this code branch is frequently needed
638
+ origin_hidden_states_dim = x_quant.shape[-1]
639
+ if self.hidden_size != origin_hidden_states_dim:
640
+ x_quant = torch.nn.functional.pad(
641
+ x_quant,
642
+ (0, self.hidden_size - origin_hidden_states_dim),
643
+ mode="constant",
644
+ value=0.0,
645
+ )
646
+ elif self.flashinfer_mxfp4_moe_precision == "default":
647
+ x_quant, x_scale = mxfp8_quantize(x, False, alignment=self.hidden_size)
648
+ x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1)
649
+ else:
650
+ raise NotImplementedError
651
+
591
652
  assert x_quant.shape[-1] == self.hidden_size
653
+ assert TopKOutputChecker.format_is_bypassed(topk_output)
592
654
 
593
- top_k, router_logits = topk_output
655
+ top_k = topk_output.topk_config.top_k
656
+ router_logits = topk_output.router_logits
594
657
 
595
658
  trtllm_gen_output = trtllm_fp4_block_scale_moe(
596
659
  router_logits.to(torch.bfloat16),
@@ -611,8 +674,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
611
674
  None, # output2_scale_scalar
612
675
  layer.num_experts,
613
676
  top_k,
614
- None, # n_group
615
- None, # topk_group
677
+ None, # n_group # TODO: support n_group
678
+ None, # topk_group # TODO: support topk_group
616
679
  self.intermediate_size, # padded to multiple of 256
617
680
  layer.moe_ep_rank * layer.num_local_experts, # local_expert_offset
618
681
  layer.num_local_experts, # local num experts
@@ -637,9 +700,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
637
700
  b1=layer.w13_weight_bias,
638
701
  b2=layer.w2_weight_bias,
639
702
  topk_output=topk_output,
640
- activation=activation,
641
- activation_alpha=activation_alpha,
642
- swiglu_limit=swiglu_limit,
703
+ moe_runner_config=moe_runner_config,
643
704
  )
644
705
  else:
645
706
  return self.triton_kernel_moe_forward(
@@ -647,6 +708,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
647
708
  w1=layer.w13_weight,
648
709
  w2=layer.w2_weight,
649
710
  topk_output=topk_output,
711
+ moe_runner_config=moe_runner_config,
650
712
  )
651
713
  else:
652
714
  from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
@@ -656,13 +718,120 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
656
718
  w1=layer.w13_weight,
657
719
  w2=layer.w2_weight,
658
720
  topk_output=topk_output,
721
+ moe_runner_config=moe_runner_config,
659
722
  b1=layer.w13_weight_bias,
660
723
  b2=layer.w2_weight_bias,
661
- inplace=inplace,
662
- activation=activation,
663
- apply_router_weight_on_input=apply_router_weight_on_input,
664
- no_combine=no_combine,
665
- routed_scaling_factor=routed_scaling_factor,
666
- activation_alpha=activation_alpha,
667
- swiglu_limit=swiglu_limit,
668
724
  )
725
+
726
+
727
+ class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase):
728
+ def create_weights(
729
+ self,
730
+ layer: torch.nn.Module,
731
+ num_experts: int,
732
+ hidden_size: int,
733
+ intermediate_size_per_partition: int,
734
+ params_dtype: torch.dtype,
735
+ **extra_weight_attrs,
736
+ ):
737
+
738
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
739
+
740
+ w13_weight = torch.nn.Parameter(
741
+ torch.empty(
742
+ num_experts,
743
+ 2 * intermediate_size_per_partition,
744
+ hidden_size,
745
+ dtype=params_dtype,
746
+ ),
747
+ requires_grad=False,
748
+ )
749
+ w2_weight = torch.nn.Parameter(
750
+ torch.empty(
751
+ num_experts,
752
+ hidden_size,
753
+ intermediate_size_per_partition,
754
+ dtype=params_dtype,
755
+ ),
756
+ requires_grad=False,
757
+ )
758
+
759
+ layer.register_parameter("w13_weight", w13_weight)
760
+ set_weight_attrs(w13_weight, extra_weight_attrs)
761
+
762
+ layer.register_parameter("w2_weight", w2_weight)
763
+ set_weight_attrs(w2_weight, extra_weight_attrs)
764
+
765
+ # Allocate 2 scales for w1 and w3 respectively.
766
+ # They will be combined to a single scale after weight loading.
767
+ w13_weight_scale = torch.nn.Parameter(
768
+ torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
769
+ )
770
+ w2_weight_scale = torch.nn.Parameter(
771
+ torch.ones(num_experts, dtype=torch.float32), requires_grad=False
772
+ )
773
+ layer.register_parameter("w13_weight_scale", w13_weight_scale)
774
+ layer.register_parameter("w2_weight_scale", w2_weight_scale)
775
+
776
+ # Add the quantization method used (per tensor/grouped/channel)
777
+ # to ensure the weight scales are loaded in properly
778
+ extra_weight_attrs.update(
779
+ {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
780
+ )
781
+
782
+ layer.w13_input_scale = None
783
+ layer.w2_input_scale = None
784
+
785
+ def mxfp4_quantize(self, w):
786
+ w_shape = w.shape
787
+ w_need_reshape = True if w.dim() != 2 else False
788
+
789
+ if w_need_reshape:
790
+ w_last_dim_size = w_shape[-1]
791
+ w = w.view(-1, w_last_dim_size)
792
+
793
+ w, mx_scales = dynamic_mxfp4_quant(w)
794
+
795
+ if w_need_reshape:
796
+ w_new_shape = w_shape[:-1] + (w.shape[-1],)
797
+ w = w.view(w_new_shape)
798
+
799
+ mx_scales = e8m0_shuffle(mx_scales)
800
+
801
+ return w, mx_scales
802
+
803
+ def process_weights_after_loading(self, layer: Module) -> None:
804
+ w13, w13_mx_scales = self.mxfp4_quantize(layer.w13_weight.data)
805
+ w2, w2_mx_scales = self.mxfp4_quantize(layer.w2_weight.data)
806
+
807
+ layer.w13_weight = torch.nn.Parameter(w13, requires_grad=False)
808
+ layer.w13_weight_scale = torch.nn.Parameter(w13_mx_scales, requires_grad=False)
809
+
810
+ layer.w2_weight = torch.nn.Parameter(w2, requires_grad=False)
811
+ layer.w2_weight_scale = torch.nn.Parameter(w2_mx_scales, requires_grad=False)
812
+
813
+ def apply(
814
+ self,
815
+ layer: torch.nn.Module,
816
+ x: torch.Tensor,
817
+ topk_output: TopKOutput,
818
+ moe_runner_config: MoeRunnerConfig,
819
+ ) -> torch.Tensor:
820
+ topk_weights, topk_ids, _ = topk_output
821
+
822
+ return fused_moe(
823
+ x,
824
+ layer.w13_weight,
825
+ layer.w2_weight,
826
+ topk_weights,
827
+ topk_ids,
828
+ quant_type=QuantType.per_1x32,
829
+ w1_scale=layer.w13_weight_scale,
830
+ w2_scale=layer.w2_weight_scale,
831
+ activation=(
832
+ ActivationType.Silu
833
+ if moe_runner_config.activation == "silu"
834
+ else ActivationType.Gelu
835
+ ),
836
+ doweight_stage1=False,
837
+ )