sglang 0.5.2rc1__py3-none-any.whl → 0.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +267 -32
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/interpreter.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/model_config.py +181 -82
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +71 -19
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +326 -53
- sglang/srt/disaggregation/prefill.py +36 -17
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +192 -113
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +855 -0
- sglang/srt/entrypoints/grpc_server.py +810 -0
- sglang/srt/entrypoints/http_server.py +132 -57
- sglang/srt/entrypoints/openai/protocol.py +115 -7
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +207 -58
- sglang/srt/entrypoints/openai/serving_completions.py +17 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +10 -4
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +49 -4
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +9 -2
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +24 -1
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +106 -82
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +118 -198
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
- sglang/srt/layers/attention/mamba/mamba.py +629 -0
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +53 -7
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +44 -12
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +256 -63
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +22 -7
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +78 -49
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +225 -57
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +77 -42
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +26 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +78 -49
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +52 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +215 -314
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +358 -404
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +240 -138
- sglang/srt/managers/schedule_policy.py +147 -19
- sglang/srt/managers/scheduler.py +501 -304
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +119 -40
- sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +321 -632
- sglang/srt/managers/tp_worker.py +81 -22
- sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +15 -21
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +58 -34
- sglang/srt/mem_cache/hiradix_cache.py +227 -80
- sglang/srt/mem_cache/memory_pool.py +535 -58
- sglang/srt/mem_cache/memory_pool_host.py +239 -223
- sglang/srt/mem_cache/radix_cache.py +222 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +268 -63
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +198 -30
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +25 -36
- sglang/srt/metrics/collector.py +519 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +55 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +98 -57
- sglang/srt/model_executor/model_runner.py +433 -158
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +133 -5
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +158 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +833 -152
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +14 -5
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +40 -4
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +124 -14
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +26 -5
- sglang/srt/models/qwen3_moe.py +71 -12
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +10 -3
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +6 -0
- sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
- sglang/srt/sampling/sampling_batch_info.py +38 -17
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +1030 -254
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +253 -136
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +445 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +77 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +383 -5
- sglang/utils.py +22 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/RECORD +392 -258
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
- /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
sglang/srt/managers/scheduler.py
CHANGED
@@ -44,6 +44,9 @@ from sglang.srt.disaggregation.decode import (
|
|
44
44
|
DecodeTransferQueue,
|
45
45
|
SchedulerDisaggregationDecodeMixin,
|
46
46
|
)
|
47
|
+
from sglang.srt.disaggregation.decode_kvcache_offload_manager import (
|
48
|
+
DecodeKVCacheOffloadManager,
|
49
|
+
)
|
47
50
|
from sglang.srt.disaggregation.prefill import (
|
48
51
|
PrefillBootstrapQueue,
|
49
52
|
SchedulerDisaggregationPrefillMixin,
|
@@ -57,11 +60,6 @@ from sglang.srt.disaggregation.utils import (
|
|
57
60
|
)
|
58
61
|
from sglang.srt.distributed import get_pp_group, get_world_group
|
59
62
|
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
60
|
-
from sglang.srt.hf_transformers_utils import (
|
61
|
-
get_processor,
|
62
|
-
get_tokenizer,
|
63
|
-
get_tokenizer_from_processor,
|
64
|
-
)
|
65
63
|
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
66
64
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
67
65
|
from sglang.srt.layers.moe import initialize_moe_config
|
@@ -72,20 +70,26 @@ from sglang.srt.managers.io_struct import (
|
|
72
70
|
ClearHiCacheReqInput,
|
73
71
|
ClearHiCacheReqOutput,
|
74
72
|
CloseSessionReqInput,
|
73
|
+
DestroyWeightsUpdateGroupReqInput,
|
75
74
|
ExpertDistributionReq,
|
76
75
|
ExpertDistributionReqOutput,
|
76
|
+
ExpertDistributionReqType,
|
77
77
|
FlushCacheReqInput,
|
78
78
|
FlushCacheReqOutput,
|
79
79
|
FreezeGCReq,
|
80
80
|
GetInternalStateReq,
|
81
81
|
GetInternalStateReqOutput,
|
82
|
+
GetLoadReqInput,
|
83
|
+
GetLoadReqOutput,
|
82
84
|
GetWeightsByNameReqInput,
|
83
85
|
HealthCheckOutput,
|
86
|
+
InitWeightsSendGroupForRemoteInstanceReqInput,
|
87
|
+
InitWeightsSendGroupForRemoteInstanceReqOutput,
|
84
88
|
InitWeightsUpdateGroupReqInput,
|
85
89
|
LoadLoRAAdapterReqInput,
|
86
90
|
LoadLoRAAdapterReqOutput,
|
87
91
|
MultiTokenizerRegisterReq,
|
88
|
-
|
92
|
+
MultiTokenizerWrapper,
|
89
93
|
OpenSessionReqInput,
|
90
94
|
OpenSessionReqOutput,
|
91
95
|
ProfileReq,
|
@@ -93,6 +97,8 @@ from sglang.srt.managers.io_struct import (
|
|
93
97
|
ResumeMemoryOccupationReqInput,
|
94
98
|
RpcReqInput,
|
95
99
|
RpcReqOutput,
|
100
|
+
SendWeightsToRemoteInstanceReqInput,
|
101
|
+
SendWeightsToRemoteInstanceReqOutput,
|
96
102
|
SetInternalStateReq,
|
97
103
|
SetInternalStateReqOutput,
|
98
104
|
SlowDownReqInput,
|
@@ -110,6 +116,7 @@ from sglang.srt.managers.schedule_batch import (
|
|
110
116
|
FINISH_ABORT,
|
111
117
|
MultimodalInputs,
|
112
118
|
Req,
|
119
|
+
RequestStage,
|
113
120
|
ScheduleBatch,
|
114
121
|
global_server_args_dict,
|
115
122
|
)
|
@@ -134,17 +141,28 @@ from sglang.srt.managers.scheduler_update_weights_mixin import (
|
|
134
141
|
from sglang.srt.managers.session_controller import Session
|
135
142
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
136
143
|
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
|
137
|
-
from sglang.srt.managers.utils import
|
144
|
+
from sglang.srt.managers.utils import validate_input_length
|
138
145
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
139
146
|
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
140
|
-
from sglang.srt.mem_cache.lora_radix_cache import LoRARadixCache
|
141
147
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
142
148
|
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
143
|
-
from sglang.srt.model_executor.forward_batch_info import
|
144
|
-
|
149
|
+
from sglang.srt.model_executor.forward_batch_info import (
|
150
|
+
ForwardBatchOutput,
|
151
|
+
ForwardMode,
|
152
|
+
PPProxyTensors,
|
153
|
+
)
|
154
|
+
from sglang.srt.parser.reasoning_parser import ReasoningParser
|
145
155
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
146
156
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
147
157
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
158
|
+
from sglang.srt.tracing.trace import (
|
159
|
+
process_tracing_init,
|
160
|
+
trace_set_proc_propagate_context,
|
161
|
+
trace_set_thread_info,
|
162
|
+
trace_slice_batch,
|
163
|
+
trace_slice_end,
|
164
|
+
trace_slice_start,
|
165
|
+
)
|
148
166
|
from sglang.srt.two_batch_overlap import TboDPAttentionPreparer
|
149
167
|
from sglang.srt.utils import (
|
150
168
|
DynamicGradMode,
|
@@ -155,9 +173,10 @@ from sglang.srt.utils import (
|
|
155
173
|
freeze_gc,
|
156
174
|
get_available_gpu_memory,
|
157
175
|
get_bool_env_var,
|
176
|
+
get_int_env_var,
|
158
177
|
get_zmq_socket,
|
159
|
-
is_cpu,
|
160
178
|
kill_itself_when_parent_died,
|
179
|
+
numa_bind_to_node,
|
161
180
|
point_to_point_pyobj,
|
162
181
|
pyspy_dump_schedulers,
|
163
182
|
require_mlp_sync,
|
@@ -166,6 +185,11 @@ from sglang.srt.utils import (
|
|
166
185
|
set_random_seed,
|
167
186
|
suppress_other_loggers,
|
168
187
|
)
|
188
|
+
from sglang.srt.utils.hf_transformers_utils import (
|
189
|
+
get_processor,
|
190
|
+
get_tokenizer,
|
191
|
+
get_tokenizer_from_processor,
|
192
|
+
)
|
169
193
|
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
|
170
194
|
|
171
195
|
logger = logging.getLogger(__name__)
|
@@ -174,24 +198,59 @@ logger = logging.getLogger(__name__)
|
|
174
198
|
TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
|
175
199
|
GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
|
176
200
|
|
177
|
-
_is_cpu = is_cpu()
|
178
|
-
|
179
201
|
|
180
202
|
@dataclass
|
181
203
|
class GenerationBatchResult:
|
182
204
|
logits_output: Optional[LogitsProcessorOutput]
|
183
|
-
pp_hidden_states_proxy_tensors: Optional[
|
205
|
+
pp_hidden_states_proxy_tensors: Optional[PPProxyTensors]
|
184
206
|
next_token_ids: Optional[List[int]]
|
207
|
+
can_run_cuda_graph: bool
|
208
|
+
|
209
|
+
# For output processing
|
185
210
|
extend_input_len_per_req: List[int]
|
186
211
|
extend_logprob_start_len_per_req: List[int]
|
187
|
-
|
188
|
-
|
212
|
+
|
213
|
+
@classmethod
|
214
|
+
def from_forward_batch_output(
|
215
|
+
cls,
|
216
|
+
forward_batch_output: ForwardBatchOutput,
|
217
|
+
extend_input_len_per_req: List[int],
|
218
|
+
extend_logprob_start_len_per_req: List[int],
|
219
|
+
):
|
220
|
+
# TODO(lsyin): remove this workaround logic and try to unify output classes
|
221
|
+
|
222
|
+
return cls(
|
223
|
+
logits_output=forward_batch_output.logits_output,
|
224
|
+
pp_hidden_states_proxy_tensors=forward_batch_output.pp_proxy_tensors,
|
225
|
+
next_token_ids=forward_batch_output.next_token_ids,
|
226
|
+
extend_input_len_per_req=extend_input_len_per_req,
|
227
|
+
extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
|
228
|
+
can_run_cuda_graph=forward_batch_output.can_run_cuda_graph,
|
229
|
+
)
|
230
|
+
|
231
|
+
@classmethod
|
232
|
+
def from_pp_proxy(
|
233
|
+
cls, logits_output, next_pp_outputs: PPProxyTensors, can_run_cuda_graph
|
234
|
+
):
|
235
|
+
# TODO(lsyin): also simplify this logic
|
236
|
+
# Current PP implementation in scheduler is not compatible with ForwardBatchOutput
|
237
|
+
# Maybe introduce a ProxyBatchOutput for PP and the original ForwardBatchOutput for TP
|
238
|
+
proxy_dict = next_pp_outputs.tensors
|
239
|
+
return cls(
|
240
|
+
logits_output=logits_output,
|
241
|
+
pp_hidden_states_proxy_tensors=None,
|
242
|
+
next_token_ids=next_pp_outputs["next_token_ids"],
|
243
|
+
extend_input_len_per_req=proxy_dict.get("extend_input_len_per_req", None),
|
244
|
+
extend_logprob_start_len_per_req=proxy_dict.get(
|
245
|
+
"extend_logprob_start_len_per_req", None
|
246
|
+
),
|
247
|
+
can_run_cuda_graph=can_run_cuda_graph,
|
248
|
+
)
|
189
249
|
|
190
250
|
|
191
251
|
@dataclass
|
192
252
|
class EmbeddingBatchResult:
|
193
253
|
embeddings: torch.Tensor
|
194
|
-
bid: int
|
195
254
|
|
196
255
|
|
197
256
|
class Scheduler(
|
@@ -213,7 +272,6 @@ class Scheduler(
|
|
213
272
|
moe_ep_rank: int,
|
214
273
|
pp_rank: int,
|
215
274
|
dp_rank: Optional[int],
|
216
|
-
dp_balance_meta: Optional[DPBalanceMeta] = None,
|
217
275
|
):
|
218
276
|
# Parse args
|
219
277
|
self.server_args = server_args
|
@@ -226,6 +284,13 @@ class Scheduler(
|
|
226
284
|
self.pp_size = server_args.pp_size
|
227
285
|
self.dp_size = server_args.dp_size
|
228
286
|
self.schedule_policy = server_args.schedule_policy
|
287
|
+
self.enable_priority_scheduling = server_args.enable_priority_scheduling
|
288
|
+
self.schedule_low_priority_values_first = (
|
289
|
+
server_args.schedule_low_priority_values_first
|
290
|
+
)
|
291
|
+
self.priority_scheduling_preemption_threshold = (
|
292
|
+
server_args.priority_scheduling_preemption_threshold
|
293
|
+
)
|
229
294
|
self.enable_lora = server_args.enable_lora
|
230
295
|
self.max_loras_per_batch = server_args.max_loras_per_batch
|
231
296
|
self.enable_overlap = not server_args.disable_overlap_schedule
|
@@ -234,7 +299,10 @@ class Scheduler(
|
|
234
299
|
self.enable_metrics_for_all_schedulers = (
|
235
300
|
server_args.enable_metrics_for_all_schedulers
|
236
301
|
)
|
237
|
-
self.enable_kv_cache_events =
|
302
|
+
self.enable_kv_cache_events = bool(
|
303
|
+
server_args.kv_events_config and tp_rank == 0
|
304
|
+
)
|
305
|
+
self.enable_trace = server_args.enable_trace
|
238
306
|
self.stream_interval = server_args.stream_interval
|
239
307
|
self.spec_algorithm = SpeculativeAlgorithm.from_string(
|
240
308
|
server_args.speculative_algorithm
|
@@ -348,9 +416,39 @@ class Scheduler(
|
|
348
416
|
target_worker=self.tp_worker,
|
349
417
|
dp_rank=dp_rank,
|
350
418
|
)
|
419
|
+
elif self.spec_algorithm.is_standalone():
|
420
|
+
from sglang.srt.speculative.standalone_worker import StandaloneWorker
|
421
|
+
|
422
|
+
self.draft_worker = StandaloneWorker(
|
423
|
+
gpu_id=gpu_id,
|
424
|
+
tp_rank=tp_rank,
|
425
|
+
moe_ep_rank=moe_ep_rank,
|
426
|
+
server_args=server_args,
|
427
|
+
nccl_port=port_args.nccl_port,
|
428
|
+
target_worker=self.tp_worker,
|
429
|
+
dp_rank=dp_rank,
|
430
|
+
)
|
431
|
+
elif self.spec_algorithm.is_ngram():
|
432
|
+
from sglang.srt.speculative.ngram_worker import NGRAMWorker
|
433
|
+
|
434
|
+
self.draft_worker = NGRAMWorker(
|
435
|
+
gpu_id=gpu_id,
|
436
|
+
tp_rank=tp_rank,
|
437
|
+
moe_ep_rank=moe_ep_rank,
|
438
|
+
server_args=server_args,
|
439
|
+
nccl_port=port_args.nccl_port,
|
440
|
+
target_worker=self.tp_worker,
|
441
|
+
dp_rank=dp_rank,
|
442
|
+
)
|
351
443
|
else:
|
352
444
|
self.draft_worker = None
|
353
445
|
|
446
|
+
# Dispatch the model worker
|
447
|
+
if self.spec_algorithm.is_none():
|
448
|
+
self.model_worker = self.tp_worker
|
449
|
+
else:
|
450
|
+
self.model_worker = self.draft_worker
|
451
|
+
|
354
452
|
# Get token and memory info from the model worker
|
355
453
|
(
|
356
454
|
self.max_total_num_tokens,
|
@@ -401,7 +499,7 @@ class Scheduler(
|
|
401
499
|
f"max_prefill_tokens={self.max_prefill_tokens}, "
|
402
500
|
f"max_running_requests={self.max_running_requests}, "
|
403
501
|
f"context_len={self.model_config.context_len}, "
|
404
|
-
f"available_gpu_mem={avail_mem:.2f} GB"
|
502
|
+
f"{'available_cpu_mem' if self.device == 'cpu' else 'available_gpu_mem'}={avail_mem:.2f} GB"
|
405
503
|
)
|
406
504
|
|
407
505
|
# Init memory pool and cache
|
@@ -458,7 +556,12 @@ class Scheduler(
|
|
458
556
|
self.schedule_policy,
|
459
557
|
self.tree_cache,
|
460
558
|
self.enable_hierarchical_cache,
|
559
|
+
self.enable_priority_scheduling,
|
560
|
+
self.schedule_low_priority_values_first,
|
461
561
|
)
|
562
|
+
# Enable preemption for priority scheduling.
|
563
|
+
self.try_preemption = self.enable_priority_scheduling
|
564
|
+
|
462
565
|
assert (
|
463
566
|
server_args.schedule_conservativeness >= 0
|
464
567
|
), "Invalid schedule_conservativeness"
|
@@ -488,7 +591,7 @@ class Scheduler(
|
|
488
591
|
enable=server_args.enable_memory_saver
|
489
592
|
)
|
490
593
|
self.offload_tags = set()
|
491
|
-
self.
|
594
|
+
self.init_profiler()
|
492
595
|
|
493
596
|
self.recv_skipper = SchedulerRecvSkipper.maybe_create(server_args)
|
494
597
|
self.input_blocker = (
|
@@ -499,7 +602,9 @@ class Scheduler(
|
|
499
602
|
|
500
603
|
# Init metrics stats
|
501
604
|
self.init_metrics(tp_rank, pp_rank, dp_rank)
|
502
|
-
|
605
|
+
|
606
|
+
if self.enable_kv_cache_events:
|
607
|
+
self.init_kv_events(server_args.kv_events_config)
|
503
608
|
|
504
609
|
# Init disaggregation
|
505
610
|
self.disaggregation_mode = DisaggregationMode(
|
@@ -510,6 +615,9 @@ class Scheduler(
|
|
510
615
|
if get_bool_env_var("SGLANG_GC_LOG"):
|
511
616
|
configure_gc_logger()
|
512
617
|
|
618
|
+
# Init prefill kv split size when deterministic inference is enabled with various attention backends
|
619
|
+
self.init_deterministic_inference_config()
|
620
|
+
|
513
621
|
# Init request dispatcher
|
514
622
|
self._request_dispatcher = TypeBasedDispatcher(
|
515
623
|
[
|
@@ -524,6 +632,15 @@ class Scheduler(
|
|
524
632
|
(CloseSessionReqInput, self.close_session),
|
525
633
|
(UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
|
526
634
|
(InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
|
635
|
+
(DestroyWeightsUpdateGroupReqInput, self.destroy_weights_update_group),
|
636
|
+
(
|
637
|
+
InitWeightsSendGroupForRemoteInstanceReqInput,
|
638
|
+
self.init_weights_send_group_for_remote_instance,
|
639
|
+
),
|
640
|
+
(
|
641
|
+
SendWeightsToRemoteInstanceReqInput,
|
642
|
+
self.send_weights_to_remote_instance,
|
643
|
+
),
|
527
644
|
(
|
528
645
|
UpdateWeightsFromDistributedReqInput,
|
529
646
|
self.update_weights_from_distributed,
|
@@ -542,17 +659,26 @@ class Scheduler(
|
|
542
659
|
(LoadLoRAAdapterReqInput, self.load_lora_adapter),
|
543
660
|
(UnloadLoRAAdapterReqInput, self.unload_lora_adapter),
|
544
661
|
(MultiTokenizerRegisterReq, self.register_multi_tokenizer),
|
662
|
+
(GetLoadReqInput, self.get_load),
|
545
663
|
]
|
546
664
|
)
|
547
665
|
|
548
|
-
|
549
|
-
|
550
|
-
|
551
|
-
|
552
|
-
|
553
|
-
assert dp_balance_meta is not None
|
666
|
+
def init_deterministic_inference_config(self):
|
667
|
+
"""Initialize deterministic inference configuration for different attention backends."""
|
668
|
+
if not self.server_args.enable_deterministic_inference:
|
669
|
+
self.truncation_align_size = None
|
670
|
+
return
|
554
671
|
|
555
|
-
|
672
|
+
backend_sizes = {
|
673
|
+
"flashinfer": ("SGLANG_FLASHINFER_PREFILL_SPLIT_TILE_SIZE", 4096),
|
674
|
+
"triton": ("SGLANG_TRITON_PREFILL_TRUNCATION_ALIGN_SIZE", 4096),
|
675
|
+
}
|
676
|
+
env_var, default_size = backend_sizes.get(
|
677
|
+
self.server_args.attention_backend, (None, None)
|
678
|
+
)
|
679
|
+
self.truncation_align_size = (
|
680
|
+
get_int_env_var(env_var, default_size) if env_var else None
|
681
|
+
)
|
556
682
|
|
557
683
|
def init_tokenizer(self):
|
558
684
|
server_args = self.server_args
|
@@ -625,15 +751,18 @@ class Scheduler(
|
|
625
751
|
else self.tp_cpu_group
|
626
752
|
),
|
627
753
|
page_size=self.page_size,
|
754
|
+
eviction_policy=server_args.radix_eviction_policy,
|
628
755
|
hicache_ratio=server_args.hicache_ratio,
|
629
756
|
hicache_size=server_args.hicache_size,
|
630
757
|
hicache_write_policy=server_args.hicache_write_policy,
|
631
758
|
hicache_io_backend=server_args.hicache_io_backend,
|
632
759
|
hicache_mem_layout=server_args.hicache_mem_layout,
|
760
|
+
enable_metrics=self.enable_metrics,
|
633
761
|
hicache_storage_backend=server_args.hicache_storage_backend,
|
634
762
|
hicache_storage_prefetch_policy=server_args.hicache_storage_prefetch_policy,
|
635
763
|
model_name=server_args.served_model_name,
|
636
764
|
storage_backend_extra_config=server_args.hicache_storage_backend_extra_config,
|
765
|
+
is_eagle=self.spec_algorithm.is_eagle(),
|
637
766
|
)
|
638
767
|
self.tp_worker.register_hicache_layer_transfer_counter(
|
639
768
|
self.tree_cache.cache_controller.layer_done_counter
|
@@ -649,18 +778,21 @@ class Scheduler(
|
|
649
778
|
page_size=self.page_size,
|
650
779
|
disable=server_args.disable_radix_cache,
|
651
780
|
)
|
652
|
-
elif
|
653
|
-
|
654
|
-
|
655
|
-
)
|
656
|
-
|
657
|
-
|
658
|
-
), "LoRA radix cache only supports FCFS policy"
|
659
|
-
self.tree_cache = LoRARadixCache(
|
781
|
+
elif server_args.enable_lmcache:
|
782
|
+
from sglang.srt.mem_cache.storage.lmcache.lmc_radix_cache import (
|
783
|
+
LMCRadixCache,
|
784
|
+
)
|
785
|
+
|
786
|
+
self.tree_cache = LMCRadixCache(
|
660
787
|
req_to_token_pool=self.req_to_token_pool,
|
661
788
|
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
662
789
|
page_size=self.page_size,
|
663
790
|
disable=server_args.disable_radix_cache,
|
791
|
+
model_config=self.model_config,
|
792
|
+
tp_size=self.tp_size,
|
793
|
+
rank=self.tp_rank,
|
794
|
+
tp_group=self.tp_group,
|
795
|
+
eviction_policy=server_args.radix_eviction_policy,
|
664
796
|
)
|
665
797
|
else:
|
666
798
|
self.tree_cache = RadixCache(
|
@@ -669,16 +801,36 @@ class Scheduler(
|
|
669
801
|
page_size=self.page_size,
|
670
802
|
disable=server_args.disable_radix_cache,
|
671
803
|
enable_kv_cache_events=self.enable_kv_cache_events,
|
804
|
+
eviction_policy=server_args.radix_eviction_policy,
|
805
|
+
is_eagle=self.spec_algorithm.is_eagle(),
|
672
806
|
)
|
673
807
|
|
808
|
+
if (
|
809
|
+
server_args.disaggregation_mode == "decode"
|
810
|
+
and server_args.disaggregation_decode_enable_offload_kvcache
|
811
|
+
):
|
812
|
+
self.decode_offload_manager = DecodeKVCacheOffloadManager(
|
813
|
+
req_to_token_pool=self.req_to_token_pool,
|
814
|
+
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
815
|
+
tp_group=(
|
816
|
+
self.attn_tp_cpu_group
|
817
|
+
if self.server_args.enable_dp_attention
|
818
|
+
else self.tp_cpu_group
|
819
|
+
),
|
820
|
+
tree_cache=self.tree_cache,
|
821
|
+
server_args=self.server_args,
|
822
|
+
)
|
823
|
+
else:
|
824
|
+
self.decode_offload_manager = None
|
825
|
+
|
674
826
|
self.decode_mem_cache_buf_multiplier = (
|
675
827
|
1
|
676
828
|
if self.spec_algorithm.is_none()
|
677
829
|
else (
|
678
830
|
server_args.speculative_num_draft_tokens
|
679
831
|
+ (
|
680
|
-
server_args.speculative_eagle_topk
|
681
|
-
* server_args.speculative_num_steps
|
832
|
+
(server_args.speculative_eagle_topk or 1)
|
833
|
+
* (server_args.speculative_num_steps or 1)
|
682
834
|
)
|
683
835
|
)
|
684
836
|
)
|
@@ -701,7 +853,7 @@ class Scheduler(
|
|
701
853
|
self.disagg_metadata_buffers = MetadataBuffers(
|
702
854
|
buffer_size,
|
703
855
|
hidden_size=self.model_config.hf_text_config.hidden_size,
|
704
|
-
|
856
|
+
hidden_states_dtype=self.model_config.dtype,
|
705
857
|
custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
|
706
858
|
)
|
707
859
|
|
@@ -721,7 +873,7 @@ class Scheduler(
|
|
721
873
|
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
722
874
|
draft_token_to_kv_pool=(
|
723
875
|
None
|
724
|
-
if self.draft_worker is None
|
876
|
+
if self.draft_worker is None or self.spec_algorithm.is_ngram()
|
725
877
|
else self.draft_worker.model_runner.token_to_kv_pool
|
726
878
|
),
|
727
879
|
req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
|
@@ -750,7 +902,7 @@ class Scheduler(
|
|
750
902
|
self.disagg_metadata_buffers = MetadataBuffers(
|
751
903
|
buffer_size,
|
752
904
|
hidden_size=self.model_config.hf_text_config.hidden_size,
|
753
|
-
|
905
|
+
hidden_states_dtype=self.model_config.dtype,
|
754
906
|
custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
|
755
907
|
)
|
756
908
|
|
@@ -758,7 +910,7 @@ class Scheduler(
|
|
758
910
|
token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
|
759
911
|
draft_token_to_kv_pool=(
|
760
912
|
None
|
761
|
-
if self.draft_worker is None
|
913
|
+
if self.draft_worker is None or self.spec_algorithm.is_ngram()
|
762
914
|
else self.draft_worker.model_runner.token_to_kv_pool
|
763
915
|
),
|
764
916
|
req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
|
@@ -853,7 +1005,6 @@ class Scheduler(
|
|
853
1005
|
self.running_mbs = [
|
854
1006
|
ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
|
855
1007
|
]
|
856
|
-
bids = [None] * self.pp_size
|
857
1008
|
pp_outputs: Optional[PPProxyTensors] = None
|
858
1009
|
while True:
|
859
1010
|
server_is_idle = True
|
@@ -874,10 +1025,7 @@ class Scheduler(
|
|
874
1025
|
# (last rank) send the outputs to the next step
|
875
1026
|
if self.pp_group.is_last_rank:
|
876
1027
|
if self.cur_batch:
|
877
|
-
next_token_ids
|
878
|
-
result.next_token_ids,
|
879
|
-
result.bid,
|
880
|
-
)
|
1028
|
+
next_token_ids = result.next_token_ids
|
881
1029
|
if self.cur_batch.return_logprob:
|
882
1030
|
pp_outputs = PPProxyTensors(
|
883
1031
|
{
|
@@ -925,17 +1073,10 @@ class Scheduler(
|
|
925
1073
|
logits_output = LogitsProcessorOutput(**logits_output_args)
|
926
1074
|
else:
|
927
1075
|
logits_output = None
|
928
|
-
|
1076
|
+
|
1077
|
+
output_result = GenerationBatchResult.from_pp_proxy(
|
929
1078
|
logits_output=logits_output,
|
930
|
-
|
931
|
-
next_token_ids=next_pp_outputs["next_token_ids"],
|
932
|
-
extend_input_len_per_req=next_pp_outputs.tensors.get(
|
933
|
-
"extend_input_len_per_req", None
|
934
|
-
),
|
935
|
-
extend_logprob_start_len_per_req=next_pp_outputs.tensors.get(
|
936
|
-
"extend_logprob_start_len_per_req", None
|
937
|
-
),
|
938
|
-
bid=bids[next_mb_id],
|
1079
|
+
next_pp_outputs=next_pp_outputs,
|
939
1080
|
can_run_cuda_graph=result.can_run_cuda_graph,
|
940
1081
|
)
|
941
1082
|
self.process_batch_result(mbs[next_mb_id], output_result)
|
@@ -943,8 +1084,6 @@ class Scheduler(
|
|
943
1084
|
|
944
1085
|
# (not last rank)
|
945
1086
|
if not self.pp_group.is_last_rank:
|
946
|
-
if self.cur_batch:
|
947
|
-
bids[mb_id] = result.bid
|
948
1087
|
# carry the outputs to the next stage
|
949
1088
|
# send the outputs from the last round to let the next stage worker run post processing
|
950
1089
|
if pp_outputs:
|
@@ -966,8 +1105,10 @@ class Scheduler(
|
|
966
1105
|
|
967
1106
|
# send out proxy tensors to the next stage
|
968
1107
|
if self.cur_batch:
|
1108
|
+
# FIXME(lsyin): remove this assert
|
1109
|
+
assert result.pp_hidden_states_proxy_tensors.tensors is not None
|
969
1110
|
self.pp_group.send_tensor_dict(
|
970
|
-
result.pp_hidden_states_proxy_tensors,
|
1111
|
+
result.pp_hidden_states_proxy_tensors.tensors,
|
971
1112
|
all_gather_group=self.attn_tp_group,
|
972
1113
|
)
|
973
1114
|
|
@@ -1077,6 +1218,15 @@ class Scheduler(
|
|
1077
1218
|
self.tp_cpu_group,
|
1078
1219
|
src=self.tp_group.ranks[0],
|
1079
1220
|
)
|
1221
|
+
|
1222
|
+
if self.enable_trace:
|
1223
|
+
for req in recv_reqs:
|
1224
|
+
if isinstance(
|
1225
|
+
req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
|
1226
|
+
):
|
1227
|
+
trace_set_proc_propagate_context(req.rid, req.trace_context)
|
1228
|
+
trace_slice_start("", req.rid, anonymous=True)
|
1229
|
+
|
1080
1230
|
return recv_reqs
|
1081
1231
|
|
1082
1232
|
def process_input_requests(self, recv_reqs: List):
|
@@ -1090,27 +1240,13 @@ class Scheduler(
|
|
1090
1240
|
self.return_health_check_ct += 1
|
1091
1241
|
continue
|
1092
1242
|
|
1093
|
-
# If it is a
|
1094
|
-
if
|
1095
|
-
if len(self.waiting_queue) + 1 > self.max_queued_requests:
|
1096
|
-
abort_req = AbortReq(
|
1097
|
-
recv_req.rid,
|
1098
|
-
finished_reason={
|
1099
|
-
"type": "abort",
|
1100
|
-
"status_code": HTTPStatus.SERVICE_UNAVAILABLE,
|
1101
|
-
"message": "The request queue is full.",
|
1102
|
-
},
|
1103
|
-
)
|
1104
|
-
self.send_to_tokenizer.send_pyobj(abort_req)
|
1105
|
-
continue
|
1106
|
-
|
1107
|
-
# If it is a MultiTokenizerWarpper, unwrap it and handle the inner request.
|
1108
|
-
if isinstance(recv_req, MultiTokenizerWarpper):
|
1243
|
+
# If it is a MultiTokenizerWrapper, unwrap it and handle the inner request.
|
1244
|
+
if isinstance(recv_req, MultiTokenizerWrapper):
|
1109
1245
|
worker_id = recv_req.worker_id
|
1110
1246
|
recv_req = recv_req.obj
|
1111
1247
|
output = self._request_dispatcher(recv_req)
|
1112
1248
|
if output is not None:
|
1113
|
-
output =
|
1249
|
+
output = MultiTokenizerWrapper(worker_id, output)
|
1114
1250
|
self.send_to_tokenizer.send_pyobj(output)
|
1115
1251
|
continue
|
1116
1252
|
|
@@ -1122,16 +1258,20 @@ class Scheduler(
|
|
1122
1258
|
else:
|
1123
1259
|
self.send_to_tokenizer.send_pyobj(output)
|
1124
1260
|
|
1261
|
+
def init_req_max_new_tokens(self, req):
|
1262
|
+
req.sampling_params.max_new_tokens = min(
|
1263
|
+
(
|
1264
|
+
req.sampling_params.max_new_tokens
|
1265
|
+
if req.sampling_params.max_new_tokens is not None
|
1266
|
+
else 1 << 30
|
1267
|
+
),
|
1268
|
+
self.max_req_len - len(req.origin_input_ids) - 1,
|
1269
|
+
)
|
1270
|
+
|
1125
1271
|
def handle_generate_request(
|
1126
1272
|
self,
|
1127
1273
|
recv_req: TokenizedGenerateReqInput,
|
1128
1274
|
):
|
1129
|
-
if (
|
1130
|
-
self.server_args.enable_dp_attention
|
1131
|
-
and self.server_args.load_balance_method == "minimum_tokens"
|
1132
|
-
):
|
1133
|
-
self.recv_dp_balance_id_this_term.append(recv_req.dp_balance_id)
|
1134
|
-
|
1135
1275
|
# Create a new request
|
1136
1276
|
if (
|
1137
1277
|
recv_req.session_params is None
|
@@ -1165,8 +1305,13 @@ class Scheduler(
|
|
1165
1305
|
bootstrap_host=recv_req.bootstrap_host,
|
1166
1306
|
bootstrap_port=recv_req.bootstrap_port,
|
1167
1307
|
bootstrap_room=recv_req.bootstrap_room,
|
1308
|
+
disagg_mode=self.disaggregation_mode,
|
1168
1309
|
data_parallel_rank=recv_req.data_parallel_rank,
|
1169
1310
|
vocab_size=self.model_config.vocab_size,
|
1311
|
+
priority=recv_req.priority,
|
1312
|
+
metrics_collector=(
|
1313
|
+
self.metrics_collector if self.enable_metrics else None
|
1314
|
+
),
|
1170
1315
|
)
|
1171
1316
|
req.tokenizer = self.tokenizer
|
1172
1317
|
|
@@ -1189,6 +1334,7 @@ class Scheduler(
|
|
1189
1334
|
req.set_finish_with_abort(
|
1190
1335
|
f"Invalid request: session id {recv_req.session_params.id} does not exist"
|
1191
1336
|
)
|
1337
|
+
self.init_req_max_new_tokens(req)
|
1192
1338
|
self._add_request_to_queue(req)
|
1193
1339
|
return
|
1194
1340
|
else:
|
@@ -1196,6 +1342,7 @@ class Scheduler(
|
|
1196
1342
|
session = self.sessions[recv_req.session_params.id]
|
1197
1343
|
req = session.create_req(recv_req, self.tokenizer)
|
1198
1344
|
if isinstance(req.finished_reason, FINISH_ABORT):
|
1345
|
+
self.init_req_max_new_tokens(req)
|
1199
1346
|
self._add_request_to_queue(req)
|
1200
1347
|
return
|
1201
1348
|
|
@@ -1215,9 +1362,13 @@ class Scheduler(
|
|
1215
1362
|
f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
|
1216
1363
|
)
|
1217
1364
|
)
|
1365
|
+
self.init_req_max_new_tokens(req)
|
1218
1366
|
self._add_request_to_queue(req)
|
1219
1367
|
return
|
1220
1368
|
|
1369
|
+
# initialize before returning
|
1370
|
+
self.init_req_max_new_tokens(req)
|
1371
|
+
|
1221
1372
|
# Validate prompt length
|
1222
1373
|
error_msg = validate_input_length(
|
1223
1374
|
req,
|
@@ -1232,26 +1383,25 @@ class Scheduler(
|
|
1232
1383
|
# Copy more attributes
|
1233
1384
|
if recv_req.logprob_start_len == -1 or not recv_req.return_logprob:
|
1234
1385
|
# By default, only return the logprobs for output tokens
|
1235
|
-
|
1386
|
+
# For prefill-only requests with logprob_start_len == -1, set logprob_start_len beyond input sequence
|
1387
|
+
# to skip input logprob computation entirely
|
1388
|
+
if req.is_prefill_only:
|
1389
|
+
req.logprob_start_len = len(req.origin_input_ids)
|
1390
|
+
else:
|
1391
|
+
# TODO: For text generation, evaluate setting logprob_start_len to len(req.origin_input_ids) as well
|
1392
|
+
req.logprob_start_len = len(req.origin_input_ids) - 1
|
1236
1393
|
else:
|
1237
1394
|
req.logprob_start_len = recv_req.logprob_start_len
|
1238
1395
|
|
1239
|
-
if req.logprob_start_len >= len(
|
1396
|
+
if not req.is_prefill_only and req.logprob_start_len >= len(
|
1397
|
+
req.origin_input_ids
|
1398
|
+
):
|
1240
1399
|
error_msg = f"{req.logprob_start_len=} is higher than the number of input tokens {len(req.origin_input_ids)=}. Please use a smaller logprob_start_len."
|
1241
1400
|
req.logprob_start_len = len(req.origin_input_ids) - 1
|
1242
1401
|
req.set_finish_with_abort(error_msg)
|
1243
1402
|
self._add_request_to_queue(req)
|
1244
1403
|
return
|
1245
1404
|
|
1246
|
-
req.sampling_params.max_new_tokens = min(
|
1247
|
-
(
|
1248
|
-
req.sampling_params.max_new_tokens
|
1249
|
-
if req.sampling_params.max_new_tokens is not None
|
1250
|
-
else 1 << 30
|
1251
|
-
),
|
1252
|
-
self.max_req_len - len(req.origin_input_ids) - 1,
|
1253
|
-
)
|
1254
|
-
|
1255
1405
|
# Init grammar cache for this request
|
1256
1406
|
add_to_grammar_queue = False
|
1257
1407
|
if (
|
@@ -1282,7 +1432,6 @@ class Scheduler(
|
|
1282
1432
|
req.set_finish_with_abort(error_msg)
|
1283
1433
|
|
1284
1434
|
if add_to_grammar_queue:
|
1285
|
-
req.queue_time_start = time.perf_counter()
|
1286
1435
|
self.grammar_queue.append(req)
|
1287
1436
|
else:
|
1288
1437
|
self._add_request_to_queue(req)
|
@@ -1298,19 +1447,6 @@ class Scheduler(
|
|
1298
1447
|
for tokenized_req in recv_req:
|
1299
1448
|
self.handle_generate_request(tokenized_req)
|
1300
1449
|
|
1301
|
-
def _add_request_to_queue(self, req: Req):
|
1302
|
-
req.queue_time_start = time.perf_counter()
|
1303
|
-
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
1304
|
-
self._prefetch_kvcache(req)
|
1305
|
-
self.disagg_prefill_bootstrap_queue.add(
|
1306
|
-
req, self.model_config.num_key_value_heads
|
1307
|
-
)
|
1308
|
-
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
1309
|
-
self.disagg_decode_prealloc_queue.add(req)
|
1310
|
-
else:
|
1311
|
-
self._prefetch_kvcache(req)
|
1312
|
-
self.waiting_queue.append(req)
|
1313
|
-
|
1314
1450
|
def _prefetch_kvcache(self, req: Req):
|
1315
1451
|
if self.enable_hicache_storage:
|
1316
1452
|
req.init_next_round_input(self.tree_cache)
|
@@ -1324,16 +1460,87 @@ class Scheduler(
|
|
1324
1460
|
req.rid, req.last_host_node, new_input_tokens, last_hash
|
1325
1461
|
)
|
1326
1462
|
|
1327
|
-
def
|
1328
|
-
if self.disaggregation_mode == DisaggregationMode.
|
1329
|
-
self.
|
1330
|
-
|
1463
|
+
def _add_request_to_queue(self, req: Req, is_retracted: bool = False):
|
1464
|
+
if self.disaggregation_mode == DisaggregationMode.NULL:
|
1465
|
+
self._set_or_validate_priority(req)
|
1466
|
+
if self._abort_on_queued_limit(req):
|
1467
|
+
return
|
1468
|
+
self._prefetch_kvcache(req)
|
1469
|
+
self.waiting_queue.append(req)
|
1470
|
+
req.time_stats.wait_queue_entry_time = time.perf_counter()
|
1471
|
+
trace_slice_end("process req", req.rid, auto_next_anon=True)
|
1472
|
+
elif self.disaggregation_mode == DisaggregationMode.PREFILL:
|
1473
|
+
self._prefetch_kvcache(req)
|
1474
|
+
self.disagg_prefill_bootstrap_queue.add(
|
1475
|
+
req, self.model_config.num_key_value_heads
|
1331
1476
|
)
|
1477
|
+
req.time_stats.prefill_bootstrap_queue_entry_time = time.perf_counter()
|
1332
1478
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
1333
|
-
|
1334
|
-
|
1479
|
+
self.disagg_decode_prealloc_queue.add(req, is_retracted=is_retracted)
|
1480
|
+
if not is_retracted:
|
1481
|
+
req.time_stats.decode_prealloc_queue_entry_time = time.perf_counter()
|
1335
1482
|
else:
|
1336
|
-
self.
|
1483
|
+
raise ValueError(f"Invalid {self.disaggregation_mode=}")
|
1484
|
+
|
1485
|
+
def _set_or_validate_priority(self, req: Req):
|
1486
|
+
"""Set the default priority value, or abort the request based on the priority scheduling mode."""
|
1487
|
+
if self.enable_priority_scheduling and req.priority is None:
|
1488
|
+
if self.schedule_low_priority_values_first:
|
1489
|
+
req.priority = sys.maxsize
|
1490
|
+
else:
|
1491
|
+
req.priority = -sys.maxsize - 1
|
1492
|
+
elif not self.enable_priority_scheduling and req.priority is not None:
|
1493
|
+
abort_req = AbortReq(
|
1494
|
+
finished_reason={
|
1495
|
+
"type": "abort",
|
1496
|
+
"status_code": HTTPStatus.SERVICE_UNAVAILABLE,
|
1497
|
+
"message": "Using priority is disabled for this server. Please send a new request without a priority.",
|
1498
|
+
},
|
1499
|
+
rid=req.rid,
|
1500
|
+
)
|
1501
|
+
self.send_to_tokenizer.send_pyobj(abort_req)
|
1502
|
+
|
1503
|
+
def _abort_on_queued_limit(self, recv_req: Req) -> bool:
|
1504
|
+
"""Abort an incoming or existing request if the waiting queue is full. Returns True if the incoming request is aborted."""
|
1505
|
+
if (
|
1506
|
+
self.max_queued_requests is None
|
1507
|
+
or len(self.waiting_queue) + 1 <= self.max_queued_requests
|
1508
|
+
):
|
1509
|
+
return False
|
1510
|
+
|
1511
|
+
# Reject the incoming request by default.
|
1512
|
+
req_to_abort = recv_req
|
1513
|
+
message = "The request queue is full."
|
1514
|
+
if self.enable_priority_scheduling:
|
1515
|
+
# With priority scheduling, consider aboritng an existing request based on the priority.
|
1516
|
+
# direction = 1 => smaller number = higher priority; -1 => larger number = higher priority.
|
1517
|
+
# max(...) + (direction * priority, queue_time_start) picks the least-preferred request.
|
1518
|
+
# Tie: later queue_time_start (newer) is evicted first. Preempt only if strictly better.
|
1519
|
+
direction = 1 if self.schedule_low_priority_values_first else -1
|
1520
|
+
key_fn = lambda item: (
|
1521
|
+
direction * item[1].priority,
|
1522
|
+
item[1].time_stats.wait_queue_entry_time,
|
1523
|
+
)
|
1524
|
+
idx, candidate_req = max(enumerate(self.waiting_queue), key=key_fn)
|
1525
|
+
abort_existing_req = (
|
1526
|
+
direction * recv_req.priority < direction * candidate_req.priority
|
1527
|
+
)
|
1528
|
+
if abort_existing_req:
|
1529
|
+
self.waiting_queue.pop(idx)
|
1530
|
+
req_to_abort = candidate_req
|
1531
|
+
message = "The request is aborted by a higher priority request."
|
1532
|
+
|
1533
|
+
self.send_to_tokenizer.send_pyobj(
|
1534
|
+
AbortReq(
|
1535
|
+
finished_reason={
|
1536
|
+
"type": "abort",
|
1537
|
+
"status_code": HTTPStatus.SERVICE_UNAVAILABLE,
|
1538
|
+
"message": message,
|
1539
|
+
},
|
1540
|
+
rid=req_to_abort.rid,
|
1541
|
+
)
|
1542
|
+
)
|
1543
|
+
return req_to_abort.rid == recv_req.rid
|
1337
1544
|
|
1338
1545
|
def handle_embedding_request(
|
1339
1546
|
self,
|
@@ -1345,6 +1552,7 @@ class Scheduler(
|
|
1345
1552
|
recv_req.input_ids,
|
1346
1553
|
recv_req.sampling_params,
|
1347
1554
|
token_type_ids=recv_req.token_type_ids,
|
1555
|
+
priority=recv_req.priority,
|
1348
1556
|
)
|
1349
1557
|
req.tokenizer = self.tokenizer
|
1350
1558
|
|
@@ -1421,9 +1629,11 @@ class Scheduler(
|
|
1421
1629
|
_, _, available_size, evictable_size = self._get_token_info()
|
1422
1630
|
protected_size = self.tree_cache.protected_size()
|
1423
1631
|
memory_leak = (available_size + evictable_size) != (
|
1632
|
+
# self.max_total_num_tokens
|
1633
|
+
# if not self.enable_hierarchical_cache
|
1634
|
+
# else self.max_total_num_tokens - protected_size
|
1424
1635
|
self.max_total_num_tokens
|
1425
|
-
|
1426
|
-
else self.max_total_num_tokens - protected_size
|
1636
|
+
- protected_size
|
1427
1637
|
)
|
1428
1638
|
token_msg = f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, {protected_size=}\n"
|
1429
1639
|
|
@@ -1474,6 +1684,20 @@ class Scheduler(
|
|
1474
1684
|
self.stats.gen_throughput = 0
|
1475
1685
|
self.stats.num_queue_reqs = len(self.waiting_queue)
|
1476
1686
|
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
|
1687
|
+
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
1688
|
+
self.stats.num_prefill_prealloc_queue_reqs = len(
|
1689
|
+
self.disagg_prefill_bootstrap_queue.queue
|
1690
|
+
)
|
1691
|
+
self.stats.num_prefill_inflight_queue_reqs = len(
|
1692
|
+
self.disagg_prefill_inflight_queue
|
1693
|
+
)
|
1694
|
+
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
1695
|
+
self.stats.num_decode_prealloc_queue_reqs = len(
|
1696
|
+
self.disagg_decode_prealloc_queue.queue
|
1697
|
+
)
|
1698
|
+
self.stats.num_decode_transfer_queue_reqs = len(
|
1699
|
+
self.disagg_decode_transfer_queue.queue
|
1700
|
+
)
|
1477
1701
|
self.metrics_collector.log_stats(self.stats)
|
1478
1702
|
self._publish_kv_events()
|
1479
1703
|
|
@@ -1521,7 +1745,12 @@ class Scheduler(
|
|
1521
1745
|
chunked_req_to_exclude.add(self.chunked_req)
|
1522
1746
|
self.tree_cache.cache_unfinished_req(self.chunked_req, chunked=True)
|
1523
1747
|
# chunked request keeps its rid but will get a new req_pool_idx
|
1524
|
-
self.
|
1748
|
+
if self.tp_worker.worker.model_runner.is_hybrid_gdn:
|
1749
|
+
self.req_to_token_pool.free(
|
1750
|
+
self.chunked_req.req_pool_idx, free_mamba_cache=False
|
1751
|
+
)
|
1752
|
+
else:
|
1753
|
+
self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
|
1525
1754
|
if self.last_batch and self.last_batch.forward_mode.is_extend():
|
1526
1755
|
if self.last_batch.chunked_req is not None:
|
1527
1756
|
# In the context pipeline parallelism, after the last chunk, the current microbatch still track outdated chunked_req.
|
@@ -1568,11 +1797,6 @@ class Scheduler(
|
|
1568
1797
|
|
1569
1798
|
# Handle DP attention
|
1570
1799
|
if need_dp_attn_preparation:
|
1571
|
-
if (
|
1572
|
-
self.server_args.load_balance_method == "minimum_tokens"
|
1573
|
-
and self.forward_ct % 40 == 0
|
1574
|
-
):
|
1575
|
-
self.handle_dp_balance_data(ret)
|
1576
1800
|
ret = self.prepare_mlp_sync_batch(ret)
|
1577
1801
|
|
1578
1802
|
return ret
|
@@ -1588,6 +1812,10 @@ class Scheduler(
|
|
1588
1812
|
if self.grammar_queue:
|
1589
1813
|
self.move_ready_grammar_requests()
|
1590
1814
|
|
1815
|
+
if self.try_preemption:
|
1816
|
+
# Reset batch_is_full to try preemption with a prefill adder.
|
1817
|
+
self.running_batch.batch_is_full = False
|
1818
|
+
|
1591
1819
|
# Handle the cases where prefill is not allowed
|
1592
1820
|
if (
|
1593
1821
|
self.running_batch.batch_is_full or len(self.waiting_queue) == 0
|
@@ -1600,7 +1828,11 @@ class Scheduler(
|
|
1600
1828
|
# as the space for the chunked request has just been released.
|
1601
1829
|
# In PP case, a chunked req can start in one microbatch and end in another microbatch, so the max_running_requests per microbatch should not be strict.
|
1602
1830
|
# Instead, we should always allow chunked request to be added, otherwise, there will be a memory leak.
|
1603
|
-
if
|
1831
|
+
if (
|
1832
|
+
self.get_num_allocatable_reqs(running_bs) <= 0
|
1833
|
+
and not self.chunked_req
|
1834
|
+
and not self.try_preemption
|
1835
|
+
):
|
1604
1836
|
self.running_batch.batch_is_full = True
|
1605
1837
|
return None
|
1606
1838
|
|
@@ -1620,6 +1852,7 @@ class Scheduler(
|
|
1620
1852
|
self.max_prefill_tokens,
|
1621
1853
|
self.chunked_prefill_size,
|
1622
1854
|
running_bs if self.is_mixed_chunk else 0,
|
1855
|
+
self.priority_scheduling_preemption_threshold,
|
1623
1856
|
)
|
1624
1857
|
|
1625
1858
|
if self.chunked_req is not None:
|
@@ -1640,15 +1873,19 @@ class Scheduler(
|
|
1640
1873
|
self.running_batch.batch_is_full = True
|
1641
1874
|
break
|
1642
1875
|
|
1876
|
+
running_bs = len(self.running_batch.reqs) - len(adder.preempt_list)
|
1643
1877
|
if len(adder.can_run_list) >= self.get_num_allocatable_reqs(running_bs):
|
1644
1878
|
self.running_batch.batch_is_full = True
|
1645
|
-
break
|
1646
|
-
|
1647
1879
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
1648
1880
|
# In prefill mode, prealloc queue and transfer queue can also take memory,
|
1649
1881
|
# so we need to check if the available size for the actual available size.
|
1650
1882
|
if len(adder.can_run_list) >= self.req_to_token_pool.available_size():
|
1651
1883
|
self.running_batch.batch_is_full = True
|
1884
|
+
|
1885
|
+
if self.running_batch.batch_is_full:
|
1886
|
+
if not self.try_preemption:
|
1887
|
+
break
|
1888
|
+
if not adder.preempt_to_schedule(req, self.server_args):
|
1652
1889
|
break
|
1653
1890
|
|
1654
1891
|
if self.enable_hicache_storage:
|
@@ -1658,7 +1895,11 @@ class Scheduler(
|
|
1658
1895
|
continue
|
1659
1896
|
|
1660
1897
|
req.init_next_round_input(self.tree_cache)
|
1661
|
-
res = adder.add_one_req(
|
1898
|
+
res = adder.add_one_req(
|
1899
|
+
req,
|
1900
|
+
has_chunked_req=(self.chunked_req is not None),
|
1901
|
+
truncation_align_size=self.truncation_align_size,
|
1902
|
+
)
|
1662
1903
|
|
1663
1904
|
if res != AddReqResult.CONTINUE:
|
1664
1905
|
if res == AddReqResult.NO_TOKEN:
|
@@ -1679,11 +1920,14 @@ class Scheduler(
|
|
1679
1920
|
if self.enable_metrics:
|
1680
1921
|
# only record queue time when enable_metrics is True to avoid overhead
|
1681
1922
|
for req in can_run_list:
|
1682
|
-
req.
|
1923
|
+
req.add_latency(RequestStage.PREFILL_WAITING)
|
1683
1924
|
|
1684
1925
|
self.waiting_queue = [
|
1685
1926
|
x for x in self.waiting_queue if x not in set(can_run_list)
|
1686
1927
|
]
|
1928
|
+
if adder.preempt_list:
|
1929
|
+
for req in adder.preempt_list:
|
1930
|
+
self._add_request_to_queue(req)
|
1687
1931
|
|
1688
1932
|
if adder.new_chunked_req is not None:
|
1689
1933
|
assert self.chunked_req is None
|
@@ -1694,7 +1938,16 @@ class Scheduler(
|
|
1694
1938
|
|
1695
1939
|
# Print stats
|
1696
1940
|
if self.current_scheduler_metrics_enabled():
|
1697
|
-
self.log_prefill_stats(adder, can_run_list, running_bs)
|
1941
|
+
self.log_prefill_stats(adder, can_run_list, running_bs, 0)
|
1942
|
+
|
1943
|
+
for req in can_run_list:
|
1944
|
+
if req.time_stats.forward_entry_time == 0:
|
1945
|
+
# Avoid update chunked request many times
|
1946
|
+
req.time_stats.forward_entry_time = time.perf_counter()
|
1947
|
+
if self.enable_metrics:
|
1948
|
+
self.metrics_collector.observe_queue_time(
|
1949
|
+
req.time_stats.get_queueing_time(),
|
1950
|
+
)
|
1698
1951
|
|
1699
1952
|
# Create a new batch
|
1700
1953
|
new_batch = ScheduleBatch.init_new(
|
@@ -1749,19 +2002,25 @@ class Scheduler(
|
|
1749
2002
|
TEST_RETRACT and batch.batch_size() > 10
|
1750
2003
|
):
|
1751
2004
|
old_ratio = self.new_token_ratio
|
1752
|
-
|
1753
|
-
|
1754
|
-
|
2005
|
+
retracted_reqs, new_token_ratio, reqs_to_abort = batch.retract_decode(
|
2006
|
+
self.server_args
|
2007
|
+
)
|
2008
|
+
self.num_retracted_reqs = len(retracted_reqs)
|
1755
2009
|
self.new_token_ratio = new_token_ratio
|
2010
|
+
for req in reqs_to_abort:
|
2011
|
+
self.send_to_tokenizer.send_pyobj(
|
2012
|
+
AbortReq(abort_reason=req.to_abort_message, rid=req.rid)
|
2013
|
+
)
|
1756
2014
|
|
1757
2015
|
logger.info(
|
1758
2016
|
"KV cache pool is full. Retract requests. "
|
1759
|
-
f"#retracted_reqs: {
|
1760
|
-
f"#
|
2017
|
+
f"#retracted_reqs: {len(retracted_reqs)}, "
|
2018
|
+
f"#aborted_retracted_reqs: {len(reqs_to_abort)}, "
|
2019
|
+
f"#new_token_ratio: {old_ratio:.4f} -> {new_token_ratio:.4f}"
|
1761
2020
|
)
|
1762
2021
|
|
1763
|
-
|
1764
|
-
|
2022
|
+
for req in retracted_reqs:
|
2023
|
+
self._add_request_to_queue(req, is_retracted=True)
|
1765
2024
|
else:
|
1766
2025
|
self.new_token_ratio = max(
|
1767
2026
|
self.new_token_ratio - self.new_token_ratio_decay,
|
@@ -1789,37 +2048,25 @@ class Scheduler(
|
|
1789
2048
|
|
1790
2049
|
# Run forward
|
1791
2050
|
if self.is_generation:
|
2051
|
+
|
2052
|
+
batch_or_worker_batch = batch
|
2053
|
+
|
1792
2054
|
if self.spec_algorithm.is_none():
|
1793
|
-
|
2055
|
+
# FIXME(lsyin): remove this if and finally unify the abstraction
|
2056
|
+
batch_or_worker_batch = batch.get_model_worker_batch()
|
1794
2057
|
|
1795
|
-
|
1796
|
-
|
1797
|
-
|
2058
|
+
forward_batch_output = self.model_worker.forward_batch_generation(
|
2059
|
+
batch_or_worker_batch
|
2060
|
+
)
|
2061
|
+
|
2062
|
+
if not self.spec_algorithm.is_none():
|
2063
|
+
# TODO(lsyin): unify this metric-updating logic with non-spec, and move it to decode processing
|
2064
|
+
self.udpate_spec_metrics(
|
2065
|
+
batch.batch_size(), forward_batch_output.num_accepted_tokens
|
1798
2066
|
)
|
1799
|
-
|
1800
|
-
|
1801
|
-
|
1802
|
-
)
|
1803
|
-
else:
|
1804
|
-
pp_hidden_states_proxy_tensors, _, can_run_cuda_graph = (
|
1805
|
-
self.tp_worker.forward_batch_generation(model_worker_batch)
|
1806
|
-
)
|
1807
|
-
bid = model_worker_batch.bid
|
1808
|
-
else:
|
1809
|
-
(
|
1810
|
-
logits_output,
|
1811
|
-
next_token_ids,
|
1812
|
-
bid,
|
1813
|
-
num_accepted_tokens,
|
1814
|
-
can_run_cuda_graph,
|
1815
|
-
) = self.draft_worker.forward_batch_speculative_generation(batch)
|
1816
|
-
bs = batch.batch_size()
|
1817
|
-
self.spec_num_total_accepted_tokens += num_accepted_tokens + bs
|
1818
|
-
self.spec_num_total_forward_ct += bs
|
1819
|
-
self.num_generated_tokens += num_accepted_tokens
|
1820
|
-
|
1821
|
-
if self.pp_group.is_last_rank:
|
1822
|
-
batch.output_ids = next_token_ids
|
2067
|
+
|
2068
|
+
# update batch's output ids
|
2069
|
+
batch.output_ids = forward_batch_output.next_token_ids
|
1823
2070
|
|
1824
2071
|
# These 2 values are needed for processing the output, but the values can be
|
1825
2072
|
# modified by overlap schedule. So we have to copy them here so that
|
@@ -1828,6 +2075,7 @@ class Scheduler(
|
|
1828
2075
|
extend_input_len_per_req = [req.extend_input_len for req in batch.reqs]
|
1829
2076
|
else:
|
1830
2077
|
extend_input_len_per_req = None
|
2078
|
+
|
1831
2079
|
if batch.return_logprob:
|
1832
2080
|
extend_logprob_start_len_per_req = [
|
1833
2081
|
req.extend_logprob_start_len for req in batch.reqs
|
@@ -1835,25 +2083,15 @@ class Scheduler(
|
|
1835
2083
|
else:
|
1836
2084
|
extend_logprob_start_len_per_req = None
|
1837
2085
|
|
1838
|
-
|
1839
|
-
|
1840
|
-
pp_hidden_states_proxy_tensors=(
|
1841
|
-
pp_hidden_states_proxy_tensors
|
1842
|
-
if not self.pp_group.is_last_rank
|
1843
|
-
else None
|
1844
|
-
),
|
1845
|
-
next_token_ids=next_token_ids if self.pp_group.is_last_rank else None,
|
2086
|
+
return GenerationBatchResult.from_forward_batch_output(
|
2087
|
+
forward_batch_output=forward_batch_output,
|
1846
2088
|
extend_input_len_per_req=extend_input_len_per_req,
|
1847
2089
|
extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
|
1848
|
-
bid=bid,
|
1849
|
-
can_run_cuda_graph=can_run_cuda_graph,
|
1850
2090
|
)
|
1851
2091
|
else: # embedding or reward model
|
1852
2092
|
model_worker_batch = batch.get_model_worker_batch()
|
1853
2093
|
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
|
1854
|
-
ret = EmbeddingBatchResult(
|
1855
|
-
embeddings=embeddings, bid=model_worker_batch.bid
|
1856
|
-
)
|
2094
|
+
ret = EmbeddingBatchResult(embeddings=embeddings)
|
1857
2095
|
return ret
|
1858
2096
|
|
1859
2097
|
def process_batch_result(
|
@@ -1864,8 +2102,14 @@ class Scheduler(
|
|
1864
2102
|
):
|
1865
2103
|
if batch.forward_mode.is_decode():
|
1866
2104
|
self.process_batch_result_decode(batch, result, launch_done)
|
2105
|
+
if self.enable_trace:
|
2106
|
+
trace_slice_batch("decode loop", batch.reqs)
|
2107
|
+
|
1867
2108
|
elif batch.forward_mode.is_extend():
|
1868
2109
|
self.process_batch_result_prefill(batch, result, launch_done)
|
2110
|
+
if self.enable_trace:
|
2111
|
+
trace_slice_batch("prefill", batch.reqs)
|
2112
|
+
|
1869
2113
|
elif batch.forward_mode.is_idle():
|
1870
2114
|
if self.enable_overlap:
|
1871
2115
|
self.tp_worker.resolve_last_batch_result(launch_done)
|
@@ -1897,86 +2141,6 @@ class Scheduler(
|
|
1897
2141
|
disable_overlap_schedule=self.server_args.disable_overlap_schedule,
|
1898
2142
|
)
|
1899
2143
|
|
1900
|
-
def handle_dp_balance_data(self, local_batch: ScheduleBatch):
|
1901
|
-
def gather_dp_balance_info(holding_tokens_list) -> Union[None, List[List[int]]]:
|
1902
|
-
"""gather recv_dp_balance_id_this_term and holding tokens per worker for dp balance"""
|
1903
|
-
recv_list = self.recv_dp_balance_id_this_term
|
1904
|
-
assert len(recv_list) <= 511, (
|
1905
|
-
"The number of requests received this round is too large. "
|
1906
|
-
"Please increase gather_tensor_size and onfly_info_size."
|
1907
|
-
)
|
1908
|
-
# The maximum size of the tensor used for gathering data from all workers.
|
1909
|
-
gather_tensor_size = 512
|
1910
|
-
|
1911
|
-
# recv_tensor: | holding_tokens | len(recv_dp_balance_id) | recv_dp_balance_ids
|
1912
|
-
recv_tensor = torch.zeros(gather_tensor_size, dtype=torch.int32)
|
1913
|
-
recv_tensor[0] = holding_tokens_list
|
1914
|
-
recv_tensor[1] = len(
|
1915
|
-
recv_list
|
1916
|
-
) # The first element is the length of the list.
|
1917
|
-
recv_tensor[2 : len(recv_list) + 2] = torch.tensor(
|
1918
|
-
recv_list, dtype=torch.int32
|
1919
|
-
)
|
1920
|
-
|
1921
|
-
if self.tp_rank == 0:
|
1922
|
-
gathered_list = [
|
1923
|
-
torch.zeros(gather_tensor_size, dtype=torch.int32)
|
1924
|
-
for _ in range(self.balance_meta.num_workers)
|
1925
|
-
]
|
1926
|
-
else:
|
1927
|
-
gathered_list = None
|
1928
|
-
|
1929
|
-
torch.distributed.gather(
|
1930
|
-
recv_tensor, gathered_list, group=self.tp_cpu_group
|
1931
|
-
)
|
1932
|
-
|
1933
|
-
gathered_id_list_per_worker = None
|
1934
|
-
if self.tp_rank == 0:
|
1935
|
-
gathered_id_list_per_worker = []
|
1936
|
-
holding_tokens_list = []
|
1937
|
-
for tensor in gathered_list:
|
1938
|
-
holding_tokens_list.append(tensor[0].item())
|
1939
|
-
list_length = tensor[1].item()
|
1940
|
-
gathered_id_list_per_worker.append(
|
1941
|
-
tensor[2 : list_length + 2].tolist()
|
1942
|
-
)
|
1943
|
-
|
1944
|
-
return gathered_id_list_per_worker, holding_tokens_list
|
1945
|
-
|
1946
|
-
def write_shared_dp_balance_info(new_recv_rid_lists, local_tokens):
|
1947
|
-
meta = self.balance_meta
|
1948
|
-
|
1949
|
-
with meta.mutex:
|
1950
|
-
onfly_list: List[Dict[int, int]] = meta.get_shared_onfly()
|
1951
|
-
assert len(new_recv_rid_lists) == len(
|
1952
|
-
onfly_list
|
1953
|
-
), "num_worker not equal"
|
1954
|
-
# 1.Check if the rid received by each worker this round is present in onfly.
|
1955
|
-
# If it is, remove the corresponding onfly item.
|
1956
|
-
worker_id = 0
|
1957
|
-
for new_recv_rids, on_fly_reqs in zip(new_recv_rid_lists, onfly_list):
|
1958
|
-
for new_recv_rid in new_recv_rids:
|
1959
|
-
assert (
|
1960
|
-
new_recv_rid in on_fly_reqs
|
1961
|
-
), f"{new_recv_rid=} not in {worker_id=} {on_fly_reqs=}, data consistency is wrong"
|
1962
|
-
del on_fly_reqs[new_recv_rid]
|
1963
|
-
worker_id += 1
|
1964
|
-
# 2. Atomically write local_tokens and onfly into shm under the mutex
|
1965
|
-
meta.set_shared_onfly_info(onfly_list)
|
1966
|
-
meta.set_shared_local_tokens(local_tokens)
|
1967
|
-
|
1968
|
-
holding_tokens = self.get_load()
|
1969
|
-
|
1970
|
-
new_recv_dp_balance_id_list, holding_token_list = gather_dp_balance_info(
|
1971
|
-
holding_tokens
|
1972
|
-
)
|
1973
|
-
|
1974
|
-
self.recv_dp_balance_id_this_term.clear()
|
1975
|
-
if self.tp_rank == 0: # only first worker write info
|
1976
|
-
write_shared_dp_balance_info(
|
1977
|
-
new_recv_dp_balance_id_list, holding_token_list
|
1978
|
-
)
|
1979
|
-
|
1980
2144
|
@staticmethod
|
1981
2145
|
def prepare_mlp_sync_batch_raw(
|
1982
2146
|
local_batch: ScheduleBatch,
|
@@ -2104,12 +2268,13 @@ class Scheduler(
|
|
2104
2268
|
if req.finished(): # It is aborted by AbortReq
|
2105
2269
|
num_ready_reqs += 1
|
2106
2270
|
continue
|
2271
|
+
|
2107
2272
|
req.grammar = req.grammar.result(timeout=0.03)
|
2108
2273
|
self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
|
2109
2274
|
if req.grammar is INVALID_GRAMMAR_OBJ:
|
2110
|
-
req.
|
2111
|
-
|
2112
|
-
|
2275
|
+
error_msg = f"Invalid grammar request: {req.grammar_key=}"
|
2276
|
+
req.set_finish_with_abort(error_msg)
|
2277
|
+
|
2113
2278
|
num_ready_reqs += 1
|
2114
2279
|
except futures._base.TimeoutError:
|
2115
2280
|
req.grammar_wait_ct += 1
|
@@ -2141,9 +2306,8 @@ class Scheduler(
|
|
2141
2306
|
req.grammar = req.grammar.result()
|
2142
2307
|
self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
|
2143
2308
|
if req.grammar is INVALID_GRAMMAR_OBJ:
|
2144
|
-
req.
|
2145
|
-
|
2146
|
-
)
|
2309
|
+
error_msg = f"Invalid grammar request: {req.grammar_key=}"
|
2310
|
+
req.set_finish_with_abort(error_msg)
|
2147
2311
|
else:
|
2148
2312
|
num_ready_reqs_max = num_ready_reqs
|
2149
2313
|
num_timeout_reqs_max = num_timeout_reqs
|
@@ -2151,12 +2315,14 @@ class Scheduler(
|
|
2151
2315
|
for i in range(num_ready_reqs, num_ready_reqs + num_timeout_reqs_max):
|
2152
2316
|
req = self.grammar_queue[i]
|
2153
2317
|
req.grammar.cancel()
|
2318
|
+
self.grammar_backend.set_cache(req.grammar_key, INVALID_GRAMMAR_OBJ)
|
2154
2319
|
error_msg = f"Grammar preprocessing timed out for {req.grammar_key=}"
|
2155
2320
|
req.set_finish_with_abort(error_msg)
|
2156
|
-
|
2321
|
+
|
2157
2322
|
num_ready_reqs = num_ready_reqs_max + num_timeout_reqs_max
|
2158
2323
|
|
2159
|
-
self.
|
2324
|
+
for req in self.grammar_queue[:num_ready_reqs]:
|
2325
|
+
self._add_request_to_queue(req)
|
2160
2326
|
self.grammar_queue = self.grammar_queue[num_ready_reqs:]
|
2161
2327
|
|
2162
2328
|
def set_next_batch_sampling_info_done(self, batch: ScheduleBatch):
|
@@ -2248,9 +2414,8 @@ class Scheduler(
|
|
2248
2414
|
self.req_to_token_pool.clear()
|
2249
2415
|
self.token_to_kv_pool_allocator.clear()
|
2250
2416
|
|
2251
|
-
if
|
2252
|
-
self.draft_worker.
|
2253
|
-
self.draft_worker.model_runner.token_to_kv_pool_allocator.clear()
|
2417
|
+
if self.draft_worker:
|
2418
|
+
self.draft_worker.clear_cache_pool()
|
2254
2419
|
|
2255
2420
|
self.num_generated_tokens = 0
|
2256
2421
|
self.forward_ct_decode = 0
|
@@ -2270,39 +2435,50 @@ class Scheduler(
|
|
2270
2435
|
if_success = False
|
2271
2436
|
return if_success
|
2272
2437
|
|
2273
|
-
def get_load(self):
|
2438
|
+
def get_load(self, recv_req: GetLoadReqInput = None) -> GetLoadReqOutput:
|
2274
2439
|
# TODO(lsyin): use dynamically maintained num_waiting_tokens
|
2440
|
+
|
2275
2441
|
if self.is_hybrid:
|
2276
|
-
|
2442
|
+
num_tokens_full = (
|
2277
2443
|
self.full_tokens_per_layer
|
2278
2444
|
- self.token_to_kv_pool_allocator.full_available_size()
|
2279
2445
|
- self.tree_cache.full_evictable_size()
|
2280
2446
|
)
|
2281
|
-
|
2447
|
+
num_tokens_swa = (
|
2282
2448
|
self.swa_tokens_per_layer
|
2283
2449
|
- self.token_to_kv_pool_allocator.swa_available_size()
|
2284
2450
|
- self.tree_cache.swa_evictable_size()
|
2285
2451
|
)
|
2286
|
-
|
2452
|
+
num_tokens = max(num_tokens_full, num_tokens_swa)
|
2287
2453
|
else:
|
2288
|
-
|
2454
|
+
num_tokens = (
|
2289
2455
|
self.max_total_num_tokens
|
2290
2456
|
- self.token_to_kv_pool_allocator.available_size()
|
2291
2457
|
- self.tree_cache.evictable_size()
|
2292
2458
|
)
|
2293
|
-
|
2459
|
+
|
2460
|
+
# Tokens in waiting queue, bootstrap queue, prealloc queue
|
2461
|
+
num_tokens += sum(len(req.origin_input_ids) for req in self.waiting_queue)
|
2462
|
+
num_waiting_reqs = len(self.waiting_queue)
|
2294
2463
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
2295
|
-
|
2464
|
+
num_tokens += sum(
|
2296
2465
|
len(req.origin_input_ids)
|
2297
2466
|
for req in self.disagg_prefill_bootstrap_queue.queue
|
2298
2467
|
)
|
2468
|
+
num_waiting_reqs += len(self.disagg_prefill_bootstrap_queue.queue)
|
2299
2469
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
2300
|
-
|
2470
|
+
num_tokens += sum(
|
2301
2471
|
len(req.req.origin_input_ids)
|
2302
2472
|
for req in self.disagg_decode_prealloc_queue.queue
|
2303
2473
|
)
|
2474
|
+
num_waiting_reqs += len(self.disagg_decode_prealloc_queue.queue)
|
2304
2475
|
|
2305
|
-
return
|
2476
|
+
return GetLoadReqOutput(
|
2477
|
+
dp_rank=self.dp_rank,
|
2478
|
+
num_reqs=len(self.running_batch.reqs) + num_waiting_reqs,
|
2479
|
+
num_waiting_reqs=num_waiting_reqs,
|
2480
|
+
num_tokens=num_tokens,
|
2481
|
+
)
|
2306
2482
|
|
2307
2483
|
def get_internal_state(self, recv_req: GetInternalStateReq):
|
2308
2484
|
ret = dict(global_server_args_dict)
|
@@ -2317,10 +2493,9 @@ class Scheduler(
|
|
2317
2493
|
"token_capacity": int(self.max_total_num_tokens),
|
2318
2494
|
}
|
2319
2495
|
|
2320
|
-
|
2321
|
-
|
2322
|
-
|
2323
|
-
)
|
2496
|
+
ret["memory_usage"]["graph"] = round(
|
2497
|
+
self.tp_worker.worker.model_runner.graph_mem_usage, 2
|
2498
|
+
)
|
2324
2499
|
|
2325
2500
|
if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
|
2326
2501
|
ret["avg_spec_accept_length"] = (
|
@@ -2329,8 +2504,6 @@ class Scheduler(
|
|
2329
2504
|
if RECORD_STEP_TIME:
|
2330
2505
|
ret["step_time_dict"] = self.step_time_dict
|
2331
2506
|
|
2332
|
-
ret["load"] = self.get_load()
|
2333
|
-
|
2334
2507
|
return GetInternalStateReqOutput(internal_state=ret)
|
2335
2508
|
|
2336
2509
|
def set_internal_state(self, recv_req: SetInternalStateReq):
|
@@ -2406,7 +2579,7 @@ class Scheduler(
|
|
2406
2579
|
if self.enable_hicache_storage:
|
2407
2580
|
# to release prefetch events associated with the request
|
2408
2581
|
self.tree_cache.release_aborted_request(req.rid)
|
2409
|
-
self.send_to_tokenizer.send_pyobj(AbortReq(req.rid))
|
2582
|
+
self.send_to_tokenizer.send_pyobj(AbortReq(rid=req.rid))
|
2410
2583
|
# For disaggregation decode mode, the request in the waiting queue has KV cache allocated.
|
2411
2584
|
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
2412
2585
|
self.tree_cache.cache_finished_req(req)
|
@@ -2427,31 +2600,31 @@ class Scheduler(
|
|
2427
2600
|
# Delete requests not in the waiting queue when PD disaggregation is enabled
|
2428
2601
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
2429
2602
|
# Abort requests that have not yet been bootstrapped
|
2430
|
-
for
|
2431
|
-
logger.debug(f"Abort bootstrap queue request. {req.rid=}")
|
2603
|
+
for req in self.disagg_prefill_bootstrap_queue.queue:
|
2432
2604
|
if recv_req.abort_all or req.rid.startswith(recv_req.rid):
|
2605
|
+
logger.debug(f"Abort bootstrap queue request. {req.rid=}")
|
2433
2606
|
if hasattr(req.disagg_kv_sender, "abort"):
|
2434
2607
|
req.disagg_kv_sender.abort()
|
2435
2608
|
|
2436
2609
|
# Abort in-flight requests
|
2437
|
-
for
|
2438
|
-
logger.debug(f"Abort inflight queue request. {req.rid=}")
|
2610
|
+
for req in self.disagg_prefill_inflight_queue:
|
2439
2611
|
if recv_req.abort_all or req.rid.startswith(recv_req.rid):
|
2612
|
+
logger.debug(f"Abort inflight queue request. {req.rid=}")
|
2440
2613
|
if hasattr(req.disagg_kv_sender, "abort"):
|
2441
2614
|
req.disagg_kv_sender.abort()
|
2442
2615
|
|
2443
2616
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
2444
2617
|
# Abort requests that have not yet finished preallocation
|
2445
|
-
for
|
2446
|
-
logger.debug(f"Abort prealloc queue request. {decode_req.req.rid=}")
|
2618
|
+
for decode_req in self.disagg_decode_prealloc_queue.queue:
|
2447
2619
|
if recv_req.abort_all or decode_req.req.rid.startswith(recv_req.rid):
|
2620
|
+
logger.debug(f"Abort prealloc queue request. {decode_req.req.rid=}")
|
2448
2621
|
if hasattr(decode_req.kv_receiver, "abort"):
|
2449
2622
|
decode_req.kv_receiver.abort()
|
2450
2623
|
|
2451
2624
|
# Abort requests waiting for kvcache to release tree cache
|
2452
|
-
for
|
2453
|
-
logger.debug(f"Abort transfer queue request. {decode_req.req.rid=}")
|
2625
|
+
for decode_req in self.disagg_decode_transfer_queue.queue:
|
2454
2626
|
if recv_req.abort_all or decode_req.req.rid.startswith(recv_req.rid):
|
2627
|
+
logger.debug(f"Abort transfer queue request. {decode_req.req.rid=}")
|
2455
2628
|
if hasattr(decode_req.kv_receiver, "abort"):
|
2456
2629
|
decode_req.kv_receiver.abort()
|
2457
2630
|
|
@@ -2494,6 +2667,22 @@ class Scheduler(
|
|
2494
2667
|
self.send_to_detokenizer.send_pyobj(recv_req)
|
2495
2668
|
return recv_req
|
2496
2669
|
|
2670
|
+
def init_weights_send_group_for_remote_instance(
|
2671
|
+
self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
|
2672
|
+
):
|
2673
|
+
"""Init the seed and client instance communication group."""
|
2674
|
+
success, message = self.tp_worker.init_weights_send_group_for_remote_instance(
|
2675
|
+
recv_req
|
2676
|
+
)
|
2677
|
+
return InitWeightsSendGroupForRemoteInstanceReqOutput(success, message)
|
2678
|
+
|
2679
|
+
def send_weights_to_remote_instance(
|
2680
|
+
self, recv_req: SendWeightsToRemoteInstanceReqInput
|
2681
|
+
):
|
2682
|
+
"""Send the seed instance weights to the destination instance."""
|
2683
|
+
success, message = self.tp_worker.send_weights_to_remote_instance(recv_req)
|
2684
|
+
return SendWeightsToRemoteInstanceReqOutput(success, message)
|
2685
|
+
|
2497
2686
|
def slow_down(self, recv_req: SlowDownReqInput):
|
2498
2687
|
t = recv_req.forward_sleep_time
|
2499
2688
|
if t is not None and t <= 0:
|
@@ -2502,11 +2691,12 @@ class Scheduler(
|
|
2502
2691
|
return SlowDownReqOutput()
|
2503
2692
|
|
2504
2693
|
def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
|
2505
|
-
|
2694
|
+
action = recv_req.action
|
2695
|
+
if action == ExpertDistributionReqType.START_RECORD:
|
2506
2696
|
get_global_expert_distribution_recorder().start_record()
|
2507
|
-
elif
|
2697
|
+
elif action == ExpertDistributionReqType.STOP_RECORD:
|
2508
2698
|
get_global_expert_distribution_recorder().stop_record()
|
2509
|
-
elif
|
2699
|
+
elif action == ExpertDistributionReqType.DUMP_RECORD:
|
2510
2700
|
get_global_expert_distribution_recorder().dump_record()
|
2511
2701
|
else:
|
2512
2702
|
raise ValueError(f"Unrecognized ExpertDistributionReq value: {recv_req=}")
|
@@ -2589,7 +2779,8 @@ class IdleSleeper:
|
|
2589
2779
|
|
2590
2780
|
|
2591
2781
|
def is_health_check_generate_req(recv_req):
|
2592
|
-
|
2782
|
+
rid = getattr(recv_req, "rid", None)
|
2783
|
+
return rid is not None and rid.startswith("HEALTH_CHECK")
|
2593
2784
|
|
2594
2785
|
|
2595
2786
|
def is_work_request(recv_req):
|
@@ -2613,10 +2804,12 @@ def run_scheduler_process(
|
|
2613
2804
|
pp_rank: int,
|
2614
2805
|
dp_rank: Optional[int],
|
2615
2806
|
pipe_writer,
|
2616
|
-
balance_meta: Optional[DPBalanceMeta] = None,
|
2617
2807
|
):
|
2618
|
-
# Generate the prefix
|
2808
|
+
# Generate the logger prefix
|
2619
2809
|
prefix = ""
|
2810
|
+
if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
|
2811
|
+
# [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
|
2812
|
+
dp_rank = int(os.environ["SGLANG_DP_RANK"])
|
2620
2813
|
if dp_rank is not None:
|
2621
2814
|
prefix += f" DP{dp_rank}"
|
2622
2815
|
if server_args.tp_size > 1:
|
@@ -2632,10 +2825,6 @@ def run_scheduler_process(
|
|
2632
2825
|
kill_itself_when_parent_died()
|
2633
2826
|
parent_process = psutil.Process().parent()
|
2634
2827
|
|
2635
|
-
# [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
|
2636
|
-
if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
|
2637
|
-
dp_rank = int(os.environ["SGLANG_DP_RANK"])
|
2638
|
-
|
2639
2828
|
# Configure the logger
|
2640
2829
|
configure_logger(server_args, prefix=prefix)
|
2641
2830
|
suppress_other_loggers()
|
@@ -2643,6 +2832,15 @@ def run_scheduler_process(
|
|
2643
2832
|
# Set cpu affinity to this gpu process
|
2644
2833
|
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
|
2645
2834
|
set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
|
2835
|
+
if (numa_node := server_args.numa_node) is not None:
|
2836
|
+
numa_bind_to_node(numa_node[gpu_id])
|
2837
|
+
|
2838
|
+
# Set up tracing
|
2839
|
+
if server_args.enable_trace:
|
2840
|
+
process_tracing_init(server_args.oltp_traces_endpoint, "sglang")
|
2841
|
+
if server_args.disaggregation_mode == "null":
|
2842
|
+
thread_label = "Scheduler"
|
2843
|
+
trace_set_thread_info(thread_label, tp_rank, dp_rank)
|
2646
2844
|
|
2647
2845
|
# Create a scheduler and run the event loop
|
2648
2846
|
try:
|
@@ -2654,7 +2852,6 @@ def run_scheduler_process(
|
|
2654
2852
|
moe_ep_rank,
|
2655
2853
|
pp_rank,
|
2656
2854
|
dp_rank,
|
2657
|
-
dp_balance_meta=balance_meta,
|
2658
2855
|
)
|
2659
2856
|
pipe_writer.send(
|
2660
2857
|
{
|