sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. sglang/bench_one_batch.py +113 -17
  2. sglang/srt/configs/model_config.py +35 -0
  3. sglang/srt/conversation.py +9 -5
  4. sglang/srt/disaggregation/base/conn.py +5 -2
  5. sglang/srt/disaggregation/decode.py +6 -1
  6. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
  7. sglang/srt/disaggregation/mooncake/conn.py +243 -135
  8. sglang/srt/disaggregation/prefill.py +2 -0
  9. sglang/srt/distributed/parallel_state.py +11 -9
  10. sglang/srt/entrypoints/context.py +244 -0
  11. sglang/srt/entrypoints/engine.py +4 -3
  12. sglang/srt/entrypoints/harmony_utils.py +370 -0
  13. sglang/srt/entrypoints/http_server.py +71 -0
  14. sglang/srt/entrypoints/openai/protocol.py +227 -1
  15. sglang/srt/entrypoints/openai/serving_chat.py +278 -42
  16. sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
  17. sglang/srt/entrypoints/openai/tool_server.py +174 -0
  18. sglang/srt/entrypoints/tool.py +87 -0
  19. sglang/srt/eplb/expert_location.py +5 -1
  20. sglang/srt/function_call/harmony_tool_parser.py +130 -0
  21. sglang/srt/hf_transformers_utils.py +30 -3
  22. sglang/srt/jinja_template_utils.py +8 -1
  23. sglang/srt/layers/attention/aiter_backend.py +5 -8
  24. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
  25. sglang/srt/layers/attention/triton_backend.py +85 -14
  26. sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
  27. sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
  28. sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
  29. sglang/srt/layers/attention/vision.py +13 -5
  30. sglang/srt/layers/communicator.py +21 -4
  31. sglang/srt/layers/dp_attention.py +12 -0
  32. sglang/srt/layers/linear.py +2 -7
  33. sglang/srt/layers/moe/cutlass_moe.py +20 -6
  34. sglang/srt/layers/moe/ep_moe/layer.py +77 -73
  35. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
  36. sglang/srt/layers/moe/fused_moe_triton/layer.py +416 -35
  37. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
  38. sglang/srt/layers/moe/topk.py +12 -3
  39. sglang/srt/layers/moe/utils.py +16 -0
  40. sglang/srt/layers/quantization/__init__.py +22 -0
  41. sglang/srt/layers/quantization/fp4.py +557 -0
  42. sglang/srt/layers/quantization/fp8.py +3 -6
  43. sglang/srt/layers/quantization/fp8_utils.py +29 -0
  44. sglang/srt/layers/quantization/modelopt_quant.py +259 -64
  45. sglang/srt/layers/quantization/mxfp4.py +651 -0
  46. sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
  47. sglang/srt/layers/quantization/quark/__init__.py +0 -0
  48. sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
  49. sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
  50. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
  51. sglang/srt/layers/quantization/quark/utils.py +107 -0
  52. sglang/srt/layers/quantization/unquant.py +60 -6
  53. sglang/srt/layers/quantization/w4afp8.py +1 -1
  54. sglang/srt/layers/rotary_embedding.py +225 -1
  55. sglang/srt/layers/utils.py +9 -0
  56. sglang/srt/layers/vocab_parallel_embedding.py +8 -3
  57. sglang/srt/lora/lora_manager.py +70 -14
  58. sglang/srt/lora/lora_registry.py +3 -2
  59. sglang/srt/lora/mem_pool.py +43 -5
  60. sglang/srt/managers/cache_controller.py +55 -30
  61. sglang/srt/managers/detokenizer_manager.py +1 -1
  62. sglang/srt/managers/io_struct.py +15 -3
  63. sglang/srt/managers/mm_utils.py +5 -11
  64. sglang/srt/managers/schedule_batch.py +28 -7
  65. sglang/srt/managers/scheduler.py +26 -12
  66. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
  67. sglang/srt/managers/scheduler_recv_skipper.py +37 -0
  68. sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
  69. sglang/srt/managers/template_manager.py +35 -1
  70. sglang/srt/managers/tokenizer_manager.py +24 -6
  71. sglang/srt/managers/tp_worker.py +3 -0
  72. sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
  73. sglang/srt/mem_cache/hiradix_cache.py +53 -5
  74. sglang/srt/mem_cache/memory_pool_host.py +1 -1
  75. sglang/srt/mem_cache/multimodal_cache.py +33 -13
  76. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
  77. sglang/srt/model_executor/cuda_graph_runner.py +7 -6
  78. sglang/srt/model_executor/forward_batch_info.py +35 -14
  79. sglang/srt/model_executor/model_runner.py +19 -2
  80. sglang/srt/model_loader/weight_utils.py +10 -0
  81. sglang/srt/models/bailing_moe.py +425 -0
  82. sglang/srt/models/deepseek_v2.py +72 -33
  83. sglang/srt/models/ernie4.py +426 -0
  84. sglang/srt/models/ernie4_eagle.py +203 -0
  85. sglang/srt/models/gemma3n_mm.py +39 -0
  86. sglang/srt/models/glm4_moe.py +24 -12
  87. sglang/srt/models/gpt_oss.py +1134 -0
  88. sglang/srt/models/qwen2.py +6 -0
  89. sglang/srt/models/qwen2_moe.py +6 -0
  90. sglang/srt/models/qwen3_moe.py +32 -6
  91. sglang/srt/models/step3_vl.py +9 -0
  92. sglang/srt/models/transformers.py +2 -5
  93. sglang/srt/multimodal/processors/step3_vl.py +3 -1
  94. sglang/srt/reasoning_parser.py +18 -39
  95. sglang/srt/server_args.py +142 -7
  96. sglang/srt/two_batch_overlap.py +157 -5
  97. sglang/srt/utils.py +38 -2
  98. sglang/test/runners.py +2 -2
  99. sglang/test/test_utils.py +1 -1
  100. sglang/version.py +1 -1
  101. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/METADATA +16 -14
  102. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/RECORD +105 -84
  103. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/WHEEL +0 -0
  104. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/licenses/LICENSE +0 -0
  105. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/top_level.txt +0 -0
