sglang 0.5.2rc2__py3-none-any.whl → 0.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +267 -32
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/model_config.py +181 -82
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +71 -19
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +326 -53
- sglang/srt/disaggregation/prefill.py +36 -17
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +855 -0
- sglang/srt/entrypoints/grpc_server.py +810 -0
- sglang/srt/entrypoints/http_server.py +130 -59
- sglang/srt/entrypoints/openai/protocol.py +112 -4
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +204 -55
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +9 -2
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +118 -198
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
- sglang/srt/layers/attention/mamba/mamba.py +629 -0
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +44 -12
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +256 -63
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +22 -6
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +78 -49
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +190 -55
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +26 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +52 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +240 -138
- sglang/srt/managers/schedule_policy.py +144 -17
- sglang/srt/managers/scheduler.py +502 -209
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +320 -632
- sglang/srt/managers/tp_worker.py +81 -22
- sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +535 -58
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +222 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +25 -36
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +82 -40
- sglang/srt/model_executor/model_runner.py +432 -157
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +133 -5
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +158 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +40 -4
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +38 -17
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +966 -267
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +99 -28
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +433 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +77 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +383 -5
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/RECORD +375 -245
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
sglang/srt/managers/scheduler.py
CHANGED
@@ -44,6 +44,9 @@ from sglang.srt.disaggregation.decode import (
|
|
44
44
|
DecodeTransferQueue,
|
45
45
|
SchedulerDisaggregationDecodeMixin,
|
46
46
|
)
|
47
|
+
from sglang.srt.disaggregation.decode_kvcache_offload_manager import (
|
48
|
+
DecodeKVCacheOffloadManager,
|
49
|
+
)
|
47
50
|
from sglang.srt.disaggregation.prefill import (
|
48
51
|
PrefillBootstrapQueue,
|
49
52
|
SchedulerDisaggregationPrefillMixin,
|
@@ -57,11 +60,6 @@ from sglang.srt.disaggregation.utils import (
|
|
57
60
|
)
|
58
61
|
from sglang.srt.distributed import get_pp_group, get_world_group
|
59
62
|
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
60
|
-
from sglang.srt.hf_transformers_utils import (
|
61
|
-
get_processor,
|
62
|
-
get_tokenizer,
|
63
|
-
get_tokenizer_from_processor,
|
64
|
-
)
|
65
63
|
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
66
64
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
67
65
|
from sglang.srt.layers.moe import initialize_moe_config
|
@@ -72,20 +70,26 @@ from sglang.srt.managers.io_struct import (
|
|
72
70
|
ClearHiCacheReqInput,
|
73
71
|
ClearHiCacheReqOutput,
|
74
72
|
CloseSessionReqInput,
|
73
|
+
DestroyWeightsUpdateGroupReqInput,
|
75
74
|
ExpertDistributionReq,
|
76
75
|
ExpertDistributionReqOutput,
|
76
|
+
ExpertDistributionReqType,
|
77
77
|
FlushCacheReqInput,
|
78
78
|
FlushCacheReqOutput,
|
79
79
|
FreezeGCReq,
|
80
80
|
GetInternalStateReq,
|
81
81
|
GetInternalStateReqOutput,
|
82
|
+
GetLoadReqInput,
|
83
|
+
GetLoadReqOutput,
|
82
84
|
GetWeightsByNameReqInput,
|
83
85
|
HealthCheckOutput,
|
86
|
+
InitWeightsSendGroupForRemoteInstanceReqInput,
|
87
|
+
InitWeightsSendGroupForRemoteInstanceReqOutput,
|
84
88
|
InitWeightsUpdateGroupReqInput,
|
85
89
|
LoadLoRAAdapterReqInput,
|
86
90
|
LoadLoRAAdapterReqOutput,
|
87
91
|
MultiTokenizerRegisterReq,
|
88
|
-
|
92
|
+
MultiTokenizerWrapper,
|
89
93
|
OpenSessionReqInput,
|
90
94
|
OpenSessionReqOutput,
|
91
95
|
ProfileReq,
|
@@ -93,6 +97,8 @@ from sglang.srt.managers.io_struct import (
|
|
93
97
|
ResumeMemoryOccupationReqInput,
|
94
98
|
RpcReqInput,
|
95
99
|
RpcReqOutput,
|
100
|
+
SendWeightsToRemoteInstanceReqInput,
|
101
|
+
SendWeightsToRemoteInstanceReqOutput,
|
96
102
|
SetInternalStateReq,
|
97
103
|
SetInternalStateReqOutput,
|
98
104
|
SlowDownReqInput,
|
@@ -110,6 +116,7 @@ from sglang.srt.managers.schedule_batch import (
|
|
110
116
|
FINISH_ABORT,
|
111
117
|
MultimodalInputs,
|
112
118
|
Req,
|
119
|
+
RequestStage,
|
113
120
|
ScheduleBatch,
|
114
121
|
global_server_args_dict,
|
115
122
|
)
|
@@ -134,17 +141,28 @@ from sglang.srt.managers.scheduler_update_weights_mixin import (
|
|
134
141
|
from sglang.srt.managers.session_controller import Session
|
135
142
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
136
143
|
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
|
137
|
-
from sglang.srt.managers.utils import
|
144
|
+
from sglang.srt.managers.utils import validate_input_length
|
138
145
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
139
146
|
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
140
|
-
from sglang.srt.mem_cache.lora_radix_cache import LoRARadixCache
|
141
147
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
142
148
|
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
143
|
-
from sglang.srt.model_executor.forward_batch_info import
|
149
|
+
from sglang.srt.model_executor.forward_batch_info import (
|
150
|
+
ForwardBatchOutput,
|
151
|
+
ForwardMode,
|
152
|
+
PPProxyTensors,
|
153
|
+
)
|
144
154
|
from sglang.srt.parser.reasoning_parser import ReasoningParser
|
145
155
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
146
156
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
147
157
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
158
|
+
from sglang.srt.tracing.trace import (
|
159
|
+
process_tracing_init,
|
160
|
+
trace_set_proc_propagate_context,
|
161
|
+
trace_set_thread_info,
|
162
|
+
trace_slice_batch,
|
163
|
+
trace_slice_end,
|
164
|
+
trace_slice_start,
|
165
|
+
)
|
148
166
|
from sglang.srt.two_batch_overlap import TboDPAttentionPreparer
|
149
167
|
from sglang.srt.utils import (
|
150
168
|
DynamicGradMode,
|
@@ -155,9 +173,10 @@ from sglang.srt.utils import (
|
|
155
173
|
freeze_gc,
|
156
174
|
get_available_gpu_memory,
|
157
175
|
get_bool_env_var,
|
176
|
+
get_int_env_var,
|
158
177
|
get_zmq_socket,
|
159
|
-
is_cpu,
|
160
178
|
kill_itself_when_parent_died,
|
179
|
+
numa_bind_to_node,
|
161
180
|
point_to_point_pyobj,
|
162
181
|
pyspy_dump_schedulers,
|
163
182
|
require_mlp_sync,
|
@@ -166,6 +185,11 @@ from sglang.srt.utils import (
|
|
166
185
|
set_random_seed,
|
167
186
|
suppress_other_loggers,
|
168
187
|
)
|
188
|
+
from sglang.srt.utils.hf_transformers_utils import (
|
189
|
+
get_processor,
|
190
|
+
get_tokenizer,
|
191
|
+
get_tokenizer_from_processor,
|
192
|
+
)
|
169
193
|
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
|
170
194
|
|
171
195
|
logger = logging.getLogger(__name__)
|
@@ -174,24 +198,59 @@ logger = logging.getLogger(__name__)
|
|
174
198
|
TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
|
175
199
|
GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
|
176
200
|
|
177
|
-
_is_cpu = is_cpu()
|
178
|
-
|
179
201
|
|
180
202
|
@dataclass
|
181
203
|
class GenerationBatchResult:
|
182
204
|
logits_output: Optional[LogitsProcessorOutput]
|
183
|
-
pp_hidden_states_proxy_tensors: Optional[
|
205
|
+
pp_hidden_states_proxy_tensors: Optional[PPProxyTensors]
|
184
206
|
next_token_ids: Optional[List[int]]
|
207
|
+
can_run_cuda_graph: bool
|
208
|
+
|
209
|
+
# For output processing
|
185
210
|
extend_input_len_per_req: List[int]
|
186
211
|
extend_logprob_start_len_per_req: List[int]
|
187
|
-
|
188
|
-
|
212
|
+
|
213
|
+
@classmethod
|
214
|
+
def from_forward_batch_output(
|
215
|
+
cls,
|
216
|
+
forward_batch_output: ForwardBatchOutput,
|
217
|
+
extend_input_len_per_req: List[int],
|
218
|
+
extend_logprob_start_len_per_req: List[int],
|
219
|
+
):
|
220
|
+
# TODO(lsyin): remove this workaround logic and try to unify output classes
|
221
|
+
|
222
|
+
return cls(
|
223
|
+
logits_output=forward_batch_output.logits_output,
|
224
|
+
pp_hidden_states_proxy_tensors=forward_batch_output.pp_proxy_tensors,
|
225
|
+
next_token_ids=forward_batch_output.next_token_ids,
|
226
|
+
extend_input_len_per_req=extend_input_len_per_req,
|
227
|
+
extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
|
228
|
+
can_run_cuda_graph=forward_batch_output.can_run_cuda_graph,
|
229
|
+
)
|
230
|
+
|
231
|
+
@classmethod
|
232
|
+
def from_pp_proxy(
|
233
|
+
cls, logits_output, next_pp_outputs: PPProxyTensors, can_run_cuda_graph
|
234
|
+
):
|
235
|
+
# TODO(lsyin): also simplify this logic
|
236
|
+
# Current PP implementation in scheduler is not compatible with ForwardBatchOutput
|
237
|
+
# Maybe introduce a ProxyBatchOutput for PP and the original ForwardBatchOutput for TP
|
238
|
+
proxy_dict = next_pp_outputs.tensors
|
239
|
+
return cls(
|
240
|
+
logits_output=logits_output,
|
241
|
+
pp_hidden_states_proxy_tensors=None,
|
242
|
+
next_token_ids=next_pp_outputs["next_token_ids"],
|
243
|
+
extend_input_len_per_req=proxy_dict.get("extend_input_len_per_req", None),
|
244
|
+
extend_logprob_start_len_per_req=proxy_dict.get(
|
245
|
+
"extend_logprob_start_len_per_req", None
|
246
|
+
),
|
247
|
+
can_run_cuda_graph=can_run_cuda_graph,
|
248
|
+
)
|
189
249
|
|
190
250
|
|
191
251
|
@dataclass
|
192
252
|
class EmbeddingBatchResult:
|
193
253
|
embeddings: torch.Tensor
|
194
|
-
bid: int
|
195
254
|
|
196
255
|
|
197
256
|
class Scheduler(
|
@@ -213,7 +272,6 @@ class Scheduler(
|
|
213
272
|
moe_ep_rank: int,
|
214
273
|
pp_rank: int,
|
215
274
|
dp_rank: Optional[int],
|
216
|
-
dp_balance_meta: Optional[DPBalanceMeta] = None,
|
217
275
|
):
|
218
276
|
# Parse args
|
219
277
|
self.server_args = server_args
|
@@ -226,6 +284,13 @@ class Scheduler(
|
|
226
284
|
self.pp_size = server_args.pp_size
|
227
285
|
self.dp_size = server_args.dp_size
|
228
286
|
self.schedule_policy = server_args.schedule_policy
|
287
|
+
self.enable_priority_scheduling = server_args.enable_priority_scheduling
|
288
|
+
self.schedule_low_priority_values_first = (
|
289
|
+
server_args.schedule_low_priority_values_first
|
290
|
+
)
|
291
|
+
self.priority_scheduling_preemption_threshold = (
|
292
|
+
server_args.priority_scheduling_preemption_threshold
|
293
|
+
)
|
229
294
|
self.enable_lora = server_args.enable_lora
|
230
295
|
self.max_loras_per_batch = server_args.max_loras_per_batch
|
231
296
|
self.enable_overlap = not server_args.disable_overlap_schedule
|
@@ -234,7 +299,10 @@ class Scheduler(
|
|
234
299
|
self.enable_metrics_for_all_schedulers = (
|
235
300
|
server_args.enable_metrics_for_all_schedulers
|
236
301
|
)
|
237
|
-
self.enable_kv_cache_events =
|
302
|
+
self.enable_kv_cache_events = bool(
|
303
|
+
server_args.kv_events_config and tp_rank == 0
|
304
|
+
)
|
305
|
+
self.enable_trace = server_args.enable_trace
|
238
306
|
self.stream_interval = server_args.stream_interval
|
239
307
|
self.spec_algorithm = SpeculativeAlgorithm.from_string(
|
240
308
|
server_args.speculative_algorithm
|
@@ -348,9 +416,39 @@ class Scheduler(
|
|
348
416
|
target_worker=self.tp_worker,
|
349
417
|
dp_rank=dp_rank,
|
350
418
|
)
|
419
|
+
elif self.spec_algorithm.is_standalone():
|
420
|
+
from sglang.srt.speculative.standalone_worker import StandaloneWorker
|
421
|
+
|
422
|
+
self.draft_worker = StandaloneWorker(
|
423
|
+
gpu_id=gpu_id,
|
424
|
+
tp_rank=tp_rank,
|
425
|
+
moe_ep_rank=moe_ep_rank,
|
426
|
+
server_args=server_args,
|
427
|
+
nccl_port=port_args.nccl_port,
|
428
|
+
target_worker=self.tp_worker,
|
429
|
+
dp_rank=dp_rank,
|
430
|
+
)
|
431
|
+
elif self.spec_algorithm.is_ngram():
|
432
|
+
from sglang.srt.speculative.ngram_worker import NGRAMWorker
|
433
|
+
|
434
|
+
self.draft_worker = NGRAMWorker(
|
435
|
+
gpu_id=gpu_id,
|
436
|
+
tp_rank=tp_rank,
|
437
|
+
moe_ep_rank=moe_ep_rank,
|
438
|
+
server_args=server_args,
|
439
|
+
nccl_port=port_args.nccl_port,
|
440
|
+
target_worker=self.tp_worker,
|
441
|
+
dp_rank=dp_rank,
|
442
|
+
)
|
351
443
|
else:
|
352
444
|
self.draft_worker = None
|
353
445
|
|
446
|
+
# Dispatch the model worker
|
447
|
+
if self.spec_algorithm.is_none():
|
448
|
+
self.model_worker = self.tp_worker
|
449
|
+
else:
|
450
|
+
self.model_worker = self.draft_worker
|
451
|
+
|
354
452
|
# Get token and memory info from the model worker
|
355
453
|
(
|
356
454
|
self.max_total_num_tokens,
|
@@ -401,7 +499,7 @@ class Scheduler(
|
|
401
499
|
f"max_prefill_tokens={self.max_prefill_tokens}, "
|
402
500
|
f"max_running_requests={self.max_running_requests}, "
|
403
501
|
f"context_len={self.model_config.context_len}, "
|
404
|
-
f"available_gpu_mem={avail_mem:.2f} GB"
|
502
|
+
f"{'available_cpu_mem' if self.device == 'cpu' else 'available_gpu_mem'}={avail_mem:.2f} GB"
|
405
503
|
)
|
406
504
|
|
407
505
|
# Init memory pool and cache
|
@@ -458,7 +556,12 @@ class Scheduler(
|
|
458
556
|
self.schedule_policy,
|
459
557
|
self.tree_cache,
|
460
558
|
self.enable_hierarchical_cache,
|
559
|
+
self.enable_priority_scheduling,
|
560
|
+
self.schedule_low_priority_values_first,
|
461
561
|
)
|
562
|
+
# Enable preemption for priority scheduling.
|
563
|
+
self.try_preemption = self.enable_priority_scheduling
|
564
|
+
|
462
565
|
assert (
|
463
566
|
server_args.schedule_conservativeness >= 0
|
464
567
|
), "Invalid schedule_conservativeness"
|
@@ -488,7 +591,7 @@ class Scheduler(
|
|
488
591
|
enable=server_args.enable_memory_saver
|
489
592
|
)
|
490
593
|
self.offload_tags = set()
|
491
|
-
self.
|
594
|
+
self.init_profiler()
|
492
595
|
|
493
596
|
self.recv_skipper = SchedulerRecvSkipper.maybe_create(server_args)
|
494
597
|
self.input_blocker = (
|
@@ -499,8 +602,9 @@ class Scheduler(
|
|
499
602
|
|
500
603
|
# Init metrics stats
|
501
604
|
self.init_metrics(tp_rank, pp_rank, dp_rank)
|
502
|
-
|
503
|
-
self.
|
605
|
+
|
606
|
+
if self.enable_kv_cache_events:
|
607
|
+
self.init_kv_events(server_args.kv_events_config)
|
504
608
|
|
505
609
|
# Init disaggregation
|
506
610
|
self.disaggregation_mode = DisaggregationMode(
|
@@ -511,6 +615,9 @@ class Scheduler(
|
|
511
615
|
if get_bool_env_var("SGLANG_GC_LOG"):
|
512
616
|
configure_gc_logger()
|
513
617
|
|
618
|
+
# Init prefill kv split size when deterministic inference is enabled with various attention backends
|
619
|
+
self.init_deterministic_inference_config()
|
620
|
+
|
514
621
|
# Init request dispatcher
|
515
622
|
self._request_dispatcher = TypeBasedDispatcher(
|
516
623
|
[
|
@@ -525,6 +632,15 @@ class Scheduler(
|
|
525
632
|
(CloseSessionReqInput, self.close_session),
|
526
633
|
(UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
|
527
634
|
(InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
|
635
|
+
(DestroyWeightsUpdateGroupReqInput, self.destroy_weights_update_group),
|
636
|
+
(
|
637
|
+
InitWeightsSendGroupForRemoteInstanceReqInput,
|
638
|
+
self.init_weights_send_group_for_remote_instance,
|
639
|
+
),
|
640
|
+
(
|
641
|
+
SendWeightsToRemoteInstanceReqInput,
|
642
|
+
self.send_weights_to_remote_instance,
|
643
|
+
),
|
528
644
|
(
|
529
645
|
UpdateWeightsFromDistributedReqInput,
|
530
646
|
self.update_weights_from_distributed,
|
@@ -543,9 +659,27 @@ class Scheduler(
|
|
543
659
|
(LoadLoRAAdapterReqInput, self.load_lora_adapter),
|
544
660
|
(UnloadLoRAAdapterReqInput, self.unload_lora_adapter),
|
545
661
|
(MultiTokenizerRegisterReq, self.register_multi_tokenizer),
|
662
|
+
(GetLoadReqInput, self.get_load),
|
546
663
|
]
|
547
664
|
)
|
548
665
|
|
666
|
+
def init_deterministic_inference_config(self):
|
667
|
+
"""Initialize deterministic inference configuration for different attention backends."""
|
668
|
+
if not self.server_args.enable_deterministic_inference:
|
669
|
+
self.truncation_align_size = None
|
670
|
+
return
|
671
|
+
|
672
|
+
backend_sizes = {
|
673
|
+
"flashinfer": ("SGLANG_FLASHINFER_PREFILL_SPLIT_TILE_SIZE", 4096),
|
674
|
+
"triton": ("SGLANG_TRITON_PREFILL_TRUNCATION_ALIGN_SIZE", 4096),
|
675
|
+
}
|
676
|
+
env_var, default_size = backend_sizes.get(
|
677
|
+
self.server_args.attention_backend, (None, None)
|
678
|
+
)
|
679
|
+
self.truncation_align_size = (
|
680
|
+
get_int_env_var(env_var, default_size) if env_var else None
|
681
|
+
)
|
682
|
+
|
549
683
|
def init_tokenizer(self):
|
550
684
|
server_args = self.server_args
|
551
685
|
self.is_generation = self.model_config.is_generation
|
@@ -617,15 +751,18 @@ class Scheduler(
|
|
617
751
|
else self.tp_cpu_group
|
618
752
|
),
|
619
753
|
page_size=self.page_size,
|
754
|
+
eviction_policy=server_args.radix_eviction_policy,
|
620
755
|
hicache_ratio=server_args.hicache_ratio,
|
621
756
|
hicache_size=server_args.hicache_size,
|
622
757
|
hicache_write_policy=server_args.hicache_write_policy,
|
623
758
|
hicache_io_backend=server_args.hicache_io_backend,
|
624
759
|
hicache_mem_layout=server_args.hicache_mem_layout,
|
760
|
+
enable_metrics=self.enable_metrics,
|
625
761
|
hicache_storage_backend=server_args.hicache_storage_backend,
|
626
762
|
hicache_storage_prefetch_policy=server_args.hicache_storage_prefetch_policy,
|
627
763
|
model_name=server_args.served_model_name,
|
628
764
|
storage_backend_extra_config=server_args.hicache_storage_backend_extra_config,
|
765
|
+
is_eagle=self.spec_algorithm.is_eagle(),
|
629
766
|
)
|
630
767
|
self.tp_worker.register_hicache_layer_transfer_counter(
|
631
768
|
self.tree_cache.cache_controller.layer_done_counter
|
@@ -641,18 +778,21 @@ class Scheduler(
|
|
641
778
|
page_size=self.page_size,
|
642
779
|
disable=server_args.disable_radix_cache,
|
643
780
|
)
|
644
|
-
elif
|
645
|
-
|
646
|
-
|
647
|
-
)
|
648
|
-
|
649
|
-
|
650
|
-
), "LoRA radix cache only supports FCFS policy"
|
651
|
-
self.tree_cache = LoRARadixCache(
|
781
|
+
elif server_args.enable_lmcache:
|
782
|
+
from sglang.srt.mem_cache.storage.lmcache.lmc_radix_cache import (
|
783
|
+
LMCRadixCache,
|
784
|
+
)
|
785
|
+
|
786
|
+
self.tree_cache = LMCRadixCache(
|
652
787
|
req_to_token_pool=self.req_to_token_pool,
|
653
788
|
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
654
789
|
page_size=self.page_size,
|
655
790
|
disable=server_args.disable_radix_cache,
|
791
|
+
model_config=self.model_config,
|
792
|
+
tp_size=self.tp_size,
|
793
|
+
rank=self.tp_rank,
|
794
|
+
tp_group=self.tp_group,
|
795
|
+
eviction_policy=server_args.radix_eviction_policy,
|
656
796
|
)
|
657
797
|
else:
|
658
798
|
self.tree_cache = RadixCache(
|
@@ -661,16 +801,36 @@ class Scheduler(
|
|
661
801
|
page_size=self.page_size,
|
662
802
|
disable=server_args.disable_radix_cache,
|
663
803
|
enable_kv_cache_events=self.enable_kv_cache_events,
|
804
|
+
eviction_policy=server_args.radix_eviction_policy,
|
805
|
+
is_eagle=self.spec_algorithm.is_eagle(),
|
664
806
|
)
|
665
807
|
|
808
|
+
if (
|
809
|
+
server_args.disaggregation_mode == "decode"
|
810
|
+
and server_args.disaggregation_decode_enable_offload_kvcache
|
811
|
+
):
|
812
|
+
self.decode_offload_manager = DecodeKVCacheOffloadManager(
|
813
|
+
req_to_token_pool=self.req_to_token_pool,
|
814
|
+
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
815
|
+
tp_group=(
|
816
|
+
self.attn_tp_cpu_group
|
817
|
+
if self.server_args.enable_dp_attention
|
818
|
+
else self.tp_cpu_group
|
819
|
+
),
|
820
|
+
tree_cache=self.tree_cache,
|
821
|
+
server_args=self.server_args,
|
822
|
+
)
|
823
|
+
else:
|
824
|
+
self.decode_offload_manager = None
|
825
|
+
|
666
826
|
self.decode_mem_cache_buf_multiplier = (
|
667
827
|
1
|
668
828
|
if self.spec_algorithm.is_none()
|
669
829
|
else (
|
670
830
|
server_args.speculative_num_draft_tokens
|
671
831
|
+ (
|
672
|
-
server_args.speculative_eagle_topk
|
673
|
-
* server_args.speculative_num_steps
|
832
|
+
(server_args.speculative_eagle_topk or 1)
|
833
|
+
* (server_args.speculative_num_steps or 1)
|
674
834
|
)
|
675
835
|
)
|
676
836
|
)
|
@@ -693,7 +853,7 @@ class Scheduler(
|
|
693
853
|
self.disagg_metadata_buffers = MetadataBuffers(
|
694
854
|
buffer_size,
|
695
855
|
hidden_size=self.model_config.hf_text_config.hidden_size,
|
696
|
-
|
856
|
+
hidden_states_dtype=self.model_config.dtype,
|
697
857
|
custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
|
698
858
|
)
|
699
859
|
|
@@ -713,7 +873,7 @@ class Scheduler(
|
|
713
873
|
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
714
874
|
draft_token_to_kv_pool=(
|
715
875
|
None
|
716
|
-
if self.draft_worker is None
|
876
|
+
if self.draft_worker is None or self.spec_algorithm.is_ngram()
|
717
877
|
else self.draft_worker.model_runner.token_to_kv_pool
|
718
878
|
),
|
719
879
|
req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
|
@@ -742,7 +902,7 @@ class Scheduler(
|
|
742
902
|
self.disagg_metadata_buffers = MetadataBuffers(
|
743
903
|
buffer_size,
|
744
904
|
hidden_size=self.model_config.hf_text_config.hidden_size,
|
745
|
-
|
905
|
+
hidden_states_dtype=self.model_config.dtype,
|
746
906
|
custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
|
747
907
|
)
|
748
908
|
|
@@ -750,7 +910,7 @@ class Scheduler(
|
|
750
910
|
token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
|
751
911
|
draft_token_to_kv_pool=(
|
752
912
|
None
|
753
|
-
if self.draft_worker is None
|
913
|
+
if self.draft_worker is None or self.spec_algorithm.is_ngram()
|
754
914
|
else self.draft_worker.model_runner.token_to_kv_pool
|
755
915
|
),
|
756
916
|
req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
|
@@ -845,7 +1005,6 @@ class Scheduler(
|
|
845
1005
|
self.running_mbs = [
|
846
1006
|
ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
|
847
1007
|
]
|
848
|
-
bids = [None] * self.pp_size
|
849
1008
|
pp_outputs: Optional[PPProxyTensors] = None
|
850
1009
|
while True:
|
851
1010
|
server_is_idle = True
|
@@ -866,10 +1025,7 @@ class Scheduler(
|
|
866
1025
|
# (last rank) send the outputs to the next step
|
867
1026
|
if self.pp_group.is_last_rank:
|
868
1027
|
if self.cur_batch:
|
869
|
-
next_token_ids
|
870
|
-
result.next_token_ids,
|
871
|
-
result.bid,
|
872
|
-
)
|
1028
|
+
next_token_ids = result.next_token_ids
|
873
1029
|
if self.cur_batch.return_logprob:
|
874
1030
|
pp_outputs = PPProxyTensors(
|
875
1031
|
{
|
@@ -917,17 +1073,10 @@ class Scheduler(
|
|
917
1073
|
logits_output = LogitsProcessorOutput(**logits_output_args)
|
918
1074
|
else:
|
919
1075
|
logits_output = None
|
920
|
-
|
1076
|
+
|
1077
|
+
output_result = GenerationBatchResult.from_pp_proxy(
|
921
1078
|
logits_output=logits_output,
|
922
|
-
|
923
|
-
next_token_ids=next_pp_outputs["next_token_ids"],
|
924
|
-
extend_input_len_per_req=next_pp_outputs.tensors.get(
|
925
|
-
"extend_input_len_per_req", None
|
926
|
-
),
|
927
|
-
extend_logprob_start_len_per_req=next_pp_outputs.tensors.get(
|
928
|
-
"extend_logprob_start_len_per_req", None
|
929
|
-
),
|
930
|
-
bid=bids[next_mb_id],
|
1079
|
+
next_pp_outputs=next_pp_outputs,
|
931
1080
|
can_run_cuda_graph=result.can_run_cuda_graph,
|
932
1081
|
)
|
933
1082
|
self.process_batch_result(mbs[next_mb_id], output_result)
|
@@ -935,8 +1084,6 @@ class Scheduler(
|
|
935
1084
|
|
936
1085
|
# (not last rank)
|
937
1086
|
if not self.pp_group.is_last_rank:
|
938
|
-
if self.cur_batch:
|
939
|
-
bids[mb_id] = result.bid
|
940
1087
|
# carry the outputs to the next stage
|
941
1088
|
# send the outputs from the last round to let the next stage worker run post processing
|
942
1089
|
if pp_outputs:
|
@@ -958,8 +1105,10 @@ class Scheduler(
|
|
958
1105
|
|
959
1106
|
# send out proxy tensors to the next stage
|
960
1107
|
if self.cur_batch:
|
1108
|
+
# FIXME(lsyin): remove this assert
|
1109
|
+
assert result.pp_hidden_states_proxy_tensors.tensors is not None
|
961
1110
|
self.pp_group.send_tensor_dict(
|
962
|
-
result.pp_hidden_states_proxy_tensors,
|
1111
|
+
result.pp_hidden_states_proxy_tensors.tensors,
|
963
1112
|
all_gather_group=self.attn_tp_group,
|
964
1113
|
)
|
965
1114
|
|
@@ -1069,6 +1218,15 @@ class Scheduler(
|
|
1069
1218
|
self.tp_cpu_group,
|
1070
1219
|
src=self.tp_group.ranks[0],
|
1071
1220
|
)
|
1221
|
+
|
1222
|
+
if self.enable_trace:
|
1223
|
+
for req in recv_reqs:
|
1224
|
+
if isinstance(
|
1225
|
+
req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
|
1226
|
+
):
|
1227
|
+
trace_set_proc_propagate_context(req.rid, req.trace_context)
|
1228
|
+
trace_slice_start("", req.rid, anonymous=True)
|
1229
|
+
|
1072
1230
|
return recv_reqs
|
1073
1231
|
|
1074
1232
|
def process_input_requests(self, recv_reqs: List):
|
@@ -1082,27 +1240,13 @@ class Scheduler(
|
|
1082
1240
|
self.return_health_check_ct += 1
|
1083
1241
|
continue
|
1084
1242
|
|
1085
|
-
# If it is a
|
1086
|
-
if
|
1087
|
-
if len(self.waiting_queue) + 1 > self.max_queued_requests:
|
1088
|
-
abort_req = AbortReq(
|
1089
|
-
recv_req.rid,
|
1090
|
-
finished_reason={
|
1091
|
-
"type": "abort",
|
1092
|
-
"status_code": HTTPStatus.SERVICE_UNAVAILABLE,
|
1093
|
-
"message": "The request queue is full.",
|
1094
|
-
},
|
1095
|
-
)
|
1096
|
-
self.send_to_tokenizer.send_pyobj(abort_req)
|
1097
|
-
continue
|
1098
|
-
|
1099
|
-
# If it is a MultiTokenizerWarpper, unwrap it and handle the inner request.
|
1100
|
-
if isinstance(recv_req, MultiTokenizerWarpper):
|
1243
|
+
# If it is a MultiTokenizerWrapper, unwrap it and handle the inner request.
|
1244
|
+
if isinstance(recv_req, MultiTokenizerWrapper):
|
1101
1245
|
worker_id = recv_req.worker_id
|
1102
1246
|
recv_req = recv_req.obj
|
1103
1247
|
output = self._request_dispatcher(recv_req)
|
1104
1248
|
if output is not None:
|
1105
|
-
output =
|
1249
|
+
output = MultiTokenizerWrapper(worker_id, output)
|
1106
1250
|
self.send_to_tokenizer.send_pyobj(output)
|
1107
1251
|
continue
|
1108
1252
|
|
@@ -1114,12 +1258,20 @@ class Scheduler(
|
|
1114
1258
|
else:
|
1115
1259
|
self.send_to_tokenizer.send_pyobj(output)
|
1116
1260
|
|
1261
|
+
def init_req_max_new_tokens(self, req):
|
1262
|
+
req.sampling_params.max_new_tokens = min(
|
1263
|
+
(
|
1264
|
+
req.sampling_params.max_new_tokens
|
1265
|
+
if req.sampling_params.max_new_tokens is not None
|
1266
|
+
else 1 << 30
|
1267
|
+
),
|
1268
|
+
self.max_req_len - len(req.origin_input_ids) - 1,
|
1269
|
+
)
|
1270
|
+
|
1117
1271
|
def handle_generate_request(
|
1118
1272
|
self,
|
1119
1273
|
recv_req: TokenizedGenerateReqInput,
|
1120
1274
|
):
|
1121
|
-
self.maybe_update_dp_balance_data(recv_req)
|
1122
|
-
|
1123
1275
|
# Create a new request
|
1124
1276
|
if (
|
1125
1277
|
recv_req.session_params is None
|
@@ -1153,8 +1305,13 @@ class Scheduler(
|
|
1153
1305
|
bootstrap_host=recv_req.bootstrap_host,
|
1154
1306
|
bootstrap_port=recv_req.bootstrap_port,
|
1155
1307
|
bootstrap_room=recv_req.bootstrap_room,
|
1308
|
+
disagg_mode=self.disaggregation_mode,
|
1156
1309
|
data_parallel_rank=recv_req.data_parallel_rank,
|
1157
1310
|
vocab_size=self.model_config.vocab_size,
|
1311
|
+
priority=recv_req.priority,
|
1312
|
+
metrics_collector=(
|
1313
|
+
self.metrics_collector if self.enable_metrics else None
|
1314
|
+
),
|
1158
1315
|
)
|
1159
1316
|
req.tokenizer = self.tokenizer
|
1160
1317
|
|
@@ -1177,6 +1334,7 @@ class Scheduler(
|
|
1177
1334
|
req.set_finish_with_abort(
|
1178
1335
|
f"Invalid request: session id {recv_req.session_params.id} does not exist"
|
1179
1336
|
)
|
1337
|
+
self.init_req_max_new_tokens(req)
|
1180
1338
|
self._add_request_to_queue(req)
|
1181
1339
|
return
|
1182
1340
|
else:
|
@@ -1184,6 +1342,7 @@ class Scheduler(
|
|
1184
1342
|
session = self.sessions[recv_req.session_params.id]
|
1185
1343
|
req = session.create_req(recv_req, self.tokenizer)
|
1186
1344
|
if isinstance(req.finished_reason, FINISH_ABORT):
|
1345
|
+
self.init_req_max_new_tokens(req)
|
1187
1346
|
self._add_request_to_queue(req)
|
1188
1347
|
return
|
1189
1348
|
|
@@ -1203,9 +1362,13 @@ class Scheduler(
|
|
1203
1362
|
f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
|
1204
1363
|
)
|
1205
1364
|
)
|
1365
|
+
self.init_req_max_new_tokens(req)
|
1206
1366
|
self._add_request_to_queue(req)
|
1207
1367
|
return
|
1208
1368
|
|
1369
|
+
# initialize before returning
|
1370
|
+
self.init_req_max_new_tokens(req)
|
1371
|
+
|
1209
1372
|
# Validate prompt length
|
1210
1373
|
error_msg = validate_input_length(
|
1211
1374
|
req,
|
@@ -1220,26 +1383,25 @@ class Scheduler(
|
|
1220
1383
|
# Copy more attributes
|
1221
1384
|
if recv_req.logprob_start_len == -1 or not recv_req.return_logprob:
|
1222
1385
|
# By default, only return the logprobs for output tokens
|
1223
|
-
|
1386
|
+
# For prefill-only requests with logprob_start_len == -1, set logprob_start_len beyond input sequence
|
1387
|
+
# to skip input logprob computation entirely
|
1388
|
+
if req.is_prefill_only:
|
1389
|
+
req.logprob_start_len = len(req.origin_input_ids)
|
1390
|
+
else:
|
1391
|
+
# TODO: For text generation, evaluate setting logprob_start_len to len(req.origin_input_ids) as well
|
1392
|
+
req.logprob_start_len = len(req.origin_input_ids) - 1
|
1224
1393
|
else:
|
1225
1394
|
req.logprob_start_len = recv_req.logprob_start_len
|
1226
1395
|
|
1227
|
-
if req.logprob_start_len >= len(
|
1396
|
+
if not req.is_prefill_only and req.logprob_start_len >= len(
|
1397
|
+
req.origin_input_ids
|
1398
|
+
):
|
1228
1399
|
error_msg = f"{req.logprob_start_len=} is higher than the number of input tokens {len(req.origin_input_ids)=}. Please use a smaller logprob_start_len."
|
1229
1400
|
req.logprob_start_len = len(req.origin_input_ids) - 1
|
1230
1401
|
req.set_finish_with_abort(error_msg)
|
1231
1402
|
self._add_request_to_queue(req)
|
1232
1403
|
return
|
1233
1404
|
|
1234
|
-
req.sampling_params.max_new_tokens = min(
|
1235
|
-
(
|
1236
|
-
req.sampling_params.max_new_tokens
|
1237
|
-
if req.sampling_params.max_new_tokens is not None
|
1238
|
-
else 1 << 30
|
1239
|
-
),
|
1240
|
-
self.max_req_len - len(req.origin_input_ids) - 1,
|
1241
|
-
)
|
1242
|
-
|
1243
1405
|
# Init grammar cache for this request
|
1244
1406
|
add_to_grammar_queue = False
|
1245
1407
|
if (
|
@@ -1270,7 +1432,6 @@ class Scheduler(
|
|
1270
1432
|
req.set_finish_with_abort(error_msg)
|
1271
1433
|
|
1272
1434
|
if add_to_grammar_queue:
|
1273
|
-
req.queue_time_start = time.perf_counter()
|
1274
1435
|
self.grammar_queue.append(req)
|
1275
1436
|
else:
|
1276
1437
|
self._add_request_to_queue(req)
|
@@ -1286,19 +1447,6 @@ class Scheduler(
|
|
1286
1447
|
for tokenized_req in recv_req:
|
1287
1448
|
self.handle_generate_request(tokenized_req)
|
1288
1449
|
|
1289
|
-
def _add_request_to_queue(self, req: Req):
|
1290
|
-
req.queue_time_start = time.perf_counter()
|
1291
|
-
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
1292
|
-
self._prefetch_kvcache(req)
|
1293
|
-
self.disagg_prefill_bootstrap_queue.add(
|
1294
|
-
req, self.model_config.num_key_value_heads
|
1295
|
-
)
|
1296
|
-
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
1297
|
-
self.disagg_decode_prealloc_queue.add(req)
|
1298
|
-
else:
|
1299
|
-
self._prefetch_kvcache(req)
|
1300
|
-
self.waiting_queue.append(req)
|
1301
|
-
|
1302
1450
|
def _prefetch_kvcache(self, req: Req):
|
1303
1451
|
if self.enable_hicache_storage:
|
1304
1452
|
req.init_next_round_input(self.tree_cache)
|
@@ -1312,16 +1460,87 @@ class Scheduler(
|
|
1312
1460
|
req.rid, req.last_host_node, new_input_tokens, last_hash
|
1313
1461
|
)
|
1314
1462
|
|
1315
|
-
def
|
1316
|
-
if self.disaggregation_mode == DisaggregationMode.
|
1317
|
-
self.
|
1318
|
-
|
1463
|
+
def _add_request_to_queue(self, req: Req, is_retracted: bool = False):
|
1464
|
+
if self.disaggregation_mode == DisaggregationMode.NULL:
|
1465
|
+
self._set_or_validate_priority(req)
|
1466
|
+
if self._abort_on_queued_limit(req):
|
1467
|
+
return
|
1468
|
+
self._prefetch_kvcache(req)
|
1469
|
+
self.waiting_queue.append(req)
|
1470
|
+
req.time_stats.wait_queue_entry_time = time.perf_counter()
|
1471
|
+
trace_slice_end("process req", req.rid, auto_next_anon=True)
|
1472
|
+
elif self.disaggregation_mode == DisaggregationMode.PREFILL:
|
1473
|
+
self._prefetch_kvcache(req)
|
1474
|
+
self.disagg_prefill_bootstrap_queue.add(
|
1475
|
+
req, self.model_config.num_key_value_heads
|
1319
1476
|
)
|
1477
|
+
req.time_stats.prefill_bootstrap_queue_entry_time = time.perf_counter()
|
1320
1478
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
1321
|
-
|
1322
|
-
|
1479
|
+
self.disagg_decode_prealloc_queue.add(req, is_retracted=is_retracted)
|
1480
|
+
if not is_retracted:
|
1481
|
+
req.time_stats.decode_prealloc_queue_entry_time = time.perf_counter()
|
1323
1482
|
else:
|
1324
|
-
self.
|
1483
|
+
raise ValueError(f"Invalid {self.disaggregation_mode=}")
|
1484
|
+
|
1485
|
+
def _set_or_validate_priority(self, req: Req):
|
1486
|
+
"""Set the default priority value, or abort the request based on the priority scheduling mode."""
|
1487
|
+
if self.enable_priority_scheduling and req.priority is None:
|
1488
|
+
if self.schedule_low_priority_values_first:
|
1489
|
+
req.priority = sys.maxsize
|
1490
|
+
else:
|
1491
|
+
req.priority = -sys.maxsize - 1
|
1492
|
+
elif not self.enable_priority_scheduling and req.priority is not None:
|
1493
|
+
abort_req = AbortReq(
|
1494
|
+
finished_reason={
|
1495
|
+
"type": "abort",
|
1496
|
+
"status_code": HTTPStatus.SERVICE_UNAVAILABLE,
|
1497
|
+
"message": "Using priority is disabled for this server. Please send a new request without a priority.",
|
1498
|
+
},
|
1499
|
+
rid=req.rid,
|
1500
|
+
)
|
1501
|
+
self.send_to_tokenizer.send_pyobj(abort_req)
|
1502
|
+
|
1503
|
+
def _abort_on_queued_limit(self, recv_req: Req) -> bool:
|
1504
|
+
"""Abort an incoming or existing request if the waiting queue is full. Returns True if the incoming request is aborted."""
|
1505
|
+
if (
|
1506
|
+
self.max_queued_requests is None
|
1507
|
+
or len(self.waiting_queue) + 1 <= self.max_queued_requests
|
1508
|
+
):
|
1509
|
+
return False
|
1510
|
+
|
1511
|
+
# Reject the incoming request by default.
|
1512
|
+
req_to_abort = recv_req
|
1513
|
+
message = "The request queue is full."
|
1514
|
+
if self.enable_priority_scheduling:
|
1515
|
+
# With priority scheduling, consider aboritng an existing request based on the priority.
|
1516
|
+
# direction = 1 => smaller number = higher priority; -1 => larger number = higher priority.
|
1517
|
+
# max(...) + (direction * priority, queue_time_start) picks the least-preferred request.
|
1518
|
+
# Tie: later queue_time_start (newer) is evicted first. Preempt only if strictly better.
|
1519
|
+
direction = 1 if self.schedule_low_priority_values_first else -1
|
1520
|
+
key_fn = lambda item: (
|
1521
|
+
direction * item[1].priority,
|
1522
|
+
item[1].time_stats.wait_queue_entry_time,
|
1523
|
+
)
|
1524
|
+
idx, candidate_req = max(enumerate(self.waiting_queue), key=key_fn)
|
1525
|
+
abort_existing_req = (
|
1526
|
+
direction * recv_req.priority < direction * candidate_req.priority
|
1527
|
+
)
|
1528
|
+
if abort_existing_req:
|
1529
|
+
self.waiting_queue.pop(idx)
|
1530
|
+
req_to_abort = candidate_req
|
1531
|
+
message = "The request is aborted by a higher priority request."
|
1532
|
+
|
1533
|
+
self.send_to_tokenizer.send_pyobj(
|
1534
|
+
AbortReq(
|
1535
|
+
finished_reason={
|
1536
|
+
"type": "abort",
|
1537
|
+
"status_code": HTTPStatus.SERVICE_UNAVAILABLE,
|
1538
|
+
"message": message,
|
1539
|
+
},
|
1540
|
+
rid=req_to_abort.rid,
|
1541
|
+
)
|
1542
|
+
)
|
1543
|
+
return req_to_abort.rid == recv_req.rid
|
1325
1544
|
|
1326
1545
|
def handle_embedding_request(
|
1327
1546
|
self,
|
@@ -1333,6 +1552,7 @@ class Scheduler(
|
|
1333
1552
|
recv_req.input_ids,
|
1334
1553
|
recv_req.sampling_params,
|
1335
1554
|
token_type_ids=recv_req.token_type_ids,
|
1555
|
+
priority=recv_req.priority,
|
1336
1556
|
)
|
1337
1557
|
req.tokenizer = self.tokenizer
|
1338
1558
|
|
@@ -1409,9 +1629,11 @@ class Scheduler(
|
|
1409
1629
|
_, _, available_size, evictable_size = self._get_token_info()
|
1410
1630
|
protected_size = self.tree_cache.protected_size()
|
1411
1631
|
memory_leak = (available_size + evictable_size) != (
|
1632
|
+
# self.max_total_num_tokens
|
1633
|
+
# if not self.enable_hierarchical_cache
|
1634
|
+
# else self.max_total_num_tokens - protected_size
|
1412
1635
|
self.max_total_num_tokens
|
1413
|
-
|
1414
|
-
else self.max_total_num_tokens - protected_size
|
1636
|
+
- protected_size
|
1415
1637
|
)
|
1416
1638
|
token_msg = f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, {protected_size=}\n"
|
1417
1639
|
|
@@ -1462,6 +1684,20 @@ class Scheduler(
|
|
1462
1684
|
self.stats.gen_throughput = 0
|
1463
1685
|
self.stats.num_queue_reqs = len(self.waiting_queue)
|
1464
1686
|
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
|
1687
|
+
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
1688
|
+
self.stats.num_prefill_prealloc_queue_reqs = len(
|
1689
|
+
self.disagg_prefill_bootstrap_queue.queue
|
1690
|
+
)
|
1691
|
+
self.stats.num_prefill_inflight_queue_reqs = len(
|
1692
|
+
self.disagg_prefill_inflight_queue
|
1693
|
+
)
|
1694
|
+
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
1695
|
+
self.stats.num_decode_prealloc_queue_reqs = len(
|
1696
|
+
self.disagg_decode_prealloc_queue.queue
|
1697
|
+
)
|
1698
|
+
self.stats.num_decode_transfer_queue_reqs = len(
|
1699
|
+
self.disagg_decode_transfer_queue.queue
|
1700
|
+
)
|
1465
1701
|
self.metrics_collector.log_stats(self.stats)
|
1466
1702
|
self._publish_kv_events()
|
1467
1703
|
|
@@ -1509,7 +1745,12 @@ class Scheduler(
|
|
1509
1745
|
chunked_req_to_exclude.add(self.chunked_req)
|
1510
1746
|
self.tree_cache.cache_unfinished_req(self.chunked_req, chunked=True)
|
1511
1747
|
# chunked request keeps its rid but will get a new req_pool_idx
|
1512
|
-
self.
|
1748
|
+
if self.tp_worker.worker.model_runner.is_hybrid_gdn:
|
1749
|
+
self.req_to_token_pool.free(
|
1750
|
+
self.chunked_req.req_pool_idx, free_mamba_cache=False
|
1751
|
+
)
|
1752
|
+
else:
|
1753
|
+
self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
|
1513
1754
|
if self.last_batch and self.last_batch.forward_mode.is_extend():
|
1514
1755
|
if self.last_batch.chunked_req is not None:
|
1515
1756
|
# In the context pipeline parallelism, after the last chunk, the current microbatch still track outdated chunked_req.
|
@@ -1556,7 +1797,6 @@ class Scheduler(
|
|
1556
1797
|
|
1557
1798
|
# Handle DP attention
|
1558
1799
|
if need_dp_attn_preparation:
|
1559
|
-
self.maybe_handle_dp_balance_data()
|
1560
1800
|
ret = self.prepare_mlp_sync_batch(ret)
|
1561
1801
|
|
1562
1802
|
return ret
|
@@ -1572,6 +1812,10 @@ class Scheduler(
|
|
1572
1812
|
if self.grammar_queue:
|
1573
1813
|
self.move_ready_grammar_requests()
|
1574
1814
|
|
1815
|
+
if self.try_preemption:
|
1816
|
+
# Reset batch_is_full to try preemption with a prefill adder.
|
1817
|
+
self.running_batch.batch_is_full = False
|
1818
|
+
|
1575
1819
|
# Handle the cases where prefill is not allowed
|
1576
1820
|
if (
|
1577
1821
|
self.running_batch.batch_is_full or len(self.waiting_queue) == 0
|
@@ -1584,7 +1828,11 @@ class Scheduler(
|
|
1584
1828
|
# as the space for the chunked request has just been released.
|
1585
1829
|
# In PP case, a chunked req can start in one microbatch and end in another microbatch, so the max_running_requests per microbatch should not be strict.
|
1586
1830
|
# Instead, we should always allow chunked request to be added, otherwise, there will be a memory leak.
|
1587
|
-
if
|
1831
|
+
if (
|
1832
|
+
self.get_num_allocatable_reqs(running_bs) <= 0
|
1833
|
+
and not self.chunked_req
|
1834
|
+
and not self.try_preemption
|
1835
|
+
):
|
1588
1836
|
self.running_batch.batch_is_full = True
|
1589
1837
|
return None
|
1590
1838
|
|
@@ -1604,6 +1852,7 @@ class Scheduler(
|
|
1604
1852
|
self.max_prefill_tokens,
|
1605
1853
|
self.chunked_prefill_size,
|
1606
1854
|
running_bs if self.is_mixed_chunk else 0,
|
1855
|
+
self.priority_scheduling_preemption_threshold,
|
1607
1856
|
)
|
1608
1857
|
|
1609
1858
|
if self.chunked_req is not None:
|
@@ -1624,15 +1873,19 @@ class Scheduler(
|
|
1624
1873
|
self.running_batch.batch_is_full = True
|
1625
1874
|
break
|
1626
1875
|
|
1876
|
+
running_bs = len(self.running_batch.reqs) - len(adder.preempt_list)
|
1627
1877
|
if len(adder.can_run_list) >= self.get_num_allocatable_reqs(running_bs):
|
1628
1878
|
self.running_batch.batch_is_full = True
|
1629
|
-
break
|
1630
|
-
|
1631
1879
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
1632
1880
|
# In prefill mode, prealloc queue and transfer queue can also take memory,
|
1633
1881
|
# so we need to check if the available size for the actual available size.
|
1634
1882
|
if len(adder.can_run_list) >= self.req_to_token_pool.available_size():
|
1635
1883
|
self.running_batch.batch_is_full = True
|
1884
|
+
|
1885
|
+
if self.running_batch.batch_is_full:
|
1886
|
+
if not self.try_preemption:
|
1887
|
+
break
|
1888
|
+
if not adder.preempt_to_schedule(req, self.server_args):
|
1636
1889
|
break
|
1637
1890
|
|
1638
1891
|
if self.enable_hicache_storage:
|
@@ -1642,7 +1895,11 @@ class Scheduler(
|
|
1642
1895
|
continue
|
1643
1896
|
|
1644
1897
|
req.init_next_round_input(self.tree_cache)
|
1645
|
-
res = adder.add_one_req(
|
1898
|
+
res = adder.add_one_req(
|
1899
|
+
req,
|
1900
|
+
has_chunked_req=(self.chunked_req is not None),
|
1901
|
+
truncation_align_size=self.truncation_align_size,
|
1902
|
+
)
|
1646
1903
|
|
1647
1904
|
if res != AddReqResult.CONTINUE:
|
1648
1905
|
if res == AddReqResult.NO_TOKEN:
|
@@ -1663,11 +1920,14 @@ class Scheduler(
|
|
1663
1920
|
if self.enable_metrics:
|
1664
1921
|
# only record queue time when enable_metrics is True to avoid overhead
|
1665
1922
|
for req in can_run_list:
|
1666
|
-
req.
|
1923
|
+
req.add_latency(RequestStage.PREFILL_WAITING)
|
1667
1924
|
|
1668
1925
|
self.waiting_queue = [
|
1669
1926
|
x for x in self.waiting_queue if x not in set(can_run_list)
|
1670
1927
|
]
|
1928
|
+
if adder.preempt_list:
|
1929
|
+
for req in adder.preempt_list:
|
1930
|
+
self._add_request_to_queue(req)
|
1671
1931
|
|
1672
1932
|
if adder.new_chunked_req is not None:
|
1673
1933
|
assert self.chunked_req is None
|
@@ -1678,7 +1938,16 @@ class Scheduler(
|
|
1678
1938
|
|
1679
1939
|
# Print stats
|
1680
1940
|
if self.current_scheduler_metrics_enabled():
|
1681
|
-
self.log_prefill_stats(adder, can_run_list, running_bs)
|
1941
|
+
self.log_prefill_stats(adder, can_run_list, running_bs, 0)
|
1942
|
+
|
1943
|
+
for req in can_run_list:
|
1944
|
+
if req.time_stats.forward_entry_time == 0:
|
1945
|
+
# Avoid update chunked request many times
|
1946
|
+
req.time_stats.forward_entry_time = time.perf_counter()
|
1947
|
+
if self.enable_metrics:
|
1948
|
+
self.metrics_collector.observe_queue_time(
|
1949
|
+
req.time_stats.get_queueing_time(),
|
1950
|
+
)
|
1682
1951
|
|
1683
1952
|
# Create a new batch
|
1684
1953
|
new_batch = ScheduleBatch.init_new(
|
@@ -1733,19 +2002,25 @@ class Scheduler(
|
|
1733
2002
|
TEST_RETRACT and batch.batch_size() > 10
|
1734
2003
|
):
|
1735
2004
|
old_ratio = self.new_token_ratio
|
1736
|
-
|
1737
|
-
|
1738
|
-
|
2005
|
+
retracted_reqs, new_token_ratio, reqs_to_abort = batch.retract_decode(
|
2006
|
+
self.server_args
|
2007
|
+
)
|
2008
|
+
self.num_retracted_reqs = len(retracted_reqs)
|
1739
2009
|
self.new_token_ratio = new_token_ratio
|
2010
|
+
for req in reqs_to_abort:
|
2011
|
+
self.send_to_tokenizer.send_pyobj(
|
2012
|
+
AbortReq(abort_reason=req.to_abort_message, rid=req.rid)
|
2013
|
+
)
|
1740
2014
|
|
1741
2015
|
logger.info(
|
1742
2016
|
"KV cache pool is full. Retract requests. "
|
1743
|
-
f"#retracted_reqs: {
|
1744
|
-
f"#
|
2017
|
+
f"#retracted_reqs: {len(retracted_reqs)}, "
|
2018
|
+
f"#aborted_retracted_reqs: {len(reqs_to_abort)}, "
|
2019
|
+
f"#new_token_ratio: {old_ratio:.4f} -> {new_token_ratio:.4f}"
|
1745
2020
|
)
|
1746
2021
|
|
1747
|
-
|
1748
|
-
|
2022
|
+
for req in retracted_reqs:
|
2023
|
+
self._add_request_to_queue(req, is_retracted=True)
|
1749
2024
|
else:
|
1750
2025
|
self.new_token_ratio = max(
|
1751
2026
|
self.new_token_ratio - self.new_token_ratio_decay,
|
@@ -1773,37 +2048,25 @@ class Scheduler(
|
|
1773
2048
|
|
1774
2049
|
# Run forward
|
1775
2050
|
if self.is_generation:
|
2051
|
+
|
2052
|
+
batch_or_worker_batch = batch
|
2053
|
+
|
1776
2054
|
if self.spec_algorithm.is_none():
|
1777
|
-
|
2055
|
+
# FIXME(lsyin): remove this if and finally unify the abstraction
|
2056
|
+
batch_or_worker_batch = batch.get_model_worker_batch()
|
1778
2057
|
|
1779
|
-
|
1780
|
-
|
1781
|
-
|
2058
|
+
forward_batch_output = self.model_worker.forward_batch_generation(
|
2059
|
+
batch_or_worker_batch
|
2060
|
+
)
|
2061
|
+
|
2062
|
+
if not self.spec_algorithm.is_none():
|
2063
|
+
# TODO(lsyin): unify this metric-updating logic with non-spec, and move it to decode processing
|
2064
|
+
self.udpate_spec_metrics(
|
2065
|
+
batch.batch_size(), forward_batch_output.num_accepted_tokens
|
1782
2066
|
)
|
1783
|
-
|
1784
|
-
|
1785
|
-
|
1786
|
-
)
|
1787
|
-
else:
|
1788
|
-
pp_hidden_states_proxy_tensors, _, can_run_cuda_graph = (
|
1789
|
-
self.tp_worker.forward_batch_generation(model_worker_batch)
|
1790
|
-
)
|
1791
|
-
bid = model_worker_batch.bid
|
1792
|
-
else:
|
1793
|
-
(
|
1794
|
-
logits_output,
|
1795
|
-
next_token_ids,
|
1796
|
-
bid,
|
1797
|
-
num_accepted_tokens,
|
1798
|
-
can_run_cuda_graph,
|
1799
|
-
) = self.draft_worker.forward_batch_speculative_generation(batch)
|
1800
|
-
bs = batch.batch_size()
|
1801
|
-
self.spec_num_total_accepted_tokens += num_accepted_tokens + bs
|
1802
|
-
self.spec_num_total_forward_ct += bs
|
1803
|
-
self.num_generated_tokens += num_accepted_tokens
|
1804
|
-
|
1805
|
-
if self.pp_group.is_last_rank:
|
1806
|
-
batch.output_ids = next_token_ids
|
2067
|
+
|
2068
|
+
# update batch's output ids
|
2069
|
+
batch.output_ids = forward_batch_output.next_token_ids
|
1807
2070
|
|
1808
2071
|
# These 2 values are needed for processing the output, but the values can be
|
1809
2072
|
# modified by overlap schedule. So we have to copy them here so that
|
@@ -1812,6 +2075,7 @@ class Scheduler(
|
|
1812
2075
|
extend_input_len_per_req = [req.extend_input_len for req in batch.reqs]
|
1813
2076
|
else:
|
1814
2077
|
extend_input_len_per_req = None
|
2078
|
+
|
1815
2079
|
if batch.return_logprob:
|
1816
2080
|
extend_logprob_start_len_per_req = [
|
1817
2081
|
req.extend_logprob_start_len for req in batch.reqs
|
@@ -1819,25 +2083,15 @@ class Scheduler(
|
|
1819
2083
|
else:
|
1820
2084
|
extend_logprob_start_len_per_req = None
|
1821
2085
|
|
1822
|
-
|
1823
|
-
|
1824
|
-
pp_hidden_states_proxy_tensors=(
|
1825
|
-
pp_hidden_states_proxy_tensors
|
1826
|
-
if not self.pp_group.is_last_rank
|
1827
|
-
else None
|
1828
|
-
),
|
1829
|
-
next_token_ids=next_token_ids if self.pp_group.is_last_rank else None,
|
2086
|
+
return GenerationBatchResult.from_forward_batch_output(
|
2087
|
+
forward_batch_output=forward_batch_output,
|
1830
2088
|
extend_input_len_per_req=extend_input_len_per_req,
|
1831
2089
|
extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
|
1832
|
-
bid=bid,
|
1833
|
-
can_run_cuda_graph=can_run_cuda_graph,
|
1834
2090
|
)
|
1835
2091
|
else: # embedding or reward model
|
1836
2092
|
model_worker_batch = batch.get_model_worker_batch()
|
1837
2093
|
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
|
1838
|
-
ret = EmbeddingBatchResult(
|
1839
|
-
embeddings=embeddings, bid=model_worker_batch.bid
|
1840
|
-
)
|
2094
|
+
ret = EmbeddingBatchResult(embeddings=embeddings)
|
1841
2095
|
return ret
|
1842
2096
|
|
1843
2097
|
def process_batch_result(
|
@@ -1848,8 +2102,14 @@ class Scheduler(
|
|
1848
2102
|
):
|
1849
2103
|
if batch.forward_mode.is_decode():
|
1850
2104
|
self.process_batch_result_decode(batch, result, launch_done)
|
2105
|
+
if self.enable_trace:
|
2106
|
+
trace_slice_batch("decode loop", batch.reqs)
|
2107
|
+
|
1851
2108
|
elif batch.forward_mode.is_extend():
|
1852
2109
|
self.process_batch_result_prefill(batch, result, launch_done)
|
2110
|
+
if self.enable_trace:
|
2111
|
+
trace_slice_batch("prefill", batch.reqs)
|
2112
|
+
|
1853
2113
|
elif batch.forward_mode.is_idle():
|
1854
2114
|
if self.enable_overlap:
|
1855
2115
|
self.tp_worker.resolve_last_batch_result(launch_done)
|
@@ -2008,12 +2268,13 @@ class Scheduler(
|
|
2008
2268
|
if req.finished(): # It is aborted by AbortReq
|
2009
2269
|
num_ready_reqs += 1
|
2010
2270
|
continue
|
2271
|
+
|
2011
2272
|
req.grammar = req.grammar.result(timeout=0.03)
|
2012
2273
|
self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
|
2013
2274
|
if req.grammar is INVALID_GRAMMAR_OBJ:
|
2014
|
-
req.
|
2015
|
-
|
2016
|
-
|
2275
|
+
error_msg = f"Invalid grammar request: {req.grammar_key=}"
|
2276
|
+
req.set_finish_with_abort(error_msg)
|
2277
|
+
|
2017
2278
|
num_ready_reqs += 1
|
2018
2279
|
except futures._base.TimeoutError:
|
2019
2280
|
req.grammar_wait_ct += 1
|
@@ -2045,9 +2306,8 @@ class Scheduler(
|
|
2045
2306
|
req.grammar = req.grammar.result()
|
2046
2307
|
self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
|
2047
2308
|
if req.grammar is INVALID_GRAMMAR_OBJ:
|
2048
|
-
req.
|
2049
|
-
|
2050
|
-
)
|
2309
|
+
error_msg = f"Invalid grammar request: {req.grammar_key=}"
|
2310
|
+
req.set_finish_with_abort(error_msg)
|
2051
2311
|
else:
|
2052
2312
|
num_ready_reqs_max = num_ready_reqs
|
2053
2313
|
num_timeout_reqs_max = num_timeout_reqs
|
@@ -2055,12 +2315,14 @@ class Scheduler(
|
|
2055
2315
|
for i in range(num_ready_reqs, num_ready_reqs + num_timeout_reqs_max):
|
2056
2316
|
req = self.grammar_queue[i]
|
2057
2317
|
req.grammar.cancel()
|
2318
|
+
self.grammar_backend.set_cache(req.grammar_key, INVALID_GRAMMAR_OBJ)
|
2058
2319
|
error_msg = f"Grammar preprocessing timed out for {req.grammar_key=}"
|
2059
2320
|
req.set_finish_with_abort(error_msg)
|
2060
|
-
|
2321
|
+
|
2061
2322
|
num_ready_reqs = num_ready_reqs_max + num_timeout_reqs_max
|
2062
2323
|
|
2063
|
-
self.
|
2324
|
+
for req in self.grammar_queue[:num_ready_reqs]:
|
2325
|
+
self._add_request_to_queue(req)
|
2064
2326
|
self.grammar_queue = self.grammar_queue[num_ready_reqs:]
|
2065
2327
|
|
2066
2328
|
def set_next_batch_sampling_info_done(self, batch: ScheduleBatch):
|
@@ -2152,9 +2414,8 @@ class Scheduler(
|
|
2152
2414
|
self.req_to_token_pool.clear()
|
2153
2415
|
self.token_to_kv_pool_allocator.clear()
|
2154
2416
|
|
2155
|
-
if
|
2156
|
-
self.draft_worker.
|
2157
|
-
self.draft_worker.model_runner.token_to_kv_pool_allocator.clear()
|
2417
|
+
if self.draft_worker:
|
2418
|
+
self.draft_worker.clear_cache_pool()
|
2158
2419
|
|
2159
2420
|
self.num_generated_tokens = 0
|
2160
2421
|
self.forward_ct_decode = 0
|
@@ -2174,39 +2435,50 @@ class Scheduler(
|
|
2174
2435
|
if_success = False
|
2175
2436
|
return if_success
|
2176
2437
|
|
2177
|
-
def get_load(self):
|
2438
|
+
def get_load(self, recv_req: GetLoadReqInput = None) -> GetLoadReqOutput:
|
2178
2439
|
# TODO(lsyin): use dynamically maintained num_waiting_tokens
|
2440
|
+
|
2179
2441
|
if self.is_hybrid:
|
2180
|
-
|
2442
|
+
num_tokens_full = (
|
2181
2443
|
self.full_tokens_per_layer
|
2182
2444
|
- self.token_to_kv_pool_allocator.full_available_size()
|
2183
2445
|
- self.tree_cache.full_evictable_size()
|
2184
2446
|
)
|
2185
|
-
|
2447
|
+
num_tokens_swa = (
|
2186
2448
|
self.swa_tokens_per_layer
|
2187
2449
|
- self.token_to_kv_pool_allocator.swa_available_size()
|
2188
2450
|
- self.tree_cache.swa_evictable_size()
|
2189
2451
|
)
|
2190
|
-
|
2452
|
+
num_tokens = max(num_tokens_full, num_tokens_swa)
|
2191
2453
|
else:
|
2192
|
-
|
2454
|
+
num_tokens = (
|
2193
2455
|
self.max_total_num_tokens
|
2194
2456
|
- self.token_to_kv_pool_allocator.available_size()
|
2195
2457
|
- self.tree_cache.evictable_size()
|
2196
2458
|
)
|
2197
|
-
|
2459
|
+
|
2460
|
+
# Tokens in waiting queue, bootstrap queue, prealloc queue
|
2461
|
+
num_tokens += sum(len(req.origin_input_ids) for req in self.waiting_queue)
|
2462
|
+
num_waiting_reqs = len(self.waiting_queue)
|
2198
2463
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
2199
|
-
|
2464
|
+
num_tokens += sum(
|
2200
2465
|
len(req.origin_input_ids)
|
2201
2466
|
for req in self.disagg_prefill_bootstrap_queue.queue
|
2202
2467
|
)
|
2468
|
+
num_waiting_reqs += len(self.disagg_prefill_bootstrap_queue.queue)
|
2203
2469
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
2204
|
-
|
2470
|
+
num_tokens += sum(
|
2205
2471
|
len(req.req.origin_input_ids)
|
2206
2472
|
for req in self.disagg_decode_prealloc_queue.queue
|
2207
2473
|
)
|
2474
|
+
num_waiting_reqs += len(self.disagg_decode_prealloc_queue.queue)
|
2208
2475
|
|
2209
|
-
return
|
2476
|
+
return GetLoadReqOutput(
|
2477
|
+
dp_rank=self.dp_rank,
|
2478
|
+
num_reqs=len(self.running_batch.reqs) + num_waiting_reqs,
|
2479
|
+
num_waiting_reqs=num_waiting_reqs,
|
2480
|
+
num_tokens=num_tokens,
|
2481
|
+
)
|
2210
2482
|
|
2211
2483
|
def get_internal_state(self, recv_req: GetInternalStateReq):
|
2212
2484
|
ret = dict(global_server_args_dict)
|
@@ -2221,10 +2493,9 @@ class Scheduler(
|
|
2221
2493
|
"token_capacity": int(self.max_total_num_tokens),
|
2222
2494
|
}
|
2223
2495
|
|
2224
|
-
|
2225
|
-
|
2226
|
-
|
2227
|
-
)
|
2496
|
+
ret["memory_usage"]["graph"] = round(
|
2497
|
+
self.tp_worker.worker.model_runner.graph_mem_usage, 2
|
2498
|
+
)
|
2228
2499
|
|
2229
2500
|
if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
|
2230
2501
|
ret["avg_spec_accept_length"] = (
|
@@ -2233,8 +2504,6 @@ class Scheduler(
|
|
2233
2504
|
if RECORD_STEP_TIME:
|
2234
2505
|
ret["step_time_dict"] = self.step_time_dict
|
2235
2506
|
|
2236
|
-
ret["load"] = self.get_load()
|
2237
|
-
|
2238
2507
|
return GetInternalStateReqOutput(internal_state=ret)
|
2239
2508
|
|
2240
2509
|
def set_internal_state(self, recv_req: SetInternalStateReq):
|
@@ -2310,7 +2579,7 @@ class Scheduler(
|
|
2310
2579
|
if self.enable_hicache_storage:
|
2311
2580
|
# to release prefetch events associated with the request
|
2312
2581
|
self.tree_cache.release_aborted_request(req.rid)
|
2313
|
-
self.send_to_tokenizer.send_pyobj(AbortReq(req.rid))
|
2582
|
+
self.send_to_tokenizer.send_pyobj(AbortReq(rid=req.rid))
|
2314
2583
|
# For disaggregation decode mode, the request in the waiting queue has KV cache allocated.
|
2315
2584
|
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
2316
2585
|
self.tree_cache.cache_finished_req(req)
|
@@ -2331,31 +2600,31 @@ class Scheduler(
|
|
2331
2600
|
# Delete requests not in the waiting queue when PD disaggregation is enabled
|
2332
2601
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
2333
2602
|
# Abort requests that have not yet been bootstrapped
|
2334
|
-
for
|
2335
|
-
logger.debug(f"Abort bootstrap queue request. {req.rid=}")
|
2603
|
+
for req in self.disagg_prefill_bootstrap_queue.queue:
|
2336
2604
|
if recv_req.abort_all or req.rid.startswith(recv_req.rid):
|
2605
|
+
logger.debug(f"Abort bootstrap queue request. {req.rid=}")
|
2337
2606
|
if hasattr(req.disagg_kv_sender, "abort"):
|
2338
2607
|
req.disagg_kv_sender.abort()
|
2339
2608
|
|
2340
2609
|
# Abort in-flight requests
|
2341
|
-
for
|
2342
|
-
logger.debug(f"Abort inflight queue request. {req.rid=}")
|
2610
|
+
for req in self.disagg_prefill_inflight_queue:
|
2343
2611
|
if recv_req.abort_all or req.rid.startswith(recv_req.rid):
|
2612
|
+
logger.debug(f"Abort inflight queue request. {req.rid=}")
|
2344
2613
|
if hasattr(req.disagg_kv_sender, "abort"):
|
2345
2614
|
req.disagg_kv_sender.abort()
|
2346
2615
|
|
2347
2616
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
2348
2617
|
# Abort requests that have not yet finished preallocation
|
2349
|
-
for
|
2350
|
-
logger.debug(f"Abort prealloc queue request. {decode_req.req.rid=}")
|
2618
|
+
for decode_req in self.disagg_decode_prealloc_queue.queue:
|
2351
2619
|
if recv_req.abort_all or decode_req.req.rid.startswith(recv_req.rid):
|
2620
|
+
logger.debug(f"Abort prealloc queue request. {decode_req.req.rid=}")
|
2352
2621
|
if hasattr(decode_req.kv_receiver, "abort"):
|
2353
2622
|
decode_req.kv_receiver.abort()
|
2354
2623
|
|
2355
2624
|
# Abort requests waiting for kvcache to release tree cache
|
2356
|
-
for
|
2357
|
-
logger.debug(f"Abort transfer queue request. {decode_req.req.rid=}")
|
2625
|
+
for decode_req in self.disagg_decode_transfer_queue.queue:
|
2358
2626
|
if recv_req.abort_all or decode_req.req.rid.startswith(recv_req.rid):
|
2627
|
+
logger.debug(f"Abort transfer queue request. {decode_req.req.rid=}")
|
2359
2628
|
if hasattr(decode_req.kv_receiver, "abort"):
|
2360
2629
|
decode_req.kv_receiver.abort()
|
2361
2630
|
|
@@ -2398,6 +2667,22 @@ class Scheduler(
|
|
2398
2667
|
self.send_to_detokenizer.send_pyobj(recv_req)
|
2399
2668
|
return recv_req
|
2400
2669
|
|
2670
|
+
def init_weights_send_group_for_remote_instance(
|
2671
|
+
self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
|
2672
|
+
):
|
2673
|
+
"""Init the seed and client instance communication group."""
|
2674
|
+
success, message = self.tp_worker.init_weights_send_group_for_remote_instance(
|
2675
|
+
recv_req
|
2676
|
+
)
|
2677
|
+
return InitWeightsSendGroupForRemoteInstanceReqOutput(success, message)
|
2678
|
+
|
2679
|
+
def send_weights_to_remote_instance(
|
2680
|
+
self, recv_req: SendWeightsToRemoteInstanceReqInput
|
2681
|
+
):
|
2682
|
+
"""Send the seed instance weights to the destination instance."""
|
2683
|
+
success, message = self.tp_worker.send_weights_to_remote_instance(recv_req)
|
2684
|
+
return SendWeightsToRemoteInstanceReqOutput(success, message)
|
2685
|
+
|
2401
2686
|
def slow_down(self, recv_req: SlowDownReqInput):
|
2402
2687
|
t = recv_req.forward_sleep_time
|
2403
2688
|
if t is not None and t <= 0:
|
@@ -2406,11 +2691,12 @@ class Scheduler(
|
|
2406
2691
|
return SlowDownReqOutput()
|
2407
2692
|
|
2408
2693
|
def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
|
2409
|
-
|
2694
|
+
action = recv_req.action
|
2695
|
+
if action == ExpertDistributionReqType.START_RECORD:
|
2410
2696
|
get_global_expert_distribution_recorder().start_record()
|
2411
|
-
elif
|
2697
|
+
elif action == ExpertDistributionReqType.STOP_RECORD:
|
2412
2698
|
get_global_expert_distribution_recorder().stop_record()
|
2413
|
-
elif
|
2699
|
+
elif action == ExpertDistributionReqType.DUMP_RECORD:
|
2414
2700
|
get_global_expert_distribution_recorder().dump_record()
|
2415
2701
|
else:
|
2416
2702
|
raise ValueError(f"Unrecognized ExpertDistributionReq value: {recv_req=}")
|
@@ -2493,7 +2779,8 @@ class IdleSleeper:
|
|
2493
2779
|
|
2494
2780
|
|
2495
2781
|
def is_health_check_generate_req(recv_req):
|
2496
|
-
|
2782
|
+
rid = getattr(recv_req, "rid", None)
|
2783
|
+
return rid is not None and rid.startswith("HEALTH_CHECK")
|
2497
2784
|
|
2498
2785
|
|
2499
2786
|
def is_work_request(recv_req):
|
@@ -2517,10 +2804,12 @@ def run_scheduler_process(
|
|
2517
2804
|
pp_rank: int,
|
2518
2805
|
dp_rank: Optional[int],
|
2519
2806
|
pipe_writer,
|
2520
|
-
balance_meta: Optional[DPBalanceMeta] = None,
|
2521
2807
|
):
|
2522
|
-
# Generate the prefix
|
2808
|
+
# Generate the logger prefix
|
2523
2809
|
prefix = ""
|
2810
|
+
if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
|
2811
|
+
# [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
|
2812
|
+
dp_rank = int(os.environ["SGLANG_DP_RANK"])
|
2524
2813
|
if dp_rank is not None:
|
2525
2814
|
prefix += f" DP{dp_rank}"
|
2526
2815
|
if server_args.tp_size > 1:
|
@@ -2536,10 +2825,6 @@ def run_scheduler_process(
|
|
2536
2825
|
kill_itself_when_parent_died()
|
2537
2826
|
parent_process = psutil.Process().parent()
|
2538
2827
|
|
2539
|
-
# [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
|
2540
|
-
if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
|
2541
|
-
dp_rank = int(os.environ["SGLANG_DP_RANK"])
|
2542
|
-
|
2543
2828
|
# Configure the logger
|
2544
2829
|
configure_logger(server_args, prefix=prefix)
|
2545
2830
|
suppress_other_loggers()
|
@@ -2547,6 +2832,15 @@ def run_scheduler_process(
|
|
2547
2832
|
# Set cpu affinity to this gpu process
|
2548
2833
|
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
|
2549
2834
|
set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
|
2835
|
+
if (numa_node := server_args.numa_node) is not None:
|
2836
|
+
numa_bind_to_node(numa_node[gpu_id])
|
2837
|
+
|
2838
|
+
# Set up tracing
|
2839
|
+
if server_args.enable_trace:
|
2840
|
+
process_tracing_init(server_args.oltp_traces_endpoint, "sglang")
|
2841
|
+
if server_args.disaggregation_mode == "null":
|
2842
|
+
thread_label = "Scheduler"
|
2843
|
+
trace_set_thread_info(thread_label, tp_rank, dp_rank)
|
2550
2844
|
|
2551
2845
|
# Create a scheduler and run the event loop
|
2552
2846
|
try:
|
@@ -2558,7 +2852,6 @@ def run_scheduler_process(
|
|
2558
2852
|
moe_ep_rank,
|
2559
2853
|
pp_rank,
|
2560
2854
|
dp_rank,
|
2561
|
-
dp_balance_meta=balance_meta,
|
2562
2855
|
)
|
2563
2856
|
pipe_writer.send(
|
2564
2857
|
{
|