sglang 0.5.3rc2__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +47 -28
- sglang/bench_one_batch_server.py +41 -25
- sglang/bench_serving.py +330 -156
- sglang/check_env.py +1 -1
- sglang/compile_deep_gemm.py +6 -2
- sglang/global_config.py +1 -25
- sglang/lang/api.py +6 -0
- sglang/lang/interpreter.py +1 -0
- sglang/lang/ir.py +13 -0
- sglang/launch_server.py +8 -15
- sglang/profiler.py +18 -1
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +4 -6
- sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
- sglang/srt/compilation/backend.py +437 -0
- sglang/srt/compilation/compilation_config.py +20 -0
- sglang/srt/compilation/compilation_counter.py +47 -0
- sglang/srt/compilation/compile.py +210 -0
- sglang/srt/compilation/compiler_interface.py +503 -0
- sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
- sglang/srt/compilation/fix_functionalization.py +134 -0
- sglang/srt/compilation/fx_utils.py +83 -0
- sglang/srt/compilation/inductor_pass.py +140 -0
- sglang/srt/compilation/pass_manager.py +66 -0
- sglang/srt/compilation/piecewise_context_manager.py +40 -0
- sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/deepseek_ocr.py +262 -0
- sglang/srt/configs/deepseekvl2.py +194 -96
- sglang/srt/configs/dots_vlm.py +2 -7
- sglang/srt/configs/falcon_h1.py +13 -64
- sglang/srt/configs/load_config.py +25 -2
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +134 -23
- sglang/srt/configs/modelopt_config.py +30 -0
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/olmo3.py +105 -0
- sglang/srt/configs/points_v15_chat.py +29 -0
- sglang/srt/configs/qwen3_next.py +11 -47
- sglang/srt/configs/qwen3_omni.py +613 -0
- sglang/srt/configs/qwen3_vl.py +0 -10
- sglang/srt/connector/remote_instance.py +1 -1
- sglang/srt/constrained/base_grammar_backend.py +5 -1
- sglang/srt/constrained/llguidance_backend.py +5 -0
- sglang/srt/constrained/outlines_backend.py +1 -1
- sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
- sglang/srt/constrained/utils.py +12 -0
- sglang/srt/constrained/xgrammar_backend.py +20 -11
- sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
- sglang/srt/disaggregation/base/conn.py +17 -4
- sglang/srt/disaggregation/common/conn.py +4 -2
- sglang/srt/disaggregation/decode.py +123 -31
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +11 -3
- sglang/srt/disaggregation/mooncake/conn.py +157 -19
- sglang/srt/disaggregation/nixl/conn.py +69 -24
- sglang/srt/disaggregation/prefill.py +96 -270
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
- sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
- sglang/srt/distributed/device_communicators/pynccl.py +24 -12
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
- sglang/srt/distributed/naive_distributed.py +5 -4
- sglang/srt/distributed/parallel_state.py +70 -19
- sglang/srt/elastic_ep/elastic_ep.py +74 -0
- sglang/srt/entrypoints/context.py +3 -2
- sglang/srt/entrypoints/engine.py +66 -66
- sglang/srt/entrypoints/grpc_server.py +431 -234
- sglang/srt/entrypoints/harmony_utils.py +2 -2
- sglang/srt/entrypoints/http_server.py +120 -8
- sglang/srt/entrypoints/http_server_engine.py +1 -7
- sglang/srt/entrypoints/openai/protocol.py +225 -37
- sglang/srt/entrypoints/openai/serving_base.py +49 -2
- sglang/srt/entrypoints/openai/serving_chat.py +29 -74
- sglang/srt/entrypoints/openai/serving_classify.py +204 -0
- sglang/srt/entrypoints/openai/serving_completions.py +15 -1
- sglang/srt/entrypoints/openai/serving_responses.py +5 -2
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +42 -4
- sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
- sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
- sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
- sglang/srt/eplb/expert_distribution.py +3 -4
- sglang/srt/eplb/expert_location_dispatch.py +2 -2
- sglang/srt/eplb/expert_location_updater.py +2 -2
- sglang/srt/function_call/base_format_detector.py +17 -18
- sglang/srt/function_call/function_call_parser.py +18 -14
- sglang/srt/function_call/glm4_moe_detector.py +1 -5
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/json_array_parser.py +0 -2
- sglang/srt/function_call/utils.py +2 -2
- sglang/srt/grpc/compile_proto.py +3 -3
- sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
- sglang/srt/grpc/health_servicer.py +189 -0
- sglang/srt/grpc/scheduler_launcher.py +181 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
- sglang/srt/layers/activation.py +4 -1
- sglang/srt/layers/attention/aiter_backend.py +3 -3
- sglang/srt/layers/attention/ascend_backend.py +17 -1
- sglang/srt/layers/attention/attention_registry.py +43 -23
- sglang/srt/layers/attention/base_attn_backend.py +20 -1
- sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
- sglang/srt/layers/attention/fla/chunk.py +0 -1
- sglang/srt/layers/attention/fla/chunk_o.py +1 -1
- sglang/srt/layers/attention/fla/index.py +0 -2
- sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
- sglang/srt/layers/attention/fla/utils.py +0 -3
- sglang/srt/layers/attention/fla/wy_fast.py +0 -2
- sglang/srt/layers/attention/flashattention_backend.py +12 -8
- sglang/srt/layers/attention/flashinfer_backend.py +248 -21
- sglang/srt/layers/attention/flashinfer_mla_backend.py +20 -18
- sglang/srt/layers/attention/flashmla_backend.py +2 -2
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
- sglang/srt/layers/attention/intel_amx_backend.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
- sglang/srt/layers/attention/mamba/mamba.py +189 -241
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
- sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
- sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
- sglang/srt/layers/attention/nsa/utils.py +0 -1
- sglang/srt/layers/attention/nsa_backend.py +404 -90
- sglang/srt/layers/attention/triton_backend.py +208 -34
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
- sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
- sglang/srt/layers/attention/trtllm_mla_backend.py +361 -30
- sglang/srt/layers/attention/utils.py +11 -7
- sglang/srt/layers/attention/vision.py +3 -3
- sglang/srt/layers/attention/xpu_backend.py +1028 -0
- sglang/srt/layers/communicator.py +11 -7
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
- sglang/srt/layers/dp_attention.py +17 -0
- sglang/srt/layers/layernorm.py +45 -15
- sglang/srt/layers/linear.py +9 -1
- sglang/srt/layers/logits_processor.py +147 -17
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/cutlass_moe.py +0 -2
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
- sglang/srt/layers/moe/ep_moe/kernels.py +35 -457
- sglang/srt/layers/moe/ep_moe/layer.py +119 -397
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
- sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -70
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/moe_runner/triton.py +3 -1
- sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
- sglang/srt/layers/moe/router.py +51 -15
- sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
- sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
- sglang/srt/layers/moe/token_dispatcher/deepep.py +110 -97
- sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
- sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
- sglang/srt/layers/moe/topk.py +3 -2
- sglang/srt/layers/moe/utils.py +17 -1
- sglang/srt/layers/quantization/__init__.py +2 -53
- sglang/srt/layers/quantization/awq.py +183 -6
- sglang/srt/layers/quantization/awq_triton.py +29 -0
- sglang/srt/layers/quantization/base_config.py +20 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
- sglang/srt/layers/quantization/fp8.py +84 -18
- sglang/srt/layers/quantization/fp8_kernel.py +55 -10
- sglang/srt/layers/quantization/fp8_utils.py +42 -14
- sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
- sglang/srt/layers/quantization/gptq.py +0 -1
- sglang/srt/layers/quantization/int8_kernel.py +18 -2
- sglang/srt/layers/quantization/marlin_utils.py +12 -0
- sglang/srt/layers/quantization/modelopt_quant.py +125 -100
- sglang/srt/layers/quantization/mxfp4.py +5 -30
- sglang/srt/layers/quantization/petit.py +1 -1
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
- sglang/srt/layers/quantization/unquant.py +1 -4
- sglang/srt/layers/quantization/utils.py +0 -1
- sglang/srt/layers/quantization/w4afp8.py +51 -20
- sglang/srt/layers/quantization/w8a8_int8.py +30 -24
- sglang/srt/layers/radix_attention.py +59 -9
- sglang/srt/layers/rotary_embedding.py +673 -16
- sglang/srt/layers/sampler.py +36 -16
- sglang/srt/layers/sparse_pooler.py +98 -0
- sglang/srt/layers/utils.py +0 -1
- sglang/srt/layers/vocab_parallel_embedding.py +4 -1
- sglang/srt/lora/backend/triton_backend.py +0 -1
- sglang/srt/lora/eviction_policy.py +139 -0
- sglang/srt/lora/lora_manager.py +24 -9
- sglang/srt/lora/lora_registry.py +1 -1
- sglang/srt/lora/mem_pool.py +40 -16
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
- sglang/srt/managers/cache_controller.py +48 -17
- sglang/srt/managers/data_parallel_controller.py +146 -42
- sglang/srt/managers/detokenizer_manager.py +40 -13
- sglang/srt/managers/io_struct.py +66 -16
- sglang/srt/managers/mm_utils.py +20 -18
- sglang/srt/managers/multi_tokenizer_mixin.py +66 -81
- sglang/srt/managers/overlap_utils.py +96 -19
- sglang/srt/managers/schedule_batch.py +241 -511
- sglang/srt/managers/schedule_policy.py +15 -2
- sglang/srt/managers/scheduler.py +399 -499
- sglang/srt/managers/scheduler_metrics_mixin.py +55 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
- sglang/srt/managers/scheduler_pp_mixin.py +341 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +57 -10
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
- sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
- sglang/srt/managers/tokenizer_manager.py +378 -90
- sglang/srt/managers/tp_worker.py +212 -161
- sglang/srt/managers/utils.py +78 -2
- sglang/srt/mem_cache/allocator.py +7 -2
- sglang/srt/mem_cache/allocator_ascend.py +2 -2
- sglang/srt/mem_cache/base_prefix_cache.py +2 -2
- sglang/srt/mem_cache/chunk_cache.py +13 -2
- sglang/srt/mem_cache/common.py +480 -0
- sglang/srt/mem_cache/evict_policy.py +16 -1
- sglang/srt/mem_cache/hicache_storage.py +4 -1
- sglang/srt/mem_cache/hiradix_cache.py +16 -3
- sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
- sglang/srt/mem_cache/memory_pool.py +435 -219
- sglang/srt/mem_cache/memory_pool_host.py +0 -1
- sglang/srt/mem_cache/multimodal_cache.py +0 -1
- sglang/srt/mem_cache/radix_cache.py +53 -19
- sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
- sglang/srt/mem_cache/storage/backend_factory.py +2 -2
- sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
- sglang/srt/mem_cache/swa_radix_cache.py +92 -26
- sglang/srt/metrics/collector.py +31 -0
- sglang/srt/metrics/func_timer.py +1 -1
- sglang/srt/model_executor/cuda_graph_runner.py +43 -5
- sglang/srt/model_executor/forward_batch_info.py +28 -23
- sglang/srt/model_executor/model_runner.py +379 -139
- sglang/srt/model_executor/npu_graph_runner.py +2 -3
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
- sglang/srt/model_loader/__init__.py +1 -1
- sglang/srt/model_loader/loader.py +424 -27
- sglang/srt/model_loader/utils.py +0 -1
- sglang/srt/model_loader/weight_utils.py +47 -28
- sglang/srt/models/apertus.py +2 -3
- sglang/srt/models/arcee.py +2 -2
- sglang/srt/models/bailing_moe.py +13 -52
- sglang/srt/models/bailing_moe_nextn.py +3 -4
- sglang/srt/models/bert.py +1 -1
- sglang/srt/models/deepseek_nextn.py +19 -3
- sglang/srt/models/deepseek_ocr.py +1516 -0
- sglang/srt/models/deepseek_v2.py +273 -98
- sglang/srt/models/dots_ocr.py +0 -2
- sglang/srt/models/dots_vlm.py +0 -1
- sglang/srt/models/dots_vlm_vit.py +1 -1
- sglang/srt/models/falcon_h1.py +13 -19
- sglang/srt/models/gemma3_mm.py +16 -0
- sglang/srt/models/gemma3n_mm.py +1 -2
- sglang/srt/models/glm4_moe.py +14 -37
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +2 -1
- sglang/srt/models/glm4v_moe.py +5 -5
- sglang/srt/models/gpt_oss.py +5 -5
- sglang/srt/models/grok.py +10 -23
- sglang/srt/models/hunyuan.py +2 -7
- sglang/srt/models/interns1.py +0 -1
- sglang/srt/models/kimi_vl.py +1 -7
- sglang/srt/models/kimi_vl_moonvit.py +3 -1
- sglang/srt/models/llama.py +2 -2
- sglang/srt/models/llama_eagle3.py +1 -1
- sglang/srt/models/longcat_flash.py +5 -22
- sglang/srt/models/longcat_flash_nextn.py +3 -14
- sglang/srt/models/mimo.py +2 -13
- sglang/srt/models/mimo_mtp.py +1 -2
- sglang/srt/models/minicpmo.py +7 -5
- sglang/srt/models/mixtral.py +1 -4
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/mllama4.py +13 -3
- sglang/srt/models/nemotron_h.py +511 -0
- sglang/srt/models/olmo2.py +31 -4
- sglang/srt/models/opt.py +5 -5
- sglang/srt/models/phi.py +1 -1
- sglang/srt/models/phi4mm.py +1 -1
- sglang/srt/models/phimoe.py +0 -1
- sglang/srt/models/pixtral.py +0 -3
- sglang/srt/models/points_v15_chat.py +186 -0
- sglang/srt/models/qwen.py +0 -1
- sglang/srt/models/qwen2_5_vl.py +3 -3
- sglang/srt/models/qwen2_audio.py +2 -15
- sglang/srt/models/qwen2_moe.py +15 -12
- sglang/srt/models/qwen2_vl.py +5 -2
- sglang/srt/models/qwen3_moe.py +19 -35
- sglang/srt/models/qwen3_next.py +7 -12
- sglang/srt/models/qwen3_next_mtp.py +3 -4
- sglang/srt/models/qwen3_omni_moe.py +661 -0
- sglang/srt/models/qwen3_vl.py +37 -33
- sglang/srt/models/qwen3_vl_moe.py +57 -185
- sglang/srt/models/roberta.py +55 -3
- sglang/srt/models/sarashina2_vision.py +0 -1
- sglang/srt/models/step3_vl.py +3 -5
- sglang/srt/models/utils.py +11 -1
- sglang/srt/multimodal/processors/base_processor.py +6 -2
- sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
- sglang/srt/multimodal/processors/dots_vlm.py +0 -1
- sglang/srt/multimodal/processors/glm4v.py +1 -5
- sglang/srt/multimodal/processors/internvl.py +0 -2
- sglang/srt/multimodal/processors/janus_pro.py +0 -1
- sglang/srt/multimodal/processors/mllama4.py +0 -8
- sglang/srt/multimodal/processors/phi4mm.py +0 -1
- sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
- sglang/srt/multimodal/processors/qwen_vl.py +75 -16
- sglang/srt/multimodal/processors/step3_vl.py +1 -1
- sglang/srt/parser/conversation.py +41 -0
- sglang/srt/parser/reasoning_parser.py +0 -1
- sglang/srt/sampling/custom_logit_processor.py +77 -2
- sglang/srt/sampling/sampling_batch_info.py +17 -22
- sglang/srt/sampling/sampling_params.py +70 -2
- sglang/srt/server_args.py +577 -73
- sglang/srt/server_args_config_parser.py +1 -1
- sglang/srt/single_batch_overlap.py +38 -28
- sglang/srt/speculative/base_spec_worker.py +34 -0
- sglang/srt/speculative/draft_utils.py +226 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
- sglang/srt/speculative/eagle_info.py +57 -18
- sglang/srt/speculative/eagle_info_v2.py +458 -0
- sglang/srt/speculative/eagle_utils.py +138 -0
- sglang/srt/speculative/eagle_worker.py +83 -280
- sglang/srt/speculative/eagle_worker_v2.py +702 -0
- sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
- sglang/srt/speculative/ngram_worker.py +12 -11
- sglang/srt/speculative/spec_info.py +2 -0
- sglang/srt/speculative/spec_utils.py +38 -3
- sglang/srt/speculative/standalone_worker.py +4 -14
- sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
- sglang/srt/two_batch_overlap.py +28 -14
- sglang/srt/utils/__init__.py +1 -1
- sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
- sglang/srt/utils/common.py +192 -47
- sglang/srt/utils/hf_transformers_utils.py +40 -17
- sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
- sglang/srt/{offloader.py → utils/offloader.py} +4 -4
- sglang/srt/utils/profile_merger.py +199 -0
- sglang/test/attention/test_flashattn_backend.py +1 -1
- sglang/test/attention/test_flashattn_mla_backend.py +0 -1
- sglang/test/attention/test_prefix_chunk_info.py +0 -2
- sglang/test/attention/test_trtllm_mla_backend.py +221 -53
- sglang/test/few_shot_gsm8k_engine.py +2 -4
- sglang/test/kit_matched_stop.py +157 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +41 -0
- sglang/test/runners.py +2 -0
- sglang/test/send_one.py +42 -7
- sglang/test/simple_eval_common.py +3 -0
- sglang/test/simple_eval_gpqa.py +0 -1
- sglang/test/simple_eval_humaneval.py +0 -3
- sglang/test/simple_eval_longbench_v2.py +344 -0
- sglang/test/test_block_fp8.py +1 -2
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
- sglang/test/test_cutlass_moe.py +1 -2
- sglang/test/test_cutlass_w4a8_moe.py +10 -20
- sglang/test/test_deterministic.py +232 -99
- sglang/test/test_deterministic_utils.py +73 -0
- sglang/test/test_disaggregation_utils.py +81 -0
- sglang/test/test_marlin_moe.py +0 -1
- sglang/test/test_utils.py +85 -20
- sglang/version.py +1 -1
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/METADATA +45 -33
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/RECORD +404 -345
- sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
- sglang/srt/speculative/build_eagle_tree.py +0 -427
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
- /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
- /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
|
@@ -7,6 +7,7 @@ FlashInfer is faster and Triton is easier to customize.
|
|
|
7
7
|
Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
|
|
8
8
|
"""
|
|
9
9
|
|
|
10
|
+
import logging
|
|
10
11
|
import os
|
|
11
12
|
from dataclasses import dataclass
|
|
12
13
|
from enum import Enum, auto
|
|
@@ -15,20 +16,13 @@ from typing import TYPE_CHECKING, Callable, List, Optional, Union
|
|
|
15
16
|
|
|
16
17
|
import torch
|
|
17
18
|
|
|
18
|
-
|
|
19
|
-
import logging
|
|
20
|
-
|
|
21
|
-
torch._logging.set_logs(dynamo=logging.ERROR)
|
|
22
|
-
torch._dynamo.config.suppress_errors = True
|
|
23
|
-
|
|
24
|
-
from sglang.global_config import global_config
|
|
19
|
+
from sglang.srt.environ import envs
|
|
25
20
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
|
26
21
|
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
|
27
22
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
|
28
23
|
from sglang.srt.layers.radix_attention import AttentionType
|
|
29
24
|
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
|
|
30
25
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
|
31
|
-
from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput
|
|
32
26
|
from sglang.srt.speculative.spec_info import SpecInput
|
|
33
27
|
from sglang.srt.utils import (
|
|
34
28
|
get_int_env_var,
|
|
@@ -41,6 +35,12 @@ if TYPE_CHECKING:
|
|
|
41
35
|
from sglang.srt.layers.radix_attention import RadixAttention
|
|
42
36
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
|
43
37
|
|
|
38
|
+
logger = logging.getLogger(__name__)
|
|
39
|
+
|
|
40
|
+
if envs.SGLANG_ENABLE_TORCH_COMPILE.get():
|
|
41
|
+
torch._logging.set_logs(dynamo=logging.ERROR)
|
|
42
|
+
torch._dynamo.config.suppress_errors = True
|
|
43
|
+
|
|
44
44
|
|
|
45
45
|
if is_flashinfer_available():
|
|
46
46
|
from flashinfer import (
|
|
@@ -50,7 +50,6 @@ if is_flashinfer_available():
|
|
|
50
50
|
fast_decode_plan,
|
|
51
51
|
)
|
|
52
52
|
from flashinfer.cascade import merge_state
|
|
53
|
-
from flashinfer.decode import _get_range_buf, get_seq_lens
|
|
54
53
|
|
|
55
54
|
|
|
56
55
|
class WrapperDispatch(Enum):
|
|
@@ -58,6 +57,36 @@ class WrapperDispatch(Enum):
|
|
|
58
57
|
CROSS_ATTENTION = auto()
|
|
59
58
|
|
|
60
59
|
|
|
60
|
+
@dataclass
|
|
61
|
+
class MultiItemScoringParams:
|
|
62
|
+
"""Parameters for multi-item scoring in attention computation.
|
|
63
|
+
|
|
64
|
+
Used when processing sequences with multiple items separated by delimiters,
|
|
65
|
+
where each item needs specific attention patterns that respect item boundaries.
|
|
66
|
+
|
|
67
|
+
Attributes:
|
|
68
|
+
prefix_len_ptr: A uint32 1D tensor indicating the prefix length of each prompt.
|
|
69
|
+
The tensor size is equal to the batch size.
|
|
70
|
+
token_pos_in_items_ptr: A uint16 1D tensor indicating the token position of each item
|
|
71
|
+
starting from 0 (delimiter) for each item. For batch size > 1,
|
|
72
|
+
sequences are concatenated with zero padding to ensure same length.
|
|
73
|
+
token_pos_in_items_len: Zero padding length for token_pos_in_items_ptr to handle
|
|
74
|
+
batch_size > 1 case. Defines the padded length for each sequence.
|
|
75
|
+
max_item_len_ptr: A uint16 tensor containing the max token length of all items
|
|
76
|
+
for each prompt in the batch.
|
|
77
|
+
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
prefix_len_ptr: Optional[torch.Tensor] = None
|
|
81
|
+
token_pos_in_items_ptr: Optional[torch.Tensor] = None
|
|
82
|
+
token_pos_in_items_len: int = 0
|
|
83
|
+
max_item_len_ptr: Optional[torch.Tensor] = None
|
|
84
|
+
|
|
85
|
+
def is_enabled(self) -> bool:
|
|
86
|
+
"""Check if multi-item scoring is enabled."""
|
|
87
|
+
return self.prefix_len_ptr is not None
|
|
88
|
+
|
|
89
|
+
|
|
61
90
|
@dataclass
|
|
62
91
|
class DecodeMetadata:
|
|
63
92
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper]
|
|
@@ -68,6 +97,7 @@ class PrefillMetadata:
|
|
|
68
97
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper]
|
|
69
98
|
use_ragged: bool
|
|
70
99
|
extend_no_prefix: bool
|
|
100
|
+
multi_item_params: Optional[MultiItemScoringParams] = None
|
|
71
101
|
|
|
72
102
|
|
|
73
103
|
# Reuse this workspace buffer across all flashinfer wrappers
|
|
@@ -87,9 +117,15 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
|
87
117
|
skip_prefill: bool = False,
|
|
88
118
|
kv_indptr_buf: Optional[torch.Tensor] = None,
|
|
89
119
|
kv_last_page_len_buf: Optional[torch.Tensor] = None,
|
|
120
|
+
init_new_workspace: bool = False,
|
|
90
121
|
):
|
|
91
122
|
super().__init__()
|
|
92
123
|
|
|
124
|
+
# Store multi-item scoring delimiter for efficient access
|
|
125
|
+
self.multi_item_scoring_delimiter = (
|
|
126
|
+
model_runner.server_args.multi_item_scoring_delimiter
|
|
127
|
+
)
|
|
128
|
+
|
|
93
129
|
# Parse constants
|
|
94
130
|
self.decode_use_tensor_cores = should_use_tensor_core(
|
|
95
131
|
kv_cache_dtype=model_runner.kv_cache_dtype,
|
|
@@ -124,7 +160,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
|
124
160
|
or "Qwen3ForCausalLM" in model_runner.model_config.hf_config.architectures
|
|
125
161
|
or "MiMoForCausalLM" in model_runner.model_config.hf_config.architectures
|
|
126
162
|
):
|
|
127
|
-
|
|
163
|
+
envs.SGLANG_FLASHINFER_WORKSPACE_SIZE.set(512 * 1024 * 1024)
|
|
128
164
|
|
|
129
165
|
# When deterministic inference is enabled, tensor cores should be used for decode
|
|
130
166
|
# Also set split tile sizes for prefill and decode from environment variables, and disable kv split for cuda graph
|
|
@@ -144,19 +180,26 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
|
144
180
|
"SGLANG_FLASHINFER_DECODE_SPLIT_TILE_SIZE", 2048
|
|
145
181
|
)
|
|
146
182
|
self.disable_cuda_graph_kv_split = True
|
|
147
|
-
|
|
183
|
+
envs.SGLANG_FLASHINFER_WORKSPACE_SIZE.set(2048 * 1024 * 1024)
|
|
148
184
|
|
|
149
185
|
# Allocate buffers
|
|
150
186
|
global global_workspace_buffer
|
|
151
187
|
if global_workspace_buffer is None:
|
|
152
188
|
# different from flashinfer zero_init_global_workspace_buffer
|
|
153
|
-
global_workspace_size =
|
|
189
|
+
global_workspace_size = envs.SGLANG_FLASHINFER_WORKSPACE_SIZE.get()
|
|
154
190
|
global_workspace_buffer = torch.empty(
|
|
155
191
|
global_workspace_size,
|
|
156
192
|
dtype=torch.uint8,
|
|
157
193
|
device=model_runner.device,
|
|
158
194
|
)
|
|
159
|
-
|
|
195
|
+
if init_new_workspace:
|
|
196
|
+
self.workspace_buffer = torch.empty(
|
|
197
|
+
envs.SGLANG_FLASHINFER_WORKSPACE_SIZE.get(),
|
|
198
|
+
dtype=torch.uint8,
|
|
199
|
+
device=model_runner.device,
|
|
200
|
+
)
|
|
201
|
+
else:
|
|
202
|
+
self.workspace_buffer = global_workspace_buffer
|
|
160
203
|
max_bs = model_runner.req_to_token_pool.size
|
|
161
204
|
if kv_indptr_buf is None:
|
|
162
205
|
self.kv_indptr = [
|
|
@@ -229,10 +272,133 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
|
229
272
|
|
|
230
273
|
# Other metadata
|
|
231
274
|
self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
|
|
275
|
+
|
|
232
276
|
self.decode_cuda_graph_metadata = {}
|
|
233
277
|
self.prefill_cuda_graph_metadata = {} # For verify
|
|
234
278
|
self.draft_extend_cuda_graph_metadata = {} # For draft extend
|
|
235
279
|
|
|
280
|
+
def _process_multi_item_scoring(
|
|
281
|
+
self, forward_batch: ForwardBatch
|
|
282
|
+
) -> MultiItemScoringParams:
|
|
283
|
+
"""Process multi-item scoring tensors for FlashInfer attention.
|
|
284
|
+
|
|
285
|
+
This method handles sequences containing multiple "items" separated by delimiter tokens,
|
|
286
|
+
where each item needs specific attention patterns that respect item boundaries.
|
|
287
|
+
|
|
288
|
+
The method produces four key tensors for FlashInfer:
|
|
289
|
+
- prefix_len_ptr: uint32 tensor with prefix length for each prompt in batch
|
|
290
|
+
- token_pos_in_items_ptr: uint16 tensor with token positions starting from 0 at delimiters
|
|
291
|
+
- token_pos_in_items_len: padding length for batch processing
|
|
292
|
+
- max_item_len_ptr: uint16 tensor with max item length for each prompt
|
|
293
|
+
|
|
294
|
+
Args:
|
|
295
|
+
forward_batch: The forward batch containing input sequences and delimiter info
|
|
296
|
+
|
|
297
|
+
Returns:
|
|
298
|
+
MultiItemScoringParams: The processed multi-item scoring parameters
|
|
299
|
+
|
|
300
|
+
Examples:
|
|
301
|
+
Following FlashInfer definition: for 3 items of length 3, 2, 4 respectively:
|
|
302
|
+
token_pos_in_items_ptr = [0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0]
|
|
303
|
+
|
|
304
|
+
Case 1: Single sequence
|
|
305
|
+
Text: "What is the capital of France? <delim> London <delim> Paris <delim> Berlin <delim>"
|
|
306
|
+
Tokens: [What, is, the, capital, of, France, ?, <delim>, London, <delim>, Paris, <delim>, Berlin, <delim>]
|
|
307
|
+
Indices: [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
|
|
308
|
+
- prefix_len_ptr: [7] (query length before first delimiter)
|
|
309
|
+
- token_pos_in_items_ptr: [0, 1, 0, 1, 0, 1, 0] (delim=0, London=1, delim=0, Paris=1, delim=0, Berlin=1, delim=0)
|
|
310
|
+
- token_pos_in_items_len: 7 (actual length)
|
|
311
|
+
- max_item_len_ptr: [1] (max item length is 1 token - all options are single tokens)
|
|
312
|
+
|
|
313
|
+
Case 2: Batch processing (batch_size=2)
|
|
314
|
+
Sequence 1: 2 items of length 2, 1 → [0, 1, 2, 0, 1, 0] (6 elements)
|
|
315
|
+
Sequence 2: 3 items of length 1, 3, 2 → [0, 1, 0, 1, 2, 3, 0, 1, 2, 0] (10 elements)
|
|
316
|
+
After padding both to length 10:
|
|
317
|
+
- token_pos_in_items_ptr: [0, 1, 2, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 2, 3, 0, 1, 2, 0]
|
|
318
|
+
- token_pos_in_items_len: 10 (padded length for batch processing)
|
|
319
|
+
- max_item_len_ptr: [2, 3] (max lengths per sequence)
|
|
320
|
+
"""
|
|
321
|
+
|
|
322
|
+
delimiter = self.multi_item_scoring_delimiter
|
|
323
|
+
if delimiter is None or forward_batch.forward_mode == ForwardMode.DECODE:
|
|
324
|
+
return MultiItemScoringParams()
|
|
325
|
+
|
|
326
|
+
delimiter_mask = forward_batch.input_ids == delimiter
|
|
327
|
+
prefix_cache_lens = getattr(forward_batch, "extend_prefix_lens", None)
|
|
328
|
+
extend_seq_lens = getattr(forward_batch, "extend_seq_lens", None)
|
|
329
|
+
prefix_len_ptr, token_pos_in_items_ptr = [], []
|
|
330
|
+
token_pos_in_items_len = 0
|
|
331
|
+
|
|
332
|
+
# If no extend_seq_lens, treat whole batch as one sequence
|
|
333
|
+
if extend_seq_lens is None or len(extend_seq_lens) <= 1:
|
|
334
|
+
extend_seq_lens = [forward_batch.input_ids.size(0)]
|
|
335
|
+
|
|
336
|
+
seq_start = 0
|
|
337
|
+
for i, seq_len in enumerate(extend_seq_lens):
|
|
338
|
+
seq_end = seq_start + seq_len
|
|
339
|
+
mask = delimiter_mask[seq_start:seq_end]
|
|
340
|
+
pos = forward_batch.positions[seq_start:seq_end]
|
|
341
|
+
delimiter_indices = torch.nonzero(mask, as_tuple=True)[0]
|
|
342
|
+
|
|
343
|
+
if len(delimiter_indices) > 0:
|
|
344
|
+
first_delim = delimiter_indices[0]
|
|
345
|
+
# Prefix length: store as scalar
|
|
346
|
+
prefix_len = first_delim + (
|
|
347
|
+
prefix_cache_lens[i] if prefix_cache_lens is not None else 0
|
|
348
|
+
)
|
|
349
|
+
prefix_len_ptr.append(
|
|
350
|
+
prefix_len.item() if torch.is_tensor(prefix_len) else prefix_len
|
|
351
|
+
)
|
|
352
|
+
|
|
353
|
+
# Compute relative positions within items after delimiters
|
|
354
|
+
diff = pos[first_delim:] - torch.cummax(mask[first_delim:], 0)[1]
|
|
355
|
+
token_pos = (diff - pos[first_delim]).to(torch.uint16)
|
|
356
|
+
token_pos_in_items_ptr.append(token_pos)
|
|
357
|
+
|
|
358
|
+
# Update forward_batch positions in-place
|
|
359
|
+
pos[first_delim:] = diff - 1
|
|
360
|
+
forward_batch.positions[seq_start:seq_end] = pos
|
|
361
|
+
|
|
362
|
+
seq_start = seq_end
|
|
363
|
+
|
|
364
|
+
# Pad token_pos_in_items_ptr for batch processing
|
|
365
|
+
if token_pos_in_items_ptr:
|
|
366
|
+
token_pos_in_items_len = max(t.numel() for t in token_pos_in_items_ptr)
|
|
367
|
+
device = forward_batch.input_ids.device
|
|
368
|
+
token_pos_in_items_ptr = [
|
|
369
|
+
torch.cat(
|
|
370
|
+
[
|
|
371
|
+
t,
|
|
372
|
+
torch.zeros(
|
|
373
|
+
token_pos_in_items_len - t.numel(),
|
|
374
|
+
dtype=torch.uint16,
|
|
375
|
+
device=device,
|
|
376
|
+
),
|
|
377
|
+
]
|
|
378
|
+
)
|
|
379
|
+
for t in token_pos_in_items_ptr
|
|
380
|
+
]
|
|
381
|
+
|
|
382
|
+
if not prefix_len_ptr or not token_pos_in_items_ptr:
|
|
383
|
+
return MultiItemScoringParams()
|
|
384
|
+
|
|
385
|
+
# Build final params
|
|
386
|
+
device = forward_batch.input_ids.device
|
|
387
|
+
return MultiItemScoringParams(
|
|
388
|
+
prefix_len_ptr=torch.tensor(
|
|
389
|
+
prefix_len_ptr, dtype=torch.uint32, device=device
|
|
390
|
+
),
|
|
391
|
+
token_pos_in_items_ptr=torch.cat(token_pos_in_items_ptr, dim=0),
|
|
392
|
+
token_pos_in_items_len=token_pos_in_items_len & 0xFFFFFFFF,
|
|
393
|
+
max_item_len_ptr=torch.stack(
|
|
394
|
+
[
|
|
395
|
+
t.to(torch.int32).max().to(torch.uint16)
|
|
396
|
+
for t in token_pos_in_items_ptr
|
|
397
|
+
],
|
|
398
|
+
dim=0,
|
|
399
|
+
),
|
|
400
|
+
)
|
|
401
|
+
|
|
236
402
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
|
237
403
|
if forward_batch.forward_mode.is_decode_or_idle():
|
|
238
404
|
self.indices_updater_decode.update(
|
|
@@ -280,13 +446,26 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
|
280
446
|
else:
|
|
281
447
|
prefix_lens = forward_batch.extend_prefix_lens
|
|
282
448
|
|
|
283
|
-
|
|
449
|
+
# Disable ragged wrapper and ensure prefix handling for multimodal and multi-item scoring
|
|
450
|
+
if self.is_multimodal or self.multi_item_scoring_delimiter is not None:
|
|
451
|
+
# use_ragged = False: Multi-item scoring requires the paged wrapper because:
|
|
452
|
+
# 1. Ragged wrapper doesn't support the specialized multi-item parameters
|
|
453
|
+
# (prefix_len_ptr, token_pos_in_items_ptr, etc.)
|
|
454
|
+
# 2. Paged wrapper provides better control over attention masking needed
|
|
455
|
+
# for respecting item boundaries in multi-item sequences
|
|
456
|
+
# 3. Custom masking logic conflicts with ragged wrapper's assumptions
|
|
284
457
|
use_ragged = False
|
|
285
458
|
extend_no_prefix = False
|
|
286
459
|
else:
|
|
287
460
|
use_ragged = not self.enable_deterministic
|
|
288
461
|
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
|
|
289
462
|
|
|
463
|
+
# Process multi-item scoring in attention backend instead of ForwardBatch
|
|
464
|
+
multi_item_params = MultiItemScoringParams()
|
|
465
|
+
if self.multi_item_scoring_delimiter is not None:
|
|
466
|
+
# Use new backend-specific implementation
|
|
467
|
+
multi_item_params = self._process_multi_item_scoring(forward_batch)
|
|
468
|
+
|
|
290
469
|
self.indices_updater_prefill.update(
|
|
291
470
|
forward_batch.req_pool_indices,
|
|
292
471
|
forward_batch.seq_lens,
|
|
@@ -298,9 +477,13 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
|
298
477
|
encoder_lens=forward_batch.encoder_lens,
|
|
299
478
|
spec_info=None,
|
|
300
479
|
fixed_split_size=self.prefill_split_tile_size,
|
|
480
|
+
multi_item_params=multi_item_params,
|
|
301
481
|
)
|
|
302
482
|
self.forward_metadata = PrefillMetadata(
|
|
303
|
-
self.prefill_wrappers_paged,
|
|
483
|
+
self.prefill_wrappers_paged,
|
|
484
|
+
use_ragged,
|
|
485
|
+
extend_no_prefix,
|
|
486
|
+
multi_item_params,
|
|
304
487
|
)
|
|
305
488
|
|
|
306
489
|
def init_cuda_graph_state(
|
|
@@ -531,7 +714,20 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
|
531
714
|
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
|
532
715
|
causal=not layer.is_cross_attention,
|
|
533
716
|
sm_scale=layer.scaling,
|
|
534
|
-
|
|
717
|
+
# Disable sliding window attention for multi-item scoring:
|
|
718
|
+
# - Sliding window could cut across item boundaries, breaking semantic coherence
|
|
719
|
+
# - Multi-item sequences need full attention to properly handle delimiter tokens
|
|
720
|
+
# - Specialized multi-item parameters (prefix_len_ptr, token_pos_in_items_ptr)
|
|
721
|
+
# provide more precise attention control than simple sliding windows
|
|
722
|
+
# - Item-aware masking takes precedence over window-based masking
|
|
723
|
+
window_left=(
|
|
724
|
+
layer.sliding_window_size
|
|
725
|
+
if not (
|
|
726
|
+
self.forward_metadata.multi_item_params
|
|
727
|
+
and self.forward_metadata.multi_item_params.is_enabled()
|
|
728
|
+
)
|
|
729
|
+
else -1
|
|
730
|
+
),
|
|
535
731
|
logits_soft_cap=logits_soft_cap,
|
|
536
732
|
# Must use _float to avoid device-to-host copy that breaks cuda graph capture.
|
|
537
733
|
k_scale=layer.k_scale_float,
|
|
@@ -539,9 +735,13 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
|
539
735
|
)
|
|
540
736
|
else:
|
|
541
737
|
causal = True
|
|
542
|
-
if
|
|
543
|
-
|
|
738
|
+
if (
|
|
739
|
+
layer.is_cross_attention
|
|
740
|
+
or layer.attn_type == AttentionType.ENCODER_ONLY
|
|
741
|
+
):
|
|
544
742
|
causal = False
|
|
743
|
+
if save_kv_cache and layer.attn_type == AttentionType.ENCODER_ONLY:
|
|
744
|
+
save_kv_cache = False
|
|
545
745
|
|
|
546
746
|
if self.forward_metadata.extend_no_prefix:
|
|
547
747
|
# NOTE: FlashInfer currently has limitations with head_dim = 32 or other dimensions
|
|
@@ -952,6 +1152,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
|
952
1152
|
encoder_lens: Optional[torch.Tensor],
|
|
953
1153
|
spec_info: Optional[SpecInput],
|
|
954
1154
|
fixed_split_size: Optional[int] = None,
|
|
1155
|
+
multi_item_params: Optional[MultiItemScoringParams] = None,
|
|
955
1156
|
):
|
|
956
1157
|
if use_ragged:
|
|
957
1158
|
# TODO: remove this device sync, we can use forward_batch.extend_prefix_lens_cpu
|
|
@@ -976,6 +1177,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
|
976
1177
|
use_ragged,
|
|
977
1178
|
spec_info,
|
|
978
1179
|
fixed_split_size=fixed_split_size,
|
|
1180
|
+
multi_item_params=multi_item_params,
|
|
979
1181
|
)
|
|
980
1182
|
|
|
981
1183
|
def update_sliding_window(
|
|
@@ -990,6 +1192,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
|
990
1192
|
encoder_lens: Optional[torch.Tensor],
|
|
991
1193
|
spec_info: Optional[SpecInput],
|
|
992
1194
|
fixed_split_size: Optional[int] = None,
|
|
1195
|
+
multi_item_params: Optional[MultiItemScoringParams] = None,
|
|
993
1196
|
):
|
|
994
1197
|
for wrapper_id in range(2):
|
|
995
1198
|
if wrapper_id == 0:
|
|
@@ -1023,6 +1226,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
|
1023
1226
|
use_ragged,
|
|
1024
1227
|
spec_info,
|
|
1025
1228
|
use_sliding_window_kv_pool=use_sliding_window_kv_pool,
|
|
1229
|
+
multi_item_params=multi_item_params,
|
|
1026
1230
|
)
|
|
1027
1231
|
|
|
1028
1232
|
def update_cross_attention(
|
|
@@ -1037,6 +1241,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
|
1037
1241
|
encoder_lens: Optional[torch.Tensor],
|
|
1038
1242
|
spec_info: Optional[SpecInput],
|
|
1039
1243
|
fixed_split_size: Optional[int] = None,
|
|
1244
|
+
multi_item_params: Optional[MultiItemScoringParams] = None,
|
|
1040
1245
|
):
|
|
1041
1246
|
for wrapper_id in range(2):
|
|
1042
1247
|
if wrapper_id == 0:
|
|
@@ -1063,6 +1268,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
|
1063
1268
|
self.qo_indptr[wrapper_id],
|
|
1064
1269
|
use_ragged,
|
|
1065
1270
|
spec_info,
|
|
1271
|
+
multi_item_params=multi_item_params,
|
|
1066
1272
|
)
|
|
1067
1273
|
|
|
1068
1274
|
def call_begin_forward(
|
|
@@ -1081,6 +1287,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
|
1081
1287
|
spec_info: Optional[SpecInput],
|
|
1082
1288
|
use_sliding_window_kv_pool: bool = False,
|
|
1083
1289
|
fixed_split_size: Optional[int] = None,
|
|
1290
|
+
multi_item_params: Optional[MultiItemScoringParams] = None,
|
|
1084
1291
|
):
|
|
1085
1292
|
bs = len(seq_lens)
|
|
1086
1293
|
if spec_info is None:
|
|
@@ -1136,6 +1343,22 @@ class FlashInferIndicesUpdaterPrefill:
|
|
|
1136
1343
|
)
|
|
1137
1344
|
|
|
1138
1345
|
# cached part
|
|
1346
|
+
# Conditionally set multi-item parameters
|
|
1347
|
+
if multi_item_params is not None and multi_item_params.is_enabled():
|
|
1348
|
+
# Multi-item scoring is active - use specialized parameters and disable generic custom_mask
|
|
1349
|
+
use_custom_mask = None
|
|
1350
|
+
prefix_len_ptr = multi_item_params.prefix_len_ptr
|
|
1351
|
+
token_pos_in_items_ptr = multi_item_params.token_pos_in_items_ptr
|
|
1352
|
+
token_pos_in_items_len = multi_item_params.token_pos_in_items_len
|
|
1353
|
+
max_item_len_ptr = multi_item_params.max_item_len_ptr
|
|
1354
|
+
else:
|
|
1355
|
+
# No multi-item scoring - use standard parameters
|
|
1356
|
+
use_custom_mask = custom_mask
|
|
1357
|
+
prefix_len_ptr = None
|
|
1358
|
+
token_pos_in_items_ptr = None
|
|
1359
|
+
token_pos_in_items_len = 0
|
|
1360
|
+
max_item_len_ptr = None
|
|
1361
|
+
|
|
1139
1362
|
wrapper_paged.begin_forward(
|
|
1140
1363
|
qo_indptr,
|
|
1141
1364
|
kv_indptr,
|
|
@@ -1147,9 +1370,13 @@ class FlashInferIndicesUpdaterPrefill:
|
|
|
1147
1370
|
1,
|
|
1148
1371
|
q_data_type=self.q_data_type,
|
|
1149
1372
|
kv_data_type=self.data_type,
|
|
1150
|
-
custom_mask=
|
|
1373
|
+
custom_mask=use_custom_mask,
|
|
1151
1374
|
non_blocking=True,
|
|
1152
1375
|
fixed_split_size=fixed_split_size,
|
|
1376
|
+
prefix_len_ptr=prefix_len_ptr,
|
|
1377
|
+
token_pos_in_items_ptr=token_pos_in_items_ptr,
|
|
1378
|
+
token_pos_in_items_len=token_pos_in_items_len,
|
|
1379
|
+
max_item_len_ptr=max_item_len_ptr,
|
|
1153
1380
|
)
|
|
1154
1381
|
|
|
1155
1382
|
|
|
@@ -1185,7 +1412,7 @@ class FlashInferMultiStepDraftBackend:
|
|
|
1185
1412
|
(max_bs,), dtype=torch.int32, device=model_runner.device
|
|
1186
1413
|
)
|
|
1187
1414
|
self.attn_backends: List[FlashInferAttnBackend] = []
|
|
1188
|
-
for i in range(self.speculative_num_steps):
|
|
1415
|
+
for i in range(self.speculative_num_steps - 1):
|
|
1189
1416
|
self.attn_backends.append(
|
|
1190
1417
|
FlashInferAttnBackend(
|
|
1191
1418
|
model_runner,
|
|
@@ -1273,7 +1500,7 @@ class FlashInferMultiStepDraftBackend:
|
|
|
1273
1500
|
device="cuda",
|
|
1274
1501
|
)
|
|
1275
1502
|
|
|
1276
|
-
for i in range(self.speculative_num_steps):
|
|
1503
|
+
for i in range(self.speculative_num_steps - 1):
|
|
1277
1504
|
self.attn_backends[i].init_cuda_graph_state(
|
|
1278
1505
|
max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
|
1279
1506
|
)
|
|
@@ -9,27 +9,20 @@ and uses BatchMLAPaged wrapper for decoding.
|
|
|
9
9
|
More details can be found in https://docs.flashinfer.ai/api/mla.html
|
|
10
10
|
"""
|
|
11
11
|
|
|
12
|
-
import os
|
|
13
12
|
from dataclasses import dataclass
|
|
14
13
|
from functools import partial
|
|
15
14
|
from typing import TYPE_CHECKING, Callable, Optional, Union
|
|
16
15
|
|
|
17
16
|
import torch
|
|
18
17
|
|
|
19
|
-
|
|
20
|
-
import logging
|
|
21
|
-
|
|
22
|
-
torch._logging.set_logs(dynamo=logging.ERROR)
|
|
23
|
-
torch._dynamo.config.suppress_errors = True
|
|
24
|
-
|
|
25
|
-
from sglang.global_config import global_config
|
|
18
|
+
from sglang.srt.environ import envs
|
|
26
19
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
|
27
20
|
from sglang.srt.layers.attention.flashinfer_backend import (
|
|
28
21
|
create_flashinfer_kv_indices_triton,
|
|
29
22
|
)
|
|
30
23
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
|
31
|
-
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
|
32
24
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
|
25
|
+
from sglang.srt.server_args import get_global_server_args
|
|
33
26
|
from sglang.srt.speculative.spec_info import SpecInput
|
|
34
27
|
from sglang.srt.utils import (
|
|
35
28
|
is_flashinfer_available,
|
|
@@ -38,10 +31,19 @@ from sglang.srt.utils import (
|
|
|
38
31
|
)
|
|
39
32
|
|
|
40
33
|
if TYPE_CHECKING:
|
|
34
|
+
from sglang.srt.layers.attention.flashinfer_mla_backend import (
|
|
35
|
+
FlashInferMlaAttnBackend,
|
|
36
|
+
)
|
|
41
37
|
from sglang.srt.layers.radix_attention import RadixAttention
|
|
42
38
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
|
43
39
|
from sglang.srt.speculative.spec_info import SpecInput
|
|
44
40
|
|
|
41
|
+
if envs.SGLANG_ENABLE_TORCH_COMPILE.get():
|
|
42
|
+
import logging
|
|
43
|
+
|
|
44
|
+
torch._logging.set_logs(dynamo=logging.ERROR)
|
|
45
|
+
torch._dynamo.config.suppress_errors = True
|
|
46
|
+
|
|
45
47
|
if is_flashinfer_available():
|
|
46
48
|
from flashinfer import (
|
|
47
49
|
BatchMLAPagedAttentionWrapper,
|
|
@@ -66,7 +68,7 @@ global_workspace_buffer = None
|
|
|
66
68
|
|
|
67
69
|
class FlashInferMhaChunkKVRunner:
|
|
68
70
|
def __init__(
|
|
69
|
-
self, model_runner: ModelRunner, attn_backend:
|
|
71
|
+
self, model_runner: ModelRunner, attn_backend: FlashInferMlaAttnBackend
|
|
70
72
|
):
|
|
71
73
|
# Parse Constants
|
|
72
74
|
self.num_local_heads = (
|
|
@@ -193,9 +195,9 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
|
193
195
|
self.skip_prefill = skip_prefill
|
|
194
196
|
self.enable_chunk_kv = (
|
|
195
197
|
not skip_prefill
|
|
196
|
-
and
|
|
197
|
-
and not
|
|
198
|
-
and not
|
|
198
|
+
and get_global_server_args().disaggregation_mode != "decode"
|
|
199
|
+
and not get_global_server_args().disable_chunked_prefix_cache
|
|
200
|
+
and not get_global_server_args().flashinfer_mla_disable_ragged
|
|
199
201
|
)
|
|
200
202
|
self.page_size = model_runner.page_size
|
|
201
203
|
|
|
@@ -204,7 +206,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
|
204
206
|
if global_workspace_buffer is None:
|
|
205
207
|
# different from flashinfer zero_init_global_workspace_buffer
|
|
206
208
|
global_workspace_buffer = torch.empty(
|
|
207
|
-
|
|
209
|
+
envs.SGLANG_FLASHINFER_WORKSPACE_SIZE.get(),
|
|
208
210
|
dtype=torch.uint8,
|
|
209
211
|
device=model_runner.device,
|
|
210
212
|
)
|
|
@@ -306,7 +308,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
|
306
308
|
prefix_lens = forward_batch.extend_prefix_lens
|
|
307
309
|
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
|
|
308
310
|
use_ragged = (
|
|
309
|
-
not
|
|
311
|
+
not get_global_server_args().flashinfer_mla_disable_ragged
|
|
310
312
|
and extend_no_prefix
|
|
311
313
|
)
|
|
312
314
|
|
|
@@ -916,7 +918,7 @@ class FlashInferMLAMultiStepDraftBackend:
|
|
|
916
918
|
)
|
|
917
919
|
|
|
918
920
|
self.attn_backends = []
|
|
919
|
-
for i in range(self.speculative_num_steps):
|
|
921
|
+
for i in range(self.speculative_num_steps - 1):
|
|
920
922
|
self.attn_backends.append(
|
|
921
923
|
FlashInferMLAAttnBackend(
|
|
922
924
|
model_runner,
|
|
@@ -998,7 +1000,7 @@ class FlashInferMLAMultiStepDraftBackend:
|
|
|
998
1000
|
device="cuda",
|
|
999
1001
|
)
|
|
1000
1002
|
|
|
1001
|
-
for i in range(self.speculative_num_steps):
|
|
1003
|
+
for i in range(self.speculative_num_steps - 1):
|
|
1002
1004
|
self.attn_backends[i].init_cuda_graph_state(
|
|
1003
1005
|
max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
|
1004
1006
|
)
|
|
@@ -1060,7 +1062,7 @@ def fast_mla_decode_plan(
|
|
|
1060
1062
|
|
|
1061
1063
|
try:
|
|
1062
1064
|
# Standard version with just the required arguments (no use_profiler)
|
|
1063
|
-
self._cached_module.plan
|
|
1065
|
+
self._cached_module.plan(
|
|
1064
1066
|
self._float_workspace_buffer,
|
|
1065
1067
|
self._int_workspace_buffer,
|
|
1066
1068
|
self._pin_memory_int_workspace_buffer,
|
|
@@ -478,7 +478,7 @@ class FlashMLAMultiStepDraftBackend:
|
|
|
478
478
|
)
|
|
479
479
|
|
|
480
480
|
self.attn_backends = []
|
|
481
|
-
for i in range(self.speculative_num_steps):
|
|
481
|
+
for i in range(self.speculative_num_steps - 1):
|
|
482
482
|
self.attn_backends.append(
|
|
483
483
|
FlashMLABackend(
|
|
484
484
|
model_runner,
|
|
@@ -506,7 +506,7 @@ class FlashMLAMultiStepDraftBackend:
|
|
|
506
506
|
self.common_template(forward_batch, call_fn)
|
|
507
507
|
|
|
508
508
|
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
|
509
|
-
for i in range(self.speculative_num_steps):
|
|
509
|
+
for i in range(self.speculative_num_steps - 1):
|
|
510
510
|
self.attn_backends[i].init_cuda_graph_state(
|
|
511
511
|
max_bs, max_num_tokens, block_kv_indices=None
|
|
512
512
|
)
|