sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -3
- sglang/bench_one_batch.py +119 -17
- sglang/lang/chat_template.py +18 -0
- sglang/srt/bench_utils.py +137 -0
- sglang/srt/configs/model_config.py +42 -7
- sglang/srt/conversation.py +9 -5
- sglang/srt/disaggregation/base/conn.py +5 -2
- sglang/srt/disaggregation/decode.py +14 -4
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
- sglang/srt/disaggregation/mooncake/conn.py +286 -160
- sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
- sglang/srt/disaggregation/prefill.py +2 -0
- sglang/srt/distributed/parallel_state.py +15 -11
- sglang/srt/entrypoints/context.py +227 -0
- sglang/srt/entrypoints/engine.py +15 -9
- sglang/srt/entrypoints/harmony_utils.py +372 -0
- sglang/srt/entrypoints/http_server.py +74 -4
- sglang/srt/entrypoints/openai/protocol.py +218 -1
- sglang/srt/entrypoints/openai/serving_chat.py +41 -11
- sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
- sglang/srt/entrypoints/openai/tool_server.py +175 -0
- sglang/srt/entrypoints/tool.py +87 -0
- sglang/srt/eplb/expert_location.py +5 -1
- sglang/srt/function_call/ebnf_composer.py +1 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +331 -0
- sglang/srt/function_call/kimik2_detector.py +3 -3
- sglang/srt/function_call/qwen3_coder_detector.py +219 -9
- sglang/srt/hf_transformers_utils.py +30 -3
- sglang/srt/jinja_template_utils.py +14 -1
- sglang/srt/layers/attention/aiter_backend.py +375 -115
- sglang/srt/layers/attention/ascend_backend.py +3 -0
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
- sglang/srt/layers/attention/flashattention_backend.py +18 -0
- sglang/srt/layers/attention/flashinfer_backend.py +52 -13
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +85 -14
- sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
- sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
- sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
- sglang/srt/layers/attention/vision.py +22 -6
- sglang/srt/layers/attention/wave_backend.py +627 -0
- sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
- sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
- sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
- sglang/srt/layers/communicator.py +29 -14
- sglang/srt/layers/dp_attention.py +12 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
- sglang/srt/layers/linear.py +3 -7
- sglang/srt/layers/moe/cutlass_moe.py +12 -3
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
- sglang/srt/layers/moe/ep_moe/layer.py +135 -73
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
- sglang/srt/layers/moe/fused_moe_triton/layer.py +412 -33
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
- sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
- sglang/srt/layers/moe/topk.py +16 -4
- sglang/srt/layers/moe/utils.py +16 -0
- sglang/srt/layers/quantization/__init__.py +27 -3
- sglang/srt/layers/quantization/fp4.py +557 -0
- sglang/srt/layers/quantization/fp8.py +3 -6
- sglang/srt/layers/quantization/fp8_kernel.py +277 -0
- sglang/srt/layers/quantization/fp8_utils.py +51 -10
- sglang/srt/layers/quantization/modelopt_quant.py +258 -68
- sglang/srt/layers/quantization/mxfp4.py +654 -0
- sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
- sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
- sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
- sglang/srt/layers/quantization/quark/utils.py +107 -0
- sglang/srt/layers/quantization/unquant.py +60 -6
- sglang/srt/layers/quantization/w4afp8.py +21 -12
- sglang/srt/layers/quantization/w8a8_int8.py +48 -34
- sglang/srt/layers/rotary_embedding.py +506 -3
- sglang/srt/layers/utils.py +9 -0
- sglang/srt/layers/vocab_parallel_embedding.py +8 -3
- sglang/srt/lora/backend/base_backend.py +3 -23
- sglang/srt/lora/layers.py +60 -114
- sglang/srt/lora/lora.py +17 -62
- sglang/srt/lora/lora_manager.py +82 -62
- sglang/srt/lora/lora_registry.py +23 -11
- sglang/srt/lora/mem_pool.py +63 -68
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/utils.py +25 -58
- sglang/srt/managers/cache_controller.py +75 -58
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +20 -8
- sglang/srt/managers/mm_utils.py +6 -13
- sglang/srt/managers/multimodal_processor.py +1 -1
- sglang/srt/managers/schedule_batch.py +61 -25
- sglang/srt/managers/schedule_policy.py +6 -6
- sglang/srt/managers/scheduler.py +41 -19
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
- sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
- sglang/srt/managers/scheduler_recv_skipper.py +37 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
- sglang/srt/managers/template_manager.py +35 -1
- sglang/srt/managers/tokenizer_manager.py +47 -30
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
- sglang/srt/mem_cache/allocator.py +61 -87
- sglang/srt/mem_cache/hicache_storage.py +1 -1
- sglang/srt/mem_cache/hiradix_cache.py +80 -22
- sglang/srt/mem_cache/lora_radix_cache.py +421 -0
- sglang/srt/mem_cache/memory_pool_host.py +34 -36
- sglang/srt/mem_cache/multimodal_cache.py +33 -13
- sglang/srt/mem_cache/radix_cache.py +2 -5
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
- sglang/srt/model_executor/cuda_graph_runner.py +29 -9
- sglang/srt/model_executor/forward_batch_info.py +61 -19
- sglang/srt/model_executor/model_runner.py +148 -37
- sglang/srt/model_loader/loader.py +18 -6
- sglang/srt/model_loader/weight_utils.py +10 -0
- sglang/srt/models/bailing_moe.py +425 -0
- sglang/srt/models/deepseek_v2.py +137 -59
- sglang/srt/models/ernie4.py +426 -0
- sglang/srt/models/ernie4_eagle.py +203 -0
- sglang/srt/models/gemma2.py +0 -34
- sglang/srt/models/gemma3n_mm.py +38 -0
- sglang/srt/models/glm4.py +6 -0
- sglang/srt/models/glm4_moe.py +28 -16
- sglang/srt/models/glm4v.py +589 -0
- sglang/srt/models/glm4v_moe.py +400 -0
- sglang/srt/models/gpt_oss.py +1251 -0
- sglang/srt/models/granite.py +0 -25
- sglang/srt/models/llama.py +0 -25
- sglang/srt/models/llama4.py +1 -1
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_5_vl.py +7 -3
- sglang/srt/models/qwen2_audio.py +10 -9
- sglang/srt/models/qwen2_moe.py +6 -0
- sglang/srt/models/qwen3.py +0 -24
- sglang/srt/models/qwen3_moe.py +32 -6
- sglang/srt/models/registry.py +1 -1
- sglang/srt/models/step3_vl.py +9 -0
- sglang/srt/models/torch_native_llama.py +0 -24
- sglang/srt/models/transformers.py +2 -5
- sglang/srt/multimodal/processors/base_processor.py +23 -13
- sglang/srt/multimodal/processors/glm4v.py +132 -0
- sglang/srt/multimodal/processors/qwen_audio.py +4 -2
- sglang/srt/multimodal/processors/step3_vl.py +3 -1
- sglang/srt/reasoning_parser.py +332 -37
- sglang/srt/server_args.py +186 -75
- sglang/srt/speculative/eagle_worker.py +16 -0
- sglang/srt/two_batch_overlap.py +169 -9
- sglang/srt/utils.py +41 -5
- sglang/srt/weight_sync/tensor_bucket.py +106 -0
- sglang/test/attention/test_trtllm_mla_backend.py +186 -36
- sglang/test/doc_patch.py +59 -0
- sglang/test/few_shot_gsm8k.py +1 -1
- sglang/test/few_shot_gsm8k_engine.py +1 -1
- sglang/test/run_eval.py +4 -1
- sglang/test/runners.py +2 -2
- sglang/test/simple_eval_common.py +6 -0
- sglang/test/simple_eval_gpqa.py +2 -0
- sglang/test/test_fp4_moe.py +118 -36
- sglang/test/test_utils.py +1 -1
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +36 -38
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +174 -141
- sglang/srt/lora/backend/flashinfer_backend.py +0 -131
- /sglang/{api.py → lang/api.py} +0 -0
- /sglang/{lang/backend → srt/layers/quantization/quark}/__init__.py +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -2,12 +2,13 @@
|
|
2
2
|
from __future__ import annotations
|
3
3
|
|
4
4
|
import logging
|
5
|
-
from typing import TYPE_CHECKING, Any,
|
5
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
6
6
|
|
7
7
|
import torch
|
8
8
|
from torch.nn.parameter import Parameter
|
9
9
|
|
10
10
|
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
|
11
|
+
from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe
|
11
12
|
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
|
12
13
|
from sglang.srt.layers.quantization.base_config import (
|
13
14
|
FusedMoEMethodBase,
|
@@ -29,6 +30,7 @@ from sglang.srt.layers.quantization.utils import (
|
|
29
30
|
requantize_with_max_scale,
|
30
31
|
)
|
31
32
|
from sglang.srt.layers.radix_attention import RadixAttention
|
33
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
32
34
|
from sglang.srt.utils import is_cuda, next_power_of_2
|
33
35
|
|
34
36
|
if TYPE_CHECKING:
|
@@ -39,6 +41,7 @@ if is_cuda():
|
|
39
41
|
|
40
42
|
try:
|
41
43
|
from flashinfer import mm_fp4 as fp4_gemm
|
44
|
+
from flashinfer import reorder_rows_for_gated_act_gemm, shuffle_matrix_sf_a
|
42
45
|
|
43
46
|
enable_flashinfer_fp4_gemm = True
|
44
47
|
except ImportError:
|
@@ -47,6 +50,9 @@ except ImportError:
|
|
47
50
|
else:
|
48
51
|
fp4_gemm = None
|
49
52
|
enable_flashinfer_fp4_gemm = False
|
53
|
+
reorder_rows_for_gated_act_gemm = None
|
54
|
+
shuffle_matrix_a = None
|
55
|
+
shuffle_matrix_sf_a = None
|
50
56
|
|
51
57
|
try:
|
52
58
|
from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe
|
@@ -527,6 +533,7 @@ class ModelOptFp4Config(QuantizationConfig):
|
|
527
533
|
) -> Optional[QuantizeMethodBase]:
|
528
534
|
from sglang.srt.layers.linear import LinearBase
|
529
535
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
536
|
+
from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFP4MoE
|
530
537
|
|
531
538
|
if isinstance(layer, LinearBase):
|
532
539
|
if is_layer_skipped(prefix, self.exclude_modules) or self.is_layer_excluded(
|
@@ -536,6 +543,9 @@ class ModelOptFp4Config(QuantizationConfig):
|
|
536
543
|
return ModelOptFp4LinearMethod(self)
|
537
544
|
if self.kv_cache_quant_algo and isinstance(layer, RadixAttention):
|
538
545
|
return ModelOptFp8KVCacheMethod(self)
|
546
|
+
elif isinstance(layer, FlashInferFP4MoE):
|
547
|
+
# FlashInferFP4MoE needs the same quantization method but with compatible attribute handling
|
548
|
+
return ModelOptNvFp4FusedMoEMethod(self)
|
539
549
|
elif isinstance(layer, FusedMoE):
|
540
550
|
return ModelOptNvFp4FusedMoEMethod(self)
|
541
551
|
return None
|
@@ -667,9 +677,9 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
|
|
667
677
|
padded_scales = padded_scales.permute((0, 1, 4, 3, 2, 5))
|
668
678
|
padded_scales = padded_scales.contiguous().cuda()
|
669
679
|
padded_scales = (
|
670
|
-
padded_scales.reshape(
|
680
|
+
padded_scales.reshape(M_padded, K_padded)
|
671
681
|
if scale_ndim == 2
|
672
|
-
else padded_scales.reshape(B,
|
682
|
+
else padded_scales.reshape(B, M_padded, K_padded)
|
673
683
|
)
|
674
684
|
layer.weight_scale_interleaved = Parameter(padded_scales, requires_grad=False)
|
675
685
|
|
@@ -726,7 +736,12 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
726
736
|
" quantization. Please use Blackwell and"
|
727
737
|
" above."
|
728
738
|
)
|
729
|
-
self.
|
739
|
+
self.enable_flashinfer_trtllm_moe = should_use_flashinfer_trtllm_moe()
|
740
|
+
|
741
|
+
@property
|
742
|
+
def enable_flashinfer_cutlass_moe(self) -> bool:
|
743
|
+
"""Access the global enable_flashinfer_cutlass_moe setting."""
|
744
|
+
return global_server_args_dict.get("enable_flashinfer_cutlass_moe", False)
|
730
745
|
|
731
746
|
def create_weights(
|
732
747
|
self,
|
@@ -743,16 +758,18 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
743
758
|
" dynamic quantization is not supported."
|
744
759
|
)
|
745
760
|
|
746
|
-
|
761
|
+
# TODO(ch-wan): check if this is needed
|
762
|
+
layer.intermediate_size_per_partition = intermediate_size_per_partition
|
747
763
|
layer.params_dtype = params_dtype
|
748
764
|
layer.quant_config = self.quant_config
|
765
|
+
|
749
766
|
weight_dtype = torch.uint8
|
750
767
|
weight_scale_dtype = torch.float8_e4m3fn
|
751
768
|
weight_loader = extra_weight_attrs.get("weight_loader")
|
752
769
|
# GEMM 1
|
753
770
|
w13_weight = ModelWeightParameter(
|
754
771
|
data=torch.empty(
|
755
|
-
|
772
|
+
layer.num_local_experts,
|
756
773
|
2 * intermediate_size_per_partition,
|
757
774
|
# 2 fp4 items are packed in the input dimension
|
758
775
|
hidden_size // 2,
|
@@ -767,7 +784,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
767
784
|
# GEMM 2
|
768
785
|
w2_weight = ModelWeightParameter(
|
769
786
|
data=torch.empty(
|
770
|
-
|
787
|
+
layer.num_local_experts,
|
771
788
|
hidden_size,
|
772
789
|
# 2 fp4 items are packed in the input dimension
|
773
790
|
intermediate_size_per_partition // 2,
|
@@ -781,7 +798,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
781
798
|
|
782
799
|
w13_weight_scale = ModelWeightParameter(
|
783
800
|
data=torch.empty(
|
784
|
-
|
801
|
+
layer.num_local_experts,
|
785
802
|
2 * intermediate_size_per_partition,
|
786
803
|
# 2 fp4 items are packed in the input dimension
|
787
804
|
hidden_size // self.quant_config.group_size,
|
@@ -795,7 +812,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
795
812
|
|
796
813
|
w2_weight_scale = ModelWeightParameter(
|
797
814
|
data=torch.empty(
|
798
|
-
|
815
|
+
layer.num_local_experts,
|
799
816
|
hidden_size,
|
800
817
|
# 2 fp4 items are packed in the input dimension
|
801
818
|
intermediate_size_per_partition // self.quant_config.group_size,
|
@@ -814,13 +831,13 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
814
831
|
)
|
815
832
|
|
816
833
|
w13_weight_scale_2 = PerTensorScaleParameter(
|
817
|
-
data=torch.empty(
|
834
|
+
data=torch.empty(layer.num_local_experts, 2, dtype=torch.float32),
|
818
835
|
weight_loader=weight_loader,
|
819
836
|
)
|
820
837
|
layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2)
|
821
838
|
|
822
839
|
w2_weight_scale_2 = PerTensorScaleParameter(
|
823
|
-
data=torch.empty(
|
840
|
+
data=torch.empty(layer.num_local_experts, dtype=torch.float32),
|
824
841
|
weight_loader=weight_loader,
|
825
842
|
)
|
826
843
|
layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2)
|
@@ -830,18 +847,18 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
830
847
|
)
|
831
848
|
|
832
849
|
w13_input_scale = PerTensorScaleParameter(
|
833
|
-
data=torch.empty(
|
850
|
+
data=torch.empty(layer.num_local_experts, 2, dtype=torch.float32),
|
834
851
|
weight_loader=weight_loader,
|
835
852
|
)
|
836
853
|
layer.register_parameter("w13_input_scale", w13_input_scale)
|
837
854
|
|
838
855
|
w2_input_scale = PerTensorScaleParameter(
|
839
|
-
data=torch.empty(
|
856
|
+
data=torch.empty(layer.num_local_experts, dtype=torch.float32),
|
840
857
|
weight_loader=weight_loader,
|
841
858
|
)
|
842
859
|
layer.register_parameter("w2_input_scale", w2_input_scale)
|
843
860
|
|
844
|
-
def swizzle_blockscale(self, scale: torch.
|
861
|
+
def swizzle_blockscale(self, scale: torch.Tensor):
|
845
862
|
assert scale.dtype == torch.float8_e4m3fn
|
846
863
|
# Pad and blockwise interleave weight_scale
|
847
864
|
scale_ndim = scale.ndim
|
@@ -861,14 +878,130 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
861
878
|
swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5))
|
862
879
|
swizzled_scale = swizzled_scale.contiguous().cuda()
|
863
880
|
return (
|
864
|
-
swizzled_scale.reshape(
|
881
|
+
swizzled_scale.reshape(M_padded, K_padded)
|
865
882
|
if scale_ndim == 2
|
866
|
-
else swizzled_scale.reshape(B,
|
883
|
+
else swizzled_scale.reshape(B, M_padded, K_padded)
|
884
|
+
)
|
885
|
+
|
886
|
+
def prepare_static_weights_for_kernel(
|
887
|
+
self,
|
888
|
+
# args_dequant,
|
889
|
+
# args,
|
890
|
+
gemm1_weights,
|
891
|
+
gemm2_weights,
|
892
|
+
gemm1_scales_linear_fp4_bytes,
|
893
|
+
gemm2_scales_linear_fp4_bytes,
|
894
|
+
hidden_size,
|
895
|
+
intermediate_size,
|
896
|
+
num_experts,
|
897
|
+
):
|
898
|
+
from flashinfer import (
|
899
|
+
RoutingMethodType,
|
900
|
+
e2m1_and_ufp8sf_scale_to_float,
|
901
|
+
fp4_quantize,
|
902
|
+
next_positive_power_of_2,
|
903
|
+
reorder_rows_for_gated_act_gemm,
|
904
|
+
shuffle_matrix_a,
|
905
|
+
shuffle_matrix_sf_a,
|
906
|
+
)
|
907
|
+
|
908
|
+
"""Prepare quantized weights for kernel (done offline with weights)."""
|
909
|
+
epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
|
910
|
+
|
911
|
+
# Convert quantized weights to proper formats
|
912
|
+
gemm1_weights_fp4 = gemm1_weights.view(torch.float8_e4m3fn).reshape(
|
913
|
+
num_experts, 2 * intermediate_size, hidden_size // 2
|
914
|
+
) # packed fp4
|
915
|
+
gemm1_scales_linear_fp4 = gemm1_scales_linear_fp4_bytes.view(
|
916
|
+
torch.float8_e4m3fn
|
917
|
+
).reshape(
|
918
|
+
num_experts, 2 * intermediate_size, hidden_size // 16
|
919
|
+
) # fp8 scaling factors
|
920
|
+
|
921
|
+
gemm2_weights_fp4 = gemm2_weights.view(torch.float8_e4m3fn).reshape(
|
922
|
+
num_experts, hidden_size, intermediate_size // 2
|
923
|
+
) # packed fp4
|
924
|
+
gemm2_scales_linear_fp4 = gemm2_scales_linear_fp4_bytes.view(
|
925
|
+
torch.float8_e4m3fn
|
926
|
+
).reshape(
|
927
|
+
num_experts, hidden_size, intermediate_size // 16
|
928
|
+
) # fp8 scaling factors
|
929
|
+
|
930
|
+
# Reorder rows of W1 and scales for fused gated activation
|
931
|
+
gemm1_weights_fp4_interleaved = []
|
932
|
+
gemm1_scales_fp4_interleaved = []
|
933
|
+
for i in range(num_experts):
|
934
|
+
gemm1_weights_fp4_interleaved.append(
|
935
|
+
reorder_rows_for_gated_act_gemm(gemm1_weights_fp4[i].clone())
|
936
|
+
)
|
937
|
+
gemm1_scales_fp4_interleaved.append(
|
938
|
+
reorder_rows_for_gated_act_gemm(gemm1_scales_linear_fp4[i].clone())
|
939
|
+
)
|
940
|
+
|
941
|
+
# Stack weights and scales for all experts
|
942
|
+
gemm1_weights_fp4_interleaved = torch.stack(
|
943
|
+
gemm1_weights_fp4_interleaved
|
944
|
+
).reshape(num_experts, 2 * intermediate_size, hidden_size // 2)
|
945
|
+
gemm1_scales_fp4_interleaved = torch.stack(
|
946
|
+
gemm1_scales_fp4_interleaved
|
947
|
+
).reshape(num_experts, 2 * intermediate_size, hidden_size // 16)
|
948
|
+
|
949
|
+
# Shuffle weights and scaling factors for transposed mma output
|
950
|
+
gemm1_weights_fp4_shuffled = []
|
951
|
+
gemm1_scales_fp4_shuffled = []
|
952
|
+
gemm2_weights_fp4_shuffled = []
|
953
|
+
gemm2_scales_fp4_shuffled = []
|
954
|
+
for i in range(num_experts):
|
955
|
+
gemm1_weights_fp4_shuffled.append(
|
956
|
+
shuffle_matrix_a(
|
957
|
+
gemm1_weights_fp4_interleaved[i].view(torch.uint8), epilogue_tile_m
|
958
|
+
)
|
959
|
+
)
|
960
|
+
gemm1_scales_fp4_shuffled.append(
|
961
|
+
shuffle_matrix_sf_a(
|
962
|
+
gemm1_scales_fp4_interleaved[i].view(torch.uint8), epilogue_tile_m
|
963
|
+
)
|
964
|
+
)
|
965
|
+
|
966
|
+
gemm2_weights_fp4_shuffled.append(
|
967
|
+
shuffle_matrix_a(
|
968
|
+
gemm2_weights_fp4[i].view(torch.uint8), epilogue_tile_m
|
969
|
+
)
|
970
|
+
)
|
971
|
+
gemm2_scales_fp4_shuffled.append(
|
972
|
+
shuffle_matrix_sf_a(
|
973
|
+
gemm2_scales_linear_fp4[i].view(torch.uint8), epilogue_tile_m
|
974
|
+
)
|
975
|
+
)
|
976
|
+
|
977
|
+
# Stack weights for all experts
|
978
|
+
gemm1_weights_fp4_shuffled = torch.stack(gemm1_weights_fp4_shuffled)
|
979
|
+
gemm1_scales_fp4_shuffled = (
|
980
|
+
torch.stack(gemm1_scales_fp4_shuffled)
|
981
|
+
.view(torch.float8_e4m3fn)
|
982
|
+
.reshape(num_experts, 2 * intermediate_size, hidden_size // 16)
|
983
|
+
)
|
984
|
+
|
985
|
+
gemm2_weights_fp4_shuffled = torch.stack(gemm2_weights_fp4_shuffled)
|
986
|
+
gemm2_scales_fp4_shuffled = (
|
987
|
+
torch.stack(gemm2_scales_fp4_shuffled)
|
988
|
+
.view(torch.float8_e4m3fn)
|
989
|
+
.reshape(num_experts, hidden_size, intermediate_size // 16)
|
990
|
+
)
|
991
|
+
return (
|
992
|
+
gemm1_weights_fp4_shuffled,
|
993
|
+
gemm1_scales_fp4_shuffled,
|
994
|
+
gemm2_weights_fp4_shuffled,
|
995
|
+
gemm2_scales_fp4_shuffled,
|
867
996
|
)
|
868
997
|
|
869
998
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
999
|
+
"""Process FP4 MoE weights after loading from serialized checkpoint.
|
870
1000
|
|
871
|
-
|
1001
|
+
Only supports pre-quantized checkpoints with FP8 weights and scales.
|
1002
|
+
"""
|
1003
|
+
|
1004
|
+
# GEMM 1 scale processing
|
872
1005
|
if not torch.allclose(
|
873
1006
|
layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]
|
874
1007
|
):
|
@@ -880,73 +1013,123 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
880
1013
|
w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
|
881
1014
|
layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False)
|
882
1015
|
|
883
|
-
|
1016
|
+
# Calculate input scales based on strategy
|
1017
|
+
if self.enable_flashinfer_cutlass_moe or self.enable_flashinfer_trtllm_moe:
|
884
1018
|
w13_input_scale = layer.w13_input_scale.max().to(torch.float32)
|
1019
|
+
w2_input_scale = layer.w2_input_scale.max().to(torch.float32)
|
885
1020
|
else:
|
886
1021
|
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
|
1022
|
+
w2_input_scale = layer.w2_input_scale
|
1023
|
+
|
1024
|
+
# Create shared parameters
|
887
1025
|
layer.g1_alphas = Parameter(
|
888
1026
|
(w13_input_scale * w13_weight_scale_2).to(torch.float32),
|
889
1027
|
requires_grad=False,
|
890
1028
|
)
|
891
|
-
|
892
|
-
|
893
|
-
|
894
|
-
), "Expected weight_scale.dim(1) to be divisible by 16"
|
895
|
-
assert (
|
896
|
-
layer.w13_weight_scale.dtype == torch.float8_e4m3fn
|
897
|
-
), "Weight Blockscale must be represented as FP8-E4M3"
|
898
|
-
w13_blockscale_swizzled = self.swizzle_blockscale(layer.w13_weight_scale)
|
899
|
-
|
900
|
-
layer.w13_blockscale_swizzled = Parameter(
|
901
|
-
w13_blockscale_swizzled, requires_grad=False
|
1029
|
+
layer.g2_alphas = Parameter(
|
1030
|
+
(w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
|
1031
|
+
requires_grad=False,
|
902
1032
|
)
|
903
|
-
del layer.w13_weight_scale
|
904
|
-
|
905
|
-
# This is for quantization, so we need to invert it.
|
906
1033
|
layer.w13_input_scale_quant = Parameter(
|
907
1034
|
(1 / w13_input_scale).to(torch.float32), requires_grad=False
|
908
1035
|
)
|
1036
|
+
layer.w2_input_scale_quant = Parameter(
|
1037
|
+
(1 / w2_input_scale).to(torch.float32), requires_grad=False
|
1038
|
+
)
|
909
1039
|
|
910
|
-
|
1040
|
+
# Validate weight scales
|
1041
|
+
for name, weight_scale in [
|
1042
|
+
("w13", layer.w13_weight_scale),
|
1043
|
+
("w2", layer.w2_weight_scale),
|
1044
|
+
]:
|
1045
|
+
assert (
|
1046
|
+
weight_scale.shape[2] % 16 == 0
|
1047
|
+
), f"Expected {name}_weight_scale.dim(2) to be divisible by 16"
|
1048
|
+
assert (
|
1049
|
+
weight_scale.dtype == torch.float8_e4m3fn
|
1050
|
+
), f"{name} Weight Blockscale must be represented as FP8-E4M3"
|
1051
|
+
|
1052
|
+
# Weight processing based on strategy
|
1053
|
+
if (
|
1054
|
+
self.enable_flashinfer_trtllm_moe
|
1055
|
+
and reorder_rows_for_gated_act_gemm is not None
|
1056
|
+
and shuffle_matrix_sf_a is not None
|
1057
|
+
):
|
1058
|
+
# FlashInfer TRTLLM processing - handles both w13 and w2
|
1059
|
+
(
|
1060
|
+
gemm1_weights_fp4_shuffled,
|
1061
|
+
gemm1_scales_fp4_shuffled,
|
1062
|
+
gemm2_weights_fp4_shuffled,
|
1063
|
+
gemm2_scales_fp4_shuffled,
|
1064
|
+
) = self.prepare_static_weights_for_kernel(
|
1065
|
+
layer.w13_weight,
|
1066
|
+
layer.w2_weight,
|
1067
|
+
layer.w13_weight_scale,
|
1068
|
+
layer.w2_weight_scale,
|
1069
|
+
layer.w2_weight.size(-2), # hidden_size
|
1070
|
+
layer.w13_weight.size(-2) // 2, # intermediate_size
|
1071
|
+
layer.w13_weight.size(0), # num_experts
|
1072
|
+
)
|
911
1073
|
|
912
|
-
|
913
|
-
|
914
|
-
|
915
|
-
|
916
|
-
|
1074
|
+
# Set flashinfer parameters
|
1075
|
+
layer.gemm1_weights_fp4_shuffled = Parameter(
|
1076
|
+
gemm1_weights_fp4_shuffled, requires_grad=False
|
1077
|
+
)
|
1078
|
+
layer.gemm2_weights_fp4_shuffled = Parameter(
|
1079
|
+
gemm2_weights_fp4_shuffled, requires_grad=False
|
1080
|
+
)
|
1081
|
+
layer.gemm1_scales_fp4_shuffled = Parameter(
|
1082
|
+
gemm1_scales_fp4_shuffled, requires_grad=False
|
1083
|
+
)
|
1084
|
+
layer.gemm2_scales_fp4_shuffled = Parameter(
|
1085
|
+
gemm2_scales_fp4_shuffled, requires_grad=False
|
1086
|
+
)
|
917
1087
|
|
918
|
-
|
919
|
-
|
920
|
-
|
921
|
-
|
1088
|
+
# Additional parameter needed for TRT-LLM
|
1089
|
+
layer.g1_scale_c = Parameter(
|
1090
|
+
(layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32),
|
1091
|
+
requires_grad=False,
|
1092
|
+
)
|
922
1093
|
|
923
|
-
|
924
|
-
|
925
|
-
|
926
|
-
|
1094
|
+
# Clean up weights that won't be used by TRT-LLM
|
1095
|
+
del (
|
1096
|
+
layer.w2_weight,
|
1097
|
+
layer.w2_weight_scale,
|
1098
|
+
layer.w13_weight,
|
1099
|
+
layer.w13_weight_scale,
|
1100
|
+
)
|
927
1101
|
|
928
|
-
|
929
|
-
layer.w2_weight_scale.shape[2] % 16 == 0
|
930
|
-
), "Expected weight_scale.dim(1) to be divisible by 16"
|
931
|
-
assert (
|
932
|
-
layer.w2_weight_scale.dtype == torch.float8_e4m3fn
|
933
|
-
), "Weight Blockscale must be represented as FP8-E4M3"
|
934
|
-
w2_blockscale_swizzled = self.swizzle_blockscale(layer.w2_weight_scale)
|
1102
|
+
logger.info_once("Applied flashinfer weight processing for both w13 and w2")
|
935
1103
|
|
936
|
-
|
937
|
-
|
938
|
-
|
939
|
-
|
940
|
-
|
1104
|
+
else:
|
1105
|
+
# CUTLASS processing - handle w13 and w2 separately
|
1106
|
+
|
1107
|
+
# Process w13 weights
|
1108
|
+
w13_blockscale_swizzled = self.swizzle_blockscale(layer.w13_weight_scale)
|
1109
|
+
layer.w13_blockscale_swizzled = Parameter(
|
1110
|
+
w13_blockscale_swizzled, requires_grad=False
|
1111
|
+
)
|
1112
|
+
layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
|
1113
|
+
|
1114
|
+
# Process w2 weights
|
1115
|
+
w2_blockscale_swizzled = self.swizzle_blockscale(layer.w2_weight_scale)
|
1116
|
+
layer.w2_blockscale_swizzled = Parameter(
|
1117
|
+
w2_blockscale_swizzled, requires_grad=False
|
1118
|
+
)
|
1119
|
+
layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
|
1120
|
+
|
1121
|
+
# Both flashinfer cutlass and regular cutlass use same processing for w2
|
1122
|
+
logger.info_once("Applied weight processing for both w13 and w2")
|
941
1123
|
|
942
|
-
|
943
|
-
|
944
|
-
|
945
|
-
|
946
|
-
|
947
|
-
|
948
|
-
|
949
|
-
|
1124
|
+
# Set up CUTLASS MoE parameters
|
1125
|
+
device = layer.w13_weight.device
|
1126
|
+
layer.cutlass_moe_params = CutlassMoEParams(
|
1127
|
+
CutlassMoEType.BlockscaledFP4,
|
1128
|
+
device,
|
1129
|
+
num_experts=layer.num_experts, # global num experts
|
1130
|
+
intermediate_size_per_partition=layer.w2_weight.shape[2] * 2, # n
|
1131
|
+
hidden_size=layer.w13_weight.shape[2] * 2,
|
1132
|
+
) # k
|
950
1133
|
|
951
1134
|
@property
|
952
1135
|
def load_up_proj_weight_first(self) -> bool:
|
@@ -971,13 +1154,20 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
971
1154
|
) -> torch.Tensor:
|
972
1155
|
assert activation == "silu", "Only SiLU activation is supported."
|
973
1156
|
|
1157
|
+
# Check if this is a FlashInferFP4MoE layer that should handle its own forward
|
1158
|
+
if hasattr(layer, "gemm1_weights_fp4_shuffled"):
|
1159
|
+
# This layer was processed with flashinfer TRTLLM - delegate to its own forward
|
1160
|
+
return layer.forward(x, topk_output)
|
1161
|
+
|
974
1162
|
if self.enable_flashinfer_cutlass_moe:
|
975
1163
|
assert (
|
976
1164
|
not apply_router_weight_on_input
|
977
1165
|
), "apply_router_weight_on_input is not supported for Flashinfer"
|
978
1166
|
# TRTLLM Cutlass moe takes in activations in BF16/Half/nvfp4 precision
|
979
1167
|
# and fp4 quantized weights loaded from the checkpoint
|
980
|
-
|
1168
|
+
|
1169
|
+
topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids
|
1170
|
+
|
981
1171
|
output = flashinfer_cutlass_fused_moe(
|
982
1172
|
x,
|
983
1173
|
topk_ids.to(torch.int),
|
@@ -1005,7 +1195,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
1005
1195
|
|
1006
1196
|
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
|
1007
1197
|
|
1008
|
-
topk_weights, topk_ids
|
1198
|
+
topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids
|
1009
1199
|
output = cutlass_moe_fp4(
|
1010
1200
|
a=x,
|
1011
1201
|
a1_gscale=layer.w13_input_scale_quant,
|