sglang 0.5.0rc2__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. sglang/bench_one_batch.py +0 -6
  2. sglang/bench_one_batch_server.py +7 -2
  3. sglang/bench_serving.py +3 -3
  4. sglang/eval/llama3_eval.py +0 -1
  5. sglang/srt/configs/model_config.py +24 -9
  6. sglang/srt/configs/update_config.py +40 -5
  7. sglang/srt/constrained/xgrammar_backend.py +23 -11
  8. sglang/srt/conversation.py +2 -15
  9. sglang/srt/disaggregation/ascend/conn.py +1 -3
  10. sglang/srt/disaggregation/base/conn.py +1 -0
  11. sglang/srt/disaggregation/decode.py +1 -1
  12. sglang/srt/disaggregation/launch_lb.py +7 -1
  13. sglang/srt/disaggregation/mini_lb.py +11 -5
  14. sglang/srt/disaggregation/mooncake/conn.py +141 -47
  15. sglang/srt/disaggregation/prefill.py +261 -5
  16. sglang/srt/disaggregation/utils.py +2 -1
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  18. sglang/srt/distributed/device_communicators/pynccl.py +68 -18
  19. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
  20. sglang/srt/distributed/naive_distributed.py +112 -0
  21. sglang/srt/distributed/parallel_state.py +90 -4
  22. sglang/srt/entrypoints/context.py +20 -1
  23. sglang/srt/entrypoints/engine.py +27 -2
  24. sglang/srt/entrypoints/http_server.py +12 -0
  25. sglang/srt/entrypoints/openai/protocol.py +2 -2
  26. sglang/srt/entrypoints/openai/serving_chat.py +22 -6
  27. sglang/srt/entrypoints/openai/serving_completions.py +9 -1
  28. sglang/srt/entrypoints/openai/serving_responses.py +2 -2
  29. sglang/srt/eplb/expert_distribution.py +2 -3
  30. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  31. sglang/srt/hf_transformers_utils.py +24 -0
  32. sglang/srt/host_shared_memory.py +83 -0
  33. sglang/srt/layers/attention/ascend_backend.py +132 -22
  34. sglang/srt/layers/attention/flashattention_backend.py +24 -17
  35. sglang/srt/layers/attention/flashinfer_backend.py +11 -3
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +226 -76
  37. sglang/srt/layers/attention/triton_backend.py +85 -46
  38. sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
  39. sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
  40. sglang/srt/layers/attention/trtllm_mha_backend.py +390 -30
  41. sglang/srt/layers/attention/trtllm_mla_backend.py +39 -16
  42. sglang/srt/layers/attention/utils.py +94 -15
  43. sglang/srt/layers/attention/vision.py +40 -13
  44. sglang/srt/layers/attention/vision_utils.py +65 -0
  45. sglang/srt/layers/communicator.py +51 -3
  46. sglang/srt/layers/dp_attention.py +23 -4
  47. sglang/srt/layers/elementwise.py +94 -0
  48. sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
  49. sglang/srt/layers/layernorm.py +8 -1
  50. sglang/srt/layers/linear.py +24 -0
  51. sglang/srt/layers/logits_processor.py +5 -1
  52. sglang/srt/layers/moe/__init__.py +31 -0
  53. sglang/srt/layers/moe/ep_moe/layer.py +37 -33
  54. sglang/srt/layers/moe/fused_moe_native.py +14 -25
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
  60. sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
  61. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
  62. sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
  63. sglang/srt/layers/moe/moe_runner/base.py +13 -0
  64. sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
  65. sglang/srt/layers/moe/router.py +15 -9
  66. sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
  67. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
  68. sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
  69. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  70. sglang/srt/layers/moe/topk.py +167 -83
  71. sglang/srt/layers/moe/utils.py +159 -18
  72. sglang/srt/layers/quantization/__init__.py +13 -14
  73. sglang/srt/layers/quantization/awq.py +7 -7
  74. sglang/srt/layers/quantization/base_config.py +2 -6
  75. sglang/srt/layers/quantization/blockwise_int8.py +4 -12
  76. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -28
  77. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
  78. sglang/srt/layers/quantization/fp8.py +127 -119
  79. sglang/srt/layers/quantization/fp8_kernel.py +195 -24
  80. sglang/srt/layers/quantization/fp8_utils.py +34 -9
  81. sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
  82. sglang/srt/layers/quantization/gptq.py +5 -4
  83. sglang/srt/layers/quantization/marlin_utils.py +11 -3
  84. sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
  85. sglang/srt/layers/quantization/modelopt_quant.py +165 -68
  86. sglang/srt/layers/quantization/moe_wna16.py +10 -15
  87. sglang/srt/layers/quantization/mxfp4.py +206 -37
  88. sglang/srt/layers/quantization/quark/quark.py +390 -0
  89. sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
  90. sglang/srt/layers/quantization/unquant.py +34 -70
  91. sglang/srt/layers/quantization/utils.py +25 -0
  92. sglang/srt/layers/quantization/w4afp8.py +7 -8
  93. sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
  94. sglang/srt/layers/quantization/w8a8_int8.py +5 -13
  95. sglang/srt/layers/radix_attention.py +6 -0
  96. sglang/srt/layers/rotary_embedding.py +1 -0
  97. sglang/srt/lora/lora_manager.py +21 -22
  98. sglang/srt/lora/lora_registry.py +3 -3
  99. sglang/srt/lora/mem_pool.py +26 -24
  100. sglang/srt/lora/utils.py +10 -12
  101. sglang/srt/managers/cache_controller.py +76 -18
  102. sglang/srt/managers/detokenizer_manager.py +10 -2
  103. sglang/srt/managers/io_struct.py +9 -0
  104. sglang/srt/managers/mm_utils.py +1 -1
  105. sglang/srt/managers/schedule_batch.py +4 -9
  106. sglang/srt/managers/scheduler.py +25 -16
  107. sglang/srt/managers/session_controller.py +1 -1
  108. sglang/srt/managers/template_manager.py +7 -5
  109. sglang/srt/managers/tokenizer_manager.py +60 -21
  110. sglang/srt/managers/tp_worker.py +1 -0
  111. sglang/srt/managers/utils.py +59 -1
  112. sglang/srt/mem_cache/allocator.py +7 -5
  113. sglang/srt/mem_cache/allocator_ascend.py +0 -11
  114. sglang/srt/mem_cache/hicache_storage.py +14 -4
  115. sglang/srt/mem_cache/memory_pool.py +3 -3
  116. sglang/srt/mem_cache/memory_pool_host.py +35 -2
  117. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
  118. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
  119. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
  120. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
  121. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
  122. sglang/srt/model_executor/cuda_graph_runner.py +25 -12
  123. sglang/srt/model_executor/forward_batch_info.py +4 -1
  124. sglang/srt/model_executor/model_runner.py +43 -32
  125. sglang/srt/model_executor/npu_graph_runner.py +94 -0
  126. sglang/srt/model_loader/loader.py +24 -6
  127. sglang/srt/models/dbrx.py +12 -6
  128. sglang/srt/models/deepseek.py +2 -1
  129. sglang/srt/models/deepseek_nextn.py +3 -1
  130. sglang/srt/models/deepseek_v2.py +224 -223
  131. sglang/srt/models/ernie4.py +2 -2
  132. sglang/srt/models/glm4_moe.py +25 -63
  133. sglang/srt/models/glm4v.py +52 -1
  134. sglang/srt/models/glm4v_moe.py +8 -11
  135. sglang/srt/models/gpt_oss.py +34 -74
  136. sglang/srt/models/granitemoe.py +0 -1
  137. sglang/srt/models/grok.py +376 -48
  138. sglang/srt/models/interns1.py +12 -47
  139. sglang/srt/models/internvl.py +6 -51
  140. sglang/srt/models/llama4.py +0 -2
  141. sglang/srt/models/minicpm3.py +0 -1
  142. sglang/srt/models/mixtral.py +0 -2
  143. sglang/srt/models/nemotron_nas.py +435 -0
  144. sglang/srt/models/olmoe.py +0 -1
  145. sglang/srt/models/phi4mm.py +3 -21
  146. sglang/srt/models/qwen2_5_vl.py +2 -0
  147. sglang/srt/models/qwen2_moe.py +3 -18
  148. sglang/srt/models/qwen3.py +2 -2
  149. sglang/srt/models/qwen3_classification.py +7 -1
  150. sglang/srt/models/qwen3_moe.py +9 -38
  151. sglang/srt/models/step3_vl.py +2 -1
  152. sglang/srt/models/xverse_moe.py +11 -5
  153. sglang/srt/multimodal/processors/base_processor.py +3 -3
  154. sglang/srt/multimodal/processors/internvl.py +7 -2
  155. sglang/srt/multimodal/processors/llava.py +11 -7
  156. sglang/srt/offloader.py +433 -0
  157. sglang/srt/operations.py +6 -1
  158. sglang/srt/reasoning_parser.py +4 -3
  159. sglang/srt/server_args.py +237 -104
  160. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
  161. sglang/srt/speculative/eagle_utils.py +36 -13
  162. sglang/srt/speculative/eagle_worker.py +56 -3
  163. sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
  164. sglang/srt/two_batch_overlap.py +16 -11
  165. sglang/srt/utils.py +68 -70
  166. sglang/test/runners.py +8 -5
  167. sglang/test/test_block_fp8.py +5 -6
  168. sglang/test/test_block_fp8_ep.py +13 -19
  169. sglang/test/test_cutlass_moe.py +4 -6
  170. sglang/test/test_cutlass_w4a8_moe.py +4 -3
  171. sglang/test/test_fp4_moe.py +4 -3
  172. sglang/test/test_utils.py +7 -0
  173. sglang/utils.py +0 -1
  174. sglang/version.py +1 -1
  175. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/METADATA +7 -7
  176. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/RECORD +179 -161
  177. sglang/srt/layers/quantization/fp4.py +0 -557
  178. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
  179. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
  180. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
