sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +54 -37
- sglang/bench_one_batch_server.py +340 -34
- sglang/bench_serving.py +340 -159
- sglang/check_env.py +1 -1
- sglang/compile_deep_gemm.py +6 -2
- sglang/global_config.py +1 -25
- sglang/lang/api.py +6 -0
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/interpreter.py +1 -0
- sglang/lang/ir.py +13 -0
- sglang/launch_server.py +9 -2
- sglang/profiler.py +20 -3
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
- sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
- sglang/srt/compilation/backend.py +437 -0
- sglang/srt/compilation/compilation_config.py +20 -0
- sglang/srt/compilation/compilation_counter.py +47 -0
- sglang/srt/compilation/compile.py +210 -0
- sglang/srt/compilation/compiler_interface.py +503 -0
- sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
- sglang/srt/compilation/fix_functionalization.py +134 -0
- sglang/srt/compilation/fx_utils.py +83 -0
- sglang/srt/compilation/inductor_pass.py +140 -0
- sglang/srt/compilation/pass_manager.py +66 -0
- sglang/srt/compilation/piecewise_context_manager.py +40 -0
- sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/deepseek_ocr.py +262 -0
- sglang/srt/configs/deepseekvl2.py +194 -96
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +2 -7
- sglang/srt/configs/falcon_h1.py +309 -0
- sglang/srt/configs/load_config.py +33 -2
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +284 -118
- sglang/srt/configs/modelopt_config.py +30 -0
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/olmo3.py +105 -0
- sglang/srt/configs/points_v15_chat.py +29 -0
- sglang/srt/configs/qwen3_next.py +11 -47
- sglang/srt/configs/qwen3_omni.py +613 -0
- sglang/srt/configs/qwen3_vl.py +576 -0
- sglang/srt/connector/remote_instance.py +1 -1
- sglang/srt/constrained/base_grammar_backend.py +6 -1
- sglang/srt/constrained/llguidance_backend.py +5 -0
- sglang/srt/constrained/outlines_backend.py +1 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
- sglang/srt/constrained/utils.py +12 -0
- sglang/srt/constrained/xgrammar_backend.py +26 -15
- sglang/srt/debug_utils/dumper.py +10 -3
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
- sglang/srt/disaggregation/base/conn.py +17 -4
- sglang/srt/disaggregation/common/conn.py +268 -98
- sglang/srt/disaggregation/decode.py +172 -39
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
- sglang/srt/disaggregation/fake/conn.py +11 -3
- sglang/srt/disaggregation/mooncake/conn.py +203 -555
- sglang/srt/disaggregation/nixl/conn.py +217 -63
- sglang/srt/disaggregation/prefill.py +113 -270
- sglang/srt/disaggregation/utils.py +36 -5
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
- sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
- sglang/srt/distributed/device_communicators/pynccl.py +24 -12
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/naive_distributed.py +5 -4
- sglang/srt/distributed/parallel_state.py +203 -97
- sglang/srt/elastic_ep/elastic_ep.py +74 -0
- sglang/srt/entrypoints/context.py +3 -2
- sglang/srt/entrypoints/engine.py +85 -65
- sglang/srt/entrypoints/grpc_server.py +632 -305
- sglang/srt/entrypoints/harmony_utils.py +2 -2
- sglang/srt/entrypoints/http_server.py +169 -17
- sglang/srt/entrypoints/http_server_engine.py +1 -7
- sglang/srt/entrypoints/openai/protocol.py +327 -34
- sglang/srt/entrypoints/openai/serving_base.py +74 -8
- sglang/srt/entrypoints/openai/serving_chat.py +202 -118
- sglang/srt/entrypoints/openai/serving_classify.py +204 -0
- sglang/srt/entrypoints/openai/serving_completions.py +20 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/entrypoints/openai/serving_responses.py +47 -2
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +323 -0
- sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
- sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
- sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
- sglang/srt/eplb/expert_distribution.py +3 -4
- sglang/srt/eplb/expert_location.py +30 -5
- sglang/srt/eplb/expert_location_dispatch.py +2 -2
- sglang/srt/eplb/expert_location_updater.py +2 -2
- sglang/srt/function_call/base_format_detector.py +17 -18
- sglang/srt/function_call/function_call_parser.py +21 -16
- sglang/srt/function_call/glm4_moe_detector.py +4 -8
- sglang/srt/function_call/gpt_oss_detector.py +24 -1
- sglang/srt/function_call/json_array_parser.py +61 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/utils.py +98 -7
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/grpc_request_manager.py +915 -0
- sglang/srt/grpc/health_servicer.py +189 -0
- sglang/srt/grpc/scheduler_launcher.py +181 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
- sglang/srt/layers/activation.py +11 -7
- sglang/srt/layers/attention/aiter_backend.py +17 -18
- sglang/srt/layers/attention/ascend_backend.py +125 -10
- sglang/srt/layers/attention/attention_registry.py +226 -0
- sglang/srt/layers/attention/base_attn_backend.py +32 -4
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +0 -1
- sglang/srt/layers/attention/fla/chunk_o.py +1 -1
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
- sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
- sglang/srt/layers/attention/fla/index.py +0 -2
- sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
- sglang/srt/layers/attention/fla/utils.py +0 -3
- sglang/srt/layers/attention/fla/wy_fast.py +0 -2
- sglang/srt/layers/attention/flashattention_backend.py +52 -15
- sglang/srt/layers/attention/flashinfer_backend.py +357 -212
- sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
- sglang/srt/layers/attention/flashmla_backend.py +9 -7
- sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
- sglang/srt/layers/attention/intel_amx_backend.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
- sglang/srt/layers/attention/mamba/mamba.py +514 -1
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
- sglang/srt/layers/attention/nsa/utils.py +23 -0
- sglang/srt/layers/attention/nsa_backend.py +1201 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/triton_backend.py +249 -42
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
- sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
- sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
- sglang/srt/layers/attention/utils.py +11 -7
- sglang/srt/layers/attention/vision.py +61 -3
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/xpu_backend.py +1028 -0
- sglang/srt/layers/communicator.py +19 -7
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
- sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
- sglang/srt/layers/dp_attention.py +28 -1
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +47 -15
- sglang/srt/layers/linear.py +30 -5
- sglang/srt/layers/logits_processor.py +161 -18
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/cutlass_moe.py +0 -2
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
- sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
- sglang/srt/layers/moe/ep_moe/layer.py +243 -448
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
- sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/moe_runner/triton.py +3 -1
- sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
- sglang/srt/layers/moe/router.py +51 -15
- sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
- sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
- sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
- sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
- sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
- sglang/srt/layers/moe/topk.py +3 -2
- sglang/srt/layers/moe/utils.py +27 -1
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/__init__.py +2 -53
- sglang/srt/layers/quantization/awq.py +183 -6
- sglang/srt/layers/quantization/awq_triton.py +29 -0
- sglang/srt/layers/quantization/base_config.py +20 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
- sglang/srt/layers/quantization/fp8.py +86 -20
- sglang/srt/layers/quantization/fp8_kernel.py +55 -10
- sglang/srt/layers/quantization/fp8_utils.py +43 -15
- sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
- sglang/srt/layers/quantization/gptq.py +0 -1
- sglang/srt/layers/quantization/int8_kernel.py +18 -2
- sglang/srt/layers/quantization/marlin_utils.py +12 -0
- sglang/srt/layers/quantization/modelopt_quant.py +141 -81
- sglang/srt/layers/quantization/mxfp4.py +17 -34
- sglang/srt/layers/quantization/petit.py +1 -1
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
- sglang/srt/layers/quantization/unquant.py +1 -4
- sglang/srt/layers/quantization/utils.py +0 -1
- sglang/srt/layers/quantization/w4afp8.py +51 -24
- sglang/srt/layers/quantization/w8a8_int8.py +45 -27
- sglang/srt/layers/radix_attention.py +59 -9
- sglang/srt/layers/rotary_embedding.py +750 -46
- sglang/srt/layers/sampler.py +84 -16
- sglang/srt/layers/sparse_pooler.py +98 -0
- sglang/srt/layers/utils.py +23 -1
- sglang/srt/layers/vocab_parallel_embedding.py +4 -1
- sglang/srt/lora/backend/base_backend.py +3 -3
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +9 -4
- sglang/srt/lora/eviction_policy.py +139 -0
- sglang/srt/lora/lora.py +7 -5
- sglang/srt/lora/lora_manager.py +33 -7
- sglang/srt/lora/lora_registry.py +1 -1
- sglang/srt/lora/mem_pool.py +41 -17
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
- sglang/srt/lora/utils.py +7 -5
- sglang/srt/managers/cache_controller.py +83 -152
- sglang/srt/managers/data_parallel_controller.py +156 -87
- sglang/srt/managers/detokenizer_manager.py +51 -24
- sglang/srt/managers/io_struct.py +223 -129
- sglang/srt/managers/mm_utils.py +49 -10
- sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +130 -0
- sglang/srt/managers/schedule_batch.py +340 -529
- sglang/srt/managers/schedule_policy.py +158 -18
- sglang/srt/managers/scheduler.py +665 -620
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
- sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
- sglang/srt/managers/scheduler_pp_mixin.py +341 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
- sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
- sglang/srt/managers/tokenizer_manager.py +462 -226
- sglang/srt/managers/tp_worker.py +217 -156
- sglang/srt/managers/utils.py +79 -47
- sglang/srt/mem_cache/allocator.py +21 -22
- sglang/srt/mem_cache/allocator_ascend.py +42 -28
- sglang/srt/mem_cache/base_prefix_cache.py +3 -3
- sglang/srt/mem_cache/chunk_cache.py +20 -2
- sglang/srt/mem_cache/common.py +480 -0
- sglang/srt/mem_cache/evict_policy.py +38 -0
- sglang/srt/mem_cache/hicache_storage.py +44 -2
- sglang/srt/mem_cache/hiradix_cache.py +134 -34
- sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
- sglang/srt/mem_cache/memory_pool.py +602 -208
- sglang/srt/mem_cache/memory_pool_host.py +134 -183
- sglang/srt/mem_cache/multimodal_cache.py +0 -1
- sglang/srt/mem_cache/radix_cache.py +263 -78
- sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
- sglang/srt/mem_cache/swa_radix_cache.py +115 -58
- sglang/srt/metrics/collector.py +113 -120
- sglang/srt/metrics/func_timer.py +3 -8
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +81 -36
- sglang/srt/model_executor/forward_batch_info.py +40 -50
- sglang/srt/model_executor/model_runner.py +507 -319
- sglang/srt/model_executor/npu_graph_runner.py +11 -5
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
- sglang/srt/model_loader/__init__.py +1 -1
- sglang/srt/model_loader/loader.py +438 -37
- sglang/srt/model_loader/utils.py +0 -1
- sglang/srt/model_loader/weight_utils.py +200 -27
- sglang/srt/models/apertus.py +2 -3
- sglang/srt/models/arcee.py +2 -2
- sglang/srt/models/bailing_moe.py +40 -56
- sglang/srt/models/bailing_moe_nextn.py +3 -4
- sglang/srt/models/bert.py +1 -1
- sglang/srt/models/deepseek_nextn.py +25 -4
- sglang/srt/models/deepseek_ocr.py +1516 -0
- sglang/srt/models/deepseek_v2.py +793 -235
- sglang/srt/models/dots_ocr.py +171 -0
- sglang/srt/models/dots_vlm.py +0 -1
- sglang/srt/models/dots_vlm_vit.py +1 -1
- sglang/srt/models/falcon_h1.py +570 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +17 -1
- sglang/srt/models/gemma3n_mm.py +2 -3
- sglang/srt/models/glm4_moe.py +17 -40
- sglang/srt/models/glm4_moe_nextn.py +4 -4
- sglang/srt/models/glm4v.py +3 -2
- sglang/srt/models/glm4v_moe.py +6 -6
- sglang/srt/models/gpt_oss.py +12 -35
- sglang/srt/models/grok.py +10 -23
- sglang/srt/models/hunyuan.py +2 -7
- sglang/srt/models/interns1.py +0 -1
- sglang/srt/models/kimi_vl.py +1 -7
- sglang/srt/models/kimi_vl_moonvit.py +4 -2
- sglang/srt/models/llama.py +6 -2
- sglang/srt/models/llama_eagle3.py +1 -1
- sglang/srt/models/longcat_flash.py +6 -23
- sglang/srt/models/longcat_flash_nextn.py +4 -15
- sglang/srt/models/mimo.py +2 -13
- sglang/srt/models/mimo_mtp.py +1 -2
- sglang/srt/models/minicpmo.py +7 -5
- sglang/srt/models/mixtral.py +1 -4
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/mllama4.py +27 -6
- sglang/srt/models/nemotron_h.py +511 -0
- sglang/srt/models/olmo2.py +31 -4
- sglang/srt/models/opt.py +5 -5
- sglang/srt/models/phi.py +1 -1
- sglang/srt/models/phi4mm.py +1 -1
- sglang/srt/models/phimoe.py +0 -1
- sglang/srt/models/pixtral.py +0 -3
- sglang/srt/models/points_v15_chat.py +186 -0
- sglang/srt/models/qwen.py +0 -1
- sglang/srt/models/qwen2.py +0 -7
- sglang/srt/models/qwen2_5_vl.py +5 -5
- sglang/srt/models/qwen2_audio.py +2 -15
- sglang/srt/models/qwen2_moe.py +70 -4
- sglang/srt/models/qwen2_vl.py +6 -3
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +50 -38
- sglang/srt/models/qwen3_next.py +43 -21
- sglang/srt/models/qwen3_next_mtp.py +3 -4
- sglang/srt/models/qwen3_omni_moe.py +661 -0
- sglang/srt/models/qwen3_vl.py +791 -0
- sglang/srt/models/qwen3_vl_moe.py +343 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/roberta.py +55 -3
- sglang/srt/models/sarashina2_vision.py +268 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +3 -5
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +61 -0
- sglang/srt/multimodal/processors/base_processor.py +21 -9
- sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
- sglang/srt/multimodal/processors/dots_vlm.py +2 -4
- sglang/srt/multimodal/processors/glm4v.py +1 -5
- sglang/srt/multimodal/processors/internvl.py +20 -10
- sglang/srt/multimodal/processors/janus_pro.py +0 -1
- sglang/srt/multimodal/processors/mllama4.py +0 -8
- sglang/srt/multimodal/processors/phi4mm.py +0 -1
- sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
- sglang/srt/multimodal/processors/qwen_vl.py +83 -17
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/multimodal/processors/step3_vl.py +1 -1
- sglang/srt/parser/conversation.py +41 -0
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/parser/reasoning_parser.py +0 -1
- sglang/srt/sampling/custom_logit_processor.py +77 -2
- sglang/srt/sampling/sampling_batch_info.py +36 -23
- sglang/srt/sampling/sampling_params.py +75 -0
- sglang/srt/server_args.py +1300 -338
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +161 -0
- sglang/srt/speculative/base_spec_worker.py +34 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/draft_utils.py +226 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
- sglang/srt/speculative/eagle_info.py +786 -0
- sglang/srt/speculative/eagle_info_v2.py +458 -0
- sglang/srt/speculative/eagle_utils.py +113 -1270
- sglang/srt/speculative/eagle_worker.py +120 -285
- sglang/srt/speculative/eagle_worker_v2.py +702 -0
- sglang/srt/speculative/ngram_info.py +433 -0
- sglang/srt/speculative/ngram_worker.py +246 -0
- sglang/srt/speculative/spec_info.py +49 -0
- sglang/srt/speculative/spec_utils.py +641 -0
- sglang/srt/speculative/standalone_worker.py +4 -14
- sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
- sglang/srt/tracing/trace.py +32 -6
- sglang/srt/two_batch_overlap.py +35 -18
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
- sglang/srt/{utils.py → utils/common.py} +583 -113
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
- sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
- sglang/srt/{offloader.py → utils/offloader.py} +4 -4
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/profile_merger.py +199 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_flashattn_backend.py +1 -1
- sglang/test/attention/test_flashattn_mla_backend.py +0 -1
- sglang/test/attention/test_prefix_chunk_info.py +0 -2
- sglang/test/attention/test_trtllm_mla_backend.py +221 -53
- sglang/test/few_shot_gsm8k_engine.py +2 -4
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/kit_matched_stop.py +157 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +120 -11
- sglang/test/runners.py +3 -1
- sglang/test/send_one.py +42 -7
- sglang/test/simple_eval_common.py +8 -2
- sglang/test/simple_eval_gpqa.py +0 -1
- sglang/test/simple_eval_humaneval.py +0 -3
- sglang/test/simple_eval_longbench_v2.py +344 -0
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +3 -4
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
- sglang/test/test_cutlass_moe.py +1 -2
- sglang/test/test_cutlass_w4a8_moe.py +10 -20
- sglang/test/test_deterministic.py +430 -0
- sglang/test/test_deterministic_utils.py +73 -0
- sglang/test/test_disaggregation_utils.py +93 -1
- sglang/test/test_marlin_moe.py +0 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +432 -16
- sglang/utils.py +10 -1
- sglang/version.py +1 -1
- {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
- {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
- sglang/srt/entrypoints/grpc_request_manager.py +0 -580
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- sglang/srt/speculative/build_eagle_tree.py +0 -427
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
- /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
- /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
|
@@ -21,6 +21,7 @@ Life cycle of a request in the decode server
|
|
|
21
21
|
from __future__ import annotations
|
|
22
22
|
|
|
23
23
|
import logging
|
|
24
|
+
import time
|
|
24
25
|
from collections import deque
|
|
25
26
|
from dataclasses import dataclass
|
|
26
27
|
from http import HTTPStatus
|
|
@@ -29,6 +30,7 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union
|
|
|
29
30
|
import torch
|
|
30
31
|
from torch.distributed import ProcessGroup
|
|
31
32
|
|
|
33
|
+
from sglang.srt.configs.mamba_utils import Mamba2CacheParams
|
|
32
34
|
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
|
|
33
35
|
from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVPoll
|
|
34
36
|
from sglang.srt.disaggregation.utils import (
|
|
@@ -45,13 +47,19 @@ from sglang.srt.disaggregation.utils import (
|
|
|
45
47
|
prepare_abort,
|
|
46
48
|
)
|
|
47
49
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
|
48
|
-
from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch
|
|
50
|
+
from sglang.srt.managers.schedule_batch import FINISH_ABORT, RequestStage, ScheduleBatch
|
|
49
51
|
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
|
50
52
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
|
51
|
-
from sglang.srt.mem_cache.memory_pool import
|
|
52
|
-
|
|
53
|
-
|
|
53
|
+
from sglang.srt.mem_cache.memory_pool import (
|
|
54
|
+
HybridLinearKVPool,
|
|
55
|
+
HybridReqToTokenPool,
|
|
56
|
+
KVCache,
|
|
57
|
+
NSATokenToKVPool,
|
|
58
|
+
ReqToTokenPool,
|
|
59
|
+
SWAKVPool,
|
|
60
|
+
)
|
|
54
61
|
from sglang.srt.utils import get_int_env_var, require_mlp_sync
|
|
62
|
+
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
|
55
63
|
|
|
56
64
|
logger = logging.getLogger(__name__)
|
|
57
65
|
|
|
@@ -123,6 +131,35 @@ class DecodeReqToTokenPool:
|
|
|
123
131
|
self.free_slots = list(range(self.size + self.pre_alloc_size))
|
|
124
132
|
|
|
125
133
|
|
|
134
|
+
class HybridMambaDecodeReqToTokenPool(HybridReqToTokenPool):
|
|
135
|
+
|
|
136
|
+
def __init__(
|
|
137
|
+
self,
|
|
138
|
+
size: int,
|
|
139
|
+
max_context_len: int,
|
|
140
|
+
device: str,
|
|
141
|
+
enable_memory_saver: bool,
|
|
142
|
+
cache_params: "Mamba2CacheParams",
|
|
143
|
+
speculative_num_draft_tokens: int,
|
|
144
|
+
pre_alloc_size: int,
|
|
145
|
+
):
|
|
146
|
+
DecodeReqToTokenPool.__init__(
|
|
147
|
+
self,
|
|
148
|
+
size=size,
|
|
149
|
+
max_context_len=max_context_len,
|
|
150
|
+
device=device,
|
|
151
|
+
enable_memory_saver=enable_memory_saver,
|
|
152
|
+
pre_alloc_size=pre_alloc_size,
|
|
153
|
+
)
|
|
154
|
+
self._init_mamba_pool(
|
|
155
|
+
size + pre_alloc_size, cache_params, device, speculative_num_draft_tokens
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
def clear(self):
|
|
159
|
+
self.free_slots = list(range(self.size + self.pre_alloc_size))
|
|
160
|
+
self.mamba_pool.clear()
|
|
161
|
+
|
|
162
|
+
|
|
126
163
|
@dataclass
|
|
127
164
|
class DecodeRequest:
|
|
128
165
|
req: Req
|
|
@@ -216,6 +253,28 @@ class DecodePreallocQueue:
|
|
|
216
253
|
self.metadata_buffers.get_buf_infos()
|
|
217
254
|
)
|
|
218
255
|
|
|
256
|
+
if hasattr(self.token_to_kv_pool, "get_state_buf_infos"):
|
|
257
|
+
state_data_ptrs, state_data_lens, state_item_lens = (
|
|
258
|
+
self.token_to_kv_pool.get_state_buf_infos()
|
|
259
|
+
)
|
|
260
|
+
kv_args.state_data_ptrs = state_data_ptrs
|
|
261
|
+
kv_args.state_data_lens = state_data_lens
|
|
262
|
+
kv_args.state_item_lens = state_item_lens
|
|
263
|
+
|
|
264
|
+
if isinstance(self.token_to_kv_pool, SWAKVPool):
|
|
265
|
+
kv_args.state_type = "swa"
|
|
266
|
+
elif isinstance(self.token_to_kv_pool, HybridLinearKVPool):
|
|
267
|
+
kv_args.state_type = "mamba"
|
|
268
|
+
elif isinstance(self.token_to_kv_pool, NSATokenToKVPool):
|
|
269
|
+
kv_args.state_type = "nsa"
|
|
270
|
+
else:
|
|
271
|
+
kv_args.state_type = "none"
|
|
272
|
+
else:
|
|
273
|
+
kv_args.state_data_ptrs = []
|
|
274
|
+
kv_args.state_data_lens = []
|
|
275
|
+
kv_args.state_item_lens = []
|
|
276
|
+
kv_args.state_type = "none"
|
|
277
|
+
|
|
219
278
|
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
|
|
220
279
|
kv_args.gpu_id = self.scheduler.gpu_id
|
|
221
280
|
kv_manager_class: Type[BaseKVManager] = get_kv_class(
|
|
@@ -253,6 +312,7 @@ class DecodePreallocQueue:
|
|
|
253
312
|
prefill_dp_rank=req.data_parallel_rank,
|
|
254
313
|
)
|
|
255
314
|
|
|
315
|
+
req.add_latency(RequestStage.DECODE_PREPARE)
|
|
256
316
|
self.queue.append(
|
|
257
317
|
DecodeRequest(req=req, kv_receiver=kv_receiver, waiting_for_input=False)
|
|
258
318
|
)
|
|
@@ -412,17 +472,62 @@ class DecodePreallocQueue:
|
|
|
412
472
|
.cpu()
|
|
413
473
|
.numpy()
|
|
414
474
|
)
|
|
475
|
+
page_size = self.token_to_kv_pool_allocator.page_size
|
|
476
|
+
|
|
477
|
+
# Prepare extra pool indices for hybrid models
|
|
478
|
+
if isinstance(self.token_to_kv_pool, HybridLinearKVPool):
|
|
479
|
+
# Mamba hybrid model: single mamba state index
|
|
480
|
+
state_indices = [
|
|
481
|
+
self.req_to_token_pool.req_index_to_mamba_index_mapping[
|
|
482
|
+
decode_req.req.req_pool_idx
|
|
483
|
+
]
|
|
484
|
+
.cpu()
|
|
485
|
+
.numpy()
|
|
486
|
+
]
|
|
487
|
+
elif isinstance(self.token_to_kv_pool, SWAKVPool):
|
|
488
|
+
# SWA hybrid model: send decode-side SWA window indices
|
|
489
|
+
seq_len = len(decode_req.req.origin_input_ids)
|
|
490
|
+
window_size = self.scheduler.sliding_window_size
|
|
491
|
+
|
|
492
|
+
window_start = max(0, seq_len - window_size)
|
|
493
|
+
window_start = (window_start // page_size) * page_size
|
|
494
|
+
window_kv_indices_full = self.req_to_token_pool.req_to_token[
|
|
495
|
+
decode_req.req.req_pool_idx, window_start:seq_len
|
|
496
|
+
]
|
|
497
|
+
|
|
498
|
+
# Translate to SWA pool indices
|
|
499
|
+
window_kv_indices_swa = (
|
|
500
|
+
self.token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
|
|
501
|
+
window_kv_indices_full
|
|
502
|
+
)
|
|
503
|
+
)
|
|
504
|
+
state_indices = window_kv_indices_swa.cpu().numpy()
|
|
505
|
+
state_indices = kv_to_page_indices(state_indices, page_size)
|
|
506
|
+
elif isinstance(self.token_to_kv_pool, NSATokenToKVPool):
|
|
507
|
+
seq_len = len(decode_req.req.origin_input_ids)
|
|
508
|
+
kv_indices_full = self.req_to_token_pool.req_to_token[
|
|
509
|
+
decode_req.req.req_pool_idx, :seq_len
|
|
510
|
+
]
|
|
511
|
+
state_indices = kv_indices_full.cpu().numpy()
|
|
512
|
+
state_indices = kv_to_page_indices(state_indices, page_size)
|
|
513
|
+
else:
|
|
514
|
+
state_indices = None
|
|
415
515
|
|
|
416
516
|
decode_req.metadata_buffer_index = (
|
|
417
517
|
self.req_to_metadata_buffer_idx_allocator.alloc()
|
|
418
518
|
)
|
|
419
519
|
assert decode_req.metadata_buffer_index is not None
|
|
420
|
-
page_indices = kv_to_page_indices(
|
|
421
|
-
|
|
520
|
+
page_indices = kv_to_page_indices(kv_indices, page_size)
|
|
521
|
+
decode_req.kv_receiver.init(
|
|
522
|
+
page_indices, decode_req.metadata_buffer_index, state_indices
|
|
422
523
|
)
|
|
423
|
-
decode_req.
|
|
524
|
+
decode_req.req.add_latency(RequestStage.DECODE_BOOTSTRAP)
|
|
424
525
|
preallocated_reqs.append(decode_req)
|
|
425
526
|
indices_to_remove.add(i)
|
|
527
|
+
decode_req.req.time_stats.decode_transfer_queue_entry_time = (
|
|
528
|
+
time.perf_counter()
|
|
529
|
+
)
|
|
530
|
+
decode_req.req.add_latency(RequestStage.DECODE_BOOTSTRAP)
|
|
426
531
|
|
|
427
532
|
self.queue = [
|
|
428
533
|
entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
|
|
@@ -496,7 +601,10 @@ class DecodePreallocQueue:
|
|
|
496
601
|
|
|
497
602
|
def _pre_alloc(self, req: Req) -> torch.Tensor:
|
|
498
603
|
"""Pre-allocate the memory for req_to_token and token_kv_pool"""
|
|
499
|
-
|
|
604
|
+
if isinstance(self.req_to_token_pool, HybridMambaDecodeReqToTokenPool):
|
|
605
|
+
req_pool_indices = self.req_to_token_pool.alloc(1, [req])
|
|
606
|
+
else:
|
|
607
|
+
req_pool_indices = self.req_to_token_pool.alloc(1)
|
|
500
608
|
|
|
501
609
|
assert (
|
|
502
610
|
req_pool_indices is not None
|
|
@@ -516,11 +624,19 @@ class DecodePreallocQueue:
|
|
|
516
624
|
dtype=torch.int64,
|
|
517
625
|
device=self.token_to_kv_pool_allocator.device,
|
|
518
626
|
),
|
|
627
|
+
prefix_lens_cpu=torch.tensor(
|
|
628
|
+
[0],
|
|
629
|
+
dtype=torch.int64,
|
|
630
|
+
),
|
|
519
631
|
seq_lens=torch.tensor(
|
|
520
632
|
[num_tokens],
|
|
521
633
|
dtype=torch.int64,
|
|
522
634
|
device=self.token_to_kv_pool_allocator.device,
|
|
523
635
|
),
|
|
636
|
+
seq_lens_cpu=torch.tensor(
|
|
637
|
+
[num_tokens],
|
|
638
|
+
dtype=torch.int64,
|
|
639
|
+
),
|
|
524
640
|
last_loc=torch.tensor(
|
|
525
641
|
[-1],
|
|
526
642
|
dtype=torch.int64,
|
|
@@ -596,8 +712,8 @@ class DecodeTransferQueue:
|
|
|
596
712
|
self.scheduler.stream_output(
|
|
597
713
|
[decode_req.req], decode_req.req.return_logprob
|
|
598
714
|
)
|
|
599
|
-
#
|
|
600
|
-
self.tree_cache.cache_finished_req(decode_req.req)
|
|
715
|
+
# release pre-allocated kv cache, but don't insert into the tree since it's failed
|
|
716
|
+
self.tree_cache.cache_finished_req(decode_req.req, is_insert=False)
|
|
601
717
|
indices_to_remove.add(i)
|
|
602
718
|
if self.scheduler.enable_metrics:
|
|
603
719
|
self.scheduler.metrics_collector.increment_transfer_failed_reqs()
|
|
@@ -607,16 +723,23 @@ class DecodeTransferQueue:
|
|
|
607
723
|
idx = decode_req.metadata_buffer_index
|
|
608
724
|
(
|
|
609
725
|
output_id,
|
|
726
|
+
cached_tokens,
|
|
610
727
|
output_token_logprobs_val,
|
|
611
728
|
output_token_logprobs_idx,
|
|
612
729
|
output_top_logprobs_val,
|
|
613
730
|
output_top_logprobs_idx,
|
|
731
|
+
output_topk_p,
|
|
732
|
+
output_topk_index,
|
|
614
733
|
output_hidden_states,
|
|
615
734
|
) = self.metadata_buffers.get_buf(idx)
|
|
616
735
|
|
|
617
736
|
decode_req.req.output_ids.append(output_id[0].item())
|
|
737
|
+
decode_req.req.cached_tokens = cached_tokens[0].item()
|
|
618
738
|
if not self.spec_algorithm.is_none():
|
|
739
|
+
decode_req.req.output_topk_p = output_topk_p
|
|
740
|
+
decode_req.req.output_topk_index = output_topk_index
|
|
619
741
|
decode_req.req.hidden_states_tensor = output_hidden_states
|
|
742
|
+
|
|
620
743
|
if decode_req.req.return_logprob:
|
|
621
744
|
decode_req.req.output_token_logprobs_val.append(
|
|
622
745
|
output_token_logprobs_val[0].item()
|
|
@@ -637,10 +760,17 @@ class DecodeTransferQueue:
|
|
|
637
760
|
|
|
638
761
|
if hasattr(decode_req.kv_receiver, "clear"):
|
|
639
762
|
decode_req.kv_receiver.clear()
|
|
763
|
+
decode_req.kv_receiver = None
|
|
764
|
+
|
|
765
|
+
indices_to_remove.add(i)
|
|
766
|
+
decode_req.req.time_stats.wait_queue_entry_time = time.perf_counter()
|
|
640
767
|
|
|
641
768
|
# special handling for sampling_params.max_new_tokens == 1
|
|
642
769
|
if decode_req.req.sampling_params.max_new_tokens == 1:
|
|
643
770
|
# finish immediately
|
|
771
|
+
decode_req.req.time_stats.forward_entry_time = (
|
|
772
|
+
decode_req.req.time_stats.completion_time
|
|
773
|
+
) = time.perf_counter()
|
|
644
774
|
decode_req.req.check_finished()
|
|
645
775
|
self.scheduler.stream_output(
|
|
646
776
|
[decode_req.req], decode_req.req.return_logprob
|
|
@@ -648,8 +778,6 @@ class DecodeTransferQueue:
|
|
|
648
778
|
self.tree_cache.cache_finished_req(decode_req.req)
|
|
649
779
|
else:
|
|
650
780
|
transferred_reqs.append(decode_req.req)
|
|
651
|
-
|
|
652
|
-
indices_to_remove.add(i)
|
|
653
781
|
elif poll in [
|
|
654
782
|
KVPoll.Bootstrapping,
|
|
655
783
|
KVPoll.WaitingForInput,
|
|
@@ -662,6 +790,7 @@ class DecodeTransferQueue:
|
|
|
662
790
|
for i in indices_to_remove:
|
|
663
791
|
idx = self.queue[i].metadata_buffer_index
|
|
664
792
|
assert idx != -1
|
|
793
|
+
self.queue[i].req.add_latency(RequestStage.DECODE_TRANSFERRED)
|
|
665
794
|
self.req_to_metadata_buffer_idx_allocator.free(idx)
|
|
666
795
|
|
|
667
796
|
self.queue = [
|
|
@@ -704,23 +833,27 @@ class SchedulerDisaggregationDecodeMixin:
|
|
|
704
833
|
elif prepare_mlp_sync_flag:
|
|
705
834
|
batch, _ = self._prepare_idle_batch_and_run(None)
|
|
706
835
|
|
|
707
|
-
|
|
836
|
+
queue_size = (
|
|
708
837
|
len(self.waiting_queue)
|
|
709
838
|
+ len(self.disagg_decode_transfer_queue.queue)
|
|
710
839
|
+ len(self.disagg_decode_prealloc_queue.queue)
|
|
711
|
-
|
|
712
|
-
|
|
840
|
+
)
|
|
841
|
+
if self.server_args.disaggregation_decode_enable_offload_kvcache:
|
|
842
|
+
queue_size += len(self.decode_offload_manager.ongoing_offload)
|
|
843
|
+
|
|
844
|
+
if batch is None and queue_size == 0:
|
|
713
845
|
self.self_check_during_idle()
|
|
714
846
|
|
|
715
847
|
self.last_batch = batch
|
|
716
848
|
|
|
717
849
|
@torch.no_grad()
|
|
718
850
|
def event_loop_overlap_disagg_decode(self: Scheduler):
|
|
719
|
-
result_queue = deque()
|
|
851
|
+
self.result_queue = deque()
|
|
720
852
|
self.last_batch: Optional[ScheduleBatch] = None
|
|
721
853
|
self.last_batch_in_queue = False # last batch is modified in-place, so we need another variable to track if it's extend
|
|
722
854
|
|
|
723
855
|
while True:
|
|
856
|
+
|
|
724
857
|
recv_reqs = self.recv_requests()
|
|
725
858
|
self.process_input_requests(recv_reqs)
|
|
726
859
|
# polling and allocating kv cache
|
|
@@ -731,6 +864,7 @@ class SchedulerDisaggregationDecodeMixin:
|
|
|
731
864
|
|
|
732
865
|
prepare_mlp_sync_flag = require_mlp_sync(self.server_args)
|
|
733
866
|
|
|
867
|
+
batch_result = None
|
|
734
868
|
if batch:
|
|
735
869
|
# Generate fake extend output.
|
|
736
870
|
if batch.forward_mode.is_extend():
|
|
@@ -739,51 +873,43 @@ class SchedulerDisaggregationDecodeMixin:
|
|
|
739
873
|
batch.reqs, any(req.return_logprob for req in batch.reqs)
|
|
740
874
|
)
|
|
741
875
|
if prepare_mlp_sync_flag:
|
|
742
|
-
batch_,
|
|
876
|
+
batch_, batch_result = self._prepare_idle_batch_and_run(
|
|
743
877
|
None, delay_process=True
|
|
744
878
|
)
|
|
745
879
|
if batch_:
|
|
746
|
-
result_queue.append((batch_.copy(),
|
|
880
|
+
self.result_queue.append((batch_.copy(), batch_result))
|
|
747
881
|
last_batch_in_queue = True
|
|
748
882
|
else:
|
|
749
883
|
if prepare_mlp_sync_flag:
|
|
750
884
|
self.prepare_mlp_sync_batch(batch)
|
|
751
|
-
|
|
752
|
-
result_queue.append((batch.copy(),
|
|
753
|
-
|
|
754
|
-
if (self.last_batch is None) or (not self.last_batch_in_queue):
|
|
755
|
-
# Create a dummy first batch to start the pipeline for overlap schedule.
|
|
756
|
-
# It is now used for triggering the sampling_info_done event.
|
|
757
|
-
tmp_batch = ScheduleBatch(
|
|
758
|
-
reqs=None,
|
|
759
|
-
forward_mode=ForwardMode.DUMMY_FIRST,
|
|
760
|
-
next_batch_sampling_info=self.tp_worker.cur_sampling_info,
|
|
761
|
-
)
|
|
762
|
-
self.set_next_batch_sampling_info_done(tmp_batch)
|
|
885
|
+
batch_result = self.run_batch(batch)
|
|
886
|
+
self.result_queue.append((batch.copy(), batch_result))
|
|
763
887
|
last_batch_in_queue = True
|
|
764
888
|
|
|
765
889
|
elif prepare_mlp_sync_flag:
|
|
766
|
-
batch,
|
|
890
|
+
batch, batch_result = self._prepare_idle_batch_and_run(
|
|
767
891
|
None, delay_process=True
|
|
768
892
|
)
|
|
769
893
|
if batch:
|
|
770
|
-
result_queue.append((batch.copy(),
|
|
894
|
+
self.result_queue.append((batch.copy(), batch_result))
|
|
771
895
|
last_batch_in_queue = True
|
|
772
896
|
|
|
773
897
|
# Process the results of the previous batch but skip if the last batch is extend
|
|
774
898
|
if self.last_batch and self.last_batch_in_queue:
|
|
775
|
-
tmp_batch, tmp_result = result_queue.popleft()
|
|
776
|
-
tmp_batch.next_batch_sampling_info = (
|
|
777
|
-
self.tp_worker.cur_sampling_info if batch else None
|
|
778
|
-
)
|
|
899
|
+
tmp_batch, tmp_result = self.result_queue.popleft()
|
|
779
900
|
self.process_batch_result(tmp_batch, tmp_result)
|
|
780
901
|
|
|
781
|
-
|
|
902
|
+
self.launch_batch_sample_if_needed(batch_result)
|
|
903
|
+
|
|
904
|
+
queue_size = (
|
|
782
905
|
len(self.waiting_queue)
|
|
783
906
|
+ len(self.disagg_decode_transfer_queue.queue)
|
|
784
907
|
+ len(self.disagg_decode_prealloc_queue.queue)
|
|
785
|
-
|
|
786
|
-
|
|
908
|
+
)
|
|
909
|
+
if self.server_args.disaggregation_decode_enable_offload_kvcache:
|
|
910
|
+
queue_size += len(self.decode_offload_manager.ongoing_offload)
|
|
911
|
+
|
|
912
|
+
if batch is None and queue_size == 0:
|
|
787
913
|
self.self_check_during_idle()
|
|
788
914
|
|
|
789
915
|
self.last_batch = batch
|
|
@@ -853,6 +979,7 @@ class SchedulerDisaggregationDecodeMixin:
|
|
|
853
979
|
# we can only add at least `num_not_used_batch` new batch to the running queue
|
|
854
980
|
if i < num_not_used_batch:
|
|
855
981
|
can_run_list.append(req)
|
|
982
|
+
req.add_latency(RequestStage.DECODE_WAITING)
|
|
856
983
|
req.init_next_round_input(self.tree_cache)
|
|
857
984
|
else:
|
|
858
985
|
waiting_queue.append(req)
|
|
@@ -861,6 +988,9 @@ class SchedulerDisaggregationDecodeMixin:
|
|
|
861
988
|
if len(can_run_list) == 0:
|
|
862
989
|
return None
|
|
863
990
|
|
|
991
|
+
for req in can_run_list:
|
|
992
|
+
req.time_stats.forward_entry_time = time.perf_counter()
|
|
993
|
+
|
|
864
994
|
# construct a schedule batch with those requests and mark as decode
|
|
865
995
|
new_batch = ScheduleBatch.init_new(
|
|
866
996
|
can_run_list,
|
|
@@ -901,3 +1031,6 @@ class SchedulerDisaggregationDecodeMixin:
|
|
|
901
1031
|
self.disagg_decode_transfer_queue.pop_transferred()
|
|
902
1032
|
) # the requests which kv has arrived
|
|
903
1033
|
self.waiting_queue.extend(alloc_reqs)
|
|
1034
|
+
|
|
1035
|
+
if self.server_args.disaggregation_decode_enable_offload_kvcache:
|
|
1036
|
+
self.decode_offload_manager.check_offload_progress()
|
|
@@ -0,0 +1,185 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import threading
|
|
3
|
+
import time
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from sglang.srt.managers.cache_controller import HiCacheController
|
|
8
|
+
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
|
9
|
+
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
|
10
|
+
from sglang.srt.mem_cache.memory_pool import (
|
|
11
|
+
MHATokenToKVPool,
|
|
12
|
+
MLATokenToKVPool,
|
|
13
|
+
ReqToTokenPool,
|
|
14
|
+
)
|
|
15
|
+
from sglang.srt.mem_cache.memory_pool_host import (
|
|
16
|
+
MHATokenToKVPoolHost,
|
|
17
|
+
MLATokenToKVPoolHost,
|
|
18
|
+
)
|
|
19
|
+
from sglang.srt.server_args import ServerArgs
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class DecodeKVCacheOffloadManager:
|
|
25
|
+
"""Manage decode-side KV cache offloading lifecycle and operations."""
|
|
26
|
+
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
req_to_token_pool: ReqToTokenPool,
|
|
30
|
+
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
|
|
31
|
+
tp_group: torch.distributed.ProcessGroup,
|
|
32
|
+
tree_cache: BasePrefixCache,
|
|
33
|
+
server_args: ServerArgs,
|
|
34
|
+
) -> None:
|
|
35
|
+
self.req_to_token_pool = req_to_token_pool
|
|
36
|
+
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
|
37
|
+
self.page_size = server_args.page_size
|
|
38
|
+
self.server_args = server_args
|
|
39
|
+
self.request_counter = 0
|
|
40
|
+
self.tree_cache = tree_cache
|
|
41
|
+
kv_cache = self.token_to_kv_pool_allocator.get_kvcache()
|
|
42
|
+
if isinstance(kv_cache, MHATokenToKVPool):
|
|
43
|
+
self.decode_host_mem_pool = MHATokenToKVPoolHost(
|
|
44
|
+
kv_cache,
|
|
45
|
+
server_args.hicache_ratio,
|
|
46
|
+
server_args.hicache_size,
|
|
47
|
+
self.page_size,
|
|
48
|
+
server_args.hicache_mem_layout,
|
|
49
|
+
)
|
|
50
|
+
elif isinstance(kv_cache, MLATokenToKVPool):
|
|
51
|
+
self.decode_host_mem_pool = MLATokenToKVPoolHost(
|
|
52
|
+
kv_cache,
|
|
53
|
+
server_args.hicache_ratio,
|
|
54
|
+
server_args.hicache_size,
|
|
55
|
+
self.page_size,
|
|
56
|
+
server_args.hicache_mem_layout,
|
|
57
|
+
)
|
|
58
|
+
else:
|
|
59
|
+
raise ValueError("Unsupported KV cache type for decode offload")
|
|
60
|
+
|
|
61
|
+
self.tp_group = tp_group
|
|
62
|
+
self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group)
|
|
63
|
+
self.cache_controller = HiCacheController(
|
|
64
|
+
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
|
65
|
+
mem_pool_host=self.decode_host_mem_pool,
|
|
66
|
+
page_size=self.page_size,
|
|
67
|
+
tp_group=tp_group,
|
|
68
|
+
io_backend=server_args.hicache_io_backend,
|
|
69
|
+
load_cache_event=threading.Event(),
|
|
70
|
+
storage_backend=server_args.hicache_storage_backend,
|
|
71
|
+
model_name=server_args.served_model_name,
|
|
72
|
+
storage_backend_extra_config=server_args.hicache_storage_backend_extra_config,
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
self.ongoing_offload = {}
|
|
76
|
+
self.ongoing_backup = {}
|
|
77
|
+
logger.info("Enable offload kv cache for decode side")
|
|
78
|
+
|
|
79
|
+
def offload_kv_cache(self, req) -> bool:
|
|
80
|
+
"""Offload a finished request's KV cache to storage."""
|
|
81
|
+
|
|
82
|
+
if self.cache_controller is None or self.decode_host_mem_pool is None:
|
|
83
|
+
return False
|
|
84
|
+
|
|
85
|
+
if req.req_pool_idx == -1:
|
|
86
|
+
return False
|
|
87
|
+
|
|
88
|
+
token_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx]
|
|
89
|
+
if token_indices.dim() == 0 or token_indices.numel() == 0:
|
|
90
|
+
logger.debug(
|
|
91
|
+
f"Request {req.rid} has invalid token_indices: {token_indices}"
|
|
92
|
+
)
|
|
93
|
+
return False
|
|
94
|
+
|
|
95
|
+
tokens = req.origin_input_ids + req.output_ids
|
|
96
|
+
aligned_len = (len(tokens) // self.page_size) * self.page_size
|
|
97
|
+
if aligned_len == 0:
|
|
98
|
+
return False
|
|
99
|
+
|
|
100
|
+
token_indices = token_indices[:aligned_len]
|
|
101
|
+
tokens = tokens[:aligned_len]
|
|
102
|
+
|
|
103
|
+
# Asynchronously offload KV cache from device to host by cache controller
|
|
104
|
+
self.request_counter += 1
|
|
105
|
+
ack_id = self.request_counter
|
|
106
|
+
host_indices = self.cache_controller.write(
|
|
107
|
+
device_indices=token_indices.long(),
|
|
108
|
+
node_id=ack_id,
|
|
109
|
+
)
|
|
110
|
+
if host_indices is None:
|
|
111
|
+
logger.error(f"Not enough host memory for request {req.rid}")
|
|
112
|
+
return False
|
|
113
|
+
|
|
114
|
+
self.ongoing_offload[ack_id] = (req, host_indices, tokens, time.time())
|
|
115
|
+
return True
|
|
116
|
+
|
|
117
|
+
def check_offload_progress(self):
|
|
118
|
+
"""Check the progress of offload from device to host and backup from host to storage."""
|
|
119
|
+
cc = self.cache_controller
|
|
120
|
+
|
|
121
|
+
qsizes = torch.tensor(
|
|
122
|
+
[
|
|
123
|
+
len(cc.ack_write_queue),
|
|
124
|
+
cc.ack_backup_queue.qsize(),
|
|
125
|
+
],
|
|
126
|
+
dtype=torch.int,
|
|
127
|
+
)
|
|
128
|
+
if self.tp_world_size > 1:
|
|
129
|
+
torch.distributed.all_reduce(
|
|
130
|
+
qsizes, op=torch.distributed.ReduceOp.MIN, group=self.tp_group
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
n_write, n_backup = map(int, qsizes.tolist())
|
|
134
|
+
self._check_offload_progress(n_write)
|
|
135
|
+
self._check_backup_progress(n_backup)
|
|
136
|
+
|
|
137
|
+
def _check_offload_progress(self, finish_count):
|
|
138
|
+
"""Check the progress of offload from device to host."""
|
|
139
|
+
while finish_count > 0:
|
|
140
|
+
_, finish_event, ack_list = self.cache_controller.ack_write_queue.pop(0)
|
|
141
|
+
finish_event.synchronize()
|
|
142
|
+
for ack_id in ack_list:
|
|
143
|
+
req, host_indices, tokens, start_time = self.ongoing_offload.pop(ack_id)
|
|
144
|
+
|
|
145
|
+
# Release device
|
|
146
|
+
self.tree_cache.cache_finished_req(req)
|
|
147
|
+
|
|
148
|
+
# Trigger async backup from host to storage by cache controller
|
|
149
|
+
self._trigger_backup(req.rid, host_indices, tokens, start_time)
|
|
150
|
+
finish_count -= 1
|
|
151
|
+
|
|
152
|
+
def _check_backup_progress(self, finish_count):
|
|
153
|
+
"""Check the progress of backup from host to storage."""
|
|
154
|
+
for _ in range(finish_count):
|
|
155
|
+
storage_operation = self.cache_controller.ack_backup_queue.get()
|
|
156
|
+
ack_id = storage_operation.id
|
|
157
|
+
req_id, host_indices, start_time = self.ongoing_backup.pop(ack_id)
|
|
158
|
+
|
|
159
|
+
# Release host memory
|
|
160
|
+
self.decode_host_mem_pool.free(host_indices)
|
|
161
|
+
|
|
162
|
+
logger.debug(
|
|
163
|
+
f"Finished backup request {req_id}, free host memory, len:{len(host_indices)}, cost time:{time.time() - start_time:.2f} seconds."
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
def _trigger_backup(self, req_id, host_indices, tokens, start_time):
|
|
167
|
+
"""Trigger async backup from host to storage by cache controller."""
|
|
168
|
+
|
|
169
|
+
# Generate page hashes and write to storage
|
|
170
|
+
page_hashes = self._compute_prefix_hash(tokens)
|
|
171
|
+
ack_id = self.cache_controller.write_storage(
|
|
172
|
+
host_indices,
|
|
173
|
+
tokens,
|
|
174
|
+
hash_value=page_hashes,
|
|
175
|
+
)
|
|
176
|
+
self.ongoing_backup[ack_id] = (req_id, host_indices, start_time)
|
|
177
|
+
|
|
178
|
+
def _compute_prefix_hash(self, tokens):
|
|
179
|
+
last_hash = ""
|
|
180
|
+
page_hashes = []
|
|
181
|
+
for offset in range(0, len(tokens), self.page_size):
|
|
182
|
+
page_tokens = tokens[offset : offset + self.page_size]
|
|
183
|
+
last_hash = self.cache_controller.get_hash_str(page_tokens, last_hash)
|
|
184
|
+
page_hashes.append(last_hash)
|
|
185
|
+
return page_hashes
|
|
@@ -76,6 +76,7 @@ class ScheduleBatchDisaggregationDecodeMixin:
|
|
|
76
76
|
req_pool_indices, dtype=torch.int64, device=self.device
|
|
77
77
|
)
|
|
78
78
|
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64, device=self.device)
|
|
79
|
+
self.seq_lens_cpu = torch.tensor(seq_lens, dtype=torch.int64)
|
|
79
80
|
self.orig_seq_lens = torch.tensor(
|
|
80
81
|
seq_lens, dtype=torch.int32, device=self.device
|
|
81
82
|
)
|
|
@@ -125,31 +126,39 @@ class ScheduleBatchDisaggregationDecodeMixin:
|
|
|
125
126
|
req.grammar.finished = req.finished()
|
|
126
127
|
self.output_ids = torch.tensor(self.output_ids, device=self.device)
|
|
127
128
|
|
|
128
|
-
# Simulate the eagle run.
|
|
129
|
-
|
|
130
|
-
# of 0.
|
|
131
|
-
if not self.spec_algorithm.is_none():
|
|
129
|
+
# Simulate the eagle run.
|
|
130
|
+
if self.spec_algorithm.is_eagle():
|
|
132
131
|
|
|
133
132
|
b = len(self.reqs)
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
133
|
+
topk = server_args.speculative_eagle_topk
|
|
134
|
+
topk_p = torch.stack(
|
|
135
|
+
[
|
|
136
|
+
torch.as_tensor(
|
|
137
|
+
req.output_topk_p[:topk],
|
|
138
|
+
device=self.device,
|
|
139
|
+
dtype=torch.float32,
|
|
140
|
+
)
|
|
141
|
+
for req in self.reqs
|
|
142
|
+
],
|
|
143
|
+
dim=0,
|
|
140
144
|
)
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
+
topk_index = torch.stack(
|
|
146
|
+
[
|
|
147
|
+
torch.as_tensor(
|
|
148
|
+
req.output_topk_index[:topk],
|
|
149
|
+
device=self.device,
|
|
150
|
+
dtype=torch.int64,
|
|
151
|
+
)
|
|
152
|
+
for req in self.reqs
|
|
153
|
+
],
|
|
154
|
+
dim=0,
|
|
145
155
|
)
|
|
146
|
-
topk_index = topk_index.reshape(b, server_args.speculative_eagle_topk)
|
|
147
156
|
|
|
148
157
|
hidden_states_list = [req.hidden_states_tensor for req in self.reqs]
|
|
149
158
|
hidden_states = torch.stack(hidden_states_list, dim=0).to(self.device)
|
|
150
159
|
|
|
151
160
|
# local import to avoid circular import
|
|
152
|
-
from sglang.srt.speculative.
|
|
161
|
+
from sglang.srt.speculative.eagle_info import EagleDraftInput
|
|
153
162
|
|
|
154
163
|
spec_info = EagleDraftInput(
|
|
155
164
|
topk_p=topk_p,
|
|
@@ -48,9 +48,12 @@ class FakeKVSender(BaseKVSender):
|
|
|
48
48
|
def send(
|
|
49
49
|
self,
|
|
50
50
|
kv_indices: npt.NDArray[np.int32],
|
|
51
|
+
state_indices: Optional[List[int]] = None,
|
|
51
52
|
):
|
|
52
53
|
self.has_sent = True
|
|
53
|
-
logger.debug(
|
|
54
|
+
logger.debug(
|
|
55
|
+
f"FakeKVSender send with kv_indices: {kv_indices}, state_indices: {state_indices}"
|
|
56
|
+
)
|
|
54
57
|
|
|
55
58
|
def failure_exception(self):
|
|
56
59
|
raise Exception("Fake KVSender Exception")
|
|
@@ -75,10 +78,15 @@ class FakeKVReceiver(BaseKVReceiver):
|
|
|
75
78
|
logger.debug("FakeKVReceiver poll success")
|
|
76
79
|
return KVPoll.Success
|
|
77
80
|
|
|
78
|
-
def init(
|
|
81
|
+
def init(
|
|
82
|
+
self,
|
|
83
|
+
kv_indices: list[int],
|
|
84
|
+
aux_index: Optional[int] = None,
|
|
85
|
+
state_indices: Optional[List[int]] = None,
|
|
86
|
+
):
|
|
79
87
|
self.has_init = True
|
|
80
88
|
logger.debug(
|
|
81
|
-
f"FakeKVReceiver init with kv_indices: {kv_indices}, aux_index: {aux_index}"
|
|
89
|
+
f"FakeKVReceiver init with kv_indices: {kv_indices}, aux_index: {aux_index}, state_indices: {state_indices}"
|
|
82
90
|
)
|
|
83
91
|
|
|
84
92
|
def failure_exception(self):
|