sglang 0.5.3rc2__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +47 -28
- sglang/bench_one_batch_server.py +41 -25
- sglang/bench_serving.py +330 -156
- sglang/check_env.py +1 -1
- sglang/compile_deep_gemm.py +6 -2
- sglang/global_config.py +1 -25
- sglang/lang/api.py +6 -0
- sglang/lang/interpreter.py +1 -0
- sglang/lang/ir.py +13 -0
- sglang/launch_server.py +8 -15
- sglang/profiler.py +18 -1
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +4 -6
- sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
- sglang/srt/compilation/backend.py +437 -0
- sglang/srt/compilation/compilation_config.py +20 -0
- sglang/srt/compilation/compilation_counter.py +47 -0
- sglang/srt/compilation/compile.py +210 -0
- sglang/srt/compilation/compiler_interface.py +503 -0
- sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
- sglang/srt/compilation/fix_functionalization.py +134 -0
- sglang/srt/compilation/fx_utils.py +83 -0
- sglang/srt/compilation/inductor_pass.py +140 -0
- sglang/srt/compilation/pass_manager.py +66 -0
- sglang/srt/compilation/piecewise_context_manager.py +40 -0
- sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/deepseek_ocr.py +262 -0
- sglang/srt/configs/deepseekvl2.py +194 -96
- sglang/srt/configs/dots_vlm.py +2 -7
- sglang/srt/configs/falcon_h1.py +13 -64
- sglang/srt/configs/load_config.py +25 -2
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +134 -23
- sglang/srt/configs/modelopt_config.py +30 -0
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/olmo3.py +105 -0
- sglang/srt/configs/points_v15_chat.py +29 -0
- sglang/srt/configs/qwen3_next.py +11 -47
- sglang/srt/configs/qwen3_omni.py +613 -0
- sglang/srt/configs/qwen3_vl.py +0 -10
- sglang/srt/connector/remote_instance.py +1 -1
- sglang/srt/constrained/base_grammar_backend.py +5 -1
- sglang/srt/constrained/llguidance_backend.py +5 -0
- sglang/srt/constrained/outlines_backend.py +1 -1
- sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
- sglang/srt/constrained/utils.py +12 -0
- sglang/srt/constrained/xgrammar_backend.py +20 -11
- sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
- sglang/srt/disaggregation/base/conn.py +17 -4
- sglang/srt/disaggregation/common/conn.py +4 -2
- sglang/srt/disaggregation/decode.py +123 -31
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +11 -3
- sglang/srt/disaggregation/mooncake/conn.py +157 -19
- sglang/srt/disaggregation/nixl/conn.py +69 -24
- sglang/srt/disaggregation/prefill.py +96 -270
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
- sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
- sglang/srt/distributed/device_communicators/pynccl.py +24 -12
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
- sglang/srt/distributed/naive_distributed.py +5 -4
- sglang/srt/distributed/parallel_state.py +70 -19
- sglang/srt/elastic_ep/elastic_ep.py +74 -0
- sglang/srt/entrypoints/context.py +3 -2
- sglang/srt/entrypoints/engine.py +66 -66
- sglang/srt/entrypoints/grpc_server.py +431 -234
- sglang/srt/entrypoints/harmony_utils.py +2 -2
- sglang/srt/entrypoints/http_server.py +120 -8
- sglang/srt/entrypoints/http_server_engine.py +1 -7
- sglang/srt/entrypoints/openai/protocol.py +225 -37
- sglang/srt/entrypoints/openai/serving_base.py +49 -2
- sglang/srt/entrypoints/openai/serving_chat.py +29 -74
- sglang/srt/entrypoints/openai/serving_classify.py +204 -0
- sglang/srt/entrypoints/openai/serving_completions.py +15 -1
- sglang/srt/entrypoints/openai/serving_responses.py +5 -2
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +42 -4
- sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
- sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
- sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
- sglang/srt/eplb/expert_distribution.py +3 -4
- sglang/srt/eplb/expert_location_dispatch.py +2 -2
- sglang/srt/eplb/expert_location_updater.py +2 -2
- sglang/srt/function_call/base_format_detector.py +17 -18
- sglang/srt/function_call/function_call_parser.py +18 -14
- sglang/srt/function_call/glm4_moe_detector.py +1 -5
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/json_array_parser.py +0 -2
- sglang/srt/function_call/utils.py +2 -2
- sglang/srt/grpc/compile_proto.py +3 -3
- sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
- sglang/srt/grpc/health_servicer.py +189 -0
- sglang/srt/grpc/scheduler_launcher.py +181 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
- sglang/srt/layers/activation.py +4 -1
- sglang/srt/layers/attention/aiter_backend.py +3 -3
- sglang/srt/layers/attention/ascend_backend.py +17 -1
- sglang/srt/layers/attention/attention_registry.py +43 -23
- sglang/srt/layers/attention/base_attn_backend.py +20 -1
- sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
- sglang/srt/layers/attention/fla/chunk.py +0 -1
- sglang/srt/layers/attention/fla/chunk_o.py +1 -1
- sglang/srt/layers/attention/fla/index.py +0 -2
- sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
- sglang/srt/layers/attention/fla/utils.py +0 -3
- sglang/srt/layers/attention/fla/wy_fast.py +0 -2
- sglang/srt/layers/attention/flashattention_backend.py +12 -8
- sglang/srt/layers/attention/flashinfer_backend.py +248 -21
- sglang/srt/layers/attention/flashinfer_mla_backend.py +20 -18
- sglang/srt/layers/attention/flashmla_backend.py +2 -2
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
- sglang/srt/layers/attention/intel_amx_backend.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
- sglang/srt/layers/attention/mamba/mamba.py +189 -241
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
- sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
- sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
- sglang/srt/layers/attention/nsa/utils.py +0 -1
- sglang/srt/layers/attention/nsa_backend.py +404 -90
- sglang/srt/layers/attention/triton_backend.py +208 -34
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
- sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
- sglang/srt/layers/attention/trtllm_mla_backend.py +361 -30
- sglang/srt/layers/attention/utils.py +11 -7
- sglang/srt/layers/attention/vision.py +3 -3
- sglang/srt/layers/attention/xpu_backend.py +1028 -0
- sglang/srt/layers/communicator.py +11 -7
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
- sglang/srt/layers/dp_attention.py +17 -0
- sglang/srt/layers/layernorm.py +45 -15
- sglang/srt/layers/linear.py +9 -1
- sglang/srt/layers/logits_processor.py +147 -17
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/cutlass_moe.py +0 -2
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
- sglang/srt/layers/moe/ep_moe/kernels.py +35 -457
- sglang/srt/layers/moe/ep_moe/layer.py +119 -397
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
- sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -70
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/moe_runner/triton.py +3 -1
- sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
- sglang/srt/layers/moe/router.py +51 -15
- sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
- sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
- sglang/srt/layers/moe/token_dispatcher/deepep.py +110 -97
- sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
- sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
- sglang/srt/layers/moe/topk.py +3 -2
- sglang/srt/layers/moe/utils.py +17 -1
- sglang/srt/layers/quantization/__init__.py +2 -53
- sglang/srt/layers/quantization/awq.py +183 -6
- sglang/srt/layers/quantization/awq_triton.py +29 -0
- sglang/srt/layers/quantization/base_config.py +20 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
- sglang/srt/layers/quantization/fp8.py +84 -18
- sglang/srt/layers/quantization/fp8_kernel.py +55 -10
- sglang/srt/layers/quantization/fp8_utils.py +42 -14
- sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
- sglang/srt/layers/quantization/gptq.py +0 -1
- sglang/srt/layers/quantization/int8_kernel.py +18 -2
- sglang/srt/layers/quantization/marlin_utils.py +12 -0
- sglang/srt/layers/quantization/modelopt_quant.py +125 -100
- sglang/srt/layers/quantization/mxfp4.py +5 -30
- sglang/srt/layers/quantization/petit.py +1 -1
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
- sglang/srt/layers/quantization/unquant.py +1 -4
- sglang/srt/layers/quantization/utils.py +0 -1
- sglang/srt/layers/quantization/w4afp8.py +51 -20
- sglang/srt/layers/quantization/w8a8_int8.py +30 -24
- sglang/srt/layers/radix_attention.py +59 -9
- sglang/srt/layers/rotary_embedding.py +673 -16
- sglang/srt/layers/sampler.py +36 -16
- sglang/srt/layers/sparse_pooler.py +98 -0
- sglang/srt/layers/utils.py +0 -1
- sglang/srt/layers/vocab_parallel_embedding.py +4 -1
- sglang/srt/lora/backend/triton_backend.py +0 -1
- sglang/srt/lora/eviction_policy.py +139 -0
- sglang/srt/lora/lora_manager.py +24 -9
- sglang/srt/lora/lora_registry.py +1 -1
- sglang/srt/lora/mem_pool.py +40 -16
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
- sglang/srt/managers/cache_controller.py +48 -17
- sglang/srt/managers/data_parallel_controller.py +146 -42
- sglang/srt/managers/detokenizer_manager.py +40 -13
- sglang/srt/managers/io_struct.py +66 -16
- sglang/srt/managers/mm_utils.py +20 -18
- sglang/srt/managers/multi_tokenizer_mixin.py +66 -81
- sglang/srt/managers/overlap_utils.py +96 -19
- sglang/srt/managers/schedule_batch.py +241 -511
- sglang/srt/managers/schedule_policy.py +15 -2
- sglang/srt/managers/scheduler.py +399 -499
- sglang/srt/managers/scheduler_metrics_mixin.py +55 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
- sglang/srt/managers/scheduler_pp_mixin.py +341 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +57 -10
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
- sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
- sglang/srt/managers/tokenizer_manager.py +378 -90
- sglang/srt/managers/tp_worker.py +212 -161
- sglang/srt/managers/utils.py +78 -2
- sglang/srt/mem_cache/allocator.py +7 -2
- sglang/srt/mem_cache/allocator_ascend.py +2 -2
- sglang/srt/mem_cache/base_prefix_cache.py +2 -2
- sglang/srt/mem_cache/chunk_cache.py +13 -2
- sglang/srt/mem_cache/common.py +480 -0
- sglang/srt/mem_cache/evict_policy.py +16 -1
- sglang/srt/mem_cache/hicache_storage.py +4 -1
- sglang/srt/mem_cache/hiradix_cache.py +16 -3
- sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
- sglang/srt/mem_cache/memory_pool.py +435 -219
- sglang/srt/mem_cache/memory_pool_host.py +0 -1
- sglang/srt/mem_cache/multimodal_cache.py +0 -1
- sglang/srt/mem_cache/radix_cache.py +53 -19
- sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
- sglang/srt/mem_cache/storage/backend_factory.py +2 -2
- sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
- sglang/srt/mem_cache/swa_radix_cache.py +92 -26
- sglang/srt/metrics/collector.py +31 -0
- sglang/srt/metrics/func_timer.py +1 -1
- sglang/srt/model_executor/cuda_graph_runner.py +43 -5
- sglang/srt/model_executor/forward_batch_info.py +28 -23
- sglang/srt/model_executor/model_runner.py +379 -139
- sglang/srt/model_executor/npu_graph_runner.py +2 -3
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
- sglang/srt/model_loader/__init__.py +1 -1
- sglang/srt/model_loader/loader.py +424 -27
- sglang/srt/model_loader/utils.py +0 -1
- sglang/srt/model_loader/weight_utils.py +47 -28
- sglang/srt/models/apertus.py +2 -3
- sglang/srt/models/arcee.py +2 -2
- sglang/srt/models/bailing_moe.py +13 -52
- sglang/srt/models/bailing_moe_nextn.py +3 -4
- sglang/srt/models/bert.py +1 -1
- sglang/srt/models/deepseek_nextn.py +19 -3
- sglang/srt/models/deepseek_ocr.py +1516 -0
- sglang/srt/models/deepseek_v2.py +273 -98
- sglang/srt/models/dots_ocr.py +0 -2
- sglang/srt/models/dots_vlm.py +0 -1
- sglang/srt/models/dots_vlm_vit.py +1 -1
- sglang/srt/models/falcon_h1.py +13 -19
- sglang/srt/models/gemma3_mm.py +16 -0
- sglang/srt/models/gemma3n_mm.py +1 -2
- sglang/srt/models/glm4_moe.py +14 -37
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +2 -1
- sglang/srt/models/glm4v_moe.py +5 -5
- sglang/srt/models/gpt_oss.py +5 -5
- sglang/srt/models/grok.py +10 -23
- sglang/srt/models/hunyuan.py +2 -7
- sglang/srt/models/interns1.py +0 -1
- sglang/srt/models/kimi_vl.py +1 -7
- sglang/srt/models/kimi_vl_moonvit.py +3 -1
- sglang/srt/models/llama.py +2 -2
- sglang/srt/models/llama_eagle3.py +1 -1
- sglang/srt/models/longcat_flash.py +5 -22
- sglang/srt/models/longcat_flash_nextn.py +3 -14
- sglang/srt/models/mimo.py +2 -13
- sglang/srt/models/mimo_mtp.py +1 -2
- sglang/srt/models/minicpmo.py +7 -5
- sglang/srt/models/mixtral.py +1 -4
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/mllama4.py +13 -3
- sglang/srt/models/nemotron_h.py +511 -0
- sglang/srt/models/olmo2.py +31 -4
- sglang/srt/models/opt.py +5 -5
- sglang/srt/models/phi.py +1 -1
- sglang/srt/models/phi4mm.py +1 -1
- sglang/srt/models/phimoe.py +0 -1
- sglang/srt/models/pixtral.py +0 -3
- sglang/srt/models/points_v15_chat.py +186 -0
- sglang/srt/models/qwen.py +0 -1
- sglang/srt/models/qwen2_5_vl.py +3 -3
- sglang/srt/models/qwen2_audio.py +2 -15
- sglang/srt/models/qwen2_moe.py +15 -12
- sglang/srt/models/qwen2_vl.py +5 -2
- sglang/srt/models/qwen3_moe.py +19 -35
- sglang/srt/models/qwen3_next.py +7 -12
- sglang/srt/models/qwen3_next_mtp.py +3 -4
- sglang/srt/models/qwen3_omni_moe.py +661 -0
- sglang/srt/models/qwen3_vl.py +37 -33
- sglang/srt/models/qwen3_vl_moe.py +57 -185
- sglang/srt/models/roberta.py +55 -3
- sglang/srt/models/sarashina2_vision.py +0 -1
- sglang/srt/models/step3_vl.py +3 -5
- sglang/srt/models/utils.py +11 -1
- sglang/srt/multimodal/processors/base_processor.py +6 -2
- sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
- sglang/srt/multimodal/processors/dots_vlm.py +0 -1
- sglang/srt/multimodal/processors/glm4v.py +1 -5
- sglang/srt/multimodal/processors/internvl.py +0 -2
- sglang/srt/multimodal/processors/janus_pro.py +0 -1
- sglang/srt/multimodal/processors/mllama4.py +0 -8
- sglang/srt/multimodal/processors/phi4mm.py +0 -1
- sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
- sglang/srt/multimodal/processors/qwen_vl.py +75 -16
- sglang/srt/multimodal/processors/step3_vl.py +1 -1
- sglang/srt/parser/conversation.py +41 -0
- sglang/srt/parser/reasoning_parser.py +0 -1
- sglang/srt/sampling/custom_logit_processor.py +77 -2
- sglang/srt/sampling/sampling_batch_info.py +17 -22
- sglang/srt/sampling/sampling_params.py +70 -2
- sglang/srt/server_args.py +577 -73
- sglang/srt/server_args_config_parser.py +1 -1
- sglang/srt/single_batch_overlap.py +38 -28
- sglang/srt/speculative/base_spec_worker.py +34 -0
- sglang/srt/speculative/draft_utils.py +226 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
- sglang/srt/speculative/eagle_info.py +57 -18
- sglang/srt/speculative/eagle_info_v2.py +458 -0
- sglang/srt/speculative/eagle_utils.py +138 -0
- sglang/srt/speculative/eagle_worker.py +83 -280
- sglang/srt/speculative/eagle_worker_v2.py +702 -0
- sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
- sglang/srt/speculative/ngram_worker.py +12 -11
- sglang/srt/speculative/spec_info.py +2 -0
- sglang/srt/speculative/spec_utils.py +38 -3
- sglang/srt/speculative/standalone_worker.py +4 -14
- sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
- sglang/srt/two_batch_overlap.py +28 -14
- sglang/srt/utils/__init__.py +1 -1
- sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
- sglang/srt/utils/common.py +192 -47
- sglang/srt/utils/hf_transformers_utils.py +40 -17
- sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
- sglang/srt/{offloader.py → utils/offloader.py} +4 -4
- sglang/srt/utils/profile_merger.py +199 -0
- sglang/test/attention/test_flashattn_backend.py +1 -1
- sglang/test/attention/test_flashattn_mla_backend.py +0 -1
- sglang/test/attention/test_prefix_chunk_info.py +0 -2
- sglang/test/attention/test_trtllm_mla_backend.py +221 -53
- sglang/test/few_shot_gsm8k_engine.py +2 -4
- sglang/test/kit_matched_stop.py +157 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +41 -0
- sglang/test/runners.py +2 -0
- sglang/test/send_one.py +42 -7
- sglang/test/simple_eval_common.py +3 -0
- sglang/test/simple_eval_gpqa.py +0 -1
- sglang/test/simple_eval_humaneval.py +0 -3
- sglang/test/simple_eval_longbench_v2.py +344 -0
- sglang/test/test_block_fp8.py +1 -2
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
- sglang/test/test_cutlass_moe.py +1 -2
- sglang/test/test_cutlass_w4a8_moe.py +10 -20
- sglang/test/test_deterministic.py +232 -99
- sglang/test/test_deterministic_utils.py +73 -0
- sglang/test/test_disaggregation_utils.py +81 -0
- sglang/test/test_marlin_moe.py +0 -1
- sglang/test/test_utils.py +85 -20
- sglang/version.py +1 -1
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/METADATA +45 -33
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/RECORD +404 -345
- sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
- sglang/srt/speculative/build_eagle_tree.py +0 -427
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
- /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
- /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
|
@@ -1,17 +1,32 @@
|
|
|
1
|
+
# Copyright 2025 SGLang Team
|
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
3
|
+
# you may not use this file except in compliance with the License.
|
|
4
|
+
# You may obtain a copy of the License at
|
|
5
|
+
#
|
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
7
|
+
#
|
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
11
|
+
# See the License for the specific language governing permissions and
|
|
12
|
+
# limitations under the License.
|
|
13
|
+
# ==============================================================================
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
1
17
|
from dataclasses import dataclass
|
|
2
18
|
from typing import TYPE_CHECKING, Any, Callable, Optional
|
|
3
19
|
|
|
4
20
|
import torch
|
|
5
21
|
|
|
22
|
+
from sglang.srt.layers import deep_gemm_wrapper
|
|
6
23
|
from sglang.srt.layers.moe import get_moe_runner_backend
|
|
24
|
+
from sglang.srt.layers.moe.topk import TopKOutput
|
|
7
25
|
from sglang.srt.layers.moe.utils import is_sbo_enabled
|
|
8
|
-
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
|
9
|
-
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
|
10
|
-
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|
11
26
|
from sglang.srt.utils import get_int_env_var
|
|
12
27
|
|
|
13
28
|
if TYPE_CHECKING:
|
|
14
|
-
from sglang.srt.layers.moe.
|
|
29
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
|
15
30
|
|
|
16
31
|
|
|
17
32
|
class SboFlags:
|
|
@@ -43,7 +58,7 @@ class CombineOverlapArgs:
|
|
|
43
58
|
wait_event: torch.cuda.Event
|
|
44
59
|
num_sms: int
|
|
45
60
|
signal: Optional[torch.Tensor] = None
|
|
46
|
-
threshold: int =
|
|
61
|
+
threshold: int = 0
|
|
47
62
|
|
|
48
63
|
|
|
49
64
|
@dataclass
|
|
@@ -55,57 +70,52 @@ class DownGemmOverlapArgs:
|
|
|
55
70
|
|
|
56
71
|
def execute_sbo(
|
|
57
72
|
forward_shared_experts: Callable[[], Any],
|
|
58
|
-
experts:
|
|
73
|
+
experts: FusedMoE,
|
|
59
74
|
hidden_states: torch.Tensor,
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
alt_stream: Optional = None,
|
|
75
|
+
topk_output: TopKOutput,
|
|
76
|
+
alt_stream: Optional[torch.cuda.Stream] = None,
|
|
77
|
+
disable_sbo: bool = False,
|
|
64
78
|
):
|
|
65
|
-
shared_output = None
|
|
66
79
|
|
|
67
|
-
dispatch_output = experts.dispatch(
|
|
68
|
-
hidden_states,
|
|
80
|
+
dispatch_output = experts.dispatcher.dispatch(
|
|
81
|
+
hidden_states=hidden_states, topk_output=topk_output
|
|
69
82
|
)
|
|
70
83
|
|
|
71
84
|
combine_overlap_args, down_gemm_overlap_args, meta_overlap_args = (
|
|
72
|
-
_compute_overlap_args(dispatch_output, alt_stream)
|
|
85
|
+
_compute_overlap_args(dispatch_output, alt_stream, disable_sbo=disable_sbo)
|
|
73
86
|
)
|
|
74
87
|
|
|
75
|
-
hidden_states = experts.
|
|
88
|
+
hidden_states = experts.run_moe_core(
|
|
76
89
|
dispatch_output, down_gemm_overlap_args=down_gemm_overlap_args
|
|
77
90
|
)
|
|
78
91
|
if (e := meta_overlap_args.get("record_event_after_down")) is not None:
|
|
79
92
|
e.record()
|
|
80
93
|
|
|
81
|
-
if SboFlags.enable_combine_shared_two_stream_overlap():
|
|
94
|
+
if (not disable_sbo) and SboFlags.enable_combine_shared_two_stream_overlap():
|
|
82
95
|
# TODO reduce sm for non-deepgemm
|
|
83
96
|
with deep_gemm_wrapper.configure_deep_gemm_num_sms(
|
|
84
97
|
meta_overlap_args["compute_num_sms"]
|
|
85
98
|
):
|
|
86
|
-
|
|
99
|
+
forward_shared_experts()
|
|
87
100
|
|
|
88
|
-
hidden_states = experts.combine(
|
|
89
|
-
hidden_states,
|
|
90
|
-
dispatch_output.
|
|
91
|
-
dispatch_output.topk_weights,
|
|
92
|
-
forward_batch,
|
|
101
|
+
hidden_states = experts.dispatcher.combine(
|
|
102
|
+
hidden_states=hidden_states,
|
|
103
|
+
topk_ids=dispatch_output.topk_ids,
|
|
104
|
+
topk_weights=dispatch_output.topk_weights,
|
|
93
105
|
overlap_args=combine_overlap_args,
|
|
94
106
|
)
|
|
95
107
|
|
|
96
|
-
return hidden_states
|
|
108
|
+
return hidden_states
|
|
97
109
|
|
|
98
110
|
|
|
99
|
-
def _compute_overlap_args(dispatch_output, alt_stream):
|
|
100
|
-
if not (
|
|
111
|
+
def _compute_overlap_args(dispatch_output, alt_stream, disable_sbo):
|
|
112
|
+
if disable_sbo or not (
|
|
101
113
|
SboFlags.enable_combine_down_gemm_two_stream_overlap()
|
|
102
114
|
or SboFlags.enable_combine_shared_two_stream_overlap()
|
|
103
115
|
):
|
|
104
116
|
return None, None, {}
|
|
105
117
|
|
|
106
|
-
hidden_states = dispatch_output.
|
|
107
|
-
if isinstance(hidden_states, tuple):
|
|
108
|
-
hidden_states = hidden_states[0]
|
|
118
|
+
hidden_states = dispatch_output.hidden_states
|
|
109
119
|
|
|
110
120
|
num_local_experts, num_tokens_static, hidden_dim = hidden_states.shape
|
|
111
121
|
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
if TYPE_CHECKING:
|
|
7
|
+
from sglang.srt.managers.tp_worker import TpModelWorker
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class BaseDraftWorker(ABC):
|
|
11
|
+
@abstractmethod
|
|
12
|
+
def draft():
|
|
13
|
+
pass
|
|
14
|
+
|
|
15
|
+
@abstractmethod
|
|
16
|
+
def draft_extend():
|
|
17
|
+
pass
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class BaseSpecWorker(ABC):
|
|
21
|
+
@property
|
|
22
|
+
@abstractmethod
|
|
23
|
+
def target_worker(self) -> TpModelWorker:
|
|
24
|
+
pass
|
|
25
|
+
|
|
26
|
+
@property
|
|
27
|
+
@abstractmethod
|
|
28
|
+
def draft_worker(self) -> BaseDraftWorker:
|
|
29
|
+
pass
|
|
30
|
+
|
|
31
|
+
@abstractmethod
|
|
32
|
+
def clear_cache_pool(self):
|
|
33
|
+
# TODO: move this abstract method to BaseTpWorker and call through self.model_runner
|
|
34
|
+
pass
|
|
@@ -0,0 +1,226 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
from sglang.srt.server_args import ServerArgs, get_global_server_args
|
|
4
|
+
from sglang.srt.utils.common import is_blackwell
|
|
5
|
+
|
|
6
|
+
logger = logging.getLogger(__name__)
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class DraftBackendFactory:
|
|
10
|
+
def __init__(
|
|
11
|
+
self,
|
|
12
|
+
server_args: ServerArgs,
|
|
13
|
+
draft_model_runner,
|
|
14
|
+
topk: int,
|
|
15
|
+
speculative_num_steps: int,
|
|
16
|
+
):
|
|
17
|
+
self.server_args = server_args
|
|
18
|
+
self.draft_model_runner = draft_model_runner
|
|
19
|
+
self.topk = topk
|
|
20
|
+
self.speculative_num_steps = speculative_num_steps
|
|
21
|
+
|
|
22
|
+
def _create_backend(
|
|
23
|
+
self, backend_name: str, backend_map: dict, error_template: str
|
|
24
|
+
):
|
|
25
|
+
backend_type = getattr(self.server_args, backend_name)
|
|
26
|
+
if backend_type is None:
|
|
27
|
+
backend_type = self.server_args.attention_backend
|
|
28
|
+
|
|
29
|
+
if backend_type not in backend_map:
|
|
30
|
+
raise ValueError(error_template.format(backend_type=backend_type))
|
|
31
|
+
|
|
32
|
+
return backend_map[backend_type]()
|
|
33
|
+
|
|
34
|
+
def create_decode_backend(self):
|
|
35
|
+
if self.speculative_num_steps == 1:
|
|
36
|
+
return None
|
|
37
|
+
|
|
38
|
+
backend_map = {
|
|
39
|
+
"flashinfer": self._create_flashinfer_decode_backend,
|
|
40
|
+
"triton": self._create_triton_decode_backend,
|
|
41
|
+
"aiter": self._create_aiter_decode_backend,
|
|
42
|
+
"fa3": self._create_fa3_decode_backend,
|
|
43
|
+
"hybrid_linear_attn": (
|
|
44
|
+
self._create_fa3_decode_backend
|
|
45
|
+
if not is_blackwell()
|
|
46
|
+
else self._create_triton_decode_backend
|
|
47
|
+
),
|
|
48
|
+
"flashmla": self._create_flashmla_decode_backend,
|
|
49
|
+
"trtllm_mha": self._create_trtllm_mha_decode_backend,
|
|
50
|
+
"trtllm_mla": self._create_trtllm_mla_decode_backend,
|
|
51
|
+
"nsa": self._create_nsa_decode_backend,
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
return self._create_backend(
|
|
55
|
+
"decode_attention_backend",
|
|
56
|
+
backend_map,
|
|
57
|
+
"EAGLE is not supported in decode attention backend {backend_type}",
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
def create_draft_extend_backend(self):
|
|
61
|
+
backend_map = {
|
|
62
|
+
"flashinfer": self._create_flashinfer_prefill_backend,
|
|
63
|
+
"triton": self._create_triton_prefill_backend,
|
|
64
|
+
"aiter": self._create_aiter_prefill_backend,
|
|
65
|
+
"fa3": self._create_fa3_prefill_backend,
|
|
66
|
+
"hybrid_linear_attn": (
|
|
67
|
+
self._create_fa3_prefill_backend
|
|
68
|
+
if not is_blackwell()
|
|
69
|
+
else self._create_triton_prefill_backend
|
|
70
|
+
),
|
|
71
|
+
"flashmla": self._create_flashmla_prefill_backend,
|
|
72
|
+
"trtllm_mha": self._create_trtllm_mha_prefill_backend,
|
|
73
|
+
"trtllm_mla": self._create_trtllm_mla_prefill_backend,
|
|
74
|
+
"nsa": self._create_nsa_prefill_backend,
|
|
75
|
+
}
|
|
76
|
+
backend_name = (
|
|
77
|
+
"decode_attention_backend"
|
|
78
|
+
if self.server_args.speculative_attention_mode == "decode"
|
|
79
|
+
else "prefill_attention_backend"
|
|
80
|
+
)
|
|
81
|
+
return self._create_backend(
|
|
82
|
+
backend_name,
|
|
83
|
+
backend_map,
|
|
84
|
+
"EAGLE is not supported in attention backend {backend_type}",
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
def _create_nsa_decode_backend(self):
|
|
88
|
+
from sglang.srt.layers.attention.nsa_backend import (
|
|
89
|
+
NativeSparseAttnMultiStepBackend,
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
return NativeSparseAttnMultiStepBackend(
|
|
93
|
+
self.draft_model_runner, self.topk, self.speculative_num_steps
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
def _create_nsa_prefill_backend(self):
|
|
97
|
+
from sglang.srt.layers.attention.nsa_backend import NativeSparseAttnBackend
|
|
98
|
+
|
|
99
|
+
return NativeSparseAttnBackend(self.draft_model_runner, skip_prefill=False)
|
|
100
|
+
|
|
101
|
+
def _create_flashinfer_decode_backend(self):
|
|
102
|
+
if not get_global_server_args().use_mla_backend:
|
|
103
|
+
from sglang.srt.layers.attention.flashinfer_backend import (
|
|
104
|
+
FlashInferMultiStepDraftBackend,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
return FlashInferMultiStepDraftBackend(
|
|
108
|
+
self.draft_model_runner, self.topk, self.speculative_num_steps
|
|
109
|
+
)
|
|
110
|
+
else:
|
|
111
|
+
from sglang.srt.layers.attention.flashinfer_mla_backend import (
|
|
112
|
+
FlashInferMLAMultiStepDraftBackend,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
return FlashInferMLAMultiStepDraftBackend(
|
|
116
|
+
self.draft_model_runner, self.topk, self.speculative_num_steps
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
def _create_triton_decode_backend(self):
|
|
120
|
+
from sglang.srt.layers.attention.triton_backend import (
|
|
121
|
+
TritonMultiStepDraftBackend,
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
return TritonMultiStepDraftBackend(
|
|
125
|
+
self.draft_model_runner, self.topk, self.speculative_num_steps
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
def _create_aiter_decode_backend(self):
|
|
129
|
+
from sglang.srt.layers.attention.aiter_backend import AiterMultiStepDraftBackend
|
|
130
|
+
|
|
131
|
+
return AiterMultiStepDraftBackend(
|
|
132
|
+
self.draft_model_runner, self.topk, self.speculative_num_steps
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
def _create_fa3_decode_backend(self):
|
|
136
|
+
from sglang.srt.layers.attention.flashattention_backend import (
|
|
137
|
+
FlashAttentionMultiStepBackend,
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
return FlashAttentionMultiStepBackend(
|
|
141
|
+
self.draft_model_runner, self.topk, self.speculative_num_steps
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
def _create_flashmla_decode_backend(self):
|
|
145
|
+
from sglang.srt.layers.attention.flashmla_backend import (
|
|
146
|
+
FlashMLAMultiStepDraftBackend,
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
return FlashMLAMultiStepDraftBackend(
|
|
150
|
+
self.draft_model_runner, self.topk, self.speculative_num_steps
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
def _create_trtllm_mha_decode_backend(self):
|
|
154
|
+
from sglang.srt.layers.attention.trtllm_mha_backend import (
|
|
155
|
+
TRTLLMHAAttnMultiStepDraftBackend,
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
return TRTLLMHAAttnMultiStepDraftBackend(
|
|
159
|
+
self.draft_model_runner, self.topk, self.speculative_num_steps
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
def _create_trtllm_mla_decode_backend(self):
|
|
163
|
+
if not get_global_server_args().use_mla_backend:
|
|
164
|
+
raise ValueError(
|
|
165
|
+
"trtllm_mla backend requires MLA model (use_mla_backend=True)."
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
from sglang.srt.layers.attention.trtllm_mla_backend import (
|
|
169
|
+
TRTLLMMLAMultiStepDraftBackend,
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
return TRTLLMMLAMultiStepDraftBackend(
|
|
173
|
+
self.draft_model_runner, self.topk, self.speculative_num_steps
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
def _create_flashinfer_prefill_backend(self):
|
|
177
|
+
if not get_global_server_args().use_mla_backend:
|
|
178
|
+
from sglang.srt.layers.attention.flashinfer_backend import (
|
|
179
|
+
FlashInferAttnBackend,
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
return FlashInferAttnBackend(self.draft_model_runner, skip_prefill=False)
|
|
183
|
+
else:
|
|
184
|
+
from sglang.srt.layers.attention.flashinfer_mla_backend import (
|
|
185
|
+
FlashInferMLAAttnBackend,
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
return FlashInferMLAAttnBackend(self.draft_model_runner, skip_prefill=False)
|
|
189
|
+
|
|
190
|
+
def _create_triton_prefill_backend(self):
|
|
191
|
+
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
|
|
192
|
+
|
|
193
|
+
return TritonAttnBackend(self.draft_model_runner, skip_prefill=False)
|
|
194
|
+
|
|
195
|
+
def _create_aiter_prefill_backend(self):
|
|
196
|
+
from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
|
|
197
|
+
|
|
198
|
+
return AiterAttnBackend(self.draft_model_runner, skip_prefill=False)
|
|
199
|
+
|
|
200
|
+
def _create_fa3_prefill_backend(self):
|
|
201
|
+
from sglang.srt.layers.attention.flashattention_backend import (
|
|
202
|
+
FlashAttentionBackend,
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
return FlashAttentionBackend(self.draft_model_runner, skip_prefill=False)
|
|
206
|
+
|
|
207
|
+
def _create_trtllm_mha_prefill_backend(self):
|
|
208
|
+
from sglang.srt.layers.attention.trtllm_mha_backend import TRTLLMHAAttnBackend
|
|
209
|
+
|
|
210
|
+
return TRTLLMHAAttnBackend(self.draft_model_runner, skip_prefill=False)
|
|
211
|
+
|
|
212
|
+
def _create_trtllm_mla_prefill_backend(self):
|
|
213
|
+
if not get_global_server_args().use_mla_backend:
|
|
214
|
+
raise ValueError(
|
|
215
|
+
"trtllm_mla backend requires MLA model (use_mla_backend=True)."
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend
|
|
219
|
+
|
|
220
|
+
return TRTLLMMLABackend(self.draft_model_runner, skip_prefill=False)
|
|
221
|
+
|
|
222
|
+
def _create_flashmla_prefill_backend(self):
|
|
223
|
+
logger.warning(
|
|
224
|
+
"flashmla prefill backend is not yet supported for draft extend."
|
|
225
|
+
)
|
|
226
|
+
return None
|
|
@@ -9,10 +9,12 @@ from sglang.srt.layers.dp_attention import DpPaddingMode, set_dp_buffer_len
|
|
|
9
9
|
from sglang.srt.model_executor.cuda_graph_runner import (
|
|
10
10
|
CUDA_GRAPH_CAPTURE_FAILED_MSG,
|
|
11
11
|
CudaGraphRunner,
|
|
12
|
+
DeepEPCudaGraphRunnerAdapter,
|
|
12
13
|
get_batch_sizes_to_capture,
|
|
13
14
|
get_global_graph_memory_pool,
|
|
14
15
|
model_capture_mode,
|
|
15
16
|
set_global_graph_memory_pool,
|
|
17
|
+
set_is_extend_in_batch,
|
|
16
18
|
set_torch_compile_config,
|
|
17
19
|
)
|
|
18
20
|
from sglang.srt.model_executor.forward_batch_info import (
|
|
@@ -40,8 +42,11 @@ class EAGLEDraftCudaGraphRunner:
|
|
|
40
42
|
def __init__(self, eagle_worker: EAGLEWorker):
|
|
41
43
|
# Parse args
|
|
42
44
|
self.eagle_worker = eagle_worker
|
|
43
|
-
|
|
44
|
-
|
|
45
|
+
if not hasattr(eagle_worker, "model_runner"):
|
|
46
|
+
# V2: EagleDraftWorker
|
|
47
|
+
self.model_runner = model_runner = eagle_worker.draft_runner
|
|
48
|
+
else:
|
|
49
|
+
self.model_runner = model_runner = eagle_worker.model_runner
|
|
45
50
|
self.graphs = {}
|
|
46
51
|
self.output_buffers = {}
|
|
47
52
|
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
|
|
@@ -58,6 +63,7 @@ class EAGLEDraftCudaGraphRunner:
|
|
|
58
63
|
self.enable_profile_cuda_graph = (
|
|
59
64
|
model_runner.server_args.enable_profile_cuda_graph
|
|
60
65
|
)
|
|
66
|
+
self.deepep_adapter = DeepEPCudaGraphRunnerAdapter()
|
|
61
67
|
server_args = model_runner.server_args
|
|
62
68
|
|
|
63
69
|
# Batch sizes to capture
|
|
@@ -76,6 +82,7 @@ class EAGLEDraftCudaGraphRunner:
|
|
|
76
82
|
self.seq_lens_cpu = torch.full(
|
|
77
83
|
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
|
|
78
84
|
)
|
|
85
|
+
self.extend_seq_lens_cpu = [self.seq_len_fill_value] * self.max_bs
|
|
79
86
|
|
|
80
87
|
if self.enable_torch_compile:
|
|
81
88
|
set_torch_compile_config()
|
|
@@ -87,6 +94,7 @@ class EAGLEDraftCudaGraphRunner:
|
|
|
87
94
|
self.seq_lens = torch.full(
|
|
88
95
|
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
|
|
89
96
|
)
|
|
97
|
+
self.extend_seq_lens = torch.ones((self.max_bs,), dtype=torch.int32)
|
|
90
98
|
self.out_cache_loc = torch.zeros(
|
|
91
99
|
(self.max_num_token * self.speculative_num_steps,), dtype=torch.int64
|
|
92
100
|
)
|
|
@@ -160,6 +168,9 @@ class EAGLEDraftCudaGraphRunner:
|
|
|
160
168
|
# Graph inputs
|
|
161
169
|
req_pool_indices = self.req_pool_indices[:num_seqs]
|
|
162
170
|
seq_lens = self.seq_lens[:num_seqs]
|
|
171
|
+
seq_lens_cpu = self.seq_lens_cpu[:num_seqs]
|
|
172
|
+
extend_seq_lens = self.extend_seq_lens[:num_seqs]
|
|
173
|
+
extend_seq_lens_cpu = self.extend_seq_lens_cpu[:num_seqs]
|
|
163
174
|
out_cache_loc = self.out_cache_loc[: num_tokens * self.speculative_num_steps]
|
|
164
175
|
positions = self.positions[:num_tokens]
|
|
165
176
|
mrope_positions = self.mrope_positions[:, :num_tokens]
|
|
@@ -222,6 +233,9 @@ class EAGLEDraftCudaGraphRunner:
|
|
|
222
233
|
input_ids=None,
|
|
223
234
|
req_pool_indices=req_pool_indices,
|
|
224
235
|
seq_lens=seq_lens,
|
|
236
|
+
seq_lens_cpu=seq_lens_cpu,
|
|
237
|
+
extend_seq_lens=extend_seq_lens,
|
|
238
|
+
extend_seq_lens_cpu=extend_seq_lens_cpu,
|
|
225
239
|
req_to_token_pool=self.model_runner.req_to_token_pool,
|
|
226
240
|
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
|
227
241
|
out_cache_loc=out_cache_loc,
|
|
@@ -250,6 +264,7 @@ class EAGLEDraftCudaGraphRunner:
|
|
|
250
264
|
# Clean intermediate result cache for DP attention
|
|
251
265
|
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
|
|
252
266
|
set_dp_buffer_len(global_dp_buffer_len, num_tokens)
|
|
267
|
+
set_is_extend_in_batch(False)
|
|
253
268
|
|
|
254
269
|
# Backup two fields, which will be modified in-place in `draft_forward`.
|
|
255
270
|
output_cache_loc_backup = forward_batch.out_cache_loc
|
|
@@ -261,6 +276,8 @@ class EAGLEDraftCudaGraphRunner:
|
|
|
261
276
|
forward_batch.spec_info.hidden_states = hidden_states_backup
|
|
262
277
|
return ret
|
|
263
278
|
|
|
279
|
+
self.deepep_adapter.capture(is_extend_in_batch=False)
|
|
280
|
+
|
|
264
281
|
for _ in range(2):
|
|
265
282
|
torch.cuda.synchronize()
|
|
266
283
|
self.model_runner.tp_group.barrier()
|
|
@@ -276,14 +293,14 @@ class EAGLEDraftCudaGraphRunner:
|
|
|
276
293
|
return graph, out
|
|
277
294
|
|
|
278
295
|
def _postprocess_output_to_raw_bs(self, out, raw_bs):
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
parents_list = [x[:raw_bs] for x in parents_list]
|
|
283
|
-
return (score_list, token_list, parents_list)
|
|
296
|
+
# Keep the variables name for readability
|
|
297
|
+
parent_list, top_scores_index, draft_tokens = (t[:raw_bs] for t in out)
|
|
298
|
+
return parent_list, top_scores_index, draft_tokens
|
|
284
299
|
|
|
285
300
|
def replay(self, forward_batch: ForwardBatch):
|
|
286
301
|
assert forward_batch.out_cache_loc is not None
|
|
302
|
+
self.deepep_adapter.replay()
|
|
303
|
+
|
|
287
304
|
raw_bs = forward_batch.batch_size
|
|
288
305
|
raw_num_token = raw_bs * self.num_tokens_per_bs
|
|
289
306
|
|
|
@@ -9,11 +9,13 @@ from sglang.srt.layers.dp_attention import DpPaddingMode, set_dp_buffer_len
|
|
|
9
9
|
from sglang.srt.model_executor.cuda_graph_runner import (
|
|
10
10
|
CUDA_GRAPH_CAPTURE_FAILED_MSG,
|
|
11
11
|
CudaGraphRunner,
|
|
12
|
+
DeepEPCudaGraphRunnerAdapter,
|
|
12
13
|
LogitsProcessorOutput,
|
|
13
14
|
get_batch_sizes_to_capture,
|
|
14
15
|
get_global_graph_memory_pool,
|
|
15
16
|
model_capture_mode,
|
|
16
17
|
set_global_graph_memory_pool,
|
|
18
|
+
set_is_extend_in_batch,
|
|
17
19
|
set_torch_compile_config,
|
|
18
20
|
)
|
|
19
21
|
from sglang.srt.model_executor.forward_batch_info import (
|
|
@@ -38,7 +40,12 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|
|
38
40
|
def __init__(self, eagle_worker: EAGLEWorker):
|
|
39
41
|
# Parse args
|
|
40
42
|
self.eagle_worker = eagle_worker
|
|
41
|
-
|
|
43
|
+
if not hasattr(eagle_worker, "model_runner"):
|
|
44
|
+
# V2: EagleDraftWorker
|
|
45
|
+
self.model_runner = model_runner = eagle_worker.draft_runner
|
|
46
|
+
else:
|
|
47
|
+
self.model_runner = model_runner = eagle_worker.model_runner
|
|
48
|
+
|
|
42
49
|
self.graphs = {}
|
|
43
50
|
self.output_buffers = {}
|
|
44
51
|
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
|
|
@@ -56,6 +63,7 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|
|
56
63
|
)
|
|
57
64
|
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
|
|
58
65
|
self.padded_static_len = -1
|
|
66
|
+
self.deepep_adapter = DeepEPCudaGraphRunnerAdapter()
|
|
59
67
|
|
|
60
68
|
# Attention backend
|
|
61
69
|
self.num_tokens_per_bs = self.speculative_num_steps + 1
|
|
@@ -71,6 +79,7 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|
|
71
79
|
self.seq_lens_cpu = torch.full(
|
|
72
80
|
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
|
|
73
81
|
)
|
|
82
|
+
self.extend_seq_lens_cpu = [self.num_tokens_per_bs] * self.max_bs
|
|
74
83
|
|
|
75
84
|
if self.enable_torch_compile:
|
|
76
85
|
set_torch_compile_config()
|
|
@@ -189,7 +198,9 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|
|
189
198
|
input_ids = self.input_ids[:num_tokens]
|
|
190
199
|
req_pool_indices = self.req_pool_indices[:bs]
|
|
191
200
|
seq_lens = self.seq_lens[:bs]
|
|
201
|
+
seq_lens_cpu = self.seq_lens_cpu[:bs]
|
|
192
202
|
extend_seq_lens = self.extend_seq_lens[:bs]
|
|
203
|
+
extend_seq_lens_cpu = self.extend_seq_lens_cpu[:bs]
|
|
193
204
|
accept_length = self.accept_length[:bs]
|
|
194
205
|
out_cache_loc = self.out_cache_loc[:num_tokens]
|
|
195
206
|
positions = self.positions[:num_tokens]
|
|
@@ -238,6 +249,8 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|
|
238
249
|
)
|
|
239
250
|
spec_info.positions = None
|
|
240
251
|
|
|
252
|
+
self.deepep_adapter.capture(is_extend_in_batch=True)
|
|
253
|
+
|
|
241
254
|
# Forward batch
|
|
242
255
|
forward_batch = ForwardBatch(
|
|
243
256
|
forward_mode=ForwardMode.DRAFT_EXTEND,
|
|
@@ -245,6 +258,7 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|
|
245
258
|
input_ids=input_ids,
|
|
246
259
|
req_pool_indices=req_pool_indices,
|
|
247
260
|
seq_lens=seq_lens,
|
|
261
|
+
seq_lens_cpu=seq_lens_cpu,
|
|
248
262
|
next_token_logits_buffer=next_token_logits_buffer,
|
|
249
263
|
req_to_token_pool=self.model_runner.req_to_token_pool,
|
|
250
264
|
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
|
@@ -262,6 +276,7 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|
|
262
276
|
capture_hidden_mode=CaptureHiddenMode.LAST,
|
|
263
277
|
attn_backend=self.eagle_worker.draft_extend_attn_backend,
|
|
264
278
|
extend_seq_lens=extend_seq_lens,
|
|
279
|
+
extend_seq_lens_cpu=extend_seq_lens_cpu,
|
|
265
280
|
padded_static_len=self.padded_static_len,
|
|
266
281
|
)
|
|
267
282
|
|
|
@@ -280,12 +295,13 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|
|
280
295
|
# Clean intermediate result cache for DP attention
|
|
281
296
|
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
|
|
282
297
|
set_dp_buffer_len(global_dp_buffer_len, num_tokens)
|
|
298
|
+
set_is_extend_in_batch(False)
|
|
283
299
|
|
|
284
300
|
# Backup two fields, which will be modified in-place in `draft_forward`.
|
|
285
301
|
output_cache_loc_backup = forward_batch.out_cache_loc
|
|
286
302
|
hidden_states_backup = forward_batch.spec_info.hidden_states
|
|
287
303
|
|
|
288
|
-
ret = self.
|
|
304
|
+
ret = self.model_runner.model.forward(
|
|
289
305
|
forward_batch.input_ids,
|
|
290
306
|
forward_batch.positions,
|
|
291
307
|
forward_batch,
|
|
@@ -313,6 +329,8 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|
|
313
329
|
|
|
314
330
|
def replay(self, forward_batch: ForwardBatch):
|
|
315
331
|
assert forward_batch.out_cache_loc is not None
|
|
332
|
+
self.deepep_adapter.replay()
|
|
333
|
+
|
|
316
334
|
# batch_size and num_seqs can be different in case there are finished examples
|
|
317
335
|
# in the batch, which will not be counted as num_seqs
|
|
318
336
|
raw_bs = forward_batch.batch_size
|
|
@@ -362,6 +380,9 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|
|
362
380
|
self.seq_lens_cpu.fill_(self.seq_len_fill_value)
|
|
363
381
|
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
|
|
364
382
|
|
|
383
|
+
if forward_batch.extend_seq_lens_cpu is not None:
|
|
384
|
+
self.extend_seq_lens_cpu[:raw_bs] = forward_batch.extend_seq_lens_cpu
|
|
385
|
+
|
|
365
386
|
if bs != raw_bs:
|
|
366
387
|
forward_batch.spec_info.positions = self.positions[:num_tokens]
|
|
367
388
|
forward_batch.spec_info.accept_length = self.accept_length[:bs]
|