@@ -7,8 +7,13 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional
7
7
  import torch
8
8
  from torch.nn.parameter import Parameter
9
9
 
10
+ from sglang.srt.distributed import get_tp_group
11
+ from sglang.srt.layers.dp_attention import get_dp_global_num_tokens, get_local_dp_buffer
12
+ from sglang.srt.layers.moe import (
13
+ should_use_flashinfer_cutlass_moe_fp4_allgather,
14
+ should_use_flashinfer_trtllm_moe,
15
+ )
10
16
  from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
11
- from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe
12
17
  from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
13
18
  from sglang.srt.layers.quantization.base_config import (
14
19
  FusedMoEMethodBase,
@@ -30,10 +35,11 @@ from sglang.srt.layers.quantization.utils import (
30
35
  requantize_with_max_scale,
31
36
  )
32
37
  from sglang.srt.layers.radix_attention import RadixAttention
33
- from sglang.srt.managers.schedule_batch import global_server_args_dict
34
38
  from sglang.srt.utils import is_cuda, next_power_of_2
35
39
 
36
40
  if TYPE_CHECKING:
41
+ from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
42
+ from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
37
43
  from sglang.srt.layers.moe.topk import TopKOutput
38
44
 
39
45
  if is_cuda():
@@ -105,18 +111,52 @@ class ModelOptFp8Config(QuantizationConfig):
105
111
 
106
112
  @classmethod
107
113
  def from_config(cls, config: Dict[str, Any]) -> ModelOptFp8Config:
108
- quant_method = cls.get_from_keys(config, ["quantization"]).get("quant_algo")
109
- kv_cache_quant_method = cls.get_from_keys(config, ["quantization"]).get(
110
- "kv_cache_quant_algo"
111
- )
112
- exclude_modules = cls.get_from_keys(config, ["quantization"]).get(
113
- "exclude_modules"
114
- )
114
+ # Handle two different config formats:
115
+ # 1. hf_quant_config.json format: {"quantization": {"quant_algo": "FP8", ...}}
116
+ # 2. config.json quantization_config format: {"quant_algo": "FP8", ...}
117
+ # In future modelopt will deprecate hf_quant_config.json, and only keep config.json.
118
+ # For legacy reasons, we keep hf_quant_config.json for now.
119
+
120
+ # Initialize variables
121
+ kv_cache_quant_method = None
122
+ exclude_modules = None
123
+
124
+ # Try flat format first (config.json quantization_config - preferred format)
125
+ quant_method = config.get("quant_algo")
126
+ if quant_method is not None:
127
+ # Flat format (config.json quantization_config)
128
+ # For kv_cache, check if kv_cache_scheme exists and extract algo
129
+ kv_cache_scheme = config.get("kv_cache_scheme")
130
+ if (
131
+ kv_cache_scheme
132
+ and kv_cache_scheme.get("type") == "float"
133
+ and kv_cache_scheme.get("num_bits") == 8
134
+ ):
135
+ kv_cache_quant_method = "FP8"
115
136
 
137
+ # Map 'ignore' field to 'exclude_modules'
138
+ exclude_modules = config.get("ignore")
139
+ else:
140
+ # Fall back to nested format (hf_quant_config.json - legacy format)
141
+ try:
142
+ quantization_section = cls.get_from_keys(config, ["quantization"])
143
+ quant_method = quantization_section.get("quant_algo")
144
+ kv_cache_quant_method = quantization_section.get("kv_cache_quant_algo")
145
+ exclude_modules = quantization_section.get("exclude_modules")
146
+ except ValueError:
147
+ raise ValueError(
148
+ "Cannot find 'quant_algo' in the model's quantization config. "
149
+ "Expected either flat format (config.json) or nested format (hf_quant_config.json)."
150
+ )
151
+ if quant_method is None:
152
+ raise ValueError(
153
+ "Cannot find 'quant_algo' in the model's quantization config. "
154
+ )
116
155
  if "FP8" not in quant_method:
117
156
  raise ValueError(
118
- "ModelOpt only supports static FP8 quantization in SGLang. "
119
- "Check the `hf_quant_config.json` file for your model's configuration."
157
+ "ModelOptFp8Config only supports static FP8 quantization in SGLang. "
158
+ "For FP4 quantization, use ModelOptFp4Config. "
159
+ "Check the quantization config for your model's configuration."
120
160
  )
121
161
 
122
162
  return cls(
@@ -422,12 +462,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
422
462
  layer: torch.nn.Module,
423
463
  x: torch.Tensor,
424
464
  topk_output: TopKOutput,
425
- *,
426
- activation: str = "silu",
427
- apply_router_weight_on_input: bool = False,
428
- inplace: bool = True,
429
- no_combine: bool = False,
430
- routed_scaling_factor: Optional[float] = None,
465
+ moe_runner_config: MoeRunnerConfig,
431
466
  ) -> torch.Tensor:
432
467
  from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
433
468
 
@@ -436,15 +471,13 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
436
471
  layer.w13_weight,
437
472
  layer.w2_weight,
438
473
  topk_output=topk_output,
439
- inplace=inplace,
440
- activation=activation,
474
+ moe_runner_config=moe_runner_config,
441
475
  use_fp8_w8a8=True,
442
476
  per_channel_quant=False, # ModelOpt uses per-tensor quantization
443
477
  w1_scale=layer.w13_weight_scale,
444
478
  w2_scale=layer.w2_weight_scale,
445
479
  a1_scale=layer.w13_input_scale,
446
480
  a2_scale=layer.w2_input_scale,
447
- no_combine=no_combine,
448
481
  )
