sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +267 -32
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/model_config.py +181 -82
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +71 -19
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +326 -53
- sglang/srt/disaggregation/prefill.py +36 -17
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +855 -0
- sglang/srt/entrypoints/grpc_server.py +810 -0
- sglang/srt/entrypoints/http_server.py +130 -59
- sglang/srt/entrypoints/openai/protocol.py +112 -4
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +204 -55
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +9 -2
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +118 -198
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
- sglang/srt/layers/attention/mamba/mamba.py +629 -0
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +44 -12
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +256 -63
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +22 -6
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +78 -49
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +190 -55
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +26 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +52 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +240 -138
- sglang/srt/managers/schedule_policy.py +144 -17
- sglang/srt/managers/scheduler.py +502 -209
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +320 -632
- sglang/srt/managers/tp_worker.py +81 -22
- sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +535 -58
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +222 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +25 -36
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +82 -40
- sglang/srt/model_executor/model_runner.py +432 -157
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +133 -5
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +158 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +40 -4
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +38 -17
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +966 -267
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +99 -28
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +433 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +77 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +383 -5
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +375 -245
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -10,10 +10,14 @@ from torch.nn.parameter import Parameter
|
|
10
10
|
from sglang.srt.distributed import get_tp_group
|
11
11
|
from sglang.srt.layers.dp_attention import get_dp_global_num_tokens, get_local_dp_buffer
|
12
12
|
from sglang.srt.layers.moe import (
|
13
|
+
MoeRunner,
|
14
|
+
MoeRunnerBackend,
|
15
|
+
MoeRunnerConfig,
|
13
16
|
should_use_flashinfer_cutlass_moe_fp4_allgather,
|
14
17
|
should_use_flashinfer_trtllm_moe,
|
15
18
|
)
|
16
19
|
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
|
20
|
+
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
|
17
21
|
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
|
18
22
|
from sglang.srt.layers.quantization.base_config import (
|
19
23
|
FusedMoEMethodBase,
|
@@ -35,12 +39,15 @@ from sglang.srt.layers.quantization.utils import (
|
|
35
39
|
requantize_with_max_scale,
|
36
40
|
)
|
37
41
|
from sglang.srt.layers.radix_attention import RadixAttention
|
38
|
-
from sglang.srt.utils import is_cuda, next_power_of_2
|
42
|
+
from sglang.srt.utils import get_bool_env_var, is_cuda, next_power_of_2
|
39
43
|
|
40
44
|
if TYPE_CHECKING:
|
41
45
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
42
|
-
from sglang.srt.layers.moe.
|
43
|
-
|
46
|
+
from sglang.srt.layers.moe.token_dispatcher import (
|
47
|
+
CombineInput,
|
48
|
+
StandardDispatchOutput,
|
49
|
+
)
|
50
|
+
from sglang.srt.single_batch_overlap import DownGemmOverlapArgs
|
44
51
|
|
45
52
|
if is_cuda():
|
46
53
|
from sgl_kernel import scaled_fp4_quant
|
@@ -68,6 +75,17 @@ except ImportError:
|
|
68
75
|
# Initialize logger for the module
|
69
76
|
logger = logging.getLogger(__name__)
|
70
77
|
|
78
|
+
CUTEDSL_MOE_SCALAR_INPUT_SCALE = get_bool_env_var(
|
79
|
+
"SGLANG_CUTEDSL_MOE_SCALAR_INPUT_SCALE", "true"
|
80
|
+
)
|
81
|
+
USE_CUTLASS_BACKEND_FOR_FP4_GEMM = get_bool_env_var(
|
82
|
+
"SGLANG_USE_CUTLASS_BACKEND_FOR_FP4_GEMM"
|
83
|
+
)
|
84
|
+
# TODO make it true by default when the DeepEP PR is merged
|
85
|
+
CUTEDSL_MOE_NVFP4_DISPATCH = get_bool_env_var(
|
86
|
+
"SGLANG_CUTEDSL_MOE_NVFP4_DISPATCH", "false"
|
87
|
+
)
|
88
|
+
|
71
89
|
# Supported activation schemes for the current configuration
|
72
90
|
ACTIVATION_SCHEMES = ["static"]
|
73
91
|
|
@@ -322,7 +340,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
|
322
340
|
layer: torch.nn.Module,
|
323
341
|
num_experts: int,
|
324
342
|
hidden_size: int,
|
325
|
-
|
343
|
+
intermediate_size_per_partition: int,
|
326
344
|
params_dtype: torch.dtype,
|
327
345
|
**extra_weight_attrs,
|
328
346
|
):
|
@@ -338,7 +356,10 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
|
338
356
|
|
339
357
|
w13_weight = ModelWeightParameter(
|
340
358
|
data=torch.empty(
|
341
|
-
num_experts,
|
359
|
+
num_experts,
|
360
|
+
2 * intermediate_size_per_partition,
|
361
|
+
hidden_size,
|
362
|
+
dtype=weight_dtype,
|
342
363
|
),
|
343
364
|
input_dim=2,
|
344
365
|
output_dim=1,
|
@@ -348,7 +369,10 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
|
348
369
|
|
349
370
|
w2_weight = ModelWeightParameter(
|
350
371
|
data=torch.empty(
|
351
|
-
num_experts,
|
372
|
+
num_experts,
|
373
|
+
hidden_size,
|
374
|
+
intermediate_size_per_partition,
|
375
|
+
dtype=weight_dtype,
|
352
376
|
),
|
353
377
|
input_dim=2,
|
354
378
|
output_dim=1,
|
@@ -414,28 +438,28 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
|
414
438
|
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
415
439
|
|
416
440
|
# Requantize each expert's weights using the combined scale
|
417
|
-
# w13_weight has shape (num_experts, 2 *
|
418
|
-
# where the first
|
419
|
-
|
441
|
+
# w13_weight has shape (num_experts, 2 * intermediate_size_per_partition, hidden_size)
|
442
|
+
# where the first intermediate_size_per_partition rows are w1, the next are w3
|
443
|
+
intermediate_size_per_partition = layer.w13_weight.shape[1] // 2
|
420
444
|
for expert_id in range(layer.w13_weight.shape[0]):
|
421
445
|
start = 0
|
422
446
|
for shard_id in range(2): # w1 and w3
|
423
447
|
# Dequantize using the original scale for this shard
|
424
448
|
dq_weight = per_tensor_dequantize(
|
425
449
|
layer.w13_weight[expert_id][
|
426
|
-
start : start +
|
450
|
+
start : start + intermediate_size_per_partition, :
|
427
451
|
],
|
428
452
|
layer.w13_weight_scale[expert_id][shard_id],
|
429
453
|
)
|
430
454
|
# Requantize using the combined max scale
|
431
455
|
(
|
432
456
|
layer.w13_weight[expert_id][
|
433
|
-
start : start +
|
457
|
+
start : start + intermediate_size_per_partition, :
|
434
458
|
],
|
435
459
|
_,
|
436
460
|
) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
|
437
461
|
|
438
|
-
start +=
|
462
|
+
start += intermediate_size_per_partition
|
439
463
|
|
440
464
|
# Update the scale parameter to be per-expert instead of per-shard
|
441
465
|
layer.w13_weight_scale = Parameter(max_w13_scales, requires_grad=False)
|
@@ -457,29 +481,31 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
|
457
481
|
layer.w2_input_scale.max(), requires_grad=False
|
458
482
|
)
|
459
483
|
|
484
|
+
def create_moe_runner(
|
485
|
+
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
486
|
+
):
|
487
|
+
self.moe_runner_config = moe_runner_config
|
488
|
+
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
|
489
|
+
|
460
490
|
def apply(
|
461
491
|
self,
|
462
492
|
layer: torch.nn.Module,
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
|
469
|
-
return fused_experts(
|
470
|
-
x,
|
471
|
-
layer.w13_weight,
|
472
|
-
layer.w2_weight,
|
473
|
-
topk_output=topk_output,
|
474
|
-
moe_runner_config=moe_runner_config,
|
493
|
+
dispatch_output: StandardDispatchOutput,
|
494
|
+
) -> CombineInput:
|
495
|
+
|
496
|
+
quant_info = TritonMoeQuantInfo(
|
497
|
+
w13_weight=layer.w13_weight,
|
498
|
+
w2_weight=layer.w2_weight,
|
475
499
|
use_fp8_w8a8=True,
|
476
|
-
per_channel_quant=False,
|
477
|
-
|
500
|
+
per_channel_quant=False,
|
501
|
+
w13_scale=layer.w13_weight_scale,
|
478
502
|
w2_scale=layer.w2_weight_scale,
|
479
|
-
|
503
|
+
a13_scale=layer.w13_input_scale,
|
480
504
|
a2_scale=layer.w2_input_scale,
|
481
505
|
)
|
482
506
|
|
507
|
+
return self.runner.run(dispatch_output, quant_info)
|
508
|
+
|
483
509
|
|
484
510
|
class ModelOptFp4Config(QuantizationConfig):
|
485
511
|
"""Config class for FP4."""
|
@@ -628,16 +654,21 @@ class ModelOptFp4Config(QuantizationConfig):
|
|
628
654
|
def is_layer_excluded(self, prefix: str, exclude_modules: list):
|
629
655
|
import regex as re
|
630
656
|
|
657
|
+
fused_patterns = ["q_a_proj", "q_b_proj", "kv_a_proj_with_mqa", "kv_b_proj"]
|
658
|
+
prefix_split = prefix.split(".")
|
631
659
|
for pattern in exclude_modules:
|
632
660
|
regex_str = pattern.replace(".", r"\.").replace("*", r".*")
|
661
|
+
pattern_split = pattern.split(".")
|
633
662
|
if re.fullmatch(regex_str, prefix):
|
634
663
|
return True
|
635
|
-
|
636
|
-
|
637
|
-
|
638
|
-
|
639
|
-
|
640
|
-
|
664
|
+
elif (
|
665
|
+
pattern_split[-1] in fused_patterns
|
666
|
+
and pattern_split[-1] in prefix_split[-1]
|
667
|
+
):
|
668
|
+
# Check if the last part of the excluded pattern is contained in the last part of the prefix
|
669
|
+
# This handles fused modules like fused_qkv_a_proj_with_mqa that contain q_a_proj and kv_a_proj_with_mqa
|
670
|
+
# e.g., model.layers.{i}.self_attn.{fused_weight_name}
|
671
|
+
assert len(prefix_split) == 5 and len(pattern_split) == 5
|
641
672
|
return True
|
642
673
|
return False
|
643
674
|
|
@@ -821,14 +852,25 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
|
|
821
852
|
if enable_flashinfer_fp4_gemm:
|
822
853
|
w = layer.weight.T
|
823
854
|
w_scale_interleaved = layer.weight_scale_interleaved.T
|
824
|
-
|
825
|
-
|
826
|
-
|
827
|
-
|
828
|
-
|
829
|
-
|
830
|
-
|
831
|
-
|
855
|
+
if USE_CUTLASS_BACKEND_FOR_FP4_GEMM:
|
856
|
+
out = fp4_gemm(
|
857
|
+
x_fp4,
|
858
|
+
w,
|
859
|
+
x_scale_interleaved,
|
860
|
+
w_scale_interleaved,
|
861
|
+
layer.alpha,
|
862
|
+
output_dtype,
|
863
|
+
backend="cutlass",
|
864
|
+
)
|
865
|
+
else:
|
866
|
+
out = fp4_gemm(
|
867
|
+
x_fp4,
|
868
|
+
w,
|
869
|
+
x_scale_interleaved,
|
870
|
+
w_scale_interleaved,
|
871
|
+
layer.alpha,
|
872
|
+
output_dtype,
|
873
|
+
)
|
832
874
|
if bias is not None:
|
833
875
|
out = out + bias
|
834
876
|
return out.view(*output_shape)
|
@@ -859,6 +901,13 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
859
901
|
"""Access the global enable_flashinfer_cutlass_moe setting."""
|
860
902
|
return get_moe_runner_backend().is_flashinfer_cutlass()
|
861
903
|
|
904
|
+
@property
|
905
|
+
def enable_flashinfer_cutedsl_moe(self) -> bool:
|
906
|
+
from sglang.srt.layers.moe import get_moe_runner_backend
|
907
|
+
|
908
|
+
"""Access the global enable_flashinfer_cutedsl_moe setting."""
|
909
|
+
return get_moe_runner_backend().is_flashinfer_cutedsl()
|
910
|
+
|
862
911
|
def create_weights(
|
863
912
|
self,
|
864
913
|
layer: torch.nn.Module,
|
@@ -970,15 +1019,17 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
970
1019
|
)
|
971
1020
|
|
972
1021
|
w13_input_scale = PerTensorScaleParameter(
|
973
|
-
data=torch.empty(layer.
|
1022
|
+
data=torch.empty(layer.num_experts, 2, dtype=torch.float32),
|
974
1023
|
weight_loader=weight_loader,
|
975
1024
|
)
|
1025
|
+
w13_input_scale._sglang_require_global_experts = True
|
976
1026
|
layer.register_parameter("w13_input_scale", w13_input_scale)
|
977
1027
|
|
978
1028
|
w2_input_scale = PerTensorScaleParameter(
|
979
|
-
data=torch.empty(layer.
|
1029
|
+
data=torch.empty(layer.num_experts, dtype=torch.float32),
|
980
1030
|
weight_loader=weight_loader,
|
981
1031
|
)
|
1032
|
+
w2_input_scale._sglang_require_global_experts = True
|
982
1033
|
layer.register_parameter("w2_input_scale", w2_input_scale)
|
983
1034
|
|
984
1035
|
def swizzle_blockscale(self, scale: torch.Tensor):
|
@@ -1161,6 +1212,37 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
1161
1212
|
if self.enable_flashinfer_cutlass_moe or self.enable_flashinfer_trtllm_moe:
|
1162
1213
|
w13_input_scale = layer.w13_input_scale.max().to(torch.float32)
|
1163
1214
|
w2_input_scale = layer.w2_input_scale.max().to(torch.float32)
|
1215
|
+
elif self.enable_flashinfer_cutedsl_moe:
|
1216
|
+
# All-expert-one-input-scale is mathematically different from default per-expert-input-scale
|
1217
|
+
# Thus we allow users to switch the flag to do thorough testing
|
1218
|
+
if CUTEDSL_MOE_SCALAR_INPUT_SCALE:
|
1219
|
+
w13_input_scale = (
|
1220
|
+
layer.w13_input_scale.max()
|
1221
|
+
.to(torch.float32)
|
1222
|
+
.repeat(layer.w13_input_scale.shape[0])
|
1223
|
+
)
|
1224
|
+
else:
|
1225
|
+
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(
|
1226
|
+
torch.float32
|
1227
|
+
)
|
1228
|
+
|
1229
|
+
w2_input_scale = layer.w2_input_scale
|
1230
|
+
|
1231
|
+
def _slice_scale(w):
|
1232
|
+
assert w.shape == (layer.num_experts,)
|
1233
|
+
assert layer.moe_ep_size * layer.num_local_experts == layer.num_experts
|
1234
|
+
return w[
|
1235
|
+
layer.moe_ep_rank
|
1236
|
+
* layer.num_local_experts : (layer.moe_ep_rank + 1)
|
1237
|
+
* layer.num_local_experts
|
1238
|
+
]
|
1239
|
+
|
1240
|
+
w13_input_scale = _slice_scale(w13_input_scale)
|
1241
|
+
w2_input_scale = _slice_scale(w2_input_scale)
|
1242
|
+
|
1243
|
+
if CUTEDSL_MOE_NVFP4_DISPATCH:
|
1244
|
+
assert torch.all(w13_input_scale == w13_input_scale[0])
|
1245
|
+
w13_input_scale = w13_input_scale[0]
|
1164
1246
|
else:
|
1165
1247
|
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
|
1166
1248
|
w2_input_scale = layer.w2_input_scale
|
@@ -1243,8 +1325,6 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
1243
1325
|
layer.w13_weight_scale,
|
1244
1326
|
)
|
1245
1327
|
|
1246
|
-
logger.info_once("Applied flashinfer weight processing for both w13 and w2")
|
1247
|
-
|
1248
1328
|
else:
|
1249
1329
|
# CUTLASS processing - handle w13 and w2 separately
|
1250
1330
|
|
@@ -1261,7 +1341,6 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
1261
1341
|
layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
|
1262
1342
|
|
1263
1343
|
# Both flashinfer cutlass and regular cutlass use same processing for w2
|
1264
|
-
logger.info_once("Applied weight processing for both w13 and w2")
|
1265
1344
|
|
1266
1345
|
# Set up CUTLASS MoE parameters
|
1267
1346
|
device = layer.w13_weight.device
|
@@ -1278,21 +1357,32 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
1278
1357
|
# FlashInfer CUTLASS kernel assumes [Up, Gate] Proj as W13
|
1279
1358
|
return self.enable_flashinfer_cutlass_moe
|
1280
1359
|
|
1360
|
+
def create_moe_runner(
|
1361
|
+
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
1362
|
+
):
|
1363
|
+
self.moe_runner_config = moe_runner_config
|
1364
|
+
|
1281
1365
|
def apply(
|
1282
1366
|
self,
|
1283
1367
|
layer: FusedMoE,
|
1284
|
-
|
1285
|
-
|
1286
|
-
|
1287
|
-
|
1368
|
+
dispatch_output: StandardDispatchOutput,
|
1369
|
+
) -> CombineInput:
|
1370
|
+
|
1371
|
+
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
1372
|
+
|
1373
|
+
x = dispatch_output.hidden_states
|
1374
|
+
topk_output = dispatch_output.topk_output
|
1375
|
+
|
1288
1376
|
assert (
|
1289
|
-
moe_runner_config.activation == "silu"
|
1377
|
+
self.moe_runner_config.activation == "silu"
|
1290
1378
|
), "Only SiLU activation is supported."
|
1291
1379
|
|
1380
|
+
moe_runner_config = self.moe_runner_config
|
1381
|
+
|
1292
1382
|
# Check if this is a FlashInferFP4MoE layer that should handle its own forward
|
1293
1383
|
if hasattr(layer, "gemm1_weights_fp4_shuffled"):
|
1294
1384
|
# This layer was processed with flashinfer TRTLLM - delegate to its own forward
|
1295
|
-
return layer.forward(x, topk_output)
|
1385
|
+
return StandardCombineInput(hidden_states=layer.forward(x, topk_output))
|
1296
1386
|
|
1297
1387
|
if self.enable_flashinfer_cutlass_moe:
|
1298
1388
|
assert (
|
@@ -1345,13 +1435,12 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
1345
1435
|
tp_rank=layer.moe_tp_rank,
|
1346
1436
|
tune_max_num_tokens=next_power_of_2(x.shape[0]),
|
1347
1437
|
)[0]
|
1348
|
-
# Scale by routed_scaling_factor is fused into select_experts.
|
1349
1438
|
if should_use_flashinfer_cutlass_moe_fp4_allgather():
|
1350
1439
|
output, global_output = get_local_dp_buffer(), output
|
1351
1440
|
get_tp_group().reduce_scatterv(
|
1352
1441
|
global_output, output=output, sizes=get_dp_global_num_tokens()
|
1353
1442
|
)
|
1354
|
-
return output
|
1443
|
+
return StandardCombineInput(hidden_states=output)
|
1355
1444
|
|
1356
1445
|
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
|
1357
1446
|
|
@@ -1372,4 +1461,50 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
1372
1461
|
apply_router_weight_on_input=moe_runner_config.apply_router_weight_on_input,
|
1373
1462
|
).to(x.dtype)
|
1374
1463
|
# Scale by routed_scaling_factor is fused into select_experts.
|
1375
|
-
return output
|
1464
|
+
return StandardCombineInput(hidden_states=output)
|
1465
|
+
|
1466
|
+
def apply_without_routing_weights(
|
1467
|
+
self,
|
1468
|
+
layer: FusedMoE,
|
1469
|
+
x: torch.Tensor,
|
1470
|
+
masked_m: torch.Tensor,
|
1471
|
+
moe_runner_config: MoeRunnerConfig,
|
1472
|
+
down_gemm_overlap_args: Optional["DownGemmOverlapArgs"],
|
1473
|
+
) -> torch.Tensor:
|
1474
|
+
assert (
|
1475
|
+
moe_runner_config.activation == "silu"
|
1476
|
+
), "Only SiLU activation is supported."
|
1477
|
+
|
1478
|
+
assert self.enable_flashinfer_cutedsl_moe, "only support flashinfer cutedsl moe"
|
1479
|
+
assert (
|
1480
|
+
not moe_runner_config.apply_router_weight_on_input
|
1481
|
+
), "apply_router_weight_on_input is not supported for Flashinfer"
|
1482
|
+
|
1483
|
+
from sglang.srt.layers.moe.flashinfer_cutedsl_moe import (
|
1484
|
+
flashinfer_cutedsl_moe_masked,
|
1485
|
+
)
|
1486
|
+
|
1487
|
+
out = flashinfer_cutedsl_moe_masked(
|
1488
|
+
hidden_states=x,
|
1489
|
+
input_global_scale=(
|
1490
|
+
None if CUTEDSL_MOE_NVFP4_DISPATCH else layer.w13_input_scale_quant
|
1491
|
+
),
|
1492
|
+
w1=layer.w13_weight,
|
1493
|
+
w1_blockscale=layer.w13_blockscale_swizzled,
|
1494
|
+
w1_alpha=layer.g1_alphas,
|
1495
|
+
w2=layer.w2_weight,
|
1496
|
+
a2_global_scale=layer.w2_input_scale_quant,
|
1497
|
+
w2_blockscale=layer.w2_blockscale_swizzled,
|
1498
|
+
w2_alpha=layer.g2_alphas,
|
1499
|
+
masked_m=masked_m,
|
1500
|
+
**(
|
1501
|
+
dict(
|
1502
|
+
down_sm_count=down_gemm_overlap_args.num_sms,
|
1503
|
+
down_signals=down_gemm_overlap_args.signal,
|
1504
|
+
down_start_event=down_gemm_overlap_args.start_event,
|
1505
|
+
)
|
1506
|
+
if down_gemm_overlap_args is not None
|
1507
|
+
else {}
|
1508
|
+
),
|
1509
|
+
)
|
1510
|
+
return out
|
@@ -9,6 +9,8 @@ import torch
|
|
9
9
|
|
10
10
|
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
11
11
|
from sglang.srt.distributed.parallel_state import get_tp_group
|
12
|
+
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
|
13
|
+
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
|
12
14
|
from sglang.srt.layers.quantization.awq import AWQConfig
|
13
15
|
from sglang.srt.layers.quantization.base_config import (
|
14
16
|
FusedMoEMethodBase,
|
@@ -22,8 +24,10 @@ from sglang.srt.utils import get_device_capability, set_weight_attrs
|
|
22
24
|
logger = logging.getLogger(__name__)
|
23
25
|
|
24
26
|
if TYPE_CHECKING:
|
25
|
-
from sglang.srt.layers.moe.
|
26
|
-
|
27
|
+
from sglang.srt.layers.moe.token_dispatcher import (
|
28
|
+
CombineInput,
|
29
|
+
StandardDispatchOutput,
|
30
|
+
)
|
27
31
|
|
28
32
|
|
29
33
|
def get_weight_perm(num_bits: int):
|
@@ -349,37 +353,36 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
|
349
353
|
layer.register_parameter(key, param)
|
350
354
|
set_weight_attrs(param, extra_weight_attrs)
|
351
355
|
|
356
|
+
def create_moe_runner(
|
357
|
+
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
358
|
+
):
|
359
|
+
self.moe_runner_config = moe_runner_config
|
360
|
+
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
|
361
|
+
|
352
362
|
def apply(
|
353
363
|
self,
|
354
364
|
layer: torch.nn.Module,
|
355
|
-
|
356
|
-
|
357
|
-
moe_runner_config: MoeRunnerConfig,
|
358
|
-
) -> torch.Tensor:
|
359
|
-
# avoid circular import
|
360
|
-
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
361
|
-
|
365
|
+
dispatch_output: StandardDispatchOutput,
|
366
|
+
) -> CombineInput:
|
362
367
|
assert (
|
363
|
-
moe_runner_config.activation == "silu"
|
368
|
+
self.moe_runner_config.activation == "silu"
|
364
369
|
), "Only SiLU activation is supported."
|
365
370
|
|
366
371
|
weight_bits = self.quant_config.weight_bits
|
367
372
|
has_zp = self.quant_config.has_zp
|
368
373
|
|
369
|
-
|
370
|
-
|
371
|
-
layer.
|
372
|
-
layer.w2_qweight,
|
373
|
-
topk_output=topk_output,
|
374
|
-
moe_runner_config=moe_runner_config,
|
374
|
+
quant_info = TritonMoeQuantInfo(
|
375
|
+
w13_weight=layer.w13_qweight,
|
376
|
+
w2_weight=layer.w2_qweight,
|
375
377
|
use_int4_w4a16=weight_bits == 4,
|
376
378
|
use_int8_w8a16=weight_bits == 8,
|
377
|
-
|
379
|
+
w13_scale=layer.w13_scales,
|
378
380
|
w2_scale=layer.w2_scales,
|
379
|
-
|
381
|
+
w13_zp=layer.w13_qzeros if has_zp else None,
|
380
382
|
w2_zp=layer.w2_qzeros if has_zp else None,
|
381
383
|
block_shape=[0, layer.group_size],
|
382
384
|
)
|
385
|
+
return self.runner.run(dispatch_output, quant_info)
|
383
386
|
|
384
387
|
@staticmethod
|
385
388
|
def get_weight_loader(layer, weight_loader):
|