sglang 0.5.2rc2__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -11
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +474 -142
- sglang/compile_deep_gemm.py +3 -0
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +10 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +314 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +228 -92
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/qwen3_next.py +294 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +78 -37
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +373 -68
- sglang/srt/disaggregation/prefill.py +53 -49
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +842 -0
- sglang/srt/entrypoints/grpc_server.py +950 -0
- sglang/srt/entrypoints/http_server.py +179 -60
- sglang/srt/entrypoints/openai/protocol.py +265 -29
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +213 -122
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +289 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +17 -8
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +119 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +492 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +327 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +215 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +343 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +40 -8
- sglang/srt/layers/attention/flashinfer_backend.py +341 -204
- sglang/srt/layers/attention/flashinfer_mla_backend.py +28 -28
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +708 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +974 -0
- sglang/srt/layers/attention/mamba/mamba.py +577 -0
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +180 -18
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +21 -24
- sglang/srt/layers/moe/ep_moe/kernels.py +33 -454
- sglang/srt/layers/moe/ep_moe/layer.py +248 -333
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +68 -72
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +83 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +29 -7
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +155 -60
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +191 -56
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +28 -33
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +44 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +55 -0
- sglang/srt/managers/schedule_batch.py +343 -212
- sglang/srt/managers/schedule_policy.py +145 -18
- sglang/srt/managers/scheduler.py +653 -273
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +255 -108
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +579 -674
- sglang/srt/managers/tp_worker.py +96 -26
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +21 -22
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +9 -2
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +651 -80
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +227 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +93 -48
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +74 -46
- sglang/srt/model_executor/model_runner.py +455 -176
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +10 -4
- sglang/srt/model_loader/loader.py +319 -10
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +161 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +578 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +17 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/grok.py +5 -13
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mixtral.py +1 -3
- sglang/srt/models/mllama4.py +50 -4
- sglang/srt/models/nemotron_h.py +514 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +55 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +49 -26
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +1051 -285
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +98 -29
- sglang/srt/speculative/ngram_info.py +428 -0
- sglang/srt/speculative/ngram_worker.py +246 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +605 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +9 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +451 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +55 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +119 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_longbench_v2.py +332 -0
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +9 -19
- sglang/test/test_deterministic.py +313 -0
- sglang/test/test_deterministic_utils.py +81 -0
- sglang/test/test_disaggregation_utils.py +140 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +407 -8
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +392 -251
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -296
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
sglang/srt/managers/scheduler.py
CHANGED
@@ -25,12 +25,14 @@ from concurrent import futures
|
|
25
25
|
from dataclasses import dataclass
|
26
26
|
from http import HTTPStatus
|
27
27
|
from types import SimpleNamespace
|
28
|
-
from typing import Dict, List, Optional, Tuple, Union
|
28
|
+
from typing import Deque, Dict, List, Optional, Tuple, Union
|
29
29
|
|
30
30
|
import psutil
|
31
31
|
import setproctitle
|
32
32
|
import torch
|
33
33
|
import zmq
|
34
|
+
from torch.cuda import Stream as CudaStream
|
35
|
+
from torch.cuda import StreamContext as CudaStreamContext
|
34
36
|
from torch.distributed import barrier
|
35
37
|
|
36
38
|
from sglang.global_config import global_config
|
@@ -44,6 +46,9 @@ from sglang.srt.disaggregation.decode import (
|
|
44
46
|
DecodeTransferQueue,
|
45
47
|
SchedulerDisaggregationDecodeMixin,
|
46
48
|
)
|
49
|
+
from sglang.srt.disaggregation.decode_kvcache_offload_manager import (
|
50
|
+
DecodeKVCacheOffloadManager,
|
51
|
+
)
|
47
52
|
from sglang.srt.disaggregation.prefill import (
|
48
53
|
PrefillBootstrapQueue,
|
49
54
|
SchedulerDisaggregationPrefillMixin,
|
@@ -57,11 +62,6 @@ from sglang.srt.disaggregation.utils import (
|
|
57
62
|
)
|
58
63
|
from sglang.srt.distributed import get_pp_group, get_world_group
|
59
64
|
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
60
|
-
from sglang.srt.hf_transformers_utils import (
|
61
|
-
get_processor,
|
62
|
-
get_tokenizer,
|
63
|
-
get_tokenizer_from_processor,
|
64
|
-
)
|
65
65
|
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
66
66
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
67
67
|
from sglang.srt.layers.moe import initialize_moe_config
|
@@ -72,20 +72,26 @@ from sglang.srt.managers.io_struct import (
|
|
72
72
|
ClearHiCacheReqInput,
|
73
73
|
ClearHiCacheReqOutput,
|
74
74
|
CloseSessionReqInput,
|
75
|
+
DestroyWeightsUpdateGroupReqInput,
|
75
76
|
ExpertDistributionReq,
|
76
77
|
ExpertDistributionReqOutput,
|
78
|
+
ExpertDistributionReqType,
|
77
79
|
FlushCacheReqInput,
|
78
80
|
FlushCacheReqOutput,
|
79
81
|
FreezeGCReq,
|
80
82
|
GetInternalStateReq,
|
81
83
|
GetInternalStateReqOutput,
|
84
|
+
GetLoadReqInput,
|
85
|
+
GetLoadReqOutput,
|
82
86
|
GetWeightsByNameReqInput,
|
83
87
|
HealthCheckOutput,
|
88
|
+
InitWeightsSendGroupForRemoteInstanceReqInput,
|
89
|
+
InitWeightsSendGroupForRemoteInstanceReqOutput,
|
84
90
|
InitWeightsUpdateGroupReqInput,
|
85
91
|
LoadLoRAAdapterReqInput,
|
86
92
|
LoadLoRAAdapterReqOutput,
|
87
93
|
MultiTokenizerRegisterReq,
|
88
|
-
|
94
|
+
MultiTokenizerWrapper,
|
89
95
|
OpenSessionReqInput,
|
90
96
|
OpenSessionReqOutput,
|
91
97
|
ProfileReq,
|
@@ -93,6 +99,8 @@ from sglang.srt.managers.io_struct import (
|
|
93
99
|
ResumeMemoryOccupationReqInput,
|
94
100
|
RpcReqInput,
|
95
101
|
RpcReqOutput,
|
102
|
+
SendWeightsToRemoteInstanceReqInput,
|
103
|
+
SendWeightsToRemoteInstanceReqOutput,
|
96
104
|
SetInternalStateReq,
|
97
105
|
SetInternalStateReqOutput,
|
98
106
|
SlowDownReqInput,
|
@@ -106,10 +114,13 @@ from sglang.srt.managers.io_struct import (
|
|
106
114
|
UpdateWeightsFromTensorReqInput,
|
107
115
|
)
|
108
116
|
from sglang.srt.managers.mm_utils import init_embedding_cache
|
117
|
+
from sglang.srt.managers.overlap_utils import FutureIndices, FutureMap
|
109
118
|
from sglang.srt.managers.schedule_batch import (
|
110
119
|
FINISH_ABORT,
|
120
|
+
ModelWorkerBatch,
|
111
121
|
MultimodalInputs,
|
112
122
|
Req,
|
123
|
+
RequestStage,
|
113
124
|
ScheduleBatch,
|
114
125
|
global_server_args_dict,
|
115
126
|
)
|
@@ -132,19 +143,28 @@ from sglang.srt.managers.scheduler_update_weights_mixin import (
|
|
132
143
|
SchedulerUpdateWeightsMixin,
|
133
144
|
)
|
134
145
|
from sglang.srt.managers.session_controller import Session
|
135
|
-
from sglang.srt.managers.
|
136
|
-
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
|
137
|
-
from sglang.srt.managers.utils import DPBalanceMeta, validate_input_length
|
146
|
+
from sglang.srt.managers.utils import validate_input_length
|
138
147
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
139
148
|
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
140
|
-
from sglang.srt.mem_cache.lora_radix_cache import LoRARadixCache
|
141
149
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
142
150
|
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
143
|
-
from sglang.srt.model_executor.forward_batch_info import
|
151
|
+
from sglang.srt.model_executor.forward_batch_info import (
|
152
|
+
ForwardBatch,
|
153
|
+
ForwardMode,
|
154
|
+
PPProxyTensors,
|
155
|
+
)
|
144
156
|
from sglang.srt.parser.reasoning_parser import ReasoningParser
|
145
157
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
146
158
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
147
159
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
160
|
+
from sglang.srt.tracing.trace import (
|
161
|
+
process_tracing_init,
|
162
|
+
trace_set_proc_propagate_context,
|
163
|
+
trace_set_thread_info,
|
164
|
+
trace_slice_batch,
|
165
|
+
trace_slice_end,
|
166
|
+
trace_slice_start,
|
167
|
+
)
|
148
168
|
from sglang.srt.two_batch_overlap import TboDPAttentionPreparer
|
149
169
|
from sglang.srt.utils import (
|
150
170
|
DynamicGradMode,
|
@@ -155,9 +175,10 @@ from sglang.srt.utils import (
|
|
155
175
|
freeze_gc,
|
156
176
|
get_available_gpu_memory,
|
157
177
|
get_bool_env_var,
|
178
|
+
get_int_env_var,
|
158
179
|
get_zmq_socket,
|
159
|
-
is_cpu,
|
160
180
|
kill_itself_when_parent_died,
|
181
|
+
numa_bind_to_node,
|
161
182
|
point_to_point_pyobj,
|
162
183
|
pyspy_dump_schedulers,
|
163
184
|
require_mlp_sync,
|
@@ -166,6 +187,11 @@ from sglang.srt.utils import (
|
|
166
187
|
set_random_seed,
|
167
188
|
suppress_other_loggers,
|
168
189
|
)
|
190
|
+
from sglang.srt.utils.hf_transformers_utils import (
|
191
|
+
get_processor,
|
192
|
+
get_tokenizer,
|
193
|
+
get_tokenizer_from_processor,
|
194
|
+
)
|
169
195
|
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
|
170
196
|
|
171
197
|
logger = logging.getLogger(__name__)
|
@@ -174,24 +200,67 @@ logger = logging.getLogger(__name__)
|
|
174
200
|
TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
|
175
201
|
GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
|
176
202
|
|
177
|
-
_is_cpu = is_cpu()
|
178
|
-
|
179
203
|
|
180
204
|
@dataclass
|
181
205
|
class GenerationBatchResult:
|
182
|
-
logits_output: Optional[LogitsProcessorOutput]
|
183
|
-
pp_hidden_states_proxy_tensors: Optional[
|
184
|
-
next_token_ids: Optional[
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
206
|
+
logits_output: Optional[LogitsProcessorOutput] = None
|
207
|
+
pp_hidden_states_proxy_tensors: Optional[PPProxyTensors] = None
|
208
|
+
next_token_ids: Optional[torch.Tensor] = None
|
209
|
+
num_accepted_tokens: Optional[int] = None
|
210
|
+
can_run_cuda_graph: bool = False
|
211
|
+
|
212
|
+
# For output processing
|
213
|
+
extend_input_len_per_req: Optional[List[int]] = None
|
214
|
+
extend_logprob_start_len_per_req: Optional[List[int]] = None
|
215
|
+
|
216
|
+
# For overlap scheduling
|
217
|
+
copy_done: Optional[torch.cuda.Event] = None
|
218
|
+
delay_sample_launch: bool = False
|
219
|
+
forward_batch: Optional[ForwardBatch] = None
|
220
|
+
future_indices: Optional[FutureIndices] = None
|
221
|
+
|
222
|
+
def copy_to_cpu(self, return_logprob: bool = False):
|
223
|
+
"""Copy tensors to CPU in overlap scheduling.
|
224
|
+
Only the tensors which are needed for processing results are copied,
|
225
|
+
e.g., next_token_ids, logits outputs
|
226
|
+
"""
|
227
|
+
if return_logprob:
|
228
|
+
if self.logits_output.next_token_logits is not None:
|
229
|
+
self.logits_output.next_token_logits = (
|
230
|
+
self.logits_output.next_token_logits.to("cpu", non_blocking=True)
|
231
|
+
)
|
232
|
+
if self.logits_output.input_token_logprobs is not None:
|
233
|
+
self.logits_output.input_token_logprobs = (
|
234
|
+
self.logits_output.input_token_logprobs.to("cpu", non_blocking=True)
|
235
|
+
)
|
236
|
+
if self.logits_output.hidden_states is not None:
|
237
|
+
self.logits_output.hidden_states = self.logits_output.hidden_states.to(
|
238
|
+
"cpu", non_blocking=True
|
239
|
+
)
|
240
|
+
self.next_token_ids = self.next_token_ids.to("cpu", non_blocking=True)
|
241
|
+
self.copy_done.record()
|
242
|
+
|
243
|
+
@classmethod
|
244
|
+
def from_pp_proxy(
|
245
|
+
cls, logits_output, next_pp_outputs: PPProxyTensors, can_run_cuda_graph
|
246
|
+
):
|
247
|
+
# TODO(lsyin): refactor PP and avoid using dict
|
248
|
+
proxy_dict = next_pp_outputs.tensors
|
249
|
+
return cls(
|
250
|
+
logits_output=logits_output,
|
251
|
+
pp_hidden_states_proxy_tensors=None,
|
252
|
+
next_token_ids=next_pp_outputs["next_token_ids"],
|
253
|
+
extend_input_len_per_req=proxy_dict.get("extend_input_len_per_req", None),
|
254
|
+
extend_logprob_start_len_per_req=proxy_dict.get(
|
255
|
+
"extend_logprob_start_len_per_req", None
|
256
|
+
),
|
257
|
+
can_run_cuda_graph=can_run_cuda_graph,
|
258
|
+
)
|
189
259
|
|
190
260
|
|
191
261
|
@dataclass
|
192
262
|
class EmbeddingBatchResult:
|
193
263
|
embeddings: torch.Tensor
|
194
|
-
bid: int
|
195
264
|
|
196
265
|
|
197
266
|
class Scheduler(
|
@@ -204,6 +273,48 @@ class Scheduler(
|
|
204
273
|
):
|
205
274
|
"""A scheduler that manages a tensor parallel GPU worker."""
|
206
275
|
|
276
|
+
def launch_draft_worker(
|
277
|
+
self, gpu_id, tp_rank, moe_ep_rank, server_args, port_args, dp_rank
|
278
|
+
):
|
279
|
+
if self.spec_algorithm.is_eagle():
|
280
|
+
from sglang.srt.speculative.eagle_worker import EAGLEWorker
|
281
|
+
|
282
|
+
self.draft_worker = EAGLEWorker(
|
283
|
+
gpu_id=gpu_id,
|
284
|
+
tp_rank=tp_rank,
|
285
|
+
moe_ep_rank=moe_ep_rank,
|
286
|
+
server_args=server_args,
|
287
|
+
nccl_port=port_args.nccl_port,
|
288
|
+
target_worker=self.tp_worker,
|
289
|
+
dp_rank=dp_rank,
|
290
|
+
)
|
291
|
+
elif self.spec_algorithm.is_standalone():
|
292
|
+
from sglang.srt.speculative.standalone_worker import StandaloneWorker
|
293
|
+
|
294
|
+
self.draft_worker = StandaloneWorker(
|
295
|
+
gpu_id=gpu_id,
|
296
|
+
tp_rank=tp_rank,
|
297
|
+
moe_ep_rank=moe_ep_rank,
|
298
|
+
server_args=server_args,
|
299
|
+
nccl_port=port_args.nccl_port,
|
300
|
+
target_worker=self.tp_worker,
|
301
|
+
dp_rank=dp_rank,
|
302
|
+
)
|
303
|
+
elif self.spec_algorithm.is_ngram():
|
304
|
+
from sglang.srt.speculative.ngram_worker import NGRAMWorker
|
305
|
+
|
306
|
+
self.draft_worker = NGRAMWorker(
|
307
|
+
gpu_id=gpu_id,
|
308
|
+
tp_rank=tp_rank,
|
309
|
+
moe_ep_rank=moe_ep_rank,
|
310
|
+
server_args=server_args,
|
311
|
+
nccl_port=port_args.nccl_port,
|
312
|
+
target_worker=self.tp_worker,
|
313
|
+
dp_rank=dp_rank,
|
314
|
+
)
|
315
|
+
else:
|
316
|
+
self.draft_worker = None
|
317
|
+
|
207
318
|
def __init__(
|
208
319
|
self,
|
209
320
|
server_args: ServerArgs,
|
@@ -213,7 +324,6 @@ class Scheduler(
|
|
213
324
|
moe_ep_rank: int,
|
214
325
|
pp_rank: int,
|
215
326
|
dp_rank: Optional[int],
|
216
|
-
dp_balance_meta: Optional[DPBalanceMeta] = None,
|
217
327
|
):
|
218
328
|
# Parse args
|
219
329
|
self.server_args = server_args
|
@@ -226,6 +336,13 @@ class Scheduler(
|
|
226
336
|
self.pp_size = server_args.pp_size
|
227
337
|
self.dp_size = server_args.dp_size
|
228
338
|
self.schedule_policy = server_args.schedule_policy
|
339
|
+
self.enable_priority_scheduling = server_args.enable_priority_scheduling
|
340
|
+
self.schedule_low_priority_values_first = (
|
341
|
+
server_args.schedule_low_priority_values_first
|
342
|
+
)
|
343
|
+
self.priority_scheduling_preemption_threshold = (
|
344
|
+
server_args.priority_scheduling_preemption_threshold
|
345
|
+
)
|
229
346
|
self.enable_lora = server_args.enable_lora
|
230
347
|
self.max_loras_per_batch = server_args.max_loras_per_batch
|
231
348
|
self.enable_overlap = not server_args.disable_overlap_schedule
|
@@ -234,7 +351,10 @@ class Scheduler(
|
|
234
351
|
self.enable_metrics_for_all_schedulers = (
|
235
352
|
server_args.enable_metrics_for_all_schedulers
|
236
353
|
)
|
237
|
-
self.enable_kv_cache_events =
|
354
|
+
self.enable_kv_cache_events = bool(
|
355
|
+
server_args.kv_events_config and tp_rank == 0
|
356
|
+
)
|
357
|
+
self.enable_trace = server_args.enable_trace
|
238
358
|
self.stream_interval = server_args.stream_interval
|
239
359
|
self.spec_algorithm = SpeculativeAlgorithm.from_string(
|
240
360
|
server_args.speculative_algorithm
|
@@ -320,12 +440,10 @@ class Scheduler(
|
|
320
440
|
logger.info("Overlap scheduler is disabled for embedding models.")
|
321
441
|
|
322
442
|
# Launch a tensor parallel worker
|
323
|
-
if self.enable_overlap:
|
324
|
-
TpWorkerClass = TpModelWorkerClient
|
325
|
-
else:
|
326
|
-
TpWorkerClass = TpModelWorker
|
327
443
|
|
328
|
-
|
444
|
+
from sglang.srt.managers.tp_worker import TpModelWorker
|
445
|
+
|
446
|
+
self.tp_worker = TpModelWorker(
|
329
447
|
server_args=server_args,
|
330
448
|
gpu_id=gpu_id,
|
331
449
|
tp_rank=tp_rank,
|
@@ -336,20 +454,15 @@ class Scheduler(
|
|
336
454
|
)
|
337
455
|
|
338
456
|
# Launch a draft worker for speculative decoding
|
339
|
-
|
340
|
-
|
457
|
+
self.launch_draft_worker(
|
458
|
+
gpu_id, tp_rank, moe_ep_rank, server_args, port_args, dp_rank
|
459
|
+
)
|
341
460
|
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
moe_ep_rank=moe_ep_rank,
|
346
|
-
server_args=server_args,
|
347
|
-
nccl_port=port_args.nccl_port,
|
348
|
-
target_worker=self.tp_worker,
|
349
|
-
dp_rank=dp_rank,
|
350
|
-
)
|
461
|
+
# Dispatch the model worker
|
462
|
+
if self.spec_algorithm.is_none():
|
463
|
+
self.model_worker = self.tp_worker
|
351
464
|
else:
|
352
|
-
self.
|
465
|
+
self.model_worker = self.draft_worker
|
353
466
|
|
354
467
|
# Get token and memory info from the model worker
|
355
468
|
(
|
@@ -366,8 +479,8 @@ class Scheduler(
|
|
366
479
|
_,
|
367
480
|
_,
|
368
481
|
) = self.tp_worker.get_worker_info()
|
369
|
-
if global_server_args_dict["
|
370
|
-
global_server_args_dict["
|
482
|
+
if global_server_args_dict["pp_max_micro_batch_size"] is None:
|
483
|
+
global_server_args_dict["pp_max_micro_batch_size"] = max(
|
371
484
|
self.max_running_requests // server_args.pp_size, 1
|
372
485
|
)
|
373
486
|
|
@@ -401,7 +514,7 @@ class Scheduler(
|
|
401
514
|
f"max_prefill_tokens={self.max_prefill_tokens}, "
|
402
515
|
f"max_running_requests={self.max_running_requests}, "
|
403
516
|
f"context_len={self.model_config.context_len}, "
|
404
|
-
f"available_gpu_mem={avail_mem:.2f} GB"
|
517
|
+
f"{'available_cpu_mem' if self.device == 'cpu' else 'available_gpu_mem'}={avail_mem:.2f} GB"
|
405
518
|
)
|
406
519
|
|
407
520
|
# Init memory pool and cache
|
@@ -427,9 +540,11 @@ class Scheduler(
|
|
427
540
|
self.kv_transfer_speed_gb_s: float = 0.0
|
428
541
|
self.kv_transfer_latency_ms: float = 0.0
|
429
542
|
self.sessions: Dict[str, Session] = {}
|
430
|
-
self.
|
543
|
+
self.default_stream: CudaStream = torch.get_device_module(
|
544
|
+
self.device
|
545
|
+
).current_stream()
|
431
546
|
if self.device == "cpu":
|
432
|
-
self.
|
547
|
+
self.default_stream.synchronize = lambda: None # No-op for CPU
|
433
548
|
self.forward_sleep_time = None
|
434
549
|
|
435
550
|
# Init chunked prefill
|
@@ -458,7 +573,12 @@ class Scheduler(
|
|
458
573
|
self.schedule_policy,
|
459
574
|
self.tree_cache,
|
460
575
|
self.enable_hierarchical_cache,
|
576
|
+
self.enable_priority_scheduling,
|
577
|
+
self.schedule_low_priority_values_first,
|
461
578
|
)
|
579
|
+
# Enable preemption for priority scheduling.
|
580
|
+
self.try_preemption = self.enable_priority_scheduling
|
581
|
+
|
462
582
|
assert (
|
463
583
|
server_args.schedule_conservativeness >= 0
|
464
584
|
), "Invalid schedule_conservativeness"
|
@@ -488,7 +608,7 @@ class Scheduler(
|
|
488
608
|
enable=server_args.enable_memory_saver
|
489
609
|
)
|
490
610
|
self.offload_tags = set()
|
491
|
-
self.
|
611
|
+
self.init_profiler()
|
492
612
|
|
493
613
|
self.recv_skipper = SchedulerRecvSkipper.maybe_create(server_args)
|
494
614
|
self.input_blocker = (
|
@@ -499,8 +619,9 @@ class Scheduler(
|
|
499
619
|
|
500
620
|
# Init metrics stats
|
501
621
|
self.init_metrics(tp_rank, pp_rank, dp_rank)
|
502
|
-
|
503
|
-
self.
|
622
|
+
|
623
|
+
if self.enable_kv_cache_events:
|
624
|
+
self.init_kv_events(server_args.kv_events_config)
|
504
625
|
|
505
626
|
# Init disaggregation
|
506
627
|
self.disaggregation_mode = DisaggregationMode(
|
@@ -511,6 +632,12 @@ class Scheduler(
|
|
511
632
|
if get_bool_env_var("SGLANG_GC_LOG"):
|
512
633
|
configure_gc_logger()
|
513
634
|
|
635
|
+
# Init prefill kv split size when deterministic inference is enabled with various attention backends
|
636
|
+
self.init_deterministic_inference_config()
|
637
|
+
|
638
|
+
# Init overlap
|
639
|
+
self.init_overlap()
|
640
|
+
|
514
641
|
# Init request dispatcher
|
515
642
|
self._request_dispatcher = TypeBasedDispatcher(
|
516
643
|
[
|
@@ -525,6 +652,15 @@ class Scheduler(
|
|
525
652
|
(CloseSessionReqInput, self.close_session),
|
526
653
|
(UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
|
527
654
|
(InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
|
655
|
+
(DestroyWeightsUpdateGroupReqInput, self.destroy_weights_update_group),
|
656
|
+
(
|
657
|
+
InitWeightsSendGroupForRemoteInstanceReqInput,
|
658
|
+
self.init_weights_send_group_for_remote_instance,
|
659
|
+
),
|
660
|
+
(
|
661
|
+
SendWeightsToRemoteInstanceReqInput,
|
662
|
+
self.send_weights_to_remote_instance,
|
663
|
+
),
|
528
664
|
(
|
529
665
|
UpdateWeightsFromDistributedReqInput,
|
530
666
|
self.update_weights_from_distributed,
|
@@ -543,9 +679,27 @@ class Scheduler(
|
|
543
679
|
(LoadLoRAAdapterReqInput, self.load_lora_adapter),
|
544
680
|
(UnloadLoRAAdapterReqInput, self.unload_lora_adapter),
|
545
681
|
(MultiTokenizerRegisterReq, self.register_multi_tokenizer),
|
682
|
+
(GetLoadReqInput, self.get_load),
|
546
683
|
]
|
547
684
|
)
|
548
685
|
|
686
|
+
def init_deterministic_inference_config(self):
|
687
|
+
"""Initialize deterministic inference configuration for different attention backends."""
|
688
|
+
if not self.server_args.enable_deterministic_inference:
|
689
|
+
self.truncation_align_size = None
|
690
|
+
return
|
691
|
+
|
692
|
+
backend_sizes = {
|
693
|
+
"flashinfer": ("SGLANG_FLASHINFER_PREFILL_SPLIT_TILE_SIZE", 4096),
|
694
|
+
"triton": ("SGLANG_TRITON_PREFILL_TRUNCATION_ALIGN_SIZE", 4096),
|
695
|
+
}
|
696
|
+
env_var, default_size = backend_sizes.get(
|
697
|
+
self.server_args.attention_backend, (None, None)
|
698
|
+
)
|
699
|
+
self.truncation_align_size = (
|
700
|
+
get_int_env_var(env_var, default_size) if env_var else None
|
701
|
+
)
|
702
|
+
|
549
703
|
def init_tokenizer(self):
|
550
704
|
server_args = self.server_args
|
551
705
|
self.is_generation = self.model_config.is_generation
|
@@ -617,15 +771,18 @@ class Scheduler(
|
|
617
771
|
else self.tp_cpu_group
|
618
772
|
),
|
619
773
|
page_size=self.page_size,
|
774
|
+
eviction_policy=server_args.radix_eviction_policy,
|
620
775
|
hicache_ratio=server_args.hicache_ratio,
|
621
776
|
hicache_size=server_args.hicache_size,
|
622
777
|
hicache_write_policy=server_args.hicache_write_policy,
|
623
778
|
hicache_io_backend=server_args.hicache_io_backend,
|
624
779
|
hicache_mem_layout=server_args.hicache_mem_layout,
|
780
|
+
enable_metrics=self.enable_metrics,
|
625
781
|
hicache_storage_backend=server_args.hicache_storage_backend,
|
626
782
|
hicache_storage_prefetch_policy=server_args.hicache_storage_prefetch_policy,
|
627
783
|
model_name=server_args.served_model_name,
|
628
784
|
storage_backend_extra_config=server_args.hicache_storage_backend_extra_config,
|
785
|
+
is_eagle=self.spec_algorithm.is_eagle(),
|
629
786
|
)
|
630
787
|
self.tp_worker.register_hicache_layer_transfer_counter(
|
631
788
|
self.tree_cache.cache_controller.layer_done_counter
|
@@ -640,19 +797,23 @@ class Scheduler(
|
|
640
797
|
sliding_window_size=self.sliding_window_size,
|
641
798
|
page_size=self.page_size,
|
642
799
|
disable=server_args.disable_radix_cache,
|
800
|
+
is_eagle=self.spec_algorithm.is_eagle(),
|
643
801
|
)
|
644
|
-
elif
|
645
|
-
|
646
|
-
|
647
|
-
)
|
648
|
-
|
649
|
-
|
650
|
-
), "LoRA radix cache only supports FCFS policy"
|
651
|
-
self.tree_cache = LoRARadixCache(
|
802
|
+
elif server_args.enable_lmcache:
|
803
|
+
from sglang.srt.mem_cache.storage.lmcache.lmc_radix_cache import (
|
804
|
+
LMCRadixCache,
|
805
|
+
)
|
806
|
+
|
807
|
+
self.tree_cache = LMCRadixCache(
|
652
808
|
req_to_token_pool=self.req_to_token_pool,
|
653
809
|
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
654
810
|
page_size=self.page_size,
|
655
811
|
disable=server_args.disable_radix_cache,
|
812
|
+
model_config=self.model_config,
|
813
|
+
tp_size=self.tp_size,
|
814
|
+
rank=self.tp_rank,
|
815
|
+
tp_group=self.tp_group,
|
816
|
+
eviction_policy=server_args.radix_eviction_policy,
|
656
817
|
)
|
657
818
|
else:
|
658
819
|
self.tree_cache = RadixCache(
|
@@ -661,16 +822,36 @@ class Scheduler(
|
|
661
822
|
page_size=self.page_size,
|
662
823
|
disable=server_args.disable_radix_cache,
|
663
824
|
enable_kv_cache_events=self.enable_kv_cache_events,
|
825
|
+
eviction_policy=server_args.radix_eviction_policy,
|
826
|
+
is_eagle=self.spec_algorithm.is_eagle(),
|
664
827
|
)
|
665
828
|
|
829
|
+
if (
|
830
|
+
server_args.disaggregation_mode == "decode"
|
831
|
+
and server_args.disaggregation_decode_enable_offload_kvcache
|
832
|
+
):
|
833
|
+
self.decode_offload_manager = DecodeKVCacheOffloadManager(
|
834
|
+
req_to_token_pool=self.req_to_token_pool,
|
835
|
+
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
836
|
+
tp_group=(
|
837
|
+
self.attn_tp_cpu_group
|
838
|
+
if self.server_args.enable_dp_attention
|
839
|
+
else self.tp_cpu_group
|
840
|
+
),
|
841
|
+
tree_cache=self.tree_cache,
|
842
|
+
server_args=self.server_args,
|
843
|
+
)
|
844
|
+
else:
|
845
|
+
self.decode_offload_manager = None
|
846
|
+
|
666
847
|
self.decode_mem_cache_buf_multiplier = (
|
667
848
|
1
|
668
849
|
if self.spec_algorithm.is_none()
|
669
850
|
else (
|
670
851
|
server_args.speculative_num_draft_tokens
|
671
852
|
+ (
|
672
|
-
server_args.speculative_eagle_topk
|
673
|
-
* server_args.speculative_num_steps
|
853
|
+
(server_args.speculative_eagle_topk or 1)
|
854
|
+
* (server_args.speculative_num_steps or 1)
|
674
855
|
)
|
675
856
|
)
|
676
857
|
)
|
@@ -693,7 +874,7 @@ class Scheduler(
|
|
693
874
|
self.disagg_metadata_buffers = MetadataBuffers(
|
694
875
|
buffer_size,
|
695
876
|
hidden_size=self.model_config.hf_text_config.hidden_size,
|
696
|
-
|
877
|
+
hidden_states_dtype=self.model_config.dtype,
|
697
878
|
custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
|
698
879
|
)
|
699
880
|
|
@@ -713,7 +894,7 @@ class Scheduler(
|
|
713
894
|
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
714
895
|
draft_token_to_kv_pool=(
|
715
896
|
None
|
716
|
-
if self.draft_worker is None
|
897
|
+
if self.draft_worker is None or self.spec_algorithm.is_ngram()
|
717
898
|
else self.draft_worker.model_runner.token_to_kv_pool
|
718
899
|
),
|
719
900
|
req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
|
@@ -742,7 +923,7 @@ class Scheduler(
|
|
742
923
|
self.disagg_metadata_buffers = MetadataBuffers(
|
743
924
|
buffer_size,
|
744
925
|
hidden_size=self.model_config.hf_text_config.hidden_size,
|
745
|
-
|
926
|
+
hidden_states_dtype=self.model_config.dtype,
|
746
927
|
custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
|
747
928
|
)
|
748
929
|
|
@@ -750,7 +931,7 @@ class Scheduler(
|
|
750
931
|
token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
|
751
932
|
draft_token_to_kv_pool=(
|
752
933
|
None
|
753
|
-
if self.draft_worker is None
|
934
|
+
if self.draft_worker is None or self.spec_algorithm.is_ngram()
|
754
935
|
else self.draft_worker.model_runner.token_to_kv_pool
|
755
936
|
),
|
756
937
|
req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
|
@@ -771,6 +952,32 @@ class Scheduler(
|
|
771
952
|
# The prefill requests that are in the middle of kv sending
|
772
953
|
self.disagg_prefill_inflight_queue: List[Req] = []
|
773
954
|
|
955
|
+
def init_overlap(self):
|
956
|
+
if not self.enable_overlap:
|
957
|
+
return
|
958
|
+
|
959
|
+
self.forward_stream: CudaStream = torch.get_device_module(self.device).Stream()
|
960
|
+
self.forward_stream_ctx: CudaStreamContext = torch.get_device_module(
|
961
|
+
self.device
|
962
|
+
).stream(self.forward_stream)
|
963
|
+
self.copy_stream: CudaStream = torch.get_device_module(self.device).Stream()
|
964
|
+
self.copy_stream_ctx: CudaStreamContext = torch.get_device_module(
|
965
|
+
self.device
|
966
|
+
).stream(self.copy_stream)
|
967
|
+
|
968
|
+
self.future_map = FutureMap(self.max_running_requests, self.device)
|
969
|
+
self.batch_record_buf = [None] * 2
|
970
|
+
self.batch_record_ct = 0
|
971
|
+
|
972
|
+
def record_batch_in_overlap(self, model_worker_batch: ModelWorkerBatch):
|
973
|
+
# FIXME(lsyin): hacky way to keep a reference to avoid GPU tensors being freed by torch GC
|
974
|
+
# NOTE: More Reliable: record all tensors into the forward stream
|
975
|
+
# NOTE: - for all future tensors, we shall always read from future map
|
976
|
+
# - for all non-future tensors (produced only by schedule stream),
|
977
|
+
# we shall keep its reference not being release during all the forwarding pass
|
978
|
+
self.batch_record_ct = (self.batch_record_ct + 1) % 2
|
979
|
+
self.batch_record_buf[self.batch_record_ct] = model_worker_batch
|
980
|
+
|
774
981
|
def init_moe_config(self):
|
775
982
|
if hasattr(self.model_config.hf_config, "num_experts_per_tok"):
|
776
983
|
initialize_moe_config(self.server_args)
|
@@ -797,9 +1004,11 @@ class Scheduler(
|
|
797
1004
|
@DynamicGradMode()
|
798
1005
|
def event_loop_overlap(self):
|
799
1006
|
"""A scheduler loop that overlaps the CPU processing and GPU computation."""
|
800
|
-
self.result_queue = deque()
|
1007
|
+
self.result_queue: Deque[Tuple[ScheduleBatch, GenerationBatchResult]] = deque()
|
801
1008
|
|
802
1009
|
while True:
|
1010
|
+
self.launch_last_batch_sample_if_needed()
|
1011
|
+
|
803
1012
|
recv_reqs = self.recv_requests()
|
804
1013
|
self.process_input_requests(recv_reqs)
|
805
1014
|
|
@@ -807,30 +1016,13 @@ class Scheduler(
|
|
807
1016
|
self.cur_batch = batch
|
808
1017
|
|
809
1018
|
if batch:
|
810
|
-
batch.launch_done = threading.Event()
|
811
1019
|
result = self.run_batch(batch)
|
812
1020
|
self.result_queue.append((batch.copy(), result))
|
813
1021
|
|
814
|
-
if self.last_batch is None:
|
815
|
-
# Create a dummy first batch to start the pipeline for overlap schedule.
|
816
|
-
# It is now used for triggering the sampling_info_done event.
|
817
|
-
tmp_batch = ScheduleBatch(
|
818
|
-
reqs=None,
|
819
|
-
forward_mode=ForwardMode.DUMMY_FIRST,
|
820
|
-
next_batch_sampling_info=self.tp_worker.cur_sampling_info,
|
821
|
-
)
|
822
|
-
self.process_batch_result(tmp_batch, None, batch.launch_done)
|
823
|
-
|
824
1022
|
if self.last_batch:
|
825
1023
|
# Process the results of the last batch
|
826
1024
|
tmp_batch, tmp_result = self.result_queue.popleft()
|
827
|
-
tmp_batch
|
828
|
-
self.tp_worker.cur_sampling_info if batch else None
|
829
|
-
)
|
830
|
-
# NOTE: we should use current launched batch's launch_done event Instead of the last batch's
|
831
|
-
self.process_batch_result(
|
832
|
-
tmp_batch, tmp_result, batch.launch_done if batch else None
|
833
|
-
)
|
1025
|
+
self.process_batch_result(tmp_batch, tmp_result)
|
834
1026
|
elif batch is None:
|
835
1027
|
# When the server is idle, do self-check and re-init some states
|
836
1028
|
self.self_check_during_idle()
|
@@ -845,7 +1037,6 @@ class Scheduler(
|
|
845
1037
|
self.running_mbs = [
|
846
1038
|
ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
|
847
1039
|
]
|
848
|
-
bids = [None] * self.pp_size
|
849
1040
|
pp_outputs: Optional[PPProxyTensors] = None
|
850
1041
|
while True:
|
851
1042
|
server_is_idle = True
|
@@ -866,10 +1057,7 @@ class Scheduler(
|
|
866
1057
|
# (last rank) send the outputs to the next step
|
867
1058
|
if self.pp_group.is_last_rank:
|
868
1059
|
if self.cur_batch:
|
869
|
-
next_token_ids
|
870
|
-
result.next_token_ids,
|
871
|
-
result.bid,
|
872
|
-
)
|
1060
|
+
next_token_ids = result.next_token_ids
|
873
1061
|
if self.cur_batch.return_logprob:
|
874
1062
|
pp_outputs = PPProxyTensors(
|
875
1063
|
{
|
@@ -917,17 +1105,10 @@ class Scheduler(
|
|
917
1105
|
logits_output = LogitsProcessorOutput(**logits_output_args)
|
918
1106
|
else:
|
919
1107
|
logits_output = None
|
920
|
-
|
1108
|
+
|
1109
|
+
output_result = GenerationBatchResult.from_pp_proxy(
|
921
1110
|
logits_output=logits_output,
|
922
|
-
|
923
|
-
next_token_ids=next_pp_outputs["next_token_ids"],
|
924
|
-
extend_input_len_per_req=next_pp_outputs.tensors.get(
|
925
|
-
"extend_input_len_per_req", None
|
926
|
-
),
|
927
|
-
extend_logprob_start_len_per_req=next_pp_outputs.tensors.get(
|
928
|
-
"extend_logprob_start_len_per_req", None
|
929
|
-
),
|
930
|
-
bid=bids[next_mb_id],
|
1111
|
+
next_pp_outputs=next_pp_outputs,
|
931
1112
|
can_run_cuda_graph=result.can_run_cuda_graph,
|
932
1113
|
)
|
933
1114
|
self.process_batch_result(mbs[next_mb_id], output_result)
|
@@ -935,8 +1116,6 @@ class Scheduler(
|
|
935
1116
|
|
936
1117
|
# (not last rank)
|
937
1118
|
if not self.pp_group.is_last_rank:
|
938
|
-
if self.cur_batch:
|
939
|
-
bids[mb_id] = result.bid
|
940
1119
|
# carry the outputs to the next stage
|
941
1120
|
# send the outputs from the last round to let the next stage worker run post processing
|
942
1121
|
if pp_outputs:
|
@@ -958,8 +1137,10 @@ class Scheduler(
|
|
958
1137
|
|
959
1138
|
# send out proxy tensors to the next stage
|
960
1139
|
if self.cur_batch:
|
1140
|
+
# FIXME(lsyin): remove this assert
|
1141
|
+
assert result.pp_hidden_states_proxy_tensors.tensors is not None
|
961
1142
|
self.pp_group.send_tensor_dict(
|
962
|
-
result.pp_hidden_states_proxy_tensors,
|
1143
|
+
result.pp_hidden_states_proxy_tensors.tensors,
|
963
1144
|
all_gather_group=self.attn_tp_group,
|
964
1145
|
)
|
965
1146
|
|
@@ -1069,6 +1250,15 @@ class Scheduler(
|
|
1069
1250
|
self.tp_cpu_group,
|
1070
1251
|
src=self.tp_group.ranks[0],
|
1071
1252
|
)
|
1253
|
+
|
1254
|
+
if self.enable_trace:
|
1255
|
+
for req in recv_reqs:
|
1256
|
+
if isinstance(
|
1257
|
+
req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
|
1258
|
+
):
|
1259
|
+
trace_set_proc_propagate_context(req.rid, req.trace_context)
|
1260
|
+
trace_slice_start("", req.rid, anonymous=True)
|
1261
|
+
|
1072
1262
|
return recv_reqs
|
1073
1263
|
|
1074
1264
|
def process_input_requests(self, recv_reqs: List):
|
@@ -1082,27 +1272,13 @@ class Scheduler(
|
|
1082
1272
|
self.return_health_check_ct += 1
|
1083
1273
|
continue
|
1084
1274
|
|
1085
|
-
# If it is a
|
1086
|
-
if
|
1087
|
-
if len(self.waiting_queue) + 1 > self.max_queued_requests:
|
1088
|
-
abort_req = AbortReq(
|
1089
|
-
recv_req.rid,
|
1090
|
-
finished_reason={
|
1091
|
-
"type": "abort",
|
1092
|
-
"status_code": HTTPStatus.SERVICE_UNAVAILABLE,
|
1093
|
-
"message": "The request queue is full.",
|
1094
|
-
},
|
1095
|
-
)
|
1096
|
-
self.send_to_tokenizer.send_pyobj(abort_req)
|
1097
|
-
continue
|
1098
|
-
|
1099
|
-
# If it is a MultiTokenizerWarpper, unwrap it and handle the inner request.
|
1100
|
-
if isinstance(recv_req, MultiTokenizerWarpper):
|
1275
|
+
# If it is a MultiTokenizerWrapper, unwrap it and handle the inner request.
|
1276
|
+
if isinstance(recv_req, MultiTokenizerWrapper):
|
1101
1277
|
worker_id = recv_req.worker_id
|
1102
1278
|
recv_req = recv_req.obj
|
1103
1279
|
output = self._request_dispatcher(recv_req)
|
1104
1280
|
if output is not None:
|
1105
|
-
output =
|
1281
|
+
output = MultiTokenizerWrapper(worker_id, output)
|
1106
1282
|
self.send_to_tokenizer.send_pyobj(output)
|
1107
1283
|
continue
|
1108
1284
|
|
@@ -1114,12 +1290,20 @@ class Scheduler(
|
|
1114
1290
|
else:
|
1115
1291
|
self.send_to_tokenizer.send_pyobj(output)
|
1116
1292
|
|
1293
|
+
def init_req_max_new_tokens(self, req):
|
1294
|
+
req.sampling_params.max_new_tokens = min(
|
1295
|
+
(
|
1296
|
+
req.sampling_params.max_new_tokens
|
1297
|
+
if req.sampling_params.max_new_tokens is not None
|
1298
|
+
else 1 << 30
|
1299
|
+
),
|
1300
|
+
self.max_req_len - len(req.origin_input_ids) - 1,
|
1301
|
+
)
|
1302
|
+
|
1117
1303
|
def handle_generate_request(
|
1118
1304
|
self,
|
1119
1305
|
recv_req: TokenizedGenerateReqInput,
|
1120
1306
|
):
|
1121
|
-
self.maybe_update_dp_balance_data(recv_req)
|
1122
|
-
|
1123
1307
|
# Create a new request
|
1124
1308
|
if (
|
1125
1309
|
recv_req.session_params is None
|
@@ -1153,8 +1337,13 @@ class Scheduler(
|
|
1153
1337
|
bootstrap_host=recv_req.bootstrap_host,
|
1154
1338
|
bootstrap_port=recv_req.bootstrap_port,
|
1155
1339
|
bootstrap_room=recv_req.bootstrap_room,
|
1340
|
+
disagg_mode=self.disaggregation_mode,
|
1156
1341
|
data_parallel_rank=recv_req.data_parallel_rank,
|
1157
1342
|
vocab_size=self.model_config.vocab_size,
|
1343
|
+
priority=recv_req.priority,
|
1344
|
+
metrics_collector=(
|
1345
|
+
self.metrics_collector if self.enable_metrics else None
|
1346
|
+
),
|
1158
1347
|
)
|
1159
1348
|
req.tokenizer = self.tokenizer
|
1160
1349
|
|
@@ -1177,6 +1366,7 @@ class Scheduler(
|
|
1177
1366
|
req.set_finish_with_abort(
|
1178
1367
|
f"Invalid request: session id {recv_req.session_params.id} does not exist"
|
1179
1368
|
)
|
1369
|
+
self.init_req_max_new_tokens(req)
|
1180
1370
|
self._add_request_to_queue(req)
|
1181
1371
|
return
|
1182
1372
|
else:
|
@@ -1184,6 +1374,7 @@ class Scheduler(
|
|
1184
1374
|
session = self.sessions[recv_req.session_params.id]
|
1185
1375
|
req = session.create_req(recv_req, self.tokenizer)
|
1186
1376
|
if isinstance(req.finished_reason, FINISH_ABORT):
|
1377
|
+
self.init_req_max_new_tokens(req)
|
1187
1378
|
self._add_request_to_queue(req)
|
1188
1379
|
return
|
1189
1380
|
|
@@ -1203,9 +1394,13 @@ class Scheduler(
|
|
1203
1394
|
f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
|
1204
1395
|
)
|
1205
1396
|
)
|
1397
|
+
self.init_req_max_new_tokens(req)
|
1206
1398
|
self._add_request_to_queue(req)
|
1207
1399
|
return
|
1208
1400
|
|
1401
|
+
# initialize before returning
|
1402
|
+
self.init_req_max_new_tokens(req)
|
1403
|
+
|
1209
1404
|
# Validate prompt length
|
1210
1405
|
error_msg = validate_input_length(
|
1211
1406
|
req,
|
@@ -1220,26 +1415,25 @@ class Scheduler(
|
|
1220
1415
|
# Copy more attributes
|
1221
1416
|
if recv_req.logprob_start_len == -1 or not recv_req.return_logprob:
|
1222
1417
|
# By default, only return the logprobs for output tokens
|
1223
|
-
|
1418
|
+
# For prefill-only requests with logprob_start_len == -1, set logprob_start_len beyond input sequence
|
1419
|
+
# to skip input logprob computation entirely
|
1420
|
+
if req.is_prefill_only:
|
1421
|
+
req.logprob_start_len = len(req.origin_input_ids)
|
1422
|
+
else:
|
1423
|
+
# TODO: For text generation, evaluate setting logprob_start_len to len(req.origin_input_ids) as well
|
1424
|
+
req.logprob_start_len = len(req.origin_input_ids) - 1
|
1224
1425
|
else:
|
1225
1426
|
req.logprob_start_len = recv_req.logprob_start_len
|
1226
1427
|
|
1227
|
-
if req.logprob_start_len >= len(
|
1428
|
+
if not req.is_prefill_only and req.logprob_start_len >= len(
|
1429
|
+
req.origin_input_ids
|
1430
|
+
):
|
1228
1431
|
error_msg = f"{req.logprob_start_len=} is higher than the number of input tokens {len(req.origin_input_ids)=}. Please use a smaller logprob_start_len."
|
1229
1432
|
req.logprob_start_len = len(req.origin_input_ids) - 1
|
1230
1433
|
req.set_finish_with_abort(error_msg)
|
1231
1434
|
self._add_request_to_queue(req)
|
1232
1435
|
return
|
1233
1436
|
|
1234
|
-
req.sampling_params.max_new_tokens = min(
|
1235
|
-
(
|
1236
|
-
req.sampling_params.max_new_tokens
|
1237
|
-
if req.sampling_params.max_new_tokens is not None
|
1238
|
-
else 1 << 30
|
1239
|
-
),
|
1240
|
-
self.max_req_len - len(req.origin_input_ids) - 1,
|
1241
|
-
)
|
1242
|
-
|
1243
1437
|
# Init grammar cache for this request
|
1244
1438
|
add_to_grammar_queue = False
|
1245
1439
|
if (
|
@@ -1270,7 +1464,6 @@ class Scheduler(
|
|
1270
1464
|
req.set_finish_with_abort(error_msg)
|
1271
1465
|
|
1272
1466
|
if add_to_grammar_queue:
|
1273
|
-
req.queue_time_start = time.perf_counter()
|
1274
1467
|
self.grammar_queue.append(req)
|
1275
1468
|
else:
|
1276
1469
|
self._add_request_to_queue(req)
|
@@ -1286,19 +1479,6 @@ class Scheduler(
|
|
1286
1479
|
for tokenized_req in recv_req:
|
1287
1480
|
self.handle_generate_request(tokenized_req)
|
1288
1481
|
|
1289
|
-
def _add_request_to_queue(self, req: Req):
|
1290
|
-
req.queue_time_start = time.perf_counter()
|
1291
|
-
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
1292
|
-
self._prefetch_kvcache(req)
|
1293
|
-
self.disagg_prefill_bootstrap_queue.add(
|
1294
|
-
req, self.model_config.num_key_value_heads
|
1295
|
-
)
|
1296
|
-
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
1297
|
-
self.disagg_decode_prealloc_queue.add(req)
|
1298
|
-
else:
|
1299
|
-
self._prefetch_kvcache(req)
|
1300
|
-
self.waiting_queue.append(req)
|
1301
|
-
|
1302
1482
|
def _prefetch_kvcache(self, req: Req):
|
1303
1483
|
if self.enable_hicache_storage:
|
1304
1484
|
req.init_next_round_input(self.tree_cache)
|
@@ -1312,16 +1492,87 @@ class Scheduler(
|
|
1312
1492
|
req.rid, req.last_host_node, new_input_tokens, last_hash
|
1313
1493
|
)
|
1314
1494
|
|
1315
|
-
def
|
1316
|
-
if self.disaggregation_mode == DisaggregationMode.
|
1317
|
-
self.
|
1318
|
-
|
1495
|
+
def _add_request_to_queue(self, req: Req, is_retracted: bool = False):
|
1496
|
+
if self.disaggregation_mode == DisaggregationMode.NULL:
|
1497
|
+
self._set_or_validate_priority(req)
|
1498
|
+
if self._abort_on_queued_limit(req):
|
1499
|
+
return
|
1500
|
+
self._prefetch_kvcache(req)
|
1501
|
+
self.waiting_queue.append(req)
|
1502
|
+
req.time_stats.wait_queue_entry_time = time.perf_counter()
|
1503
|
+
trace_slice_end("process req", req.rid, auto_next_anon=True)
|
1504
|
+
elif self.disaggregation_mode == DisaggregationMode.PREFILL:
|
1505
|
+
self._prefetch_kvcache(req)
|
1506
|
+
self.disagg_prefill_bootstrap_queue.add(
|
1507
|
+
req, self.model_config.num_key_value_heads
|
1319
1508
|
)
|
1509
|
+
req.time_stats.prefill_bootstrap_queue_entry_time = time.perf_counter()
|
1320
1510
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
1321
|
-
|
1322
|
-
|
1511
|
+
self.disagg_decode_prealloc_queue.add(req, is_retracted=is_retracted)
|
1512
|
+
if not is_retracted:
|
1513
|
+
req.time_stats.decode_prealloc_queue_entry_time = time.perf_counter()
|
1323
1514
|
else:
|
1324
|
-
self.
|
1515
|
+
raise ValueError(f"Invalid {self.disaggregation_mode=}")
|
1516
|
+
|
1517
|
+
def _set_or_validate_priority(self, req: Req):
|
1518
|
+
"""Set the default priority value, or abort the request based on the priority scheduling mode."""
|
1519
|
+
if self.enable_priority_scheduling and req.priority is None:
|
1520
|
+
if self.schedule_low_priority_values_first:
|
1521
|
+
req.priority = sys.maxsize
|
1522
|
+
else:
|
1523
|
+
req.priority = -sys.maxsize - 1
|
1524
|
+
elif not self.enable_priority_scheduling and req.priority is not None:
|
1525
|
+
abort_req = AbortReq(
|
1526
|
+
finished_reason={
|
1527
|
+
"type": "abort",
|
1528
|
+
"status_code": HTTPStatus.SERVICE_UNAVAILABLE,
|
1529
|
+
"message": "Using priority is disabled for this server. Please send a new request without a priority.",
|
1530
|
+
},
|
1531
|
+
rid=req.rid,
|
1532
|
+
)
|
1533
|
+
self.send_to_tokenizer.send_pyobj(abort_req)
|
1534
|
+
|
1535
|
+
def _abort_on_queued_limit(self, recv_req: Req) -> bool:
|
1536
|
+
"""Abort an incoming or existing request if the waiting queue is full. Returns True if the incoming request is aborted."""
|
1537
|
+
if (
|
1538
|
+
self.max_queued_requests is None
|
1539
|
+
or len(self.waiting_queue) + 1 <= self.max_queued_requests
|
1540
|
+
):
|
1541
|
+
return False
|
1542
|
+
|
1543
|
+
# Reject the incoming request by default.
|
1544
|
+
req_to_abort = recv_req
|
1545
|
+
message = "The request queue is full."
|
1546
|
+
if self.enable_priority_scheduling:
|
1547
|
+
# With priority scheduling, consider aboritng an existing request based on the priority.
|
1548
|
+
# direction = 1 => smaller number = higher priority; -1 => larger number = higher priority.
|
1549
|
+
# max(...) + (direction * priority, queue_time_start) picks the least-preferred request.
|
1550
|
+
# Tie: later queue_time_start (newer) is evicted first. Preempt only if strictly better.
|
1551
|
+
direction = 1 if self.schedule_low_priority_values_first else -1
|
1552
|
+
key_fn = lambda item: (
|
1553
|
+
direction * item[1].priority,
|
1554
|
+
item[1].time_stats.wait_queue_entry_time,
|
1555
|
+
)
|
1556
|
+
idx, candidate_req = max(enumerate(self.waiting_queue), key=key_fn)
|
1557
|
+
abort_existing_req = (
|
1558
|
+
direction * recv_req.priority < direction * candidate_req.priority
|
1559
|
+
)
|
1560
|
+
if abort_existing_req:
|
1561
|
+
self.waiting_queue.pop(idx)
|
1562
|
+
req_to_abort = candidate_req
|
1563
|
+
message = "The request is aborted by a higher priority request."
|
1564
|
+
|
1565
|
+
self.send_to_tokenizer.send_pyobj(
|
1566
|
+
AbortReq(
|
1567
|
+
finished_reason={
|
1568
|
+
"type": "abort",
|
1569
|
+
"status_code": HTTPStatus.SERVICE_UNAVAILABLE,
|
1570
|
+
"message": message,
|
1571
|
+
},
|
1572
|
+
rid=req_to_abort.rid,
|
1573
|
+
)
|
1574
|
+
)
|
1575
|
+
return req_to_abort.rid == recv_req.rid
|
1325
1576
|
|
1326
1577
|
def handle_embedding_request(
|
1327
1578
|
self,
|
@@ -1333,6 +1584,7 @@ class Scheduler(
|
|
1333
1584
|
recv_req.input_ids,
|
1334
1585
|
recv_req.sampling_params,
|
1335
1586
|
token_type_ids=recv_req.token_type_ids,
|
1587
|
+
priority=recv_req.priority,
|
1336
1588
|
)
|
1337
1589
|
req.tokenizer = self.tokenizer
|
1338
1590
|
|
@@ -1409,9 +1661,11 @@ class Scheduler(
|
|
1409
1661
|
_, _, available_size, evictable_size = self._get_token_info()
|
1410
1662
|
protected_size = self.tree_cache.protected_size()
|
1411
1663
|
memory_leak = (available_size + evictable_size) != (
|
1664
|
+
# self.max_total_num_tokens
|
1665
|
+
# if not self.enable_hierarchical_cache
|
1666
|
+
# else self.max_total_num_tokens - protected_size
|
1412
1667
|
self.max_total_num_tokens
|
1413
|
-
|
1414
|
-
else self.max_total_num_tokens - protected_size
|
1668
|
+
- protected_size
|
1415
1669
|
)
|
1416
1670
|
token_msg = f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, {protected_size=}\n"
|
1417
1671
|
|
@@ -1462,6 +1716,20 @@ class Scheduler(
|
|
1462
1716
|
self.stats.gen_throughput = 0
|
1463
1717
|
self.stats.num_queue_reqs = len(self.waiting_queue)
|
1464
1718
|
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
|
1719
|
+
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
1720
|
+
self.stats.num_prefill_prealloc_queue_reqs = len(
|
1721
|
+
self.disagg_prefill_bootstrap_queue.queue
|
1722
|
+
)
|
1723
|
+
self.stats.num_prefill_inflight_queue_reqs = len(
|
1724
|
+
self.disagg_prefill_inflight_queue
|
1725
|
+
)
|
1726
|
+
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
1727
|
+
self.stats.num_decode_prealloc_queue_reqs = len(
|
1728
|
+
self.disagg_decode_prealloc_queue.queue
|
1729
|
+
)
|
1730
|
+
self.stats.num_decode_transfer_queue_reqs = len(
|
1731
|
+
self.disagg_decode_transfer_queue.queue
|
1732
|
+
)
|
1465
1733
|
self.metrics_collector.log_stats(self.stats)
|
1466
1734
|
self._publish_kv_events()
|
1467
1735
|
|
@@ -1509,7 +1777,12 @@ class Scheduler(
|
|
1509
1777
|
chunked_req_to_exclude.add(self.chunked_req)
|
1510
1778
|
self.tree_cache.cache_unfinished_req(self.chunked_req, chunked=True)
|
1511
1779
|
# chunked request keeps its rid but will get a new req_pool_idx
|
1512
|
-
self.
|
1780
|
+
if self.tp_worker.worker.model_runner.mambaish_config is not None:
|
1781
|
+
self.req_to_token_pool.free(
|
1782
|
+
self.chunked_req.req_pool_idx, free_mamba_cache=False
|
1783
|
+
)
|
1784
|
+
else:
|
1785
|
+
self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
|
1513
1786
|
if self.last_batch and self.last_batch.forward_mode.is_extend():
|
1514
1787
|
if self.last_batch.chunked_req is not None:
|
1515
1788
|
# In the context pipeline parallelism, after the last chunk, the current microbatch still track outdated chunked_req.
|
@@ -1556,13 +1829,12 @@ class Scheduler(
|
|
1556
1829
|
|
1557
1830
|
# Handle DP attention
|
1558
1831
|
if need_dp_attn_preparation:
|
1559
|
-
self.maybe_handle_dp_balance_data()
|
1560
1832
|
ret = self.prepare_mlp_sync_batch(ret)
|
1561
1833
|
|
1562
1834
|
return ret
|
1563
1835
|
|
1564
1836
|
def get_num_allocatable_reqs(self, running_bs):
|
1565
|
-
res = global_server_args_dict["
|
1837
|
+
res = global_server_args_dict["pp_max_micro_batch_size"] - running_bs
|
1566
1838
|
if self.pp_size > 1:
|
1567
1839
|
res = min(res, self.req_to_token_pool.available_size())
|
1568
1840
|
return res
|
@@ -1572,6 +1844,10 @@ class Scheduler(
|
|
1572
1844
|
if self.grammar_queue:
|
1573
1845
|
self.move_ready_grammar_requests()
|
1574
1846
|
|
1847
|
+
if self.try_preemption:
|
1848
|
+
# Reset batch_is_full to try preemption with a prefill adder.
|
1849
|
+
self.running_batch.batch_is_full = False
|
1850
|
+
|
1575
1851
|
# Handle the cases where prefill is not allowed
|
1576
1852
|
if (
|
1577
1853
|
self.running_batch.batch_is_full or len(self.waiting_queue) == 0
|
@@ -1584,7 +1860,11 @@ class Scheduler(
|
|
1584
1860
|
# as the space for the chunked request has just been released.
|
1585
1861
|
# In PP case, a chunked req can start in one microbatch and end in another microbatch, so the max_running_requests per microbatch should not be strict.
|
1586
1862
|
# Instead, we should always allow chunked request to be added, otherwise, there will be a memory leak.
|
1587
|
-
if
|
1863
|
+
if (
|
1864
|
+
self.get_num_allocatable_reqs(running_bs) <= 0
|
1865
|
+
and not self.chunked_req
|
1866
|
+
and not self.try_preemption
|
1867
|
+
):
|
1588
1868
|
self.running_batch.batch_is_full = True
|
1589
1869
|
return None
|
1590
1870
|
|
@@ -1604,6 +1884,7 @@ class Scheduler(
|
|
1604
1884
|
self.max_prefill_tokens,
|
1605
1885
|
self.chunked_prefill_size,
|
1606
1886
|
running_bs if self.is_mixed_chunk else 0,
|
1887
|
+
self.priority_scheduling_preemption_threshold,
|
1607
1888
|
)
|
1608
1889
|
|
1609
1890
|
if self.chunked_req is not None:
|
@@ -1624,15 +1905,19 @@ class Scheduler(
|
|
1624
1905
|
self.running_batch.batch_is_full = True
|
1625
1906
|
break
|
1626
1907
|
|
1908
|
+
running_bs = len(self.running_batch.reqs) - len(adder.preempt_list)
|
1627
1909
|
if len(adder.can_run_list) >= self.get_num_allocatable_reqs(running_bs):
|
1628
1910
|
self.running_batch.batch_is_full = True
|
1629
|
-
break
|
1630
|
-
|
1631
1911
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
1632
1912
|
# In prefill mode, prealloc queue and transfer queue can also take memory,
|
1633
1913
|
# so we need to check if the available size for the actual available size.
|
1634
1914
|
if len(adder.can_run_list) >= self.req_to_token_pool.available_size():
|
1635
1915
|
self.running_batch.batch_is_full = True
|
1916
|
+
|
1917
|
+
if self.running_batch.batch_is_full:
|
1918
|
+
if not self.try_preemption:
|
1919
|
+
break
|
1920
|
+
if not adder.preempt_to_schedule(req, self.server_args):
|
1636
1921
|
break
|
1637
1922
|
|
1638
1923
|
if self.enable_hicache_storage:
|
@@ -1642,7 +1927,11 @@ class Scheduler(
|
|
1642
1927
|
continue
|
1643
1928
|
|
1644
1929
|
req.init_next_round_input(self.tree_cache)
|
1645
|
-
res = adder.add_one_req(
|
1930
|
+
res = adder.add_one_req(
|
1931
|
+
req,
|
1932
|
+
has_chunked_req=(self.chunked_req is not None),
|
1933
|
+
truncation_align_size=self.truncation_align_size,
|
1934
|
+
)
|
1646
1935
|
|
1647
1936
|
if res != AddReqResult.CONTINUE:
|
1648
1937
|
if res == AddReqResult.NO_TOKEN:
|
@@ -1663,11 +1952,14 @@ class Scheduler(
|
|
1663
1952
|
if self.enable_metrics:
|
1664
1953
|
# only record queue time when enable_metrics is True to avoid overhead
|
1665
1954
|
for req in can_run_list:
|
1666
|
-
req.
|
1955
|
+
req.add_latency(RequestStage.PREFILL_WAITING)
|
1667
1956
|
|
1668
1957
|
self.waiting_queue = [
|
1669
1958
|
x for x in self.waiting_queue if x not in set(can_run_list)
|
1670
1959
|
]
|
1960
|
+
if adder.preempt_list:
|
1961
|
+
for req in adder.preempt_list:
|
1962
|
+
self._add_request_to_queue(req)
|
1671
1963
|
|
1672
1964
|
if adder.new_chunked_req is not None:
|
1673
1965
|
assert self.chunked_req is None
|
@@ -1678,7 +1970,16 @@ class Scheduler(
|
|
1678
1970
|
|
1679
1971
|
# Print stats
|
1680
1972
|
if self.current_scheduler_metrics_enabled():
|
1681
|
-
self.log_prefill_stats(adder, can_run_list, running_bs)
|
1973
|
+
self.log_prefill_stats(adder, can_run_list, running_bs, 0)
|
1974
|
+
|
1975
|
+
for req in can_run_list:
|
1976
|
+
if req.time_stats.forward_entry_time == 0:
|
1977
|
+
# Avoid update chunked request many times
|
1978
|
+
req.time_stats.forward_entry_time = time.perf_counter()
|
1979
|
+
if self.enable_metrics:
|
1980
|
+
self.metrics_collector.observe_queue_time(
|
1981
|
+
req.time_stats.get_queueing_time(),
|
1982
|
+
)
|
1682
1983
|
|
1683
1984
|
# Create a new batch
|
1684
1985
|
new_batch = ScheduleBatch.init_new(
|
@@ -1733,19 +2034,25 @@ class Scheduler(
|
|
1733
2034
|
TEST_RETRACT and batch.batch_size() > 10
|
1734
2035
|
):
|
1735
2036
|
old_ratio = self.new_token_ratio
|
1736
|
-
|
1737
|
-
|
1738
|
-
|
2037
|
+
retracted_reqs, new_token_ratio, reqs_to_abort = batch.retract_decode(
|
2038
|
+
self.server_args
|
2039
|
+
)
|
2040
|
+
self.num_retracted_reqs = len(retracted_reqs)
|
1739
2041
|
self.new_token_ratio = new_token_ratio
|
2042
|
+
for req in reqs_to_abort:
|
2043
|
+
self.send_to_tokenizer.send_pyobj(
|
2044
|
+
AbortReq(abort_reason=req.to_abort_message, rid=req.rid)
|
2045
|
+
)
|
1740
2046
|
|
1741
2047
|
logger.info(
|
1742
2048
|
"KV cache pool is full. Retract requests. "
|
1743
|
-
f"#retracted_reqs: {
|
1744
|
-
f"#
|
2049
|
+
f"#retracted_reqs: {len(retracted_reqs)}, "
|
2050
|
+
f"#aborted_retracted_reqs: {len(reqs_to_abort)}, "
|
2051
|
+
f"#new_token_ratio: {old_ratio:.4f} -> {new_token_ratio:.4f}"
|
1745
2052
|
)
|
1746
2053
|
|
1747
|
-
|
1748
|
-
|
2054
|
+
for req in retracted_reqs:
|
2055
|
+
self._add_request_to_queue(req, is_retracted=True)
|
1749
2056
|
else:
|
1750
2057
|
self.new_token_ratio = max(
|
1751
2058
|
self.new_token_ratio - self.new_token_ratio_decay,
|
@@ -1773,37 +2080,66 @@ class Scheduler(
|
|
1773
2080
|
|
1774
2081
|
# Run forward
|
1775
2082
|
if self.is_generation:
|
2083
|
+
|
2084
|
+
batch_or_worker_batch = batch
|
2085
|
+
|
1776
2086
|
if self.spec_algorithm.is_none():
|
1777
|
-
|
2087
|
+
# FIXME(lsyin): remove this if and finally unify the abstraction
|
2088
|
+
batch_or_worker_batch = batch.get_model_worker_batch()
|
1778
2089
|
|
1779
|
-
|
1780
|
-
|
1781
|
-
|
2090
|
+
if self.enable_overlap:
|
2091
|
+
# FIXME: remove this assert
|
2092
|
+
assert isinstance(batch_or_worker_batch, ModelWorkerBatch)
|
2093
|
+
model_worker_batch = batch_or_worker_batch
|
2094
|
+
self.record_batch_in_overlap(model_worker_batch)
|
2095
|
+
|
2096
|
+
# Sampling info will be modified during forward
|
2097
|
+
model_worker_batch.sampling_info = (
|
2098
|
+
model_worker_batch.sampling_info.copy_for_forward()
|
1782
2099
|
)
|
1783
|
-
|
1784
|
-
|
1785
|
-
|
1786
|
-
|
1787
|
-
|
1788
|
-
|
1789
|
-
|
2100
|
+
|
2101
|
+
bs = len(model_worker_batch.seq_lens)
|
2102
|
+
future_indices = self.future_map.alloc_future_indices(bs)
|
2103
|
+
|
2104
|
+
with self.forward_stream_ctx:
|
2105
|
+
self.forward_stream.wait_stream(self.default_stream)
|
2106
|
+
self.future_map.resolve_future(model_worker_batch)
|
2107
|
+
if batch.sampling_info.grammars is not None:
|
2108
|
+
model_worker_batch.delay_sample_launch = True
|
2109
|
+
batch_result = self.model_worker.forward_batch_generation(
|
2110
|
+
batch_or_worker_batch
|
1790
2111
|
)
|
1791
|
-
|
2112
|
+
# FIXME(lsyin): maybe move this to forward_batch_generation
|
2113
|
+
batch_result.copy_done = torch.get_device_module(
|
2114
|
+
self.device
|
2115
|
+
).Event()
|
2116
|
+
if not model_worker_batch.delay_sample_launch:
|
2117
|
+
self.future_map.store_to_map(
|
2118
|
+
future_indices, batch_result.next_token_ids
|
2119
|
+
)
|
2120
|
+
batch_result.copy_to_cpu()
|
2121
|
+
else:
|
2122
|
+
batch_result.future_indices = future_indices
|
2123
|
+
|
2124
|
+
# FIXME(lsyin): move this assignment elsewhere
|
2125
|
+
maybe_future_next_token_ids = -future_indices.indices
|
1792
2126
|
else:
|
1793
|
-
(
|
1794
|
-
|
1795
|
-
|
1796
|
-
|
1797
|
-
|
1798
|
-
|
1799
|
-
)
|
1800
|
-
|
1801
|
-
|
1802
|
-
|
1803
|
-
|
1804
|
-
|
1805
|
-
|
1806
|
-
|
2127
|
+
batch_result = self.model_worker.forward_batch_generation(
|
2128
|
+
batch_or_worker_batch
|
2129
|
+
)
|
2130
|
+
maybe_future_next_token_ids = batch_result.next_token_ids
|
2131
|
+
|
2132
|
+
if not self.spec_algorithm.is_none():
|
2133
|
+
# TODO(lsyin): unify this metric-updating logic with non-spec, and move it to decode processing
|
2134
|
+
self.update_spec_metrics(
|
2135
|
+
batch.batch_size(), batch_result.num_accepted_tokens
|
2136
|
+
)
|
2137
|
+
|
2138
|
+
# NOTE: maybe_future_next_token_ids is used in ScheduleBatch,
|
2139
|
+
# which can probably be replaced by future_indices later [TODO(lsyin)].
|
2140
|
+
# we shall still keep the original outputs, e.g. next_token_ids
|
2141
|
+
# in the GenerationBatchOutput for processing after copy_done.
|
2142
|
+
batch.output_ids = maybe_future_next_token_ids
|
1807
2143
|
|
1808
2144
|
# These 2 values are needed for processing the output, but the values can be
|
1809
2145
|
# modified by overlap schedule. So we have to copy them here so that
|
@@ -1812,6 +2148,7 @@ class Scheduler(
|
|
1812
2148
|
extend_input_len_per_req = [req.extend_input_len for req in batch.reqs]
|
1813
2149
|
else:
|
1814
2150
|
extend_input_len_per_req = None
|
2151
|
+
|
1815
2152
|
if batch.return_logprob:
|
1816
2153
|
extend_logprob_start_len_per_req = [
|
1817
2154
|
req.extend_logprob_start_len for req in batch.reqs
|
@@ -1819,43 +2156,60 @@ class Scheduler(
|
|
1819
2156
|
else:
|
1820
2157
|
extend_logprob_start_len_per_req = None
|
1821
2158
|
|
1822
|
-
|
1823
|
-
|
1824
|
-
|
1825
|
-
pp_hidden_states_proxy_tensors
|
1826
|
-
if not self.pp_group.is_last_rank
|
1827
|
-
else None
|
1828
|
-
),
|
1829
|
-
next_token_ids=next_token_ids if self.pp_group.is_last_rank else None,
|
1830
|
-
extend_input_len_per_req=extend_input_len_per_req,
|
1831
|
-
extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
|
1832
|
-
bid=bid,
|
1833
|
-
can_run_cuda_graph=can_run_cuda_graph,
|
2159
|
+
batch_result.extend_input_len_per_req = extend_input_len_per_req
|
2160
|
+
batch_result.extend_logprob_start_len_per_req = (
|
2161
|
+
extend_logprob_start_len_per_req
|
1834
2162
|
)
|
2163
|
+
return batch_result
|
1835
2164
|
else: # embedding or reward model
|
1836
2165
|
model_worker_batch = batch.get_model_worker_batch()
|
1837
2166
|
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
|
1838
|
-
ret = EmbeddingBatchResult(
|
1839
|
-
embeddings=embeddings, bid=model_worker_batch.bid
|
1840
|
-
)
|
2167
|
+
ret = EmbeddingBatchResult(embeddings=embeddings)
|
1841
2168
|
return ret
|
1842
2169
|
|
2170
|
+
def launch_last_batch_sample_if_needed(
|
2171
|
+
self,
|
2172
|
+
) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
|
2173
|
+
if len(self.result_queue) == 0:
|
2174
|
+
return
|
2175
|
+
|
2176
|
+
tmp_batch, tmp_result = self.result_queue.popleft()
|
2177
|
+
|
2178
|
+
tmp_result: GenerationBatchResult
|
2179
|
+
if not tmp_result.delay_sample_launch:
|
2180
|
+
self.result_queue.appendleft((tmp_batch, tmp_result))
|
2181
|
+
return
|
2182
|
+
|
2183
|
+
with self.forward_stream_ctx:
|
2184
|
+
self.forward_stream.wait_stream(self.default_stream)
|
2185
|
+
tmp_result.next_token_ids = self.model_worker.model_runner.sample(
|
2186
|
+
tmp_result.logits_output,
|
2187
|
+
tmp_result.forward_batch,
|
2188
|
+
)
|
2189
|
+
future_indices = tmp_result.future_indices
|
2190
|
+
self.future_map.store_to_map(future_indices, tmp_result.next_token_ids)
|
2191
|
+
tmp_result.copy_to_cpu()
|
2192
|
+
self.result_queue.appendleft((tmp_batch, tmp_result))
|
2193
|
+
|
1843
2194
|
def process_batch_result(
|
1844
2195
|
self,
|
1845
2196
|
batch: ScheduleBatch,
|
1846
2197
|
result: Union[GenerationBatchResult, EmbeddingBatchResult],
|
1847
|
-
launch_done: Optional[threading.Event] = None,
|
1848
2198
|
):
|
1849
2199
|
if batch.forward_mode.is_decode():
|
1850
|
-
self.process_batch_result_decode(batch, result
|
2200
|
+
self.process_batch_result_decode(batch, result)
|
2201
|
+
if self.enable_trace:
|
2202
|
+
trace_slice_batch("decode loop", batch.reqs)
|
2203
|
+
|
1851
2204
|
elif batch.forward_mode.is_extend():
|
1852
|
-
self.process_batch_result_prefill(batch, result
|
2205
|
+
self.process_batch_result_prefill(batch, result)
|
2206
|
+
if self.enable_trace:
|
2207
|
+
trace_slice_batch("prefill", batch.reqs)
|
2208
|
+
|
1853
2209
|
elif batch.forward_mode.is_idle():
|
1854
2210
|
if self.enable_overlap:
|
1855
|
-
|
1856
|
-
|
1857
|
-
elif batch.forward_mode.is_dummy_first():
|
1858
|
-
self.set_next_batch_sampling_info_done(batch)
|
2211
|
+
if result.copy_done is not None:
|
2212
|
+
result.copy_done.synchronize()
|
1859
2213
|
|
1860
2214
|
self.maybe_send_health_check_signal()
|
1861
2215
|
|
@@ -2008,12 +2362,13 @@ class Scheduler(
|
|
2008
2362
|
if req.finished(): # It is aborted by AbortReq
|
2009
2363
|
num_ready_reqs += 1
|
2010
2364
|
continue
|
2365
|
+
|
2011
2366
|
req.grammar = req.grammar.result(timeout=0.03)
|
2012
2367
|
self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
|
2013
2368
|
if req.grammar is INVALID_GRAMMAR_OBJ:
|
2014
|
-
req.
|
2015
|
-
|
2016
|
-
|
2369
|
+
error_msg = f"Invalid grammar request: {req.grammar_key=}"
|
2370
|
+
req.set_finish_with_abort(error_msg)
|
2371
|
+
|
2017
2372
|
num_ready_reqs += 1
|
2018
2373
|
except futures._base.TimeoutError:
|
2019
2374
|
req.grammar_wait_ct += 1
|
@@ -2045,9 +2400,8 @@ class Scheduler(
|
|
2045
2400
|
req.grammar = req.grammar.result()
|
2046
2401
|
self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
|
2047
2402
|
if req.grammar is INVALID_GRAMMAR_OBJ:
|
2048
|
-
req.
|
2049
|
-
|
2050
|
-
)
|
2403
|
+
error_msg = f"Invalid grammar request: {req.grammar_key=}"
|
2404
|
+
req.set_finish_with_abort(error_msg)
|
2051
2405
|
else:
|
2052
2406
|
num_ready_reqs_max = num_ready_reqs
|
2053
2407
|
num_timeout_reqs_max = num_timeout_reqs
|
@@ -2055,21 +2409,16 @@ class Scheduler(
|
|
2055
2409
|
for i in range(num_ready_reqs, num_ready_reqs + num_timeout_reqs_max):
|
2056
2410
|
req = self.grammar_queue[i]
|
2057
2411
|
req.grammar.cancel()
|
2412
|
+
self.grammar_backend.set_cache(req.grammar_key, INVALID_GRAMMAR_OBJ)
|
2058
2413
|
error_msg = f"Grammar preprocessing timed out for {req.grammar_key=}"
|
2059
2414
|
req.set_finish_with_abort(error_msg)
|
2060
|
-
|
2415
|
+
|
2061
2416
|
num_ready_reqs = num_ready_reqs_max + num_timeout_reqs_max
|
2062
2417
|
|
2063
|
-
self.
|
2418
|
+
for req in self.grammar_queue[:num_ready_reqs]:
|
2419
|
+
self._add_request_to_queue(req)
|
2064
2420
|
self.grammar_queue = self.grammar_queue[num_ready_reqs:]
|
2065
2421
|
|
2066
|
-
def set_next_batch_sampling_info_done(self, batch: ScheduleBatch):
|
2067
|
-
if batch.next_batch_sampling_info:
|
2068
|
-
if batch.next_batch_sampling_info.grammars is not None:
|
2069
|
-
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
2070
|
-
self.current_stream.synchronize()
|
2071
|
-
batch.next_batch_sampling_info.sampling_info_done.set()
|
2072
|
-
|
2073
2422
|
def watchdog_thread(self):
|
2074
2423
|
"""A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
|
2075
2424
|
self.watchdog_last_forward_ct = 0
|
@@ -2152,9 +2501,8 @@ class Scheduler(
|
|
2152
2501
|
self.req_to_token_pool.clear()
|
2153
2502
|
self.token_to_kv_pool_allocator.clear()
|
2154
2503
|
|
2155
|
-
if
|
2156
|
-
self.draft_worker.
|
2157
|
-
self.draft_worker.model_runner.token_to_kv_pool_allocator.clear()
|
2504
|
+
if self.draft_worker:
|
2505
|
+
self.draft_worker.clear_cache_pool()
|
2158
2506
|
|
2159
2507
|
self.num_generated_tokens = 0
|
2160
2508
|
self.forward_ct_decode = 0
|
@@ -2174,39 +2522,50 @@ class Scheduler(
|
|
2174
2522
|
if_success = False
|
2175
2523
|
return if_success
|
2176
2524
|
|
2177
|
-
def get_load(self):
|
2525
|
+
def get_load(self, recv_req: GetLoadReqInput = None) -> GetLoadReqOutput:
|
2178
2526
|
# TODO(lsyin): use dynamically maintained num_waiting_tokens
|
2527
|
+
|
2179
2528
|
if self.is_hybrid:
|
2180
|
-
|
2529
|
+
num_tokens_full = (
|
2181
2530
|
self.full_tokens_per_layer
|
2182
2531
|
- self.token_to_kv_pool_allocator.full_available_size()
|
2183
2532
|
- self.tree_cache.full_evictable_size()
|
2184
2533
|
)
|
2185
|
-
|
2534
|
+
num_tokens_swa = (
|
2186
2535
|
self.swa_tokens_per_layer
|
2187
2536
|
- self.token_to_kv_pool_allocator.swa_available_size()
|
2188
2537
|
- self.tree_cache.swa_evictable_size()
|
2189
2538
|
)
|
2190
|
-
|
2539
|
+
num_tokens = max(num_tokens_full, num_tokens_swa)
|
2191
2540
|
else:
|
2192
|
-
|
2541
|
+
num_tokens = (
|
2193
2542
|
self.max_total_num_tokens
|
2194
2543
|
- self.token_to_kv_pool_allocator.available_size()
|
2195
2544
|
- self.tree_cache.evictable_size()
|
2196
2545
|
)
|
2197
|
-
|
2546
|
+
|
2547
|
+
# Tokens in waiting queue, bootstrap queue, prealloc queue
|
2548
|
+
num_tokens += sum(len(req.origin_input_ids) for req in self.waiting_queue)
|
2549
|
+
num_waiting_reqs = len(self.waiting_queue)
|
2198
2550
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
2199
|
-
|
2551
|
+
num_tokens += sum(
|
2200
2552
|
len(req.origin_input_ids)
|
2201
2553
|
for req in self.disagg_prefill_bootstrap_queue.queue
|
2202
2554
|
)
|
2555
|
+
num_waiting_reqs += len(self.disagg_prefill_bootstrap_queue.queue)
|
2203
2556
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
2204
|
-
|
2557
|
+
num_tokens += sum(
|
2205
2558
|
len(req.req.origin_input_ids)
|
2206
2559
|
for req in self.disagg_decode_prealloc_queue.queue
|
2207
2560
|
)
|
2561
|
+
num_waiting_reqs += len(self.disagg_decode_prealloc_queue.queue)
|
2208
2562
|
|
2209
|
-
return
|
2563
|
+
return GetLoadReqOutput(
|
2564
|
+
dp_rank=self.dp_rank,
|
2565
|
+
num_reqs=len(self.running_batch.reqs) + num_waiting_reqs,
|
2566
|
+
num_waiting_reqs=num_waiting_reqs,
|
2567
|
+
num_tokens=num_tokens,
|
2568
|
+
)
|
2210
2569
|
|
2211
2570
|
def get_internal_state(self, recv_req: GetInternalStateReq):
|
2212
2571
|
ret = dict(global_server_args_dict)
|
@@ -2221,10 +2580,9 @@ class Scheduler(
|
|
2221
2580
|
"token_capacity": int(self.max_total_num_tokens),
|
2222
2581
|
}
|
2223
2582
|
|
2224
|
-
|
2225
|
-
|
2226
|
-
|
2227
|
-
)
|
2583
|
+
ret["memory_usage"]["graph"] = round(
|
2584
|
+
self.tp_worker.worker.model_runner.graph_mem_usage, 2
|
2585
|
+
)
|
2228
2586
|
|
2229
2587
|
if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
|
2230
2588
|
ret["avg_spec_accept_length"] = (
|
@@ -2233,15 +2591,13 @@ class Scheduler(
|
|
2233
2591
|
if RECORD_STEP_TIME:
|
2234
2592
|
ret["step_time_dict"] = self.step_time_dict
|
2235
2593
|
|
2236
|
-
ret["load"] = self.get_load()
|
2237
|
-
|
2238
2594
|
return GetInternalStateReqOutput(internal_state=ret)
|
2239
2595
|
|
2240
2596
|
def set_internal_state(self, recv_req: SetInternalStateReq):
|
2241
2597
|
server_args_dict = recv_req.server_args
|
2242
2598
|
args_allow_update = set(
|
2243
2599
|
[
|
2244
|
-
"
|
2600
|
+
"pp_max_micro_batch_size",
|
2245
2601
|
"speculative_accept_threshold_single",
|
2246
2602
|
"speculative_accept_threshold_acc",
|
2247
2603
|
]
|
@@ -2252,7 +2608,7 @@ class Scheduler(
|
|
2252
2608
|
logging.warning(f"Updating {k} is not supported.")
|
2253
2609
|
if_success = False
|
2254
2610
|
break
|
2255
|
-
elif k == "
|
2611
|
+
elif k == "pp_max_micro_batch_size" and (
|
2256
2612
|
v > self.max_running_requests // self.pp_size or v < 1
|
2257
2613
|
):
|
2258
2614
|
logging.warning(
|
@@ -2310,7 +2666,7 @@ class Scheduler(
|
|
2310
2666
|
if self.enable_hicache_storage:
|
2311
2667
|
# to release prefetch events associated with the request
|
2312
2668
|
self.tree_cache.release_aborted_request(req.rid)
|
2313
|
-
self.send_to_tokenizer.send_pyobj(AbortReq(req.rid))
|
2669
|
+
self.send_to_tokenizer.send_pyobj(AbortReq(rid=req.rid))
|
2314
2670
|
# For disaggregation decode mode, the request in the waiting queue has KV cache allocated.
|
2315
2671
|
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
2316
2672
|
self.tree_cache.cache_finished_req(req)
|
@@ -2331,31 +2687,31 @@ class Scheduler(
|
|
2331
2687
|
# Delete requests not in the waiting queue when PD disaggregation is enabled
|
2332
2688
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
2333
2689
|
# Abort requests that have not yet been bootstrapped
|
2334
|
-
for
|
2335
|
-
logger.debug(f"Abort bootstrap queue request. {req.rid=}")
|
2690
|
+
for req in self.disagg_prefill_bootstrap_queue.queue:
|
2336
2691
|
if recv_req.abort_all or req.rid.startswith(recv_req.rid):
|
2692
|
+
logger.debug(f"Abort bootstrap queue request. {req.rid=}")
|
2337
2693
|
if hasattr(req.disagg_kv_sender, "abort"):
|
2338
2694
|
req.disagg_kv_sender.abort()
|
2339
2695
|
|
2340
2696
|
# Abort in-flight requests
|
2341
|
-
for
|
2342
|
-
logger.debug(f"Abort inflight queue request. {req.rid=}")
|
2697
|
+
for req in self.disagg_prefill_inflight_queue:
|
2343
2698
|
if recv_req.abort_all or req.rid.startswith(recv_req.rid):
|
2699
|
+
logger.debug(f"Abort inflight queue request. {req.rid=}")
|
2344
2700
|
if hasattr(req.disagg_kv_sender, "abort"):
|
2345
2701
|
req.disagg_kv_sender.abort()
|
2346
2702
|
|
2347
2703
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
2348
2704
|
# Abort requests that have not yet finished preallocation
|
2349
|
-
for
|
2350
|
-
logger.debug(f"Abort prealloc queue request. {decode_req.req.rid=}")
|
2705
|
+
for decode_req in self.disagg_decode_prealloc_queue.queue:
|
2351
2706
|
if recv_req.abort_all or decode_req.req.rid.startswith(recv_req.rid):
|
2707
|
+
logger.debug(f"Abort prealloc queue request. {decode_req.req.rid=}")
|
2352
2708
|
if hasattr(decode_req.kv_receiver, "abort"):
|
2353
2709
|
decode_req.kv_receiver.abort()
|
2354
2710
|
|
2355
2711
|
# Abort requests waiting for kvcache to release tree cache
|
2356
|
-
for
|
2357
|
-
logger.debug(f"Abort transfer queue request. {decode_req.req.rid=}")
|
2712
|
+
for decode_req in self.disagg_decode_transfer_queue.queue:
|
2358
2713
|
if recv_req.abort_all or decode_req.req.rid.startswith(recv_req.rid):
|
2714
|
+
logger.debug(f"Abort transfer queue request. {decode_req.req.rid=}")
|
2359
2715
|
if hasattr(decode_req.kv_receiver, "abort"):
|
2360
2716
|
decode_req.kv_receiver.abort()
|
2361
2717
|
|
@@ -2398,6 +2754,22 @@ class Scheduler(
|
|
2398
2754
|
self.send_to_detokenizer.send_pyobj(recv_req)
|
2399
2755
|
return recv_req
|
2400
2756
|
|
2757
|
+
def init_weights_send_group_for_remote_instance(
|
2758
|
+
self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
|
2759
|
+
):
|
2760
|
+
"""Init the seed and client instance communication group."""
|
2761
|
+
success, message = self.tp_worker.init_weights_send_group_for_remote_instance(
|
2762
|
+
recv_req
|
2763
|
+
)
|
2764
|
+
return InitWeightsSendGroupForRemoteInstanceReqOutput(success, message)
|
2765
|
+
|
2766
|
+
def send_weights_to_remote_instance(
|
2767
|
+
self, recv_req: SendWeightsToRemoteInstanceReqInput
|
2768
|
+
):
|
2769
|
+
"""Send the seed instance weights to the destination instance."""
|
2770
|
+
success, message = self.tp_worker.send_weights_to_remote_instance(recv_req)
|
2771
|
+
return SendWeightsToRemoteInstanceReqOutput(success, message)
|
2772
|
+
|
2401
2773
|
def slow_down(self, recv_req: SlowDownReqInput):
|
2402
2774
|
t = recv_req.forward_sleep_time
|
2403
2775
|
if t is not None and t <= 0:
|
@@ -2406,11 +2778,12 @@ class Scheduler(
|
|
2406
2778
|
return SlowDownReqOutput()
|
2407
2779
|
|
2408
2780
|
def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
|
2409
|
-
|
2781
|
+
action = recv_req.action
|
2782
|
+
if action == ExpertDistributionReqType.START_RECORD:
|
2410
2783
|
get_global_expert_distribution_recorder().start_record()
|
2411
|
-
elif
|
2784
|
+
elif action == ExpertDistributionReqType.STOP_RECORD:
|
2412
2785
|
get_global_expert_distribution_recorder().stop_record()
|
2413
|
-
elif
|
2786
|
+
elif action == ExpertDistributionReqType.DUMP_RECORD:
|
2414
2787
|
get_global_expert_distribution_recorder().dump_record()
|
2415
2788
|
else:
|
2416
2789
|
raise ValueError(f"Unrecognized ExpertDistributionReq value: {recv_req=}")
|
@@ -2493,7 +2866,8 @@ class IdleSleeper:
|
|
2493
2866
|
|
2494
2867
|
|
2495
2868
|
def is_health_check_generate_req(recv_req):
|
2496
|
-
|
2869
|
+
rid = getattr(recv_req, "rid", None)
|
2870
|
+
return rid is not None and rid.startswith("HEALTH_CHECK")
|
2497
2871
|
|
2498
2872
|
|
2499
2873
|
def is_work_request(recv_req):
|
@@ -2517,10 +2891,12 @@ def run_scheduler_process(
|
|
2517
2891
|
pp_rank: int,
|
2518
2892
|
dp_rank: Optional[int],
|
2519
2893
|
pipe_writer,
|
2520
|
-
balance_meta: Optional[DPBalanceMeta] = None,
|
2521
2894
|
):
|
2522
|
-
# Generate the prefix
|
2895
|
+
# Generate the logger prefix
|
2523
2896
|
prefix = ""
|
2897
|
+
if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
|
2898
|
+
# [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
|
2899
|
+
dp_rank = int(os.environ["SGLANG_DP_RANK"])
|
2524
2900
|
if dp_rank is not None:
|
2525
2901
|
prefix += f" DP{dp_rank}"
|
2526
2902
|
if server_args.tp_size > 1:
|
@@ -2536,10 +2912,6 @@ def run_scheduler_process(
|
|
2536
2912
|
kill_itself_when_parent_died()
|
2537
2913
|
parent_process = psutil.Process().parent()
|
2538
2914
|
|
2539
|
-
# [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
|
2540
|
-
if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
|
2541
|
-
dp_rank = int(os.environ["SGLANG_DP_RANK"])
|
2542
|
-
|
2543
2915
|
# Configure the logger
|
2544
2916
|
configure_logger(server_args, prefix=prefix)
|
2545
2917
|
suppress_other_loggers()
|
@@ -2547,6 +2919,15 @@ def run_scheduler_process(
|
|
2547
2919
|
# Set cpu affinity to this gpu process
|
2548
2920
|
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
|
2549
2921
|
set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
|
2922
|
+
if (numa_node := server_args.numa_node) is not None:
|
2923
|
+
numa_bind_to_node(numa_node[gpu_id])
|
2924
|
+
|
2925
|
+
# Set up tracing
|
2926
|
+
if server_args.enable_trace:
|
2927
|
+
process_tracing_init(server_args.oltp_traces_endpoint, "sglang")
|
2928
|
+
if server_args.disaggregation_mode == "null":
|
2929
|
+
thread_label = "Scheduler"
|
2930
|
+
trace_set_thread_info(thread_label, tp_rank, dp_rank)
|
2550
2931
|
|
2551
2932
|
# Create a scheduler and run the event loop
|
2552
2933
|
try:
|
@@ -2558,7 +2939,6 @@ def run_scheduler_process(
|
|
2558
2939
|
moe_ep_rank,
|
2559
2940
|
pp_rank,
|
2560
2941
|
dp_rank,
|
2561
|
-
dp_balance_meta=balance_meta,
|
2562
2942
|
)
|
2563
2943
|
pipe_writer.send(
|
2564
2944
|
{
|