449
482
 
450
483
 
@@ -486,22 +519,63 @@ class ModelOptFp4Config(QuantizationConfig):
486
519
 
487
520
  @classmethod
488
521
  def from_config(cls, config: Dict[str, Any]) -> ModelOptFp4Config:
489
- quant_config = cls.get_from_keys(config, ["quantization"])
490
- quant_method = quant_config["quant_algo"]
522
+ # Handle two different config formats:
523
+ # 1. hf_quant_config.json format: {"quantization": {"quant_algo": "NVFP4", ...}}
524
+ # 2. config.json quantization_config format: {"quant_algo": "NVFP4", ...}
525
+ # In future modelopt will deprecate hf_quant_config.json, and only keep config.json.
526
+ # For legacy reasons, we keep hf_quant_config.json for now.
527
+
528
+ # Initialize variables
529
+ kv_cache_quant_algo = None
530
+ group_size = None
531
+ exclude_modules = []
532
+
533
+ # Try flat format first (config.json quantization_config - preferred format)
534
+ quant_method = config.get("quant_algo")
535
+ if quant_method is not None:
536
+ # Flat format (config.json quantization_config)
537
+ # Note: FP4 models in config.json format may not have all the detailed fields
538
+ # that are present in hf_quant_config.json, so we need to handle defaults
539
+ kv_cache_quant_algo = config.get("kv_cache_quant_algo")
540
+ if not kv_cache_quant_algo:
541
+ # For config.json format, derive from kv_cache_scheme if available
542
+ kv_cache_scheme = config.get("kv_cache_scheme")
543
+ if (
544
+ kv_cache_scheme
545
+ and kv_cache_scheme.get("type") == "float"
546
+ and kv_cache_scheme.get("num_bits") == 8
547
+ ):
548
+ kv_cache_quant_algo = "FP8"
549
+ else:
550
+ kv_cache_quant_algo = "auto"
551
+
552
+ group_size = config.get("group_size")
553
+ exclude_modules = config.get("ignore", [])
554
+ else:
555
+ # Fall back to nested format (hf_quant_config.json - legacy format)
556
+ try:
557
+ quant_config = cls.get_from_keys(config, ["quantization"])
558
+ quant_method = quant_config["quant_algo"]
559
+ kv_cache_quant_algo = quant_config.get("kv_cache_quant_algo")
560
+ if not kv_cache_quant_algo:
561
+ kv_cache_quant_algo = "auto"
562
+ group_size = quant_config.get("group_size")
563
+ exclude_modules = quant_config.get("exclude_modules", [])
564
+ except (ValueError, KeyError):
565
+ raise ValueError(
566
+ "Cannot find 'quant_algo' in the model's quantization config. "
567
+ "Expected either flat format (config.json) or nested format (hf_quant_config.json)."
568
+ )
569
+
491
570
  if not quant_method in ["FP8", "NVFP4"]:
