sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. sglang/__init__.py +8 -3
  2. sglang/bench_one_batch.py +119 -17
  3. sglang/lang/chat_template.py +18 -0
  4. sglang/srt/bench_utils.py +137 -0
  5. sglang/srt/configs/model_config.py +42 -7
  6. sglang/srt/conversation.py +9 -5
  7. sglang/srt/disaggregation/base/conn.py +5 -2
  8. sglang/srt/disaggregation/decode.py +14 -4
  9. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
  10. sglang/srt/disaggregation/mooncake/conn.py +286 -160
  11. sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
  12. sglang/srt/disaggregation/prefill.py +2 -0
  13. sglang/srt/distributed/parallel_state.py +15 -11
  14. sglang/srt/entrypoints/context.py +227 -0
  15. sglang/srt/entrypoints/engine.py +15 -9
  16. sglang/srt/entrypoints/harmony_utils.py +372 -0
  17. sglang/srt/entrypoints/http_server.py +74 -4
  18. sglang/srt/entrypoints/openai/protocol.py +218 -1
  19. sglang/srt/entrypoints/openai/serving_chat.py +41 -11
  20. sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
  21. sglang/srt/entrypoints/openai/tool_server.py +175 -0
  22. sglang/srt/entrypoints/tool.py +87 -0
  23. sglang/srt/eplb/expert_location.py +5 -1
  24. sglang/srt/function_call/ebnf_composer.py +1 -0
  25. sglang/srt/function_call/function_call_parser.py +2 -0
  26. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  27. sglang/srt/function_call/gpt_oss_detector.py +331 -0
  28. sglang/srt/function_call/kimik2_detector.py +3 -3
  29. sglang/srt/function_call/qwen3_coder_detector.py +219 -9
  30. sglang/srt/hf_transformers_utils.py +30 -3
  31. sglang/srt/jinja_template_utils.py +14 -1
  32. sglang/srt/layers/attention/aiter_backend.py +375 -115
  33. sglang/srt/layers/attention/ascend_backend.py +3 -0
  34. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
  35. sglang/srt/layers/attention/flashattention_backend.py +18 -0
  36. sglang/srt/layers/attention/flashinfer_backend.py +52 -13
  37. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  38. sglang/srt/layers/attention/triton_backend.py +85 -14
  39. sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
  40. sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
  41. sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
  42. sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
  43. sglang/srt/layers/attention/vision.py +22 -6
  44. sglang/srt/layers/attention/wave_backend.py +627 -0
  45. sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
  46. sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
  47. sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
  48. sglang/srt/layers/communicator.py +29 -14
  49. sglang/srt/layers/dp_attention.py +12 -0
  50. sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
  51. sglang/srt/layers/linear.py +3 -7
  52. sglang/srt/layers/moe/cutlass_moe.py +12 -3
  53. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
  54. sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
  55. sglang/srt/layers/moe/ep_moe/layer.py +135 -73
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
  59. sglang/srt/layers/moe/fused_moe_triton/layer.py +412 -33
  60. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
  61. sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
  62. sglang/srt/layers/moe/topk.py +16 -4
  63. sglang/srt/layers/moe/utils.py +16 -0
  64. sglang/srt/layers/quantization/__init__.py +27 -3
  65. sglang/srt/layers/quantization/fp4.py +557 -0
  66. sglang/srt/layers/quantization/fp8.py +3 -6
  67. sglang/srt/layers/quantization/fp8_kernel.py +277 -0
  68. sglang/srt/layers/quantization/fp8_utils.py +51 -10
  69. sglang/srt/layers/quantization/modelopt_quant.py +258 -68
  70. sglang/srt/layers/quantization/mxfp4.py +654 -0
  71. sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
  72. sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
  73. sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
  74. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
  75. sglang/srt/layers/quantization/quark/utils.py +107 -0
  76. sglang/srt/layers/quantization/unquant.py +60 -6
  77. sglang/srt/layers/quantization/w4afp8.py +21 -12
  78. sglang/srt/layers/quantization/w8a8_int8.py +48 -34
  79. sglang/srt/layers/rotary_embedding.py +506 -3
  80. sglang/srt/layers/utils.py +9 -0
  81. sglang/srt/layers/vocab_parallel_embedding.py +8 -3
  82. sglang/srt/lora/backend/base_backend.py +3 -23
  83. sglang/srt/lora/layers.py +60 -114
  84. sglang/srt/lora/lora.py +17 -62
  85. sglang/srt/lora/lora_manager.py +82 -62
  86. sglang/srt/lora/lora_registry.py +23 -11
  87. sglang/srt/lora/mem_pool.py +63 -68
  88. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  89. sglang/srt/lora/utils.py +25 -58
  90. sglang/srt/managers/cache_controller.py +75 -58
  91. sglang/srt/managers/detokenizer_manager.py +1 -1
  92. sglang/srt/managers/io_struct.py +20 -8
  93. sglang/srt/managers/mm_utils.py +6 -13
  94. sglang/srt/managers/multimodal_processor.py +1 -1
  95. sglang/srt/managers/schedule_batch.py +61 -25
  96. sglang/srt/managers/schedule_policy.py +6 -6
  97. sglang/srt/managers/scheduler.py +41 -19
  98. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
  99. sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
  100. sglang/srt/managers/scheduler_recv_skipper.py +37 -0
  101. sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
  102. sglang/srt/managers/template_manager.py +35 -1
  103. sglang/srt/managers/tokenizer_manager.py +47 -30
  104. sglang/srt/managers/tp_worker.py +3 -0
  105. sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
  106. sglang/srt/mem_cache/allocator.py +61 -87
  107. sglang/srt/mem_cache/hicache_storage.py +1 -1
  108. sglang/srt/mem_cache/hiradix_cache.py +80 -22
  109. sglang/srt/mem_cache/lora_radix_cache.py +421 -0
  110. sglang/srt/mem_cache/memory_pool_host.py +34 -36
  111. sglang/srt/mem_cache/multimodal_cache.py +33 -13
  112. sglang/srt/mem_cache/radix_cache.py +2 -5
  113. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
  114. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
  115. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
  116. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
  117. sglang/srt/model_executor/cuda_graph_runner.py +29 -9
  118. sglang/srt/model_executor/forward_batch_info.py +61 -19
  119. sglang/srt/model_executor/model_runner.py +148 -37
  120. sglang/srt/model_loader/loader.py +18 -6
  121. sglang/srt/model_loader/weight_utils.py +10 -0
  122. sglang/srt/models/bailing_moe.py +425 -0
  123. sglang/srt/models/deepseek_v2.py +137 -59
  124. sglang/srt/models/ernie4.py +426 -0
  125. sglang/srt/models/ernie4_eagle.py +203 -0
  126. sglang/srt/models/gemma2.py +0 -34
  127. sglang/srt/models/gemma3n_mm.py +38 -0
  128. sglang/srt/models/glm4.py +6 -0
  129. sglang/srt/models/glm4_moe.py +28 -16
  130. sglang/srt/models/glm4v.py +589 -0
  131. sglang/srt/models/glm4v_moe.py +400 -0
  132. sglang/srt/models/gpt_oss.py +1251 -0
  133. sglang/srt/models/granite.py +0 -25
  134. sglang/srt/models/llama.py +0 -25
  135. sglang/srt/models/llama4.py +1 -1
  136. sglang/srt/models/qwen2.py +6 -0
  137. sglang/srt/models/qwen2_5_vl.py +7 -3
  138. sglang/srt/models/qwen2_audio.py +10 -9
  139. sglang/srt/models/qwen2_moe.py +6 -0
  140. sglang/srt/models/qwen3.py +0 -24
  141. sglang/srt/models/qwen3_moe.py +32 -6
  142. sglang/srt/models/registry.py +1 -1
  143. sglang/srt/models/step3_vl.py +9 -0
  144. sglang/srt/models/torch_native_llama.py +0 -24
  145. sglang/srt/models/transformers.py +2 -5
  146. sglang/srt/multimodal/processors/base_processor.py +23 -13
  147. sglang/srt/multimodal/processors/glm4v.py +132 -0
  148. sglang/srt/multimodal/processors/qwen_audio.py +4 -2
  149. sglang/srt/multimodal/processors/step3_vl.py +3 -1
  150. sglang/srt/reasoning_parser.py +332 -37
  151. sglang/srt/server_args.py +186 -75
  152. sglang/srt/speculative/eagle_worker.py +16 -0
  153. sglang/srt/two_batch_overlap.py +169 -9
  154. sglang/srt/utils.py +41 -5
  155. sglang/srt/weight_sync/tensor_bucket.py +106 -0
  156. sglang/test/attention/test_trtllm_mla_backend.py +186 -36
  157. sglang/test/doc_patch.py +59 -0
  158. sglang/test/few_shot_gsm8k.py +1 -1
  159. sglang/test/few_shot_gsm8k_engine.py +1 -1
  160. sglang/test/run_eval.py +4 -1
  161. sglang/test/runners.py +2 -2
  162. sglang/test/simple_eval_common.py +6 -0
  163. sglang/test/simple_eval_gpqa.py +2 -0
  164. sglang/test/test_fp4_moe.py +118 -36
  165. sglang/test/test_utils.py +1 -1
  166. sglang/utils.py +1 -1
  167. sglang/version.py +1 -1
  168. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +36 -38
  169. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +174 -141
  170. sglang/srt/lora/backend/flashinfer_backend.py +0 -131
  171. /sglang/{api.py → lang/api.py} +0 -0
  172. /sglang/{lang/backend → srt/layers/quantization/quark}/__init__.py +0 -0
  173. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
  174. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
  175. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -2,12 +2,13 @@
