sglang 0.5.2rc1__py3-none-any.whl → 0.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +267 -32
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/interpreter.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/model_config.py +181 -82
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +71 -19
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +326 -53
- sglang/srt/disaggregation/prefill.py +36 -17
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +192 -113
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +855 -0
- sglang/srt/entrypoints/grpc_server.py +810 -0
- sglang/srt/entrypoints/http_server.py +132 -57
- sglang/srt/entrypoints/openai/protocol.py +115 -7
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +207 -58
- sglang/srt/entrypoints/openai/serving_completions.py +17 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +10 -4
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +49 -4
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +9 -2
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +24 -1
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +106 -82
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +118 -198
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
- sglang/srt/layers/attention/mamba/mamba.py +629 -0
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +53 -7
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +44 -12
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +256 -63
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +22 -7
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +78 -49
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +225 -57
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +77 -42
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +26 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +78 -49
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +52 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +215 -314
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +358 -404
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +240 -138
- sglang/srt/managers/schedule_policy.py +147 -19
- sglang/srt/managers/scheduler.py +501 -304
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +119 -40
- sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +321 -632
- sglang/srt/managers/tp_worker.py +81 -22
- sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +15 -21
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +58 -34
- sglang/srt/mem_cache/hiradix_cache.py +227 -80
- sglang/srt/mem_cache/memory_pool.py +535 -58
- sglang/srt/mem_cache/memory_pool_host.py +239 -223
- sglang/srt/mem_cache/radix_cache.py +222 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +268 -63
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +198 -30
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +25 -36
- sglang/srt/metrics/collector.py +519 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +55 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +98 -57
- sglang/srt/model_executor/model_runner.py +433 -158
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +133 -5
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +158 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +833 -152
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +14 -5
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +40 -4
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +124 -14
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +26 -5
- sglang/srt/models/qwen3_moe.py +71 -12
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +10 -3
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +6 -0
- sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
- sglang/srt/sampling/sampling_batch_info.py +38 -17
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +1030 -254
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +253 -136
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +445 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +77 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +383 -5
- sglang/utils.py +22 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/RECORD +392 -258
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
- /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
@@ -10,10 +10,14 @@ from torch.nn.parameter import Parameter
|
|
10
10
|
from sglang.srt.distributed import get_tp_group
|
11
11
|
from sglang.srt.layers.dp_attention import get_dp_global_num_tokens, get_local_dp_buffer
|
12
12
|
from sglang.srt.layers.moe import (
|
13
|
+
MoeRunner,
|
14
|
+
MoeRunnerBackend,
|
15
|
+
MoeRunnerConfig,
|
13
16
|
should_use_flashinfer_cutlass_moe_fp4_allgather,
|
14
17
|
should_use_flashinfer_trtllm_moe,
|
15
18
|
)
|
16
19
|
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
|
20
|
+
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
|
17
21
|
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
|
18
22
|
from sglang.srt.layers.quantization.base_config import (
|
19
23
|
FusedMoEMethodBase,
|
@@ -35,12 +39,15 @@ from sglang.srt.layers.quantization.utils import (
|
|
35
39
|
requantize_with_max_scale,
|
36
40
|
)
|
37
41
|
from sglang.srt.layers.radix_attention import RadixAttention
|
38
|
-
from sglang.srt.utils import is_cuda, next_power_of_2
|
42
|
+
from sglang.srt.utils import get_bool_env_var, is_cuda, next_power_of_2
|
39
43
|
|
40
44
|
if TYPE_CHECKING:
|
41
45
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
42
|
-
from sglang.srt.layers.moe.
|
43
|
-
|
46
|
+
from sglang.srt.layers.moe.token_dispatcher import (
|
47
|
+
CombineInput,
|
48
|
+
StandardDispatchOutput,
|
49
|
+
)
|
50
|
+
from sglang.srt.single_batch_overlap import DownGemmOverlapArgs
|
44
51
|
|
45
52
|
if is_cuda():
|
46
53
|
from sgl_kernel import scaled_fp4_quant
|
@@ -68,6 +75,17 @@ except ImportError:
|
|
68
75
|
# Initialize logger for the module
|
69
76
|
logger = logging.getLogger(__name__)
|
70
77
|
|
78
|
+
CUTEDSL_MOE_SCALAR_INPUT_SCALE = get_bool_env_var(
|
79
|
+
"SGLANG_CUTEDSL_MOE_SCALAR_INPUT_SCALE", "true"
|
80
|
+
)
|
81
|
+
USE_CUTLASS_BACKEND_FOR_FP4_GEMM = get_bool_env_var(
|
82
|
+
"SGLANG_USE_CUTLASS_BACKEND_FOR_FP4_GEMM"
|
83
|
+
)
|
84
|
+
# TODO make it true by default when the DeepEP PR is merged
|
85
|
+
CUTEDSL_MOE_NVFP4_DISPATCH = get_bool_env_var(
|
86
|
+
"SGLANG_CUTEDSL_MOE_NVFP4_DISPATCH", "false"
|
87
|
+
)
|
88
|
+
|
71
89
|
# Supported activation schemes for the current configuration
|
72
90
|
ACTIVATION_SCHEMES = ["static"]
|
73
91
|
|
@@ -322,7 +340,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
|
322
340
|
layer: torch.nn.Module,
|
323
341
|
num_experts: int,
|
324
342
|
hidden_size: int,
|
325
|
-
|
343
|
+
intermediate_size_per_partition: int,
|
326
344
|
params_dtype: torch.dtype,
|
327
345
|
**extra_weight_attrs,
|
328
346
|
):
|
@@ -338,7 +356,10 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
|
338
356
|
|
339
357
|
w13_weight = ModelWeightParameter(
|
340
358
|
data=torch.empty(
|
341
|
-
num_experts,
|
359
|
+
num_experts,
|
360
|
+
2 * intermediate_size_per_partition,
|
361
|
+
hidden_size,
|
362
|
+
dtype=weight_dtype,
|
342
363
|
),
|
343
364
|
input_dim=2,
|
344
365
|
output_dim=1,
|
@@ -348,7 +369,10 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
|
348
369
|
|
349
370
|
w2_weight = ModelWeightParameter(
|
350
371
|
data=torch.empty(
|
351
|
-
num_experts,
|
372
|
+
num_experts,
|
373
|
+
hidden_size,
|
374
|
+
intermediate_size_per_partition,
|
375
|
+
dtype=weight_dtype,
|
352
376
|
),
|
353
377
|
input_dim=2,
|
354
378
|
output_dim=1,
|
@@ -414,28 +438,28 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
|
414
438
|
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
415
439
|
|
416
440
|
# Requantize each expert's weights using the combined scale
|
417
|
-
# w13_weight has shape (num_experts, 2 *
|
418
|
-
# where the first
|
419
|
-
|
441
|
+
# w13_weight has shape (num_experts, 2 * intermediate_size_per_partition, hidden_size)
|
442
|
+
# where the first intermediate_size_per_partition rows are w1, the next are w3
|
443
|
+
intermediate_size_per_partition = layer.w13_weight.shape[1] // 2
|
420
444
|
for expert_id in range(layer.w13_weight.shape[0]):
|
421
445
|
start = 0
|
422
446
|
for shard_id in range(2): # w1 and w3
|
423
447
|
# Dequantize using the original scale for this shard
|
424
448
|
dq_weight = per_tensor_dequantize(
|
425
449
|
layer.w13_weight[expert_id][
|
426
|
-
start : start +
|
450
|
+
start : start + intermediate_size_per_partition, :
|
427
451
|
],
|
428
452
|
layer.w13_weight_scale[expert_id][shard_id],
|
429
453
|
)
|
430
454
|
# Requantize using the combined max scale
|
431
455
|
(
|
432
456
|
layer.w13_weight[expert_id][
|
433
|
-
start : start +
|
457
|
+
start : start + intermediate_size_per_partition, :
|
434
458
|
],
|
435
459
|
_,
|
436
460
|
) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
|
437
461
|
|
438
|
-
start +=
|
462
|
+
start += intermediate_size_per_partition
|
439
463
|
|
440
464
|
# Update the scale parameter to be per-expert instead of per-shard
|
441
465
|
layer.w13_weight_scale = Parameter(max_w13_scales, requires_grad=False)
|
@@ -457,29 +481,31 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
|
457
481
|
layer.w2_input_scale.max(), requires_grad=False
|
458
482
|
)
|
459
483
|
|
484
|
+
def create_moe_runner(
|
485
|
+
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
486
|
+
):
|
487
|
+
self.moe_runner_config = moe_runner_config
|
488
|
+
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
|
489
|
+
|
460
490
|
def apply(
|
461
491
|
self,
|
462
492
|
layer: torch.nn.Module,
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
|
469
|
-
return fused_experts(
|
470
|
-
x,
|
471
|
-
layer.w13_weight,
|
472
|
-
layer.w2_weight,
|
473
|
-
topk_output=topk_output,
|
474
|
-
moe_runner_config=moe_runner_config,
|
493
|
+
dispatch_output: StandardDispatchOutput,
|
494
|
+
) -> CombineInput:
|
495
|
+
|
496
|
+
quant_info = TritonMoeQuantInfo(
|
497
|
+
w13_weight=layer.w13_weight,
|
498
|
+
w2_weight=layer.w2_weight,
|
475
499
|
use_fp8_w8a8=True,
|
476
|
-
per_channel_quant=False,
|
477
|
-
|
500
|
+
per_channel_quant=False,
|
501
|
+
w13_scale=layer.w13_weight_scale,
|
478
502
|
w2_scale=layer.w2_weight_scale,
|
479
|
-
|
503
|
+
a13_scale=layer.w13_input_scale,
|
480
504
|
a2_scale=layer.w2_input_scale,
|
481
505
|
)
|
482
506
|
|
507
|
+
return self.runner.run(dispatch_output, quant_info)
|
508
|
+
|
483
509
|
|
484
510
|
class ModelOptFp4Config(QuantizationConfig):
|
485
511
|
"""Config class for FP4."""
|
@@ -517,6 +543,39 @@ class ModelOptFp4Config(QuantizationConfig):
|
|
517
543
|
def get_config_filenames(cls) -> List[str]:
|
518
544
|
return ["hf_quant_config.json"]
|
519
545
|
|
546
|
+
@staticmethod
|
547
|
+
def common_group_size(cfg: dict) -> int:
|
548
|
+
"""Return the unique group_size across the config; raise if missing/mismatched."""
|
549
|
+
sizes = set()
|
550
|
+
|
551
|
+
# Top-level and 'quantization' block
|
552
|
+
v = cfg.get("group_size")
|
553
|
+
if isinstance(v, int):
|
554
|
+
sizes.add(v)
|
555
|
+
q = cfg.get("quantization")
|
556
|
+
if isinstance(q, dict):
|
557
|
+
v = q.get("group_size")
|
558
|
+
if isinstance(v, int):
|
559
|
+
sizes.add(v)
|
560
|
+
|
561
|
+
# config_groups: accept group-level or nested dicts (e.g., weights/input_activations)
|
562
|
+
for g in (cfg.get("config_groups") or {}).values():
|
563
|
+
if isinstance(g, dict):
|
564
|
+
v = g.get("group_size")
|
565
|
+
if isinstance(v, int):
|
566
|
+
sizes.add(v)
|
567
|
+
for sub in g.values():
|
568
|
+
if isinstance(sub, dict):
|
569
|
+
v = sub.get("group_size")
|
570
|
+
if isinstance(v, int):
|
571
|
+
sizes.add(v)
|
572
|
+
|
573
|
+
if not sizes:
|
574
|
+
raise ValueError("No group_size found in config.")
|
575
|
+
if len(sizes) > 1:
|
576
|
+
raise ValueError(f"Inconsistent group_size values: {sorted(sizes)}")
|
577
|
+
return next(iter(sizes))
|
578
|
+
|
520
579
|
@classmethod
|
521
580
|
def from_config(cls, config: Dict[str, Any]) -> ModelOptFp4Config:
|
522
581
|
# Handle two different config formats:
|
@@ -549,7 +608,7 @@ class ModelOptFp4Config(QuantizationConfig):
|
|
549
608
|
else:
|
550
609
|
kv_cache_quant_algo = "auto"
|
551
610
|
|
552
|
-
group_size =
|
611
|
+
group_size = ModelOptFp4Config.common_group_size(config)
|
553
612
|
exclude_modules = config.get("ignore", [])
|
554
613
|
else:
|
555
614
|
# Fall back to nested format (hf_quant_config.json - legacy format)
|
@@ -559,7 +618,7 @@ class ModelOptFp4Config(QuantizationConfig):
|
|
559
618
|
kv_cache_quant_algo = quant_config.get("kv_cache_quant_algo")
|
560
619
|
if not kv_cache_quant_algo:
|
561
620
|
kv_cache_quant_algo = "auto"
|
562
|
-
group_size =
|
621
|
+
group_size = ModelOptFp4Config.common_group_size(config)
|
563
622
|
exclude_modules = quant_config.get("exclude_modules", [])
|
564
623
|
except (ValueError, KeyError):
|
565
624
|
raise ValueError(
|
@@ -595,16 +654,21 @@ class ModelOptFp4Config(QuantizationConfig):
|
|
595
654
|
def is_layer_excluded(self, prefix: str, exclude_modules: list):
|
596
655
|
import regex as re
|
597
656
|
|
657
|
+
fused_patterns = ["q_a_proj", "q_b_proj", "kv_a_proj_with_mqa", "kv_b_proj"]
|
658
|
+
prefix_split = prefix.split(".")
|
598
659
|
for pattern in exclude_modules:
|
599
660
|
regex_str = pattern.replace(".", r"\.").replace("*", r".*")
|
661
|
+
pattern_split = pattern.split(".")
|
600
662
|
if re.fullmatch(regex_str, prefix):
|
601
663
|
return True
|
602
|
-
|
603
|
-
|
604
|
-
|
605
|
-
|
606
|
-
|
607
|
-
|
664
|
+
elif (
|
665
|
+
pattern_split[-1] in fused_patterns
|
666
|
+
and pattern_split[-1] in prefix_split[-1]
|
667
|
+
):
|
668
|
+
# Check if the last part of the excluded pattern is contained in the last part of the prefix
|
669
|
+
# This handles fused modules like fused_qkv_a_proj_with_mqa that contain q_a_proj and kv_a_proj_with_mqa
|
670
|
+
# e.g., model.layers.{i}.self_attn.{fused_weight_name}
|
671
|
+
assert len(prefix_split) == 5 and len(pattern_split) == 5
|
608
672
|
return True
|
609
673
|
return False
|
610
674
|
|
@@ -788,14 +852,25 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
|
|
788
852
|
if enable_flashinfer_fp4_gemm:
|
789
853
|
w = layer.weight.T
|
790
854
|
w_scale_interleaved = layer.weight_scale_interleaved.T
|
791
|
-
|
792
|
-
|
793
|
-
|
794
|
-
|
795
|
-
|
796
|
-
|
797
|
-
|
798
|
-
|
855
|
+
if USE_CUTLASS_BACKEND_FOR_FP4_GEMM:
|
856
|
+
out = fp4_gemm(
|
857
|
+
x_fp4,
|
858
|
+
w,
|
859
|
+
x_scale_interleaved,
|
860
|
+
w_scale_interleaved,
|
861
|
+
layer.alpha,
|
862
|
+
output_dtype,
|
863
|
+
backend="cutlass",
|
864
|
+
)
|
865
|
+
else:
|
866
|
+
out = fp4_gemm(
|
867
|
+
x_fp4,
|
868
|
+
w,
|
869
|
+
x_scale_interleaved,
|
870
|
+
w_scale_interleaved,
|
871
|
+
layer.alpha,
|
872
|
+
output_dtype,
|
873
|
+
)
|
799
874
|
if bias is not None:
|
800
875
|
out = out + bias
|
801
876
|
return out.view(*output_shape)
|
@@ -826,6 +901,13 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
826
901
|
"""Access the global enable_flashinfer_cutlass_moe setting."""
|
827
902
|
return get_moe_runner_backend().is_flashinfer_cutlass()
|
828
903
|
|
904
|
+
@property
|
905
|
+
def enable_flashinfer_cutedsl_moe(self) -> bool:
|
906
|
+
from sglang.srt.layers.moe import get_moe_runner_backend
|
907
|
+
|
908
|
+
"""Access the global enable_flashinfer_cutedsl_moe setting."""
|
909
|
+
return get_moe_runner_backend().is_flashinfer_cutedsl()
|
910
|
+
|
829
911
|
def create_weights(
|
830
912
|
self,
|
831
913
|
layer: torch.nn.Module,
|
@@ -937,15 +1019,17 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
937
1019
|
)
|
938
1020
|
|
939
1021
|
w13_input_scale = PerTensorScaleParameter(
|
940
|
-
data=torch.empty(layer.
|
1022
|
+
data=torch.empty(layer.num_experts, 2, dtype=torch.float32),
|
941
1023
|
weight_loader=weight_loader,
|
942
1024
|
)
|
1025
|
+
w13_input_scale._sglang_require_global_experts = True
|
943
1026
|
layer.register_parameter("w13_input_scale", w13_input_scale)
|
944
1027
|
|
945
1028
|
w2_input_scale = PerTensorScaleParameter(
|
946
|
-
data=torch.empty(layer.
|
1029
|
+
data=torch.empty(layer.num_experts, dtype=torch.float32),
|
947
1030
|
weight_loader=weight_loader,
|
948
1031
|
)
|
1032
|
+
w2_input_scale._sglang_require_global_experts = True
|
949
1033
|
layer.register_parameter("w2_input_scale", w2_input_scale)
|
950
1034
|
|
951
1035
|
def swizzle_blockscale(self, scale: torch.Tensor):
|
@@ -1128,6 +1212,37 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
1128
1212
|
if self.enable_flashinfer_cutlass_moe or self.enable_flashinfer_trtllm_moe:
|
1129
1213
|
w13_input_scale = layer.w13_input_scale.max().to(torch.float32)
|
1130
1214
|
w2_input_scale = layer.w2_input_scale.max().to(torch.float32)
|
1215
|
+
elif self.enable_flashinfer_cutedsl_moe:
|
1216
|
+
# All-expert-one-input-scale is mathematically different from default per-expert-input-scale
|
1217
|
+
# Thus we allow users to switch the flag to do thorough testing
|
1218
|
+
if CUTEDSL_MOE_SCALAR_INPUT_SCALE:
|
1219
|
+
w13_input_scale = (
|
1220
|
+
layer.w13_input_scale.max()
|
1221
|
+
.to(torch.float32)
|
1222
|
+
.repeat(layer.w13_input_scale.shape[0])
|
1223
|
+
)
|
1224
|
+
else:
|
1225
|
+
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(
|
1226
|
+
torch.float32
|
1227
|
+
)
|
1228
|
+
|
1229
|
+
w2_input_scale = layer.w2_input_scale
|
1230
|
+
|
1231
|
+
def _slice_scale(w):
|
1232
|
+
assert w.shape == (layer.num_experts,)
|
1233
|
+
assert layer.moe_ep_size * layer.num_local_experts == layer.num_experts
|
1234
|
+
return w[
|
1235
|
+
layer.moe_ep_rank
|
1236
|
+
* layer.num_local_experts : (layer.moe_ep_rank + 1)
|
1237
|
+
* layer.num_local_experts
|
1238
|
+
]
|
1239
|
+
|
1240
|
+
w13_input_scale = _slice_scale(w13_input_scale)
|
1241
|
+
w2_input_scale = _slice_scale(w2_input_scale)
|
1242
|
+
|
1243
|
+
if CUTEDSL_MOE_NVFP4_DISPATCH:
|
1244
|
+
assert torch.all(w13_input_scale == w13_input_scale[0])
|
1245
|
+
w13_input_scale = w13_input_scale[0]
|
1131
1246
|
else:
|
1132
1247
|
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
|
1133
1248
|
w2_input_scale = layer.w2_input_scale
|
@@ -1210,8 +1325,6 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
1210
1325
|
layer.w13_weight_scale,
|
1211
1326
|
)
|
1212
1327
|
|
1213
|
-
logger.info_once("Applied flashinfer weight processing for both w13 and w2")
|
1214
|
-
|
1215
1328
|
else:
|
1216
1329
|
# CUTLASS processing - handle w13 and w2 separately
|
1217
1330
|
|
@@ -1228,7 +1341,6 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
1228
1341
|
layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
|
1229
1342
|
|
1230
1343
|
# Both flashinfer cutlass and regular cutlass use same processing for w2
|
1231
|
-
logger.info_once("Applied weight processing for both w13 and w2")
|
1232
1344
|
|
1233
1345
|
# Set up CUTLASS MoE parameters
|
1234
1346
|
device = layer.w13_weight.device
|
@@ -1245,21 +1357,32 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
1245
1357
|
# FlashInfer CUTLASS kernel assumes [Up, Gate] Proj as W13
|
1246
1358
|
return self.enable_flashinfer_cutlass_moe
|
1247
1359
|
|
1360
|
+
def create_moe_runner(
|
1361
|
+
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
1362
|
+
):
|
1363
|
+
self.moe_runner_config = moe_runner_config
|
1364
|
+
|
1248
1365
|
def apply(
|
1249
1366
|
self,
|
1250
1367
|
layer: FusedMoE,
|
1251
|
-
|
1252
|
-
|
1253
|
-
|
1254
|
-
|
1368
|
+
dispatch_output: StandardDispatchOutput,
|
1369
|
+
) -> CombineInput:
|
1370
|
+
|
1371
|
+
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
1372
|
+
|
1373
|
+
x = dispatch_output.hidden_states
|
1374
|
+
topk_output = dispatch_output.topk_output
|
1375
|
+
|
1255
1376
|
assert (
|
1256
|
-
moe_runner_config.activation == "silu"
|
1377
|
+
self.moe_runner_config.activation == "silu"
|
1257
1378
|
), "Only SiLU activation is supported."
|
1258
1379
|
|
1380
|
+
moe_runner_config = self.moe_runner_config
|
1381
|
+
|
1259
1382
|
# Check if this is a FlashInferFP4MoE layer that should handle its own forward
|
1260
1383
|
if hasattr(layer, "gemm1_weights_fp4_shuffled"):
|
1261
1384
|
# This layer was processed with flashinfer TRTLLM - delegate to its own forward
|
1262
|
-
return layer.forward(x, topk_output)
|
1385
|
+
return StandardCombineInput(hidden_states=layer.forward(x, topk_output))
|
1263
1386
|
|
1264
1387
|
if self.enable_flashinfer_cutlass_moe:
|
1265
1388
|
assert (
|
@@ -1312,13 +1435,12 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
1312
1435
|
tp_rank=layer.moe_tp_rank,
|
1313
1436
|
tune_max_num_tokens=next_power_of_2(x.shape[0]),
|
1314
1437
|
)[0]
|
1315
|
-
# Scale by routed_scaling_factor is fused into select_experts.
|
1316
1438
|
if should_use_flashinfer_cutlass_moe_fp4_allgather():
|
1317
1439
|
output, global_output = get_local_dp_buffer(), output
|
1318
1440
|
get_tp_group().reduce_scatterv(
|
1319
1441
|
global_output, output=output, sizes=get_dp_global_num_tokens()
|
1320
1442
|
)
|
1321
|
-
return output
|
1443
|
+
return StandardCombineInput(hidden_states=output)
|
1322
1444
|
|
1323
1445
|
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
|
1324
1446
|
|
@@ -1339,4 +1461,50 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
1339
1461
|
apply_router_weight_on_input=moe_runner_config.apply_router_weight_on_input,
|
1340
1462
|
).to(x.dtype)
|
1341
1463
|
# Scale by routed_scaling_factor is fused into select_experts.
|
1342
|
-
return output
|
1464
|
+
return StandardCombineInput(hidden_states=output)
|
1465
|
+
|
1466
|
+
def apply_without_routing_weights(
|
1467
|
+
self,
|
1468
|
+
layer: FusedMoE,
|
1469
|
+
x: torch.Tensor,
|
1470
|
+
masked_m: torch.Tensor,
|
1471
|
+
moe_runner_config: MoeRunnerConfig,
|
1472
|
+
down_gemm_overlap_args: Optional["DownGemmOverlapArgs"],
|
1473
|
+
) -> torch.Tensor:
|
1474
|
+
assert (
|
1475
|
+
moe_runner_config.activation == "silu"
|
1476
|
+
), "Only SiLU activation is supported."
|
1477
|
+
|
1478
|
+
assert self.enable_flashinfer_cutedsl_moe, "only support flashinfer cutedsl moe"
|
1479
|
+
assert (
|
1480
|
+
not moe_runner_config.apply_router_weight_on_input
|
1481
|
+
), "apply_router_weight_on_input is not supported for Flashinfer"
|
1482
|
+
|
1483
|
+
from sglang.srt.layers.moe.flashinfer_cutedsl_moe import (
|
1484
|
+
flashinfer_cutedsl_moe_masked,
|
1485
|
+
)
|
1486
|
+
|
1487
|
+
out = flashinfer_cutedsl_moe_masked(
|
1488
|
+
hidden_states=x,
|
1489
|
+
input_global_scale=(
|
1490
|
+
None if CUTEDSL_MOE_NVFP4_DISPATCH else layer.w13_input_scale_quant
|
1491
|
+
),
|
1492
|
+
w1=layer.w13_weight,
|
1493
|
+
w1_blockscale=layer.w13_blockscale_swizzled,
|
1494
|
+
w1_alpha=layer.g1_alphas,
|
1495
|
+
w2=layer.w2_weight,
|
1496
|
+
a2_global_scale=layer.w2_input_scale_quant,
|
1497
|
+
w2_blockscale=layer.w2_blockscale_swizzled,
|
1498
|
+
w2_alpha=layer.g2_alphas,
|
1499
|
+
masked_m=masked_m,
|
1500
|
+
**(
|
1501
|
+
dict(
|
1502
|
+
down_sm_count=down_gemm_overlap_args.num_sms,
|
1503
|
+
down_signals=down_gemm_overlap_args.signal,
|
1504
|
+
down_start_event=down_gemm_overlap_args.start_event,
|
1505
|
+
)
|
1506
|
+
if down_gemm_overlap_args is not None
|
1507
|
+
else {}
|
1508
|
+
),
|
1509
|
+
)
|
1510
|
+
return out
|
@@ -9,6 +9,8 @@ import torch
|
|
9
9
|
|
10
10
|
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
11
11
|
from sglang.srt.distributed.parallel_state import get_tp_group
|
12
|
+
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
|
13
|
+
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
|
12
14
|
from sglang.srt.layers.quantization.awq import AWQConfig
|
13
15
|
from sglang.srt.layers.quantization.base_config import (
|
14
16
|
FusedMoEMethodBase,
|
@@ -22,8 +24,10 @@ from sglang.srt.utils import get_device_capability, set_weight_attrs
|
|
22
24
|
logger = logging.getLogger(__name__)
|
23
25
|
|
24
26
|
if TYPE_CHECKING:
|
25
|
-
from sglang.srt.layers.moe.
|
26
|
-
|
27
|
+
from sglang.srt.layers.moe.token_dispatcher import (
|
28
|
+
CombineInput,
|
29
|
+
StandardDispatchOutput,
|
30
|
+
)
|
27
31
|
|
28
32
|
|
29
33
|
def get_weight_perm(num_bits: int):
|
@@ -349,37 +353,36 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
|
349
353
|
layer.register_parameter(key, param)
|
350
354
|
set_weight_attrs(param, extra_weight_attrs)
|
351
355
|
|
356
|
+
def create_moe_runner(
|
357
|
+
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
358
|
+
):
|
359
|
+
self.moe_runner_config = moe_runner_config
|
360
|
+
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
|
361
|
+
|
352
362
|
def apply(
|
353
363
|
self,
|
354
364
|
layer: torch.nn.Module,
|
355
|
-
|
356
|
-
|
357
|
-
moe_runner_config: MoeRunnerConfig,
|
358
|
-
) -> torch.Tensor:
|
359
|
-
# avoid circular import
|
360
|
-
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
361
|
-
|
365
|
+
dispatch_output: StandardDispatchOutput,
|
366
|
+
) -> CombineInput:
|
362
367
|
assert (
|
363
|
-
moe_runner_config.activation == "silu"
|
368
|
+
self.moe_runner_config.activation == "silu"
|
364
369
|
), "Only SiLU activation is supported."
|
365
370
|
|
366
371
|
weight_bits = self.quant_config.weight_bits
|
367
372
|
has_zp = self.quant_config.has_zp
|
368
373
|
|
369
|
-
|
370
|
-
|
371
|
-
layer.
|
372
|
-
layer.w2_qweight,
|
373
|
-
topk_output=topk_output,
|
374
|
-
moe_runner_config=moe_runner_config,
|
374
|
+
quant_info = TritonMoeQuantInfo(
|
375
|
+
w13_weight=layer.w13_qweight,
|
376
|
+
w2_weight=layer.w2_qweight,
|
375
377
|
use_int4_w4a16=weight_bits == 4,
|
376
378
|
use_int8_w8a16=weight_bits == 8,
|
377
|
-
|
379
|
+
w13_scale=layer.w13_scales,
|
378
380
|
w2_scale=layer.w2_scales,
|
379
|
-
|
381
|
+
w13_zp=layer.w13_qzeros if has_zp else None,
|
380
382
|
w2_zp=layer.w2_qzeros if has_zp else None,
|
381
383
|
block_shape=[0, layer.group_size],
|
382
384
|
)
|
385
|
+
return self.runner.run(dispatch_output, quant_info)
|
383
386
|
|
384
387
|
@staticmethod
|
385
388
|
def get_weight_loader(layer, weight_loader):
|