sglang 0.5.2rc2__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -11
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +474 -142
- sglang/compile_deep_gemm.py +3 -0
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +10 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +314 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +228 -92
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/qwen3_next.py +294 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +78 -37
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +373 -68
- sglang/srt/disaggregation/prefill.py +53 -49
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +842 -0
- sglang/srt/entrypoints/grpc_server.py +950 -0
- sglang/srt/entrypoints/http_server.py +179 -60
- sglang/srt/entrypoints/openai/protocol.py +265 -29
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +213 -122
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +289 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +17 -8
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +119 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +492 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +327 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +215 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +343 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +40 -8
- sglang/srt/layers/attention/flashinfer_backend.py +341 -204
- sglang/srt/layers/attention/flashinfer_mla_backend.py +28 -28
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +708 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +974 -0
- sglang/srt/layers/attention/mamba/mamba.py +577 -0
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +180 -18
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +21 -24
- sglang/srt/layers/moe/ep_moe/kernels.py +33 -454
- sglang/srt/layers/moe/ep_moe/layer.py +248 -333
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +68 -72
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +83 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +29 -7
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +155 -60
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +191 -56
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +28 -33
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +44 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +55 -0
- sglang/srt/managers/schedule_batch.py +343 -212
- sglang/srt/managers/schedule_policy.py +145 -18
- sglang/srt/managers/scheduler.py +653 -273
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +255 -108
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +579 -674
- sglang/srt/managers/tp_worker.py +96 -26
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +21 -22
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +9 -2
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +651 -80
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +227 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +93 -48
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +74 -46
- sglang/srt/model_executor/model_runner.py +455 -176
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +10 -4
- sglang/srt/model_loader/loader.py +319 -10
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +161 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +578 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +17 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/grok.py +5 -13
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mixtral.py +1 -3
- sglang/srt/models/mllama4.py +50 -4
- sglang/srt/models/nemotron_h.py +514 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +55 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +49 -26
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +1051 -285
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +98 -29
- sglang/srt/speculative/ngram_info.py +428 -0
- sglang/srt/speculative/ngram_worker.py +246 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +605 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +9 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +451 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +55 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +119 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_longbench_v2.py +332 -0
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +9 -19
- sglang/test/test_deterministic.py +313 -0
- sglang/test/test_deterministic_utils.py +81 -0
- sglang/test/test_disaggregation_utils.py +140 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +407 -8
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +392 -251
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -296
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -7,6 +7,7 @@ FlashInfer is faster and Triton is easier to customize.
|
|
7
7
|
Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
|
8
8
|
"""
|
9
9
|
|
10
|
+
import logging
|
10
11
|
import os
|
11
12
|
from dataclasses import dataclass
|
12
13
|
from enum import Enum, auto
|
@@ -16,11 +17,11 @@ from typing import TYPE_CHECKING, Callable, List, Optional, Union
|
|
16
17
|
import torch
|
17
18
|
|
18
19
|
if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
|
19
|
-
import logging
|
20
|
-
|
21
20
|
torch._logging.set_logs(dynamo=logging.ERROR)
|
22
21
|
torch._dynamo.config.suppress_errors = True
|
23
22
|
|
23
|
+
logger = logging.getLogger(__name__)
|
24
|
+
|
24
25
|
from sglang.global_config import global_config
|
25
26
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
26
27
|
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
@@ -28,8 +29,10 @@ from sglang.srt.layers.dp_attention import get_attention_tp_size
|
|
28
29
|
from sglang.srt.layers.radix_attention import AttentionType
|
29
30
|
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
|
30
31
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
31
|
-
from sglang.srt.speculative.
|
32
|
+
from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput
|
33
|
+
from sglang.srt.speculative.spec_info import SpecInput
|
32
34
|
from sglang.srt.utils import (
|
35
|
+
get_int_env_var,
|
33
36
|
is_flashinfer_available,
|
34
37
|
is_sm100_supported,
|
35
38
|
next_power_of_2,
|
@@ -39,11 +42,13 @@ if TYPE_CHECKING:
|
|
39
42
|
from sglang.srt.layers.radix_attention import RadixAttention
|
40
43
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
41
44
|
|
45
|
+
|
42
46
|
if is_flashinfer_available():
|
43
47
|
from flashinfer import (
|
44
48
|
BatchDecodeWithPagedKVCacheWrapper,
|
45
49
|
BatchPrefillWithPagedKVCacheWrapper,
|
46
50
|
BatchPrefillWithRaggedKVCacheWrapper,
|
51
|
+
fast_decode_plan,
|
47
52
|
)
|
48
53
|
from flashinfer.cascade import merge_state
|
49
54
|
from flashinfer.decode import _get_range_buf, get_seq_lens
|
@@ -54,6 +59,36 @@ class WrapperDispatch(Enum):
|
|
54
59
|
CROSS_ATTENTION = auto()
|
55
60
|
|
56
61
|
|
62
|
+
@dataclass
|
63
|
+
class MultiItemScoringParams:
|
64
|
+
"""Parameters for multi-item scoring in attention computation.
|
65
|
+
|
66
|
+
Used when processing sequences with multiple items separated by delimiters,
|
67
|
+
where each item needs specific attention patterns that respect item boundaries.
|
68
|
+
|
69
|
+
Attributes:
|
70
|
+
prefix_len_ptr: A uint32 1D tensor indicating the prefix length of each prompt.
|
71
|
+
The tensor size is equal to the batch size.
|
72
|
+
token_pos_in_items_ptr: A uint16 1D tensor indicating the token position of each item
|
73
|
+
starting from 0 (delimiter) for each item. For batch size > 1,
|
74
|
+
sequences are concatenated with zero padding to ensure same length.
|
75
|
+
token_pos_in_items_len: Zero padding length for token_pos_in_items_ptr to handle
|
76
|
+
batch_size > 1 case. Defines the padded length for each sequence.
|
77
|
+
max_item_len_ptr: A uint16 tensor containing the max token length of all items
|
78
|
+
for each prompt in the batch.
|
79
|
+
|
80
|
+
"""
|
81
|
+
|
82
|
+
prefix_len_ptr: Optional[torch.Tensor] = None
|
83
|
+
token_pos_in_items_ptr: Optional[torch.Tensor] = None
|
84
|
+
token_pos_in_items_len: int = 0
|
85
|
+
max_item_len_ptr: Optional[torch.Tensor] = None
|
86
|
+
|
87
|
+
def is_enabled(self) -> bool:
|
88
|
+
"""Check if multi-item scoring is enabled."""
|
89
|
+
return self.prefix_len_ptr is not None
|
90
|
+
|
91
|
+
|
57
92
|
@dataclass
|
58
93
|
class DecodeMetadata:
|
59
94
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper]
|
@@ -64,6 +99,7 @@ class PrefillMetadata:
|
|
64
99
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper]
|
65
100
|
use_ragged: bool
|
66
101
|
extend_no_prefix: bool
|
102
|
+
multi_item_params: Optional[MultiItemScoringParams] = None
|
67
103
|
|
68
104
|
|
69
105
|
# Reuse this workspace buffer across all flashinfer wrappers
|
@@ -86,6 +122,11 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
86
122
|
):
|
87
123
|
super().__init__()
|
88
124
|
|
125
|
+
# Store multi-item scoring delimiter for efficient access
|
126
|
+
self.multi_item_scoring_delimiter = (
|
127
|
+
model_runner.server_args.multi_item_scoring_delimiter
|
128
|
+
)
|
129
|
+
|
89
130
|
# Parse constants
|
90
131
|
self.decode_use_tensor_cores = should_use_tensor_core(
|
91
132
|
kv_cache_dtype=model_runner.kv_cache_dtype,
|
@@ -122,12 +163,33 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
122
163
|
):
|
123
164
|
global_config.flashinfer_workspace_size = 512 * 1024 * 1024
|
124
165
|
|
166
|
+
# When deterministic inference is enabled, tensor cores should be used for decode
|
167
|
+
# Also set split tile sizes for prefill and decode from environment variables, and disable kv split for cuda graph
|
168
|
+
# More information can be found here: https://github.com/flashinfer-ai/flashinfer/pull/1675
|
169
|
+
self.enable_deterministic = (
|
170
|
+
model_runner.server_args.enable_deterministic_inference
|
171
|
+
)
|
172
|
+
self.prefill_split_tile_size = None
|
173
|
+
self.decode_split_tile_size = None
|
174
|
+
self.disable_cuda_graph_kv_split = False
|
175
|
+
if self.enable_deterministic:
|
176
|
+
self.decode_use_tensor_cores = True
|
177
|
+
self.prefill_split_tile_size = get_int_env_var(
|
178
|
+
"SGLANG_FLASHINFER_PREFILL_SPLIT_TILE_SIZE", 4096
|
179
|
+
)
|
180
|
+
self.decode_split_tile_size = get_int_env_var(
|
181
|
+
"SGLANG_FLASHINFER_DECODE_SPLIT_TILE_SIZE", 2048
|
182
|
+
)
|
183
|
+
self.disable_cuda_graph_kv_split = True
|
184
|
+
global_config.flashinfer_workspace_size = 2048 * 1024 * 1024
|
185
|
+
|
125
186
|
# Allocate buffers
|
126
187
|
global global_workspace_buffer
|
127
188
|
if global_workspace_buffer is None:
|
128
189
|
# different from flashinfer zero_init_global_workspace_buffer
|
190
|
+
global_workspace_size = global_config.flashinfer_workspace_size
|
129
191
|
global_workspace_buffer = torch.empty(
|
130
|
-
|
192
|
+
global_workspace_size,
|
131
193
|
dtype=torch.uint8,
|
132
194
|
device=model_runner.device,
|
133
195
|
)
|
@@ -204,10 +266,133 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
204
266
|
|
205
267
|
# Other metadata
|
206
268
|
self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
|
269
|
+
|
207
270
|
self.decode_cuda_graph_metadata = {}
|
208
271
|
self.prefill_cuda_graph_metadata = {} # For verify
|
209
272
|
self.draft_extend_cuda_graph_metadata = {} # For draft extend
|
210
273
|
|
274
|
+
def _process_multi_item_scoring(
|
275
|
+
self, forward_batch: ForwardBatch
|
276
|
+
) -> MultiItemScoringParams:
|
277
|
+
"""Process multi-item scoring tensors for FlashInfer attention.
|
278
|
+
|
279
|
+
This method handles sequences containing multiple "items" separated by delimiter tokens,
|
280
|
+
where each item needs specific attention patterns that respect item boundaries.
|
281
|
+
|
282
|
+
The method produces four key tensors for FlashInfer:
|
283
|
+
- prefix_len_ptr: uint32 tensor with prefix length for each prompt in batch
|
284
|
+
- token_pos_in_items_ptr: uint16 tensor with token positions starting from 0 at delimiters
|
285
|
+
- token_pos_in_items_len: padding length for batch processing
|
286
|
+
- max_item_len_ptr: uint16 tensor with max item length for each prompt
|
287
|
+
|
288
|
+
Args:
|
289
|
+
forward_batch: The forward batch containing input sequences and delimiter info
|
290
|
+
|
291
|
+
Returns:
|
292
|
+
MultiItemScoringParams: The processed multi-item scoring parameters
|
293
|
+
|
294
|
+
Examples:
|
295
|
+
Following FlashInfer definition: for 3 items of length 3, 2, 4 respectively:
|
296
|
+
token_pos_in_items_ptr = [0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0]
|
297
|
+
|
298
|
+
Case 1: Single sequence
|
299
|
+
Text: "What is the capital of France? <delim> London <delim> Paris <delim> Berlin <delim>"
|
300
|
+
Tokens: [What, is, the, capital, of, France, ?, <delim>, London, <delim>, Paris, <delim>, Berlin, <delim>]
|
301
|
+
Indices: [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
|
302
|
+
- prefix_len_ptr: [7] (query length before first delimiter)
|
303
|
+
- token_pos_in_items_ptr: [0, 1, 0, 1, 0, 1, 0] (delim=0, London=1, delim=0, Paris=1, delim=0, Berlin=1, delim=0)
|
304
|
+
- token_pos_in_items_len: 7 (actual length)
|
305
|
+
- max_item_len_ptr: [1] (max item length is 1 token - all options are single tokens)
|
306
|
+
|
307
|
+
Case 2: Batch processing (batch_size=2)
|
308
|
+
Sequence 1: 2 items of length 2, 1 → [0, 1, 2, 0, 1, 0] (6 elements)
|
309
|
+
Sequence 2: 3 items of length 1, 3, 2 → [0, 1, 0, 1, 2, 3, 0, 1, 2, 0] (10 elements)
|
310
|
+
After padding both to length 10:
|
311
|
+
- token_pos_in_items_ptr: [0, 1, 2, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 2, 3, 0, 1, 2, 0]
|
312
|
+
- token_pos_in_items_len: 10 (padded length for batch processing)
|
313
|
+
- max_item_len_ptr: [2, 3] (max lengths per sequence)
|
314
|
+
"""
|
315
|
+
|
316
|
+
delimiter = self.multi_item_scoring_delimiter
|
317
|
+
if delimiter is None or forward_batch.forward_mode == ForwardMode.DECODE:
|
318
|
+
return MultiItemScoringParams()
|
319
|
+
|
320
|
+
delimiter_mask = forward_batch.input_ids == delimiter
|
321
|
+
prefix_cache_lens = getattr(forward_batch, "extend_prefix_lens", None)
|
322
|
+
extend_seq_lens = getattr(forward_batch, "extend_seq_lens", None)
|
323
|
+
prefix_len_ptr, token_pos_in_items_ptr = [], []
|
324
|
+
token_pos_in_items_len = 0
|
325
|
+
|
326
|
+
# If no extend_seq_lens, treat whole batch as one sequence
|
327
|
+
if extend_seq_lens is None or len(extend_seq_lens) <= 1:
|
328
|
+
extend_seq_lens = [forward_batch.input_ids.size(0)]
|
329
|
+
|
330
|
+
seq_start = 0
|
331
|
+
for i, seq_len in enumerate(extend_seq_lens):
|
332
|
+
seq_end = seq_start + seq_len
|
333
|
+
mask = delimiter_mask[seq_start:seq_end]
|
334
|
+
pos = forward_batch.positions[seq_start:seq_end]
|
335
|
+
delimiter_indices = torch.nonzero(mask, as_tuple=True)[0]
|
336
|
+
|
337
|
+
if len(delimiter_indices) > 0:
|
338
|
+
first_delim = delimiter_indices[0]
|
339
|
+
# Prefix length: store as scalar
|
340
|
+
prefix_len = first_delim + (
|
341
|
+
prefix_cache_lens[i] if prefix_cache_lens is not None else 0
|
342
|
+
)
|
343
|
+
prefix_len_ptr.append(
|
344
|
+
prefix_len.item() if torch.is_tensor(prefix_len) else prefix_len
|
345
|
+
)
|
346
|
+
|
347
|
+
# Compute relative positions within items after delimiters
|
348
|
+
diff = pos[first_delim:] - torch.cummax(mask[first_delim:], 0)[1]
|
349
|
+
token_pos = (diff - pos[first_delim]).to(torch.uint16)
|
350
|
+
token_pos_in_items_ptr.append(token_pos)
|
351
|
+
|
352
|
+
# Update forward_batch positions in-place
|
353
|
+
pos[first_delim:] = diff - 1
|
354
|
+
forward_batch.positions[seq_start:seq_end] = pos
|
355
|
+
|
356
|
+
seq_start = seq_end
|
357
|
+
|
358
|
+
# Pad token_pos_in_items_ptr for batch processing
|
359
|
+
if token_pos_in_items_ptr:
|
360
|
+
token_pos_in_items_len = max(t.numel() for t in token_pos_in_items_ptr)
|
361
|
+
device = forward_batch.input_ids.device
|
362
|
+
token_pos_in_items_ptr = [
|
363
|
+
torch.cat(
|
364
|
+
[
|
365
|
+
t,
|
366
|
+
torch.zeros(
|
367
|
+
token_pos_in_items_len - t.numel(),
|
368
|
+
dtype=torch.uint16,
|
369
|
+
device=device,
|
370
|
+
),
|
371
|
+
]
|
372
|
+
)
|
373
|
+
for t in token_pos_in_items_ptr
|
374
|
+
]
|
375
|
+
|
376
|
+
if not prefix_len_ptr or not token_pos_in_items_ptr:
|
377
|
+
return MultiItemScoringParams()
|
378
|
+
|
379
|
+
# Build final params
|
380
|
+
device = forward_batch.input_ids.device
|
381
|
+
return MultiItemScoringParams(
|
382
|
+
prefix_len_ptr=torch.tensor(
|
383
|
+
prefix_len_ptr, dtype=torch.uint32, device=device
|
384
|
+
),
|
385
|
+
token_pos_in_items_ptr=torch.cat(token_pos_in_items_ptr, dim=0),
|
386
|
+
token_pos_in_items_len=token_pos_in_items_len & 0xFFFFFFFF,
|
387
|
+
max_item_len_ptr=torch.stack(
|
388
|
+
[
|
389
|
+
t.to(torch.int32).max().to(torch.uint16)
|
390
|
+
for t in token_pos_in_items_ptr
|
391
|
+
],
|
392
|
+
dim=0,
|
393
|
+
),
|
394
|
+
)
|
395
|
+
|
211
396
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
212
397
|
if forward_batch.forward_mode.is_decode_or_idle():
|
213
398
|
self.indices_updater_decode.update(
|
@@ -218,6 +403,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
218
403
|
decode_wrappers=self.decode_wrappers,
|
219
404
|
encoder_lens=forward_batch.encoder_lens,
|
220
405
|
spec_info=forward_batch.spec_info,
|
406
|
+
fixed_split_size=self.decode_split_tile_size,
|
407
|
+
disable_split_kv=False,
|
221
408
|
)
|
222
409
|
self.forward_metadata = DecodeMetadata(self.decode_wrappers)
|
223
410
|
elif forward_batch.forward_mode.is_draft_extend():
|
@@ -253,13 +440,26 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
253
440
|
else:
|
254
441
|
prefix_lens = forward_batch.extend_prefix_lens
|
255
442
|
|
256
|
-
|
443
|
+
# Disable ragged wrapper and ensure prefix handling for multimodal and multi-item scoring
|
444
|
+
if self.is_multimodal or self.multi_item_scoring_delimiter is not None:
|
445
|
+
# use_ragged = False: Multi-item scoring requires the paged wrapper because:
|
446
|
+
# 1. Ragged wrapper doesn't support the specialized multi-item parameters
|
447
|
+
# (prefix_len_ptr, token_pos_in_items_ptr, etc.)
|
448
|
+
# 2. Paged wrapper provides better control over attention masking needed
|
449
|
+
# for respecting item boundaries in multi-item sequences
|
450
|
+
# 3. Custom masking logic conflicts with ragged wrapper's assumptions
|
257
451
|
use_ragged = False
|
258
452
|
extend_no_prefix = False
|
259
453
|
else:
|
260
|
-
use_ragged =
|
454
|
+
use_ragged = not self.enable_deterministic
|
261
455
|
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
|
262
456
|
|
457
|
+
# Process multi-item scoring in attention backend instead of ForwardBatch
|
458
|
+
multi_item_params = MultiItemScoringParams()
|
459
|
+
if self.multi_item_scoring_delimiter is not None:
|
460
|
+
# Use new backend-specific implementation
|
461
|
+
multi_item_params = self._process_multi_item_scoring(forward_batch)
|
462
|
+
|
263
463
|
self.indices_updater_prefill.update(
|
264
464
|
forward_batch.req_pool_indices,
|
265
465
|
forward_batch.seq_lens,
|
@@ -270,9 +470,14 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
270
470
|
use_ragged=use_ragged,
|
271
471
|
encoder_lens=forward_batch.encoder_lens,
|
272
472
|
spec_info=None,
|
473
|
+
fixed_split_size=self.prefill_split_tile_size,
|
474
|
+
multi_item_params=multi_item_params,
|
273
475
|
)
|
274
476
|
self.forward_metadata = PrefillMetadata(
|
275
|
-
self.prefill_wrappers_paged,
|
477
|
+
self.prefill_wrappers_paged,
|
478
|
+
use_ragged,
|
479
|
+
extend_no_prefix,
|
480
|
+
multi_item_params,
|
276
481
|
)
|
277
482
|
|
278
483
|
def init_cuda_graph_state(
|
@@ -317,7 +522,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
317
522
|
seq_lens: torch.Tensor,
|
318
523
|
encoder_lens: Optional[torch.Tensor],
|
319
524
|
forward_mode: ForwardMode,
|
320
|
-
spec_info: Optional[
|
525
|
+
spec_info: Optional[SpecInput],
|
321
526
|
):
|
322
527
|
if forward_mode.is_decode_or_idle():
|
323
528
|
decode_wrappers = []
|
@@ -344,6 +549,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
344
549
|
decode_wrappers=decode_wrappers,
|
345
550
|
encoder_lens=encoder_lens,
|
346
551
|
spec_info=spec_info,
|
552
|
+
fixed_split_size=None,
|
553
|
+
disable_split_kv=self.disable_cuda_graph_kv_split,
|
347
554
|
)
|
348
555
|
self.decode_cuda_graph_metadata[bs] = decode_wrappers
|
349
556
|
self.forward_metadata = DecodeMetadata(decode_wrappers)
|
@@ -422,7 +629,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
422
629
|
seq_lens_sum: int,
|
423
630
|
encoder_lens: Optional[torch.Tensor],
|
424
631
|
forward_mode: ForwardMode,
|
425
|
-
spec_info: Optional[
|
632
|
+
spec_info: Optional[SpecInput],
|
426
633
|
seq_lens_cpu: Optional[torch.Tensor],
|
427
634
|
):
|
428
635
|
if forward_mode.is_decode_or_idle():
|
@@ -434,6 +641,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
434
641
|
decode_wrappers=self.decode_cuda_graph_metadata[bs],
|
435
642
|
encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
|
436
643
|
spec_info=spec_info,
|
644
|
+
fixed_split_size=None,
|
645
|
+
disable_split_kv=self.disable_cuda_graph_kv_split,
|
437
646
|
)
|
438
647
|
elif forward_mode.is_target_verify():
|
439
648
|
self.indices_updater_prefill.update(
|
@@ -499,10 +708,24 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
499
708
|
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
500
709
|
causal=not layer.is_cross_attention,
|
501
710
|
sm_scale=layer.scaling,
|
502
|
-
|
711
|
+
# Disable sliding window attention for multi-item scoring:
|
712
|
+
# - Sliding window could cut across item boundaries, breaking semantic coherence
|
713
|
+
# - Multi-item sequences need full attention to properly handle delimiter tokens
|
714
|
+
# - Specialized multi-item parameters (prefix_len_ptr, token_pos_in_items_ptr)
|
715
|
+
# provide more precise attention control than simple sliding windows
|
716
|
+
# - Item-aware masking takes precedence over window-based masking
|
717
|
+
window_left=(
|
718
|
+
layer.sliding_window_size
|
719
|
+
if not (
|
720
|
+
self.forward_metadata.multi_item_params
|
721
|
+
and self.forward_metadata.multi_item_params.is_enabled()
|
722
|
+
)
|
723
|
+
else -1
|
724
|
+
),
|
503
725
|
logits_soft_cap=logits_soft_cap,
|
504
|
-
|
505
|
-
|
726
|
+
# Must use _float to avoid device-to-host copy that breaks cuda graph capture.
|
727
|
+
k_scale=layer.k_scale_float,
|
728
|
+
v_scale=layer.v_scale_float,
|
506
729
|
)
|
507
730
|
else:
|
508
731
|
causal = True
|
@@ -580,8 +803,9 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
580
803
|
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
581
804
|
sm_scale=layer.scaling,
|
582
805
|
logits_soft_cap=layer.logit_cap,
|
583
|
-
|
584
|
-
|
806
|
+
# Must use _float to avoid device-to-host copy that breaks cuda graph capture.
|
807
|
+
k_scale=layer.k_scale_float,
|
808
|
+
v_scale=layer.v_scale_float,
|
585
809
|
)
|
586
810
|
|
587
811
|
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
@@ -636,7 +860,9 @@ class FlashInferIndicesUpdaterDecode:
|
|
636
860
|
seq_lens_sum: int,
|
637
861
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
638
862
|
encoder_lens: Optional[torch.Tensor],
|
639
|
-
spec_info: Optional[
|
863
|
+
spec_info: Optional[SpecInput],
|
864
|
+
fixed_split_size: Optional[int] = None,
|
865
|
+
disable_split_kv: Optional[bool] = None,
|
640
866
|
):
|
641
867
|
# Keep the signature for type checking. It will be assigned during runtime.
|
642
868
|
raise NotImplementedError()
|
@@ -649,7 +875,9 @@ class FlashInferIndicesUpdaterDecode:
|
|
649
875
|
seq_lens_sum: int,
|
650
876
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
651
877
|
encoder_lens: Optional[torch.Tensor],
|
652
|
-
spec_info: Optional[
|
878
|
+
spec_info: Optional[SpecInput],
|
879
|
+
fixed_split_size: Optional[int] = None,
|
880
|
+
disable_split_kv: Optional[bool] = None,
|
653
881
|
):
|
654
882
|
decode_wrappers = decode_wrappers or self.decode_wrappers
|
655
883
|
self.call_begin_forward(
|
@@ -661,6 +889,8 @@ class FlashInferIndicesUpdaterDecode:
|
|
661
889
|
None,
|
662
890
|
spec_info,
|
663
891
|
seq_lens_cpu,
|
892
|
+
fixed_split_size=fixed_split_size,
|
893
|
+
disable_split_kv=disable_split_kv,
|
664
894
|
)
|
665
895
|
|
666
896
|
def update_sliding_window(
|
@@ -671,7 +901,9 @@ class FlashInferIndicesUpdaterDecode:
|
|
671
901
|
seq_lens_sum: int,
|
672
902
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
673
903
|
encoder_lens: Optional[torch.Tensor],
|
674
|
-
spec_info: Optional[
|
904
|
+
spec_info: Optional[SpecInput],
|
905
|
+
fixed_split_size: Optional[int] = None,
|
906
|
+
disable_split_kv: Optional[bool] = None,
|
675
907
|
):
|
676
908
|
assert self.sliding_window_size is not None
|
677
909
|
for wrapper_id in range(2):
|
@@ -719,7 +951,9 @@ class FlashInferIndicesUpdaterDecode:
|
|
719
951
|
seq_lens_sum: int,
|
720
952
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
721
953
|
encoder_lens: Optional[torch.Tensor],
|
722
|
-
spec_info: Optional[
|
954
|
+
spec_info: Optional[SpecInput],
|
955
|
+
fixed_split_size: Optional[int] = None,
|
956
|
+
disable_split_kv: Optional[bool] = None,
|
723
957
|
):
|
724
958
|
for wrapper_id in range(2):
|
725
959
|
if wrapper_id == 0:
|
@@ -751,9 +985,11 @@ class FlashInferIndicesUpdaterDecode:
|
|
751
985
|
paged_kernel_lens_sum: int,
|
752
986
|
kv_indptr: torch.Tensor,
|
753
987
|
kv_start_idx: torch.Tensor,
|
754
|
-
spec_info: Optional[
|
988
|
+
spec_info: Optional[SpecInput],
|
755
989
|
seq_lens_cpu: Optional[torch.Tensor],
|
756
990
|
use_sliding_window_kv_pool: bool = False,
|
991
|
+
fixed_split_size: Optional[int] = None,
|
992
|
+
disable_split_kv: Optional[bool] = None,
|
757
993
|
):
|
758
994
|
if spec_info is None:
|
759
995
|
bs = len(req_pool_indices)
|
@@ -797,19 +1033,51 @@ class FlashInferIndicesUpdaterDecode:
|
|
797
1033
|
global_override_indptr_cpu[0] = 0
|
798
1034
|
global_override_indptr_cpu[1 : bs + 1] = torch.cumsum(seq_lens_cpu, dim=0)
|
799
1035
|
|
800
|
-
wrapper
|
801
|
-
|
802
|
-
|
803
|
-
|
804
|
-
|
805
|
-
self.num_kv_heads,
|
806
|
-
self.head_dim,
|
807
|
-
1,
|
808
|
-
data_type=self.data_type,
|
809
|
-
q_data_type=self.q_data_type,
|
810
|
-
non_blocking=True,
|
1036
|
+
# Check if this specific wrapper's begin_forward has been replaced with fast_decode_plan
|
1037
|
+
# by checking if it's a partial function with fast_decode_plan as the func
|
1038
|
+
wrapper_uses_fast_decode_plan = (
|
1039
|
+
hasattr(wrapper.begin_forward, "func")
|
1040
|
+
and wrapper.begin_forward.func == fast_decode_plan
|
811
1041
|
)
|
812
1042
|
|
1043
|
+
if wrapper_uses_fast_decode_plan:
|
1044
|
+
# When begin_forward is replaced with fast_decode_plan, pass global_override_indptr_cpu
|
1045
|
+
wrapper.begin_forward(
|
1046
|
+
kv_indptr,
|
1047
|
+
kv_indices,
|
1048
|
+
self.kv_last_page_len[:bs],
|
1049
|
+
self.num_qo_heads,
|
1050
|
+
self.num_kv_heads,
|
1051
|
+
self.head_dim,
|
1052
|
+
1,
|
1053
|
+
data_type=self.data_type,
|
1054
|
+
q_data_type=self.q_data_type,
|
1055
|
+
non_blocking=True,
|
1056
|
+
fixed_split_size=fixed_split_size,
|
1057
|
+
disable_split_kv=(
|
1058
|
+
disable_split_kv if disable_split_kv is not None else False
|
1059
|
+
),
|
1060
|
+
global_override_indptr_cpu=global_override_indptr_cpu,
|
1061
|
+
)
|
1062
|
+
else:
|
1063
|
+
# When using original begin_forward, don't pass global_override_indptr_cpu
|
1064
|
+
wrapper.begin_forward(
|
1065
|
+
kv_indptr,
|
1066
|
+
kv_indices,
|
1067
|
+
self.kv_last_page_len[:bs],
|
1068
|
+
self.num_qo_heads,
|
1069
|
+
self.num_kv_heads,
|
1070
|
+
self.head_dim,
|
1071
|
+
1,
|
1072
|
+
data_type=self.data_type,
|
1073
|
+
q_data_type=self.q_data_type,
|
1074
|
+
non_blocking=True,
|
1075
|
+
fixed_split_size=fixed_split_size,
|
1076
|
+
disable_split_kv=(
|
1077
|
+
disable_split_kv if disable_split_kv is not None else False
|
1078
|
+
),
|
1079
|
+
)
|
1080
|
+
|
813
1081
|
if locally_override:
|
814
1082
|
global_override_indptr_cpu = None
|
815
1083
|
|
@@ -856,7 +1124,8 @@ class FlashInferIndicesUpdaterPrefill:
|
|
856
1124
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
857
1125
|
use_ragged: bool,
|
858
1126
|
encoder_lens: Optional[torch.Tensor],
|
859
|
-
spec_info: Optional[
|
1127
|
+
spec_info: Optional[SpecInput],
|
1128
|
+
fixed_split_size: Optional[int] = None,
|
860
1129
|
):
|
861
1130
|
# Keep the signature for type checking. It will be assigned during runtime.
|
862
1131
|
raise NotImplementedError()
|
@@ -871,7 +1140,9 @@ class FlashInferIndicesUpdaterPrefill:
|
|
871
1140
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
872
1141
|
use_ragged: bool,
|
873
1142
|
encoder_lens: Optional[torch.Tensor],
|
874
|
-
spec_info: Optional[
|
1143
|
+
spec_info: Optional[SpecInput],
|
1144
|
+
fixed_split_size: Optional[int] = None,
|
1145
|
+
multi_item_params: Optional[MultiItemScoringParams] = None,
|
875
1146
|
):
|
876
1147
|
if use_ragged:
|
877
1148
|
# TODO: remove this device sync, we can use forward_batch.extend_prefix_lens_cpu
|
@@ -895,6 +1166,8 @@ class FlashInferIndicesUpdaterPrefill:
|
|
895
1166
|
self.qo_indptr[0],
|
896
1167
|
use_ragged,
|
897
1168
|
spec_info,
|
1169
|
+
fixed_split_size=fixed_split_size,
|
1170
|
+
multi_item_params=multi_item_params,
|
898
1171
|
)
|
899
1172
|
|
900
1173
|
def update_sliding_window(
|
@@ -907,7 +1180,9 @@ class FlashInferIndicesUpdaterPrefill:
|
|
907
1180
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
908
1181
|
use_ragged: bool,
|
909
1182
|
encoder_lens: Optional[torch.Tensor],
|
910
|
-
spec_info: Optional[
|
1183
|
+
spec_info: Optional[SpecInput],
|
1184
|
+
fixed_split_size: Optional[int] = None,
|
1185
|
+
multi_item_params: Optional[MultiItemScoringParams] = None,
|
911
1186
|
):
|
912
1187
|
for wrapper_id in range(2):
|
913
1188
|
if wrapper_id == 0:
|
@@ -941,6 +1216,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
941
1216
|
use_ragged,
|
942
1217
|
spec_info,
|
943
1218
|
use_sliding_window_kv_pool=use_sliding_window_kv_pool,
|
1219
|
+
multi_item_params=multi_item_params,
|
944
1220
|
)
|
945
1221
|
|
946
1222
|
def update_cross_attention(
|
@@ -953,7 +1229,9 @@ class FlashInferIndicesUpdaterPrefill:
|
|
953
1229
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
954
1230
|
use_ragged: bool,
|
955
1231
|
encoder_lens: Optional[torch.Tensor],
|
956
|
-
spec_info: Optional[
|
1232
|
+
spec_info: Optional[SpecInput],
|
1233
|
+
fixed_split_size: Optional[int] = None,
|
1234
|
+
multi_item_params: Optional[MultiItemScoringParams] = None,
|
957
1235
|
):
|
958
1236
|
for wrapper_id in range(2):
|
959
1237
|
if wrapper_id == 0:
|
@@ -980,6 +1258,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
980
1258
|
self.qo_indptr[wrapper_id],
|
981
1259
|
use_ragged,
|
982
1260
|
spec_info,
|
1261
|
+
multi_item_params=multi_item_params,
|
983
1262
|
)
|
984
1263
|
|
985
1264
|
def call_begin_forward(
|
@@ -995,8 +1274,10 @@ class FlashInferIndicesUpdaterPrefill:
|
|
995
1274
|
kv_indptr: torch.Tensor,
|
996
1275
|
qo_indptr: torch.Tensor,
|
997
1276
|
use_ragged: bool,
|
998
|
-
spec_info: Optional[
|
1277
|
+
spec_info: Optional[SpecInput],
|
999
1278
|
use_sliding_window_kv_pool: bool = False,
|
1279
|
+
fixed_split_size: Optional[int] = None,
|
1280
|
+
multi_item_params: Optional[MultiItemScoringParams] = None,
|
1000
1281
|
):
|
1001
1282
|
bs = len(seq_lens)
|
1002
1283
|
if spec_info is None:
|
@@ -1022,9 +1303,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
1022
1303
|
qo_indptr = qo_indptr[: bs + 1]
|
1023
1304
|
custom_mask = None
|
1024
1305
|
else:
|
1025
|
-
assert isinstance(spec_info,
|
1026
|
-
spec_info, EagleVerifyInput
|
1027
|
-
)
|
1306
|
+
assert isinstance(spec_info, SpecInput)
|
1028
1307
|
kv_indices, kv_indptr, qo_indptr, custom_mask = (
|
1029
1308
|
spec_info.generate_attn_arg_prefill(
|
1030
1309
|
req_pool_indices,
|
@@ -1054,6 +1333,22 @@ class FlashInferIndicesUpdaterPrefill:
|
|
1054
1333
|
)
|
1055
1334
|
|
1056
1335
|
# cached part
|
1336
|
+
# Conditionally set multi-item parameters
|
1337
|
+
if multi_item_params is not None and multi_item_params.is_enabled():
|
1338
|
+
# Multi-item scoring is active - use specialized parameters and disable generic custom_mask
|
1339
|
+
use_custom_mask = None
|
1340
|
+
prefix_len_ptr = multi_item_params.prefix_len_ptr
|
1341
|
+
token_pos_in_items_ptr = multi_item_params.token_pos_in_items_ptr
|
1342
|
+
token_pos_in_items_len = multi_item_params.token_pos_in_items_len
|
1343
|
+
max_item_len_ptr = multi_item_params.max_item_len_ptr
|
1344
|
+
else:
|
1345
|
+
# No multi-item scoring - use standard parameters
|
1346
|
+
use_custom_mask = custom_mask
|
1347
|
+
prefix_len_ptr = None
|
1348
|
+
token_pos_in_items_ptr = None
|
1349
|
+
token_pos_in_items_len = 0
|
1350
|
+
max_item_len_ptr = None
|
1351
|
+
|
1057
1352
|
wrapper_paged.begin_forward(
|
1058
1353
|
qo_indptr,
|
1059
1354
|
kv_indptr,
|
@@ -1065,8 +1360,13 @@ class FlashInferIndicesUpdaterPrefill:
|
|
1065
1360
|
1,
|
1066
1361
|
q_data_type=self.q_data_type,
|
1067
1362
|
kv_data_type=self.data_type,
|
1068
|
-
custom_mask=
|
1363
|
+
custom_mask=use_custom_mask,
|
1069
1364
|
non_blocking=True,
|
1365
|
+
fixed_split_size=fixed_split_size,
|
1366
|
+
prefix_len_ptr=prefix_len_ptr,
|
1367
|
+
token_pos_in_items_ptr=token_pos_in_items_ptr,
|
1368
|
+
token_pos_in_items_len=token_pos_in_items_len,
|
1369
|
+
max_item_len_ptr=max_item_len_ptr,
|
1070
1370
|
)
|
1071
1371
|
|
1072
1372
|
|
@@ -1082,7 +1382,7 @@ class FlashInferMultiStepDraftBackend:
|
|
1082
1382
|
topk: int,
|
1083
1383
|
speculative_num_steps: int,
|
1084
1384
|
):
|
1085
|
-
from sglang.srt.speculative.
|
1385
|
+
from sglang.srt.speculative.spec_utils import generate_draft_decode_kv_indices
|
1086
1386
|
|
1087
1387
|
self.topk = topk
|
1088
1388
|
self.speculative_num_steps = speculative_num_steps
|
@@ -1146,7 +1446,7 @@ class FlashInferMultiStepDraftBackend:
|
|
1146
1446
|
)
|
1147
1447
|
|
1148
1448
|
assert forward_batch.spec_info is not None
|
1149
|
-
assert
|
1449
|
+
assert forward_batch.spec_info.is_draft_input()
|
1150
1450
|
|
1151
1451
|
# Copy the kv_indptr once to avoid multiple device-to-host copies in flashinfer's plan.
|
1152
1452
|
indptr_cpu_whole = self.kv_indptr[:, : bs + 1].cpu()
|
@@ -1274,166 +1574,3 @@ def should_use_tensor_core(
|
|
1274
1574
|
return gqa_group_size >= 4
|
1275
1575
|
else:
|
1276
1576
|
return False
|
1277
|
-
|
1278
|
-
|
1279
|
-
# Use as a fast path to override the indptr in flashinfer's plan function
|
1280
|
-
# This is used to remove some host-to-device copy overhead.
|
1281
|
-
global_override_indptr_cpu = None
|
1282
|
-
|
1283
|
-
|
1284
|
-
def fast_decode_plan(
|
1285
|
-
self,
|
1286
|
-
indptr: torch.Tensor,
|
1287
|
-
indices: torch.Tensor,
|
1288
|
-
last_page_len: torch.Tensor,
|
1289
|
-
num_qo_heads: int,
|
1290
|
-
num_kv_heads: int,
|
1291
|
-
head_dim: int,
|
1292
|
-
page_size: int,
|
1293
|
-
pos_encoding_mode: str = "NONE",
|
1294
|
-
window_left: int = -1,
|
1295
|
-
logits_soft_cap: Optional[float] = None,
|
1296
|
-
q_data_type: Optional[Union[str, torch.dtype]] = None,
|
1297
|
-
kv_data_type: Optional[Union[str, torch.dtype]] = None,
|
1298
|
-
data_type: Optional[Union[str, torch.dtype]] = None,
|
1299
|
-
sm_scale: Optional[float] = None,
|
1300
|
-
rope_scale: Optional[float] = None,
|
1301
|
-
rope_theta: Optional[float] = None,
|
1302
|
-
non_blocking: bool = True,
|
1303
|
-
) -> None:
|
1304
|
-
"""
|
1305
|
-
A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend.
|
1306
|
-
Modifications:
|
1307
|
-
- Remove unnecessary device-to-device copy for the cuda graph buffers.
|
1308
|
-
- Remove unnecessary host-to-device copy for the metadata buffers.
|
1309
|
-
"""
|
1310
|
-
batch_size = len(last_page_len)
|
1311
|
-
if logits_soft_cap is None:
|
1312
|
-
logits_soft_cap = 0.0
|
1313
|
-
|
1314
|
-
# Handle data types consistently
|
1315
|
-
if data_type is not None:
|
1316
|
-
if q_data_type is None:
|
1317
|
-
q_data_type = data_type
|
1318
|
-
if kv_data_type is None:
|
1319
|
-
kv_data_type = data_type
|
1320
|
-
elif q_data_type is None:
|
1321
|
-
q_data_type = "float16"
|
1322
|
-
|
1323
|
-
if kv_data_type is None:
|
1324
|
-
kv_data_type = q_data_type
|
1325
|
-
|
1326
|
-
if self.use_tensor_cores:
|
1327
|
-
qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
|
1328
|
-
|
1329
|
-
if self.is_cuda_graph_enabled:
|
1330
|
-
if batch_size != self._fixed_batch_size:
|
1331
|
-
raise ValueError(
|
1332
|
-
"The batch size should be fixed in cudagraph mode, the runtime batch size {} "
|
1333
|
-
" mismatches the batch size set during initialization {}".format(
|
1334
|
-
batch_size, self._fixed_batch_size
|
1335
|
-
)
|
1336
|
-
)
|
1337
|
-
if len(indices) > len(self._paged_kv_indices_buf):
|
1338
|
-
raise ValueError(
|
1339
|
-
"The size of indices should be less than or equal to the allocated buffer"
|
1340
|
-
)
|
1341
|
-
else:
|
1342
|
-
self._paged_kv_indptr_buf = indptr
|
1343
|
-
self._paged_kv_indices_buf = indices
|
1344
|
-
self._paged_kv_last_page_len_buf = last_page_len
|
1345
|
-
if self.use_tensor_cores:
|
1346
|
-
self._qo_indptr_buf = qo_indptr_host.to(
|
1347
|
-
self.device, non_blocking=non_blocking
|
1348
|
-
)
|
1349
|
-
|
1350
|
-
# Create empty tensors for dtype info if needed
|
1351
|
-
empty_q_data = torch.empty(
|
1352
|
-
0,
|
1353
|
-
dtype=(
|
1354
|
-
getattr(torch, q_data_type) if isinstance(q_data_type, str) else q_data_type
|
1355
|
-
),
|
1356
|
-
device=self.device,
|
1357
|
-
)
|
1358
|
-
|
1359
|
-
empty_kv_cache = torch.empty(
|
1360
|
-
0,
|
1361
|
-
dtype=(
|
1362
|
-
getattr(torch, kv_data_type)
|
1363
|
-
if isinstance(kv_data_type, str)
|
1364
|
-
else kv_data_type
|
1365
|
-
),
|
1366
|
-
device=self.device,
|
1367
|
-
)
|
1368
|
-
|
1369
|
-
indptr_host = (
|
1370
|
-
global_override_indptr_cpu
|
1371
|
-
if global_override_indptr_cpu is not None
|
1372
|
-
else indptr.cpu()
|
1373
|
-
)
|
1374
|
-
|
1375
|
-
with torch.cuda.device(self.device):
|
1376
|
-
|
1377
|
-
if self.use_tensor_cores:
|
1378
|
-
# ALSO convert last_page_len to CPU
|
1379
|
-
if page_size == 1:
|
1380
|
-
# When page size is 1, last_page_len is always 1.
|
1381
|
-
# Directly construct the host tensor rather than executing a device-to-host copy.
|
1382
|
-
last_page_len_host = torch.ones(
|
1383
|
-
(batch_size,), dtype=torch.int32, device="cpu"
|
1384
|
-
)
|
1385
|
-
else:
|
1386
|
-
last_page_len_host = last_page_len.cpu()
|
1387
|
-
|
1388
|
-
kv_lens_arr_host = get_seq_lens(indptr_host, last_page_len_host, page_size)
|
1389
|
-
|
1390
|
-
try:
|
1391
|
-
# Make sure we pass exactly 15 arguments for tensor core version
|
1392
|
-
self._plan_info = self._cached_module.plan(
|
1393
|
-
self._float_workspace_buffer,
|
1394
|
-
self._int_workspace_buffer,
|
1395
|
-
self._pin_memory_int_workspace_buffer,
|
1396
|
-
qo_indptr_host,
|
1397
|
-
indptr_host,
|
1398
|
-
kv_lens_arr_host,
|
1399
|
-
batch_size, # total_num_rows
|
1400
|
-
batch_size,
|
1401
|
-
num_qo_heads,
|
1402
|
-
num_kv_heads,
|
1403
|
-
page_size,
|
1404
|
-
self.is_cuda_graph_enabled,
|
1405
|
-
head_dim,
|
1406
|
-
head_dim,
|
1407
|
-
False, # causal
|
1408
|
-
)
|
1409
|
-
except Exception as e:
|
1410
|
-
raise RuntimeError(f"Error in standard plan: {e}")
|
1411
|
-
else:
|
1412
|
-
try:
|
1413
|
-
# Make sure we pass exactly 15 arguments for standard version
|
1414
|
-
self._plan_info = self._cached_module.plan(
|
1415
|
-
self._float_workspace_buffer,
|
1416
|
-
self._int_workspace_buffer,
|
1417
|
-
self._pin_memory_int_workspace_buffer,
|
1418
|
-
indptr_host,
|
1419
|
-
batch_size,
|
1420
|
-
num_qo_heads,
|
1421
|
-
num_kv_heads,
|
1422
|
-
page_size,
|
1423
|
-
self.is_cuda_graph_enabled,
|
1424
|
-
window_left,
|
1425
|
-
logits_soft_cap,
|
1426
|
-
head_dim,
|
1427
|
-
head_dim,
|
1428
|
-
empty_q_data,
|
1429
|
-
empty_kv_cache,
|
1430
|
-
)
|
1431
|
-
except Exception as e:
|
1432
|
-
raise RuntimeError(f"Error in standard plan: {e}")
|
1433
|
-
|
1434
|
-
self._pos_encoding_mode = pos_encoding_mode
|
1435
|
-
self._window_left = window_left
|
1436
|
-
self._logits_soft_cap = logits_soft_cap
|
1437
|
-
self._sm_scale = sm_scale
|
1438
|
-
self._rope_scale = rope_scale
|
1439
|
-
self._rope_theta = rope_theta
|