2
2
  from __future__ import annotations
3
3
 
4
4
  import logging
5
- from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
5
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional
6
6
 
7
7
  import torch
8
8
  from torch.nn.parameter import Parameter
9
9
 
10
10
  from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
11
+ from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe
11
12
  from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
12
13
  from sglang.srt.layers.quantization.base_config import (
13
14
  FusedMoEMethodBase,
@@ -29,6 +30,7 @@ from sglang.srt.layers.quantization.utils import (
29
30
  requantize_with_max_scale,
30
31
  )
31
32
  from sglang.srt.layers.radix_attention import RadixAttention
33
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
32
34
  from sglang.srt.utils import is_cuda, next_power_of_2
33
35
 
34
36
  if TYPE_CHECKING:
@@ -39,6 +41,7 @@ if is_cuda():
39
41
 
40
42
  try:
41
43
  from flashinfer import mm_fp4 as fp4_gemm
44
+ from flashinfer import reorder_rows_for_gated_act_gemm, shuffle_matrix_sf_a
42
45
 
43
46
  enable_flashinfer_fp4_gemm = True
44
47
  except ImportError:
@@ -47,6 +50,9 @@ except ImportError:
47
50
  else:
48
51
  fp4_gemm = None
49
52
  enable_flashinfer_fp4_gemm = False
53
+ reorder_rows_for_gated_act_gemm = None
54
+ shuffle_matrix_a = None
55
+ shuffle_matrix_sf_a = None
50
56
 
51
57
  try:
52
58
  from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe
@@ -527,6 +533,7 @@ class ModelOptFp4Config(QuantizationConfig):
527
533
  ) -> Optional[QuantizeMethodBase]:
