sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +54 -37
- sglang/bench_one_batch_server.py +340 -34
- sglang/bench_serving.py +340 -159
- sglang/check_env.py +1 -1
- sglang/compile_deep_gemm.py +6 -2
- sglang/global_config.py +1 -25
- sglang/lang/api.py +6 -0
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/interpreter.py +1 -0
- sglang/lang/ir.py +13 -0
- sglang/launch_server.py +9 -2
- sglang/profiler.py +20 -3
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
- sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
- sglang/srt/compilation/backend.py +437 -0
- sglang/srt/compilation/compilation_config.py +20 -0
- sglang/srt/compilation/compilation_counter.py +47 -0
- sglang/srt/compilation/compile.py +210 -0
- sglang/srt/compilation/compiler_interface.py +503 -0
- sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
- sglang/srt/compilation/fix_functionalization.py +134 -0
- sglang/srt/compilation/fx_utils.py +83 -0
- sglang/srt/compilation/inductor_pass.py +140 -0
- sglang/srt/compilation/pass_manager.py +66 -0
- sglang/srt/compilation/piecewise_context_manager.py +40 -0
- sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/deepseek_ocr.py +262 -0
- sglang/srt/configs/deepseekvl2.py +194 -96
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +2 -7
- sglang/srt/configs/falcon_h1.py +309 -0
- sglang/srt/configs/load_config.py +33 -2
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +284 -118
- sglang/srt/configs/modelopt_config.py +30 -0
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/olmo3.py +105 -0
- sglang/srt/configs/points_v15_chat.py +29 -0
- sglang/srt/configs/qwen3_next.py +11 -47
- sglang/srt/configs/qwen3_omni.py +613 -0
- sglang/srt/configs/qwen3_vl.py +576 -0
- sglang/srt/connector/remote_instance.py +1 -1
- sglang/srt/constrained/base_grammar_backend.py +6 -1
- sglang/srt/constrained/llguidance_backend.py +5 -0
- sglang/srt/constrained/outlines_backend.py +1 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
- sglang/srt/constrained/utils.py +12 -0
- sglang/srt/constrained/xgrammar_backend.py +26 -15
- sglang/srt/debug_utils/dumper.py +10 -3
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
- sglang/srt/disaggregation/base/conn.py +17 -4
- sglang/srt/disaggregation/common/conn.py +268 -98
- sglang/srt/disaggregation/decode.py +172 -39
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
- sglang/srt/disaggregation/fake/conn.py +11 -3
- sglang/srt/disaggregation/mooncake/conn.py +203 -555
- sglang/srt/disaggregation/nixl/conn.py +217 -63
- sglang/srt/disaggregation/prefill.py +113 -270
- sglang/srt/disaggregation/utils.py +36 -5
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
- sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
- sglang/srt/distributed/device_communicators/pynccl.py +24 -12
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/naive_distributed.py +5 -4
- sglang/srt/distributed/parallel_state.py +203 -97
- sglang/srt/elastic_ep/elastic_ep.py +74 -0
- sglang/srt/entrypoints/context.py +3 -2
- sglang/srt/entrypoints/engine.py +85 -65
- sglang/srt/entrypoints/grpc_server.py +632 -305
- sglang/srt/entrypoints/harmony_utils.py +2 -2
- sglang/srt/entrypoints/http_server.py +169 -17
- sglang/srt/entrypoints/http_server_engine.py +1 -7
- sglang/srt/entrypoints/openai/protocol.py +327 -34
- sglang/srt/entrypoints/openai/serving_base.py +74 -8
- sglang/srt/entrypoints/openai/serving_chat.py +202 -118
- sglang/srt/entrypoints/openai/serving_classify.py +204 -0
- sglang/srt/entrypoints/openai/serving_completions.py +20 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/entrypoints/openai/serving_responses.py +47 -2
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +323 -0
- sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
- sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
- sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
- sglang/srt/eplb/expert_distribution.py +3 -4
- sglang/srt/eplb/expert_location.py +30 -5
- sglang/srt/eplb/expert_location_dispatch.py +2 -2
- sglang/srt/eplb/expert_location_updater.py +2 -2
- sglang/srt/function_call/base_format_detector.py +17 -18
- sglang/srt/function_call/function_call_parser.py +21 -16
- sglang/srt/function_call/glm4_moe_detector.py +4 -8
- sglang/srt/function_call/gpt_oss_detector.py +24 -1
- sglang/srt/function_call/json_array_parser.py +61 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/utils.py +98 -7
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/grpc_request_manager.py +915 -0
- sglang/srt/grpc/health_servicer.py +189 -0
- sglang/srt/grpc/scheduler_launcher.py +181 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
- sglang/srt/layers/activation.py +11 -7
- sglang/srt/layers/attention/aiter_backend.py +17 -18
- sglang/srt/layers/attention/ascend_backend.py +125 -10
- sglang/srt/layers/attention/attention_registry.py +226 -0
- sglang/srt/layers/attention/base_attn_backend.py +32 -4
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +0 -1
- sglang/srt/layers/attention/fla/chunk_o.py +1 -1
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
- sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
- sglang/srt/layers/attention/fla/index.py +0 -2
- sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
- sglang/srt/layers/attention/fla/utils.py +0 -3
- sglang/srt/layers/attention/fla/wy_fast.py +0 -2
- sglang/srt/layers/attention/flashattention_backend.py +52 -15
- sglang/srt/layers/attention/flashinfer_backend.py +357 -212
- sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
- sglang/srt/layers/attention/flashmla_backend.py +9 -7
- sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
- sglang/srt/layers/attention/intel_amx_backend.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
- sglang/srt/layers/attention/mamba/mamba.py +514 -1
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
- sglang/srt/layers/attention/nsa/utils.py +23 -0
- sglang/srt/layers/attention/nsa_backend.py +1201 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/triton_backend.py +249 -42
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
- sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
- sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
- sglang/srt/layers/attention/utils.py +11 -7
- sglang/srt/layers/attention/vision.py +61 -3
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/xpu_backend.py +1028 -0
- sglang/srt/layers/communicator.py +19 -7
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
- sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
- sglang/srt/layers/dp_attention.py +28 -1
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +47 -15
- sglang/srt/layers/linear.py +30 -5
- sglang/srt/layers/logits_processor.py +161 -18
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/cutlass_moe.py +0 -2
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
- sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
- sglang/srt/layers/moe/ep_moe/layer.py +243 -448
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
- sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/moe_runner/triton.py +3 -1
- sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
- sglang/srt/layers/moe/router.py +51 -15
- sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
- sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
- sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
- sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
- sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
- sglang/srt/layers/moe/topk.py +3 -2
- sglang/srt/layers/moe/utils.py +27 -1
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/__init__.py +2 -53
- sglang/srt/layers/quantization/awq.py +183 -6
- sglang/srt/layers/quantization/awq_triton.py +29 -0
- sglang/srt/layers/quantization/base_config.py +20 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
- sglang/srt/layers/quantization/fp8.py +86 -20
- sglang/srt/layers/quantization/fp8_kernel.py +55 -10
- sglang/srt/layers/quantization/fp8_utils.py +43 -15
- sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
- sglang/srt/layers/quantization/gptq.py +0 -1
- sglang/srt/layers/quantization/int8_kernel.py +18 -2
- sglang/srt/layers/quantization/marlin_utils.py +12 -0
- sglang/srt/layers/quantization/modelopt_quant.py +141 -81
- sglang/srt/layers/quantization/mxfp4.py +17 -34
- sglang/srt/layers/quantization/petit.py +1 -1
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
- sglang/srt/layers/quantization/unquant.py +1 -4
- sglang/srt/layers/quantization/utils.py +0 -1
- sglang/srt/layers/quantization/w4afp8.py +51 -24
- sglang/srt/layers/quantization/w8a8_int8.py +45 -27
- sglang/srt/layers/radix_attention.py +59 -9
- sglang/srt/layers/rotary_embedding.py +750 -46
- sglang/srt/layers/sampler.py +84 -16
- sglang/srt/layers/sparse_pooler.py +98 -0
- sglang/srt/layers/utils.py +23 -1
- sglang/srt/layers/vocab_parallel_embedding.py +4 -1
- sglang/srt/lora/backend/base_backend.py +3 -3
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +9 -4
- sglang/srt/lora/eviction_policy.py +139 -0
- sglang/srt/lora/lora.py +7 -5
- sglang/srt/lora/lora_manager.py +33 -7
- sglang/srt/lora/lora_registry.py +1 -1
- sglang/srt/lora/mem_pool.py +41 -17
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
- sglang/srt/lora/utils.py +7 -5
- sglang/srt/managers/cache_controller.py +83 -152
- sglang/srt/managers/data_parallel_controller.py +156 -87
- sglang/srt/managers/detokenizer_manager.py +51 -24
- sglang/srt/managers/io_struct.py +223 -129
- sglang/srt/managers/mm_utils.py +49 -10
- sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +130 -0
- sglang/srt/managers/schedule_batch.py +340 -529
- sglang/srt/managers/schedule_policy.py +158 -18
- sglang/srt/managers/scheduler.py +665 -620
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
- sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
- sglang/srt/managers/scheduler_pp_mixin.py +341 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
- sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
- sglang/srt/managers/tokenizer_manager.py +462 -226
- sglang/srt/managers/tp_worker.py +217 -156
- sglang/srt/managers/utils.py +79 -47
- sglang/srt/mem_cache/allocator.py +21 -22
- sglang/srt/mem_cache/allocator_ascend.py +42 -28
- sglang/srt/mem_cache/base_prefix_cache.py +3 -3
- sglang/srt/mem_cache/chunk_cache.py +20 -2
- sglang/srt/mem_cache/common.py +480 -0
- sglang/srt/mem_cache/evict_policy.py +38 -0
- sglang/srt/mem_cache/hicache_storage.py +44 -2
- sglang/srt/mem_cache/hiradix_cache.py +134 -34
- sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
- sglang/srt/mem_cache/memory_pool.py +602 -208
- sglang/srt/mem_cache/memory_pool_host.py +134 -183
- sglang/srt/mem_cache/multimodal_cache.py +0 -1
- sglang/srt/mem_cache/radix_cache.py +263 -78
- sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
- sglang/srt/mem_cache/swa_radix_cache.py +115 -58
- sglang/srt/metrics/collector.py +113 -120
- sglang/srt/metrics/func_timer.py +3 -8
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +81 -36
- sglang/srt/model_executor/forward_batch_info.py +40 -50
- sglang/srt/model_executor/model_runner.py +507 -319
- sglang/srt/model_executor/npu_graph_runner.py +11 -5
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
- sglang/srt/model_loader/__init__.py +1 -1
- sglang/srt/model_loader/loader.py +438 -37
- sglang/srt/model_loader/utils.py +0 -1
- sglang/srt/model_loader/weight_utils.py +200 -27
- sglang/srt/models/apertus.py +2 -3
- sglang/srt/models/arcee.py +2 -2
- sglang/srt/models/bailing_moe.py +40 -56
- sglang/srt/models/bailing_moe_nextn.py +3 -4
- sglang/srt/models/bert.py +1 -1
- sglang/srt/models/deepseek_nextn.py +25 -4
- sglang/srt/models/deepseek_ocr.py +1516 -0
- sglang/srt/models/deepseek_v2.py +793 -235
- sglang/srt/models/dots_ocr.py +171 -0
- sglang/srt/models/dots_vlm.py +0 -1
- sglang/srt/models/dots_vlm_vit.py +1 -1
- sglang/srt/models/falcon_h1.py +570 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +17 -1
- sglang/srt/models/gemma3n_mm.py +2 -3
- sglang/srt/models/glm4_moe.py +17 -40
- sglang/srt/models/glm4_moe_nextn.py +4 -4
- sglang/srt/models/glm4v.py +3 -2
- sglang/srt/models/glm4v_moe.py +6 -6
- sglang/srt/models/gpt_oss.py +12 -35
- sglang/srt/models/grok.py +10 -23
- sglang/srt/models/hunyuan.py +2 -7
- sglang/srt/models/interns1.py +0 -1
- sglang/srt/models/kimi_vl.py +1 -7
- sglang/srt/models/kimi_vl_moonvit.py +4 -2
- sglang/srt/models/llama.py +6 -2
- sglang/srt/models/llama_eagle3.py +1 -1
- sglang/srt/models/longcat_flash.py +6 -23
- sglang/srt/models/longcat_flash_nextn.py +4 -15
- sglang/srt/models/mimo.py +2 -13
- sglang/srt/models/mimo_mtp.py +1 -2
- sglang/srt/models/minicpmo.py +7 -5
- sglang/srt/models/mixtral.py +1 -4
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/mllama4.py +27 -6
- sglang/srt/models/nemotron_h.py +511 -0
- sglang/srt/models/olmo2.py +31 -4
- sglang/srt/models/opt.py +5 -5
- sglang/srt/models/phi.py +1 -1
- sglang/srt/models/phi4mm.py +1 -1
- sglang/srt/models/phimoe.py +0 -1
- sglang/srt/models/pixtral.py +0 -3
- sglang/srt/models/points_v15_chat.py +186 -0
- sglang/srt/models/qwen.py +0 -1
- sglang/srt/models/qwen2.py +0 -7
- sglang/srt/models/qwen2_5_vl.py +5 -5
- sglang/srt/models/qwen2_audio.py +2 -15
- sglang/srt/models/qwen2_moe.py +70 -4
- sglang/srt/models/qwen2_vl.py +6 -3
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +50 -38
- sglang/srt/models/qwen3_next.py +43 -21
- sglang/srt/models/qwen3_next_mtp.py +3 -4
- sglang/srt/models/qwen3_omni_moe.py +661 -0
- sglang/srt/models/qwen3_vl.py +791 -0
- sglang/srt/models/qwen3_vl_moe.py +343 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/roberta.py +55 -3
- sglang/srt/models/sarashina2_vision.py +268 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +3 -5
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +61 -0
- sglang/srt/multimodal/processors/base_processor.py +21 -9
- sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
- sglang/srt/multimodal/processors/dots_vlm.py +2 -4
- sglang/srt/multimodal/processors/glm4v.py +1 -5
- sglang/srt/multimodal/processors/internvl.py +20 -10
- sglang/srt/multimodal/processors/janus_pro.py +0 -1
- sglang/srt/multimodal/processors/mllama4.py +0 -8
- sglang/srt/multimodal/processors/phi4mm.py +0 -1
- sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
- sglang/srt/multimodal/processors/qwen_vl.py +83 -17
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/multimodal/processors/step3_vl.py +1 -1
- sglang/srt/parser/conversation.py +41 -0
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/parser/reasoning_parser.py +0 -1
- sglang/srt/sampling/custom_logit_processor.py +77 -2
- sglang/srt/sampling/sampling_batch_info.py +36 -23
- sglang/srt/sampling/sampling_params.py +75 -0
- sglang/srt/server_args.py +1300 -338
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +161 -0
- sglang/srt/speculative/base_spec_worker.py +34 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/draft_utils.py +226 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
- sglang/srt/speculative/eagle_info.py +786 -0
- sglang/srt/speculative/eagle_info_v2.py +458 -0
- sglang/srt/speculative/eagle_utils.py +113 -1270
- sglang/srt/speculative/eagle_worker.py +120 -285
- sglang/srt/speculative/eagle_worker_v2.py +702 -0
- sglang/srt/speculative/ngram_info.py +433 -0
- sglang/srt/speculative/ngram_worker.py +246 -0
- sglang/srt/speculative/spec_info.py +49 -0
- sglang/srt/speculative/spec_utils.py +641 -0
- sglang/srt/speculative/standalone_worker.py +4 -14
- sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
- sglang/srt/tracing/trace.py +32 -6
- sglang/srt/two_batch_overlap.py +35 -18
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
- sglang/srt/{utils.py → utils/common.py} +583 -113
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
- sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
- sglang/srt/{offloader.py → utils/offloader.py} +4 -4
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/profile_merger.py +199 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_flashattn_backend.py +1 -1
- sglang/test/attention/test_flashattn_mla_backend.py +0 -1
- sglang/test/attention/test_prefix_chunk_info.py +0 -2
- sglang/test/attention/test_trtllm_mla_backend.py +221 -53
- sglang/test/few_shot_gsm8k_engine.py +2 -4
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/kit_matched_stop.py +157 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +120 -11
- sglang/test/runners.py +3 -1
- sglang/test/send_one.py +42 -7
- sglang/test/simple_eval_common.py +8 -2
- sglang/test/simple_eval_gpqa.py +0 -1
- sglang/test/simple_eval_humaneval.py +0 -3
- sglang/test/simple_eval_longbench_v2.py +344 -0
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +3 -4
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
- sglang/test/test_cutlass_moe.py +1 -2
- sglang/test/test_cutlass_w4a8_moe.py +10 -20
- sglang/test/test_deterministic.py +430 -0
- sglang/test/test_deterministic_utils.py +73 -0
- sglang/test/test_disaggregation_utils.py +93 -1
- sglang/test/test_marlin_moe.py +0 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +432 -16
- sglang/utils.py +10 -1
- sglang/version.py +1 -1
- {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
- {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
- sglang/srt/entrypoints/grpc_request_manager.py +0 -580
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- sglang/srt/speculative/build_eagle_tree.py +0 -427
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
- /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
- /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
|
@@ -20,7 +20,7 @@ Life cycle of a request in the prefill server
|
|
|
20
20
|
from __future__ import annotations
|
|
21
21
|
|
|
22
22
|
import logging
|
|
23
|
-
import
|
|
23
|
+
import time
|
|
24
24
|
from collections import deque
|
|
25
25
|
from http import HTTPStatus
|
|
26
26
|
from typing import TYPE_CHECKING, List, Optional, Type
|
|
@@ -42,14 +42,18 @@ from sglang.srt.disaggregation.utils import (
|
|
|
42
42
|
poll_and_all_reduce,
|
|
43
43
|
prepare_abort,
|
|
44
44
|
)
|
|
45
|
-
from sglang.srt.managers.schedule_batch import
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
point_to_point_pyobj,
|
|
51
|
-
require_mlp_sync,
|
|
45
|
+
from sglang.srt.managers.schedule_batch import (
|
|
46
|
+
FINISH_LENGTH,
|
|
47
|
+
Req,
|
|
48
|
+
RequestStage,
|
|
49
|
+
ScheduleBatch,
|
|
52
50
|
)
|
|
51
|
+
from sglang.srt.mem_cache.memory_pool import (
|
|
52
|
+
HybridLinearKVPool,
|
|
53
|
+
NSATokenToKVPool,
|
|
54
|
+
SWAKVPool,
|
|
55
|
+
)
|
|
56
|
+
from sglang.srt.utils import broadcast_pyobj, point_to_point_pyobj, require_mlp_sync
|
|
53
57
|
|
|
54
58
|
if TYPE_CHECKING:
|
|
55
59
|
from torch.distributed import ProcessGroup
|
|
@@ -140,6 +144,28 @@ class PrefillBootstrapQueue:
|
|
|
140
144
|
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
|
|
141
145
|
kv_args.gpu_id = self.scheduler.gpu_id
|
|
142
146
|
|
|
147
|
+
if hasattr(self.token_to_kv_pool, "get_state_buf_infos"):
|
|
148
|
+
state_data_ptrs, state_data_lens, state_item_lens = (
|
|
149
|
+
self.token_to_kv_pool.get_state_buf_infos()
|
|
150
|
+
)
|
|
151
|
+
kv_args.state_data_ptrs = state_data_ptrs
|
|
152
|
+
kv_args.state_data_lens = state_data_lens
|
|
153
|
+
kv_args.state_item_lens = state_item_lens
|
|
154
|
+
|
|
155
|
+
if isinstance(self.token_to_kv_pool, SWAKVPool):
|
|
156
|
+
kv_args.state_type = "swa"
|
|
157
|
+
elif isinstance(self.token_to_kv_pool, HybridLinearKVPool):
|
|
158
|
+
kv_args.state_type = "mamba"
|
|
159
|
+
elif isinstance(self.token_to_kv_pool, NSATokenToKVPool):
|
|
160
|
+
kv_args.state_type = "nsa"
|
|
161
|
+
else:
|
|
162
|
+
kv_args.state_type = "none"
|
|
163
|
+
else:
|
|
164
|
+
kv_args.state_data_ptrs = []
|
|
165
|
+
kv_args.state_data_lens = []
|
|
166
|
+
kv_args.state_item_lens = []
|
|
167
|
+
kv_args.state_type = "none"
|
|
168
|
+
|
|
143
169
|
kv_manager_class: Type[BaseKVManager] = get_kv_class(
|
|
144
170
|
self.transfer_backend, KVClassType.MANAGER
|
|
145
171
|
)
|
|
@@ -170,6 +196,7 @@ class PrefillBootstrapQueue:
|
|
|
170
196
|
pp_rank=self.pp_rank,
|
|
171
197
|
)
|
|
172
198
|
self._process_req(req)
|
|
199
|
+
req.add_latency(RequestStage.PREFILL_PREPARE)
|
|
173
200
|
self.queue.append(req)
|
|
174
201
|
|
|
175
202
|
def extend(self, reqs: List[Req], num_kv_heads: int) -> None:
|
|
@@ -256,8 +283,11 @@ class PrefillBootstrapQueue:
|
|
|
256
283
|
|
|
257
284
|
num_pages = kv_to_page_num(num_kv_indices, self.token_to_kv_pool.page_size)
|
|
258
285
|
req.disagg_kv_sender.init(num_pages, req.metadata_buffer_index)
|
|
286
|
+
|
|
259
287
|
bootstrapped_reqs.append(req)
|
|
260
288
|
indices_to_remove.add(i)
|
|
289
|
+
req.time_stats.wait_queue_entry_time = time.perf_counter()
|
|
290
|
+
req.add_latency(RequestStage.PREFILL_BOOTSTRAP)
|
|
261
291
|
|
|
262
292
|
self.queue = [
|
|
263
293
|
entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
|
|
@@ -322,30 +352,21 @@ class SchedulerDisaggregationPrefillMixin:
|
|
|
322
352
|
if require_mlp_sync(self.server_args):
|
|
323
353
|
batch = self.prepare_mlp_sync_batch(batch)
|
|
324
354
|
self.cur_batch = batch
|
|
355
|
+
|
|
356
|
+
batch_result = None
|
|
325
357
|
if batch:
|
|
326
|
-
|
|
327
|
-
self.result_queue.append((batch.copy(),
|
|
328
|
-
|
|
329
|
-
if self.last_batch is None:
|
|
330
|
-
# Create a dummy first batch to start the pipeline for overlap schedule.
|
|
331
|
-
# It is now used for triggering the sampling_info_done event.
|
|
332
|
-
tmp_batch = ScheduleBatch(
|
|
333
|
-
reqs=None,
|
|
334
|
-
forward_mode=ForwardMode.DUMMY_FIRST,
|
|
335
|
-
next_batch_sampling_info=self.tp_worker.cur_sampling_info,
|
|
336
|
-
)
|
|
337
|
-
self.set_next_batch_sampling_info_done(tmp_batch)
|
|
358
|
+
batch_result = self.run_batch(batch)
|
|
359
|
+
self.result_queue.append((batch.copy(), batch_result))
|
|
338
360
|
|
|
339
361
|
if self.last_batch:
|
|
340
362
|
tmp_batch, tmp_result = self.result_queue.popleft()
|
|
341
|
-
tmp_batch.next_batch_sampling_info = (
|
|
342
|
-
self.tp_worker.cur_sampling_info if batch else None
|
|
343
|
-
)
|
|
344
363
|
self.process_batch_result_disagg_prefill(tmp_batch, tmp_result)
|
|
345
364
|
|
|
346
365
|
if len(self.disagg_prefill_inflight_queue) > 0:
|
|
347
366
|
self.process_disagg_prefill_inflight_queue()
|
|
348
367
|
|
|
368
|
+
self.launch_batch_sample_if_needed(batch_result)
|
|
369
|
+
|
|
349
370
|
if batch is None and len(self.disagg_prefill_inflight_queue) == 0:
|
|
350
371
|
self.self_check_during_idle()
|
|
351
372
|
|
|
@@ -358,7 +379,6 @@ class SchedulerDisaggregationPrefillMixin:
|
|
|
358
379
|
self: Scheduler,
|
|
359
380
|
batch: ScheduleBatch,
|
|
360
381
|
result: GenerationBatchResult,
|
|
361
|
-
launch_done: Optional[threading.Event] = None,
|
|
362
382
|
) -> None:
|
|
363
383
|
"""
|
|
364
384
|
Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
|
|
@@ -369,53 +389,47 @@ class SchedulerDisaggregationPrefillMixin:
|
|
|
369
389
|
next_token_ids,
|
|
370
390
|
extend_input_len_per_req,
|
|
371
391
|
extend_logprob_start_len_per_req,
|
|
392
|
+
copy_done,
|
|
372
393
|
) = (
|
|
373
394
|
result.logits_output,
|
|
374
395
|
result.next_token_ids,
|
|
375
396
|
result.extend_input_len_per_req,
|
|
376
397
|
result.extend_logprob_start_len_per_req,
|
|
398
|
+
result.copy_done,
|
|
377
399
|
)
|
|
378
400
|
|
|
401
|
+
if copy_done is not None:
|
|
402
|
+
copy_done.synchronize()
|
|
403
|
+
|
|
379
404
|
logprob_pt = 0
|
|
380
405
|
# Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
logits_output
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
logits_output.next_token_logprobs.tolist()
|
|
392
|
-
)
|
|
393
|
-
if logits_output.input_token_logprobs is not None:
|
|
394
|
-
logits_output.input_token_logprobs = tuple(
|
|
395
|
-
logits_output.input_token_logprobs.tolist()
|
|
396
|
-
)
|
|
406
|
+
next_token_ids = result.next_token_ids.tolist()
|
|
407
|
+
if batch.return_logprob:
|
|
408
|
+
if logits_output.next_token_logprobs is not None:
|
|
409
|
+
logits_output.next_token_logprobs = (
|
|
410
|
+
logits_output.next_token_logprobs.tolist()
|
|
411
|
+
)
|
|
412
|
+
if logits_output.input_token_logprobs is not None:
|
|
413
|
+
logits_output.input_token_logprobs = tuple(
|
|
414
|
+
logits_output.input_token_logprobs.tolist()
|
|
415
|
+
)
|
|
397
416
|
|
|
398
417
|
hidden_state_offset = 0
|
|
399
418
|
for i, (req, next_token_id) in enumerate(
|
|
400
419
|
zip(batch.reqs, next_token_ids, strict=True)
|
|
401
420
|
):
|
|
402
|
-
req: Req
|
|
403
421
|
if req.is_chunked <= 0:
|
|
404
422
|
# There is no output_ids for prefill
|
|
405
423
|
req.output_ids.append(next_token_id)
|
|
406
424
|
self.tree_cache.cache_unfinished_req(req) # update the tree and lock
|
|
425
|
+
req.add_latency(RequestStage.PREFILL_FORWARD)
|
|
407
426
|
self.disagg_prefill_inflight_queue.append(req)
|
|
408
|
-
if (
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
):
|
|
412
|
-
last_hidden_index = (
|
|
413
|
-
hidden_state_offset + extend_input_len_per_req[i] - 1
|
|
414
|
-
)
|
|
427
|
+
if self.spec_algorithm.is_eagle() and batch.spec_info is not None:
|
|
428
|
+
req.output_topk_p = batch.spec_info.topk_p[i]
|
|
429
|
+
req.output_topk_index = batch.spec_info.topk_index[i]
|
|
415
430
|
req.hidden_states_tensor = (
|
|
416
|
-
|
|
431
|
+
batch.spec_info.hidden_states[i].cpu().clone()
|
|
417
432
|
)
|
|
418
|
-
hidden_state_offset += extend_input_len_per_req[i]
|
|
419
433
|
else:
|
|
420
434
|
req.hidden_states_tensor = None
|
|
421
435
|
if req.return_logprob:
|
|
@@ -434,6 +448,7 @@ class SchedulerDisaggregationPrefillMixin:
|
|
|
434
448
|
)
|
|
435
449
|
logprob_pt += num_input_logprobs
|
|
436
450
|
self.send_kv_chunk(req, last_chunk=True)
|
|
451
|
+
req.time_stats.prefill_transfer_queue_entry_time = time.perf_counter()
|
|
437
452
|
|
|
438
453
|
if req.grammar is not None:
|
|
439
454
|
# FIXME: this try-except block is for handling unexpected xgrammar issue.
|
|
@@ -473,8 +488,6 @@ class SchedulerDisaggregationPrefillMixin:
|
|
|
473
488
|
if self.enable_overlap:
|
|
474
489
|
self.send_kv_chunk(req, last_chunk=False, end_idx=req.tmp_end_idx)
|
|
475
490
|
|
|
476
|
-
# We need to remove the sync in the following function for overlap schedule.
|
|
477
|
-
self.set_next_batch_sampling_info_done(batch)
|
|
478
491
|
self.maybe_send_health_check_signal()
|
|
479
492
|
|
|
480
493
|
def process_disagg_prefill_inflight_queue(
|
|
@@ -531,6 +544,9 @@ class SchedulerDisaggregationPrefillMixin:
|
|
|
531
544
|
else:
|
|
532
545
|
assert False, f"Unexpected polling state {poll=}"
|
|
533
546
|
|
|
547
|
+
for req in done_reqs:
|
|
548
|
+
req.time_stats.completion_time = time.perf_counter()
|
|
549
|
+
|
|
534
550
|
# Stream requests which have finished transfer
|
|
535
551
|
self.stream_output(
|
|
536
552
|
done_reqs,
|
|
@@ -539,6 +555,7 @@ class SchedulerDisaggregationPrefillMixin:
|
|
|
539
555
|
)
|
|
540
556
|
for req in done_reqs:
|
|
541
557
|
req: Req
|
|
558
|
+
req.add_latency(RequestStage.PREFILL_TRANSFER_KV_CACHE)
|
|
542
559
|
self.req_to_metadata_buffer_idx_allocator.free(req.metadata_buffer_index)
|
|
543
560
|
req.metadata_buffer_index = -1
|
|
544
561
|
|
|
@@ -609,232 +626,58 @@ class SchedulerDisaggregationPrefillMixin:
|
|
|
609
626
|
.numpy()
|
|
610
627
|
)
|
|
611
628
|
req.start_send_idx = end_idx
|
|
629
|
+
state_indices = None
|
|
612
630
|
if last_chunk:
|
|
613
631
|
self.disagg_metadata_buffers.set_buf(req)
|
|
632
|
+
|
|
633
|
+
# Prepare extra pool indices for hybrid models
|
|
634
|
+
if isinstance(
|
|
635
|
+
self.token_to_kv_pool_allocator.get_kvcache(), HybridLinearKVPool
|
|
636
|
+
):
|
|
637
|
+
# Mamba hybrid model: send single mamba state index
|
|
638
|
+
state_indices = [
|
|
639
|
+
self.req_to_token_pool.req_index_to_mamba_index_mapping[
|
|
640
|
+
req.req_pool_idx
|
|
641
|
+
]
|
|
642
|
+
.cpu()
|
|
643
|
+
.numpy()
|
|
644
|
+
]
|
|
645
|
+
elif isinstance(self.token_to_kv_pool_allocator.get_kvcache(), SWAKVPool):
|
|
646
|
+
# SWA hybrid model: send last window KV indices
|
|
647
|
+
seq_len = len(req.fill_ids)
|
|
648
|
+
window_size = self.sliding_window_size
|
|
649
|
+
window_start = max(0, seq_len - window_size)
|
|
650
|
+
window_start = (window_start // page_size) * page_size
|
|
651
|
+
|
|
652
|
+
window_kv_indices_full = self.req_to_token_pool.req_to_token[
|
|
653
|
+
req.req_pool_idx, window_start:seq_len
|
|
654
|
+
]
|
|
655
|
+
|
|
656
|
+
# Translate to SWA pool indices
|
|
657
|
+
window_kv_indices_swa = (
|
|
658
|
+
self.token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
|
|
659
|
+
window_kv_indices_full
|
|
660
|
+
)
|
|
661
|
+
)
|
|
662
|
+
state_indices = window_kv_indices_swa.cpu().numpy()
|
|
663
|
+
state_indices = kv_to_page_indices(state_indices, page_size)
|
|
664
|
+
elif isinstance(
|
|
665
|
+
self.token_to_kv_pool_allocator.get_kvcache(), NSATokenToKVPool
|
|
666
|
+
):
|
|
667
|
+
seq_len = len(req.fill_ids)
|
|
668
|
+
kv_indices_full = self.req_to_token_pool.req_to_token[
|
|
669
|
+
req.req_pool_idx, :seq_len
|
|
670
|
+
]
|
|
671
|
+
state_indices = kv_indices_full.cpu().numpy()
|
|
672
|
+
state_indices = kv_to_page_indices(state_indices, page_size)
|
|
673
|
+
|
|
614
674
|
page_indices = kv_to_page_indices(kv_indices, page_size)
|
|
615
675
|
if len(page_indices) == 0:
|
|
616
676
|
logger.info(
|
|
617
677
|
f"Skip sending kv chunk for request {req.rid=} {req.bootstrap_room=} because page_indices is empty"
|
|
618
678
|
)
|
|
619
679
|
return
|
|
620
|
-
req.disagg_kv_sender.send(page_indices)
|
|
621
|
-
|
|
622
|
-
# PP
|
|
623
|
-
@DynamicGradMode()
|
|
624
|
-
def event_loop_pp_disagg_prefill(self: Scheduler):
|
|
625
|
-
"""
|
|
626
|
-
An event loop for the prefill server in pipeline parallelism.
|
|
627
|
-
|
|
628
|
-
Rules:
|
|
629
|
-
1. Each stage runs in the same order and is notified by the previous stage.
|
|
630
|
-
2. Each send/recv operation is blocking and matched by the neighboring stage.
|
|
631
|
-
|
|
632
|
-
Regular Schedule:
|
|
633
|
-
====================================================================
|
|
634
|
-
Stage i | Stage i+1
|
|
635
|
-
send ith req | recv ith req
|
|
636
|
-
send ith proxy | recv ith proxy
|
|
637
|
-
send prev (i+1)th carry | recv prev (i+1)th carry
|
|
638
|
-
====================================================================
|
|
639
|
-
|
|
640
|
-
Prefill Server Schedule:
|
|
641
|
-
====================================================================
|
|
642
|
-
Stage i | Stage i+1
|
|
643
|
-
send ith req | recv ith req
|
|
644
|
-
send ith bootstrap req | recv ith bootstrap req
|
|
645
|
-
send ith transferred req | recv ith transferred req
|
|
646
|
-
send ith proxy | recv ith proxy
|
|
647
|
-
send prev (i+1)th carry | recv prev (i+1)th carry
|
|
648
|
-
send prev (i+1)th release req | recv prev (i+1)th release req
|
|
649
|
-
====================================================================
|
|
650
|
-
|
|
651
|
-
There are two additional elements compared to the regular schedule:
|
|
652
|
-
|
|
653
|
-
1. Bootstrap Requests:
|
|
654
|
-
a. Instead of polling the status on the current workers, we should wait for the previous stage to notify to avoid desynchronization.
|
|
655
|
-
b. The first stage polls the status and propagates the bootstrapped requests down to all other stages.
|
|
656
|
-
c. If the first stage polls successfully, by nature, other ranks are also successful because they performed a handshake together.
|
|
657
|
-
|
|
658
|
-
2. Transferred Requests + Release Requests:
|
|
659
|
-
a. The first stage polls the transfer finished requests, performs an intersection with the next stage's finished requests, and propagates down to the last stage.
|
|
660
|
-
b. The last stage receives the requests that have finished transfer on all stages (consensus), then sends them to the first stage to release the memory.
|
|
661
|
-
c. The first stage receives the release requests, releases the memory, and then propagates the release requests down to the last stage.
|
|
662
|
-
"""
|
|
663
|
-
from sglang.srt.managers.scheduler import GenerationBatchResult
|
|
664
|
-
|
|
665
|
-
mbs = [None] * self.pp_size
|
|
666
|
-
last_mbs = [None] * self.pp_size
|
|
667
|
-
self.running_mbs = [
|
|
668
|
-
ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
|
|
669
|
-
]
|
|
670
|
-
bids = [None] * self.pp_size
|
|
671
|
-
pp_outputs: Optional[PPProxyTensors] = None
|
|
672
|
-
|
|
673
|
-
# Either success or failed
|
|
674
|
-
bootstrapped_rids: List[str] = []
|
|
675
|
-
transferred_rids: List[str] = []
|
|
676
|
-
release_rids: Optional[List[str]] = None
|
|
677
|
-
|
|
678
|
-
# transferred microbatch
|
|
679
|
-
tmbs = [None] * self.pp_size
|
|
680
|
-
|
|
681
|
-
ENABLE_RELEASE = True # For debug
|
|
682
|
-
|
|
683
|
-
while True:
|
|
684
|
-
server_is_idle = True
|
|
685
|
-
|
|
686
|
-
for mb_id in range(self.pp_size):
|
|
687
|
-
self.running_batch = self.running_mbs[mb_id]
|
|
688
|
-
self.last_batch = last_mbs[mb_id]
|
|
689
|
-
|
|
690
|
-
recv_reqs = self.recv_requests()
|
|
691
|
-
|
|
692
|
-
self.process_input_requests(recv_reqs)
|
|
693
|
-
|
|
694
|
-
if self.pp_group.is_first_rank:
|
|
695
|
-
# First rank, pop the bootstrap reqs from the bootstrap queue
|
|
696
|
-
bootstrapped_reqs, failed_reqs = (
|
|
697
|
-
self.disagg_prefill_bootstrap_queue.pop_bootstrapped(
|
|
698
|
-
return_failed_reqs=True
|
|
699
|
-
)
|
|
700
|
-
)
|
|
701
|
-
bootstrapped_rids = [req.rid for req in bootstrapped_reqs] + [
|
|
702
|
-
req.rid for req in failed_reqs
|
|
703
|
-
]
|
|
704
|
-
self.waiting_queue.extend(bootstrapped_reqs)
|
|
705
|
-
else:
|
|
706
|
-
# Other ranks, receive the bootstrap reqs info from the previous rank and ensure the consensus
|
|
707
|
-
bootstrapped_rids = self.recv_pyobj_from_prev_stage()
|
|
708
|
-
bootstrapped_reqs = (
|
|
709
|
-
self.disagg_prefill_bootstrap_queue.pop_bootstrapped(
|
|
710
|
-
rids_to_check=bootstrapped_rids
|
|
711
|
-
)
|
|
712
|
-
)
|
|
713
|
-
self.waiting_queue.extend(bootstrapped_reqs)
|
|
714
|
-
|
|
715
|
-
if self.pp_group.is_first_rank:
|
|
716
|
-
transferred_rids = self.get_transferred_rids()
|
|
717
|
-
# if other ranks,
|
|
718
|
-
else:
|
|
719
|
-
# 1. recv previous stage's transferred reqs info
|
|
720
|
-
prev_transferred_rids = self.recv_pyobj_from_prev_stage()
|
|
721
|
-
# 2. get the current stage's transferred reqs info
|
|
722
|
-
curr_transferred_rids = self.get_transferred_rids()
|
|
723
|
-
# 3. new consensus rids = intersection(previous consensus rids, transfer finished rids)
|
|
724
|
-
transferred_rids = list(
|
|
725
|
-
set(prev_transferred_rids) & set(curr_transferred_rids)
|
|
726
|
-
)
|
|
727
|
-
|
|
728
|
-
tmbs[mb_id] = transferred_rids
|
|
729
|
-
|
|
730
|
-
self.process_prefill_chunk()
|
|
731
|
-
mbs[mb_id] = self.get_new_batch_prefill()
|
|
732
|
-
self.running_mbs[mb_id] = self.running_batch
|
|
733
|
-
|
|
734
|
-
self.cur_batch = mbs[mb_id]
|
|
735
|
-
if self.cur_batch:
|
|
736
|
-
server_is_idle = False
|
|
737
|
-
result = self.run_batch(self.cur_batch)
|
|
738
|
-
|
|
739
|
-
# send the outputs to the next step
|
|
740
|
-
if self.pp_group.is_last_rank:
|
|
741
|
-
if self.cur_batch:
|
|
742
|
-
next_token_ids, bids[mb_id] = (
|
|
743
|
-
result.next_token_ids,
|
|
744
|
-
result.bid,
|
|
745
|
-
)
|
|
746
|
-
pp_outputs = PPProxyTensors(
|
|
747
|
-
{
|
|
748
|
-
"next_token_ids": next_token_ids,
|
|
749
|
-
}
|
|
750
|
-
)
|
|
751
|
-
# send the output from the last round to let the next stage worker run post processing
|
|
752
|
-
self.pp_group.send_tensor_dict(
|
|
753
|
-
pp_outputs.tensors,
|
|
754
|
-
all_gather_group=self.attn_tp_group,
|
|
755
|
-
)
|
|
756
|
-
|
|
757
|
-
if ENABLE_RELEASE:
|
|
758
|
-
if self.pp_group.is_last_rank:
|
|
759
|
-
# At the last stage, all stages has reached the consensus to release memory for transferred_rids
|
|
760
|
-
release_rids = transferred_rids
|
|
761
|
-
# send to the first rank
|
|
762
|
-
self.send_pyobj_to_next_stage(release_rids)
|
|
763
|
-
|
|
764
|
-
# receive outputs and post-process (filter finished reqs) the coming microbatch
|
|
765
|
-
next_mb_id = (mb_id + 1) % self.pp_size
|
|
766
|
-
next_pp_outputs = None
|
|
767
|
-
next_release_rids = None
|
|
768
|
-
|
|
769
|
-
if mbs[next_mb_id] is not None:
|
|
770
|
-
next_pp_outputs: Optional[PPProxyTensors] = PPProxyTensors(
|
|
771
|
-
self.pp_group.recv_tensor_dict(
|
|
772
|
-
all_gather_group=self.attn_tp_group
|
|
773
|
-
)
|
|
774
|
-
)
|
|
775
|
-
mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"]
|
|
776
|
-
output_result = GenerationBatchResult(
|
|
777
|
-
logits_output=None,
|
|
778
|
-
pp_hidden_states_proxy_tensors=None,
|
|
779
|
-
next_token_ids=next_pp_outputs["next_token_ids"],
|
|
780
|
-
extend_input_len_per_req=None,
|
|
781
|
-
extend_logprob_start_len_per_req=None,
|
|
782
|
-
bid=bids[next_mb_id],
|
|
783
|
-
can_run_cuda_graph=result.can_run_cuda_graph,
|
|
784
|
-
)
|
|
785
|
-
self.process_batch_result_disagg_prefill(
|
|
786
|
-
mbs[next_mb_id], output_result
|
|
787
|
-
)
|
|
788
|
-
|
|
789
|
-
last_mbs[next_mb_id] = mbs[next_mb_id]
|
|
790
|
-
|
|
791
|
-
if ENABLE_RELEASE:
|
|
792
|
-
if tmbs[next_mb_id] is not None:
|
|
793
|
-
# recv consensus rids from the previous rank
|
|
794
|
-
next_release_rids = self.recv_pyobj_from_prev_stage()
|
|
795
|
-
self.process_disagg_prefill_inflight_queue(next_release_rids)
|
|
796
|
-
|
|
797
|
-
# carry the outputs to the next stage
|
|
798
|
-
if not self.pp_group.is_last_rank:
|
|
799
|
-
if self.cur_batch:
|
|
800
|
-
bids[mb_id] = result.bid
|
|
801
|
-
if pp_outputs:
|
|
802
|
-
# send the outputs from the last round to let the next stage worker run post processing
|
|
803
|
-
self.pp_group.send_tensor_dict(
|
|
804
|
-
pp_outputs.tensors,
|
|
805
|
-
all_gather_group=self.attn_tp_group,
|
|
806
|
-
)
|
|
807
|
-
if ENABLE_RELEASE:
|
|
808
|
-
if release_rids is not None:
|
|
809
|
-
self.send_pyobj_to_next_stage(release_rids)
|
|
810
|
-
|
|
811
|
-
if not self.pp_group.is_last_rank:
|
|
812
|
-
# send out reqs to the next stage
|
|
813
|
-
self.send_pyobj_to_next_stage(recv_reqs)
|
|
814
|
-
self.send_pyobj_to_next_stage(bootstrapped_rids)
|
|
815
|
-
self.send_pyobj_to_next_stage(transferred_rids)
|
|
816
|
-
|
|
817
|
-
# send out proxy tensors to the next stage
|
|
818
|
-
if self.cur_batch:
|
|
819
|
-
self.pp_group.send_tensor_dict(
|
|
820
|
-
result.pp_hidden_states_proxy_tensors,
|
|
821
|
-
all_gather_group=self.attn_tp_group,
|
|
822
|
-
)
|
|
823
|
-
|
|
824
|
-
pp_outputs = next_pp_outputs
|
|
825
|
-
release_rids = next_release_rids
|
|
826
|
-
|
|
827
|
-
self.running_batch.batch_is_full = False
|
|
828
|
-
|
|
829
|
-
if not ENABLE_RELEASE:
|
|
830
|
-
if len(self.disagg_prefill_inflight_queue) > 0:
|
|
831
|
-
self.process_disagg_prefill_inflight_queue()
|
|
832
|
-
|
|
833
|
-
# When the server is idle, self-check and re-init some states
|
|
834
|
-
if server_is_idle and len(self.disagg_prefill_inflight_queue) == 0:
|
|
835
|
-
self.check_memory()
|
|
836
|
-
self.check_tree_cache()
|
|
837
|
-
self.new_token_ratio = self.init_new_token_ratio
|
|
680
|
+
req.disagg_kv_sender.send(page_indices, state_indices)
|
|
838
681
|
|
|
839
682
|
def send_pyobj_to_next_stage(self, data):
|
|
840
683
|
if self.attn_tp_rank == 0:
|
|
@@ -5,7 +5,7 @@ import random
|
|
|
5
5
|
from collections import deque
|
|
6
6
|
from contextlib import nullcontext
|
|
7
7
|
from enum import Enum
|
|
8
|
-
from typing import TYPE_CHECKING,
|
|
8
|
+
from typing import TYPE_CHECKING, Optional, Type
|
|
9
9
|
|
|
10
10
|
import numpy as np
|
|
11
11
|
import torch
|
|
@@ -85,7 +85,7 @@ class MetadataBuffers:
|
|
|
85
85
|
self,
|
|
86
86
|
size: int,
|
|
87
87
|
hidden_size: int,
|
|
88
|
-
|
|
88
|
+
hidden_states_dtype: torch.dtype,
|
|
89
89
|
max_top_logprobs_num: int = 128,
|
|
90
90
|
custom_mem_pool: torch.cuda.MemPool = None,
|
|
91
91
|
):
|
|
@@ -107,7 +107,9 @@ class MetadataBuffers:
|
|
|
107
107
|
# We transfer the metadata of first output token to decode
|
|
108
108
|
# The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
|
|
109
109
|
self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device=device)
|
|
110
|
-
|
|
110
|
+
self.cached_tokens = torch.zeros(
|
|
111
|
+
(size, 16), dtype=torch.int32, device=device
|
|
112
|
+
)
|
|
111
113
|
self.output_token_logprobs_val = torch.zeros(
|
|
112
114
|
(size, 16), dtype=torch.float32, device=device
|
|
113
115
|
)
|
|
@@ -120,33 +122,49 @@ class MetadataBuffers:
|
|
|
120
122
|
self.output_top_logprobs_idx = torch.zeros(
|
|
121
123
|
(size, max_top_logprobs_num), dtype=torch.int32, device=device
|
|
122
124
|
)
|
|
125
|
+
# For PD + spec decode
|
|
126
|
+
self.output_topk_p = torch.zeros(
|
|
127
|
+
(size, 16), dtype=torch.float32, device=device
|
|
128
|
+
)
|
|
129
|
+
self.output_topk_index = torch.zeros(
|
|
130
|
+
(size, 16), dtype=torch.int64, device=device
|
|
131
|
+
)
|
|
123
132
|
self.output_hidden_states = torch.zeros(
|
|
124
|
-
(size, hidden_size), dtype=
|
|
133
|
+
(size, hidden_size), dtype=hidden_states_dtype, device=device
|
|
125
134
|
)
|
|
126
135
|
|
|
127
136
|
def get_buf_infos(self):
|
|
128
137
|
ptrs = [
|
|
129
138
|
self.output_ids.data_ptr(),
|
|
139
|
+
self.cached_tokens.data_ptr(),
|
|
130
140
|
self.output_token_logprobs_val.data_ptr(),
|
|
131
141
|
self.output_token_logprobs_idx.data_ptr(),
|
|
132
142
|
self.output_top_logprobs_val.data_ptr(),
|
|
133
143
|
self.output_top_logprobs_idx.data_ptr(),
|
|
144
|
+
self.output_topk_p.data_ptr(),
|
|
145
|
+
self.output_topk_index.data_ptr(),
|
|
134
146
|
self.output_hidden_states.data_ptr(),
|
|
135
147
|
]
|
|
136
148
|
data_lens = [
|
|
137
149
|
self.output_ids.nbytes,
|
|
150
|
+
self.cached_tokens.nbytes,
|
|
138
151
|
self.output_token_logprobs_val.nbytes,
|
|
139
152
|
self.output_token_logprobs_idx.nbytes,
|
|
140
153
|
self.output_top_logprobs_val.nbytes,
|
|
141
154
|
self.output_top_logprobs_idx.nbytes,
|
|
155
|
+
self.output_topk_p.nbytes,
|
|
156
|
+
self.output_topk_index.nbytes,
|
|
142
157
|
self.output_hidden_states.nbytes,
|
|
143
158
|
]
|
|
144
159
|
item_lens = [
|
|
145
160
|
self.output_ids[0].nbytes,
|
|
161
|
+
self.cached_tokens[0].nbytes,
|
|
146
162
|
self.output_token_logprobs_val[0].nbytes,
|
|
147
163
|
self.output_token_logprobs_idx[0].nbytes,
|
|
148
164
|
self.output_top_logprobs_val[0].nbytes,
|
|
149
165
|
self.output_top_logprobs_idx[0].nbytes,
|
|
166
|
+
self.output_topk_p[0].nbytes,
|
|
167
|
+
self.output_topk_index[0].nbytes,
|
|
150
168
|
self.output_hidden_states[0].nbytes,
|
|
151
169
|
]
|
|
152
170
|
return ptrs, data_lens, item_lens
|
|
@@ -154,16 +172,20 @@ class MetadataBuffers:
|
|
|
154
172
|
def get_buf(self, idx: int):
|
|
155
173
|
return (
|
|
156
174
|
self.output_ids[idx],
|
|
175
|
+
self.cached_tokens[idx],
|
|
157
176
|
self.output_token_logprobs_val[idx],
|
|
158
177
|
self.output_token_logprobs_idx[idx],
|
|
159
178
|
self.output_top_logprobs_val[idx],
|
|
160
179
|
self.output_top_logprobs_idx[idx],
|
|
180
|
+
self.output_topk_p[idx],
|
|
181
|
+
self.output_topk_index[idx],
|
|
161
182
|
self.output_hidden_states[idx],
|
|
162
183
|
)
|
|
163
184
|
|
|
164
185
|
def set_buf(self, req: Req):
|
|
165
186
|
|
|
166
187
|
self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0]
|
|
188
|
+
self.cached_tokens[req.metadata_buffer_index][0] = req.cached_tokens
|
|
167
189
|
if req.return_logprob:
|
|
168
190
|
if req.output_token_logprobs_val: # not none or empty list
|
|
169
191
|
self.output_token_logprobs_val[req.metadata_buffer_index][0] = (
|
|
@@ -186,8 +208,17 @@ class MetadataBuffers:
|
|
|
186
208
|
] = torch.tensor(
|
|
187
209
|
req.output_top_logprobs_idx[0], dtype=torch.int32, device="cpu"
|
|
188
210
|
)
|
|
189
|
-
#
|
|
211
|
+
# For PD + spec decode
|
|
190
212
|
if req.hidden_states_tensor is not None:
|
|
213
|
+
# speculative_eagle_topk should not be greater than 16 currently
|
|
214
|
+
topk = req.output_topk_p.size(0)
|
|
215
|
+
|
|
216
|
+
self.output_topk_p[req.metadata_buffer_index, :topk].copy_(
|
|
217
|
+
req.output_topk_p
|
|
218
|
+
)
|
|
219
|
+
self.output_topk_index[req.metadata_buffer_index, :topk].copy_(
|
|
220
|
+
req.output_topk_index
|
|
221
|
+
)
|
|
191
222
|
self.output_hidden_states[req.metadata_buffer_index].copy_(
|
|
192
223
|
req.hidden_states_tensor
|
|
193
224
|
)
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
MiB = 1024 * 1024
|
|
2
|
+
|
|
3
|
+
SYMM_MEM_ALL_REDUCE_MAX_SIZES = {
|
|
4
|
+
9: {
|
|
5
|
+
2: 64 * MiB, # 64 MB
|
|
6
|
+
4: 64 * MiB, # 64 MB
|
|
7
|
+
6: 128 * MiB, # 128 MB
|
|
8
|
+
8: 128 * MiB, # 128 MB
|
|
9
|
+
},
|
|
10
|
+
10: {
|
|
11
|
+
2: 64 * MiB, # 64 MB
|
|
12
|
+
4: 64 * MiB, # 64 MB
|
|
13
|
+
6: 128 * MiB, # 128 MB
|
|
14
|
+
8: 128 * MiB, # 128 MB
|
|
15
|
+
},
|
|
16
|
+
}
|