sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +257 -29
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +50 -6
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +48 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/xgrammar_backend.py +28 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +21 -10
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +5 -3
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +24 -3
- sglang/srt/entrypoints/engine.py +38 -17
- sglang/srt/entrypoints/grpc_request_manager.py +580 -0
- sglang/srt/entrypoints/grpc_server.py +680 -0
- sglang/srt/entrypoints/http_server.py +85 -54
- sglang/srt/entrypoints/openai/protocol.py +4 -1
- sglang/srt/entrypoints/openai/serving_base.py +46 -3
- sglang/srt/entrypoints/openai/serving_chat.py +36 -16
- sglang/srt/entrypoints/openai/serving_completions.py +12 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +6 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +6 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/activation.py +142 -9
- sglang/srt/layers/attention/ascend_backend.py +11 -4
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +18 -1
- sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/dp_attention.py +30 -1
- sglang/srt/layers/layernorm.py +32 -15
- sglang/srt/layers/linear.py +34 -3
- sglang/srt/layers/logits_processor.py +29 -10
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +182 -62
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +12 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +50 -31
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +147 -47
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +64 -40
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +30 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +76 -38
- sglang/srt/layers/sampler.py +162 -18
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +158 -160
- sglang/srt/managers/data_parallel_controller.py +105 -35
- sglang/srt/managers/detokenizer_manager.py +8 -4
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +199 -12
- sglang/srt/managers/mm_utils.py +1 -0
- sglang/srt/managers/multi_tokenizer_mixin.py +350 -400
- sglang/srt/managers/schedule_batch.py +77 -56
- sglang/srt/managers/schedule_policy.py +1 -1
- sglang/srt/managers/scheduler.py +187 -39
- sglang/srt/managers/scheduler_metrics_mixin.py +4 -3
- sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
- sglang/srt/managers/tokenizer_manager.py +259 -519
- sglang/srt/managers/tp_worker.py +53 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
- sglang/srt/mem_cache/hicache_storage.py +3 -23
- sglang/srt/mem_cache/hiradix_cache.py +103 -43
- sglang/srt/mem_cache/memory_pool.py +347 -48
- sglang/srt/mem_cache/memory_pool_host.py +105 -46
- sglang/srt/mem_cache/radix_cache.py +0 -2
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +86 -4
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +49 -7
- sglang/srt/mem_cache/swa_radix_cache.py +0 -2
- sglang/srt/metrics/collector.py +493 -76
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +59 -2
- sglang/srt/model_executor/model_runner.py +356 -29
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +128 -4
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +798 -218
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_v2.py +109 -15
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +1 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/glm4v_moe.py +3 -0
- sglang/srt/models/gpt_oss.py +1 -1
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +2 -2
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +7 -0
- sglang/srt/models/qwen2_5_vl.py +27 -3
- sglang/srt/models/qwen2_moe.py +56 -12
- sglang/srt/models/qwen3_moe.py +1 -1
- sglang/srt/models/qwen3_next.py +1042 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/multimodal/processors/dots_vlm.py +99 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/multimodal/processors/qwen_vl.py +15 -5
- sglang/srt/offloader.py +27 -3
- sglang/srt/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +276 -35
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_utils.py +0 -2
- sglang/srt/speculative/eagle_worker.py +43 -4
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/tracing/trace.py +552 -0
- sglang/srt/utils.py +34 -3
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_utils.py +28 -1
- sglang/utils.py +11 -0
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +237 -178
- sglang/srt/disaggregation/launch_lb.py +0 -118
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -45,7 +45,10 @@ from sglang.srt.layers.quantization.utils import (
|
|
45
45
|
|
46
46
|
if TYPE_CHECKING:
|
47
47
|
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
48
|
-
from sglang.srt.layers.moe.
|
48
|
+
from sglang.srt.layers.moe.token_dispatcher import (
|
49
|
+
StandardDispatchOutput,
|
50
|
+
CombineInput,
|
51
|
+
)
|
49
52
|
|
50
53
|
from sglang.srt.utils import is_cuda
|
51
54
|
|
@@ -838,19 +841,14 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
|
838
841
|
from sglang.srt.layers.linear import set_weight_attrs
|
839
842
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
840
843
|
|
841
|
-
|
842
|
-
|
843
|
-
self.is_k_full = (not self.quant_config.desc_act) or (
|
844
|
-
intermediate_size_per_partition == intermediate_size
|
845
|
-
)
|
844
|
+
self.is_k_full = (not self.quant_config.desc_act) or layer.moe_tp_size == 1
|
846
845
|
|
847
846
|
if self.quant_config.group_size != -1:
|
848
847
|
scales_size13 = hidden_size // self.quant_config.group_size
|
849
|
-
|
850
|
-
|
851
|
-
|
852
|
-
|
853
|
-
)
|
848
|
+
if self.quant_config.desc_act:
|
849
|
+
w2_scales_size = intermediate_size_per_partition
|
850
|
+
else:
|
851
|
+
w2_scales_size = intermediate_size_per_partition * layer.moe_tp_size
|
854
852
|
scales_size2 = w2_scales_size // self.quant_config.group_size
|
855
853
|
strategy = FusedMoeWeightScaleSupported.GROUP.value
|
856
854
|
else:
|
@@ -1052,17 +1050,26 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
|
1052
1050
|
)
|
1053
1051
|
replace_parameter(layer, "w2_scales", marlin_w2_scales)
|
1054
1052
|
|
1053
|
+
def create_moe_runner(
|
1054
|
+
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
1055
|
+
):
|
1056
|
+
self.moe_runner_config = moe_runner_config
|
1057
|
+
|
1055
1058
|
def apply(
|
1056
1059
|
self,
|
1057
1060
|
layer: torch.nn.Module,
|
1058
|
-
|
1059
|
-
|
1060
|
-
|
1061
|
-
|
1061
|
+
dispatch_output: StandardDispatchOutput,
|
1062
|
+
) -> CombineInput:
|
1063
|
+
|
1064
|
+
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
1065
|
+
|
1066
|
+
x = dispatch_output.hidden_states
|
1067
|
+
topk_output = dispatch_output.topk_output
|
1068
|
+
|
1062
1069
|
# Delay the import to avoid circular dependency
|
1063
1070
|
|
1064
1071
|
assert (
|
1065
|
-
moe_runner_config.activation == "silu"
|
1072
|
+
self.moe_runner_config.activation == "silu"
|
1066
1073
|
), "Only SiLU activation is supported."
|
1067
1074
|
|
1068
1075
|
# The input must currently be float16
|
@@ -1071,7 +1078,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
|
1071
1078
|
|
1072
1079
|
topk_weights, topk_ids, router_logits = topk_output
|
1073
1080
|
|
1074
|
-
|
1081
|
+
output = fused_marlin_moe(
|
1075
1082
|
x,
|
1076
1083
|
layer.w13_qweight,
|
1077
1084
|
layer.w2_qweight,
|
@@ -1087,3 +1094,4 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
|
1087
1094
|
num_bits=self.quant_config.weight_bits,
|
1088
1095
|
is_k_full=self.is_k_full,
|
1089
1096
|
).to(orig_dtype)
|
1097
|
+
return StandardCombineInput(hidden_states=output)
|
@@ -10,10 +10,14 @@ from torch.nn.parameter import Parameter
|
|
10
10
|
from sglang.srt.distributed import get_tp_group
|
11
11
|
from sglang.srt.layers.dp_attention import get_dp_global_num_tokens, get_local_dp_buffer
|
12
12
|
from sglang.srt.layers.moe import (
|
13
|
+
MoeRunner,
|
14
|
+
MoeRunnerBackend,
|
15
|
+
MoeRunnerConfig,
|
13
16
|
should_use_flashinfer_cutlass_moe_fp4_allgather,
|
14
17
|
should_use_flashinfer_trtllm_moe,
|
15
18
|
)
|
16
19
|
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
|
20
|
+
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
|
17
21
|
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
|
18
22
|
from sglang.srt.layers.quantization.base_config import (
|
19
23
|
FusedMoEMethodBase,
|
@@ -35,12 +39,14 @@ from sglang.srt.layers.quantization.utils import (
|
|
35
39
|
requantize_with_max_scale,
|
36
40
|
)
|
37
41
|
from sglang.srt.layers.radix_attention import RadixAttention
|
38
|
-
from sglang.srt.utils import is_cuda, next_power_of_2
|
42
|
+
from sglang.srt.utils import get_bool_env_var, is_cuda, next_power_of_2
|
39
43
|
|
40
44
|
if TYPE_CHECKING:
|
41
45
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
42
|
-
from sglang.srt.layers.moe.
|
43
|
-
|
46
|
+
from sglang.srt.layers.moe.token_dispatcher import (
|
47
|
+
CombineInput,
|
48
|
+
StandardDispatchOutput,
|
49
|
+
)
|
44
50
|
|
45
51
|
if is_cuda():
|
46
52
|
from sgl_kernel import scaled_fp4_quant
|
@@ -68,6 +74,10 @@ except ImportError:
|
|
68
74
|
# Initialize logger for the module
|
69
75
|
logger = logging.getLogger(__name__)
|
70
76
|
|
77
|
+
CUTEDSL_MOE_SCALAR_INPUT_SCALE = get_bool_env_var(
|
78
|
+
"SGLANG_CUTEDSL_MOE_SCALAR_INPUT_SCALE", "true"
|
79
|
+
)
|
80
|
+
|
71
81
|
# Supported activation schemes for the current configuration
|
72
82
|
ACTIVATION_SCHEMES = ["static"]
|
73
83
|
|
@@ -322,7 +332,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
|
322
332
|
layer: torch.nn.Module,
|
323
333
|
num_experts: int,
|
324
334
|
hidden_size: int,
|
325
|
-
|
335
|
+
intermediate_size_per_partition: int,
|
326
336
|
params_dtype: torch.dtype,
|
327
337
|
**extra_weight_attrs,
|
328
338
|
):
|
@@ -338,7 +348,10 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
|
338
348
|
|
339
349
|
w13_weight = ModelWeightParameter(
|
340
350
|
data=torch.empty(
|
341
|
-
num_experts,
|
351
|
+
num_experts,
|
352
|
+
2 * intermediate_size_per_partition,
|
353
|
+
hidden_size,
|
354
|
+
dtype=weight_dtype,
|
342
355
|
),
|
343
356
|
input_dim=2,
|
344
357
|
output_dim=1,
|
@@ -348,7 +361,10 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
|
348
361
|
|
349
362
|
w2_weight = ModelWeightParameter(
|
350
363
|
data=torch.empty(
|
351
|
-
num_experts,
|
364
|
+
num_experts,
|
365
|
+
hidden_size,
|
366
|
+
intermediate_size_per_partition,
|
367
|
+
dtype=weight_dtype,
|
352
368
|
),
|
353
369
|
input_dim=2,
|
354
370
|
output_dim=1,
|
@@ -414,28 +430,28 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
|
414
430
|
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
415
431
|
|
416
432
|
# Requantize each expert's weights using the combined scale
|
417
|
-
# w13_weight has shape (num_experts, 2 *
|
418
|
-
# where the first
|
419
|
-
|
433
|
+
# w13_weight has shape (num_experts, 2 * intermediate_size_per_partition, hidden_size)
|
434
|
+
# where the first intermediate_size_per_partition rows are w1, the next are w3
|
435
|
+
intermediate_size_per_partition = layer.w13_weight.shape[1] // 2
|
420
436
|
for expert_id in range(layer.w13_weight.shape[0]):
|
421
437
|
start = 0
|
422
438
|
for shard_id in range(2): # w1 and w3
|
423
439
|
# Dequantize using the original scale for this shard
|
424
440
|
dq_weight = per_tensor_dequantize(
|
425
441
|
layer.w13_weight[expert_id][
|
426
|
-
start : start +
|
442
|
+
start : start + intermediate_size_per_partition, :
|
427
443
|
],
|
428
444
|
layer.w13_weight_scale[expert_id][shard_id],
|
429
445
|
)
|
430
446
|
# Requantize using the combined max scale
|
431
447
|
(
|
432
448
|
layer.w13_weight[expert_id][
|
433
|
-
start : start +
|
449
|
+
start : start + intermediate_size_per_partition, :
|
434
450
|
],
|
435
451
|
_,
|
436
452
|
) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
|
437
453
|
|
438
|
-
start +=
|
454
|
+
start += intermediate_size_per_partition
|
439
455
|
|
440
456
|
# Update the scale parameter to be per-expert instead of per-shard
|
441
457
|
layer.w13_weight_scale = Parameter(max_w13_scales, requires_grad=False)
|
@@ -457,29 +473,31 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
|
457
473
|
layer.w2_input_scale.max(), requires_grad=False
|
458
474
|
)
|
459
475
|
|
476
|
+
def create_moe_runner(
|
477
|
+
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
478
|
+
):
|
479
|
+
self.moe_runner_config = moe_runner_config
|
480
|
+
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
|
481
|
+
|
460
482
|
def apply(
|
461
483
|
self,
|
462
484
|
layer: torch.nn.Module,
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
|
469
|
-
return fused_experts(
|
470
|
-
x,
|
471
|
-
layer.w13_weight,
|
472
|
-
layer.w2_weight,
|
473
|
-
topk_output=topk_output,
|
474
|
-
moe_runner_config=moe_runner_config,
|
485
|
+
dispatch_output: StandardDispatchOutput,
|
486
|
+
) -> CombineInput:
|
487
|
+
|
488
|
+
quant_info = TritonMoeQuantInfo(
|
489
|
+
w13_weight=layer.w13_weight,
|
490
|
+
w2_weight=layer.w2_weight,
|
475
491
|
use_fp8_w8a8=True,
|
476
|
-
per_channel_quant=False,
|
477
|
-
|
492
|
+
per_channel_quant=False,
|
493
|
+
w13_scale=layer.w13_weight_scale,
|
478
494
|
w2_scale=layer.w2_weight_scale,
|
479
|
-
|
495
|
+
a13_scale=layer.w13_input_scale,
|
480
496
|
a2_scale=layer.w2_input_scale,
|
481
497
|
)
|
482
498
|
|
499
|
+
return self.runner.run(dispatch_output, quant_info)
|
500
|
+
|
483
501
|
|
484
502
|
class ModelOptFp4Config(QuantizationConfig):
|
485
503
|
"""Config class for FP4."""
|
@@ -628,16 +646,21 @@ class ModelOptFp4Config(QuantizationConfig):
|
|
628
646
|
def is_layer_excluded(self, prefix: str, exclude_modules: list):
|
629
647
|
import regex as re
|
630
648
|
|
649
|
+
fused_patterns = ["q_a_proj", "q_b_proj", "kv_a_proj_with_mqa", "kv_b_proj"]
|
650
|
+
prefix_split = prefix.split(".")
|
631
651
|
for pattern in exclude_modules:
|
632
652
|
regex_str = pattern.replace(".", r"\.").replace("*", r".*")
|
653
|
+
pattern_split = pattern.split(".")
|
633
654
|
if re.fullmatch(regex_str, prefix):
|
634
655
|
return True
|
635
|
-
|
636
|
-
|
637
|
-
|
638
|
-
|
639
|
-
|
640
|
-
|
656
|
+
elif (
|
657
|
+
pattern_split[-1] in fused_patterns
|
658
|
+
and pattern_split[-1] in prefix_split[-1]
|
659
|
+
):
|
660
|
+
# Check if the last part of the excluded pattern is contained in the last part of the prefix
|
661
|
+
# This handles fused modules like fused_qkv_a_proj_with_mqa that contain q_a_proj and kv_a_proj_with_mqa
|
662
|
+
# e.g., model.layers.{i}.self_attn.{fused_weight_name}
|
663
|
+
assert len(prefix_split) == 5 and len(pattern_split) == 5
|
641
664
|
return True
|
642
665
|
return False
|
643
666
|
|
@@ -859,6 +882,13 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
859
882
|
"""Access the global enable_flashinfer_cutlass_moe setting."""
|
860
883
|
return get_moe_runner_backend().is_flashinfer_cutlass()
|
861
884
|
|
885
|
+
@property
|
886
|
+
def enable_flashinfer_cutedsl_moe(self) -> bool:
|
887
|
+
from sglang.srt.layers.moe import get_moe_runner_backend
|
888
|
+
|
889
|
+
"""Access the global enable_flashinfer_cutedsl_moe setting."""
|
890
|
+
return get_moe_runner_backend().is_flashinfer_cutedsl()
|
891
|
+
|
862
892
|
def create_weights(
|
863
893
|
self,
|
864
894
|
layer: torch.nn.Module,
|
@@ -970,15 +1000,17 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
970
1000
|
)
|
971
1001
|
|
972
1002
|
w13_input_scale = PerTensorScaleParameter(
|
973
|
-
data=torch.empty(layer.
|
1003
|
+
data=torch.empty(layer.num_experts, 2, dtype=torch.float32),
|
974
1004
|
weight_loader=weight_loader,
|
975
1005
|
)
|
1006
|
+
w13_input_scale._sglang_require_global_experts = True
|
976
1007
|
layer.register_parameter("w13_input_scale", w13_input_scale)
|
977
1008
|
|
978
1009
|
w2_input_scale = PerTensorScaleParameter(
|
979
|
-
data=torch.empty(layer.
|
1010
|
+
data=torch.empty(layer.num_experts, dtype=torch.float32),
|
980
1011
|
weight_loader=weight_loader,
|
981
1012
|
)
|
1013
|
+
w2_input_scale._sglang_require_global_experts = True
|
982
1014
|
layer.register_parameter("w2_input_scale", w2_input_scale)
|
983
1015
|
|
984
1016
|
def swizzle_blockscale(self, scale: torch.Tensor):
|
@@ -1161,6 +1193,33 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
1161
1193
|
if self.enable_flashinfer_cutlass_moe or self.enable_flashinfer_trtllm_moe:
|
1162
1194
|
w13_input_scale = layer.w13_input_scale.max().to(torch.float32)
|
1163
1195
|
w2_input_scale = layer.w2_input_scale.max().to(torch.float32)
|
1196
|
+
elif self.enable_flashinfer_cutedsl_moe:
|
1197
|
+
# All-expert-one-input-scale is mathematically different from default per-expert-input-scale
|
1198
|
+
# Thus we allow users to switch the flag to do thorough testing
|
1199
|
+
if CUTEDSL_MOE_SCALAR_INPUT_SCALE:
|
1200
|
+
w13_input_scale = (
|
1201
|
+
layer.w13_input_scale.max()
|
1202
|
+
.to(torch.float32)
|
1203
|
+
.repeat(layer.w13_input_scale.shape[0])
|
1204
|
+
)
|
1205
|
+
else:
|
1206
|
+
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(
|
1207
|
+
torch.float32
|
1208
|
+
)
|
1209
|
+
|
1210
|
+
w2_input_scale = layer.w2_input_scale
|
1211
|
+
|
1212
|
+
def _slice_scale(w):
|
1213
|
+
assert w.shape == (layer.num_experts,)
|
1214
|
+
assert layer.moe_ep_size * layer.num_local_experts == layer.num_experts
|
1215
|
+
return w[
|
1216
|
+
layer.moe_ep_rank
|
1217
|
+
* layer.num_local_experts : (layer.moe_ep_rank + 1)
|
1218
|
+
* layer.num_local_experts
|
1219
|
+
]
|
1220
|
+
|
1221
|
+
w13_input_scale = _slice_scale(w13_input_scale)
|
1222
|
+
w2_input_scale = _slice_scale(w2_input_scale)
|
1164
1223
|
else:
|
1165
1224
|
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
|
1166
1225
|
w2_input_scale = layer.w2_input_scale
|
@@ -1243,8 +1302,6 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
1243
1302
|
layer.w13_weight_scale,
|
1244
1303
|
)
|
1245
1304
|
|
1246
|
-
logger.info_once("Applied flashinfer weight processing for both w13 and w2")
|
1247
|
-
|
1248
1305
|
else:
|
1249
1306
|
# CUTLASS processing - handle w13 and w2 separately
|
1250
1307
|
|
@@ -1261,7 +1318,6 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
1261
1318
|
layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
|
1262
1319
|
|
1263
1320
|
# Both flashinfer cutlass and regular cutlass use same processing for w2
|
1264
|
-
logger.info_once("Applied weight processing for both w13 and w2")
|
1265
1321
|
|
1266
1322
|
# Set up CUTLASS MoE parameters
|
1267
1323
|
device = layer.w13_weight.device
|
@@ -1278,21 +1334,32 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
1278
1334
|
# FlashInfer CUTLASS kernel assumes [Up, Gate] Proj as W13
|
1279
1335
|
return self.enable_flashinfer_cutlass_moe
|
1280
1336
|
|
1337
|
+
def create_moe_runner(
|
1338
|
+
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
1339
|
+
):
|
1340
|
+
self.moe_runner_config = moe_runner_config
|
1341
|
+
|
1281
1342
|
def apply(
|
1282
1343
|
self,
|
1283
1344
|
layer: FusedMoE,
|
1284
|
-
|
1285
|
-
|
1286
|
-
|
1287
|
-
|
1345
|
+
dispatch_output: StandardDispatchOutput,
|
1346
|
+
) -> CombineInput:
|
1347
|
+
|
1348
|
+
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
1349
|
+
|
1350
|
+
x = dispatch_output.hidden_states
|
1351
|
+
topk_output = dispatch_output.topk_output
|
1352
|
+
|
1288
1353
|
assert (
|
1289
|
-
moe_runner_config.activation == "silu"
|
1354
|
+
self.moe_runner_config.activation == "silu"
|
1290
1355
|
), "Only SiLU activation is supported."
|
1291
1356
|
|
1357
|
+
moe_runner_config = self.moe_runner_config
|
1358
|
+
|
1292
1359
|
# Check if this is a FlashInferFP4MoE layer that should handle its own forward
|
1293
1360
|
if hasattr(layer, "gemm1_weights_fp4_shuffled"):
|
1294
1361
|
# This layer was processed with flashinfer TRTLLM - delegate to its own forward
|
1295
|
-
return layer.forward(x, topk_output)
|
1362
|
+
return StandardCombineInput(hidden_states=layer.forward(x, topk_output))
|
1296
1363
|
|
1297
1364
|
if self.enable_flashinfer_cutlass_moe:
|
1298
1365
|
assert (
|
@@ -1345,13 +1412,12 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
1345
1412
|
tp_rank=layer.moe_tp_rank,
|
1346
1413
|
tune_max_num_tokens=next_power_of_2(x.shape[0]),
|
1347
1414
|
)[0]
|
1348
|
-
# Scale by routed_scaling_factor is fused into select_experts.
|
1349
1415
|
if should_use_flashinfer_cutlass_moe_fp4_allgather():
|
1350
1416
|
output, global_output = get_local_dp_buffer(), output
|
1351
1417
|
get_tp_group().reduce_scatterv(
|
1352
1418
|
global_output, output=output, sizes=get_dp_global_num_tokens()
|
1353
1419
|
)
|
1354
|
-
return output
|
1420
|
+
return StandardCombineInput(hidden_states=output)
|
1355
1421
|
|
1356
1422
|
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
|
1357
1423
|
|
@@ -1372,4 +1438,38 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
1372
1438
|
apply_router_weight_on_input=moe_runner_config.apply_router_weight_on_input,
|
1373
1439
|
).to(x.dtype)
|
1374
1440
|
# Scale by routed_scaling_factor is fused into select_experts.
|
1375
|
-
return output
|
1441
|
+
return StandardCombineInput(hidden_states=output)
|
1442
|
+
|
1443
|
+
def apply_without_routing_weights(
|
1444
|
+
self,
|
1445
|
+
layer: FusedMoE,
|
1446
|
+
x: torch.Tensor,
|
1447
|
+
masked_m: torch.Tensor,
|
1448
|
+
moe_runner_config: MoeRunnerConfig,
|
1449
|
+
) -> torch.Tensor:
|
1450
|
+
assert (
|
1451
|
+
moe_runner_config.activation == "silu"
|
1452
|
+
), "Only SiLU activation is supported."
|
1453
|
+
|
1454
|
+
assert self.enable_flashinfer_cutedsl_moe, "only support flashinfer cutedsl moe"
|
1455
|
+
assert (
|
1456
|
+
not moe_runner_config.apply_router_weight_on_input
|
1457
|
+
), "apply_router_weight_on_input is not supported for Flashinfer"
|
1458
|
+
|
1459
|
+
from sglang.srt.layers.moe.flashinfer_cutedsl_moe import (
|
1460
|
+
flashinfer_cutedsl_moe_masked,
|
1461
|
+
)
|
1462
|
+
|
1463
|
+
out = flashinfer_cutedsl_moe_masked(
|
1464
|
+
hidden_states=x,
|
1465
|
+
input_global_scale=layer.w13_input_scale_quant,
|
1466
|
+
w1=layer.w13_weight,
|
1467
|
+
w1_blockscale=layer.w13_blockscale_swizzled,
|
1468
|
+
w1_alpha=layer.g1_alphas,
|
1469
|
+
w2=layer.w2_weight,
|
1470
|
+
a2_global_scale=layer.w2_input_scale_quant,
|
1471
|
+
w2_blockscale=layer.w2_blockscale_swizzled,
|
1472
|
+
w2_alpha=layer.g2_alphas,
|
1473
|
+
masked_m=masked_m,
|
1474
|
+
)
|
1475
|
+
return out
|
@@ -9,6 +9,8 @@ import torch
|
|
9
9
|
|
10
10
|
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
11
11
|
from sglang.srt.distributed.parallel_state import get_tp_group
|
12
|
+
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
|
13
|
+
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
|
12
14
|
from sglang.srt.layers.quantization.awq import AWQConfig
|
13
15
|
from sglang.srt.layers.quantization.base_config import (
|
14
16
|
FusedMoEMethodBase,
|
@@ -22,8 +24,10 @@ from sglang.srt.utils import get_device_capability, set_weight_attrs
|
|
22
24
|
logger = logging.getLogger(__name__)
|
23
25
|
|
24
26
|
if TYPE_CHECKING:
|
25
|
-
from sglang.srt.layers.moe.
|
26
|
-
|
27
|
+
from sglang.srt.layers.moe.token_dispatcher import (
|
28
|
+
CombineInput,
|
29
|
+
StandardDispatchOutput,
|
30
|
+
)
|
27
31
|
|
28
32
|
|
29
33
|
def get_weight_perm(num_bits: int):
|
@@ -349,37 +353,36 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
|
349
353
|
layer.register_parameter(key, param)
|
350
354
|
set_weight_attrs(param, extra_weight_attrs)
|
351
355
|
|
356
|
+
def create_moe_runner(
|
357
|
+
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
358
|
+
):
|
359
|
+
self.moe_runner_config = moe_runner_config
|
360
|
+
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
|
361
|
+
|
352
362
|
def apply(
|
353
363
|
self,
|
354
364
|
layer: torch.nn.Module,
|
355
|
-
|
356
|
-
|
357
|
-
moe_runner_config: MoeRunnerConfig,
|
358
|
-
) -> torch.Tensor:
|
359
|
-
# avoid circular import
|
360
|
-
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
361
|
-
|
365
|
+
dispatch_output: StandardDispatchOutput,
|
366
|
+
) -> CombineInput:
|
362
367
|
assert (
|
363
|
-
moe_runner_config.activation == "silu"
|
368
|
+
self.moe_runner_config.activation == "silu"
|
364
369
|
), "Only SiLU activation is supported."
|
365
370
|
|
366
371
|
weight_bits = self.quant_config.weight_bits
|
367
372
|
has_zp = self.quant_config.has_zp
|
368
373
|
|
369
|
-
|
370
|
-
|
371
|
-
layer.
|
372
|
-
layer.w2_qweight,
|
373
|
-
topk_output=topk_output,
|
374
|
-
moe_runner_config=moe_runner_config,
|
374
|
+
quant_info = TritonMoeQuantInfo(
|
375
|
+
w13_weight=layer.w13_qweight,
|
376
|
+
w2_weight=layer.w2_qweight,
|
375
377
|
use_int4_w4a16=weight_bits == 4,
|
376
378
|
use_int8_w8a16=weight_bits == 8,
|
377
|
-
|
379
|
+
w13_scale=layer.w13_scales,
|
378
380
|
w2_scale=layer.w2_scales,
|
379
|
-
|
381
|
+
w13_zp=layer.w13_qzeros if has_zp else None,
|
380
382
|
w2_zp=layer.w2_qzeros if has_zp else None,
|
381
383
|
block_shape=[0, layer.group_size],
|
382
384
|
)
|
385
|
+
return self.runner.run(dispatch_output, quant_info)
|
383
386
|
|
384
387
|
@staticmethod
|
385
388
|
def get_weight_loader(layer, weight_loader):
|