528
534
  from sglang.srt.layers.linear import LinearBase
529
535
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
536
+ from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFP4MoE
530
537
 
531
538
  if isinstance(layer, LinearBase):
532
539
  if is_layer_skipped(prefix, self.exclude_modules) or self.is_layer_excluded(
@@ -536,6 +543,9 @@ class ModelOptFp4Config(QuantizationConfig):
536
543
  return ModelOptFp4LinearMethod(self)
537
544
  if self.kv_cache_quant_algo and isinstance(layer, RadixAttention):
538
545
  return ModelOptFp8KVCacheMethod(self)
546
+ elif isinstance(layer, FlashInferFP4MoE):
547
+ # FlashInferFP4MoE needs the same quantization method but with compatible attribute handling
548
+ return ModelOptNvFp4FusedMoEMethod(self)
539
549
  elif isinstance(layer, FusedMoE):
540
550
  return ModelOptNvFp4FusedMoEMethod(self)
541
551
  return None
@@ -667,9 +677,9 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
667
677
  padded_scales = padded_scales.permute((0, 1, 4, 3, 2, 5))
668
678
  padded_scales = padded_scales.contiguous().cuda()
669
679
  padded_scales = (
670
- padded_scales.reshape(M, K)
680
+ padded_scales.reshape(M_padded, K_padded)
671
681
  if scale_ndim == 2
672
- else padded_scales.reshape(B, M, K)
682
+ else padded_scales.reshape(B, M_padded, K_padded)
673
683
  )
674
684
  layer.weight_scale_interleaved = Parameter(padded_scales, requires_grad=False)
675
685
 
@@ -726,7 +736,12 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
726
736
  " quantization. Please use Blackwell and"
727
737
  " above."
728
738
  )
