sglang 0.5.3rc2__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +47 -28
 - sglang/bench_one_batch_server.py +41 -25
 - sglang/bench_serving.py +378 -160
 - sglang/check_env.py +1 -1
 - sglang/compile_deep_gemm.py +6 -2
 - sglang/global_config.py +1 -25
 - sglang/lang/api.py +6 -0
 - sglang/lang/interpreter.py +1 -0
 - sglang/lang/ir.py +13 -0
 - sglang/launch_server.py +10 -15
 - sglang/profiler.py +18 -1
 - sglang/srt/_custom_ops.py +1 -1
 - sglang/srt/batch_invariant_ops/batch_invariant_ops.py +105 -10
 - sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
 - sglang/srt/compilation/backend.py +437 -0
 - sglang/srt/compilation/compilation_config.py +20 -0
 - sglang/srt/compilation/compilation_counter.py +47 -0
 - sglang/srt/compilation/compile.py +210 -0
 - sglang/srt/compilation/compiler_interface.py +503 -0
 - sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
 - sglang/srt/compilation/fix_functionalization.py +134 -0
 - sglang/srt/compilation/fx_utils.py +83 -0
 - sglang/srt/compilation/inductor_pass.py +140 -0
 - sglang/srt/compilation/pass_manager.py +66 -0
 - sglang/srt/compilation/piecewise_context_manager.py +40 -0
 - sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
 - sglang/srt/configs/__init__.py +4 -0
 - sglang/srt/configs/deepseek_ocr.py +262 -0
 - sglang/srt/configs/deepseekvl2.py +194 -96
 - sglang/srt/configs/dots_vlm.py +2 -7
 - sglang/srt/configs/falcon_h1.py +13 -64
 - sglang/srt/configs/load_config.py +25 -2
 - sglang/srt/configs/mamba_utils.py +117 -0
 - sglang/srt/configs/model_config.py +136 -25
 - sglang/srt/configs/modelopt_config.py +30 -0
 - sglang/srt/configs/nemotron_h.py +286 -0
 - sglang/srt/configs/olmo3.py +105 -0
 - sglang/srt/configs/points_v15_chat.py +29 -0
 - sglang/srt/configs/qwen3_next.py +11 -47
 - sglang/srt/configs/qwen3_omni.py +613 -0
 - sglang/srt/configs/qwen3_vl.py +0 -10
 - sglang/srt/connector/remote_instance.py +1 -1
 - sglang/srt/constrained/base_grammar_backend.py +5 -1
 - sglang/srt/constrained/llguidance_backend.py +5 -0
 - sglang/srt/constrained/outlines_backend.py +1 -1
 - sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
 - sglang/srt/constrained/utils.py +12 -0
 - sglang/srt/constrained/xgrammar_backend.py +20 -11
 - sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
 - sglang/srt/disaggregation/base/conn.py +17 -4
 - sglang/srt/disaggregation/common/conn.py +4 -2
 - sglang/srt/disaggregation/decode.py +123 -31
 - sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
 - sglang/srt/disaggregation/fake/conn.py +11 -3
 - sglang/srt/disaggregation/mooncake/conn.py +157 -19
 - sglang/srt/disaggregation/nixl/conn.py +69 -24
 - sglang/srt/disaggregation/prefill.py +96 -270
 - sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
 - sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
 - sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
 - sglang/srt/distributed/device_communicators/pynccl.py +24 -12
 - sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
 - sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
 - sglang/srt/distributed/naive_distributed.py +5 -4
 - sglang/srt/distributed/parallel_state.py +63 -19
 - sglang/srt/elastic_ep/elastic_ep.py +74 -0
 - sglang/srt/entrypoints/context.py +3 -2
 - sglang/srt/entrypoints/engine.py +83 -80
 - sglang/srt/entrypoints/grpc_server.py +430 -234
 - sglang/srt/entrypoints/harmony_utils.py +2 -2
 - sglang/srt/entrypoints/http_server.py +195 -102
 - sglang/srt/entrypoints/http_server_engine.py +1 -7
 - sglang/srt/entrypoints/openai/protocol.py +225 -37
 - sglang/srt/entrypoints/openai/serving_base.py +49 -2
 - sglang/srt/entrypoints/openai/serving_chat.py +29 -74
 - sglang/srt/entrypoints/openai/serving_classify.py +204 -0
 - sglang/srt/entrypoints/openai/serving_completions.py +15 -1
 - sglang/srt/entrypoints/openai/serving_responses.py +5 -2
 - sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
 - sglang/srt/environ.py +58 -6
 - sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
 - sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
 - sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
 - sglang/srt/eplb/expert_distribution.py +33 -4
 - sglang/srt/eplb/expert_location_dispatch.py +2 -2
 - sglang/srt/eplb/expert_location_updater.py +2 -2
 - sglang/srt/function_call/base_format_detector.py +17 -18
 - sglang/srt/function_call/function_call_parser.py +20 -14
 - sglang/srt/function_call/glm4_moe_detector.py +1 -5
 - sglang/srt/function_call/gpt_oss_detector.py +1 -1
 - sglang/srt/function_call/json_array_parser.py +0 -2
 - sglang/srt/function_call/minimax_m2.py +367 -0
 - sglang/srt/function_call/utils.py +2 -2
 - sglang/srt/grpc/compile_proto.py +3 -3
 - sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
 - sglang/srt/grpc/health_servicer.py +189 -0
 - sglang/srt/grpc/scheduler_launcher.py +181 -0
 - sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
 - sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
 - sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
 - sglang/srt/layers/activation.py +10 -1
 - sglang/srt/layers/attention/aiter_backend.py +3 -3
 - sglang/srt/layers/attention/ascend_backend.py +17 -1
 - sglang/srt/layers/attention/attention_registry.py +43 -23
 - sglang/srt/layers/attention/base_attn_backend.py +20 -1
 - sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
 - sglang/srt/layers/attention/fla/chunk.py +0 -1
 - sglang/srt/layers/attention/fla/chunk_o.py +1 -1
 - sglang/srt/layers/attention/fla/index.py +0 -2
 - sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
 - sglang/srt/layers/attention/fla/utils.py +0 -3
 - sglang/srt/layers/attention/fla/wy_fast.py +0 -2
 - sglang/srt/layers/attention/flashattention_backend.py +24 -10
 - sglang/srt/layers/attention/flashinfer_backend.py +258 -22
 - sglang/srt/layers/attention/flashinfer_mla_backend.py +38 -28
 - sglang/srt/layers/attention/flashmla_backend.py +2 -2
 - sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
 - sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
 - sglang/srt/layers/attention/intel_amx_backend.py +1 -1
 - sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
 - sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
 - sglang/srt/layers/attention/mamba/mamba.py +189 -241
 - sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
 - sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
 - sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
 - sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
 - sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
 - sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
 - sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
 - sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
 - sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
 - sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
 - sglang/srt/layers/attention/nsa/utils.py +0 -1
 - sglang/srt/layers/attention/nsa_backend.py +404 -90
 - sglang/srt/layers/attention/triton_backend.py +208 -34
 - sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
 - sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
 - sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
 - sglang/srt/layers/attention/trtllm_mla_backend.py +362 -43
 - sglang/srt/layers/attention/utils.py +89 -7
 - sglang/srt/layers/attention/vision.py +3 -3
 - sglang/srt/layers/attention/xpu_backend.py +1028 -0
 - sglang/srt/layers/communicator.py +12 -7
 - sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +5 -9
 - sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
 - sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
 - sglang/srt/layers/dp_attention.py +17 -0
 - sglang/srt/layers/layernorm.py +64 -19
 - sglang/srt/layers/linear.py +9 -1
 - sglang/srt/layers/logits_processor.py +152 -17
 - sglang/srt/layers/modelopt_utils.py +11 -0
 - sglang/srt/layers/moe/cutlass_moe.py +0 -2
 - sglang/srt/layers/moe/cutlass_w4a8_moe.py +351 -21
 - sglang/srt/layers/moe/ep_moe/kernels.py +229 -457
 - sglang/srt/layers/moe/ep_moe/layer.py +154 -625
 - sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
 - sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
 - sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
 - sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
 - sglang/srt/layers/moe/fused_moe_triton/layer.py +79 -73
 - sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +25 -46
 - sglang/srt/layers/moe/moe_runner/deep_gemm.py +569 -0
 - sglang/srt/layers/moe/moe_runner/runner.py +6 -0
 - sglang/srt/layers/moe/moe_runner/triton.py +3 -1
 - sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
 - sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
 - sglang/srt/layers/moe/router.py +51 -15
 - sglang/srt/layers/moe/token_dispatcher/__init__.py +14 -4
 - sglang/srt/layers/moe/token_dispatcher/base.py +12 -6
 - sglang/srt/layers/moe/token_dispatcher/deepep.py +127 -110
 - sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
 - sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
 - sglang/srt/layers/moe/topk.py +7 -6
 - sglang/srt/layers/moe/utils.py +20 -5
 - sglang/srt/layers/quantization/__init__.py +5 -58
 - sglang/srt/layers/quantization/awq.py +183 -9
 - sglang/srt/layers/quantization/awq_triton.py +29 -0
 - sglang/srt/layers/quantization/base_config.py +27 -1
 - sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
 - sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
 - sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
 - sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
 - sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
 - sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
 - sglang/srt/layers/quantization/fp8.py +152 -81
 - sglang/srt/layers/quantization/fp8_kernel.py +55 -10
 - sglang/srt/layers/quantization/fp8_utils.py +42 -14
 - sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
 - sglang/srt/layers/quantization/gguf.py +566 -0
 - sglang/srt/layers/quantization/gptq.py +0 -1
 - sglang/srt/layers/quantization/int8_kernel.py +18 -2
 - sglang/srt/layers/quantization/marlin_utils.py +12 -0
 - sglang/srt/layers/quantization/modelopt_quant.py +125 -100
 - sglang/srt/layers/quantization/mxfp4.py +35 -68
 - sglang/srt/layers/quantization/petit.py +1 -1
 - sglang/srt/layers/quantization/quark/quark.py +3 -1
 - sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
 - sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
 - sglang/srt/layers/quantization/unquant.py +23 -48
 - sglang/srt/layers/quantization/utils.py +0 -1
 - sglang/srt/layers/quantization/w4afp8.py +87 -20
 - sglang/srt/layers/quantization/w8a8_int8.py +30 -24
 - sglang/srt/layers/radix_attention.py +62 -9
 - sglang/srt/layers/rotary_embedding.py +686 -17
 - sglang/srt/layers/sampler.py +47 -16
 - sglang/srt/layers/sparse_pooler.py +98 -0
 - sglang/srt/layers/utils.py +0 -1
 - sglang/srt/layers/vocab_parallel_embedding.py +4 -1
 - sglang/srt/lora/backend/triton_backend.py +0 -1
 - sglang/srt/lora/eviction_policy.py +139 -0
 - sglang/srt/lora/lora_manager.py +24 -9
 - sglang/srt/lora/lora_registry.py +1 -1
 - sglang/srt/lora/mem_pool.py +40 -16
 - sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
 - sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
 - sglang/srt/managers/cache_controller.py +48 -17
 - sglang/srt/managers/data_parallel_controller.py +146 -42
 - sglang/srt/managers/detokenizer_manager.py +40 -13
 - sglang/srt/managers/io_struct.py +69 -16
 - sglang/srt/managers/mm_utils.py +20 -18
 - sglang/srt/managers/multi_tokenizer_mixin.py +83 -82
 - sglang/srt/managers/overlap_utils.py +96 -19
 - sglang/srt/managers/schedule_batch.py +241 -511
 - sglang/srt/managers/schedule_policy.py +15 -2
 - sglang/srt/managers/scheduler.py +420 -514
 - sglang/srt/managers/scheduler_metrics_mixin.py +73 -18
 - sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
 - sglang/srt/managers/scheduler_pp_mixin.py +341 -0
 - sglang/srt/managers/scheduler_profiler_mixin.py +60 -14
 - sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
 - sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
 - sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
 - sglang/srt/managers/tokenizer_manager.py +375 -95
 - sglang/srt/managers/tp_worker.py +212 -161
 - sglang/srt/managers/utils.py +78 -2
 - sglang/srt/mem_cache/allocator.py +7 -2
 - sglang/srt/mem_cache/allocator_ascend.py +2 -2
 - sglang/srt/mem_cache/base_prefix_cache.py +2 -2
 - sglang/srt/mem_cache/chunk_cache.py +13 -2
 - sglang/srt/mem_cache/common.py +480 -0
 - sglang/srt/mem_cache/evict_policy.py +16 -1
 - sglang/srt/mem_cache/hicache_storage.py +11 -2
 - sglang/srt/mem_cache/hiradix_cache.py +16 -3
 - sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
 - sglang/srt/mem_cache/memory_pool.py +517 -219
 - sglang/srt/mem_cache/memory_pool_host.py +0 -1
 - sglang/srt/mem_cache/multimodal_cache.py +0 -1
 - sglang/srt/mem_cache/radix_cache.py +53 -19
 - sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
 - sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
 - sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
 - sglang/srt/mem_cache/storage/backend_factory.py +2 -2
 - sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
 - sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
 - sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
 - sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
 - sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
 - sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
 - sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
 - sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
 - sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
 - sglang/srt/mem_cache/swa_radix_cache.py +92 -26
 - sglang/srt/metrics/collector.py +31 -0
 - sglang/srt/metrics/func_timer.py +1 -1
 - sglang/srt/model_executor/cuda_graph_runner.py +43 -5
 - sglang/srt/model_executor/forward_batch_info.py +71 -25
 - sglang/srt/model_executor/model_runner.py +362 -270
 - sglang/srt/model_executor/npu_graph_runner.py +2 -3
 - sglang/srt/model_executor/piecewise_cuda_graph_runner.py +549 -0
 - sglang/srt/model_loader/__init__.py +1 -1
 - sglang/srt/model_loader/loader.py +424 -27
 - sglang/srt/model_loader/utils.py +0 -1
 - sglang/srt/model_loader/weight_utils.py +47 -28
 - sglang/srt/models/apertus.py +2 -3
 - sglang/srt/models/arcee.py +2 -2
 - sglang/srt/models/bailing_moe.py +13 -52
 - sglang/srt/models/bailing_moe_nextn.py +3 -4
 - sglang/srt/models/bert.py +1 -1
 - sglang/srt/models/deepseek_nextn.py +19 -3
 - sglang/srt/models/deepseek_ocr.py +1516 -0
 - sglang/srt/models/deepseek_v2.py +418 -140
 - sglang/srt/models/dots_ocr.py +0 -2
 - sglang/srt/models/dots_vlm.py +0 -1
 - sglang/srt/models/dots_vlm_vit.py +1 -1
 - sglang/srt/models/falcon_h1.py +13 -19
 - sglang/srt/models/gemma3_mm.py +16 -0
 - sglang/srt/models/gemma3n_mm.py +1 -2
 - sglang/srt/models/glm4_moe.py +327 -382
 - sglang/srt/models/glm4_moe_nextn.py +6 -16
 - sglang/srt/models/glm4v.py +2 -1
 - sglang/srt/models/glm4v_moe.py +32 -199
 - sglang/srt/models/gpt_oss.py +5 -5
 - sglang/srt/models/grok.py +10 -23
 - sglang/srt/models/hunyuan.py +2 -7
 - sglang/srt/models/interns1.py +0 -1
 - sglang/srt/models/kimi_vl.py +1 -7
 - sglang/srt/models/kimi_vl_moonvit.py +3 -1
 - sglang/srt/models/llama.py +2 -2
 - sglang/srt/models/llama_eagle3.py +1 -1
 - sglang/srt/models/longcat_flash.py +5 -22
 - sglang/srt/models/longcat_flash_nextn.py +3 -14
 - sglang/srt/models/mimo.py +2 -13
 - sglang/srt/models/mimo_mtp.py +1 -2
 - sglang/srt/models/minicpmo.py +7 -5
 - sglang/srt/models/minimax_m2.py +922 -0
 - sglang/srt/models/mixtral.py +1 -4
 - sglang/srt/models/mllama.py +1 -1
 - sglang/srt/models/mllama4.py +13 -3
 - sglang/srt/models/nemotron_h.py +511 -0
 - sglang/srt/models/nvila.py +355 -0
 - sglang/srt/models/nvila_lite.py +184 -0
 - sglang/srt/models/olmo2.py +31 -4
 - sglang/srt/models/opt.py +5 -5
 - sglang/srt/models/phi.py +1 -1
 - sglang/srt/models/phi4mm.py +1 -1
 - sglang/srt/models/phimoe.py +0 -1
 - sglang/srt/models/pixtral.py +0 -3
 - sglang/srt/models/points_v15_chat.py +186 -0
 - sglang/srt/models/qwen.py +0 -1
 - sglang/srt/models/qwen2.py +22 -1
 - sglang/srt/models/qwen2_5_vl.py +3 -3
 - sglang/srt/models/qwen2_audio.py +2 -15
 - sglang/srt/models/qwen2_moe.py +15 -12
 - sglang/srt/models/qwen2_vl.py +5 -2
 - sglang/srt/models/qwen3.py +34 -4
 - sglang/srt/models/qwen3_moe.py +19 -37
 - sglang/srt/models/qwen3_next.py +7 -12
 - sglang/srt/models/qwen3_next_mtp.py +3 -4
 - sglang/srt/models/qwen3_omni_moe.py +661 -0
 - sglang/srt/models/qwen3_vl.py +37 -33
 - sglang/srt/models/qwen3_vl_moe.py +57 -185
 - sglang/srt/models/roberta.py +55 -3
 - sglang/srt/models/sarashina2_vision.py +0 -1
 - sglang/srt/models/step3_vl.py +3 -5
 - sglang/srt/models/utils.py +11 -1
 - sglang/srt/multimodal/processors/base_processor.py +7 -2
 - sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
 - sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
 - sglang/srt/multimodal/processors/dots_vlm.py +0 -1
 - sglang/srt/multimodal/processors/glm4v.py +2 -6
 - sglang/srt/multimodal/processors/internvl.py +0 -2
 - sglang/srt/multimodal/processors/janus_pro.py +0 -1
 - sglang/srt/multimodal/processors/mllama4.py +0 -8
 - sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
 - sglang/srt/multimodal/processors/phi4mm.py +0 -1
 - sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
 - sglang/srt/multimodal/processors/qwen_vl.py +75 -16
 - sglang/srt/multimodal/processors/step3_vl.py +1 -1
 - sglang/srt/parser/conversation.py +41 -0
 - sglang/srt/parser/reasoning_parser.py +28 -2
 - sglang/srt/sampling/custom_logit_processor.py +77 -2
 - sglang/srt/sampling/sampling_batch_info.py +17 -22
 - sglang/srt/sampling/sampling_params.py +70 -2
 - sglang/srt/server_args.py +846 -163
 - sglang/srt/server_args_config_parser.py +1 -1
 - sglang/srt/single_batch_overlap.py +36 -31
 - sglang/srt/speculative/base_spec_worker.py +34 -0
 - sglang/srt/speculative/draft_utils.py +226 -0
 - sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
 - sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
 - sglang/srt/speculative/eagle_info.py +57 -18
 - sglang/srt/speculative/eagle_info_v2.py +458 -0
 - sglang/srt/speculative/eagle_utils.py +138 -0
 - sglang/srt/speculative/eagle_worker.py +83 -280
 - sglang/srt/speculative/eagle_worker_v2.py +702 -0
 - sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
 - sglang/srt/speculative/ngram_worker.py +12 -11
 - sglang/srt/speculative/spec_info.py +2 -0
 - sglang/srt/speculative/spec_utils.py +38 -3
 - sglang/srt/speculative/standalone_worker.py +4 -14
 - sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
 - sglang/srt/two_batch_overlap.py +28 -14
 - sglang/srt/utils/__init__.py +1 -1
 - sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
 - sglang/srt/utils/common.py +272 -82
 - sglang/srt/utils/hf_transformers_utils.py +44 -17
 - sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
 - sglang/srt/{offloader.py → utils/offloader.py} +4 -4
 - sglang/srt/utils/profile_merger.py +199 -0
 - sglang/test/attention/test_flashattn_backend.py +1 -1
 - sglang/test/attention/test_flashattn_mla_backend.py +0 -1
 - sglang/test/attention/test_prefix_chunk_info.py +0 -2
 - sglang/test/attention/test_trtllm_mla_backend.py +221 -53
 - sglang/test/few_shot_gsm8k_engine.py +2 -4
 - sglang/test/kit_matched_stop.py +157 -0
 - sglang/test/longbench_v2/__init__.py +1 -0
 - sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
 - sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
 - sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
 - sglang/test/run_eval.py +41 -0
 - sglang/test/runners.py +2 -0
 - sglang/test/send_one.py +42 -7
 - sglang/test/simple_eval_common.py +3 -0
 - sglang/test/simple_eval_gpqa.py +0 -1
 - sglang/test/simple_eval_humaneval.py +0 -3
 - sglang/test/simple_eval_longbench_v2.py +344 -0
 - sglang/test/test_block_fp8.py +1 -2
 - sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
 - sglang/test/test_cutlass_moe.py +1 -2
 - sglang/test/test_cutlass_w4a8_moe.py +10 -20
 - sglang/test/test_deterministic.py +463 -107
 - sglang/test/test_deterministic_utils.py +74 -0
 - sglang/test/test_disaggregation_utils.py +81 -0
 - sglang/test/test_marlin_moe.py +0 -1
 - sglang/test/test_utils.py +85 -20
 - sglang/version.py +1 -1
 - {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +48 -35
 - {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +414 -350
 - sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
 - sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
 - sglang/srt/models/vila.py +0 -306
 - sglang/srt/speculative/build_eagle_tree.py +0 -427
 - sglang/test/test_block_fp8_ep.py +0 -358
 - /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
 - /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
 - /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
 - {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
 - {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
 - {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
 
| 
         @@ -7,6 +7,7 @@ FlashInfer is faster and Triton is easier to customize. 
     | 
|
| 
       7 
7 
     | 
    
         
             
            Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
         
     | 
| 
       8 
8 
     | 
    
         
             
            """
         
     | 
| 
       9 
9 
     | 
    
         | 
| 
      
 10 
     | 
    
         
            +
            import logging
         
     | 
| 
       10 
11 
     | 
    
         
             
            import os
         
     | 
| 
       11 
12 
     | 
    
         
             
            from dataclasses import dataclass
         
     | 
| 
       12 
13 
     | 
    
         
             
            from enum import Enum, auto
         
     | 
| 
         @@ -15,20 +16,13 @@ from typing import TYPE_CHECKING, Callable, List, Optional, Union 
     | 
|
| 
       15 
16 
     | 
    
         | 
| 
       16 
17 
     | 
    
         
             
            import torch
         
     | 
| 
       17 
18 
     | 
    
         | 
| 
       18 
     | 
    
         
            -
             
     | 
| 
       19 
     | 
    
         
            -
                import logging
         
     | 
| 
       20 
     | 
    
         
            -
             
     | 
| 
       21 
     | 
    
         
            -
                torch._logging.set_logs(dynamo=logging.ERROR)
         
     | 
| 
       22 
     | 
    
         
            -
                torch._dynamo.config.suppress_errors = True
         
     | 
| 
       23 
     | 
    
         
            -
             
     | 
| 
       24 
     | 
    
         
            -
            from sglang.global_config import global_config
         
     | 
| 
      
 19 
     | 
    
         
            +
            from sglang.srt.environ import envs
         
     | 
| 
       25 
20 
     | 
    
         
             
            from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
         
     | 
| 
       26 
21 
     | 
    
         
             
            from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
         
     | 
| 
       27 
22 
     | 
    
         
             
            from sglang.srt.layers.dp_attention import get_attention_tp_size
         
     | 
| 
       28 
23 
     | 
    
         
             
            from sglang.srt.layers.radix_attention import AttentionType
         
     | 
| 
       29 
24 
     | 
    
         
             
            from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
         
     | 
| 
       30 
25 
     | 
    
         
             
            from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
         
     | 
| 
       31 
     | 
    
         
            -
            from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput
         
     | 
| 
       32 
26 
     | 
    
         
             
            from sglang.srt.speculative.spec_info import SpecInput
         
     | 
| 
       33 
27 
     | 
    
         
             
            from sglang.srt.utils import (
         
     | 
| 
       34 
28 
     | 
    
         
             
                get_int_env_var,
         
     | 
| 
         @@ -41,6 +35,12 @@ if TYPE_CHECKING: 
     | 
|
| 
       41 
35 
     | 
    
         
             
                from sglang.srt.layers.radix_attention import RadixAttention
         
     | 
| 
       42 
36 
     | 
    
         
             
                from sglang.srt.model_executor.model_runner import ModelRunner
         
     | 
| 
       43 
37 
     | 
    
         | 
| 
      
 38 
     | 
    
         
            +
            logger = logging.getLogger(__name__)
         
     | 
| 
      
 39 
     | 
    
         
            +
             
     | 
| 
      
 40 
     | 
    
         
            +
            if envs.SGLANG_ENABLE_TORCH_COMPILE.get():
         
     | 
| 
      
 41 
     | 
    
         
            +
                torch._logging.set_logs(dynamo=logging.ERROR)
         
     | 
| 
      
 42 
     | 
    
         
            +
                torch._dynamo.config.suppress_errors = True
         
     | 
| 
      
 43 
     | 
    
         
            +
             
     | 
| 
       44 
44 
     | 
    
         | 
| 
       45 
45 
     | 
    
         
             
            if is_flashinfer_available():
         
     | 
| 
       46 
46 
     | 
    
         
             
                from flashinfer import (
         
     | 
| 
         @@ -50,7 +50,6 @@ if is_flashinfer_available(): 
     | 
|
| 
       50 
50 
     | 
    
         
             
                    fast_decode_plan,
         
     | 
| 
       51 
51 
     | 
    
         
             
                )
         
     | 
| 
       52 
52 
     | 
    
         
             
                from flashinfer.cascade import merge_state
         
     | 
| 
       53 
     | 
    
         
            -
                from flashinfer.decode import _get_range_buf, get_seq_lens
         
     | 
| 
       54 
53 
     | 
    
         | 
| 
       55 
54 
     | 
    
         | 
| 
       56 
55 
     | 
    
         
             
            class WrapperDispatch(Enum):
         
     | 
| 
         @@ -58,6 +57,36 @@ class WrapperDispatch(Enum): 
     | 
|
| 
       58 
57 
     | 
    
         
             
                CROSS_ATTENTION = auto()
         
     | 
| 
       59 
58 
     | 
    
         | 
| 
       60 
59 
     | 
    
         | 
| 
      
 60 
     | 
    
         
            +
            @dataclass
         
     | 
| 
      
 61 
     | 
    
         
            +
            class MultiItemScoringParams:
         
     | 
| 
      
 62 
     | 
    
         
            +
                """Parameters for multi-item scoring in attention computation.
         
     | 
| 
      
 63 
     | 
    
         
            +
             
     | 
| 
      
 64 
     | 
    
         
            +
                Used when processing sequences with multiple items separated by delimiters,
         
     | 
| 
      
 65 
     | 
    
         
            +
                where each item needs specific attention patterns that respect item boundaries.
         
     | 
| 
      
 66 
     | 
    
         
            +
             
     | 
| 
      
 67 
     | 
    
         
            +
                Attributes:
         
     | 
| 
      
 68 
     | 
    
         
            +
                    prefix_len_ptr: A uint32 1D tensor indicating the prefix length of each prompt.
         
     | 
| 
      
 69 
     | 
    
         
            +
                                   The tensor size is equal to the batch size.
         
     | 
| 
      
 70 
     | 
    
         
            +
                    token_pos_in_items_ptr: A uint16 1D tensor indicating the token position of each item
         
     | 
| 
      
 71 
     | 
    
         
            +
                                           starting from 0 (delimiter) for each item. For batch size > 1,
         
     | 
| 
      
 72 
     | 
    
         
            +
                                           sequences are concatenated with zero padding to ensure same length.
         
     | 
| 
      
 73 
     | 
    
         
            +
                    token_pos_in_items_len: Zero padding length for token_pos_in_items_ptr to handle
         
     | 
| 
      
 74 
     | 
    
         
            +
                                           batch_size > 1 case. Defines the padded length for each sequence.
         
     | 
| 
      
 75 
     | 
    
         
            +
                    max_item_len_ptr: A uint16 tensor containing the max token length of all items
         
     | 
| 
      
 76 
     | 
    
         
            +
                                     for each prompt in the batch.
         
     | 
| 
      
 77 
     | 
    
         
            +
             
     | 
| 
      
 78 
     | 
    
         
            +
                """
         
     | 
| 
      
 79 
     | 
    
         
            +
             
     | 
| 
      
 80 
     | 
    
         
            +
                prefix_len_ptr: Optional[torch.Tensor] = None
         
     | 
| 
      
 81 
     | 
    
         
            +
                token_pos_in_items_ptr: Optional[torch.Tensor] = None
         
     | 
| 
      
 82 
     | 
    
         
            +
                token_pos_in_items_len: int = 0
         
     | 
| 
      
 83 
     | 
    
         
            +
                max_item_len_ptr: Optional[torch.Tensor] = None
         
     | 
| 
      
 84 
     | 
    
         
            +
             
     | 
| 
      
 85 
     | 
    
         
            +
                def is_enabled(self) -> bool:
         
     | 
| 
      
 86 
     | 
    
         
            +
                    """Check if multi-item scoring is enabled."""
         
     | 
| 
      
 87 
     | 
    
         
            +
                    return self.prefix_len_ptr is not None
         
     | 
| 
      
 88 
     | 
    
         
            +
             
     | 
| 
      
 89 
     | 
    
         
            +
             
     | 
| 
       61 
90 
     | 
    
         
             
            @dataclass
         
     | 
| 
       62 
91 
     | 
    
         
             
            class DecodeMetadata:
         
     | 
| 
       63 
92 
     | 
    
         
             
                decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper]
         
     | 
| 
         @@ -68,6 +97,7 @@ class PrefillMetadata: 
     | 
|
| 
       68 
97 
     | 
    
         
             
                prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper]
         
     | 
| 
       69 
98 
     | 
    
         
             
                use_ragged: bool
         
     | 
| 
       70 
99 
     | 
    
         
             
                extend_no_prefix: bool
         
     | 
| 
      
 100 
     | 
    
         
            +
                multi_item_params: Optional[MultiItemScoringParams] = None
         
     | 
| 
       71 
101 
     | 
    
         | 
| 
       72 
102 
     | 
    
         | 
| 
       73 
103 
     | 
    
         
             
            # Reuse this workspace buffer across all flashinfer wrappers
         
     | 
| 
         @@ -87,9 +117,15 @@ class FlashInferAttnBackend(AttentionBackend): 
     | 
|
| 
       87 
117 
     | 
    
         
             
                    skip_prefill: bool = False,
         
     | 
| 
       88 
118 
     | 
    
         
             
                    kv_indptr_buf: Optional[torch.Tensor] = None,
         
     | 
| 
       89 
119 
     | 
    
         
             
                    kv_last_page_len_buf: Optional[torch.Tensor] = None,
         
     | 
| 
      
 120 
     | 
    
         
            +
                    init_new_workspace: bool = False,
         
     | 
| 
       90 
121 
     | 
    
         
             
                ):
         
     | 
| 
       91 
122 
     | 
    
         
             
                    super().__init__()
         
     | 
| 
       92 
123 
     | 
    
         | 
| 
      
 124 
     | 
    
         
            +
                    # Store multi-item scoring delimiter for efficient access
         
     | 
| 
      
 125 
     | 
    
         
            +
                    self.multi_item_scoring_delimiter = (
         
     | 
| 
      
 126 
     | 
    
         
            +
                        model_runner.server_args.multi_item_scoring_delimiter
         
     | 
| 
      
 127 
     | 
    
         
            +
                    )
         
     | 
| 
      
 128 
     | 
    
         
            +
             
     | 
| 
       93 
129 
     | 
    
         
             
                    # Parse constants
         
     | 
| 
       94 
130 
     | 
    
         
             
                    self.decode_use_tensor_cores = should_use_tensor_core(
         
     | 
| 
       95 
131 
     | 
    
         
             
                        kv_cache_dtype=model_runner.kv_cache_dtype,
         
     | 
| 
         @@ -124,7 +160,7 @@ class FlashInferAttnBackend(AttentionBackend): 
     | 
|
| 
       124 
160 
     | 
    
         
             
                        or "Qwen3ForCausalLM" in model_runner.model_config.hf_config.architectures
         
     | 
| 
       125 
161 
     | 
    
         
             
                        or "MiMoForCausalLM" in model_runner.model_config.hf_config.architectures
         
     | 
| 
       126 
162 
     | 
    
         
             
                    ):
         
     | 
| 
       127 
     | 
    
         
            -
                         
     | 
| 
      
 163 
     | 
    
         
            +
                        envs.SGLANG_FLASHINFER_WORKSPACE_SIZE.set(512 * 1024 * 1024)
         
     | 
| 
       128 
164 
     | 
    
         | 
| 
       129 
165 
     | 
    
         
             
                    # When deterministic inference is enabled, tensor cores should be used for decode
         
     | 
| 
       130 
166 
     | 
    
         
             
                    # Also set split tile sizes for prefill and decode from environment variables, and disable kv split for cuda graph
         
     | 
| 
         @@ -144,19 +180,26 @@ class FlashInferAttnBackend(AttentionBackend): 
     | 
|
| 
       144 
180 
     | 
    
         
             
                            "SGLANG_FLASHINFER_DECODE_SPLIT_TILE_SIZE", 2048
         
     | 
| 
       145 
181 
     | 
    
         
             
                        )
         
     | 
| 
       146 
182 
     | 
    
         
             
                        self.disable_cuda_graph_kv_split = True
         
     | 
| 
       147 
     | 
    
         
            -
                         
     | 
| 
      
 183 
     | 
    
         
            +
                        envs.SGLANG_FLASHINFER_WORKSPACE_SIZE.set(2048 * 1024 * 1024)
         
     | 
| 
       148 
184 
     | 
    
         | 
| 
       149 
185 
     | 
    
         
             
                    # Allocate buffers
         
     | 
| 
       150 
186 
     | 
    
         
             
                    global global_workspace_buffer
         
     | 
| 
       151 
187 
     | 
    
         
             
                    if global_workspace_buffer is None:
         
     | 
| 
       152 
188 
     | 
    
         
             
                        # different from flashinfer zero_init_global_workspace_buffer
         
     | 
| 
       153 
     | 
    
         
            -
                        global_workspace_size =  
     | 
| 
      
 189 
     | 
    
         
            +
                        global_workspace_size = envs.SGLANG_FLASHINFER_WORKSPACE_SIZE.get()
         
     | 
| 
       154 
190 
     | 
    
         
             
                        global_workspace_buffer = torch.empty(
         
     | 
| 
       155 
191 
     | 
    
         
             
                            global_workspace_size,
         
     | 
| 
       156 
192 
     | 
    
         
             
                            dtype=torch.uint8,
         
     | 
| 
       157 
193 
     | 
    
         
             
                            device=model_runner.device,
         
     | 
| 
       158 
194 
     | 
    
         
             
                        )
         
     | 
| 
       159 
     | 
    
         
            -
                     
     | 
| 
      
 195 
     | 
    
         
            +
                    if init_new_workspace:
         
     | 
| 
      
 196 
     | 
    
         
            +
                        self.workspace_buffer = torch.empty(
         
     | 
| 
      
 197 
     | 
    
         
            +
                            envs.SGLANG_FLASHINFER_WORKSPACE_SIZE.get(),
         
     | 
| 
      
 198 
     | 
    
         
            +
                            dtype=torch.uint8,
         
     | 
| 
      
 199 
     | 
    
         
            +
                            device=model_runner.device,
         
     | 
| 
      
 200 
     | 
    
         
            +
                        )
         
     | 
| 
      
 201 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 202 
     | 
    
         
            +
                        self.workspace_buffer = global_workspace_buffer
         
     | 
| 
       160 
203 
     | 
    
         
             
                    max_bs = model_runner.req_to_token_pool.size
         
     | 
| 
       161 
204 
     | 
    
         
             
                    if kv_indptr_buf is None:
         
     | 
| 
       162 
205 
     | 
    
         
             
                        self.kv_indptr = [
         
     | 
| 
         @@ -187,7 +230,16 @@ class FlashInferAttnBackend(AttentionBackend): 
     | 
|
| 
       187 
230 
     | 
    
         | 
| 
       188 
231 
     | 
    
         
             
                    fmha_backend = "auto"
         
     | 
| 
       189 
232 
     | 
    
         
             
                    if is_sm100_supported():
         
     | 
| 
       190 
     | 
    
         
            -
                         
     | 
| 
      
 233 
     | 
    
         
            +
                        # Disable CUTLASS backend when piecewise cuda graph is enabled
         
     | 
| 
      
 234 
     | 
    
         
            +
                        # due to TMA descriptor initialization issues on B200
         
     | 
| 
      
 235 
     | 
    
         
            +
                        if model_runner.server_args.enable_piecewise_cuda_graph:
         
     | 
| 
      
 236 
     | 
    
         
            +
                            logger.warning(
         
     | 
| 
      
 237 
     | 
    
         
            +
                                "CUTLASS backend is disabled when piecewise cuda graph is enabled "
         
     | 
| 
      
 238 
     | 
    
         
            +
                                "due to TMA descriptor initialization issues on B200. "
         
     | 
| 
      
 239 
     | 
    
         
            +
                                "Using auto backend instead for stability."
         
     | 
| 
      
 240 
     | 
    
         
            +
                            )
         
     | 
| 
      
 241 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 242 
     | 
    
         
            +
                            fmha_backend = "cutlass"
         
     | 
| 
       191 
243 
     | 
    
         
             
                    self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
         
     | 
| 
       192 
244 
     | 
    
         
             
                        self.workspace_buffer, "NHD", backend=fmha_backend
         
     | 
| 
       193 
245 
     | 
    
         
             
                    )
         
     | 
| 
         @@ -229,10 +281,133 @@ class FlashInferAttnBackend(AttentionBackend): 
     | 
|
| 
       229 
281 
     | 
    
         | 
| 
       230 
282 
     | 
    
         
             
                    # Other metadata
         
     | 
| 
       231 
283 
     | 
    
         
             
                    self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
         
     | 
| 
      
 284 
     | 
    
         
            +
             
     | 
| 
       232 
285 
     | 
    
         
             
                    self.decode_cuda_graph_metadata = {}
         
     | 
| 
       233 
286 
     | 
    
         
             
                    self.prefill_cuda_graph_metadata = {}  # For verify
         
     | 
| 
       234 
287 
     | 
    
         
             
                    self.draft_extend_cuda_graph_metadata = {}  # For draft extend
         
     | 
| 
       235 
288 
     | 
    
         | 
| 
      
 289 
     | 
    
         
            +
                def _process_multi_item_scoring(
         
     | 
| 
      
 290 
     | 
    
         
            +
                    self, forward_batch: ForwardBatch
         
     | 
| 
      
 291 
     | 
    
         
            +
                ) -> MultiItemScoringParams:
         
     | 
| 
      
 292 
     | 
    
         
            +
                    """Process multi-item scoring tensors for FlashInfer attention.
         
     | 
| 
      
 293 
     | 
    
         
            +
             
     | 
| 
      
 294 
     | 
    
         
            +
                    This method handles sequences containing multiple "items" separated by delimiter tokens,
         
     | 
| 
      
 295 
     | 
    
         
            +
                    where each item needs specific attention patterns that respect item boundaries.
         
     | 
| 
      
 296 
     | 
    
         
            +
             
     | 
| 
      
 297 
     | 
    
         
            +
                    The method produces four key tensors for FlashInfer:
         
     | 
| 
      
 298 
     | 
    
         
            +
                    - prefix_len_ptr: uint32 tensor with prefix length for each prompt in batch
         
     | 
| 
      
 299 
     | 
    
         
            +
                    - token_pos_in_items_ptr: uint16 tensor with token positions starting from 0 at delimiters
         
     | 
| 
      
 300 
     | 
    
         
            +
                    - token_pos_in_items_len: padding length for batch processing
         
     | 
| 
      
 301 
     | 
    
         
            +
                    - max_item_len_ptr: uint16 tensor with max item length for each prompt
         
     | 
| 
      
 302 
     | 
    
         
            +
             
     | 
| 
      
 303 
     | 
    
         
            +
                    Args:
         
     | 
| 
      
 304 
     | 
    
         
            +
                        forward_batch: The forward batch containing input sequences and delimiter info
         
     | 
| 
      
 305 
     | 
    
         
            +
             
     | 
| 
      
 306 
     | 
    
         
            +
                    Returns:
         
     | 
| 
      
 307 
     | 
    
         
            +
                        MultiItemScoringParams: The processed multi-item scoring parameters
         
     | 
| 
      
 308 
     | 
    
         
            +
             
     | 
| 
      
 309 
     | 
    
         
            +
                    Examples:
         
     | 
| 
      
 310 
     | 
    
         
            +
                        Following FlashInfer definition: for 3 items of length 3, 2, 4 respectively:
         
     | 
| 
      
 311 
     | 
    
         
            +
                        token_pos_in_items_ptr = [0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0]
         
     | 
| 
      
 312 
     | 
    
         
            +
             
     | 
| 
      
 313 
     | 
    
         
            +
                        Case 1: Single sequence
         
     | 
| 
      
 314 
     | 
    
         
            +
                        Text: "What is the capital of France? <delim> London <delim> Paris <delim> Berlin <delim>"
         
     | 
| 
      
 315 
     | 
    
         
            +
                        Tokens: [What, is, the, capital, of, France, ?, <delim>, London, <delim>, Paris, <delim>, Berlin, <delim>]
         
     | 
| 
      
 316 
     | 
    
         
            +
                        Indices: [ 0,   1,  2,   3,      4,  5,     6,   7,     8,      9,     10,    11,    12,     13]
         
     | 
| 
      
 317 
     | 
    
         
            +
                        - prefix_len_ptr: [7] (query length before first delimiter)
         
     | 
| 
      
 318 
     | 
    
         
            +
                        - token_pos_in_items_ptr: [0, 1, 0, 1, 0, 1, 0] (delim=0, London=1, delim=0, Paris=1, delim=0, Berlin=1, delim=0)
         
     | 
| 
      
 319 
     | 
    
         
            +
                        - token_pos_in_items_len: 7 (actual length)
         
     | 
| 
      
 320 
     | 
    
         
            +
                        - max_item_len_ptr: [1] (max item length is 1 token - all options are single tokens)
         
     | 
| 
      
 321 
     | 
    
         
            +
             
     | 
| 
      
 322 
     | 
    
         
            +
                        Case 2: Batch processing (batch_size=2)
         
     | 
| 
      
 323 
     | 
    
         
            +
                        Sequence 1: 2 items of length 2, 1 → [0, 1, 2, 0, 1, 0] (6 elements)
         
     | 
| 
      
 324 
     | 
    
         
            +
                        Sequence 2: 3 items of length 1, 3, 2 → [0, 1, 0, 1, 2, 3, 0, 1, 2, 0] (10 elements)
         
     | 
| 
      
 325 
     | 
    
         
            +
                        After padding both to length 10:
         
     | 
| 
      
 326 
     | 
    
         
            +
                        - token_pos_in_items_ptr: [0, 1, 2, 0, 1, 0, 0, 0, 0, 0,    0, 1, 0, 1, 2, 3, 0, 1, 2, 0]
         
     | 
| 
      
 327 
     | 
    
         
            +
                        - token_pos_in_items_len: 10 (padded length for batch processing)
         
     | 
| 
      
 328 
     | 
    
         
            +
                        - max_item_len_ptr: [2, 3] (max lengths per sequence)
         
     | 
| 
      
 329 
     | 
    
         
            +
                    """
         
     | 
| 
      
 330 
     | 
    
         
            +
             
     | 
| 
      
 331 
     | 
    
         
            +
                    delimiter = self.multi_item_scoring_delimiter
         
     | 
| 
      
 332 
     | 
    
         
            +
                    if delimiter is None or forward_batch.forward_mode == ForwardMode.DECODE:
         
     | 
| 
      
 333 
     | 
    
         
            +
                        return MultiItemScoringParams()
         
     | 
| 
      
 334 
     | 
    
         
            +
             
     | 
| 
      
 335 
     | 
    
         
            +
                    delimiter_mask = forward_batch.input_ids == delimiter
         
     | 
| 
      
 336 
     | 
    
         
            +
                    prefix_cache_lens = getattr(forward_batch, "extend_prefix_lens", None)
         
     | 
| 
      
 337 
     | 
    
         
            +
                    extend_seq_lens = getattr(forward_batch, "extend_seq_lens", None)
         
     | 
| 
      
 338 
     | 
    
         
            +
                    prefix_len_ptr, token_pos_in_items_ptr = [], []
         
     | 
| 
      
 339 
     | 
    
         
            +
                    token_pos_in_items_len = 0
         
     | 
| 
      
 340 
     | 
    
         
            +
             
     | 
| 
      
 341 
     | 
    
         
            +
                    # If no extend_seq_lens, treat whole batch as one sequence
         
     | 
| 
      
 342 
     | 
    
         
            +
                    if extend_seq_lens is None or len(extend_seq_lens) <= 1:
         
     | 
| 
      
 343 
     | 
    
         
            +
                        extend_seq_lens = [forward_batch.input_ids.size(0)]
         
     | 
| 
      
 344 
     | 
    
         
            +
             
     | 
| 
      
 345 
     | 
    
         
            +
                    seq_start = 0
         
     | 
| 
      
 346 
     | 
    
         
            +
                    for i, seq_len in enumerate(extend_seq_lens):
         
     | 
| 
      
 347 
     | 
    
         
            +
                        seq_end = seq_start + seq_len
         
     | 
| 
      
 348 
     | 
    
         
            +
                        mask = delimiter_mask[seq_start:seq_end]
         
     | 
| 
      
 349 
     | 
    
         
            +
                        pos = forward_batch.positions[seq_start:seq_end]
         
     | 
| 
      
 350 
     | 
    
         
            +
                        delimiter_indices = torch.nonzero(mask, as_tuple=True)[0]
         
     | 
| 
      
 351 
     | 
    
         
            +
             
     | 
| 
      
 352 
     | 
    
         
            +
                        if len(delimiter_indices) > 0:
         
     | 
| 
      
 353 
     | 
    
         
            +
                            first_delim = delimiter_indices[0]
         
     | 
| 
      
 354 
     | 
    
         
            +
                            # Prefix length: store as scalar
         
     | 
| 
      
 355 
     | 
    
         
            +
                            prefix_len = first_delim + (
         
     | 
| 
      
 356 
     | 
    
         
            +
                                prefix_cache_lens[i] if prefix_cache_lens is not None else 0
         
     | 
| 
      
 357 
     | 
    
         
            +
                            )
         
     | 
| 
      
 358 
     | 
    
         
            +
                            prefix_len_ptr.append(
         
     | 
| 
      
 359 
     | 
    
         
            +
                                prefix_len.item() if torch.is_tensor(prefix_len) else prefix_len
         
     | 
| 
      
 360 
     | 
    
         
            +
                            )
         
     | 
| 
      
 361 
     | 
    
         
            +
             
     | 
| 
      
 362 
     | 
    
         
            +
                            # Compute relative positions within items after delimiters
         
     | 
| 
      
 363 
     | 
    
         
            +
                            diff = pos[first_delim:] - torch.cummax(mask[first_delim:], 0)[1]
         
     | 
| 
      
 364 
     | 
    
         
            +
                            token_pos = (diff - pos[first_delim]).to(torch.uint16)
         
     | 
| 
      
 365 
     | 
    
         
            +
                            token_pos_in_items_ptr.append(token_pos)
         
     | 
| 
      
 366 
     | 
    
         
            +
             
     | 
| 
      
 367 
     | 
    
         
            +
                            # Update forward_batch positions in-place
         
     | 
| 
      
 368 
     | 
    
         
            +
                            pos[first_delim:] = diff - 1
         
     | 
| 
      
 369 
     | 
    
         
            +
                            forward_batch.positions[seq_start:seq_end] = pos
         
     | 
| 
      
 370 
     | 
    
         
            +
             
     | 
| 
      
 371 
     | 
    
         
            +
                        seq_start = seq_end
         
     | 
| 
      
 372 
     | 
    
         
            +
             
     | 
| 
      
 373 
     | 
    
         
            +
                    # Pad token_pos_in_items_ptr for batch processing
         
     | 
| 
      
 374 
     | 
    
         
            +
                    if token_pos_in_items_ptr:
         
     | 
| 
      
 375 
     | 
    
         
            +
                        token_pos_in_items_len = max(t.numel() for t in token_pos_in_items_ptr)
         
     | 
| 
      
 376 
     | 
    
         
            +
                        device = forward_batch.input_ids.device
         
     | 
| 
      
 377 
     | 
    
         
            +
                        token_pos_in_items_ptr = [
         
     | 
| 
      
 378 
     | 
    
         
            +
                            torch.cat(
         
     | 
| 
      
 379 
     | 
    
         
            +
                                [
         
     | 
| 
      
 380 
     | 
    
         
            +
                                    t,
         
     | 
| 
      
 381 
     | 
    
         
            +
                                    torch.zeros(
         
     | 
| 
      
 382 
     | 
    
         
            +
                                        token_pos_in_items_len - t.numel(),
         
     | 
| 
      
 383 
     | 
    
         
            +
                                        dtype=torch.uint16,
         
     | 
| 
      
 384 
     | 
    
         
            +
                                        device=device,
         
     | 
| 
      
 385 
     | 
    
         
            +
                                    ),
         
     | 
| 
      
 386 
     | 
    
         
            +
                                ]
         
     | 
| 
      
 387 
     | 
    
         
            +
                            )
         
     | 
| 
      
 388 
     | 
    
         
            +
                            for t in token_pos_in_items_ptr
         
     | 
| 
      
 389 
     | 
    
         
            +
                        ]
         
     | 
| 
      
 390 
     | 
    
         
            +
             
     | 
| 
      
 391 
     | 
    
         
            +
                    if not prefix_len_ptr or not token_pos_in_items_ptr:
         
     | 
| 
      
 392 
     | 
    
         
            +
                        return MultiItemScoringParams()
         
     | 
| 
      
 393 
     | 
    
         
            +
             
     | 
| 
      
 394 
     | 
    
         
            +
                    # Build final params
         
     | 
| 
      
 395 
     | 
    
         
            +
                    device = forward_batch.input_ids.device
         
     | 
| 
      
 396 
     | 
    
         
            +
                    return MultiItemScoringParams(
         
     | 
| 
      
 397 
     | 
    
         
            +
                        prefix_len_ptr=torch.tensor(
         
     | 
| 
      
 398 
     | 
    
         
            +
                            prefix_len_ptr, dtype=torch.uint32, device=device
         
     | 
| 
      
 399 
     | 
    
         
            +
                        ),
         
     | 
| 
      
 400 
     | 
    
         
            +
                        token_pos_in_items_ptr=torch.cat(token_pos_in_items_ptr, dim=0),
         
     | 
| 
      
 401 
     | 
    
         
            +
                        token_pos_in_items_len=token_pos_in_items_len & 0xFFFFFFFF,
         
     | 
| 
      
 402 
     | 
    
         
            +
                        max_item_len_ptr=torch.stack(
         
     | 
| 
      
 403 
     | 
    
         
            +
                            [
         
     | 
| 
      
 404 
     | 
    
         
            +
                                t.to(torch.int32).max().to(torch.uint16)
         
     | 
| 
      
 405 
     | 
    
         
            +
                                for t in token_pos_in_items_ptr
         
     | 
| 
      
 406 
     | 
    
         
            +
                            ],
         
     | 
| 
      
 407 
     | 
    
         
            +
                            dim=0,
         
     | 
| 
      
 408 
     | 
    
         
            +
                        ),
         
     | 
| 
      
 409 
     | 
    
         
            +
                    )
         
     | 
| 
      
 410 
     | 
    
         
            +
             
     | 
| 
       236 
411 
     | 
    
         
             
                def init_forward_metadata(self, forward_batch: ForwardBatch):
         
     | 
| 
       237 
412 
     | 
    
         
             
                    if forward_batch.forward_mode.is_decode_or_idle():
         
     | 
| 
       238 
413 
     | 
    
         
             
                        self.indices_updater_decode.update(
         
     | 
| 
         @@ -280,13 +455,26 @@ class FlashInferAttnBackend(AttentionBackend): 
     | 
|
| 
       280 
455 
     | 
    
         
             
                    else:
         
     | 
| 
       281 
456 
     | 
    
         
             
                        prefix_lens = forward_batch.extend_prefix_lens
         
     | 
| 
       282 
457 
     | 
    
         | 
| 
       283 
     | 
    
         
            -
                         
     | 
| 
      
 458 
     | 
    
         
            +
                        # Disable ragged wrapper and ensure prefix handling for multimodal and multi-item scoring
         
     | 
| 
      
 459 
     | 
    
         
            +
                        if self.is_multimodal or self.multi_item_scoring_delimiter is not None:
         
     | 
| 
      
 460 
     | 
    
         
            +
                            # use_ragged = False: Multi-item scoring requires the paged wrapper because:
         
     | 
| 
      
 461 
     | 
    
         
            +
                            # 1. Ragged wrapper doesn't support the specialized multi-item parameters
         
     | 
| 
      
 462 
     | 
    
         
            +
                            #    (prefix_len_ptr, token_pos_in_items_ptr, etc.)
         
     | 
| 
      
 463 
     | 
    
         
            +
                            # 2. Paged wrapper provides better control over attention masking needed
         
     | 
| 
      
 464 
     | 
    
         
            +
                            #    for respecting item boundaries in multi-item sequences
         
     | 
| 
      
 465 
     | 
    
         
            +
                            # 3. Custom masking logic conflicts with ragged wrapper's assumptions
         
     | 
| 
       284 
466 
     | 
    
         
             
                            use_ragged = False
         
     | 
| 
       285 
467 
     | 
    
         
             
                            extend_no_prefix = False
         
     | 
| 
       286 
468 
     | 
    
         
             
                        else:
         
     | 
| 
       287 
469 
     | 
    
         
             
                            use_ragged = not self.enable_deterministic
         
     | 
| 
       288 
470 
     | 
    
         
             
                            extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
         
     | 
| 
       289 
471 
     | 
    
         | 
| 
      
 472 
     | 
    
         
            +
                        # Process multi-item scoring in attention backend instead of ForwardBatch
         
     | 
| 
      
 473 
     | 
    
         
            +
                        multi_item_params = MultiItemScoringParams()
         
     | 
| 
      
 474 
     | 
    
         
            +
                        if self.multi_item_scoring_delimiter is not None:
         
     | 
| 
      
 475 
     | 
    
         
            +
                            # Use new backend-specific implementation
         
     | 
| 
      
 476 
     | 
    
         
            +
                            multi_item_params = self._process_multi_item_scoring(forward_batch)
         
     | 
| 
      
 477 
     | 
    
         
            +
             
     | 
| 
       290 
478 
     | 
    
         
             
                        self.indices_updater_prefill.update(
         
     | 
| 
       291 
479 
     | 
    
         
             
                            forward_batch.req_pool_indices,
         
     | 
| 
       292 
480 
     | 
    
         
             
                            forward_batch.seq_lens,
         
     | 
| 
         @@ -298,9 +486,13 @@ class FlashInferAttnBackend(AttentionBackend): 
     | 
|
| 
       298 
486 
     | 
    
         
             
                            encoder_lens=forward_batch.encoder_lens,
         
     | 
| 
       299 
487 
     | 
    
         
             
                            spec_info=None,
         
     | 
| 
       300 
488 
     | 
    
         
             
                            fixed_split_size=self.prefill_split_tile_size,
         
     | 
| 
      
 489 
     | 
    
         
            +
                            multi_item_params=multi_item_params,
         
     | 
| 
       301 
490 
     | 
    
         
             
                        )
         
     | 
| 
       302 
491 
     | 
    
         
             
                        self.forward_metadata = PrefillMetadata(
         
     | 
| 
       303 
     | 
    
         
            -
                            self.prefill_wrappers_paged, 
     | 
| 
      
 492 
     | 
    
         
            +
                            self.prefill_wrappers_paged,
         
     | 
| 
      
 493 
     | 
    
         
            +
                            use_ragged,
         
     | 
| 
      
 494 
     | 
    
         
            +
                            extend_no_prefix,
         
     | 
| 
      
 495 
     | 
    
         
            +
                            multi_item_params,
         
     | 
| 
       304 
496 
     | 
    
         
             
                        )
         
     | 
| 
       305 
497 
     | 
    
         | 
| 
       306 
498 
     | 
    
         
             
                def init_cuda_graph_state(
         
     | 
| 
         @@ -531,7 +723,20 @@ class FlashInferAttnBackend(AttentionBackend): 
     | 
|
| 
       531 
723 
     | 
    
         
             
                            forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
         
     | 
| 
       532 
724 
     | 
    
         
             
                            causal=not layer.is_cross_attention,
         
     | 
| 
       533 
725 
     | 
    
         
             
                            sm_scale=layer.scaling,
         
     | 
| 
       534 
     | 
    
         
            -
                             
     | 
| 
      
 726 
     | 
    
         
            +
                            # Disable sliding window attention for multi-item scoring:
         
     | 
| 
      
 727 
     | 
    
         
            +
                            # - Sliding window could cut across item boundaries, breaking semantic coherence
         
     | 
| 
      
 728 
     | 
    
         
            +
                            # - Multi-item sequences need full attention to properly handle delimiter tokens
         
     | 
| 
      
 729 
     | 
    
         
            +
                            # - Specialized multi-item parameters (prefix_len_ptr, token_pos_in_items_ptr)
         
     | 
| 
      
 730 
     | 
    
         
            +
                            #   provide more precise attention control than simple sliding windows
         
     | 
| 
      
 731 
     | 
    
         
            +
                            # - Item-aware masking takes precedence over window-based masking
         
     | 
| 
      
 732 
     | 
    
         
            +
                            window_left=(
         
     | 
| 
      
 733 
     | 
    
         
            +
                                layer.sliding_window_size
         
     | 
| 
      
 734 
     | 
    
         
            +
                                if not (
         
     | 
| 
      
 735 
     | 
    
         
            +
                                    self.forward_metadata.multi_item_params
         
     | 
| 
      
 736 
     | 
    
         
            +
                                    and self.forward_metadata.multi_item_params.is_enabled()
         
     | 
| 
      
 737 
     | 
    
         
            +
                                )
         
     | 
| 
      
 738 
     | 
    
         
            +
                                else -1
         
     | 
| 
      
 739 
     | 
    
         
            +
                            ),
         
     | 
| 
       535 
740 
     | 
    
         
             
                            logits_soft_cap=logits_soft_cap,
         
     | 
| 
       536 
741 
     | 
    
         
             
                            # Must use _float to avoid device-to-host copy that breaks cuda graph capture.
         
     | 
| 
       537 
742 
     | 
    
         
             
                            k_scale=layer.k_scale_float,
         
     | 
| 
         @@ -539,9 +744,13 @@ class FlashInferAttnBackend(AttentionBackend): 
     | 
|
| 
       539 
744 
     | 
    
         
             
                        )
         
     | 
| 
       540 
745 
     | 
    
         
             
                    else:
         
     | 
| 
       541 
746 
     | 
    
         
             
                        causal = True
         
     | 
| 
       542 
     | 
    
         
            -
                        if  
     | 
| 
       543 
     | 
    
         
            -
                             
     | 
| 
      
 747 
     | 
    
         
            +
                        if (
         
     | 
| 
      
 748 
     | 
    
         
            +
                            layer.is_cross_attention
         
     | 
| 
      
 749 
     | 
    
         
            +
                            or layer.attn_type == AttentionType.ENCODER_ONLY
         
     | 
| 
      
 750 
     | 
    
         
            +
                        ):
         
     | 
| 
       544 
751 
     | 
    
         
             
                            causal = False
         
     | 
| 
      
 752 
     | 
    
         
            +
                        if save_kv_cache and layer.attn_type == AttentionType.ENCODER_ONLY:
         
     | 
| 
      
 753 
     | 
    
         
            +
                            save_kv_cache = False
         
     | 
| 
       545 
754 
     | 
    
         | 
| 
       546 
755 
     | 
    
         
             
                        if self.forward_metadata.extend_no_prefix:
         
     | 
| 
       547 
756 
     | 
    
         
             
                            # NOTE: FlashInfer currently has limitations with head_dim = 32 or other dimensions
         
     | 
| 
         @@ -952,6 +1161,7 @@ class FlashInferIndicesUpdaterPrefill: 
     | 
|
| 
       952 
1161 
     | 
    
         
             
                    encoder_lens: Optional[torch.Tensor],
         
     | 
| 
       953 
1162 
     | 
    
         
             
                    spec_info: Optional[SpecInput],
         
     | 
| 
       954 
1163 
     | 
    
         
             
                    fixed_split_size: Optional[int] = None,
         
     | 
| 
      
 1164 
     | 
    
         
            +
                    multi_item_params: Optional[MultiItemScoringParams] = None,
         
     | 
| 
       955 
1165 
     | 
    
         
             
                ):
         
     | 
| 
       956 
1166 
     | 
    
         
             
                    if use_ragged:
         
     | 
| 
       957 
1167 
     | 
    
         
             
                        # TODO: remove this device sync, we can use forward_batch.extend_prefix_lens_cpu
         
     | 
| 
         @@ -976,6 +1186,7 @@ class FlashInferIndicesUpdaterPrefill: 
     | 
|
| 
       976 
1186 
     | 
    
         
             
                        use_ragged,
         
     | 
| 
       977 
1187 
     | 
    
         
             
                        spec_info,
         
     | 
| 
       978 
1188 
     | 
    
         
             
                        fixed_split_size=fixed_split_size,
         
     | 
| 
      
 1189 
     | 
    
         
            +
                        multi_item_params=multi_item_params,
         
     | 
| 
       979 
1190 
     | 
    
         
             
                    )
         
     | 
| 
       980 
1191 
     | 
    
         | 
| 
       981 
1192 
     | 
    
         
             
                def update_sliding_window(
         
     | 
| 
         @@ -990,6 +1201,7 @@ class FlashInferIndicesUpdaterPrefill: 
     | 
|
| 
       990 
1201 
     | 
    
         
             
                    encoder_lens: Optional[torch.Tensor],
         
     | 
| 
       991 
1202 
     | 
    
         
             
                    spec_info: Optional[SpecInput],
         
     | 
| 
       992 
1203 
     | 
    
         
             
                    fixed_split_size: Optional[int] = None,
         
     | 
| 
      
 1204 
     | 
    
         
            +
                    multi_item_params: Optional[MultiItemScoringParams] = None,
         
     | 
| 
       993 
1205 
     | 
    
         
             
                ):
         
     | 
| 
       994 
1206 
     | 
    
         
             
                    for wrapper_id in range(2):
         
     | 
| 
       995 
1207 
     | 
    
         
             
                        if wrapper_id == 0:
         
     | 
| 
         @@ -1023,6 +1235,7 @@ class FlashInferIndicesUpdaterPrefill: 
     | 
|
| 
       1023 
1235 
     | 
    
         
             
                            use_ragged,
         
     | 
| 
       1024 
1236 
     | 
    
         
             
                            spec_info,
         
     | 
| 
       1025 
1237 
     | 
    
         
             
                            use_sliding_window_kv_pool=use_sliding_window_kv_pool,
         
     | 
| 
      
 1238 
     | 
    
         
            +
                            multi_item_params=multi_item_params,
         
     | 
| 
       1026 
1239 
     | 
    
         
             
                        )
         
     | 
| 
       1027 
1240 
     | 
    
         | 
| 
       1028 
1241 
     | 
    
         
             
                def update_cross_attention(
         
     | 
| 
         @@ -1037,6 +1250,7 @@ class FlashInferIndicesUpdaterPrefill: 
     | 
|
| 
       1037 
1250 
     | 
    
         
             
                    encoder_lens: Optional[torch.Tensor],
         
     | 
| 
       1038 
1251 
     | 
    
         
             
                    spec_info: Optional[SpecInput],
         
     | 
| 
       1039 
1252 
     | 
    
         
             
                    fixed_split_size: Optional[int] = None,
         
     | 
| 
      
 1253 
     | 
    
         
            +
                    multi_item_params: Optional[MultiItemScoringParams] = None,
         
     | 
| 
       1040 
1254 
     | 
    
         
             
                ):
         
     | 
| 
       1041 
1255 
     | 
    
         
             
                    for wrapper_id in range(2):
         
     | 
| 
       1042 
1256 
     | 
    
         
             
                        if wrapper_id == 0:
         
     | 
| 
         @@ -1063,6 +1277,7 @@ class FlashInferIndicesUpdaterPrefill: 
     | 
|
| 
       1063 
1277 
     | 
    
         
             
                            self.qo_indptr[wrapper_id],
         
     | 
| 
       1064 
1278 
     | 
    
         
             
                            use_ragged,
         
     | 
| 
       1065 
1279 
     | 
    
         
             
                            spec_info,
         
     | 
| 
      
 1280 
     | 
    
         
            +
                            multi_item_params=multi_item_params,
         
     | 
| 
       1066 
1281 
     | 
    
         
             
                        )
         
     | 
| 
       1067 
1282 
     | 
    
         | 
| 
       1068 
1283 
     | 
    
         
             
                def call_begin_forward(
         
     | 
| 
         @@ -1081,6 +1296,7 @@ class FlashInferIndicesUpdaterPrefill: 
     | 
|
| 
       1081 
1296 
     | 
    
         
             
                    spec_info: Optional[SpecInput],
         
     | 
| 
       1082 
1297 
     | 
    
         
             
                    use_sliding_window_kv_pool: bool = False,
         
     | 
| 
       1083 
1298 
     | 
    
         
             
                    fixed_split_size: Optional[int] = None,
         
     | 
| 
      
 1299 
     | 
    
         
            +
                    multi_item_params: Optional[MultiItemScoringParams] = None,
         
     | 
| 
       1084 
1300 
     | 
    
         
             
                ):
         
     | 
| 
       1085 
1301 
     | 
    
         
             
                    bs = len(seq_lens)
         
     | 
| 
       1086 
1302 
     | 
    
         
             
                    if spec_info is None:
         
     | 
| 
         @@ -1136,6 +1352,22 @@ class FlashInferIndicesUpdaterPrefill: 
     | 
|
| 
       1136 
1352 
     | 
    
         
             
                        )
         
     | 
| 
       1137 
1353 
     | 
    
         | 
| 
       1138 
1354 
     | 
    
         
             
                    # cached part
         
     | 
| 
      
 1355 
     | 
    
         
            +
                    # Conditionally set multi-item parameters
         
     | 
| 
      
 1356 
     | 
    
         
            +
                    if multi_item_params is not None and multi_item_params.is_enabled():
         
     | 
| 
      
 1357 
     | 
    
         
            +
                        # Multi-item scoring is active - use specialized parameters and disable generic custom_mask
         
     | 
| 
      
 1358 
     | 
    
         
            +
                        use_custom_mask = None
         
     | 
| 
      
 1359 
     | 
    
         
            +
                        prefix_len_ptr = multi_item_params.prefix_len_ptr
         
     | 
| 
      
 1360 
     | 
    
         
            +
                        token_pos_in_items_ptr = multi_item_params.token_pos_in_items_ptr
         
     | 
| 
      
 1361 
     | 
    
         
            +
                        token_pos_in_items_len = multi_item_params.token_pos_in_items_len
         
     | 
| 
      
 1362 
     | 
    
         
            +
                        max_item_len_ptr = multi_item_params.max_item_len_ptr
         
     | 
| 
      
 1363 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 1364 
     | 
    
         
            +
                        # No multi-item scoring - use standard parameters
         
     | 
| 
      
 1365 
     | 
    
         
            +
                        use_custom_mask = custom_mask
         
     | 
| 
      
 1366 
     | 
    
         
            +
                        prefix_len_ptr = None
         
     | 
| 
      
 1367 
     | 
    
         
            +
                        token_pos_in_items_ptr = None
         
     | 
| 
      
 1368 
     | 
    
         
            +
                        token_pos_in_items_len = 0
         
     | 
| 
      
 1369 
     | 
    
         
            +
                        max_item_len_ptr = None
         
     | 
| 
      
 1370 
     | 
    
         
            +
             
     | 
| 
       1139 
1371 
     | 
    
         
             
                    wrapper_paged.begin_forward(
         
     | 
| 
       1140 
1372 
     | 
    
         
             
                        qo_indptr,
         
     | 
| 
       1141 
1373 
     | 
    
         
             
                        kv_indptr,
         
     | 
| 
         @@ -1147,9 +1379,13 @@ class FlashInferIndicesUpdaterPrefill: 
     | 
|
| 
       1147 
1379 
     | 
    
         
             
                        1,
         
     | 
| 
       1148 
1380 
     | 
    
         
             
                        q_data_type=self.q_data_type,
         
     | 
| 
       1149 
1381 
     | 
    
         
             
                        kv_data_type=self.data_type,
         
     | 
| 
       1150 
     | 
    
         
            -
                        custom_mask= 
     | 
| 
      
 1382 
     | 
    
         
            +
                        custom_mask=use_custom_mask,
         
     | 
| 
       1151 
1383 
     | 
    
         
             
                        non_blocking=True,
         
     | 
| 
       1152 
1384 
     | 
    
         
             
                        fixed_split_size=fixed_split_size,
         
     | 
| 
      
 1385 
     | 
    
         
            +
                        prefix_len_ptr=prefix_len_ptr,
         
     | 
| 
      
 1386 
     | 
    
         
            +
                        token_pos_in_items_ptr=token_pos_in_items_ptr,
         
     | 
| 
      
 1387 
     | 
    
         
            +
                        token_pos_in_items_len=token_pos_in_items_len,
         
     | 
| 
      
 1388 
     | 
    
         
            +
                        max_item_len_ptr=max_item_len_ptr,
         
     | 
| 
       1153 
1389 
     | 
    
         
             
                    )
         
     | 
| 
       1154 
1390 
     | 
    
         | 
| 
       1155 
1391 
     | 
    
         | 
| 
         @@ -1185,7 +1421,7 @@ class FlashInferMultiStepDraftBackend: 
     | 
|
| 
       1185 
1421 
     | 
    
         
             
                        (max_bs,), dtype=torch.int32, device=model_runner.device
         
     | 
| 
       1186 
1422 
     | 
    
         
             
                    )
         
     | 
| 
       1187 
1423 
     | 
    
         
             
                    self.attn_backends: List[FlashInferAttnBackend] = []
         
     | 
| 
       1188 
     | 
    
         
            -
                    for i in range(self.speculative_num_steps):
         
     | 
| 
      
 1424 
     | 
    
         
            +
                    for i in range(self.speculative_num_steps - 1):
         
     | 
| 
       1189 
1425 
     | 
    
         
             
                        self.attn_backends.append(
         
     | 
| 
       1190 
1426 
     | 
    
         
             
                            FlashInferAttnBackend(
         
     | 
| 
       1191 
1427 
     | 
    
         
             
                                model_runner,
         
     | 
| 
         @@ -1273,7 +1509,7 @@ class FlashInferMultiStepDraftBackend: 
     | 
|
| 
       1273 
1509 
     | 
    
         
             
                        device="cuda",
         
     | 
| 
       1274 
1510 
     | 
    
         
             
                    )
         
     | 
| 
       1275 
1511 
     | 
    
         | 
| 
       1276 
     | 
    
         
            -
                    for i in range(self.speculative_num_steps):
         
     | 
| 
      
 1512 
     | 
    
         
            +
                    for i in range(self.speculative_num_steps - 1):
         
     | 
| 
       1277 
1513 
     | 
    
         
             
                        self.attn_backends[i].init_cuda_graph_state(
         
     | 
| 
       1278 
1514 
     | 
    
         
             
                            max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
         
     | 
| 
       1279 
1515 
     | 
    
         
             
                        )
         
     | 
| 
         @@ -9,27 +9,20 @@ and uses BatchMLAPaged wrapper for decoding. 
     | 
|
| 
       9 
9 
     | 
    
         
             
            More details can be found in https://docs.flashinfer.ai/api/mla.html
         
     | 
| 
       10 
10 
     | 
    
         
             
            """
         
     | 
| 
       11 
11 
     | 
    
         | 
| 
       12 
     | 
    
         
            -
            import os
         
     | 
| 
       13 
12 
     | 
    
         
             
            from dataclasses import dataclass
         
     | 
| 
       14 
13 
     | 
    
         
             
            from functools import partial
         
     | 
| 
       15 
14 
     | 
    
         
             
            from typing import TYPE_CHECKING, Callable, Optional, Union
         
     | 
| 
       16 
15 
     | 
    
         | 
| 
       17 
16 
     | 
    
         
             
            import torch
         
     | 
| 
       18 
17 
     | 
    
         | 
| 
       19 
     | 
    
         
            -
             
     | 
| 
       20 
     | 
    
         
            -
                import logging
         
     | 
| 
       21 
     | 
    
         
            -
             
     | 
| 
       22 
     | 
    
         
            -
                torch._logging.set_logs(dynamo=logging.ERROR)
         
     | 
| 
       23 
     | 
    
         
            -
                torch._dynamo.config.suppress_errors = True
         
     | 
| 
       24 
     | 
    
         
            -
             
     | 
| 
       25 
     | 
    
         
            -
            from sglang.global_config import global_config
         
     | 
| 
      
 18 
     | 
    
         
            +
            from sglang.srt.environ import envs
         
     | 
| 
       26 
19 
     | 
    
         
             
            from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
         
     | 
| 
       27 
20 
     | 
    
         
             
            from sglang.srt.layers.attention.flashinfer_backend import (
         
     | 
| 
       28 
21 
     | 
    
         
             
                create_flashinfer_kv_indices_triton,
         
     | 
| 
       29 
22 
     | 
    
         
             
            )
         
     | 
| 
       30 
23 
     | 
    
         
             
            from sglang.srt.layers.dp_attention import get_attention_tp_size
         
     | 
| 
       31 
     | 
    
         
            -
            from sglang.srt.managers.schedule_batch import global_server_args_dict
         
     | 
| 
       32 
24 
     | 
    
         
             
            from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
         
     | 
| 
      
 25 
     | 
    
         
            +
            from sglang.srt.server_args import get_global_server_args
         
     | 
| 
       33 
26 
     | 
    
         
             
            from sglang.srt.speculative.spec_info import SpecInput
         
     | 
| 
       34 
27 
     | 
    
         
             
            from sglang.srt.utils import (
         
     | 
| 
       35 
28 
     | 
    
         
             
                is_flashinfer_available,
         
     | 
| 
         @@ -38,10 +31,19 @@ from sglang.srt.utils import ( 
     | 
|
| 
       38 
31 
     | 
    
         
             
            )
         
     | 
| 
       39 
32 
     | 
    
         | 
| 
       40 
33 
     | 
    
         
             
            if TYPE_CHECKING:
         
     | 
| 
      
 34 
     | 
    
         
            +
                from sglang.srt.layers.attention.flashinfer_mla_backend import (
         
     | 
| 
      
 35 
     | 
    
         
            +
                    FlashInferMlaAttnBackend,
         
     | 
| 
      
 36 
     | 
    
         
            +
                )
         
     | 
| 
       41 
37 
     | 
    
         
             
                from sglang.srt.layers.radix_attention import RadixAttention
         
     | 
| 
       42 
38 
     | 
    
         
             
                from sglang.srt.model_executor.model_runner import ModelRunner
         
     | 
| 
       43 
39 
     | 
    
         
             
                from sglang.srt.speculative.spec_info import SpecInput
         
     | 
| 
       44 
40 
     | 
    
         | 
| 
      
 41 
     | 
    
         
            +
            if envs.SGLANG_ENABLE_TORCH_COMPILE.get():
         
     | 
| 
      
 42 
     | 
    
         
            +
                import logging
         
     | 
| 
      
 43 
     | 
    
         
            +
             
     | 
| 
      
 44 
     | 
    
         
            +
                torch._logging.set_logs(dynamo=logging.ERROR)
         
     | 
| 
      
 45 
     | 
    
         
            +
                torch._dynamo.config.suppress_errors = True
         
     | 
| 
      
 46 
     | 
    
         
            +
             
     | 
| 
       45 
47 
     | 
    
         
             
            if is_flashinfer_available():
         
     | 
| 
       46 
48 
     | 
    
         
             
                from flashinfer import (
         
     | 
| 
       47 
49 
     | 
    
         
             
                    BatchMLAPagedAttentionWrapper,
         
     | 
| 
         @@ -66,7 +68,7 @@ global_workspace_buffer = None 
     | 
|
| 
       66 
68 
     | 
    
         | 
| 
       67 
69 
     | 
    
         
             
            class FlashInferMhaChunkKVRunner:
         
     | 
| 
       68 
70 
     | 
    
         
             
                def __init__(
         
     | 
| 
       69 
     | 
    
         
            -
                    self, model_runner: ModelRunner, attn_backend:  
     | 
| 
      
 71 
     | 
    
         
            +
                    self, model_runner: ModelRunner, attn_backend: FlashInferMlaAttnBackend
         
     | 
| 
       70 
72 
     | 
    
         
             
                ):
         
     | 
| 
       71 
73 
     | 
    
         
             
                    # Parse Constants
         
     | 
| 
       72 
74 
     | 
    
         
             
                    self.num_local_heads = (
         
     | 
| 
         @@ -80,6 +82,7 @@ class FlashInferMhaChunkKVRunner: 
     | 
|
| 
       80 
82 
     | 
    
         | 
| 
       81 
83 
     | 
    
         
             
                    # Buffers and wrappers
         
     | 
| 
       82 
84 
     | 
    
         
             
                    self.qo_indptr = attn_backend.qo_indptr
         
     | 
| 
      
 85 
     | 
    
         
            +
                    self.kv_indptr = attn_backend.kv_indptr
         
     | 
| 
       83 
86 
     | 
    
         
             
                    self.workspace_buffer = attn_backend.workspace_buffer
         
     | 
| 
       84 
87 
     | 
    
         
             
                    self.fmha_backend = attn_backend.fmha_backend
         
     | 
| 
       85 
88 
     | 
    
         | 
| 
         @@ -130,9 +133,14 @@ class FlashInferMhaChunkKVRunner: 
     | 
|
| 
       130 
133 
     | 
    
         
             
                        )
         
     | 
| 
       131 
134 
     | 
    
         
             
                    # ragged prefill
         
     | 
| 
       132 
135 
     | 
    
         
             
                    if not disable_flashinfer_ragged:
         
     | 
| 
      
 136 
     | 
    
         
            +
                        kv_indptr = (
         
     | 
| 
      
 137 
     | 
    
         
            +
                            qo_indptr
         
     | 
| 
      
 138 
     | 
    
         
            +
                            if not forward_batch.mha_one_shot
         
     | 
| 
      
 139 
     | 
    
         
            +
                            else self.kv_indptr[: bs + 1]
         
     | 
| 
      
 140 
     | 
    
         
            +
                        )
         
     | 
| 
       133 
141 
     | 
    
         
             
                        self.ragged_wrapper.begin_forward(
         
     | 
| 
       134 
142 
     | 
    
         
             
                            qo_indptr=qo_indptr,
         
     | 
| 
       135 
     | 
    
         
            -
                            kv_indptr= 
     | 
| 
      
 143 
     | 
    
         
            +
                            kv_indptr=kv_indptr,
         
     | 
| 
       136 
144 
     | 
    
         
             
                            num_qo_heads=self.num_local_heads,
         
     | 
| 
       137 
145 
     | 
    
         
             
                            num_kv_heads=self.num_local_heads,
         
     | 
| 
       138 
146 
     | 
    
         
             
                            head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
         
     | 
| 
         @@ -154,7 +162,7 @@ class FlashInferMhaChunkKVRunner: 
     | 
|
| 
       154 
162 
     | 
    
         
             
                        chunk_idx = forward_batch.prefix_chunk_idx
         
     | 
| 
       155 
163 
     | 
    
         
             
                        assert chunk_idx >= 0
         
     | 
| 
       156 
164 
     | 
    
         
             
                        wrapper = self.chunk_ragged_wrappers[chunk_idx]
         
     | 
| 
       157 
     | 
    
         
            -
                         
     | 
| 
      
 165 
     | 
    
         
            +
                        o = wrapper.forward_return_lse(
         
     | 
| 
       158 
166 
     | 
    
         
             
                            q.view(-1, layer.tp_q_head_num, layer.head_dim),
         
     | 
| 
       159 
167 
     | 
    
         
             
                            k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
         
     | 
| 
       160 
168 
     | 
    
         
             
                            v.view(-1, layer.tp_v_head_num, layer.v_head_dim).to(q.dtype),
         
     | 
| 
         @@ -163,7 +171,12 @@ class FlashInferMhaChunkKVRunner: 
     | 
|
| 
       163 
171 
     | 
    
         
             
                            logits_soft_cap=logits_soft_cap,
         
     | 
| 
       164 
172 
     | 
    
         
             
                        )
         
     | 
| 
       165 
173 
     | 
    
         
             
                    else:
         
     | 
| 
       166 
     | 
    
         
            -
                         
     | 
| 
      
 174 
     | 
    
         
            +
                        forward = (
         
     | 
| 
      
 175 
     | 
    
         
            +
                            self.ragged_wrapper.forward_return_lse
         
     | 
| 
      
 176 
     | 
    
         
            +
                            if forward_batch.mha_return_lse
         
     | 
| 
      
 177 
     | 
    
         
            +
                            else self.ragged_wrapper.forward
         
     | 
| 
      
 178 
     | 
    
         
            +
                        )
         
     | 
| 
      
 179 
     | 
    
         
            +
                        o = forward(
         
     | 
| 
       167 
180 
     | 
    
         
             
                            q.view(-1, layer.tp_q_head_num, layer.head_dim),
         
     | 
| 
       168 
181 
     | 
    
         
             
                            k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
         
     | 
| 
       169 
182 
     | 
    
         
             
                            v.view(-1, layer.tp_v_head_num, layer.v_head_dim).to(q.dtype),
         
     | 
| 
         @@ -171,8 +184,7 @@ class FlashInferMhaChunkKVRunner: 
     | 
|
| 
       171 
184 
     | 
    
         
             
                            sm_scale=layer.scaling,
         
     | 
| 
       172 
185 
     | 
    
         
             
                            logits_soft_cap=logits_soft_cap,
         
     | 
| 
       173 
186 
     | 
    
         
             
                        )
         
     | 
| 
       174 
     | 
    
         
            -
             
     | 
| 
       175 
     | 
    
         
            -
                    return o1, s1
         
     | 
| 
      
 187 
     | 
    
         
            +
                    return o
         
     | 
| 
       176 
188 
     | 
    
         | 
| 
       177 
189 
     | 
    
         | 
| 
       178 
190 
     | 
    
         
             
            class FlashInferMLAAttnBackend(AttentionBackend):
         
     | 
| 
         @@ -193,9 +205,9 @@ class FlashInferMLAAttnBackend(AttentionBackend): 
     | 
|
| 
       193 
205 
     | 
    
         
             
                    self.skip_prefill = skip_prefill
         
     | 
| 
       194 
206 
     | 
    
         
             
                    self.enable_chunk_kv = (
         
     | 
| 
       195 
207 
     | 
    
         
             
                        not skip_prefill
         
     | 
| 
       196 
     | 
    
         
            -
                        and  
     | 
| 
       197 
     | 
    
         
            -
                        and not  
     | 
| 
       198 
     | 
    
         
            -
                        and not  
     | 
| 
      
 208 
     | 
    
         
            +
                        and get_global_server_args().disaggregation_mode != "decode"
         
     | 
| 
      
 209 
     | 
    
         
            +
                        and not get_global_server_args().disable_chunked_prefix_cache
         
     | 
| 
      
 210 
     | 
    
         
            +
                        and not get_global_server_args().flashinfer_mla_disable_ragged
         
     | 
| 
       199 
211 
     | 
    
         
             
                    )
         
     | 
| 
       200 
212 
     | 
    
         
             
                    self.page_size = model_runner.page_size
         
     | 
| 
       201 
213 
     | 
    
         | 
| 
         @@ -204,7 +216,7 @@ class FlashInferMLAAttnBackend(AttentionBackend): 
     | 
|
| 
       204 
216 
     | 
    
         
             
                    if global_workspace_buffer is None:
         
     | 
| 
       205 
217 
     | 
    
         
             
                        # different from flashinfer zero_init_global_workspace_buffer
         
     | 
| 
       206 
218 
     | 
    
         
             
                        global_workspace_buffer = torch.empty(
         
     | 
| 
       207 
     | 
    
         
            -
                             
     | 
| 
      
 219 
     | 
    
         
            +
                            envs.SGLANG_FLASHINFER_WORKSPACE_SIZE.get(),
         
     | 
| 
       208 
220 
     | 
    
         
             
                            dtype=torch.uint8,
         
     | 
| 
       209 
221 
     | 
    
         
             
                            device=model_runner.device,
         
     | 
| 
       210 
222 
     | 
    
         
             
                        )
         
     | 
| 
         @@ -306,7 +318,7 @@ class FlashInferMLAAttnBackend(AttentionBackend): 
     | 
|
| 
       306 
318 
     | 
    
         
             
                        prefix_lens = forward_batch.extend_prefix_lens
         
     | 
| 
       307 
319 
     | 
    
         
             
                        extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
         
     | 
| 
       308 
320 
     | 
    
         
             
                        use_ragged = (
         
     | 
| 
       309 
     | 
    
         
            -
                            not  
     | 
| 
      
 321 
     | 
    
         
            +
                            not get_global_server_args().flashinfer_mla_disable_ragged
         
     | 
| 
       310 
322 
     | 
    
         
             
                            and extend_no_prefix
         
     | 
| 
       311 
323 
     | 
    
         
             
                        )
         
     | 
| 
       312 
324 
     | 
    
         | 
| 
         @@ -510,15 +522,13 @@ class FlashInferMLAAttnBackend(AttentionBackend): 
     | 
|
| 
       510 
522 
     | 
    
         
             
                    q_rope: Optional[torch.Tensor] = None,
         
     | 
| 
       511 
523 
     | 
    
         
             
                    k_rope: Optional[torch.Tensor] = None,
         
     | 
| 
       512 
524 
     | 
    
         
             
                ):
         
     | 
| 
       513 
     | 
    
         
            -
                    if (
         
     | 
| 
       514 
     | 
    
         
            -
                        forward_batch. 
     | 
| 
       515 
     | 
    
         
            -
                        and forward_batch.mha_return_lse
         
     | 
| 
      
 525 
     | 
    
         
            +
                    if forward_batch.attn_attend_prefix_cache is not None and any(
         
     | 
| 
      
 526 
     | 
    
         
            +
                        forward_batch.extend_prefix_lens_cpu
         
     | 
| 
       516 
527 
     | 
    
         
             
                    ):  # MHA Chunk
         
     | 
| 
       517 
528 
     | 
    
         
             
                        assert self.enable_chunk_kv
         
     | 
| 
       518 
529 
     | 
    
         
             
                        assert q_rope is None
         
     | 
| 
       519 
530 
     | 
    
         
             
                        assert k_rope is None
         
     | 
| 
       520 
     | 
    
         
            -
                         
     | 
| 
       521 
     | 
    
         
            -
                        return o1, s1
         
     | 
| 
      
 531 
     | 
    
         
            +
                        return self.mha_chunk_kv_cache.forward(q, k, v, layer, forward_batch)
         
     | 
| 
       522 
532 
     | 
    
         | 
| 
       523 
533 
     | 
    
         
             
                    cache_loc = forward_batch.out_cache_loc
         
     | 
| 
       524 
534 
     | 
    
         
             
                    logits_soft_cap = layer.logit_cap
         
     | 
| 
         @@ -916,7 +926,7 @@ class FlashInferMLAMultiStepDraftBackend: 
     | 
|
| 
       916 
926 
     | 
    
         
             
                    )
         
     | 
| 
       917 
927 
     | 
    
         | 
| 
       918 
928 
     | 
    
         
             
                    self.attn_backends = []
         
     | 
| 
       919 
     | 
    
         
            -
                    for i in range(self.speculative_num_steps):
         
     | 
| 
      
 929 
     | 
    
         
            +
                    for i in range(self.speculative_num_steps - 1):
         
     | 
| 
       920 
930 
     | 
    
         
             
                        self.attn_backends.append(
         
     | 
| 
       921 
931 
     | 
    
         
             
                            FlashInferMLAAttnBackend(
         
     | 
| 
       922 
932 
     | 
    
         
             
                                model_runner,
         
     | 
| 
         @@ -998,7 +1008,7 @@ class FlashInferMLAMultiStepDraftBackend: 
     | 
|
| 
       998 
1008 
     | 
    
         
             
                        device="cuda",
         
     | 
| 
       999 
1009 
     | 
    
         
             
                    )
         
     | 
| 
       1000 
1010 
     | 
    
         | 
| 
       1001 
     | 
    
         
            -
                    for i in range(self.speculative_num_steps):
         
     | 
| 
      
 1011 
     | 
    
         
            +
                    for i in range(self.speculative_num_steps - 1):
         
     | 
| 
       1002 
1012 
     | 
    
         
             
                        self.attn_backends[i].init_cuda_graph_state(
         
     | 
| 
       1003 
1013 
     | 
    
         
             
                            max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
         
     | 
| 
       1004 
1014 
     | 
    
         
             
                        )
         
     | 
| 
         @@ -1060,7 +1070,7 @@ def fast_mla_decode_plan( 
     | 
|
| 
       1060 
1070 
     | 
    
         | 
| 
       1061 
1071 
     | 
    
         
             
                try:
         
     | 
| 
       1062 
1072 
     | 
    
         
             
                    # Standard version with just the required arguments (no use_profiler)
         
     | 
| 
       1063 
     | 
    
         
            -
                    self._cached_module.plan 
     | 
| 
      
 1073 
     | 
    
         
            +
                    self._cached_module.plan(
         
     | 
| 
       1064 
1074 
     | 
    
         
             
                        self._float_workspace_buffer,
         
     | 
| 
       1065 
1075 
     | 
    
         
             
                        self._int_workspace_buffer,
         
     | 
| 
       1066 
1076 
     | 
    
         
             
                        self._pin_memory_int_workspace_buffer,
         
     |