@@ -1,13 +1,15 @@
1
1
  # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py
2
2
  from __future__ import annotations
3
3
 
4
+ import importlib.util
4
5
  import logging
5
- from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
6
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
6
7
 
7
8
  import torch
8
9
  from torch.nn.parameter import Parameter
9
10
 
10
11
  from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
12
+ from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe
11
13
  from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
12
14
  from sglang.srt.layers.quantization.base_config import (
13
15
  FusedMoEMethodBase,
@@ -29,6 +31,7 @@ from sglang.srt.layers.quantization.utils import (
29
31
  requantize_with_max_scale,
30
32
  )
31
33
  from sglang.srt.layers.radix_attention import RadixAttention
34
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
32
35
  from sglang.srt.utils import is_cuda, next_power_of_2
33
36
 
34
37
  if TYPE_CHECKING:
@@ -39,6 +42,11 @@ if is_cuda():
39
42
 
40
43
  try:
41
44
  from flashinfer import mm_fp4 as fp4_gemm
45
+ from flashinfer import (
46
+ reorder_rows_for_gated_act_gemm,
47
+ shuffle_matrix_a,
48
+ shuffle_matrix_sf_a,
49
+ )
42
50
 
43
51
  enable_flashinfer_fp4_gemm = True
44
52
  except ImportError:
@@ -47,6 +55,9 @@ except ImportError:
47
55
  else:
48
56
  fp4_gemm = None
49
57
  enable_flashinfer_fp4_gemm = False
58
+ reorder_rows_for_gated_act_gemm = None
59
+ shuffle_matrix_a = None
60
+ shuffle_matrix_sf_a = None
50
61
 
51
62
  try:
52
63
  from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe
@@ -527,6 +538,7 @@ class ModelOptFp4Config(QuantizationConfig):
527
538
  ) -> Optional[QuantizeMethodBase]:
528
539
  from sglang.srt.layers.linear import LinearBase
529
540
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
541
+ from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFP4MoE
530
542
 
531
543
  if isinstance(layer, LinearBase):