492
571
  raise ValueError(
493
572
  f"ModelOpt currently only supports: FP8, NVFP4"
494
573
  " quantizations in sglang. Please check the "
495
- "`hf_quant_config.json` file for your model's "
496
- "quant configuration."
574
+ "quantization config for your model's configuration."
497
575
  )
498
576
  is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
499
- kv_cache_quant_algo = quant_config["kv_cache_quant_algo"]
500
- if not kv_cache_quant_algo:
501
- kv_cache_quant_algo = "auto"
502
- group_size = quant_config["group_size"]
503
- exclude_modules = quant_config["exclude_modules"]
504
- if not (group_size and kv_cache_quant_algo and exclude_modules):
577
+
578
+ if not (group_size and kv_cache_quant_algo) or exclude_modules is None:
505
579
  logger.warning(
506
580
  f"group_size: {group_size},"
507
581
  f"kv_cache_quant_algo: {kv_cache_quant_algo},"
@@ -509,8 +583,7 @@ class ModelOptFp4Config(QuantizationConfig):
509
583
  )
510
584
  raise ValueError(
511
585
  "NVFP4 quantization requires group size and "
512
- "kv_cache_quant_algo specified in "
513
- "hf_quant_config.json"
586
+ "kv_cache_quant_algo specified in the quantization config"
514
587
  )