729
- self.enable_flashinfer_cutlass_moe = False
739
+ self.enable_flashinfer_trtllm_moe = should_use_flashinfer_trtllm_moe()
740
+
741
+ @property
742
+ def enable_flashinfer_cutlass_moe(self) -> bool:
743
+ """Access the global enable_flashinfer_cutlass_moe setting."""
744
+ return global_server_args_dict.get("enable_flashinfer_cutlass_moe", False)
730
745
 
731
746
  def create_weights(
732
747
  self,
@@ -743,16 +758,18 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
743
758
  " dynamic quantization is not supported."
744
759
  )
745
760
 
746
- layer.num_experts = num_experts
761
+ # TODO(ch-wan): check if this is needed
762
+ layer.intermediate_size_per_partition = intermediate_size_per_partition
747
763
  layer.params_dtype = params_dtype
748
764
  layer.quant_config = self.quant_config
765
+
749
766
  weight_dtype = torch.uint8
750
767
  weight_scale_dtype = torch.float8_e4m3fn
751
768
  weight_loader = extra_weight_attrs.get("weight_loader")
752
769
  # GEMM 1
753
770
  w13_weight = ModelWeightParameter(
754
771
  data=torch.empty(
755
- num_experts,
772
+ layer.num_local_experts,
756
773
  2 * intermediate_size_per_partition,
757
774
  # 2 fp4 items are packed in the input dimension
758
775
  hidden_size // 2,
@@ -767,7 +784,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
767
784
  # GEMM 2
768
785
  w2_weight = ModelWeightParameter(
769
786
  data=torch.empty(
770
- num_experts,
787
+ layer.num_local_experts,
771
788
  hidden_size,
772
789
  # 2 fp4 items are packed in the input dimension
773
790
  intermediate_size_per_partition // 2,
@@ -781,7 +798,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
781
798
 
782
799
  w13_weight_scale = ModelWeightParameter(
783
800
  data=torch.empty(
784
- num_experts,
801
+ layer.num_local_experts,
785
802
  2 * intermediate_size_per_partition,
786
803
  # 2 fp4 items are packed in the input dimension
787
804
  hidden_size // self.quant_config.group_size,
@@ -795,7 +812,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
795
812
 
796
813
  w2_weight_scale = ModelWeightParameter(
797
814
  data=torch.empty(
798
- num_experts,
815
+ layer.num_local_experts,
799
816
  hidden_size,
800
817
  # 2 fp4 items are packed in the input dimension
801
818
  intermediate_size_per_partition // self.quant_config.group_size,
@@ -814,13 +831,13 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
814
831
  )
815
832
 
816
833
  w13_weight_scale_2 = PerTensorScaleParameter(
817
- data=torch.empty(num_experts, 2, dtype=torch.float32),
834
+ data=torch.empty(layer.num_local_experts, 2, dtype=torch.float32),
818
835
  weight_loader=weight_loader,
819
836
  )
820
837
  layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2)
821
838
 
822
839
  w2_weight_scale_2 = PerTensorScaleParameter(
823
- data=torch.empty(num_experts, dtype=torch.float32),
840
+ data=torch.empty(layer.num_local_experts, dtype=torch.float32),
824
841
  weight_loader=weight_loader,
825
842
  )
826
843
  layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2)
@@ -830,18 +847,18 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
830
847
  )
831
848
 
832
849
  w13_input_scale = PerTensorScaleParameter(
833
- data=torch.empty(num_experts, 2, dtype=torch.float32),
850
+ data=torch.empty(layer.num_local_experts, 2, dtype=torch.float32),
834
851
  weight_loader=weight_loader,
835
852
  )
836
853
  layer.register_parameter("w13_input_scale", w13_input_scale)
837
854
 
838
855
  w2_input_scale = PerTensorScaleParameter(
839
- data=torch.empty(num_experts, dtype=torch.float32),
856
+ data=torch.empty(layer.num_local_experts, dtype=torch.float32),
840
857
  weight_loader=weight_loader,
841
858
  )
842
859
  layer.register_parameter("w2_input_scale", w2_input_scale)
843
860
 
844
- def swizzle_blockscale(self, scale: torch.tensor):
861
+ def swizzle_blockscale(self, scale: torch.Tensor):
845
862
  assert scale.dtype == torch.float8_e4m3fn
846
863
  # Pad and blockwise interleave weight_scale
847
864
  scale_ndim = scale.ndim
@@ -861,14 +878,130 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
861
878
  swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5))
862
879
  swizzled_scale = swizzled_scale.contiguous().cuda()
863
880
  return (
864
- swizzled_scale.reshape(M, K)
881
+ swizzled_scale.reshape(M_padded, K_padded)
865
882
  if scale_ndim == 2
866
- else swizzled_scale.reshape(B, M, K)
883
+ else swizzled_scale.reshape(B, M_padded, K_padded)
884
+ )
885
+
886
+ def prepare_static_weights_for_kernel(
887
+ self,
888
+ # args_dequant,
889
+ # args,
890
+ gemm1_weights,
891
+ gemm2_weights,
892
+ gemm1_scales_linear_fp4_bytes,
893
+ gemm2_scales_linear_fp4_bytes,
894
+ hidden_size,
895
+ intermediate_size,
896
+ num_experts,
897
+ ):
898
+ from flashinfer import (
899
+ RoutingMethodType,
900
+ e2m1_and_ufp8sf_scale_to_float,
901
+ fp4_quantize,
902
+ next_positive_power_of_2,
903
+ reorder_rows_for_gated_act_gemm,
904
+ shuffle_matrix_a,
905
+ shuffle_matrix_sf_a,
906
+ )
907
+
908
+ """Prepare quantized weights for kernel (done offline with weights)."""
909
+ epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
910
+
911
+ # Convert quantized weights to proper formats
912
+ gemm1_weights_fp4 = gemm1_weights.view(torch.float8_e4m3fn).reshape(
913
+ num_experts, 2 * intermediate_size, hidden_size // 2
914
+ ) # packed fp4
915
+ gemm1_scales_linear_fp4 = gemm1_scales_linear_fp4_bytes.view(
916
+ torch.float8_e4m3fn
917
+ ).reshape(
918
+ num_experts, 2 * intermediate_size, hidden_size // 16
919
+ ) # fp8 scaling factors
920
+
921
+ gemm2_weights_fp4 = gemm2_weights.view(torch.float8_e4m3fn).reshape(
922
+ num_experts, hidden_size, intermediate_size // 2
923
+ ) # packed fp4
924
+ gemm2_scales_linear_fp4 = gemm2_scales_linear_fp4_bytes.view(
925
+ torch.float8_e4m3fn
926
+ ).reshape(
927
+ num_experts, hidden_size, intermediate_size // 16
928
+ ) # fp8 scaling factors
929
+
930
+ # Reorder rows of W1 and scales for fused gated activation
931
+ gemm1_weights_fp4_interleaved = []
932
+ gemm1_scales_fp4_interleaved = []
933
+ for i in range(num_experts):
934
+ gemm1_weights_fp4_interleaved.append(
935
+ reorder_rows_for_gated_act_gemm(gemm1_weights_fp4[i].clone())
936
+ )
937
+ gemm1_scales_fp4_interleaved.append(
938
+ reorder_rows_for_gated_act_gemm(gemm1_scales_linear_fp4[i].clone())
939
+ )
940
+
941
+ # Stack weights and scales for all experts
942
+ gemm1_weights_fp4_interleaved = torch.stack(
943
+ gemm1_weights_fp4_interleaved
944
+ ).reshape(num_experts, 2 * intermediate_size, hidden_size // 2)
945
+ gemm1_scales_fp4_interleaved = torch.stack(
946
+ gemm1_scales_fp4_interleaved
947
+ ).reshape(num_experts, 2 * intermediate_size, hidden_size // 16)
948
+
949
+ # Shuffle weights and scaling factors for transposed mma output
950
+ gemm1_weights_fp4_shuffled = []
951
+ gemm1_scales_fp4_shuffled = []
952
+ gemm2_weights_fp4_shuffled = []
953
+ gemm2_scales_fp4_shuffled = []
954
+ for i in range(num_experts):
955
+ gemm1_weights_fp4_shuffled.append(
956
+ shuffle_matrix_a(
957
+ gemm1_weights_fp4_interleaved[i].view(torch.uint8), epilogue_tile_m
958
+ )
959
+ )
960
+ gemm1_scales_fp4_shuffled.append(
961
+ shuffle_matrix_sf_a(
962
+ gemm1_scales_fp4_interleaved[i].view(torch.uint8), epilogue_tile_m
963
+ )
964
+ )
965
+
966
+ gemm2_weights_fp4_shuffled.append(
967
+ shuffle_matrix_a(
968
+ gemm2_weights_fp4[i].view(torch.uint8), epilogue_tile_m
969
+ )
970
+ )
971
+ gemm2_scales_fp4_shuffled.append(
972
+ shuffle_matrix_sf_a(
973
+ gemm2_scales_linear_fp4[i].view(torch.uint8), epilogue_tile_m
974
+ )
975
+ )
976
+
977
+ # Stack weights for all experts
978
+ gemm1_weights_fp4_shuffled = torch.stack(gemm1_weights_fp4_shuffled)
979
+ gemm1_scales_fp4_shuffled = (
980
+ torch.stack(gemm1_scales_fp4_shuffled)
981
+ .view(torch.float8_e4m3fn)
982
+ .reshape(num_experts, 2 * intermediate_size, hidden_size // 16)
983
+ )
984
+
985
+ gemm2_weights_fp4_shuffled = torch.stack(gemm2_weights_fp4_shuffled)
986
+ gemm2_scales_fp4_shuffled = (
987
+ torch.stack(gemm2_scales_fp4_shuffled)
988
+ .view(torch.float8_e4m3fn)
989
+ .reshape(num_experts, hidden_size, intermediate_size // 16)
990
+ )
991
+ return (
992
+ gemm1_weights_fp4_shuffled,
993
+ gemm1_scales_fp4_shuffled,
994
+ gemm2_weights_fp4_shuffled,
995
+ gemm2_scales_fp4_shuffled,
867
996
  )
868
997
 
869
998
  def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
999
+ """Process FP4 MoE weights after loading from serialized checkpoint.
870
1000
 
871
- # GEMM 1
1001
+ Only supports pre-quantized checkpoints with FP8 weights and scales.
1002
+ """
1003
+
1004
+ # GEMM 1 scale processing
872
1005
  if not torch.allclose(
873
1006
  layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]
874
1007
  ):
@@ -880,73 +1013,123 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
880
1013
  w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
881
1014
  layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False)
882
1015
 
883
- if self.enable_flashinfer_cutlass_moe:
1016
+ # Calculate input scales based on strategy
1017
+ if self.enable_flashinfer_cutlass_moe or self.enable_flashinfer_trtllm_moe:
884
1018
  w13_input_scale = layer.w13_input_scale.max().to(torch.float32)
1019
+ w2_input_scale = layer.w2_input_scale.max().to(torch.float32)
885
1020
  else:
886
1021
  w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
1022
+ w2_input_scale = layer.w2_input_scale
1023
+
1024
+ # Create shared parameters
887
1025
  layer.g1_alphas = Parameter(
888
1026
  (w13_input_scale * w13_weight_scale_2).to(torch.float32),
889
1027
  requires_grad=False,
890
1028
  )
891
-
892
- assert (
893
- layer.w13_weight_scale.shape[2] % 16 == 0
894
- ), "Expected weight_scale.dim(1) to be divisible by 16"
895
- assert (
896
- layer.w13_weight_scale.dtype == torch.float8_e4m3fn
897
- ), "Weight Blockscale must be represented as FP8-E4M3"
898
- w13_blockscale_swizzled = self.swizzle_blockscale(layer.w13_weight_scale)
899
-
900
- layer.w13_blockscale_swizzled = Parameter(
901
- w13_blockscale_swizzled, requires_grad=False
1029
+ layer.g2_alphas = Parameter(
1030
+ (w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
1031
+ requires_grad=False,
902
1032
  )
903
- del layer.w13_weight_scale
904
-
905
- # This is for quantization, so we need to invert it.
906
1033
  layer.w13_input_scale_quant = Parameter(
907
1034
  (1 / w13_input_scale).to(torch.float32), requires_grad=False
908
1035
  )
1036
+ layer.w2_input_scale_quant = Parameter(
1037
+ (1 / w2_input_scale).to(torch.float32), requires_grad=False
1038
+ )
909
1039
 
910
- layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
1040
+ # Validate weight scales
1041
+ for name, weight_scale in [
1042
+ ("w13", layer.w13_weight_scale),
1043
+ ("w2", layer.w2_weight_scale),
1044
+ ]:
1045
+ assert (
1046
+ weight_scale.shape[2] % 16 == 0
1047
+ ), f"Expected {name}_weight_scale.dim(2) to be divisible by 16"
1048
+ assert (
1049
+ weight_scale.dtype == torch.float8_e4m3fn
1050
+ ), f"{name} Weight Blockscale must be represented as FP8-E4M3"
1051
+
1052
+ # Weight processing based on strategy
1053
+ if (
1054
+ self.enable_flashinfer_trtllm_moe
1055
+ and reorder_rows_for_gated_act_gemm is not None
1056
+ and shuffle_matrix_sf_a is not None
1057
+ ):
1058
+ # FlashInfer TRTLLM processing - handles both w13 and w2
1059
+ (
1060
+ gemm1_weights_fp4_shuffled,
1061
+ gemm1_scales_fp4_shuffled,
1062
+ gemm2_weights_fp4_shuffled,
1063
+ gemm2_scales_fp4_shuffled,
1064
+ ) = self.prepare_static_weights_for_kernel(
1065
+ layer.w13_weight,
1066
+ layer.w2_weight,
1067
+ layer.w13_weight_scale,
1068
+ layer.w2_weight_scale,
1069
+ layer.w2_weight.size(-2), # hidden_size
1070
+ layer.w13_weight.size(-2) // 2, # intermediate_size
1071
+ layer.w13_weight.size(0), # num_experts
1072
+ )
911
1073
 
912
- # GEMM 2
913
- if self.enable_flashinfer_cutlass_moe:
914
- w2_input_scale = layer.w2_input_scale.max().to(torch.float32)
915
- else:
916
- w2_input_scale = layer.w2_input_scale
1074
+ # Set flashinfer parameters
1075
+ layer.gemm1_weights_fp4_shuffled = Parameter(
1076
+ gemm1_weights_fp4_shuffled, requires_grad=False
1077
+ )
1078
+ layer.gemm2_weights_fp4_shuffled = Parameter(
1079
+ gemm2_weights_fp4_shuffled, requires_grad=False
1080
+ )
1081
+ layer.gemm1_scales_fp4_shuffled = Parameter(
1082
+ gemm1_scales_fp4_shuffled, requires_grad=False
1083
+ )
1084
+ layer.gemm2_scales_fp4_shuffled = Parameter(
1085
+ gemm2_scales_fp4_shuffled, requires_grad=False
1086
+ )
917
1087
 
918
- layer.g2_alphas = Parameter(
919
- (w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
920
- requires_grad=False,
921
- )
1088
+ # Additional parameter needed for TRT-LLM
1089
+ layer.g1_scale_c = Parameter(
1090
+ (layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32),
1091
+ requires_grad=False,
1092
+ )
922
1093
 
923
- # This is for quantization, so we need to invert it.
924
- layer.w2_input_scale_quant = Parameter(
925
- (1 / w2_input_scale).to(torch.float32), requires_grad=False
926
- )
1094
+ # Clean up weights that won't be used by TRT-LLM
1095
+ del (
1096
+ layer.w2_weight,
1097
+ layer.w2_weight_scale,
1098
+ layer.w13_weight,
1099
+ layer.w13_weight_scale,
1100
+ )
927
1101
 
928
- assert (
929
- layer.w2_weight_scale.shape[2] % 16 == 0
930
- ), "Expected weight_scale.dim(1) to be divisible by 16"
931
- assert (
932
- layer.w2_weight_scale.dtype == torch.float8_e4m3fn
933
- ), "Weight Blockscale must be represented as FP8-E4M3"
934
- w2_blockscale_swizzled = self.swizzle_blockscale(layer.w2_weight_scale)
1102
+ logger.info_once("Applied flashinfer weight processing for both w13 and w2")
935
1103
 
936
- layer.w2_blockscale_swizzled = Parameter(
937
- w2_blockscale_swizzled, requires_grad=False
938
- )
939
- del layer.w2_weight_scale
940
- layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
1104
+ else:
1105
+ # CUTLASS processing - handle w13 and w2 separately
1106
+
1107
+ # Process w13 weights
1108
+ w13_blockscale_swizzled = self.swizzle_blockscale(layer.w13_weight_scale)
1109
+ layer.w13_blockscale_swizzled = Parameter(
1110
+ w13_blockscale_swizzled, requires_grad=False
1111
+ )
1112
+ layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
1113
+
1114
+ # Process w2 weights
1115
+ w2_blockscale_swizzled = self.swizzle_blockscale(layer.w2_weight_scale)
1116
+ layer.w2_blockscale_swizzled = Parameter(
1117
+ w2_blockscale_swizzled, requires_grad=False
1118
+ )
1119
+ layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
1120
+
1121
+ # Both flashinfer cutlass and regular cutlass use same processing for w2
1122
+ logger.info_once("Applied weight processing for both w13 and w2")
941
1123
 
942
- device = layer.w13_weight.device
943
- layer.cutlass_moe_params = CutlassMoEParams(
944
- CutlassMoEType.BlockscaledFP4,
945
- device,
946
- num_experts=layer.num_experts, # global num experts
947
- intermediate_size_per_partition=layer.w2_weight.shape[2] * 2, # n
948
- hidden_size=layer.w13_weight.shape[2] * 2,
949
- ) # k
1124
+ # Set up CUTLASS MoE parameters
1125
+ device = layer.w13_weight.device
1126
+ layer.cutlass_moe_params = CutlassMoEParams(
1127
+ CutlassMoEType.BlockscaledFP4,
1128
+ device,
1129
+ num_experts=layer.num_experts, # global num experts
1130
+ intermediate_size_per_partition=layer.w2_weight.shape[2] * 2, # n
1131
+ hidden_size=layer.w13_weight.shape[2] * 2,
1132
+ ) # k
950
1133
 
951
1134
  @property
952
1135
  def load_up_proj_weight_first(self) -> bool:
@@ -971,13 +1154,20 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
971
1154
  ) -> torch.Tensor:
972
1155
  assert activation == "silu", "Only SiLU activation is supported."
973
1156
 
1157
+ # Check if this is a FlashInferFP4MoE layer that should handle its own forward
1158
+ if hasattr(layer, "gemm1_weights_fp4_shuffled"):
1159
+ # This layer was processed with flashinfer TRTLLM - delegate to its own forward
1160
+ return layer.forward(x, topk_output)
1161
+
974
1162
  if self.enable_flashinfer_cutlass_moe:
975
1163
  assert (
976
1164
  not apply_router_weight_on_input
977
1165
  ), "apply_router_weight_on_input is not supported for Flashinfer"
978
1166
  # TRTLLM Cutlass moe takes in activations in BF16/Half/nvfp4 precision
979
1167
  # and fp4 quantized weights loaded from the checkpoint
980
- topk_weights, topk_ids, _ = topk_output
1168
+
1169
+ topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids
1170
+
981
1171
  output = flashinfer_cutlass_fused_moe(
982
1172
  x,
983
1173
  topk_ids.to(torch.int),
@@ -1005,7 +1195,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1005
1195
 
1006
1196
  from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
1007
1197
 
1008
- topk_weights, topk_ids, _ = topk_output
1198
+ topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids
1009
1199
  output = cutlass_moe_fp4(
1010
1200
  a=x,
1011
1201
  a1_gscale=layer.w13_input_scale_quant,