532
544
  if is_layer_skipped(prefix, self.exclude_modules) or self.is_layer_excluded(
@@ -536,6 +548,9 @@ class ModelOptFp4Config(QuantizationConfig):
536
548
  return ModelOptFp4LinearMethod(self)
537
549
  if self.kv_cache_quant_algo and isinstance(layer, RadixAttention):
538
550
  return ModelOptFp8KVCacheMethod(self)
551
+ elif isinstance(layer, FlashInferFP4MoE):
552
+ # FlashInferFP4MoE needs the same quantization method but with compatible attribute handling
553
+ return ModelOptNvFp4FusedMoEMethod(self)
539
554
  elif isinstance(layer, FusedMoE):
540
555
  return ModelOptNvFp4FusedMoEMethod(self)
541
556
  return None
@@ -726,7 +741,12 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
726
741
  " quantization. Please use Blackwell and"
727
742
  " above."
728
743
  )
729
- self.enable_flashinfer_cutlass_moe = False
744
+ self.enable_flashinfer_trtllm_moe = should_use_flashinfer_trtllm_moe()
745
+
746
+ @property
747
+ def enable_flashinfer_cutlass_moe(self) -> bool:
748
+ """Access the global enable_flashinfer_cutlass_moe setting."""
749
+ return global_server_args_dict.get("enable_flashinfer_cutlass_moe", False)
730
750
 
731
751
  def create_weights(
732
752
  self,
@@ -743,16 +763,18 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
743
763
  " dynamic quantization is not supported."
744
764
  )
745
765
 
746
- layer.num_experts = num_experts
766
+ # TODO(ch-wan): check if this is needed
767
+ layer.intermediate_size_per_partition = intermediate_size_per_partition
747
768
  layer.params_dtype = params_dtype
748
769
  layer.quant_config = self.quant_config
770
+
749
771
  weight_dtype = torch.uint8
750
772
  weight_scale_dtype = torch.float8_e4m3fn
751
773
  weight_loader = extra_weight_attrs.get("weight_loader")
752
774
  # GEMM 1
