sglang 0.5.2rc1__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +257 -29
- sglang/lang/interpreter.py +1 -1
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +50 -6
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +48 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/xgrammar_backend.py +28 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +21 -10
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +5 -3
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +67 -43
- sglang/srt/entrypoints/engine.py +38 -17
- sglang/srt/entrypoints/grpc_request_manager.py +580 -0
- sglang/srt/entrypoints/grpc_server.py +680 -0
- sglang/srt/entrypoints/http_server.py +88 -53
- sglang/srt/entrypoints/openai/protocol.py +7 -4
- sglang/srt/entrypoints/openai/serving_base.py +46 -3
- sglang/srt/entrypoints/openai/serving_chat.py +39 -19
- sglang/srt/entrypoints/openai/serving_completions.py +15 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -4
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +7 -4
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +6 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/activation.py +142 -9
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/attention/ascend_backend.py +11 -4
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +18 -1
- sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +45 -7
- sglang/srt/layers/dp_attention.py +30 -1
- sglang/srt/layers/layernorm.py +32 -15
- sglang/srt/layers/linear.py +34 -3
- sglang/srt/layers/logits_processor.py +29 -10
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +182 -62
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +12 -7
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +50 -31
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +182 -49
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +68 -41
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +30 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +76 -38
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +0 -18
- sglang/srt/layers/sampler.py +162 -18
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +200 -199
- sglang/srt/managers/data_parallel_controller.py +105 -35
- sglang/srt/managers/detokenizer_manager.py +8 -4
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +199 -12
- sglang/srt/managers/mm_utils.py +1 -0
- sglang/srt/managers/multi_tokenizer_mixin.py +351 -397
- sglang/srt/managers/schedule_batch.py +77 -56
- sglang/srt/managers/schedule_policy.py +4 -3
- sglang/srt/managers/scheduler.py +191 -139
- sglang/srt/managers/scheduler_metrics_mixin.py +116 -9
- sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
- sglang/srt/managers/tokenizer_manager.py +260 -519
- sglang/srt/managers/tp_worker.py +53 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +18 -33
- sglang/srt/mem_cache/hiradix_cache.py +108 -48
- sglang/srt/mem_cache/memory_pool.py +347 -48
- sglang/srt/mem_cache/memory_pool_host.py +121 -57
- sglang/srt/mem_cache/radix_cache.py +0 -2
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +95 -5
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +81 -20
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +0 -2
- sglang/srt/metrics/collector.py +502 -77
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +75 -19
- sglang/srt/model_executor/model_runner.py +357 -30
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +128 -4
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +798 -218
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_v2.py +346 -48
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +11 -2
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/glm4v_moe.py +3 -0
- sglang/srt/models/gpt_oss.py +1 -1
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +2 -2
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +7 -0
- sglang/srt/models/qwen2_5_vl.py +27 -3
- sglang/srt/models/qwen2_moe.py +60 -13
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +40 -9
- sglang/srt/models/qwen3_next.py +1042 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/multimodal/processors/dots_vlm.py +99 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/multimodal/processors/qwen_vl.py +15 -5
- sglang/srt/offloader.py +27 -3
- sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
- sglang/srt/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +355 -37
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_utils.py +0 -2
- sglang/srt/speculative/eagle_worker.py +197 -112
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/tracing/trace.py +552 -0
- sglang/srt/utils.py +46 -3
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_utils.py +28 -1
- sglang/utils.py +12 -0
- sglang/version.py +1 -1
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +263 -200
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
- /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -23,8 +23,14 @@ from sglang.srt.layers.moe import (
|
|
23
23
|
get_moe_runner_backend,
|
24
24
|
should_use_flashinfer_trtllm_moe,
|
25
25
|
)
|
26
|
+
from sglang.srt.layers.moe.token_dispatcher.standard import (
|
27
|
+
CombineInput,
|
28
|
+
StandardDispatcher,
|
29
|
+
StandardDispatchOutput,
|
30
|
+
)
|
26
31
|
from sglang.srt.layers.moe.topk import TopKOutput, TopKOutputChecker
|
27
32
|
from sglang.srt.layers.quantization.base_config import (
|
33
|
+
FusedMoEMethodBase,
|
28
34
|
QuantizationConfig,
|
29
35
|
QuantizeMethodBase,
|
30
36
|
)
|
@@ -68,16 +74,6 @@ if should_use_flashinfer_trtllm_moe():
|
|
68
74
|
logger = logging.getLogger(__name__)
|
69
75
|
|
70
76
|
|
71
|
-
def _is_fp4_quantization_enabled():
|
72
|
-
"""Check if ModelOpt FP4 quantization is enabled."""
|
73
|
-
try:
|
74
|
-
# Use the same simple check that works for class selection
|
75
|
-
quantization = global_server_args_dict.get("quantization")
|
76
|
-
return quantization == "modelopt_fp4"
|
77
|
-
except:
|
78
|
-
return False
|
79
|
-
|
80
|
-
|
81
77
|
def _get_tile_tokens_dim(num_tokens, top_k, num_experts):
|
82
78
|
# Guess tokens per expert assuming perfect expert distribution first.
|
83
79
|
num_tokens_per_expert = (num_tokens * top_k) // num_experts
|
@@ -152,16 +148,6 @@ class FusedMoE(torch.nn.Module):
|
|
152
148
|
self.expert_map_cpu = None
|
153
149
|
self.expert_map_gpu = None
|
154
150
|
|
155
|
-
self.moe_runner_config = MoeRunnerConfig(
|
156
|
-
activation=activation,
|
157
|
-
apply_router_weight_on_input=apply_router_weight_on_input,
|
158
|
-
inplace=inplace,
|
159
|
-
no_combine=no_combine,
|
160
|
-
routed_scaling_factor=routed_scaling_factor,
|
161
|
-
gemm1_alpha=gemm1_alpha,
|
162
|
-
gemm1_clamp_limit=gemm1_clamp_limit,
|
163
|
-
)
|
164
|
-
|
165
151
|
enable_flashinfer_cutlass_moe = get_moe_runner_backend().is_flashinfer_cutlass()
|
166
152
|
|
167
153
|
if enable_flashinfer_cutlass_moe and quant_config is None:
|
@@ -196,13 +182,6 @@ class FusedMoE(torch.nn.Module):
|
|
196
182
|
self.use_presharded_weights = use_presharded_weights
|
197
183
|
|
198
184
|
self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
|
199
|
-
if quant_config is None:
|
200
|
-
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod(
|
201
|
-
self.use_triton_kernels
|
202
|
-
)
|
203
|
-
else:
|
204
|
-
self.quant_method = quant_config.get_quant_method(self, prefix)
|
205
|
-
assert self.quant_method is not None
|
206
185
|
|
207
186
|
self.quant_config = quant_config
|
208
187
|
self.use_flashinfer_mxfp4_moe = get_moe_runner_backend().is_flashinfer_mxfp4()
|
@@ -213,12 +192,40 @@ class FusedMoE(torch.nn.Module):
|
|
213
192
|
and self.use_flashinfer_mxfp4_moe
|
214
193
|
):
|
215
194
|
hidden_size = round_up(hidden_size, 256)
|
195
|
+
self.hidden_size = hidden_size
|
196
|
+
|
197
|
+
self.moe_runner_config = MoeRunnerConfig(
|
198
|
+
num_experts=num_experts,
|
199
|
+
num_local_experts=self.num_local_experts,
|
200
|
+
hidden_size=hidden_size,
|
201
|
+
intermediate_size_per_partition=self.intermediate_size_per_partition,
|
202
|
+
layer_id=layer_id,
|
203
|
+
top_k=top_k,
|
204
|
+
num_fused_shared_experts=num_fused_shared_experts,
|
205
|
+
params_dtype=params_dtype,
|
206
|
+
activation=activation,
|
207
|
+
apply_router_weight_on_input=apply_router_weight_on_input,
|
208
|
+
inplace=inplace,
|
209
|
+
no_combine=no_combine,
|
210
|
+
routed_scaling_factor=routed_scaling_factor,
|
211
|
+
gemm1_alpha=gemm1_alpha,
|
212
|
+
gemm1_clamp_limit=gemm1_clamp_limit,
|
213
|
+
)
|
214
|
+
|
215
|
+
if quant_config is None:
|
216
|
+
self.quant_method: FusedMoEMethodBase = UnquantizedFusedMoEMethod(
|
217
|
+
self.use_triton_kernels
|
218
|
+
)
|
219
|
+
else:
|
220
|
+
self.quant_method: FusedMoEMethodBase = quant_config.get_quant_method(
|
221
|
+
self, prefix
|
222
|
+
)
|
223
|
+
assert self.quant_method is not None
|
224
|
+
|
216
225
|
self.quant_method.create_weights(
|
217
226
|
layer=self,
|
218
227
|
num_experts=self.num_local_experts,
|
219
228
|
hidden_size=hidden_size,
|
220
|
-
# FIXME: figure out which intermediate_size to use
|
221
|
-
intermediate_size=self.intermediate_size_per_partition,
|
222
229
|
intermediate_size_per_partition=self.intermediate_size_per_partition,
|
223
230
|
params_dtype=params_dtype,
|
224
231
|
weight_loader=(
|
@@ -229,6 +236,9 @@ class FusedMoE(torch.nn.Module):
|
|
229
236
|
with_bias=with_bias,
|
230
237
|
)
|
231
238
|
|
239
|
+
self.quant_method.create_moe_runner(self, self.moe_runner_config)
|
240
|
+
self.dispatcher = StandardDispatcher()
|
241
|
+
|
232
242
|
def _load_per_tensor_weight_scale(
|
233
243
|
self,
|
234
244
|
shard_id: str,
|
@@ -522,10 +532,12 @@ class FusedMoE(torch.nn.Module):
|
|
522
532
|
shard_id: str,
|
523
533
|
expert_id: int,
|
524
534
|
) -> None:
|
535
|
+
# WARN: This makes the `expert_id` mean "local" and "global" in different cases
|
536
|
+
if not getattr(param, "_sglang_require_global_experts", False):
|
537
|
+
expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
|
538
|
+
if expert_id == -1:
|
539
|
+
return
|
525
540
|
|
526
|
-
expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
|
527
|
-
if expert_id == -1:
|
528
|
-
return
|
529
541
|
self._weight_loader_impl(
|
530
542
|
param=param,
|
531
543
|
loaded_weight=loaded_weight,
|
@@ -594,8 +606,10 @@ class FusedMoE(torch.nn.Module):
|
|
594
606
|
loaded_weight = loaded_weight.to(param.data.device)
|
595
607
|
|
596
608
|
if (
|
597
|
-
|
598
|
-
|
609
|
+
(
|
610
|
+
"compressed" in self.quant_method.__class__.__name__.lower()
|
611
|
+
or "w4afp8" in self.quant_config.get_name()
|
612
|
+
)
|
599
613
|
and (param.data[expert_id] != 1).any()
|
600
614
|
and ((param.data[expert_id] - loaded_weight).abs() > 1e-5).any()
|
601
615
|
):
|
@@ -811,16 +825,17 @@ class FusedMoE(torch.nn.Module):
|
|
811
825
|
elif TopKOutputChecker.format_is_triton_kernel(topk_output):
|
812
826
|
raise NotImplementedError()
|
813
827
|
|
814
|
-
|
815
|
-
|
828
|
+
dispatch_output = self.dispatcher.dispatch(
|
829
|
+
hidden_states=hidden_states, topk_output=topk_output
|
830
|
+
)
|
816
831
|
|
817
|
-
|
818
|
-
|
819
|
-
|
820
|
-
|
821
|
-
|
822
|
-
|
823
|
-
|
832
|
+
# TODO: consider using symmetric memory
|
833
|
+
combine_input = self.quant_method.apply(
|
834
|
+
layer=self,
|
835
|
+
dispatch_output=dispatch_output,
|
836
|
+
)
|
837
|
+
|
838
|
+
final_hidden_states = self.dispatcher.combine(combine_input)
|
824
839
|
|
825
840
|
final_hidden_states = final_hidden_states[
|
826
841
|
..., :origin_hidden_states_dim
|
@@ -953,9 +968,9 @@ class FlashInferFusedMoE(FusedMoE):
|
|
953
968
|
# Matrix multiply.
|
954
969
|
final_hidden_states = self.quant_method.apply_with_router_logits(
|
955
970
|
layer=self,
|
956
|
-
|
957
|
-
|
958
|
-
|
971
|
+
dispatch_output=StandardDispatchOutput(
|
972
|
+
hidden_states=hidden_states, topk_output=topk_output
|
973
|
+
),
|
959
974
|
)
|
960
975
|
|
961
976
|
if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
|
@@ -1055,16 +1070,3 @@ class FlashInferFP4MoE(FusedMoE):
|
|
1055
1070
|
)[0]
|
1056
1071
|
|
1057
1072
|
return result
|
1058
|
-
|
1059
|
-
|
1060
|
-
def get_fused_moe_impl_class():
|
1061
|
-
"""Factory function to get the appropriate FusedMoE implementation class."""
|
1062
|
-
if should_use_flashinfer_trtllm_moe() and _is_fp4_quantization_enabled():
|
1063
|
-
# Use FP4 variant when FP4 quantization is enabled
|
1064
|
-
return FlashInferFP4MoE
|
1065
|
-
elif should_use_flashinfer_trtllm_moe():
|
1066
|
-
# Use regular FlashInfer variant for non-FP4 FlashInfer cases
|
1067
|
-
return FlashInferFusedMoE
|
1068
|
-
else:
|
1069
|
-
# Default case
|
1070
|
-
return FusedMoE
|
@@ -1,9 +1,41 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from abc import ABC, abstractmethod
|
1
4
|
from dataclasses import dataclass
|
2
|
-
from typing import Optional
|
5
|
+
from typing import TYPE_CHECKING, Callable, Optional, Tuple, TypeGuard
|
6
|
+
|
7
|
+
import torch
|
8
|
+
|
9
|
+
from sglang.srt.layers.moe.utils import MoeA2ABackend, MoeRunnerBackend
|
10
|
+
|
11
|
+
if TYPE_CHECKING:
|
12
|
+
from sglang.srt.layers.moe.moe_runner.triton import (
|
13
|
+
TritonRunnerCore,
|
14
|
+
TritonRunnerInput,
|
15
|
+
TritonRunnerOutput,
|
16
|
+
)
|
17
|
+
from sglang.srt.layers.moe.token_dispatcher import (
|
18
|
+
CombineInput,
|
19
|
+
CombineInputFormat,
|
20
|
+
DispatchOutput,
|
21
|
+
DispatchOutputFormat,
|
22
|
+
)
|
3
23
|
|
4
24
|
|
5
25
|
@dataclass
|
6
26
|
class MoeRunnerConfig:
|
27
|
+
|
28
|
+
# MoE parameters
|
29
|
+
num_experts: Optional[int] = None
|
30
|
+
num_local_experts: Optional[int] = None
|
31
|
+
hidden_size: Optional[int] = None
|
32
|
+
intermediate_size_per_partition: Optional[int] = None
|
33
|
+
layer_id: Optional[int] = None
|
34
|
+
top_k: Optional[int] = None
|
35
|
+
num_fused_shared_experts: Optional[int] = None
|
36
|
+
params_dtype: Optional[torch.dtype] = None
|
37
|
+
|
38
|
+
# Runner configuration
|
7
39
|
activation: str = "silu"
|
8
40
|
apply_router_weight_on_input: bool = False
|
9
41
|
inplace: bool = True
|
@@ -11,3 +43,244 @@ class MoeRunnerConfig:
|
|
11
43
|
routed_scaling_factor: Optional[float] = None
|
12
44
|
gemm1_alpha: Optional[float] = None
|
13
45
|
gemm1_clamp_limit: Optional[float] = None
|
46
|
+
|
47
|
+
|
48
|
+
@dataclass
|
49
|
+
class RunnerInput(ABC):
|
50
|
+
|
51
|
+
@property
|
52
|
+
@abstractmethod
|
53
|
+
def runner_backend(self) -> MoeRunnerBackend: ...
|
54
|
+
|
55
|
+
def runner_backend_is_triton(self) -> TypeGuard[TritonRunnerInput]:
|
56
|
+
return self.runner_backend == MoeRunnerBackend.TRITON
|
57
|
+
|
58
|
+
|
59
|
+
class RunnerOutput(ABC):
|
60
|
+
|
61
|
+
@property
|
62
|
+
@abstractmethod
|
63
|
+
def runner_backend(self) -> MoeRunnerBackend: ...
|
64
|
+
|
65
|
+
def runner_backend_is_triton(self) -> TypeGuard[TritonRunnerOutput]:
|
66
|
+
return self.runner_backend == MoeRunnerBackend.TRITON
|
67
|
+
|
68
|
+
|
69
|
+
@dataclass
|
70
|
+
class MoeQuantInfo(ABC):
|
71
|
+
"""Moe quantization data."""
|
72
|
+
|
73
|
+
pass
|
74
|
+
|
75
|
+
|
76
|
+
class MoeRunnerCore(ABC):
|
77
|
+
|
78
|
+
def __init__(self, config: MoeRunnerConfig):
|
79
|
+
self.config = config
|
80
|
+
|
81
|
+
@abstractmethod
|
82
|
+
def run(
|
83
|
+
self, runner_input: RunnerInput, quant_info: MoeQuantInfo, running_state: dict
|
84
|
+
) -> RunnerOutput:
|
85
|
+
pass
|
86
|
+
|
87
|
+
@property
|
88
|
+
@abstractmethod
|
89
|
+
def runner_backend(self) -> MoeRunnerBackend: ...
|
90
|
+
|
91
|
+
def runner_backend_is_triton(self) -> TypeGuard[TritonRunnerCore]:
|
92
|
+
return self.runner_backend == MoeRunnerBackend.TRITON
|
93
|
+
|
94
|
+
|
95
|
+
class FusedOpPool:
|
96
|
+
|
97
|
+
_fused_funcs: dict[str, Callable] = {}
|
98
|
+
|
99
|
+
@classmethod
|
100
|
+
def register_fused_func(
|
101
|
+
cls, a2a_backend_name: str, runner_backend_name: str, fused_func: Callable
|
102
|
+
):
|
103
|
+
key = (a2a_backend_name, runner_backend_name)
|
104
|
+
if key in cls._fused_funcs:
|
105
|
+
raise ValueError(
|
106
|
+
f"Fused function for {a2a_backend_name} to {runner_backend_name} is already registered."
|
107
|
+
)
|
108
|
+
assert MoeA2ABackend(
|
109
|
+
a2a_backend_name
|
110
|
+
), f"Invalid dispatch name: {a2a_backend_name}"
|
111
|
+
assert MoeRunnerBackend(
|
112
|
+
runner_backend_name
|
113
|
+
), f"Invalid runner name: {runner_backend_name}"
|
114
|
+
cls._fused_funcs[key] = fused_func
|
115
|
+
|
116
|
+
@classmethod
|
117
|
+
def get_fused_func(cls, dispatch_name: str, runner_name: str) -> Optional[Callable]:
|
118
|
+
key = (dispatch_name, runner_name)
|
119
|
+
fused_func = cls._fused_funcs.get(key)
|
120
|
+
return fused_func
|
121
|
+
|
122
|
+
|
123
|
+
class PermuteMethodPool:
|
124
|
+
|
125
|
+
_pre_permute_methods: dict[
|
126
|
+
Tuple[DispatchOutputFormat, MoeRunnerBackend], Callable
|
127
|
+
] = {}
|
128
|
+
_post_permute_methods: dict[
|
129
|
+
Tuple[MoeRunnerBackend, CombineInputFormat], Callable
|
130
|
+
] = {}
|
131
|
+
|
132
|
+
@classmethod
|
133
|
+
def register_pre_permute(
|
134
|
+
cls,
|
135
|
+
dispatch_output_name: str,
|
136
|
+
runner_backend_name: str,
|
137
|
+
permute_func: Callable,
|
138
|
+
):
|
139
|
+
"""
|
140
|
+
Register a customized pre-permute function for the given DispatchOutputFormat and MoeRunnerBackend.
|
141
|
+
|
142
|
+
:param dispatch_output_name: The DispatchOutputFormat name.
|
143
|
+
:param runner_backend_name: The MoeRunnerBackend name.
|
144
|
+
:param permute_func: The permute function to register.
|
145
|
+
"""
|
146
|
+
# TODO: check if registration is valid
|
147
|
+
key = (dispatch_output_name, runner_backend_name)
|
148
|
+
if key in cls._pre_permute_methods:
|
149
|
+
raise ValueError(
|
150
|
+
f"Pre-permute method for {dispatch_output_name} to {runner_backend_name} is already registered."
|
151
|
+
)
|
152
|
+
cls._pre_permute_methods[key] = permute_func
|
153
|
+
|
154
|
+
@classmethod
|
155
|
+
def register_post_permute(
|
156
|
+
cls,
|
157
|
+
runner_backend_name: str,
|
158
|
+
combine_input_name: str,
|
159
|
+
permute_func: Callable,
|
160
|
+
):
|
161
|
+
"""
|
162
|
+
Register a customized post-permute function for the given MoeRunnerBackend and CombineInputFormat.
|
163
|
+
|
164
|
+
:param runner_backend_name: The MoeRunnerBackend name.
|
165
|
+
:param combine_input_name: The CombineInputFormat name.
|
166
|
+
:param permute_func: The permute function to register.
|
167
|
+
"""
|
168
|
+
# TODO: check if registration is valid
|
169
|
+
key = (runner_backend_name, combine_input_name)
|
170
|
+
if key in cls._post_permute_methods:
|
171
|
+
raise ValueError(
|
172
|
+
f"Post-permute method for {runner_backend_name} to {combine_input_name} is already registered."
|
173
|
+
)
|
174
|
+
cls._post_permute_methods[key] = permute_func
|
175
|
+
|
176
|
+
@classmethod
|
177
|
+
def get_pre_permute(
|
178
|
+
cls,
|
179
|
+
dispatch_output_format: DispatchOutputFormat,
|
180
|
+
runner_input_format: MoeRunnerBackend,
|
181
|
+
) -> Callable:
|
182
|
+
"""
|
183
|
+
Retrieve the pre-permute function for the given DispatchOutputFormat and MoeRunnerBackend.
|
184
|
+
|
185
|
+
:param dispatch_output_format: The DispatchOutputFormat type.
|
186
|
+
:param runner_input_format: The MoeRunnerBackend type.
|
187
|
+
:return: The registered permute function or None if not found.
|
188
|
+
"""
|
189
|
+
key = (dispatch_output_format, runner_input_format)
|
190
|
+
pre_permute_func = cls._pre_permute_methods.get(key)
|
191
|
+
assert (
|
192
|
+
pre_permute_func is not None
|
193
|
+
), f"Pre-permute function for {dispatch_output_format} to {runner_input_format} is not registered"
|
194
|
+
return pre_permute_func
|
195
|
+
|
196
|
+
@classmethod
|
197
|
+
def get_post_permute(
|
198
|
+
cls,
|
199
|
+
runner_output_format: MoeRunnerBackend,
|
200
|
+
combine_input_format: CombineInputFormat,
|
201
|
+
) -> Callable:
|
202
|
+
"""
|
203
|
+
Retrieve the post-permute function for the given MoeRunnerBackend and CombineInputFormat.
|
204
|
+
|
205
|
+
:param runner_output_format: The MoeRunnerBackend type.
|
206
|
+
:param combine_input_format: The CombineInputFormat type.
|
207
|
+
:return: The registered permute function or None if not found.
|
208
|
+
"""
|
209
|
+
key = (runner_output_format, combine_input_format)
|
210
|
+
post_permute_func = cls._post_permute_methods.get(key)
|
211
|
+
assert (
|
212
|
+
post_permute_func is not None
|
213
|
+
), f"Post-permute function for {runner_output_format} to {combine_input_format} is not registered"
|
214
|
+
return post_permute_func
|
215
|
+
|
216
|
+
|
217
|
+
def register_fused_func(
|
218
|
+
a2a_backend_name: str,
|
219
|
+
runner_backend_name: str,
|
220
|
+
) -> Callable:
|
221
|
+
"""
|
222
|
+
Decorator to register a fused function for the given DispatchOutputFormat and MoeRunnerBackend.
|
223
|
+
|
224
|
+
:param a2a_backend_name: The A2A backend name.
|
225
|
+
:param runner_backend_name: The MoeRunnerBackend name.
|
226
|
+
:return: The decorator function.
|
227
|
+
"""
|
228
|
+
|
229
|
+
def decorator(fused_func: Callable):
|
230
|
+
FusedOpPool.register_fused_func(
|
231
|
+
a2a_backend_name, runner_backend_name, fused_func
|
232
|
+
)
|
233
|
+
return fused_func
|
234
|
+
|
235
|
+
return decorator
|
236
|
+
|
237
|
+
|
238
|
+
def register_pre_permute(
|
239
|
+
dispatch_output_name: str,
|
240
|
+
runner_backend_name: str,
|
241
|
+
) -> Callable:
|
242
|
+
"""
|
243
|
+
Decorator to register a pre-permute function for the given DispatchOutputFormat and MoeRunnerBackend.
|
244
|
+
|
245
|
+
:param dispatch_output_name: The DispatchOutputFormat name.
|
246
|
+
:param runner_backend_name: The MoeRunnerBackend name.
|
247
|
+
:return: The decorator function.
|
248
|
+
"""
|
249
|
+
|
250
|
+
def decorator(
|
251
|
+
permute_func: Callable[
|
252
|
+
[DispatchOutput, MoeQuantInfo, MoeRunnerConfig, dict], RunnerInput
|
253
|
+
]
|
254
|
+
) -> Callable:
|
255
|
+
|
256
|
+
PermuteMethodPool.register_pre_permute(
|
257
|
+
dispatch_output_name, runner_backend_name, permute_func
|
258
|
+
)
|
259
|
+
return permute_func
|
260
|
+
|
261
|
+
return decorator
|
262
|
+
|
263
|
+
|
264
|
+
def register_post_permute(
|
265
|
+
runner_backend_name: str,
|
266
|
+
combine_input_name: str,
|
267
|
+
) -> Callable:
|
268
|
+
"""
|
269
|
+
Decorator to register a post-permute function for the given MoeRunnerBackend and CombineInputFormat.
|
270
|
+
|
271
|
+
:param runner_backend_name: The MoeRunnerBackend name.
|
272
|
+
:param combine_input_name: The CombineInputFormat name.
|
273
|
+
:return: The decorator function.
|
274
|
+
"""
|
275
|
+
|
276
|
+
def decorator(
|
277
|
+
permute_func: Callable[
|
278
|
+
[RunnerOutput, MoeQuantInfo, MoeRunnerConfig, dict], CombineInput
|
279
|
+
]
|
280
|
+
) -> Callable:
|
281
|
+
PermuteMethodPool.register_post_permute(
|
282
|
+
runner_backend_name, combine_input_name, permute_func
|
283
|
+
)
|
284
|
+
return permute_func
|
285
|
+
|
286
|
+
return decorator
|
@@ -0,0 +1,80 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import logging
|
4
|
+
import os
|
5
|
+
from typing import TYPE_CHECKING
|
6
|
+
|
7
|
+
from sglang.srt.layers.moe.moe_runner.base import (
|
8
|
+
FusedOpPool,
|
9
|
+
MoeRunnerConfig,
|
10
|
+
PermuteMethodPool,
|
11
|
+
)
|
12
|
+
from sglang.srt.layers.moe.moe_runner.triton import TritonRunnerCore
|
13
|
+
from sglang.srt.layers.moe.utils import get_moe_a2a_backend
|
14
|
+
|
15
|
+
if TYPE_CHECKING:
|
16
|
+
from sglang.srt.layers.moe.moe_runner.base import MoeQuantInfo
|
17
|
+
from sglang.srt.layers.moe.token_dispatcher.base import CombineInput, DispatchOutput
|
18
|
+
from sglang.srt.layers.moe.utils import MoeRunnerBackend
|
19
|
+
|
20
|
+
logger = logging.getLogger(__name__)
|
21
|
+
|
22
|
+
|
23
|
+
class MoeRunner:
|
24
|
+
|
25
|
+
def __init__(self, runner_backend: MoeRunnerBackend, config: MoeRunnerConfig):
|
26
|
+
self.runner_backend = runner_backend
|
27
|
+
self.config = config
|
28
|
+
|
29
|
+
self.fused_func = None
|
30
|
+
|
31
|
+
if runner_backend.is_triton():
|
32
|
+
self.runner_core = TritonRunnerCore(config)
|
33
|
+
else:
|
34
|
+
raise NotImplementedError(f"Unsupported runner backend: {runner_backend}")
|
35
|
+
|
36
|
+
a2a_backend_name = get_moe_a2a_backend().value
|
37
|
+
runner_backend_name = runner_backend.value
|
38
|
+
|
39
|
+
self.fused_func = FusedOpPool.get_fused_func(
|
40
|
+
a2a_backend_name, runner_backend_name
|
41
|
+
)
|
42
|
+
|
43
|
+
SGLANG_CI_DISABLE_MOE_FUSED_FUNC = os.environ.get(
|
44
|
+
"SGLANG_CI_DISABLE_MOE_FUSED_FUNC", "0"
|
45
|
+
)
|
46
|
+
if SGLANG_CI_DISABLE_MOE_FUSED_FUNC == "1":
|
47
|
+
logger.info(
|
48
|
+
"SGLANG_CI_DISABLE_MOE_FUSED_FUNC is set to 1, disabling fused func"
|
49
|
+
)
|
50
|
+
self.fused_func = None
|
51
|
+
|
52
|
+
def run(
|
53
|
+
self, dispatch_output: DispatchOutput, quant_info: MoeQuantInfo
|
54
|
+
) -> CombineInput:
|
55
|
+
|
56
|
+
if self.fused_func is not None:
|
57
|
+
return self.fused_func(dispatch_output, quant_info, self.config)
|
58
|
+
|
59
|
+
dispatch_format = dispatch_output.format.value
|
60
|
+
runner_format = self.runner_core.runner_backend.value
|
61
|
+
self.pre_permute_func = PermuteMethodPool.get_pre_permute(
|
62
|
+
dispatch_format, runner_format
|
63
|
+
)
|
64
|
+
|
65
|
+
running_state = {}
|
66
|
+
runner_input = self.pre_permute_func(
|
67
|
+
dispatch_output, quant_info, self.config, running_state
|
68
|
+
)
|
69
|
+
runner_output = self.runner_core.run(runner_input, quant_info, running_state)
|
70
|
+
|
71
|
+
runner_format = self.runner_core.runner_backend.value
|
72
|
+
combine_format = dispatch_output.format.value
|
73
|
+
self.post_permute_func = PermuteMethodPool.get_post_permute(
|
74
|
+
runner_format, combine_format
|
75
|
+
)
|
76
|
+
combine_input = self.post_permute_func(
|
77
|
+
runner_output, quant_info, self.config, running_state
|
78
|
+
)
|
79
|
+
|
80
|
+
return combine_input
|