sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +113 -17
- sglang/srt/configs/model_config.py +35 -0
- sglang/srt/conversation.py +9 -5
- sglang/srt/disaggregation/base/conn.py +5 -2
- sglang/srt/disaggregation/decode.py +6 -1
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
- sglang/srt/disaggregation/mooncake/conn.py +243 -135
- sglang/srt/disaggregation/prefill.py +2 -0
- sglang/srt/distributed/parallel_state.py +11 -9
- sglang/srt/entrypoints/context.py +244 -0
- sglang/srt/entrypoints/engine.py +4 -3
- sglang/srt/entrypoints/harmony_utils.py +370 -0
- sglang/srt/entrypoints/http_server.py +71 -0
- sglang/srt/entrypoints/openai/protocol.py +227 -1
- sglang/srt/entrypoints/openai/serving_chat.py +278 -42
- sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
- sglang/srt/entrypoints/openai/tool_server.py +174 -0
- sglang/srt/entrypoints/tool.py +87 -0
- sglang/srt/eplb/expert_location.py +5 -1
- sglang/srt/function_call/harmony_tool_parser.py +130 -0
- sglang/srt/hf_transformers_utils.py +30 -3
- sglang/srt/jinja_template_utils.py +8 -1
- sglang/srt/layers/attention/aiter_backend.py +5 -8
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
- sglang/srt/layers/attention/triton_backend.py +85 -14
- sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
- sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
- sglang/srt/layers/attention/vision.py +13 -5
- sglang/srt/layers/communicator.py +21 -4
- sglang/srt/layers/dp_attention.py +12 -0
- sglang/srt/layers/linear.py +2 -7
- sglang/srt/layers/moe/cutlass_moe.py +20 -6
- sglang/srt/layers/moe/ep_moe/layer.py +77 -73
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
- sglang/srt/layers/moe/fused_moe_triton/layer.py +416 -35
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
- sglang/srt/layers/moe/topk.py +12 -3
- sglang/srt/layers/moe/utils.py +16 -0
- sglang/srt/layers/quantization/__init__.py +22 -0
- sglang/srt/layers/quantization/fp4.py +557 -0
- sglang/srt/layers/quantization/fp8.py +3 -6
- sglang/srt/layers/quantization/fp8_utils.py +29 -0
- sglang/srt/layers/quantization/modelopt_quant.py +259 -64
- sglang/srt/layers/quantization/mxfp4.py +651 -0
- sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
- sglang/srt/layers/quantization/quark/__init__.py +0 -0
- sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
- sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
- sglang/srt/layers/quantization/quark/utils.py +107 -0
- sglang/srt/layers/quantization/unquant.py +60 -6
- sglang/srt/layers/quantization/w4afp8.py +1 -1
- sglang/srt/layers/rotary_embedding.py +225 -1
- sglang/srt/layers/utils.py +9 -0
- sglang/srt/layers/vocab_parallel_embedding.py +8 -3
- sglang/srt/lora/lora_manager.py +70 -14
- sglang/srt/lora/lora_registry.py +3 -2
- sglang/srt/lora/mem_pool.py +43 -5
- sglang/srt/managers/cache_controller.py +55 -30
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +15 -3
- sglang/srt/managers/mm_utils.py +5 -11
- sglang/srt/managers/schedule_batch.py +28 -7
- sglang/srt/managers/scheduler.py +26 -12
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
- sglang/srt/managers/scheduler_recv_skipper.py +37 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
- sglang/srt/managers/template_manager.py +35 -1
- sglang/srt/managers/tokenizer_manager.py +24 -6
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
- sglang/srt/mem_cache/hiradix_cache.py +53 -5
- sglang/srt/mem_cache/memory_pool_host.py +1 -1
- sglang/srt/mem_cache/multimodal_cache.py +33 -13
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +7 -6
- sglang/srt/model_executor/forward_batch_info.py +35 -14
- sglang/srt/model_executor/model_runner.py +19 -2
- sglang/srt/model_loader/weight_utils.py +10 -0
- sglang/srt/models/bailing_moe.py +425 -0
- sglang/srt/models/deepseek_v2.py +72 -33
- sglang/srt/models/ernie4.py +426 -0
- sglang/srt/models/ernie4_eagle.py +203 -0
- sglang/srt/models/gemma3n_mm.py +39 -0
- sglang/srt/models/glm4_moe.py +24 -12
- sglang/srt/models/gpt_oss.py +1134 -0
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_moe.py +6 -0
- sglang/srt/models/qwen3_moe.py +32 -6
- sglang/srt/models/step3_vl.py +9 -0
- sglang/srt/models/transformers.py +2 -5
- sglang/srt/multimodal/processors/step3_vl.py +3 -1
- sglang/srt/reasoning_parser.py +18 -39
- sglang/srt/server_args.py +142 -7
- sglang/srt/two_batch_overlap.py +157 -5
- sglang/srt/utils.py +38 -2
- sglang/test/runners.py +2 -2
- sglang/test/test_utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/METADATA +16 -14
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/RECORD +105 -84
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/top_level.txt +0 -0
@@ -1,13 +1,15 @@
|
|
1
1
|
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py
|
2
2
|
from __future__ import annotations
|
3
3
|
|
4
|
+
import importlib.util
|
4
5
|
import logging
|
5
|
-
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
|
6
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
|
6
7
|
|
7
8
|
import torch
|
8
9
|
from torch.nn.parameter import Parameter
|
9
10
|
|
10
11
|
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
|
12
|
+
from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe
|
11
13
|
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
|
12
14
|
from sglang.srt.layers.quantization.base_config import (
|
13
15
|
FusedMoEMethodBase,
|
@@ -29,6 +31,7 @@ from sglang.srt.layers.quantization.utils import (
|
|
29
31
|
requantize_with_max_scale,
|
30
32
|
)
|
31
33
|
from sglang.srt.layers.radix_attention import RadixAttention
|
34
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
32
35
|
from sglang.srt.utils import is_cuda, next_power_of_2
|
33
36
|
|
34
37
|
if TYPE_CHECKING:
|
@@ -39,6 +42,11 @@ if is_cuda():
|
|
39
42
|
|
40
43
|
try:
|
41
44
|
from flashinfer import mm_fp4 as fp4_gemm
|
45
|
+
from flashinfer import (
|
46
|
+
reorder_rows_for_gated_act_gemm,
|
47
|
+
shuffle_matrix_a,
|
48
|
+
shuffle_matrix_sf_a,
|
49
|
+
)
|
42
50
|
|
43
51
|
enable_flashinfer_fp4_gemm = True
|
44
52
|
except ImportError:
|
@@ -47,6 +55,9 @@ except ImportError:
|
|
47
55
|
else:
|
48
56
|
fp4_gemm = None
|
49
57
|
enable_flashinfer_fp4_gemm = False
|
58
|
+
reorder_rows_for_gated_act_gemm = None
|
59
|
+
shuffle_matrix_a = None
|
60
|
+
shuffle_matrix_sf_a = None
|
50
61
|
|
51
62
|
try:
|
52
63
|
from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe
|
@@ -527,6 +538,7 @@ class ModelOptFp4Config(QuantizationConfig):
|
|
527
538
|
) -> Optional[QuantizeMethodBase]:
|
528
539
|
from sglang.srt.layers.linear import LinearBase
|
529
540
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
541
|
+
from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFP4MoE
|
530
542
|
|
531
543
|
if isinstance(layer, LinearBase):
|
532
544
|
if is_layer_skipped(prefix, self.exclude_modules) or self.is_layer_excluded(
|
@@ -536,6 +548,9 @@ class ModelOptFp4Config(QuantizationConfig):
|
|
536
548
|
return ModelOptFp4LinearMethod(self)
|
537
549
|
if self.kv_cache_quant_algo and isinstance(layer, RadixAttention):
|
538
550
|
return ModelOptFp8KVCacheMethod(self)
|
551
|
+
elif isinstance(layer, FlashInferFP4MoE):
|
552
|
+
# FlashInferFP4MoE needs the same quantization method but with compatible attribute handling
|
553
|
+
return ModelOptNvFp4FusedMoEMethod(self)
|
539
554
|
elif isinstance(layer, FusedMoE):
|
540
555
|
return ModelOptNvFp4FusedMoEMethod(self)
|
541
556
|
return None
|
@@ -726,7 +741,12 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
726
741
|
" quantization. Please use Blackwell and"
|
727
742
|
" above."
|
728
743
|
)
|
729
|
-
self.
|
744
|
+
self.enable_flashinfer_trtllm_moe = should_use_flashinfer_trtllm_moe()
|
745
|
+
|
746
|
+
@property
|
747
|
+
def enable_flashinfer_cutlass_moe(self) -> bool:
|
748
|
+
"""Access the global enable_flashinfer_cutlass_moe setting."""
|
749
|
+
return global_server_args_dict.get("enable_flashinfer_cutlass_moe", False)
|
730
750
|
|
731
751
|
def create_weights(
|
732
752
|
self,
|
@@ -743,16 +763,18 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
743
763
|
" dynamic quantization is not supported."
|
744
764
|
)
|
745
765
|
|
746
|
-
|
766
|
+
# TODO(ch-wan): check if this is needed
|
767
|
+
layer.intermediate_size_per_partition = intermediate_size_per_partition
|
747
768
|
layer.params_dtype = params_dtype
|
748
769
|
layer.quant_config = self.quant_config
|
770
|
+
|
749
771
|
weight_dtype = torch.uint8
|
750
772
|
weight_scale_dtype = torch.float8_e4m3fn
|
751
773
|
weight_loader = extra_weight_attrs.get("weight_loader")
|
752
774
|
# GEMM 1
|
753
775
|
w13_weight = ModelWeightParameter(
|
754
776
|
data=torch.empty(
|
755
|
-
|
777
|
+
layer.num_local_experts,
|
756
778
|
2 * intermediate_size_per_partition,
|
757
779
|
# 2 fp4 items are packed in the input dimension
|
758
780
|
hidden_size // 2,
|
@@ -767,7 +789,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
767
789
|
# GEMM 2
|
768
790
|
w2_weight = ModelWeightParameter(
|
769
791
|
data=torch.empty(
|
770
|
-
|
792
|
+
layer.num_local_experts,
|
771
793
|
hidden_size,
|
772
794
|
# 2 fp4 items are packed in the input dimension
|
773
795
|
intermediate_size_per_partition // 2,
|
@@ -781,7 +803,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
781
803
|
|
782
804
|
w13_weight_scale = ModelWeightParameter(
|
783
805
|
data=torch.empty(
|
784
|
-
|
806
|
+
layer.num_local_experts,
|
785
807
|
2 * intermediate_size_per_partition,
|
786
808
|
# 2 fp4 items are packed in the input dimension
|
787
809
|
hidden_size // self.quant_config.group_size,
|
@@ -795,7 +817,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
795
817
|
|
796
818
|
w2_weight_scale = ModelWeightParameter(
|
797
819
|
data=torch.empty(
|
798
|
-
|
820
|
+
layer.num_local_experts,
|
799
821
|
hidden_size,
|
800
822
|
# 2 fp4 items are packed in the input dimension
|
801
823
|
intermediate_size_per_partition // self.quant_config.group_size,
|
@@ -814,13 +836,13 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
814
836
|
)
|
815
837
|
|
816
838
|
w13_weight_scale_2 = PerTensorScaleParameter(
|
817
|
-
data=torch.empty(
|
839
|
+
data=torch.empty(layer.num_local_experts, 2, dtype=torch.float32),
|
818
840
|
weight_loader=weight_loader,
|
819
841
|
)
|
820
842
|
layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2)
|
821
843
|
|
822
844
|
w2_weight_scale_2 = PerTensorScaleParameter(
|
823
|
-
data=torch.empty(
|
845
|
+
data=torch.empty(layer.num_local_experts, dtype=torch.float32),
|
824
846
|
weight_loader=weight_loader,
|
825
847
|
)
|
826
848
|
layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2)
|
@@ -830,18 +852,18 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
830
852
|
)
|
831
853
|
|
832
854
|
w13_input_scale = PerTensorScaleParameter(
|
833
|
-
data=torch.empty(
|
855
|
+
data=torch.empty(layer.num_local_experts, 2, dtype=torch.float32),
|
834
856
|
weight_loader=weight_loader,
|
835
857
|
)
|
836
858
|
layer.register_parameter("w13_input_scale", w13_input_scale)
|
837
859
|
|
838
860
|
w2_input_scale = PerTensorScaleParameter(
|
839
|
-
data=torch.empty(
|
861
|
+
data=torch.empty(layer.num_local_experts, dtype=torch.float32),
|
840
862
|
weight_loader=weight_loader,
|
841
863
|
)
|
842
864
|
layer.register_parameter("w2_input_scale", w2_input_scale)
|
843
865
|
|
844
|
-
def swizzle_blockscale(self, scale: torch.
|
866
|
+
def swizzle_blockscale(self, scale: torch.Tensor):
|
845
867
|
assert scale.dtype == torch.float8_e4m3fn
|
846
868
|
# Pad and blockwise interleave weight_scale
|
847
869
|
scale_ndim = scale.ndim
|
@@ -866,9 +888,125 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
866
888
|
else swizzled_scale.reshape(B, M, K)
|
867
889
|
)
|
868
890
|
|
891
|
+
def prepare_static_weights_for_kernel(
|
892
|
+
self,
|
893
|
+
# args_dequant,
|
894
|
+
# args,
|
895
|
+
gemm1_weights,
|
896
|
+
gemm2_weights,
|
897
|
+
gemm1_scales_linear_fp4_bytes,
|
898
|
+
gemm2_scales_linear_fp4_bytes,
|
899
|
+
hidden_size,
|
900
|
+
intermediate_size,
|
901
|
+
num_experts,
|
902
|
+
):
|
903
|
+
from flashinfer import (
|
904
|
+
RoutingMethodType,
|
905
|
+
e2m1_and_ufp8sf_scale_to_float,
|
906
|
+
fp4_quantize,
|
907
|
+
next_positive_power_of_2,
|
908
|
+
reorder_rows_for_gated_act_gemm,
|
909
|
+
shuffle_matrix_a,
|
910
|
+
shuffle_matrix_sf_a,
|
911
|
+
)
|
912
|
+
|
913
|
+
"""Prepare quantized weights for kernel (done offline with weights)."""
|
914
|
+
epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
|
915
|
+
|
916
|
+
# Convert quantized weights to proper formats
|
917
|
+
gemm1_weights_fp4 = gemm1_weights.view(torch.float8_e4m3fn).reshape(
|
918
|
+
num_experts, 2 * intermediate_size, hidden_size // 2
|
919
|
+
) # packed fp4
|
920
|
+
gemm1_scales_linear_fp4 = gemm1_scales_linear_fp4_bytes.view(
|
921
|
+
torch.float8_e4m3fn
|
922
|
+
).reshape(
|
923
|
+
num_experts, 2 * intermediate_size, hidden_size // 16
|
924
|
+
) # fp8 scaling factors
|
925
|
+
|
926
|
+
gemm2_weights_fp4 = gemm2_weights.view(torch.float8_e4m3fn).reshape(
|
927
|
+
num_experts, hidden_size, intermediate_size // 2
|
928
|
+
) # packed fp4
|
929
|
+
gemm2_scales_linear_fp4 = gemm2_scales_linear_fp4_bytes.view(
|
930
|
+
torch.float8_e4m3fn
|
931
|
+
).reshape(
|
932
|
+
num_experts, hidden_size, intermediate_size // 16
|
933
|
+
) # fp8 scaling factors
|
934
|
+
|
935
|
+
# Reorder rows of W1 and scales for fused gated activation
|
936
|
+
gemm1_weights_fp4_interleaved = []
|
937
|
+
gemm1_scales_fp4_interleaved = []
|
938
|
+
for i in range(num_experts):
|
939
|
+
gemm1_weights_fp4_interleaved.append(
|
940
|
+
reorder_rows_for_gated_act_gemm(gemm1_weights_fp4[i].clone())
|
941
|
+
)
|
942
|
+
gemm1_scales_fp4_interleaved.append(
|
943
|
+
reorder_rows_for_gated_act_gemm(gemm1_scales_linear_fp4[i].clone())
|
944
|
+
)
|
945
|
+
|
946
|
+
# Stack weights and scales for all experts
|
947
|
+
gemm1_weights_fp4_interleaved = torch.stack(
|
948
|
+
gemm1_weights_fp4_interleaved
|
949
|
+
).reshape(num_experts, 2 * intermediate_size, hidden_size // 2)
|
950
|
+
gemm1_scales_fp4_interleaved = torch.stack(
|
951
|
+
gemm1_scales_fp4_interleaved
|
952
|
+
).reshape(num_experts, 2 * intermediate_size, hidden_size // 16)
|
953
|
+
|
954
|
+
# Shuffle weights and scaling factors for transposed mma output
|
955
|
+
gemm1_weights_fp4_shuffled = []
|
956
|
+
gemm1_scales_fp4_shuffled = []
|
957
|
+
gemm2_weights_fp4_shuffled = []
|
958
|
+
gemm2_scales_fp4_shuffled = []
|
959
|
+
for i in range(num_experts):
|
960
|
+
gemm1_weights_fp4_shuffled.append(
|
961
|
+
shuffle_matrix_a(
|
962
|
+
gemm1_weights_fp4_interleaved[i].view(torch.uint8), epilogue_tile_m
|
963
|
+
)
|
964
|
+
)
|
965
|
+
gemm1_scales_fp4_shuffled.append(
|
966
|
+
shuffle_matrix_sf_a(
|
967
|
+
gemm1_scales_fp4_interleaved[i].view(torch.uint8), epilogue_tile_m
|
968
|
+
)
|
969
|
+
)
|
970
|
+
|
971
|
+
gemm2_weights_fp4_shuffled.append(
|
972
|
+
shuffle_matrix_a(
|
973
|
+
gemm2_weights_fp4[i].view(torch.uint8), epilogue_tile_m
|
974
|
+
)
|
975
|
+
)
|
976
|
+
gemm2_scales_fp4_shuffled.append(
|
977
|
+
shuffle_matrix_sf_a(
|
978
|
+
gemm2_scales_linear_fp4[i].view(torch.uint8), epilogue_tile_m
|
979
|
+
)
|
980
|
+
)
|
981
|
+
|
982
|
+
# Stack weights for all experts
|
983
|
+
gemm1_weights_fp4_shuffled = torch.stack(gemm1_weights_fp4_shuffled)
|
984
|
+
gemm1_scales_fp4_shuffled = (
|
985
|
+
torch.stack(gemm1_scales_fp4_shuffled)
|
986
|
+
.view(torch.float8_e4m3fn)
|
987
|
+
.reshape(num_experts, 2 * intermediate_size, hidden_size // 16)
|
988
|
+
)
|
989
|
+
|
990
|
+
gemm2_weights_fp4_shuffled = torch.stack(gemm2_weights_fp4_shuffled)
|
991
|
+
gemm2_scales_fp4_shuffled = (
|
992
|
+
torch.stack(gemm2_scales_fp4_shuffled)
|
993
|
+
.view(torch.float8_e4m3fn)
|
994
|
+
.reshape(num_experts, hidden_size, intermediate_size // 16)
|
995
|
+
)
|
996
|
+
return (
|
997
|
+
gemm1_weights_fp4_shuffled,
|
998
|
+
gemm1_scales_fp4_shuffled,
|
999
|
+
gemm2_weights_fp4_shuffled,
|
1000
|
+
gemm2_scales_fp4_shuffled,
|
1001
|
+
)
|
1002
|
+
|
869
1003
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
1004
|
+
"""Process FP4 MoE weights after loading from serialized checkpoint.
|
870
1005
|
|
871
|
-
|
1006
|
+
Only supports pre-quantized checkpoints with FP8 weights and scales.
|
1007
|
+
"""
|
1008
|
+
|
1009
|
+
# GEMM 1 scale processing
|
872
1010
|
if not torch.allclose(
|
873
1011
|
layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]
|
874
1012
|
):
|
@@ -880,73 +1018,123 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
880
1018
|
w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
|
881
1019
|
layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False)
|
882
1020
|
|
883
|
-
|
1021
|
+
# Calculate input scales based on strategy
|
1022
|
+
if self.enable_flashinfer_cutlass_moe or self.enable_flashinfer_trtllm_moe:
|
884
1023
|
w13_input_scale = layer.w13_input_scale.max().to(torch.float32)
|
1024
|
+
w2_input_scale = layer.w2_input_scale.max().to(torch.float32)
|
885
1025
|
else:
|
886
1026
|
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
|
1027
|
+
w2_input_scale = layer.w2_input_scale
|
1028
|
+
|
1029
|
+
# Create shared parameters
|
887
1030
|
layer.g1_alphas = Parameter(
|
888
1031
|
(w13_input_scale * w13_weight_scale_2).to(torch.float32),
|
889
1032
|
requires_grad=False,
|
890
1033
|
)
|
891
|
-
|
892
|
-
|
893
|
-
|
894
|
-
), "Expected weight_scale.dim(1) to be divisible by 16"
|
895
|
-
assert (
|
896
|
-
layer.w13_weight_scale.dtype == torch.float8_e4m3fn
|
897
|
-
), "Weight Blockscale must be represented as FP8-E4M3"
|
898
|
-
w13_blockscale_swizzled = self.swizzle_blockscale(layer.w13_weight_scale)
|
899
|
-
|
900
|
-
layer.w13_blockscale_swizzled = Parameter(
|
901
|
-
w13_blockscale_swizzled, requires_grad=False
|
1034
|
+
layer.g2_alphas = Parameter(
|
1035
|
+
(w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
|
1036
|
+
requires_grad=False,
|
902
1037
|
)
|
903
|
-
del layer.w13_weight_scale
|
904
|
-
|
905
|
-
# This is for quantization, so we need to invert it.
|
906
1038
|
layer.w13_input_scale_quant = Parameter(
|
907
1039
|
(1 / w13_input_scale).to(torch.float32), requires_grad=False
|
908
1040
|
)
|
1041
|
+
layer.w2_input_scale_quant = Parameter(
|
1042
|
+
(1 / w2_input_scale).to(torch.float32), requires_grad=False
|
1043
|
+
)
|
909
1044
|
|
910
|
-
|
1045
|
+
# Validate weight scales
|
1046
|
+
for name, weight_scale in [
|
1047
|
+
("w13", layer.w13_weight_scale),
|
1048
|
+
("w2", layer.w2_weight_scale),
|
1049
|
+
]:
|
1050
|
+
assert (
|
1051
|
+
weight_scale.shape[2] % 16 == 0
|
1052
|
+
), f"Expected {name}_weight_scale.dim(2) to be divisible by 16"
|
1053
|
+
assert (
|
1054
|
+
weight_scale.dtype == torch.float8_e4m3fn
|
1055
|
+
), f"{name} Weight Blockscale must be represented as FP8-E4M3"
|
1056
|
+
|
1057
|
+
# Weight processing based on strategy
|
1058
|
+
if (
|
1059
|
+
self.enable_flashinfer_trtllm_moe
|
1060
|
+
and reorder_rows_for_gated_act_gemm is not None
|
1061
|
+
and shuffle_matrix_sf_a is not None
|
1062
|
+
):
|
1063
|
+
# FlashInfer TRTLLM processing - handles both w13 and w2
|
1064
|
+
(
|
1065
|
+
gemm1_weights_fp4_shuffled,
|
1066
|
+
gemm1_scales_fp4_shuffled,
|
1067
|
+
gemm2_weights_fp4_shuffled,
|
1068
|
+
gemm2_scales_fp4_shuffled,
|
1069
|
+
) = self.prepare_static_weights_for_kernel(
|
1070
|
+
layer.w13_weight,
|
1071
|
+
layer.w2_weight,
|
1072
|
+
layer.w13_weight_scale,
|
1073
|
+
layer.w2_weight_scale,
|
1074
|
+
layer.w2_weight.size(-2), # hidden_size
|
1075
|
+
layer.w13_weight.size(-2) // 2, # intermediate_size
|
1076
|
+
layer.w13_weight.size(0), # num_experts
|
1077
|
+
)
|
911
1078
|
|
912
|
-
|
913
|
-
|
914
|
-
|
915
|
-
|
916
|
-
|
1079
|
+
# Set flashinfer parameters
|
1080
|
+
layer.gemm1_weights_fp4_shuffled = Parameter(
|
1081
|
+
gemm1_weights_fp4_shuffled, requires_grad=False
|
1082
|
+
)
|
1083
|
+
layer.gemm2_weights_fp4_shuffled = Parameter(
|
1084
|
+
gemm2_weights_fp4_shuffled, requires_grad=False
|
1085
|
+
)
|
1086
|
+
layer.gemm1_scales_fp4_shuffled = Parameter(
|
1087
|
+
gemm1_scales_fp4_shuffled, requires_grad=False
|
1088
|
+
)
|
1089
|
+
layer.gemm2_scales_fp4_shuffled = Parameter(
|
1090
|
+
gemm2_scales_fp4_shuffled, requires_grad=False
|
1091
|
+
)
|
917
1092
|
|
918
|
-
|
919
|
-
|
920
|
-
|
921
|
-
|
1093
|
+
# Additional parameter needed for TRT-LLM
|
1094
|
+
layer.g1_scale_c = Parameter(
|
1095
|
+
(layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32),
|
1096
|
+
requires_grad=False,
|
1097
|
+
)
|
922
1098
|
|
923
|
-
|
924
|
-
|
925
|
-
|
926
|
-
|
1099
|
+
# Clean up weights that won't be used by TRT-LLM
|
1100
|
+
del (
|
1101
|
+
layer.w2_weight,
|
1102
|
+
layer.w2_weight_scale,
|
1103
|
+
layer.w13_weight,
|
1104
|
+
layer.w13_weight_scale,
|
1105
|
+
)
|
927
1106
|
|
928
|
-
|
929
|
-
layer.w2_weight_scale.shape[2] % 16 == 0
|
930
|
-
), "Expected weight_scale.dim(1) to be divisible by 16"
|
931
|
-
assert (
|
932
|
-
layer.w2_weight_scale.dtype == torch.float8_e4m3fn
|
933
|
-
), "Weight Blockscale must be represented as FP8-E4M3"
|
934
|
-
w2_blockscale_swizzled = self.swizzle_blockscale(layer.w2_weight_scale)
|
1107
|
+
logger.info_once("Applied flashinfer weight processing for both w13 and w2")
|
935
1108
|
|
936
|
-
|
937
|
-
|
938
|
-
|
939
|
-
|
940
|
-
|
1109
|
+
else:
|
1110
|
+
# CUTLASS processing - handle w13 and w2 separately
|
1111
|
+
|
1112
|
+
# Process w13 weights
|
1113
|
+
w13_blockscale_swizzled = self.swizzle_blockscale(layer.w13_weight_scale)
|
1114
|
+
layer.w13_blockscale_swizzled = Parameter(
|
1115
|
+
w13_blockscale_swizzled, requires_grad=False
|
1116
|
+
)
|
1117
|
+
layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
|
1118
|
+
|
1119
|
+
# Process w2 weights
|
1120
|
+
w2_blockscale_swizzled = self.swizzle_blockscale(layer.w2_weight_scale)
|
1121
|
+
layer.w2_blockscale_swizzled = Parameter(
|
1122
|
+
w2_blockscale_swizzled, requires_grad=False
|
1123
|
+
)
|
1124
|
+
layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
|
1125
|
+
|
1126
|
+
# Both flashinfer cutlass and regular cutlass use same processing for w2
|
1127
|
+
logger.info_once("Applied weight processing for both w13 and w2")
|
941
1128
|
|
942
|
-
|
943
|
-
|
944
|
-
|
945
|
-
|
946
|
-
|
947
|
-
|
948
|
-
|
949
|
-
|
1129
|
+
# Set up CUTLASS MoE parameters
|
1130
|
+
device = layer.w13_weight.device
|
1131
|
+
layer.cutlass_moe_params = CutlassMoEParams(
|
1132
|
+
CutlassMoEType.BlockscaledFP4,
|
1133
|
+
device,
|
1134
|
+
num_experts=layer.num_experts, # global num experts
|
1135
|
+
intermediate_size_per_partition=layer.w2_weight.shape[2] * 2, # n
|
1136
|
+
hidden_size=layer.w13_weight.shape[2] * 2,
|
1137
|
+
) # k
|
950
1138
|
|
951
1139
|
@property
|
952
1140
|
def load_up_proj_weight_first(self) -> bool:
|
@@ -971,13 +1159,20 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
971
1159
|
) -> torch.Tensor:
|
972
1160
|
assert activation == "silu", "Only SiLU activation is supported."
|
973
1161
|
|
1162
|
+
# Check if this is a FlashInferFP4MoE layer that should handle its own forward
|
1163
|
+
if hasattr(layer, "gemm1_weights_fp4_shuffled"):
|
1164
|
+
# This layer was processed with flashinfer TRTLLM - delegate to its own forward
|
1165
|
+
return layer.forward(x, topk_output)
|
1166
|
+
|
974
1167
|
if self.enable_flashinfer_cutlass_moe:
|
975
1168
|
assert (
|
976
1169
|
not apply_router_weight_on_input
|
977
1170
|
), "apply_router_weight_on_input is not supported for Flashinfer"
|
978
1171
|
# TRTLLM Cutlass moe takes in activations in BF16/Half/nvfp4 precision
|
979
1172
|
# and fp4 quantized weights loaded from the checkpoint
|
980
|
-
|
1173
|
+
|
1174
|
+
topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids
|
1175
|
+
|
981
1176
|
output = flashinfer_cutlass_fused_moe(
|
982
1177
|
x,
|
983
1178
|
topk_ids.to(torch.int),
|
@@ -1005,7 +1200,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
1005
1200
|
|
1006
1201
|
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
|
1007
1202
|
|
1008
|
-
topk_weights, topk_ids
|
1203
|
+
topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids
|
1009
1204
|
output = cutlass_moe_fp4(
|
1010
1205
|
a=x,
|
1011
1206
|
a1_gscale=layer.w13_input_scale_quant,
|