sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +267 -32
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/model_config.py +181 -82
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +71 -19
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +326 -53
- sglang/srt/disaggregation/prefill.py +36 -17
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +855 -0
- sglang/srt/entrypoints/grpc_server.py +810 -0
- sglang/srt/entrypoints/http_server.py +130 -59
- sglang/srt/entrypoints/openai/protocol.py +112 -4
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +204 -55
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +9 -2
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +118 -198
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
- sglang/srt/layers/attention/mamba/mamba.py +629 -0
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +44 -12
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +256 -63
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +22 -6
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +78 -49
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +190 -55
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +26 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +52 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +240 -138
- sglang/srt/managers/schedule_policy.py +144 -17
- sglang/srt/managers/scheduler.py +502 -209
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +320 -632
- sglang/srt/managers/tp_worker.py +81 -22
- sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +535 -58
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +222 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +25 -36
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +82 -40
- sglang/srt/model_executor/model_runner.py +432 -157
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +133 -5
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +158 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +40 -4
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +38 -17
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +966 -267
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +99 -28
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +433 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +77 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +383 -5
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +375 -245
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -28,8 +28,10 @@ from sglang.srt.layers.dp_attention import get_attention_tp_size
|
|
28
28
|
from sglang.srt.layers.radix_attention import AttentionType
|
29
29
|
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
|
30
30
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
31
|
-
from sglang.srt.speculative.
|
31
|
+
from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput
|
32
|
+
from sglang.srt.speculative.spec_info import SpecInput
|
32
33
|
from sglang.srt.utils import (
|
34
|
+
get_int_env_var,
|
33
35
|
is_flashinfer_available,
|
34
36
|
is_sm100_supported,
|
35
37
|
next_power_of_2,
|
@@ -39,11 +41,13 @@ if TYPE_CHECKING:
|
|
39
41
|
from sglang.srt.layers.radix_attention import RadixAttention
|
40
42
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
41
43
|
|
44
|
+
|
42
45
|
if is_flashinfer_available():
|
43
46
|
from flashinfer import (
|
44
47
|
BatchDecodeWithPagedKVCacheWrapper,
|
45
48
|
BatchPrefillWithPagedKVCacheWrapper,
|
46
49
|
BatchPrefillWithRaggedKVCacheWrapper,
|
50
|
+
fast_decode_plan,
|
47
51
|
)
|
48
52
|
from flashinfer.cascade import merge_state
|
49
53
|
from flashinfer.decode import _get_range_buf, get_seq_lens
|
@@ -122,12 +126,33 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
122
126
|
):
|
123
127
|
global_config.flashinfer_workspace_size = 512 * 1024 * 1024
|
124
128
|
|
129
|
+
# When deterministic inference is enabled, tensor cores should be used for decode
|
130
|
+
# Also set split tile sizes for prefill and decode from environment variables, and disable kv split for cuda graph
|
131
|
+
# More information can be found here: https://github.com/flashinfer-ai/flashinfer/pull/1675
|
132
|
+
self.enable_deterministic = (
|
133
|
+
model_runner.server_args.enable_deterministic_inference
|
134
|
+
)
|
135
|
+
self.prefill_split_tile_size = None
|
136
|
+
self.decode_split_tile_size = None
|
137
|
+
self.disable_cuda_graph_kv_split = False
|
138
|
+
if self.enable_deterministic:
|
139
|
+
self.decode_use_tensor_cores = True
|
140
|
+
self.prefill_split_tile_size = get_int_env_var(
|
141
|
+
"SGLANG_FLASHINFER_PREFILL_SPLIT_TILE_SIZE", 4096
|
142
|
+
)
|
143
|
+
self.decode_split_tile_size = get_int_env_var(
|
144
|
+
"SGLANG_FLASHINFER_DECODE_SPLIT_TILE_SIZE", 2048
|
145
|
+
)
|
146
|
+
self.disable_cuda_graph_kv_split = True
|
147
|
+
global_config.flashinfer_workspace_size = 2048 * 1024 * 1024
|
148
|
+
|
125
149
|
# Allocate buffers
|
126
150
|
global global_workspace_buffer
|
127
151
|
if global_workspace_buffer is None:
|
128
152
|
# different from flashinfer zero_init_global_workspace_buffer
|
153
|
+
global_workspace_size = global_config.flashinfer_workspace_size
|
129
154
|
global_workspace_buffer = torch.empty(
|
130
|
-
|
155
|
+
global_workspace_size,
|
131
156
|
dtype=torch.uint8,
|
132
157
|
device=model_runner.device,
|
133
158
|
)
|
@@ -218,6 +243,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
218
243
|
decode_wrappers=self.decode_wrappers,
|
219
244
|
encoder_lens=forward_batch.encoder_lens,
|
220
245
|
spec_info=forward_batch.spec_info,
|
246
|
+
fixed_split_size=self.decode_split_tile_size,
|
247
|
+
disable_split_kv=False,
|
221
248
|
)
|
222
249
|
self.forward_metadata = DecodeMetadata(self.decode_wrappers)
|
223
250
|
elif forward_batch.forward_mode.is_draft_extend():
|
@@ -257,7 +284,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
257
284
|
use_ragged = False
|
258
285
|
extend_no_prefix = False
|
259
286
|
else:
|
260
|
-
use_ragged =
|
287
|
+
use_ragged = not self.enable_deterministic
|
261
288
|
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
|
262
289
|
|
263
290
|
self.indices_updater_prefill.update(
|
@@ -270,6 +297,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
270
297
|
use_ragged=use_ragged,
|
271
298
|
encoder_lens=forward_batch.encoder_lens,
|
272
299
|
spec_info=None,
|
300
|
+
fixed_split_size=self.prefill_split_tile_size,
|
273
301
|
)
|
274
302
|
self.forward_metadata = PrefillMetadata(
|
275
303
|
self.prefill_wrappers_paged, use_ragged, extend_no_prefix
|
@@ -317,7 +345,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
317
345
|
seq_lens: torch.Tensor,
|
318
346
|
encoder_lens: Optional[torch.Tensor],
|
319
347
|
forward_mode: ForwardMode,
|
320
|
-
spec_info: Optional[
|
348
|
+
spec_info: Optional[SpecInput],
|
321
349
|
):
|
322
350
|
if forward_mode.is_decode_or_idle():
|
323
351
|
decode_wrappers = []
|
@@ -344,6 +372,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
344
372
|
decode_wrappers=decode_wrappers,
|
345
373
|
encoder_lens=encoder_lens,
|
346
374
|
spec_info=spec_info,
|
375
|
+
fixed_split_size=None,
|
376
|
+
disable_split_kv=self.disable_cuda_graph_kv_split,
|
347
377
|
)
|
348
378
|
self.decode_cuda_graph_metadata[bs] = decode_wrappers
|
349
379
|
self.forward_metadata = DecodeMetadata(decode_wrappers)
|
@@ -422,7 +452,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
422
452
|
seq_lens_sum: int,
|
423
453
|
encoder_lens: Optional[torch.Tensor],
|
424
454
|
forward_mode: ForwardMode,
|
425
|
-
spec_info: Optional[
|
455
|
+
spec_info: Optional[SpecInput],
|
426
456
|
seq_lens_cpu: Optional[torch.Tensor],
|
427
457
|
):
|
428
458
|
if forward_mode.is_decode_or_idle():
|
@@ -434,6 +464,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
434
464
|
decode_wrappers=self.decode_cuda_graph_metadata[bs],
|
435
465
|
encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
|
436
466
|
spec_info=spec_info,
|
467
|
+
fixed_split_size=None,
|
468
|
+
disable_split_kv=self.disable_cuda_graph_kv_split,
|
437
469
|
)
|
438
470
|
elif forward_mode.is_target_verify():
|
439
471
|
self.indices_updater_prefill.update(
|
@@ -501,8 +533,9 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
501
533
|
sm_scale=layer.scaling,
|
502
534
|
window_left=layer.sliding_window_size,
|
503
535
|
logits_soft_cap=logits_soft_cap,
|
504
|
-
|
505
|
-
|
536
|
+
# Must use _float to avoid device-to-host copy that breaks cuda graph capture.
|
537
|
+
k_scale=layer.k_scale_float,
|
538
|
+
v_scale=layer.v_scale_float,
|
506
539
|
)
|
507
540
|
else:
|
508
541
|
causal = True
|
@@ -580,8 +613,9 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
580
613
|
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
581
614
|
sm_scale=layer.scaling,
|
582
615
|
logits_soft_cap=layer.logit_cap,
|
583
|
-
|
584
|
-
|
616
|
+
# Must use _float to avoid device-to-host copy that breaks cuda graph capture.
|
617
|
+
k_scale=layer.k_scale_float,
|
618
|
+
v_scale=layer.v_scale_float,
|
585
619
|
)
|
586
620
|
|
587
621
|
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
@@ -636,7 +670,9 @@ class FlashInferIndicesUpdaterDecode:
|
|
636
670
|
seq_lens_sum: int,
|
637
671
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
638
672
|
encoder_lens: Optional[torch.Tensor],
|
639
|
-
spec_info: Optional[
|
673
|
+
spec_info: Optional[SpecInput],
|
674
|
+
fixed_split_size: Optional[int] = None,
|
675
|
+
disable_split_kv: Optional[bool] = None,
|
640
676
|
):
|
641
677
|
# Keep the signature for type checking. It will be assigned during runtime.
|
642
678
|
raise NotImplementedError()
|
@@ -649,7 +685,9 @@ class FlashInferIndicesUpdaterDecode:
|
|
649
685
|
seq_lens_sum: int,
|
650
686
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
651
687
|
encoder_lens: Optional[torch.Tensor],
|
652
|
-
spec_info: Optional[
|
688
|
+
spec_info: Optional[SpecInput],
|
689
|
+
fixed_split_size: Optional[int] = None,
|
690
|
+
disable_split_kv: Optional[bool] = None,
|
653
691
|
):
|
654
692
|
decode_wrappers = decode_wrappers or self.decode_wrappers
|
655
693
|
self.call_begin_forward(
|
@@ -661,6 +699,8 @@ class FlashInferIndicesUpdaterDecode:
|
|
661
699
|
None,
|
662
700
|
spec_info,
|
663
701
|
seq_lens_cpu,
|
702
|
+
fixed_split_size=fixed_split_size,
|
703
|
+
disable_split_kv=disable_split_kv,
|
664
704
|
)
|
665
705
|
|
666
706
|
def update_sliding_window(
|
@@ -671,7 +711,9 @@ class FlashInferIndicesUpdaterDecode:
|
|
671
711
|
seq_lens_sum: int,
|
672
712
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
673
713
|
encoder_lens: Optional[torch.Tensor],
|
674
|
-
spec_info: Optional[
|
714
|
+
spec_info: Optional[SpecInput],
|
715
|
+
fixed_split_size: Optional[int] = None,
|
716
|
+
disable_split_kv: Optional[bool] = None,
|
675
717
|
):
|
676
718
|
assert self.sliding_window_size is not None
|
677
719
|
for wrapper_id in range(2):
|
@@ -719,7 +761,9 @@ class FlashInferIndicesUpdaterDecode:
|
|
719
761
|
seq_lens_sum: int,
|
720
762
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
721
763
|
encoder_lens: Optional[torch.Tensor],
|
722
|
-
spec_info: Optional[
|
764
|
+
spec_info: Optional[SpecInput],
|
765
|
+
fixed_split_size: Optional[int] = None,
|
766
|
+
disable_split_kv: Optional[bool] = None,
|
723
767
|
):
|
724
768
|
for wrapper_id in range(2):
|
725
769
|
if wrapper_id == 0:
|
@@ -751,9 +795,11 @@ class FlashInferIndicesUpdaterDecode:
|
|
751
795
|
paged_kernel_lens_sum: int,
|
752
796
|
kv_indptr: torch.Tensor,
|
753
797
|
kv_start_idx: torch.Tensor,
|
754
|
-
spec_info: Optional[
|
798
|
+
spec_info: Optional[SpecInput],
|
755
799
|
seq_lens_cpu: Optional[torch.Tensor],
|
756
800
|
use_sliding_window_kv_pool: bool = False,
|
801
|
+
fixed_split_size: Optional[int] = None,
|
802
|
+
disable_split_kv: Optional[bool] = None,
|
757
803
|
):
|
758
804
|
if spec_info is None:
|
759
805
|
bs = len(req_pool_indices)
|
@@ -797,19 +843,51 @@ class FlashInferIndicesUpdaterDecode:
|
|
797
843
|
global_override_indptr_cpu[0] = 0
|
798
844
|
global_override_indptr_cpu[1 : bs + 1] = torch.cumsum(seq_lens_cpu, dim=0)
|
799
845
|
|
800
|
-
wrapper
|
801
|
-
|
802
|
-
|
803
|
-
|
804
|
-
|
805
|
-
self.num_kv_heads,
|
806
|
-
self.head_dim,
|
807
|
-
1,
|
808
|
-
data_type=self.data_type,
|
809
|
-
q_data_type=self.q_data_type,
|
810
|
-
non_blocking=True,
|
846
|
+
# Check if this specific wrapper's begin_forward has been replaced with fast_decode_plan
|
847
|
+
# by checking if it's a partial function with fast_decode_plan as the func
|
848
|
+
wrapper_uses_fast_decode_plan = (
|
849
|
+
hasattr(wrapper.begin_forward, "func")
|
850
|
+
and wrapper.begin_forward.func == fast_decode_plan
|
811
851
|
)
|
812
852
|
|
853
|
+
if wrapper_uses_fast_decode_plan:
|
854
|
+
# When begin_forward is replaced with fast_decode_plan, pass global_override_indptr_cpu
|
855
|
+
wrapper.begin_forward(
|
856
|
+
kv_indptr,
|
857
|
+
kv_indices,
|
858
|
+
self.kv_last_page_len[:bs],
|
859
|
+
self.num_qo_heads,
|
860
|
+
self.num_kv_heads,
|
861
|
+
self.head_dim,
|
862
|
+
1,
|
863
|
+
data_type=self.data_type,
|
864
|
+
q_data_type=self.q_data_type,
|
865
|
+
non_blocking=True,
|
866
|
+
fixed_split_size=fixed_split_size,
|
867
|
+
disable_split_kv=(
|
868
|
+
disable_split_kv if disable_split_kv is not None else False
|
869
|
+
),
|
870
|
+
global_override_indptr_cpu=global_override_indptr_cpu,
|
871
|
+
)
|
872
|
+
else:
|
873
|
+
# When using original begin_forward, don't pass global_override_indptr_cpu
|
874
|
+
wrapper.begin_forward(
|
875
|
+
kv_indptr,
|
876
|
+
kv_indices,
|
877
|
+
self.kv_last_page_len[:bs],
|
878
|
+
self.num_qo_heads,
|
879
|
+
self.num_kv_heads,
|
880
|
+
self.head_dim,
|
881
|
+
1,
|
882
|
+
data_type=self.data_type,
|
883
|
+
q_data_type=self.q_data_type,
|
884
|
+
non_blocking=True,
|
885
|
+
fixed_split_size=fixed_split_size,
|
886
|
+
disable_split_kv=(
|
887
|
+
disable_split_kv if disable_split_kv is not None else False
|
888
|
+
),
|
889
|
+
)
|
890
|
+
|
813
891
|
if locally_override:
|
814
892
|
global_override_indptr_cpu = None
|
815
893
|
|
@@ -856,7 +934,8 @@ class FlashInferIndicesUpdaterPrefill:
|
|
856
934
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
857
935
|
use_ragged: bool,
|
858
936
|
encoder_lens: Optional[torch.Tensor],
|
859
|
-
spec_info: Optional[
|
937
|
+
spec_info: Optional[SpecInput],
|
938
|
+
fixed_split_size: Optional[int] = None,
|
860
939
|
):
|
861
940
|
# Keep the signature for type checking. It will be assigned during runtime.
|
862
941
|
raise NotImplementedError()
|
@@ -871,7 +950,8 @@ class FlashInferIndicesUpdaterPrefill:
|
|
871
950
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
872
951
|
use_ragged: bool,
|
873
952
|
encoder_lens: Optional[torch.Tensor],
|
874
|
-
spec_info: Optional[
|
953
|
+
spec_info: Optional[SpecInput],
|
954
|
+
fixed_split_size: Optional[int] = None,
|
875
955
|
):
|
876
956
|
if use_ragged:
|
877
957
|
# TODO: remove this device sync, we can use forward_batch.extend_prefix_lens_cpu
|
@@ -895,6 +975,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
895
975
|
self.qo_indptr[0],
|
896
976
|
use_ragged,
|
897
977
|
spec_info,
|
978
|
+
fixed_split_size=fixed_split_size,
|
898
979
|
)
|
899
980
|
|
900
981
|
def update_sliding_window(
|
@@ -907,7 +988,8 @@ class FlashInferIndicesUpdaterPrefill:
|
|
907
988
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
908
989
|
use_ragged: bool,
|
909
990
|
encoder_lens: Optional[torch.Tensor],
|
910
|
-
spec_info: Optional[
|
991
|
+
spec_info: Optional[SpecInput],
|
992
|
+
fixed_split_size: Optional[int] = None,
|
911
993
|
):
|
912
994
|
for wrapper_id in range(2):
|
913
995
|
if wrapper_id == 0:
|
@@ -953,7 +1035,8 @@ class FlashInferIndicesUpdaterPrefill:
|
|
953
1035
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
954
1036
|
use_ragged: bool,
|
955
1037
|
encoder_lens: Optional[torch.Tensor],
|
956
|
-
spec_info: Optional[
|
1038
|
+
spec_info: Optional[SpecInput],
|
1039
|
+
fixed_split_size: Optional[int] = None,
|
957
1040
|
):
|
958
1041
|
for wrapper_id in range(2):
|
959
1042
|
if wrapper_id == 0:
|
@@ -995,8 +1078,9 @@ class FlashInferIndicesUpdaterPrefill:
|
|
995
1078
|
kv_indptr: torch.Tensor,
|
996
1079
|
qo_indptr: torch.Tensor,
|
997
1080
|
use_ragged: bool,
|
998
|
-
spec_info: Optional[
|
1081
|
+
spec_info: Optional[SpecInput],
|
999
1082
|
use_sliding_window_kv_pool: bool = False,
|
1083
|
+
fixed_split_size: Optional[int] = None,
|
1000
1084
|
):
|
1001
1085
|
bs = len(seq_lens)
|
1002
1086
|
if spec_info is None:
|
@@ -1022,9 +1106,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
1022
1106
|
qo_indptr = qo_indptr[: bs + 1]
|
1023
1107
|
custom_mask = None
|
1024
1108
|
else:
|
1025
|
-
assert isinstance(spec_info,
|
1026
|
-
spec_info, EagleVerifyInput
|
1027
|
-
)
|
1109
|
+
assert isinstance(spec_info, SpecInput)
|
1028
1110
|
kv_indices, kv_indptr, qo_indptr, custom_mask = (
|
1029
1111
|
spec_info.generate_attn_arg_prefill(
|
1030
1112
|
req_pool_indices,
|
@@ -1067,6 +1149,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
1067
1149
|
kv_data_type=self.data_type,
|
1068
1150
|
custom_mask=custom_mask,
|
1069
1151
|
non_blocking=True,
|
1152
|
+
fixed_split_size=fixed_split_size,
|
1070
1153
|
)
|
1071
1154
|
|
1072
1155
|
|
@@ -1082,7 +1165,7 @@ class FlashInferMultiStepDraftBackend:
|
|
1082
1165
|
topk: int,
|
1083
1166
|
speculative_num_steps: int,
|
1084
1167
|
):
|
1085
|
-
from sglang.srt.speculative.
|
1168
|
+
from sglang.srt.speculative.spec_utils import generate_draft_decode_kv_indices
|
1086
1169
|
|
1087
1170
|
self.topk = topk
|
1088
1171
|
self.speculative_num_steps = speculative_num_steps
|
@@ -1146,7 +1229,7 @@ class FlashInferMultiStepDraftBackend:
|
|
1146
1229
|
)
|
1147
1230
|
|
1148
1231
|
assert forward_batch.spec_info is not None
|
1149
|
-
assert
|
1232
|
+
assert forward_batch.spec_info.is_draft_input()
|
1150
1233
|
|
1151
1234
|
# Copy the kv_indptr once to avoid multiple device-to-host copies in flashinfer's plan.
|
1152
1235
|
indptr_cpu_whole = self.kv_indptr[:, : bs + 1].cpu()
|
@@ -1274,166 +1357,3 @@ def should_use_tensor_core(
|
|
1274
1357
|
return gqa_group_size >= 4
|
1275
1358
|
else:
|
1276
1359
|
return False
|
1277
|
-
|
1278
|
-
|
1279
|
-
# Use as a fast path to override the indptr in flashinfer's plan function
|
1280
|
-
# This is used to remove some host-to-device copy overhead.
|
1281
|
-
global_override_indptr_cpu = None
|
1282
|
-
|
1283
|
-
|
1284
|
-
def fast_decode_plan(
|
1285
|
-
self,
|
1286
|
-
indptr: torch.Tensor,
|
1287
|
-
indices: torch.Tensor,
|
1288
|
-
last_page_len: torch.Tensor,
|
1289
|
-
num_qo_heads: int,
|
1290
|
-
num_kv_heads: int,
|
1291
|
-
head_dim: int,
|
1292
|
-
page_size: int,
|
1293
|
-
pos_encoding_mode: str = "NONE",
|
1294
|
-
window_left: int = -1,
|
1295
|
-
logits_soft_cap: Optional[float] = None,
|
1296
|
-
q_data_type: Optional[Union[str, torch.dtype]] = None,
|
1297
|
-
kv_data_type: Optional[Union[str, torch.dtype]] = None,
|
1298
|
-
data_type: Optional[Union[str, torch.dtype]] = None,
|
1299
|
-
sm_scale: Optional[float] = None,
|
1300
|
-
rope_scale: Optional[float] = None,
|
1301
|
-
rope_theta: Optional[float] = None,
|
1302
|
-
non_blocking: bool = True,
|
1303
|
-
) -> None:
|
1304
|
-
"""
|
1305
|
-
A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend.
|
1306
|
-
Modifications:
|
1307
|
-
- Remove unnecessary device-to-device copy for the cuda graph buffers.
|
1308
|
-
- Remove unnecessary host-to-device copy for the metadata buffers.
|
1309
|
-
"""
|
1310
|
-
batch_size = len(last_page_len)
|
1311
|
-
if logits_soft_cap is None:
|
1312
|
-
logits_soft_cap = 0.0
|
1313
|
-
|
1314
|
-
# Handle data types consistently
|
1315
|
-
if data_type is not None:
|
1316
|
-
if q_data_type is None:
|
1317
|
-
q_data_type = data_type
|
1318
|
-
if kv_data_type is None:
|
1319
|
-
kv_data_type = data_type
|
1320
|
-
elif q_data_type is None:
|
1321
|
-
q_data_type = "float16"
|
1322
|
-
|
1323
|
-
if kv_data_type is None:
|
1324
|
-
kv_data_type = q_data_type
|
1325
|
-
|
1326
|
-
if self.use_tensor_cores:
|
1327
|
-
qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
|
1328
|
-
|
1329
|
-
if self.is_cuda_graph_enabled:
|
1330
|
-
if batch_size != self._fixed_batch_size:
|
1331
|
-
raise ValueError(
|
1332
|
-
"The batch size should be fixed in cudagraph mode, the runtime batch size {} "
|
1333
|
-
" mismatches the batch size set during initialization {}".format(
|
1334
|
-
batch_size, self._fixed_batch_size
|
1335
|
-
)
|
1336
|
-
)
|
1337
|
-
if len(indices) > len(self._paged_kv_indices_buf):
|
1338
|
-
raise ValueError(
|
1339
|
-
"The size of indices should be less than or equal to the allocated buffer"
|
1340
|
-
)
|
1341
|
-
else:
|
1342
|
-
self._paged_kv_indptr_buf = indptr
|
1343
|
-
self._paged_kv_indices_buf = indices
|
1344
|
-
self._paged_kv_last_page_len_buf = last_page_len
|
1345
|
-
if self.use_tensor_cores:
|
1346
|
-
self._qo_indptr_buf = qo_indptr_host.to(
|
1347
|
-
self.device, non_blocking=non_blocking
|
1348
|
-
)
|
1349
|
-
|
1350
|
-
# Create empty tensors for dtype info if needed
|
1351
|
-
empty_q_data = torch.empty(
|
1352
|
-
0,
|
1353
|
-
dtype=(
|
1354
|
-
getattr(torch, q_data_type) if isinstance(q_data_type, str) else q_data_type
|
1355
|
-
),
|
1356
|
-
device=self.device,
|
1357
|
-
)
|
1358
|
-
|
1359
|
-
empty_kv_cache = torch.empty(
|
1360
|
-
0,
|
1361
|
-
dtype=(
|
1362
|
-
getattr(torch, kv_data_type)
|
1363
|
-
if isinstance(kv_data_type, str)
|
1364
|
-
else kv_data_type
|
1365
|
-
),
|
1366
|
-
device=self.device,
|
1367
|
-
)
|
1368
|
-
|
1369
|
-
indptr_host = (
|
1370
|
-
global_override_indptr_cpu
|
1371
|
-
if global_override_indptr_cpu is not None
|
1372
|
-
else indptr.cpu()
|
1373
|
-
)
|
1374
|
-
|
1375
|
-
with torch.cuda.device(self.device):
|
1376
|
-
|
1377
|
-
if self.use_tensor_cores:
|
1378
|
-
# ALSO convert last_page_len to CPU
|
1379
|
-
if page_size == 1:
|
1380
|
-
# When page size is 1, last_page_len is always 1.
|
1381
|
-
# Directly construct the host tensor rather than executing a device-to-host copy.
|
1382
|
-
last_page_len_host = torch.ones(
|
1383
|
-
(batch_size,), dtype=torch.int32, device="cpu"
|
1384
|
-
)
|
1385
|
-
else:
|
1386
|
-
last_page_len_host = last_page_len.cpu()
|
1387
|
-
|
1388
|
-
kv_lens_arr_host = get_seq_lens(indptr_host, last_page_len_host, page_size)
|
1389
|
-
|
1390
|
-
try:
|
1391
|
-
# Make sure we pass exactly 15 arguments for tensor core version
|
1392
|
-
self._plan_info = self._cached_module.plan(
|
1393
|
-
self._float_workspace_buffer,
|
1394
|
-
self._int_workspace_buffer,
|
1395
|
-
self._pin_memory_int_workspace_buffer,
|
1396
|
-
qo_indptr_host,
|
1397
|
-
indptr_host,
|
1398
|
-
kv_lens_arr_host,
|
1399
|
-
batch_size, # total_num_rows
|
1400
|
-
batch_size,
|
1401
|
-
num_qo_heads,
|
1402
|
-
num_kv_heads,
|
1403
|
-
page_size,
|
1404
|
-
self.is_cuda_graph_enabled,
|
1405
|
-
head_dim,
|
1406
|
-
head_dim,
|
1407
|
-
False, # causal
|
1408
|
-
)
|
1409
|
-
except Exception as e:
|
1410
|
-
raise RuntimeError(f"Error in standard plan: {e}")
|
1411
|
-
else:
|
1412
|
-
try:
|
1413
|
-
# Make sure we pass exactly 15 arguments for standard version
|
1414
|
-
self._plan_info = self._cached_module.plan(
|
1415
|
-
self._float_workspace_buffer,
|
1416
|
-
self._int_workspace_buffer,
|
1417
|
-
self._pin_memory_int_workspace_buffer,
|
1418
|
-
indptr_host,
|
1419
|
-
batch_size,
|
1420
|
-
num_qo_heads,
|
1421
|
-
num_kv_heads,
|
1422
|
-
page_size,
|
1423
|
-
self.is_cuda_graph_enabled,
|
1424
|
-
window_left,
|
1425
|
-
logits_soft_cap,
|
1426
|
-
head_dim,
|
1427
|
-
head_dim,
|
1428
|
-
empty_q_data,
|
1429
|
-
empty_kv_cache,
|
1430
|
-
)
|
1431
|
-
except Exception as e:
|
1432
|
-
raise RuntimeError(f"Error in standard plan: {e}")
|
1433
|
-
|
1434
|
-
self._pos_encoding_mode = pos_encoding_mode
|
1435
|
-
self._window_left = window_left
|
1436
|
-
self._logits_soft_cap = logits_soft_cap
|
1437
|
-
self._sm_scale = sm_scale
|
1438
|
-
self._rope_scale = rope_scale
|
1439
|
-
self._rope_theta = rope_theta
|
@@ -30,7 +30,7 @@ from sglang.srt.layers.attention.flashinfer_backend import (
|
|
30
30
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
31
31
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
32
32
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
33
|
-
from sglang.srt.speculative.
|
33
|
+
from sglang.srt.speculative.spec_info import SpecInput
|
34
34
|
from sglang.srt.utils import (
|
35
35
|
is_flashinfer_available,
|
36
36
|
is_sm100_supported,
|
@@ -40,7 +40,7 @@ from sglang.srt.utils import (
|
|
40
40
|
if TYPE_CHECKING:
|
41
41
|
from sglang.srt.layers.radix_attention import RadixAttention
|
42
42
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
43
|
-
from sglang.srt.speculative.spec_info import
|
43
|
+
from sglang.srt.speculative.spec_info import SpecInput
|
44
44
|
|
45
45
|
if is_flashinfer_available():
|
46
46
|
from flashinfer import (
|
@@ -96,6 +96,7 @@ class FlashInferMhaChunkKVRunner:
|
|
96
96
|
def update_wrapper(
|
97
97
|
self,
|
98
98
|
forward_batch: ForwardBatch,
|
99
|
+
disable_flashinfer_ragged: bool = False,
|
99
100
|
):
|
100
101
|
assert forward_batch.num_prefix_chunks is not None
|
101
102
|
num_prefix_chunks = forward_batch.num_prefix_chunks
|
@@ -128,16 +129,17 @@ class FlashInferMhaChunkKVRunner:
|
|
128
129
|
causal=False,
|
129
130
|
)
|
130
131
|
# ragged prefill
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
132
|
+
if not disable_flashinfer_ragged:
|
133
|
+
self.ragged_wrapper.begin_forward(
|
134
|
+
qo_indptr=qo_indptr,
|
135
|
+
kv_indptr=qo_indptr,
|
136
|
+
num_qo_heads=self.num_local_heads,
|
137
|
+
num_kv_heads=self.num_local_heads,
|
138
|
+
head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim,
|
139
|
+
head_dim_vo=self.v_head_dim,
|
140
|
+
q_data_type=self.q_data_type,
|
141
|
+
causal=True,
|
142
|
+
)
|
141
143
|
|
142
144
|
def forward(
|
143
145
|
self,
|
@@ -359,7 +361,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
359
361
|
seq_lens: torch.Tensor,
|
360
362
|
encoder_lens: Optional[torch.Tensor],
|
361
363
|
forward_mode: ForwardMode,
|
362
|
-
spec_info: Optional[
|
364
|
+
spec_info: Optional[SpecInput],
|
363
365
|
):
|
364
366
|
if forward_mode.is_decode_or_idle():
|
365
367
|
decode_wrapper = BatchMLAPagedAttentionWrapper(
|
@@ -439,7 +441,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
439
441
|
seq_lens_sum: int,
|
440
442
|
encoder_lens: Optional[torch.Tensor],
|
441
443
|
forward_mode: ForwardMode,
|
442
|
-
spec_info: Optional[
|
444
|
+
spec_info: Optional[SpecInput],
|
443
445
|
seq_lens_cpu: Optional[torch.Tensor],
|
444
446
|
):
|
445
447
|
if forward_mode.is_decode_or_idle():
|
@@ -491,9 +493,11 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
491
493
|
def get_cuda_graph_seq_len_fill_value(self):
|
492
494
|
return 1
|
493
495
|
|
494
|
-
def init_mha_chunk_metadata(
|
496
|
+
def init_mha_chunk_metadata(
|
497
|
+
self, forward_batch: ForwardBatch, disable_flashinfer_ragged: bool = False
|
498
|
+
):
|
495
499
|
"""Init the metadata for a forward pass."""
|
496
|
-
self.mha_chunk_kv_cache.update_wrapper(forward_batch)
|
500
|
+
self.mha_chunk_kv_cache.update_wrapper(forward_batch, disable_flashinfer_ragged)
|
497
501
|
|
498
502
|
def forward_extend(
|
499
503
|
self,
|
@@ -659,7 +663,7 @@ class FlashInferMLAIndicesUpdaterDecode:
|
|
659
663
|
seq_lens_sum: int,
|
660
664
|
decode_wrapper: BatchMLAPagedAttentionWrapper,
|
661
665
|
init_metadata_replay: bool = False,
|
662
|
-
spec_info: Optional[
|
666
|
+
spec_info: Optional[SpecInput] = None,
|
663
667
|
**fast_decode_kwargs,
|
664
668
|
):
|
665
669
|
decode_wrapper = decode_wrapper or self.decode_wrapper
|
@@ -684,7 +688,7 @@ class FlashInferMLAIndicesUpdaterDecode:
|
|
684
688
|
q_indptr: torch.Tensor,
|
685
689
|
kv_indptr: torch.Tensor,
|
686
690
|
init_metadata_replay: bool = False,
|
687
|
-
spec_info: Optional[
|
691
|
+
spec_info: Optional[SpecInput] = None,
|
688
692
|
**fast_decode_kwargs,
|
689
693
|
):
|
690
694
|
bs = len(req_pool_indices)
|
@@ -772,7 +776,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
|
|
772
776
|
prefix_lens: torch.Tensor,
|
773
777
|
prefill_wrapper_paged: BatchMLAPagedAttentionWrapper,
|
774
778
|
use_ragged: bool,
|
775
|
-
spec_info: Optional[
|
779
|
+
spec_info: Optional[SpecInput] = None,
|
776
780
|
):
|
777
781
|
if use_ragged:
|
778
782
|
paged_kernel_lens = prefix_lens
|
@@ -807,7 +811,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
|
|
807
811
|
kv_indptr: torch.Tensor,
|
808
812
|
qo_indptr: torch.Tensor,
|
809
813
|
use_ragged: bool,
|
810
|
-
spec_info: Optional[
|
814
|
+
spec_info: Optional[SpecInput] = None,
|
811
815
|
):
|
812
816
|
bs = len(seq_lens)
|
813
817
|
sm_scale = self.scaling
|
@@ -834,9 +838,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
|
|
834
838
|
qo_indptr = qo_indptr[: bs + 1]
|
835
839
|
custom_mask = None
|
836
840
|
else:
|
837
|
-
assert isinstance(spec_info,
|
838
|
-
spec_info, EagleVerifyInput
|
839
|
-
)
|
841
|
+
assert isinstance(spec_info, SpecInput)
|
840
842
|
# TODO: Support topk > 1 with custom mask
|
841
843
|
kv_indices, kv_indptr, qo_indptr, custom_mask = (
|
842
844
|
spec_info.generate_attn_arg_prefill(
|
@@ -890,7 +892,7 @@ class FlashInferMLAMultiStepDraftBackend:
|
|
890
892
|
topk: int,
|
891
893
|
speculative_num_steps: int,
|
892
894
|
):
|
893
|
-
from sglang.srt.speculative.
|
895
|
+
from sglang.srt.speculative.spec_utils import generate_draft_decode_kv_indices
|
894
896
|
|
895
897
|
if topk > 1:
|
896
898
|
raise ValueError(
|
@@ -959,7 +961,7 @@ class FlashInferMLAMultiStepDraftBackend:
|
|
959
961
|
)
|
960
962
|
|
961
963
|
assert forward_batch.spec_info is not None
|
962
|
-
assert
|
964
|
+
assert forward_batch.spec_info.is_draft_input()
|
963
965
|
|
964
966
|
for i in range(self.speculative_num_steps - 1):
|
965
967
|
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
|
@@ -979,8 +981,6 @@ class FlashInferMLAMultiStepDraftBackend:
|
|
979
981
|
)
|
980
982
|
|
981
983
|
def call_fn(i, forward_batch):
|
982
|
-
assert forward_batch.spec_info is not None
|
983
|
-
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
984
984
|
forward_batch.spec_info.kv_indptr = (
|
985
985
|
forward_batch.spec_info.kv_indptr.clone()
|
986
986
|
)
|