753
775
  w13_weight = ModelWeightParameter(
754
776
  data=torch.empty(
755
- num_experts,
777
+ layer.num_local_experts,
756
778
  2 * intermediate_size_per_partition,
757
779
  # 2 fp4 items are packed in the input dimension
758
780
  hidden_size // 2,
@@ -767,7 +789,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
767
789
  # GEMM 2
768
790
  w2_weight = ModelWeightParameter(
769
791
  data=torch.empty(
770
- num_experts,
792
+ layer.num_local_experts,
771
793
  hidden_size,
772
794
  # 2 fp4 items are packed in the input dimension
773
795
  intermediate_size_per_partition // 2,
@@ -781,7 +803,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
781
803
 
782
804
  w13_weight_scale = ModelWeightParameter(
783
805
  data=torch.empty(
784
- num_experts,
806
+ layer.num_local_experts,
785
807
  2 * intermediate_size_per_partition,
786
808
  # 2 fp4 items are packed in the input dimension
787
809
  hidden_size // self.quant_config.group_size,
@@ -795,7 +817,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
795
817
 
796
818
  w2_weight_scale = ModelWeightParameter(
797
819
  data=torch.empty(
798
- num_experts,
820
+ layer.num_local_experts,
799
821
  hidden_size,
800
822
  # 2 fp4 items are packed in the input dimension
801
823
  intermediate_size_per_partition // self.quant_config.group_size,
@@ -814,13 +836,13 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
814
836
  )
815
837
 
816
838
  w13_weight_scale_2 = PerTensorScaleParameter(
817
- data=torch.empty(num_experts, 2, dtype=torch.float32),
839
+ data=torch.empty(layer.num_local_experts, 2, dtype=torch.float32),
818
840
  weight_loader=weight_loader,
819
841
  )
820
842
  layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2)
821
843
 
822
844
  w2_weight_scale_2 = PerTensorScaleParameter(
823
- data=torch.empty(num_experts, dtype=torch.float32),
845
+ data=torch.empty(layer.num_local_experts, dtype=torch.float32),
824
846
  weight_loader=weight_loader,
825
847
  )
826
848
  layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2)
@@ -830,18 +852,18 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
830
852
  )
831
853
 
832
854
  w13_input_scale = PerTensorScaleParameter(
833
- data=torch.empty(num_experts, 2, dtype=torch.float32),
855
+ data=torch.empty(layer.num_local_experts, 2, dtype=torch.float32),
834
856
  weight_loader=weight_loader,
835
857
  )
836
858
  layer.register_parameter("w13_input_scale", w13_input_scale)
837
859
 
838
860
  w2_input_scale = PerTensorScaleParameter(
839
- data=torch.empty(num_experts, dtype=torch.float32),
861
+ data=torch.empty(layer.num_local_experts, dtype=torch.float32),
840
862
  weight_loader=weight_loader,
841
863
  )
842
864
  layer.register_parameter("w2_input_scale", w2_input_scale)
843
865
 
844
- def swizzle_blockscale(self, scale: torch.tensor):
866
+ def swizzle_blockscale(self, scale: torch.Tensor):
845
867
  assert scale.dtype == torch.float8_e4m3fn
846
868
  # Pad and blockwise interleave weight_scale
847
869
  scale_ndim = scale.ndim
@@ -866,9 +888,125 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
866
888
  else swizzled_scale.reshape(B, M, K)
867
889
  )
868
890
 
891
+ def prepare_static_weights_for_kernel(
892
+ self,
893
+ # args_dequant,
894
+ # args,
895
+ gemm1_weights,
896
+ gemm2_weights,
897
+ gemm1_scales_linear_fp4_bytes,
898
+ gemm2_scales_linear_fp4_bytes,
899
+ hidden_size,
900
+ intermediate_size,
901
+ num_experts,
902
+ ):
903
+ from flashinfer import (
904
+ RoutingMethodType,
905
+ e2m1_and_ufp8sf_scale_to_float,
906
+ fp4_quantize,
907
+ next_positive_power_of_2,
908
+ reorder_rows_for_gated_act_gemm,
909
+ shuffle_matrix_a,
910
+ shuffle_matrix_sf_a,
911
+ )
912
+
913
+ """Prepare quantized weights for kernel (done offline with weights)."""
914
+ epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
915
+
916
+ # Convert quantized weights to proper formats
917
+ gemm1_weights_fp4 = gemm1_weights.view(torch.float8_e4m3fn).reshape(
918
+ num_experts, 2 * intermediate_size, hidden_size // 2
919
+ ) # packed fp4
920
+ gemm1_scales_linear_fp4 = gemm1_scales_linear_fp4_bytes.view(
921
+ torch.float8_e4m3fn
922
+ ).reshape(
923
+ num_experts, 2 * intermediate_size, hidden_size // 16
924
+ ) # fp8 scaling factors
925
+
926
+ gemm2_weights_fp4 = gemm2_weights.view(torch.float8_e4m3fn).reshape(
927
+ num_experts, hidden_size, intermediate_size // 2
928
+ ) # packed fp4
929
+ gemm2_scales_linear_fp4 = gemm2_scales_linear_fp4_bytes.view(
930
+ torch.float8_e4m3fn
931
+ ).reshape(
932
+ num_experts, hidden_size, intermediate_size // 16
933
+ ) # fp8 scaling factors
934
+
935
+ # Reorder rows of W1 and scales for fused gated activation
936
+ gemm1_weights_fp4_interleaved = []
937
+ gemm1_scales_fp4_interleaved = []
938
+ for i in range(num_experts):
939
+ gemm1_weights_fp4_interleaved.append(
940
+ reorder_rows_for_gated_act_gemm(gemm1_weights_fp4[i].clone())
941
+ )
942
+ gemm1_scales_fp4_interleaved.append(
943
+ reorder_rows_for_gated_act_gemm(gemm1_scales_linear_fp4[i].clone())
944
+ )
945
+
946
+ # Stack weights and scales for all experts
947
+ gemm1_weights_fp4_interleaved = torch.stack(
948
+ gemm1_weights_fp4_interleaved
949
+ ).reshape(num_experts, 2 * intermediate_size, hidden_size // 2)
950
+ gemm1_scales_fp4_interleaved = torch.stack(
951
+ gemm1_scales_fp4_interleaved
952
+ ).reshape(num_experts, 2 * intermediate_size, hidden_size // 16)
953
+
954
+ # Shuffle weights and scaling factors for transposed mma output
955
+ gemm1_weights_fp4_shuffled = []
956
+ gemm1_scales_fp4_shuffled = []
957
+ gemm2_weights_fp4_shuffled = []
958
+ gemm2_scales_fp4_shuffled = []
959
+ for i in range(num_experts):
960
+ gemm1_weights_fp4_shuffled.append(
961
+ shuffle_matrix_a(
962
+ gemm1_weights_fp4_interleaved[i].view(torch.uint8), epilogue_tile_m
963
+ )
964
+ )
965
+ gemm1_scales_fp4_shuffled.append(
966
+ shuffle_matrix_sf_a(
967
+ gemm1_scales_fp4_interleaved[i].view(torch.uint8), epilogue_tile_m
968
+ )
969
+ )
970
+
971
+ gemm2_weights_fp4_shuffled.append(
972
+ shuffle_matrix_a(
973
+ gemm2_weights_fp4[i].view(torch.uint8), epilogue_tile_m
974
+ )
975
+ )
976
+ gemm2_scales_fp4_shuffled.append(
977
+ shuffle_matrix_sf_a(
978
+ gemm2_scales_linear_fp4[i].view(torch.uint8), epilogue_tile_m
979
+ )
980
+ )
981
+
982
+ # Stack weights for all experts
983
+ gemm1_weights_fp4_shuffled = torch.stack(gemm1_weights_fp4_shuffled)
984
+ gemm1_scales_fp4_shuffled = (
985
+ torch.stack(gemm1_scales_fp4_shuffled)
986
+ .view(torch.float8_e4m3fn)
987
+ .reshape(num_experts, 2 * intermediate_size, hidden_size // 16)
988
+ )
989
+
990
+ gemm2_weights_fp4_shuffled = torch.stack(gemm2_weights_fp4_shuffled)
991
+ gemm2_scales_fp4_shuffled = (
992
+ torch.stack(gemm2_scales_fp4_shuffled)
993
+ .view(torch.float8_e4m3fn)
994
+ .reshape(num_experts, hidden_size, intermediate_size // 16)
995
+ )
996
+ return (
997
+ gemm1_weights_fp4_shuffled,
998
+ gemm1_scales_fp4_shuffled,
999
+ gemm2_weights_fp4_shuffled,
1000
+ gemm2_scales_fp4_shuffled,
1001
+ )
1002
+
869
1003
  def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
1004
+ """Process FP4 MoE weights after loading from serialized checkpoint.
870
1005
 
871
- # GEMM 1
1006
+ Only supports pre-quantized checkpoints with FP8 weights and scales.
1007
+ """
1008
+
1009
+ # GEMM 1 scale processing
872
1010
  if not torch.allclose(
873
1011
  layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]
874
1012
  ):
@@ -880,73 +1018,123 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
880
1018
  w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
881
1019
  layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False)
882
1020
 
883
- if self.enable_flashinfer_cutlass_moe:
1021
+ # Calculate input scales based on strategy
1022
+ if self.enable_flashinfer_cutlass_moe or self.enable_flashinfer_trtllm_moe:
884
1023
  w13_input_scale = layer.w13_input_scale.max().to(torch.float32)
1024
+ w2_input_scale = layer.w2_input_scale.max().to(torch.float32)
885
1025
  else:
886
1026
  w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
1027
+ w2_input_scale = layer.w2_input_scale
1028
+
1029
+ # Create shared parameters
887
1030
  layer.g1_alphas = Parameter(
888
1031
  (w13_input_scale * w13_weight_scale_2).to(torch.float32),
889
1032
  requires_grad=False,
890
1033
  )
891
-
892
- assert (
893
- layer.w13_weight_scale.shape[2] % 16 == 0
894
- ), "Expected weight_scale.dim(1) to be divisible by 16"
895
- assert (
896
- layer.w13_weight_scale.dtype == torch.float8_e4m3fn
897
- ), "Weight Blockscale must be represented as FP8-E4M3"
898
- w13_blockscale_swizzled = self.swizzle_blockscale(layer.w13_weight_scale)
899
-
900
- layer.w13_blockscale_swizzled = Parameter(
901
- w13_blockscale_swizzled, requires_grad=False
1034
+ layer.g2_alphas = Parameter(
1035
+ (w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
1036
+ requires_grad=False,
902
1037
  )
903
- del layer.w13_weight_scale
904
-
905
- # This is for quantization, so we need to invert it.
906
1038
  layer.w13_input_scale_quant = Parameter(
907
1039
  (1 / w13_input_scale).to(torch.float32), requires_grad=False
908
1040
  )
1041
+ layer.w2_input_scale_quant = Parameter(
1042
+ (1 / w2_input_scale).to(torch.float32), requires_grad=False
1043
+ )
909
1044
 
910
- layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
1045
+ # Validate weight scales
1046
+ for name, weight_scale in [
1047
+ ("w13", layer.w13_weight_scale),
1048
+ ("w2", layer.w2_weight_scale),
1049
+ ]:
1050
+ assert (
1051
+ weight_scale.shape[2] % 16 == 0
1052
+ ), f"Expected {name}_weight_scale.dim(2) to be divisible by 16"
1053
+ assert (
1054
+ weight_scale.dtype == torch.float8_e4m3fn
1055
+ ), f"{name} Weight Blockscale must be represented as FP8-E4M3"
1056
+
1057
+ # Weight processing based on strategy
1058
+ if (
1059
+ self.enable_flashinfer_trtllm_moe
1060
+ and reorder_rows_for_gated_act_gemm is not None
1061
+ and shuffle_matrix_sf_a is not None
1062
+ ):
1063
+ # FlashInfer TRTLLM processing - handles both w13 and w2
1064
+ (
1065
+ gemm1_weights_fp4_shuffled,
1066
+ gemm1_scales_fp4_shuffled,
1067
+ gemm2_weights_fp4_shuffled,
1068
+ gemm2_scales_fp4_shuffled,
1069
+ ) = self.prepare_static_weights_for_kernel(
1070
+ layer.w13_weight,
1071
+ layer.w2_weight,
1072
+ layer.w13_weight_scale,
1073
+ layer.w2_weight_scale,
1074
+ layer.w2_weight.size(-2), # hidden_size
1075
+ layer.w13_weight.size(-2) // 2, # intermediate_size
1076
+ layer.w13_weight.size(0), # num_experts
1077
+ )
911
1078
 
912
- # GEMM 2
913
- if self.enable_flashinfer_cutlass_moe:
914
- w2_input_scale = layer.w2_input_scale.max().to(torch.float32)
915
- else:
916
- w2_input_scale = layer.w2_input_scale
1079
+ # Set flashinfer parameters
1080
+ layer.gemm1_weights_fp4_shuffled = Parameter(
1081
+ gemm1_weights_fp4_shuffled, requires_grad=False
1082
+ )
1083
+ layer.gemm2_weights_fp4_shuffled = Parameter(
1084
+ gemm2_weights_fp4_shuffled, requires_grad=False
1085
+ )
1086
+ layer.gemm1_scales_fp4_shuffled = Parameter(
1087
+ gemm1_scales_fp4_shuffled, requires_grad=False
1088
+ )
1089
+ layer.gemm2_scales_fp4_shuffled = Parameter(
1090
+ gemm2_scales_fp4_shuffled, requires_grad=False
1091
+ )
917
1092
 
918
- layer.g2_alphas = Parameter(
919
- (w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
920
- requires_grad=False,
921
- )
1093
+ # Additional parameter needed for TRT-LLM
1094
+ layer.g1_scale_c = Parameter(
1095
+ (layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32),
1096
+ requires_grad=False,
1097
+ )
922
1098
 
923
- # This is for quantization, so we need to invert it.
924
- layer.w2_input_scale_quant = Parameter(
925
- (1 / w2_input_scale).to(torch.float32), requires_grad=False
926
- )
1099
+ # Clean up weights that won't be used by TRT-LLM
1100
+ del (
1101
+ layer.w2_weight,
1102
+ layer.w2_weight_scale,
1103
+ layer.w13_weight,
1104
+ layer.w13_weight_scale,
1105
+ )
927
1106
 
928
- assert (
929
- layer.w2_weight_scale.shape[2] % 16 == 0
930
- ), "Expected weight_scale.dim(1) to be divisible by 16"
931
- assert (
932
- layer.w2_weight_scale.dtype == torch.float8_e4m3fn
933
- ), "Weight Blockscale must be represented as FP8-E4M3"
934
- w2_blockscale_swizzled = self.swizzle_blockscale(layer.w2_weight_scale)
1107
+ logger.info_once("Applied flashinfer weight processing for both w13 and w2")
935
1108
 
936
- layer.w2_blockscale_swizzled = Parameter(
937
- w2_blockscale_swizzled, requires_grad=False
938
- )
939
- del layer.w2_weight_scale
940
- layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
1109
+ else:
1110
+ # CUTLASS processing - handle w13 and w2 separately
1111
+
1112
+ # Process w13 weights
1113
+ w13_blockscale_swizzled = self.swizzle_blockscale(layer.w13_weight_scale)
1114
+ layer.w13_blockscale_swizzled = Parameter(
1115
+ w13_blockscale_swizzled, requires_grad=False
1116
+ )
1117
+ layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
1118
+
1119
+ # Process w2 weights
1120
+ w2_blockscale_swizzled = self.swizzle_blockscale(layer.w2_weight_scale)
1121
+ layer.w2_blockscale_swizzled = Parameter(
1122
+ w2_blockscale_swizzled, requires_grad=False
1123
+ )
1124
+ layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
1125
+
1126
+ # Both flashinfer cutlass and regular cutlass use same processing for w2
1127
+ logger.info_once("Applied weight processing for both w13 and w2")
941
1128
 
942
- device = layer.w13_weight.device
943
- layer.cutlass_moe_params = CutlassMoEParams(
944
- CutlassMoEType.BlockscaledFP4,
945
- device,
946
- num_experts=layer.num_experts, # global num experts
947
- intermediate_size_per_partition=layer.w2_weight.shape[2] * 2, # n
948
- hidden_size=layer.w13_weight.shape[2] * 2,
949
- ) # k
1129
+ # Set up CUTLASS MoE parameters
1130
+ device = layer.w13_weight.device
1131
+ layer.cutlass_moe_params = CutlassMoEParams(
1132
+ CutlassMoEType.BlockscaledFP4,
1133
+ device,
1134
+ num_experts=layer.num_experts, # global num experts
1135
+ intermediate_size_per_partition=layer.w2_weight.shape[2] * 2, # n
1136
+ hidden_size=layer.w13_weight.shape[2] * 2,
1137
+ ) # k
950
1138
 
951
1139
  @property
952
1140
  def load_up_proj_weight_first(self) -> bool:
@@ -971,13 +1159,20 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
971
1159
  ) -> torch.Tensor:
972
1160
  assert activation == "silu", "Only SiLU activation is supported."
973
1161
 
1162
+ # Check if this is a FlashInferFP4MoE layer that should handle its own forward
1163
+ if hasattr(layer, "gemm1_weights_fp4_shuffled"):
1164
+ # This layer was processed with flashinfer TRTLLM - delegate to its own forward
1165
+ return layer.forward(x, topk_output)
1166
+
974
1167
  if self.enable_flashinfer_cutlass_moe:
975
1168
  assert (
976
1169
  not apply_router_weight_on_input
977
1170
  ), "apply_router_weight_on_input is not supported for Flashinfer"
978
1171
  # TRTLLM Cutlass moe takes in activations in BF16/Half/nvfp4 precision
979
1172
  # and fp4 quantized weights loaded from the checkpoint
980
- topk_weights, topk_ids, _ = topk_output
1173
+
1174
+ topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids
1175
+
981
1176
  output = flashinfer_cutlass_fused_moe(
982
1177
  x,
983
1178
  topk_ids.to(torch.int),
@@ -1005,7 +1200,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1005
1200
 
1006
1201
  from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
1007
1202
 
1008
- topk_weights, topk_ids, _ = topk_output
1203
+ topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids
1009
1204
  output = cutlass_moe_fp4(
1010
1205
  a=x,
1011
1206
  a1_gscale=layer.w13_input_scale_quant,