515
588
  return cls(
516
589
  is_checkpoint_nvfp4_serialized,
@@ -741,8 +814,10 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
741
814
 
742
815
  @property
743
816
  def enable_flashinfer_cutlass_moe(self) -> bool:
817
+ from sglang.srt.layers.moe import get_moe_runner_backend
818
+
744
819
  """Access the global enable_flashinfer_cutlass_moe setting."""
745
- return global_server_args_dict.get("enable_flashinfer_cutlass_moe", False)
820
+ return get_moe_runner_backend().is_flashinfer_cutlass()
746
821
 
747
822
  def create_weights(
748
823
  self,
@@ -811,6 +886,11 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
811
886
  )
812
887
  layer.register_parameter("w13_weight_scale", w13_weight_scale)
813
888
 
889
+ # Only use `swizzle_blockscale` for shapes, not for real content
890
+ layer.w13_blockscale_swizzled = Parameter(
891
+ self.swizzle_blockscale(layer.w13_weight_scale), requires_grad=False
892
+ )
893
+
814
894
  w2_weight_scale = ModelWeightParameter(
815
895
  data=torch.empty(
816
896
  layer.num_local_experts,
@@ -825,6 +905,10 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
825
905
  )
826
906
  layer.register_parameter("w2_weight_scale", w2_weight_scale)
827
907
 
908
+ layer.w2_blockscale_swizzled = Parameter(
909
+ self.swizzle_blockscale(layer.w2_weight_scale), requires_grad=False
910
+ )
911
+
828
912
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
829
913
 
830
914
  extra_weight_attrs.update(
@@ -1128,16 +1212,12 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1128
1212
 
1129
1213
  # Process w13 weights
1130
1214
  w13_blockscale_swizzled = self.swizzle_blockscale(layer.w13_weight_scale)
1131
- layer.w13_blockscale_swizzled = Parameter(
1132
- w13_blockscale_swizzled, requires_grad=False
1133
- )
1215
+ layer.w13_blockscale_swizzled.data.copy_(w13_blockscale_swizzled)
1134
1216
  layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
1135
1217
 
1136
1218
  # Process w2 weights
1137
1219
  w2_blockscale_swizzled = self.swizzle_blockscale(layer.w2_weight_scale)
1138
- layer.w2_blockscale_swizzled = Parameter(
1139
- w2_blockscale_swizzled, requires_grad=False
1140
- )
1220
+ layer.w2_blockscale_swizzled.data.copy_(w2_blockscale_swizzled)
1141
1221
  layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
1142
1222
 
1143
1223
  # Both flashinfer cutlass and regular cutlass use same processing for w2
@@ -1160,21 +1240,14 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1160
1240
 
1161
1241
  def apply(
1162
1242
  self,
1163
- layer: torch.nn.Module,
1243
+ layer: FusedMoE,
1164
1244
  x: torch.Tensor,
1165
1245
  topk_output: TopKOutput,
1166
- *,
1167
- activation: str = "silu",
1168
- apply_router_weight_on_input: bool = False,
1169
- inplace: bool = True,
1170
- no_combine: bool = False,
1171
- routed_scaling_factor: Optional[float] = None,
1172
- ep_rank: Optional[int] = None,
1173
- ep_size: Optional[int] = None,
1174
- tp_rank: Optional[int] = None,
1175
- tp_size: Optional[int] = None,
1246
+ moe_runner_config: MoeRunnerConfig,
1176
1247
  ) -> torch.Tensor:
1177
- assert activation == "silu", "Only SiLU activation is supported."
1248
+ assert (
1249
+ moe_runner_config.activation == "silu"
1250
+ ), "Only SiLU activation is supported."
1178
1251
 
1179
1252
  # Check if this is a FlashInferFP4MoE layer that should handle its own forward
1180
1253
  if hasattr(layer, "gemm1_weights_fp4_shuffled"):
@@ -1183,20 +1256,41 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1183
1256
 
1184
1257
  if self.enable_flashinfer_cutlass_moe:
1185
1258
  assert (
1186
- not apply_router_weight_on_input
1259
+ not moe_runner_config.apply_router_weight_on_input
1187
1260
  ), "apply_router_weight_on_input is not supported for Flashinfer"
1188
1261
  # TRTLLM Cutlass moe takes in activations in BF16/Half/nvfp4 precision
1189
1262
  # and fp4 quantized weights loaded from the checkpoint
1190
-
1191
1263
  topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids
1192
1264
 
1265
+ output_dtype = x.dtype
1266
+ x_sf = None
1267
+ if should_use_flashinfer_cutlass_moe_fp4_allgather():
1268
+ from flashinfer import fp4_quantize, nvfp4_block_scale_interleave
1269
+
1270
+ # Quantize before comm, swizzle after.
1271
+ if x.shape[0] > 0:
1272
+ x, x_sf = fp4_quantize(
1273
+ x, layer.w13_input_scale_quant, is_sf_swizzled_layout=False
1274
+ )
1275
+ else:
1276
+ x_col = x.shape[1]
1277
+ x = torch.zeros(0, x_col // 2, dtype=torch.uint8, device=x.device)
1278
+ x_sf = torch.zeros(
1279
+ 0, x_col // 16, dtype=torch.uint8, device=x.device
1280
+ )
1281
+ topk_weights, topk_ids, x, x_sf = get_tp_group().all_gatherv(
1282
+ [topk_weights, topk_ids, x, x_sf], sizes=get_dp_global_num_tokens()
1283
+ )
1284
+ x_sf = nvfp4_block_scale_interleave(x_sf)
1285
+
1193
1286
  output = flashinfer_cutlass_fused_moe(
1194
- x,
1195
- topk_ids.to(torch.int),
1196
- topk_weights,
1197
- layer.w13_weight.view(torch.long),
1198
- layer.w2_weight.view(torch.long),
1199
- x.dtype,
1287
+ input=x,
1288
+ token_selected_experts=topk_ids.to(torch.int),
1289
+ token_final_scales=topk_weights,
1290
+ fc1_expert_weights=layer.w13_weight.view(torch.long),
1291
+ fc2_expert_weights=layer.w2_weight.view(torch.long),
1292
+ output_dtype=output_dtype,
1293
+ input_sf=x_sf,
1200
1294
  quant_scales=[
1201
1295
  layer.w13_input_scale_quant,
1202
1296
  layer.w13_blockscale_swizzled.view(torch.int32),
@@ -1205,14 +1299,18 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1205
1299
  layer.w2_blockscale_swizzled.view(torch.int32),
1206
1300
  layer.g2_alphas,
1207
1301
  ],
1208
- ep_size=ep_size,
1209
- ep_rank=ep_rank,
1210
- tp_size=tp_size,
1211
- tp_rank=tp_rank,
1302
+ ep_size=layer.moe_ep_size,
1303
+ ep_rank=layer.moe_ep_rank,
1304
+ tp_size=layer.moe_tp_size,
1305
+ tp_rank=layer.moe_tp_rank,
1212
1306
  tune_max_num_tokens=next_power_of_2(x.shape[0]),
1213
1307
  )[0]
1214
- if routed_scaling_factor is not None:
1215
- output *= routed_scaling_factor
1308
+ # Scale by routed_scaling_factor is fused into select_experts.
1309
+ if should_use_flashinfer_cutlass_moe_fp4_allgather():
1310
+ output, global_output = get_local_dp_buffer(), output
1311
+ get_tp_group().reduce_scatterv(
1312
+ global_output, output=output, sizes=get_dp_global_num_tokens()
1313
+ )
1216
1314
  return output
1217
1315
 
1218
1316
  from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
@@ -1231,8 +1329,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1231
1329
  topk_weights=topk_weights,
1232
1330
  topk_ids=topk_ids,
1233
1331
  params=layer.cutlass_moe_params,
1234
- apply_router_weight_on_input=apply_router_weight_on_input,
1332
+ apply_router_weight_on_input=moe_runner_config.apply_router_weight_on_input,
1235
1333
  ).to(x.dtype)
1236
- if routed_scaling_factor is not None:
1237
- output *= routed_scaling_factor
1334
+ # Scale by routed_scaling_factor is fused into select_experts.
1238
1335
  return output
@@ -22,6 +22,7 @@ from sglang.srt.utils import get_device_capability, set_weight_attrs
22
22
  logger = logging.getLogger(__name__)
23
23
 
24
24
  if TYPE_CHECKING:
25
+ from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
25
26
  from sglang.srt.layers.moe.topk import TopKOutput
26
27
 
27
28
 
@@ -353,17 +354,14 @@ class MoeWNA16Method(FusedMoEMethodBase):
353
354
  layer: torch.nn.Module,
354
355
  x: torch.Tensor,
355
356
  topk_output: TopKOutput,
356
- *,
357
- activation: str = "silu",
358
- apply_router_weight_on_input: bool = False,
359
- inplace: bool = True,
360
- no_combine: bool = False,
361
- routed_scaling_factor: Optional[float] = None,
357
+ moe_runner_config: MoeRunnerConfig,
362
358
  ) -> torch.Tensor:
363
359
  # avoid circular import
364
360
  from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
365
361
 
366
- assert activation == "silu", "Only SiLU activation is supported."
362
+ assert (
363
+ moe_runner_config.activation == "silu"
364
+ ), "Only SiLU activation is supported."
367
365
 
368
366
  weight_bits = self.quant_config.weight_bits
369
367
  has_zp = self.quant_config.has_zp
@@ -373,8 +371,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
373
371
  layer.w13_qweight,
374
372
  layer.w2_qweight,
375
373
  topk_output=topk_output,
376
- inplace=inplace,
377
- apply_router_weight_on_input=apply_router_weight_on_input,
374
+ moe_runner_config=moe_runner_config,
378
375
  use_int4_w4a16=weight_bits == 4,
379
376
  use_int8_w8a16=weight_bits == 8,
380
377
  w1_scale=layer.w13_scales,
@@ -382,8 +379,6 @@ class MoeWNA16Method(FusedMoEMethodBase):
382
379
  w1_zp=layer.w13_qzeros if has_zp else None,
383
380
  w2_zp=layer.w2_qzeros if has_zp else None,
384
381
  block_shape=[0, layer.group_size],
385
- no_combine=no_combine,
386
- routed_scaling_factor=routed_scaling_factor,
387
382
  )
388
383
 
389
384
  @staticmethod
@@ -486,16 +481,16 @@ class MoeWNA16Method(FusedMoEMethodBase):
486
481
  )
487
482
 
488
483
  if "w13_qzeros" in weight_name:
489
- tensor = loaded_weight.view(layer.tp_size, -1, loaded_weight.size(1))[
490
- tp_rank
491
- ]
484
+ tensor = loaded_weight.view(
485
+ layer.moe_tp_size, -1, loaded_weight.size(1)
486
+ )[tp_rank]
492
487
  if shard_id == "w1":
493
488
  param.data[expert_id, : shard_size // 2] = tensor
494
489
  else:
495
490
  param.data[expert_id, shard_size // 2 :] = tensor
496
491
  elif "w2_qzeros" in weight_name:
497
492
  param.data[expert_id] = loaded_weight.view(
498
- loaded_weight.size(0), layer.tp_size, -1
493
+ loaded_weight.size(0), layer.moe_tp_size, -1
499
494
  )[:, tp_rank]
500
495
  else:
501
496
  weight_loader(param, loaded_weight, weight_name, shard_id, expert_id)