sglang 0.5.3rc2__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +47 -28
- sglang/bench_one_batch_server.py +41 -25
- sglang/bench_serving.py +330 -156
- sglang/check_env.py +1 -1
- sglang/compile_deep_gemm.py +6 -2
- sglang/global_config.py +1 -25
- sglang/lang/api.py +6 -0
- sglang/lang/interpreter.py +1 -0
- sglang/lang/ir.py +13 -0
- sglang/launch_server.py +8 -15
- sglang/profiler.py +18 -1
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +4 -6
- sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
- sglang/srt/compilation/backend.py +437 -0
- sglang/srt/compilation/compilation_config.py +20 -0
- sglang/srt/compilation/compilation_counter.py +47 -0
- sglang/srt/compilation/compile.py +210 -0
- sglang/srt/compilation/compiler_interface.py +503 -0
- sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
- sglang/srt/compilation/fix_functionalization.py +134 -0
- sglang/srt/compilation/fx_utils.py +83 -0
- sglang/srt/compilation/inductor_pass.py +140 -0
- sglang/srt/compilation/pass_manager.py +66 -0
- sglang/srt/compilation/piecewise_context_manager.py +40 -0
- sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/deepseek_ocr.py +262 -0
- sglang/srt/configs/deepseekvl2.py +194 -96
- sglang/srt/configs/dots_vlm.py +2 -7
- sglang/srt/configs/falcon_h1.py +13 -64
- sglang/srt/configs/load_config.py +25 -2
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +134 -23
- sglang/srt/configs/modelopt_config.py +30 -0
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/olmo3.py +105 -0
- sglang/srt/configs/points_v15_chat.py +29 -0
- sglang/srt/configs/qwen3_next.py +11 -47
- sglang/srt/configs/qwen3_omni.py +613 -0
- sglang/srt/configs/qwen3_vl.py +0 -10
- sglang/srt/connector/remote_instance.py +1 -1
- sglang/srt/constrained/base_grammar_backend.py +5 -1
- sglang/srt/constrained/llguidance_backend.py +5 -0
- sglang/srt/constrained/outlines_backend.py +1 -1
- sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
- sglang/srt/constrained/utils.py +12 -0
- sglang/srt/constrained/xgrammar_backend.py +20 -11
- sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
- sglang/srt/disaggregation/base/conn.py +17 -4
- sglang/srt/disaggregation/common/conn.py +4 -2
- sglang/srt/disaggregation/decode.py +123 -31
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +11 -3
- sglang/srt/disaggregation/mooncake/conn.py +157 -19
- sglang/srt/disaggregation/nixl/conn.py +69 -24
- sglang/srt/disaggregation/prefill.py +96 -270
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
- sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
- sglang/srt/distributed/device_communicators/pynccl.py +24 -12
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
- sglang/srt/distributed/naive_distributed.py +5 -4
- sglang/srt/distributed/parallel_state.py +70 -19
- sglang/srt/elastic_ep/elastic_ep.py +74 -0
- sglang/srt/entrypoints/context.py +3 -2
- sglang/srt/entrypoints/engine.py +66 -66
- sglang/srt/entrypoints/grpc_server.py +431 -234
- sglang/srt/entrypoints/harmony_utils.py +2 -2
- sglang/srt/entrypoints/http_server.py +120 -8
- sglang/srt/entrypoints/http_server_engine.py +1 -7
- sglang/srt/entrypoints/openai/protocol.py +225 -37
- sglang/srt/entrypoints/openai/serving_base.py +49 -2
- sglang/srt/entrypoints/openai/serving_chat.py +29 -74
- sglang/srt/entrypoints/openai/serving_classify.py +204 -0
- sglang/srt/entrypoints/openai/serving_completions.py +15 -1
- sglang/srt/entrypoints/openai/serving_responses.py +5 -2
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +42 -4
- sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
- sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
- sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
- sglang/srt/eplb/expert_distribution.py +3 -4
- sglang/srt/eplb/expert_location_dispatch.py +2 -2
- sglang/srt/eplb/expert_location_updater.py +2 -2
- sglang/srt/function_call/base_format_detector.py +17 -18
- sglang/srt/function_call/function_call_parser.py +18 -14
- sglang/srt/function_call/glm4_moe_detector.py +1 -5
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/json_array_parser.py +0 -2
- sglang/srt/function_call/utils.py +2 -2
- sglang/srt/grpc/compile_proto.py +3 -3
- sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
- sglang/srt/grpc/health_servicer.py +189 -0
- sglang/srt/grpc/scheduler_launcher.py +181 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
- sglang/srt/layers/activation.py +4 -1
- sglang/srt/layers/attention/aiter_backend.py +3 -3
- sglang/srt/layers/attention/ascend_backend.py +17 -1
- sglang/srt/layers/attention/attention_registry.py +43 -23
- sglang/srt/layers/attention/base_attn_backend.py +20 -1
- sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
- sglang/srt/layers/attention/fla/chunk.py +0 -1
- sglang/srt/layers/attention/fla/chunk_o.py +1 -1
- sglang/srt/layers/attention/fla/index.py +0 -2
- sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
- sglang/srt/layers/attention/fla/utils.py +0 -3
- sglang/srt/layers/attention/fla/wy_fast.py +0 -2
- sglang/srt/layers/attention/flashattention_backend.py +12 -8
- sglang/srt/layers/attention/flashinfer_backend.py +248 -21
- sglang/srt/layers/attention/flashinfer_mla_backend.py +20 -18
- sglang/srt/layers/attention/flashmla_backend.py +2 -2
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
- sglang/srt/layers/attention/intel_amx_backend.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
- sglang/srt/layers/attention/mamba/mamba.py +189 -241
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
- sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
- sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
- sglang/srt/layers/attention/nsa/utils.py +0 -1
- sglang/srt/layers/attention/nsa_backend.py +404 -90
- sglang/srt/layers/attention/triton_backend.py +208 -34
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
- sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
- sglang/srt/layers/attention/trtllm_mla_backend.py +361 -30
- sglang/srt/layers/attention/utils.py +11 -7
- sglang/srt/layers/attention/vision.py +3 -3
- sglang/srt/layers/attention/xpu_backend.py +1028 -0
- sglang/srt/layers/communicator.py +11 -7
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
- sglang/srt/layers/dp_attention.py +17 -0
- sglang/srt/layers/layernorm.py +45 -15
- sglang/srt/layers/linear.py +9 -1
- sglang/srt/layers/logits_processor.py +147 -17
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/cutlass_moe.py +0 -2
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
- sglang/srt/layers/moe/ep_moe/kernels.py +35 -457
- sglang/srt/layers/moe/ep_moe/layer.py +119 -397
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
- sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -70
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/moe_runner/triton.py +3 -1
- sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
- sglang/srt/layers/moe/router.py +51 -15
- sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
- sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
- sglang/srt/layers/moe/token_dispatcher/deepep.py +110 -97
- sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
- sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
- sglang/srt/layers/moe/topk.py +3 -2
- sglang/srt/layers/moe/utils.py +17 -1
- sglang/srt/layers/quantization/__init__.py +2 -53
- sglang/srt/layers/quantization/awq.py +183 -6
- sglang/srt/layers/quantization/awq_triton.py +29 -0
- sglang/srt/layers/quantization/base_config.py +20 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
- sglang/srt/layers/quantization/fp8.py +84 -18
- sglang/srt/layers/quantization/fp8_kernel.py +55 -10
- sglang/srt/layers/quantization/fp8_utils.py +42 -14
- sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
- sglang/srt/layers/quantization/gptq.py +0 -1
- sglang/srt/layers/quantization/int8_kernel.py +18 -2
- sglang/srt/layers/quantization/marlin_utils.py +12 -0
- sglang/srt/layers/quantization/modelopt_quant.py +125 -100
- sglang/srt/layers/quantization/mxfp4.py +5 -30
- sglang/srt/layers/quantization/petit.py +1 -1
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
- sglang/srt/layers/quantization/unquant.py +1 -4
- sglang/srt/layers/quantization/utils.py +0 -1
- sglang/srt/layers/quantization/w4afp8.py +51 -20
- sglang/srt/layers/quantization/w8a8_int8.py +30 -24
- sglang/srt/layers/radix_attention.py +59 -9
- sglang/srt/layers/rotary_embedding.py +673 -16
- sglang/srt/layers/sampler.py +36 -16
- sglang/srt/layers/sparse_pooler.py +98 -0
- sglang/srt/layers/utils.py +0 -1
- sglang/srt/layers/vocab_parallel_embedding.py +4 -1
- sglang/srt/lora/backend/triton_backend.py +0 -1
- sglang/srt/lora/eviction_policy.py +139 -0
- sglang/srt/lora/lora_manager.py +24 -9
- sglang/srt/lora/lora_registry.py +1 -1
- sglang/srt/lora/mem_pool.py +40 -16
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
- sglang/srt/managers/cache_controller.py +48 -17
- sglang/srt/managers/data_parallel_controller.py +146 -42
- sglang/srt/managers/detokenizer_manager.py +40 -13
- sglang/srt/managers/io_struct.py +66 -16
- sglang/srt/managers/mm_utils.py +20 -18
- sglang/srt/managers/multi_tokenizer_mixin.py +66 -81
- sglang/srt/managers/overlap_utils.py +96 -19
- sglang/srt/managers/schedule_batch.py +241 -511
- sglang/srt/managers/schedule_policy.py +15 -2
- sglang/srt/managers/scheduler.py +399 -499
- sglang/srt/managers/scheduler_metrics_mixin.py +55 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
- sglang/srt/managers/scheduler_pp_mixin.py +341 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +57 -10
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
- sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
- sglang/srt/managers/tokenizer_manager.py +378 -90
- sglang/srt/managers/tp_worker.py +212 -161
- sglang/srt/managers/utils.py +78 -2
- sglang/srt/mem_cache/allocator.py +7 -2
- sglang/srt/mem_cache/allocator_ascend.py +2 -2
- sglang/srt/mem_cache/base_prefix_cache.py +2 -2
- sglang/srt/mem_cache/chunk_cache.py +13 -2
- sglang/srt/mem_cache/common.py +480 -0
- sglang/srt/mem_cache/evict_policy.py +16 -1
- sglang/srt/mem_cache/hicache_storage.py +4 -1
- sglang/srt/mem_cache/hiradix_cache.py +16 -3
- sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
- sglang/srt/mem_cache/memory_pool.py +435 -219
- sglang/srt/mem_cache/memory_pool_host.py +0 -1
- sglang/srt/mem_cache/multimodal_cache.py +0 -1
- sglang/srt/mem_cache/radix_cache.py +53 -19
- sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
- sglang/srt/mem_cache/storage/backend_factory.py +2 -2
- sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
- sglang/srt/mem_cache/swa_radix_cache.py +92 -26
- sglang/srt/metrics/collector.py +31 -0
- sglang/srt/metrics/func_timer.py +1 -1
- sglang/srt/model_executor/cuda_graph_runner.py +43 -5
- sglang/srt/model_executor/forward_batch_info.py +28 -23
- sglang/srt/model_executor/model_runner.py +379 -139
- sglang/srt/model_executor/npu_graph_runner.py +2 -3
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
- sglang/srt/model_loader/__init__.py +1 -1
- sglang/srt/model_loader/loader.py +424 -27
- sglang/srt/model_loader/utils.py +0 -1
- sglang/srt/model_loader/weight_utils.py +47 -28
- sglang/srt/models/apertus.py +2 -3
- sglang/srt/models/arcee.py +2 -2
- sglang/srt/models/bailing_moe.py +13 -52
- sglang/srt/models/bailing_moe_nextn.py +3 -4
- sglang/srt/models/bert.py +1 -1
- sglang/srt/models/deepseek_nextn.py +19 -3
- sglang/srt/models/deepseek_ocr.py +1516 -0
- sglang/srt/models/deepseek_v2.py +273 -98
- sglang/srt/models/dots_ocr.py +0 -2
- sglang/srt/models/dots_vlm.py +0 -1
- sglang/srt/models/dots_vlm_vit.py +1 -1
- sglang/srt/models/falcon_h1.py +13 -19
- sglang/srt/models/gemma3_mm.py +16 -0
- sglang/srt/models/gemma3n_mm.py +1 -2
- sglang/srt/models/glm4_moe.py +14 -37
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +2 -1
- sglang/srt/models/glm4v_moe.py +5 -5
- sglang/srt/models/gpt_oss.py +5 -5
- sglang/srt/models/grok.py +10 -23
- sglang/srt/models/hunyuan.py +2 -7
- sglang/srt/models/interns1.py +0 -1
- sglang/srt/models/kimi_vl.py +1 -7
- sglang/srt/models/kimi_vl_moonvit.py +3 -1
- sglang/srt/models/llama.py +2 -2
- sglang/srt/models/llama_eagle3.py +1 -1
- sglang/srt/models/longcat_flash.py +5 -22
- sglang/srt/models/longcat_flash_nextn.py +3 -14
- sglang/srt/models/mimo.py +2 -13
- sglang/srt/models/mimo_mtp.py +1 -2
- sglang/srt/models/minicpmo.py +7 -5
- sglang/srt/models/mixtral.py +1 -4
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/mllama4.py +13 -3
- sglang/srt/models/nemotron_h.py +511 -0
- sglang/srt/models/olmo2.py +31 -4
- sglang/srt/models/opt.py +5 -5
- sglang/srt/models/phi.py +1 -1
- sglang/srt/models/phi4mm.py +1 -1
- sglang/srt/models/phimoe.py +0 -1
- sglang/srt/models/pixtral.py +0 -3
- sglang/srt/models/points_v15_chat.py +186 -0
- sglang/srt/models/qwen.py +0 -1
- sglang/srt/models/qwen2_5_vl.py +3 -3
- sglang/srt/models/qwen2_audio.py +2 -15
- sglang/srt/models/qwen2_moe.py +15 -12
- sglang/srt/models/qwen2_vl.py +5 -2
- sglang/srt/models/qwen3_moe.py +19 -35
- sglang/srt/models/qwen3_next.py +7 -12
- sglang/srt/models/qwen3_next_mtp.py +3 -4
- sglang/srt/models/qwen3_omni_moe.py +661 -0
- sglang/srt/models/qwen3_vl.py +37 -33
- sglang/srt/models/qwen3_vl_moe.py +57 -185
- sglang/srt/models/roberta.py +55 -3
- sglang/srt/models/sarashina2_vision.py +0 -1
- sglang/srt/models/step3_vl.py +3 -5
- sglang/srt/models/utils.py +11 -1
- sglang/srt/multimodal/processors/base_processor.py +6 -2
- sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
- sglang/srt/multimodal/processors/dots_vlm.py +0 -1
- sglang/srt/multimodal/processors/glm4v.py +1 -5
- sglang/srt/multimodal/processors/internvl.py +0 -2
- sglang/srt/multimodal/processors/janus_pro.py +0 -1
- sglang/srt/multimodal/processors/mllama4.py +0 -8
- sglang/srt/multimodal/processors/phi4mm.py +0 -1
- sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
- sglang/srt/multimodal/processors/qwen_vl.py +75 -16
- sglang/srt/multimodal/processors/step3_vl.py +1 -1
- sglang/srt/parser/conversation.py +41 -0
- sglang/srt/parser/reasoning_parser.py +0 -1
- sglang/srt/sampling/custom_logit_processor.py +77 -2
- sglang/srt/sampling/sampling_batch_info.py +17 -22
- sglang/srt/sampling/sampling_params.py +70 -2
- sglang/srt/server_args.py +577 -73
- sglang/srt/server_args_config_parser.py +1 -1
- sglang/srt/single_batch_overlap.py +38 -28
- sglang/srt/speculative/base_spec_worker.py +34 -0
- sglang/srt/speculative/draft_utils.py +226 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
- sglang/srt/speculative/eagle_info.py +57 -18
- sglang/srt/speculative/eagle_info_v2.py +458 -0
- sglang/srt/speculative/eagle_utils.py +138 -0
- sglang/srt/speculative/eagle_worker.py +83 -280
- sglang/srt/speculative/eagle_worker_v2.py +702 -0
- sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
- sglang/srt/speculative/ngram_worker.py +12 -11
- sglang/srt/speculative/spec_info.py +2 -0
- sglang/srt/speculative/spec_utils.py +38 -3
- sglang/srt/speculative/standalone_worker.py +4 -14
- sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
- sglang/srt/two_batch_overlap.py +28 -14
- sglang/srt/utils/__init__.py +1 -1
- sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
- sglang/srt/utils/common.py +192 -47
- sglang/srt/utils/hf_transformers_utils.py +40 -17
- sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
- sglang/srt/{offloader.py → utils/offloader.py} +4 -4
- sglang/srt/utils/profile_merger.py +199 -0
- sglang/test/attention/test_flashattn_backend.py +1 -1
- sglang/test/attention/test_flashattn_mla_backend.py +0 -1
- sglang/test/attention/test_prefix_chunk_info.py +0 -2
- sglang/test/attention/test_trtllm_mla_backend.py +221 -53
- sglang/test/few_shot_gsm8k_engine.py +2 -4
- sglang/test/kit_matched_stop.py +157 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +41 -0
- sglang/test/runners.py +2 -0
- sglang/test/send_one.py +42 -7
- sglang/test/simple_eval_common.py +3 -0
- sglang/test/simple_eval_gpqa.py +0 -1
- sglang/test/simple_eval_humaneval.py +0 -3
- sglang/test/simple_eval_longbench_v2.py +344 -0
- sglang/test/test_block_fp8.py +1 -2
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
- sglang/test/test_cutlass_moe.py +1 -2
- sglang/test/test_cutlass_w4a8_moe.py +10 -20
- sglang/test/test_deterministic.py +232 -99
- sglang/test/test_deterministic_utils.py +73 -0
- sglang/test/test_disaggregation_utils.py +81 -0
- sglang/test/test_marlin_moe.py +0 -1
- sglang/test/test_utils.py +85 -20
- sglang/version.py +1 -1
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/METADATA +45 -33
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/RECORD +404 -345
- sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
- sglang/srt/speculative/build_eagle_tree.py +0 -427
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
- /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
- /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
sglang/srt/managers/scheduler.py
CHANGED
|
@@ -24,16 +24,16 @@ from collections import deque
|
|
|
24
24
|
from concurrent import futures
|
|
25
25
|
from dataclasses import dataclass
|
|
26
26
|
from http import HTTPStatus
|
|
27
|
-
from
|
|
28
|
-
from typing import Dict, List, Optional, Tuple, Union
|
|
27
|
+
from typing import Deque, Dict, List, Optional, Tuple, Union
|
|
29
28
|
|
|
30
29
|
import psutil
|
|
31
30
|
import setproctitle
|
|
32
31
|
import torch
|
|
33
32
|
import zmq
|
|
33
|
+
from torch.cuda import Stream as CudaStream
|
|
34
|
+
from torch.cuda import StreamContext as CudaStreamContext
|
|
34
35
|
from torch.distributed import barrier
|
|
35
36
|
|
|
36
|
-
from sglang.global_config import global_config
|
|
37
37
|
from sglang.srt.configs.model_config import ModelConfig
|
|
38
38
|
from sglang.srt.constrained.base_grammar_backend import (
|
|
39
39
|
INVALID_GRAMMAR_OBJ,
|
|
@@ -59,12 +59,14 @@ from sglang.srt.disaggregation.utils import (
|
|
|
59
59
|
prepare_abort,
|
|
60
60
|
)
|
|
61
61
|
from sglang.srt.distributed import get_pp_group, get_world_group
|
|
62
|
+
from sglang.srt.environ import envs
|
|
62
63
|
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
|
63
64
|
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
|
64
|
-
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
|
65
65
|
from sglang.srt.layers.moe import initialize_moe_config
|
|
66
66
|
from sglang.srt.managers.io_struct import (
|
|
67
67
|
AbortReq,
|
|
68
|
+
BaseBatchReq,
|
|
69
|
+
BaseReq,
|
|
68
70
|
BatchTokenizedEmbeddingReqInput,
|
|
69
71
|
BatchTokenizedGenerateReqInput,
|
|
70
72
|
ClearHiCacheReqInput,
|
|
@@ -88,8 +90,6 @@ from sglang.srt.managers.io_struct import (
|
|
|
88
90
|
InitWeightsUpdateGroupReqInput,
|
|
89
91
|
LoadLoRAAdapterReqInput,
|
|
90
92
|
LoadLoRAAdapterReqOutput,
|
|
91
|
-
MultiTokenizerRegisterReq,
|
|
92
|
-
MultiTokenizerWrapper,
|
|
93
93
|
OpenSessionReqInput,
|
|
94
94
|
OpenSessionReqOutput,
|
|
95
95
|
ProfileReq,
|
|
@@ -109,16 +109,18 @@ from sglang.srt.managers.io_struct import (
|
|
|
109
109
|
UnloadLoRAAdapterReqOutput,
|
|
110
110
|
UpdateWeightFromDiskReqInput,
|
|
111
111
|
UpdateWeightsFromDistributedReqInput,
|
|
112
|
+
UpdateWeightsFromIPCReqInput,
|
|
112
113
|
UpdateWeightsFromTensorReqInput,
|
|
113
114
|
)
|
|
114
115
|
from sglang.srt.managers.mm_utils import init_embedding_cache
|
|
116
|
+
from sglang.srt.managers.overlap_utils import FutureMap
|
|
115
117
|
from sglang.srt.managers.schedule_batch import (
|
|
116
118
|
FINISH_ABORT,
|
|
119
|
+
ModelWorkerBatch,
|
|
117
120
|
MultimodalInputs,
|
|
118
121
|
Req,
|
|
119
122
|
RequestStage,
|
|
120
123
|
ScheduleBatch,
|
|
121
|
-
global_server_args_dict,
|
|
122
124
|
)
|
|
123
125
|
from sglang.srt.managers.schedule_policy import (
|
|
124
126
|
AddReqResult,
|
|
@@ -133,28 +135,25 @@ from sglang.srt.managers.scheduler_metrics_mixin import (
|
|
|
133
135
|
from sglang.srt.managers.scheduler_output_processor_mixin import (
|
|
134
136
|
SchedulerOutputProcessorMixin,
|
|
135
137
|
)
|
|
138
|
+
from sglang.srt.managers.scheduler_pp_mixin import SchedulerPPMixin
|
|
136
139
|
from sglang.srt.managers.scheduler_profiler_mixin import SchedulerProfilerMixin
|
|
137
140
|
from sglang.srt.managers.scheduler_recv_skipper import SchedulerRecvSkipper
|
|
141
|
+
from sglang.srt.managers.scheduler_runtime_checker_mixin import (
|
|
142
|
+
SchedulerRuntimeCheckerMixin,
|
|
143
|
+
)
|
|
138
144
|
from sglang.srt.managers.scheduler_update_weights_mixin import (
|
|
139
145
|
SchedulerUpdateWeightsMixin,
|
|
140
146
|
)
|
|
141
147
|
from sglang.srt.managers.session_controller import Session
|
|
142
|
-
from sglang.srt.managers.
|
|
143
|
-
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
|
|
144
|
-
from sglang.srt.managers.utils import validate_input_length
|
|
148
|
+
from sglang.srt.managers.utils import GenerationBatchResult, validate_input_length
|
|
145
149
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
|
146
150
|
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
|
151
|
+
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
|
|
147
152
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
|
148
153
|
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
|
149
|
-
from sglang.srt.model_executor.forward_batch_info import (
|
|
150
|
-
ForwardBatchOutput,
|
|
151
|
-
ForwardMode,
|
|
152
|
-
PPProxyTensors,
|
|
153
|
-
)
|
|
154
154
|
from sglang.srt.parser.reasoning_parser import ReasoningParser
|
|
155
|
-
from sglang.srt.server_args import PortArgs, ServerArgs
|
|
155
|
+
from sglang.srt.server_args import PortArgs, ServerArgs, get_global_server_args
|
|
156
156
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
|
157
|
-
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
|
158
157
|
from sglang.srt.tracing.trace import (
|
|
159
158
|
process_tracing_init,
|
|
160
159
|
trace_set_proc_propagate_context,
|
|
@@ -190,64 +189,17 @@ from sglang.srt.utils.hf_transformers_utils import (
|
|
|
190
189
|
get_tokenizer,
|
|
191
190
|
get_tokenizer_from_processor,
|
|
192
191
|
)
|
|
192
|
+
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
|
193
193
|
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
|
|
194
194
|
|
|
195
195
|
logger = logging.getLogger(__name__)
|
|
196
196
|
|
|
197
197
|
# Test retract decode for debugging purposes
|
|
198
|
-
TEST_RETRACT =
|
|
198
|
+
TEST_RETRACT = envs.SGLANG_TEST_RETRACT.get()
|
|
199
|
+
TEST_RETRACT_INTERVAL = envs.SGLANG_TEST_RETRACT_INTERVAL.get()
|
|
199
200
|
GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
|
|
200
201
|
|
|
201
202
|
|
|
202
|
-
@dataclass
|
|
203
|
-
class GenerationBatchResult:
|
|
204
|
-
logits_output: Optional[LogitsProcessorOutput]
|
|
205
|
-
pp_hidden_states_proxy_tensors: Optional[PPProxyTensors]
|
|
206
|
-
next_token_ids: Optional[List[int]]
|
|
207
|
-
can_run_cuda_graph: bool
|
|
208
|
-
|
|
209
|
-
# For output processing
|
|
210
|
-
extend_input_len_per_req: List[int]
|
|
211
|
-
extend_logprob_start_len_per_req: List[int]
|
|
212
|
-
|
|
213
|
-
@classmethod
|
|
214
|
-
def from_forward_batch_output(
|
|
215
|
-
cls,
|
|
216
|
-
forward_batch_output: ForwardBatchOutput,
|
|
217
|
-
extend_input_len_per_req: List[int],
|
|
218
|
-
extend_logprob_start_len_per_req: List[int],
|
|
219
|
-
):
|
|
220
|
-
# TODO(lsyin): remove this workaround logic and try to unify output classes
|
|
221
|
-
|
|
222
|
-
return cls(
|
|
223
|
-
logits_output=forward_batch_output.logits_output,
|
|
224
|
-
pp_hidden_states_proxy_tensors=forward_batch_output.pp_proxy_tensors,
|
|
225
|
-
next_token_ids=forward_batch_output.next_token_ids,
|
|
226
|
-
extend_input_len_per_req=extend_input_len_per_req,
|
|
227
|
-
extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
|
|
228
|
-
can_run_cuda_graph=forward_batch_output.can_run_cuda_graph,
|
|
229
|
-
)
|
|
230
|
-
|
|
231
|
-
@classmethod
|
|
232
|
-
def from_pp_proxy(
|
|
233
|
-
cls, logits_output, next_pp_outputs: PPProxyTensors, can_run_cuda_graph
|
|
234
|
-
):
|
|
235
|
-
# TODO(lsyin): also simplify this logic
|
|
236
|
-
# Current PP implementation in scheduler is not compatible with ForwardBatchOutput
|
|
237
|
-
# Maybe introduce a ProxyBatchOutput for PP and the original ForwardBatchOutput for TP
|
|
238
|
-
proxy_dict = next_pp_outputs.tensors
|
|
239
|
-
return cls(
|
|
240
|
-
logits_output=logits_output,
|
|
241
|
-
pp_hidden_states_proxy_tensors=None,
|
|
242
|
-
next_token_ids=next_pp_outputs["next_token_ids"],
|
|
243
|
-
extend_input_len_per_req=proxy_dict.get("extend_input_len_per_req", None),
|
|
244
|
-
extend_logprob_start_len_per_req=proxy_dict.get(
|
|
245
|
-
"extend_logprob_start_len_per_req", None
|
|
246
|
-
),
|
|
247
|
-
can_run_cuda_graph=can_run_cuda_graph,
|
|
248
|
-
)
|
|
249
|
-
|
|
250
|
-
|
|
251
203
|
@dataclass
|
|
252
204
|
class EmbeddingBatchResult:
|
|
253
205
|
embeddings: torch.Tensor
|
|
@@ -260,6 +212,8 @@ class Scheduler(
|
|
|
260
212
|
SchedulerMetricsMixin,
|
|
261
213
|
SchedulerDisaggregationDecodeMixin,
|
|
262
214
|
SchedulerDisaggregationPrefillMixin,
|
|
215
|
+
SchedulerRuntimeCheckerMixin,
|
|
216
|
+
SchedulerPPMixin,
|
|
263
217
|
):
|
|
264
218
|
"""A scheduler that manages a tensor parallel GPU worker."""
|
|
265
219
|
|
|
@@ -285,6 +239,9 @@ class Scheduler(
|
|
|
285
239
|
self.dp_size = server_args.dp_size
|
|
286
240
|
self.schedule_policy = server_args.schedule_policy
|
|
287
241
|
self.enable_priority_scheduling = server_args.enable_priority_scheduling
|
|
242
|
+
self.abort_on_priority_when_disabled = (
|
|
243
|
+
server_args.abort_on_priority_when_disabled
|
|
244
|
+
)
|
|
288
245
|
self.schedule_low_priority_values_first = (
|
|
289
246
|
server_args.schedule_low_priority_values_first
|
|
290
247
|
)
|
|
@@ -325,47 +282,7 @@ class Scheduler(
|
|
|
325
282
|
self.model_config = ModelConfig.from_server_args(server_args)
|
|
326
283
|
|
|
327
284
|
# Init inter-process communication
|
|
328
|
-
|
|
329
|
-
self.idle_sleeper = None
|
|
330
|
-
if self.pp_rank == 0 and self.attn_tp_rank == 0:
|
|
331
|
-
self.recv_from_tokenizer = get_zmq_socket(
|
|
332
|
-
context, zmq.PULL, port_args.scheduler_input_ipc_name, False
|
|
333
|
-
)
|
|
334
|
-
self.recv_from_rpc = get_zmq_socket(
|
|
335
|
-
context, zmq.DEALER, port_args.rpc_ipc_name, False
|
|
336
|
-
)
|
|
337
|
-
|
|
338
|
-
self.send_to_tokenizer = get_zmq_socket(
|
|
339
|
-
context, zmq.PUSH, port_args.tokenizer_ipc_name, False
|
|
340
|
-
)
|
|
341
|
-
if server_args.skip_tokenizer_init:
|
|
342
|
-
# Directly send to the TokenizerManager
|
|
343
|
-
self.send_to_detokenizer = get_zmq_socket(
|
|
344
|
-
context, zmq.PUSH, port_args.tokenizer_ipc_name, False
|
|
345
|
-
)
|
|
346
|
-
else:
|
|
347
|
-
# Send to the DetokenizerManager
|
|
348
|
-
self.send_to_detokenizer = get_zmq_socket(
|
|
349
|
-
context, zmq.PUSH, port_args.detokenizer_ipc_name, False
|
|
350
|
-
)
|
|
351
|
-
|
|
352
|
-
if self.server_args.sleep_on_idle:
|
|
353
|
-
self.idle_sleeper = IdleSleeper(
|
|
354
|
-
[
|
|
355
|
-
self.recv_from_tokenizer,
|
|
356
|
-
self.recv_from_rpc,
|
|
357
|
-
]
|
|
358
|
-
)
|
|
359
|
-
else:
|
|
360
|
-
self.recv_from_tokenizer = None
|
|
361
|
-
self.recv_from_rpc = None
|
|
362
|
-
self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
|
|
363
|
-
self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
|
|
364
|
-
|
|
365
|
-
if self.current_scheduler_metrics_enabled():
|
|
366
|
-
self.send_metrics_from_scheduler = get_zmq_socket(
|
|
367
|
-
context, zmq.PUSH, port_args.metrics_ipc_name, False
|
|
368
|
-
)
|
|
285
|
+
self.init_sockets(server_args, port_args)
|
|
369
286
|
|
|
370
287
|
# Init tokenizer
|
|
371
288
|
self.init_tokenizer()
|
|
@@ -388,12 +305,10 @@ class Scheduler(
|
|
|
388
305
|
logger.info("Overlap scheduler is disabled for embedding models.")
|
|
389
306
|
|
|
390
307
|
# Launch a tensor parallel worker
|
|
391
|
-
if self.enable_overlap:
|
|
392
|
-
TpWorkerClass = TpModelWorkerClient
|
|
393
|
-
else:
|
|
394
|
-
TpWorkerClass = TpModelWorker
|
|
395
308
|
|
|
396
|
-
|
|
309
|
+
from sglang.srt.managers.tp_worker import TpModelWorker
|
|
310
|
+
|
|
311
|
+
self.tp_worker = TpModelWorker(
|
|
397
312
|
server_args=server_args,
|
|
398
313
|
gpu_id=gpu_id,
|
|
399
314
|
tp_rank=tp_rank,
|
|
@@ -404,44 +319,10 @@ class Scheduler(
|
|
|
404
319
|
)
|
|
405
320
|
|
|
406
321
|
# Launch a draft worker for speculative decoding
|
|
407
|
-
if self.spec_algorithm.is_eagle():
|
|
408
|
-
from sglang.srt.speculative.eagle_worker import EAGLEWorker
|
|
409
322
|
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
moe_ep_rank=moe_ep_rank,
|
|
414
|
-
server_args=server_args,
|
|
415
|
-
nccl_port=port_args.nccl_port,
|
|
416
|
-
target_worker=self.tp_worker,
|
|
417
|
-
dp_rank=dp_rank,
|
|
418
|
-
)
|
|
419
|
-
elif self.spec_algorithm.is_standalone():
|
|
420
|
-
from sglang.srt.speculative.standalone_worker import StandaloneWorker
|
|
421
|
-
|
|
422
|
-
self.draft_worker = StandaloneWorker(
|
|
423
|
-
gpu_id=gpu_id,
|
|
424
|
-
tp_rank=tp_rank,
|
|
425
|
-
moe_ep_rank=moe_ep_rank,
|
|
426
|
-
server_args=server_args,
|
|
427
|
-
nccl_port=port_args.nccl_port,
|
|
428
|
-
target_worker=self.tp_worker,
|
|
429
|
-
dp_rank=dp_rank,
|
|
430
|
-
)
|
|
431
|
-
elif self.spec_algorithm.is_ngram():
|
|
432
|
-
from sglang.srt.speculative.ngram_worker import NGRAMWorker
|
|
433
|
-
|
|
434
|
-
self.draft_worker = NGRAMWorker(
|
|
435
|
-
gpu_id=gpu_id,
|
|
436
|
-
tp_rank=tp_rank,
|
|
437
|
-
moe_ep_rank=moe_ep_rank,
|
|
438
|
-
server_args=server_args,
|
|
439
|
-
nccl_port=port_args.nccl_port,
|
|
440
|
-
target_worker=self.tp_worker,
|
|
441
|
-
dp_rank=dp_rank,
|
|
442
|
-
)
|
|
443
|
-
else:
|
|
444
|
-
self.draft_worker = None
|
|
323
|
+
self.launch_draft_worker(
|
|
324
|
+
gpu_id, tp_rank, moe_ep_rank, server_args, port_args, dp_rank
|
|
325
|
+
)
|
|
445
326
|
|
|
446
327
|
# Dispatch the model worker
|
|
447
328
|
if self.spec_algorithm.is_none():
|
|
@@ -459,13 +340,12 @@ class Scheduler(
|
|
|
459
340
|
self.max_req_input_len,
|
|
460
341
|
self.random_seed,
|
|
461
342
|
self.device,
|
|
462
|
-
worker_global_server_args_dict,
|
|
463
343
|
_,
|
|
464
344
|
_,
|
|
465
345
|
_,
|
|
466
346
|
) = self.tp_worker.get_worker_info()
|
|
467
|
-
if
|
|
468
|
-
|
|
347
|
+
if get_global_server_args().pp_max_micro_batch_size is None:
|
|
348
|
+
get_global_server_args().pp_max_micro_batch_size = max(
|
|
469
349
|
self.max_running_requests // server_args.pp_size, 1
|
|
470
350
|
)
|
|
471
351
|
|
|
@@ -477,11 +357,12 @@ class Scheduler(
|
|
|
477
357
|
self.world_group = get_world_group()
|
|
478
358
|
|
|
479
359
|
self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
|
|
480
|
-
global_server_args_dict.update(worker_global_server_args_dict)
|
|
481
360
|
set_random_seed(self.random_seed)
|
|
482
361
|
|
|
483
362
|
# Hybrid memory pool
|
|
484
363
|
self.is_hybrid = self.tp_worker.is_hybrid
|
|
364
|
+
self.is_hybrid_gdn = self.tp_worker.model_runner.hybrid_gdn_config is not None
|
|
365
|
+
|
|
485
366
|
if self.is_hybrid:
|
|
486
367
|
self.sliding_window_size = self.tp_worker.sliding_window_size
|
|
487
368
|
self.full_tokens_per_layer, self.swa_tokens_per_layer = (
|
|
@@ -525,9 +406,11 @@ class Scheduler(
|
|
|
525
406
|
self.kv_transfer_speed_gb_s: float = 0.0
|
|
526
407
|
self.kv_transfer_latency_ms: float = 0.0
|
|
527
408
|
self.sessions: Dict[str, Session] = {}
|
|
528
|
-
self.
|
|
409
|
+
self.default_stream: CudaStream = torch.get_device_module(
|
|
410
|
+
self.device
|
|
411
|
+
).current_stream()
|
|
529
412
|
if self.device == "cpu":
|
|
530
|
-
self.
|
|
413
|
+
self.default_stream.synchronize = lambda: None # No-op for CPU
|
|
531
414
|
self.forward_sleep_time = None
|
|
532
415
|
|
|
533
416
|
# Init chunked prefill
|
|
@@ -566,18 +449,17 @@ class Scheduler(
|
|
|
566
449
|
server_args.schedule_conservativeness >= 0
|
|
567
450
|
), "Invalid schedule_conservativeness"
|
|
568
451
|
self.init_new_token_ratio = min(
|
|
569
|
-
|
|
452
|
+
envs.SGLANG_INIT_NEW_TOKEN_RATIO.get()
|
|
570
453
|
* server_args.schedule_conservativeness,
|
|
571
454
|
1.0,
|
|
572
455
|
)
|
|
573
456
|
self.min_new_token_ratio = min(
|
|
574
|
-
self.init_new_token_ratio
|
|
575
|
-
* global_config.default_min_new_token_ratio_factor,
|
|
457
|
+
self.init_new_token_ratio * envs.SGLANG_MIN_NEW_TOKEN_RATIO_FACTOR.get(),
|
|
576
458
|
1.0,
|
|
577
459
|
)
|
|
578
460
|
self.new_token_ratio_decay = (
|
|
579
461
|
self.init_new_token_ratio - self.min_new_token_ratio
|
|
580
|
-
) /
|
|
462
|
+
) / envs.SGLANG_NEW_TOKEN_RATIO_DECAY_STEPS.get()
|
|
581
463
|
self.new_token_ratio = self.init_new_token_ratio
|
|
582
464
|
|
|
583
465
|
# Init watchdog thread
|
|
@@ -618,6 +500,9 @@ class Scheduler(
|
|
|
618
500
|
# Init prefill kv split size when deterministic inference is enabled with various attention backends
|
|
619
501
|
self.init_deterministic_inference_config()
|
|
620
502
|
|
|
503
|
+
# Init overlap
|
|
504
|
+
self.init_overlap()
|
|
505
|
+
|
|
621
506
|
# Init request dispatcher
|
|
622
507
|
self._request_dispatcher = TypeBasedDispatcher(
|
|
623
508
|
[
|
|
@@ -646,6 +531,7 @@ class Scheduler(
|
|
|
646
531
|
self.update_weights_from_distributed,
|
|
647
532
|
),
|
|
648
533
|
(UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor),
|
|
534
|
+
(UpdateWeightsFromIPCReqInput, self.update_weights_from_ipc),
|
|
649
535
|
(GetWeightsByNameReqInput, self.get_weights_by_name),
|
|
650
536
|
(ReleaseMemoryOccupationReqInput, self.release_memory_occupation),
|
|
651
537
|
(ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
|
|
@@ -658,11 +544,130 @@ class Scheduler(
|
|
|
658
544
|
(ExpertDistributionReq, self.expert_distribution_handle),
|
|
659
545
|
(LoadLoRAAdapterReqInput, self.load_lora_adapter),
|
|
660
546
|
(UnloadLoRAAdapterReqInput, self.unload_lora_adapter),
|
|
661
|
-
(MultiTokenizerRegisterReq, self.register_multi_tokenizer),
|
|
662
547
|
(GetLoadReqInput, self.get_load),
|
|
663
548
|
]
|
|
664
549
|
)
|
|
665
550
|
|
|
551
|
+
def launch_draft_worker(
|
|
552
|
+
self, gpu_id, tp_rank, moe_ep_rank, server_args, port_args, dp_rank
|
|
553
|
+
):
|
|
554
|
+
if server_args.speculative_draft_load_format is not None:
|
|
555
|
+
server_args.load_format = server_args.speculative_draft_load_format
|
|
556
|
+
logger.info(
|
|
557
|
+
f"Using draft model load_format: '{server_args.speculative_draft_load_format}'"
|
|
558
|
+
)
|
|
559
|
+
|
|
560
|
+
if self.spec_algorithm.is_eagle():
|
|
561
|
+
from sglang.srt.speculative.eagle_worker import EAGLEWorker
|
|
562
|
+
from sglang.srt.speculative.eagle_worker_v2 import EAGLEWorkerV2
|
|
563
|
+
|
|
564
|
+
WorkerClass = EAGLEWorkerV2 if self.enable_overlap else EAGLEWorker
|
|
565
|
+
|
|
566
|
+
self.draft_worker = WorkerClass(
|
|
567
|
+
gpu_id=gpu_id,
|
|
568
|
+
tp_rank=tp_rank,
|
|
569
|
+
moe_ep_rank=moe_ep_rank,
|
|
570
|
+
server_args=server_args,
|
|
571
|
+
nccl_port=port_args.nccl_port,
|
|
572
|
+
target_worker=self.tp_worker,
|
|
573
|
+
dp_rank=dp_rank,
|
|
574
|
+
)
|
|
575
|
+
elif self.spec_algorithm.is_standalone():
|
|
576
|
+
from sglang.srt.speculative.standalone_worker import StandaloneWorker
|
|
577
|
+
|
|
578
|
+
self.draft_worker = StandaloneWorker(
|
|
579
|
+
gpu_id=gpu_id,
|
|
580
|
+
tp_rank=tp_rank,
|
|
581
|
+
moe_ep_rank=moe_ep_rank,
|
|
582
|
+
server_args=server_args,
|
|
583
|
+
nccl_port=port_args.nccl_port,
|
|
584
|
+
target_worker=self.tp_worker,
|
|
585
|
+
dp_rank=dp_rank,
|
|
586
|
+
)
|
|
587
|
+
elif self.spec_algorithm.is_ngram():
|
|
588
|
+
from sglang.srt.speculative.ngram_worker import NGRAMWorker
|
|
589
|
+
|
|
590
|
+
self.draft_worker = NGRAMWorker(
|
|
591
|
+
gpu_id=gpu_id,
|
|
592
|
+
tp_rank=tp_rank,
|
|
593
|
+
moe_ep_rank=moe_ep_rank,
|
|
594
|
+
server_args=server_args,
|
|
595
|
+
nccl_port=port_args.nccl_port,
|
|
596
|
+
target_worker=self.tp_worker,
|
|
597
|
+
dp_rank=dp_rank,
|
|
598
|
+
)
|
|
599
|
+
else:
|
|
600
|
+
self.draft_worker = None
|
|
601
|
+
|
|
602
|
+
def init_sockets(self, server_args: ServerArgs, port_args: PortArgs):
|
|
603
|
+
context = zmq.Context(2)
|
|
604
|
+
self.idle_sleeper = None
|
|
605
|
+
|
|
606
|
+
class SenderWrapper:
|
|
607
|
+
def __init__(self, socket: zmq.Socket):
|
|
608
|
+
self.socket = socket
|
|
609
|
+
|
|
610
|
+
def send_output(
|
|
611
|
+
self,
|
|
612
|
+
output: Union[BaseReq, BaseBatchReq],
|
|
613
|
+
recv_obj: Optional[Union[BaseReq, BaseBatchReq]] = None,
|
|
614
|
+
):
|
|
615
|
+
if self.socket is None:
|
|
616
|
+
return
|
|
617
|
+
|
|
618
|
+
if (
|
|
619
|
+
isinstance(recv_obj, BaseReq)
|
|
620
|
+
and recv_obj.http_worker_ipc is not None
|
|
621
|
+
and output.http_worker_ipc is None
|
|
622
|
+
):
|
|
623
|
+
# handle communicator reqs for multi-http worker case
|
|
624
|
+
output.http_worker_ipc = recv_obj.http_worker_ipc
|
|
625
|
+
|
|
626
|
+
self.socket.send_pyobj(output)
|
|
627
|
+
|
|
628
|
+
if self.pp_rank == 0 and self.attn_tp_rank == 0:
|
|
629
|
+
self.recv_from_tokenizer = get_zmq_socket(
|
|
630
|
+
context, zmq.PULL, port_args.scheduler_input_ipc_name, False
|
|
631
|
+
)
|
|
632
|
+
self.recv_from_rpc = get_zmq_socket(
|
|
633
|
+
context, zmq.DEALER, port_args.rpc_ipc_name, False
|
|
634
|
+
)
|
|
635
|
+
|
|
636
|
+
send_to_tokenizer = get_zmq_socket(
|
|
637
|
+
context, zmq.PUSH, port_args.tokenizer_ipc_name, False
|
|
638
|
+
)
|
|
639
|
+
if server_args.skip_tokenizer_init:
|
|
640
|
+
# Directly send to the TokenizerManager
|
|
641
|
+
send_to_detokenizer = get_zmq_socket(
|
|
642
|
+
context, zmq.PUSH, port_args.tokenizer_ipc_name, False
|
|
643
|
+
)
|
|
644
|
+
else:
|
|
645
|
+
# Send to the DetokenizerManager
|
|
646
|
+
send_to_detokenizer = get_zmq_socket(
|
|
647
|
+
context, zmq.PUSH, port_args.detokenizer_ipc_name, False
|
|
648
|
+
)
|
|
649
|
+
|
|
650
|
+
self.send_to_tokenizer = SenderWrapper(send_to_tokenizer)
|
|
651
|
+
self.send_to_detokenizer = SenderWrapper(send_to_detokenizer)
|
|
652
|
+
|
|
653
|
+
if self.server_args.sleep_on_idle:
|
|
654
|
+
self.idle_sleeper = IdleSleeper(
|
|
655
|
+
[
|
|
656
|
+
self.recv_from_tokenizer,
|
|
657
|
+
self.recv_from_rpc,
|
|
658
|
+
]
|
|
659
|
+
)
|
|
660
|
+
else:
|
|
661
|
+
self.recv_from_tokenizer = None
|
|
662
|
+
self.recv_from_rpc = None
|
|
663
|
+
self.send_to_tokenizer = SenderWrapper(None)
|
|
664
|
+
self.send_to_detokenizer = SenderWrapper(None)
|
|
665
|
+
|
|
666
|
+
if self.current_scheduler_metrics_enabled():
|
|
667
|
+
self.send_metrics_from_scheduler = get_zmq_socket(
|
|
668
|
+
context, zmq.PUSH, port_args.metrics_ipc_name, False
|
|
669
|
+
)
|
|
670
|
+
|
|
666
671
|
def init_deterministic_inference_config(self):
|
|
667
672
|
"""Initialize deterministic inference configuration for different attention backends."""
|
|
668
673
|
if not self.server_args.enable_deterministic_inference:
|
|
@@ -768,15 +773,20 @@ class Scheduler(
|
|
|
768
773
|
self.tree_cache.cache_controller.layer_done_counter
|
|
769
774
|
)
|
|
770
775
|
elif self.is_hybrid:
|
|
771
|
-
assert (
|
|
772
|
-
self.server_args.disaggregation_mode == "null"
|
|
773
|
-
), "Hybrid mode does not support disaggregation yet"
|
|
774
776
|
self.tree_cache = SWARadixCache(
|
|
775
777
|
req_to_token_pool=self.req_to_token_pool,
|
|
776
778
|
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
|
777
779
|
sliding_window_size=self.sliding_window_size,
|
|
778
780
|
page_size=self.page_size,
|
|
779
781
|
disable=server_args.disable_radix_cache,
|
|
782
|
+
is_eagle=self.spec_algorithm.is_eagle(),
|
|
783
|
+
)
|
|
784
|
+
elif self.is_hybrid_gdn:
|
|
785
|
+
self.tree_cache = MambaRadixCache(
|
|
786
|
+
req_to_token_pool=self.req_to_token_pool,
|
|
787
|
+
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
|
788
|
+
page_size=self.page_size,
|
|
789
|
+
disable=server_args.disable_radix_cache,
|
|
780
790
|
)
|
|
781
791
|
elif server_args.enable_lmcache:
|
|
782
792
|
from sglang.srt.mem_cache.storage.lmcache.lmc_radix_cache import (
|
|
@@ -931,6 +941,34 @@ class Scheduler(
|
|
|
931
941
|
# The prefill requests that are in the middle of kv sending
|
|
932
942
|
self.disagg_prefill_inflight_queue: List[Req] = []
|
|
933
943
|
|
|
944
|
+
def init_overlap(self):
|
|
945
|
+
if not self.enable_overlap:
|
|
946
|
+
return
|
|
947
|
+
|
|
948
|
+
self.forward_stream: CudaStream = torch.get_device_module(self.device).Stream()
|
|
949
|
+
self.forward_stream_ctx: CudaStreamContext = torch.get_device_module(
|
|
950
|
+
self.device
|
|
951
|
+
).stream(self.forward_stream)
|
|
952
|
+
self.copy_stream: CudaStream = torch.get_device_module(self.device).Stream()
|
|
953
|
+
self.copy_stream_ctx: CudaStreamContext = torch.get_device_module(
|
|
954
|
+
self.device
|
|
955
|
+
).stream(self.copy_stream)
|
|
956
|
+
|
|
957
|
+
self.future_map = FutureMap(
|
|
958
|
+
self.max_running_requests, self.device, self.spec_algorithm
|
|
959
|
+
)
|
|
960
|
+
self.batch_record_buf = [None] * 2
|
|
961
|
+
self.batch_record_ct = 0
|
|
962
|
+
|
|
963
|
+
def record_batch_in_overlap(self, model_worker_batch: ModelWorkerBatch):
|
|
964
|
+
# FIXME(lsyin): hacky way to keep a reference to avoid GPU tensors being freed by torch GC
|
|
965
|
+
# NOTE: More Reliable: record all tensors into the forward stream
|
|
966
|
+
# NOTE: - for all future tensors, we shall always read from future map
|
|
967
|
+
# - for all non-future tensors (produced only by schedule stream),
|
|
968
|
+
# we shall keep its reference not being release during all the forwarding pass
|
|
969
|
+
self.batch_record_ct = (self.batch_record_ct + 1) % 2
|
|
970
|
+
self.batch_record_buf[self.batch_record_ct] = model_worker_batch
|
|
971
|
+
|
|
934
972
|
def init_moe_config(self):
|
|
935
973
|
if hasattr(self.model_config.hf_config, "num_experts_per_tok"):
|
|
936
974
|
initialize_moe_config(self.server_args)
|
|
@@ -957,7 +995,7 @@ class Scheduler(
|
|
|
957
995
|
@DynamicGradMode()
|
|
958
996
|
def event_loop_overlap(self):
|
|
959
997
|
"""A scheduler loop that overlaps the CPU processing and GPU computation."""
|
|
960
|
-
self.result_queue = deque()
|
|
998
|
+
self.result_queue: Deque[Tuple[ScheduleBatch, GenerationBatchResult]] = deque()
|
|
961
999
|
|
|
962
1000
|
while True:
|
|
963
1001
|
recv_reqs = self.recv_requests()
|
|
@@ -966,158 +1004,24 @@ class Scheduler(
|
|
|
966
1004
|
batch = self.get_next_batch_to_run()
|
|
967
1005
|
self.cur_batch = batch
|
|
968
1006
|
|
|
1007
|
+
batch_result = None
|
|
969
1008
|
if batch:
|
|
970
|
-
|
|
971
|
-
|
|
972
|
-
self.result_queue.append((batch.copy(), result))
|
|
973
|
-
|
|
974
|
-
if self.last_batch is None:
|
|
975
|
-
# Create a dummy first batch to start the pipeline for overlap schedule.
|
|
976
|
-
# It is now used for triggering the sampling_info_done event.
|
|
977
|
-
tmp_batch = ScheduleBatch(
|
|
978
|
-
reqs=None,
|
|
979
|
-
forward_mode=ForwardMode.DUMMY_FIRST,
|
|
980
|
-
next_batch_sampling_info=self.tp_worker.cur_sampling_info,
|
|
981
|
-
)
|
|
982
|
-
self.process_batch_result(tmp_batch, None, batch.launch_done)
|
|
1009
|
+
batch_result = self.run_batch(batch)
|
|
1010
|
+
self.result_queue.append((batch.copy(), batch_result))
|
|
983
1011
|
|
|
984
1012
|
if self.last_batch:
|
|
985
1013
|
# Process the results of the last batch
|
|
986
1014
|
tmp_batch, tmp_result = self.result_queue.popleft()
|
|
987
|
-
tmp_batch
|
|
988
|
-
self.tp_worker.cur_sampling_info if batch else None
|
|
989
|
-
)
|
|
990
|
-
# NOTE: we should use current launched batch's launch_done event Instead of the last batch's
|
|
991
|
-
self.process_batch_result(
|
|
992
|
-
tmp_batch, tmp_result, batch.launch_done if batch else None
|
|
993
|
-
)
|
|
1015
|
+
self.process_batch_result(tmp_batch, tmp_result)
|
|
994
1016
|
elif batch is None:
|
|
995
1017
|
# When the server is idle, do self-check and re-init some states
|
|
996
1018
|
self.self_check_during_idle()
|
|
997
1019
|
|
|
1020
|
+
self.launch_batch_sample_if_needed(batch_result)
|
|
998
1021
|
self.last_batch = batch
|
|
999
1022
|
|
|
1000
|
-
|
|
1001
|
-
|
|
1002
|
-
"""A non-overlap scheduler loop for pipeline parallelism."""
|
|
1003
|
-
mbs = [None] * self.pp_size
|
|
1004
|
-
last_mbs = [None] * self.pp_size
|
|
1005
|
-
self.running_mbs = [
|
|
1006
|
-
ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
|
|
1007
|
-
]
|
|
1008
|
-
pp_outputs: Optional[PPProxyTensors] = None
|
|
1009
|
-
while True:
|
|
1010
|
-
server_is_idle = True
|
|
1011
|
-
for mb_id in range(self.pp_size):
|
|
1012
|
-
self.running_batch = self.running_mbs[mb_id]
|
|
1013
|
-
self.last_batch = last_mbs[mb_id]
|
|
1014
|
-
|
|
1015
|
-
recv_reqs = self.recv_requests()
|
|
1016
|
-
self.process_input_requests(recv_reqs)
|
|
1017
|
-
mbs[mb_id] = self.get_next_batch_to_run()
|
|
1018
|
-
self.running_mbs[mb_id] = self.running_batch
|
|
1019
|
-
|
|
1020
|
-
self.cur_batch = mbs[mb_id]
|
|
1021
|
-
if self.cur_batch:
|
|
1022
|
-
server_is_idle = False
|
|
1023
|
-
result = self.run_batch(self.cur_batch)
|
|
1024
|
-
|
|
1025
|
-
# (last rank) send the outputs to the next step
|
|
1026
|
-
if self.pp_group.is_last_rank:
|
|
1027
|
-
if self.cur_batch:
|
|
1028
|
-
next_token_ids = result.next_token_ids
|
|
1029
|
-
if self.cur_batch.return_logprob:
|
|
1030
|
-
pp_outputs = PPProxyTensors(
|
|
1031
|
-
{
|
|
1032
|
-
"next_token_ids": next_token_ids,
|
|
1033
|
-
"extend_input_len_per_req": result.extend_input_len_per_req,
|
|
1034
|
-
"extend_logprob_start_len_per_req": result.extend_logprob_start_len_per_req,
|
|
1035
|
-
}
|
|
1036
|
-
| (
|
|
1037
|
-
{
|
|
1038
|
-
f"logits_output.{k}": v
|
|
1039
|
-
for k, v in result.logits_output.__dict__.items()
|
|
1040
|
-
}
|
|
1041
|
-
if result.logits_output is not None
|
|
1042
|
-
else {}
|
|
1043
|
-
)
|
|
1044
|
-
)
|
|
1045
|
-
else:
|
|
1046
|
-
pp_outputs = PPProxyTensors(
|
|
1047
|
-
{
|
|
1048
|
-
"next_token_ids": next_token_ids,
|
|
1049
|
-
}
|
|
1050
|
-
)
|
|
1051
|
-
# send the output from the last round to let the next stage worker run post processing
|
|
1052
|
-
self.pp_group.send_tensor_dict(
|
|
1053
|
-
pp_outputs.tensors,
|
|
1054
|
-
all_gather_group=self.attn_tp_group,
|
|
1055
|
-
)
|
|
1056
|
-
|
|
1057
|
-
# receive outputs and post-process (filter finished reqs) the coming microbatch
|
|
1058
|
-
next_mb_id = (mb_id + 1) % self.pp_size
|
|
1059
|
-
next_pp_outputs = None
|
|
1060
|
-
if mbs[next_mb_id] is not None:
|
|
1061
|
-
next_pp_outputs: Optional[PPProxyTensors] = PPProxyTensors(
|
|
1062
|
-
self.pp_group.recv_tensor_dict(
|
|
1063
|
-
all_gather_group=self.attn_tp_group
|
|
1064
|
-
)
|
|
1065
|
-
)
|
|
1066
|
-
mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"]
|
|
1067
|
-
logits_output_args = {
|
|
1068
|
-
k[len("logits_output.") :]: v
|
|
1069
|
-
for k, v in next_pp_outputs.tensors.items()
|
|
1070
|
-
if k.startswith("logits_output.")
|
|
1071
|
-
}
|
|
1072
|
-
if len(logits_output_args) > 0:
|
|
1073
|
-
logits_output = LogitsProcessorOutput(**logits_output_args)
|
|
1074
|
-
else:
|
|
1075
|
-
logits_output = None
|
|
1076
|
-
|
|
1077
|
-
output_result = GenerationBatchResult.from_pp_proxy(
|
|
1078
|
-
logits_output=logits_output,
|
|
1079
|
-
next_pp_outputs=next_pp_outputs,
|
|
1080
|
-
can_run_cuda_graph=result.can_run_cuda_graph,
|
|
1081
|
-
)
|
|
1082
|
-
self.process_batch_result(mbs[next_mb_id], output_result)
|
|
1083
|
-
last_mbs[next_mb_id] = mbs[next_mb_id]
|
|
1084
|
-
|
|
1085
|
-
# (not last rank)
|
|
1086
|
-
if not self.pp_group.is_last_rank:
|
|
1087
|
-
# carry the outputs to the next stage
|
|
1088
|
-
# send the outputs from the last round to let the next stage worker run post processing
|
|
1089
|
-
if pp_outputs:
|
|
1090
|
-
self.pp_group.send_tensor_dict(
|
|
1091
|
-
pp_outputs.tensors,
|
|
1092
|
-
all_gather_group=self.attn_tp_group,
|
|
1093
|
-
)
|
|
1094
|
-
|
|
1095
|
-
# send out reqs to the next stage
|
|
1096
|
-
dp_offset = self.attn_dp_rank * self.attn_tp_size
|
|
1097
|
-
if self.attn_tp_rank == 0:
|
|
1098
|
-
point_to_point_pyobj(
|
|
1099
|
-
recv_reqs,
|
|
1100
|
-
self.pp_rank * self.tp_size + dp_offset,
|
|
1101
|
-
self.world_group.device_group,
|
|
1102
|
-
self.pp_rank * self.tp_size + dp_offset,
|
|
1103
|
-
(self.pp_rank + 1) * self.tp_size + dp_offset,
|
|
1104
|
-
)
|
|
1105
|
-
|
|
1106
|
-
# send out proxy tensors to the next stage
|
|
1107
|
-
if self.cur_batch:
|
|
1108
|
-
# FIXME(lsyin): remove this assert
|
|
1109
|
-
assert result.pp_hidden_states_proxy_tensors.tensors is not None
|
|
1110
|
-
self.pp_group.send_tensor_dict(
|
|
1111
|
-
result.pp_hidden_states_proxy_tensors.tensors,
|
|
1112
|
-
all_gather_group=self.attn_tp_group,
|
|
1113
|
-
)
|
|
1114
|
-
|
|
1115
|
-
pp_outputs = next_pp_outputs
|
|
1116
|
-
|
|
1117
|
-
# When the server is idle, self-check and re-init some states
|
|
1118
|
-
if server_is_idle:
|
|
1119
|
-
# When the server is idle, do self-check and re-init some states
|
|
1120
|
-
self.self_check_during_idle()
|
|
1023
|
+
if envs.SGLANG_ENABLE_RUNTIME_MEM_LEAK_CHECK.get():
|
|
1024
|
+
self._check_runtime_mem_leak()
|
|
1121
1025
|
|
|
1122
1026
|
def recv_requests(self) -> List[Req]:
|
|
1123
1027
|
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
|
|
@@ -1240,23 +1144,13 @@ class Scheduler(
|
|
|
1240
1144
|
self.return_health_check_ct += 1
|
|
1241
1145
|
continue
|
|
1242
1146
|
|
|
1243
|
-
# If it is a MultiTokenizerWrapper, unwrap it and handle the inner request.
|
|
1244
|
-
if isinstance(recv_req, MultiTokenizerWrapper):
|
|
1245
|
-
worker_id = recv_req.worker_id
|
|
1246
|
-
recv_req = recv_req.obj
|
|
1247
|
-
output = self._request_dispatcher(recv_req)
|
|
1248
|
-
if output is not None:
|
|
1249
|
-
output = MultiTokenizerWrapper(worker_id, output)
|
|
1250
|
-
self.send_to_tokenizer.send_pyobj(output)
|
|
1251
|
-
continue
|
|
1252
|
-
|
|
1253
1147
|
output = self._request_dispatcher(recv_req)
|
|
1254
1148
|
if output is not None:
|
|
1255
1149
|
if isinstance(output, RpcReqOutput):
|
|
1256
1150
|
if self.recv_from_rpc is not None:
|
|
1257
1151
|
self.recv_from_rpc.send_pyobj(output)
|
|
1258
1152
|
else:
|
|
1259
|
-
self.send_to_tokenizer.
|
|
1153
|
+
self.send_to_tokenizer.send_output(output, recv_req)
|
|
1260
1154
|
|
|
1261
1155
|
def init_req_max_new_tokens(self, req):
|
|
1262
1156
|
req.sampling_params.max_new_tokens = min(
|
|
@@ -1312,6 +1206,7 @@ class Scheduler(
|
|
|
1312
1206
|
metrics_collector=(
|
|
1313
1207
|
self.metrics_collector if self.enable_metrics else None
|
|
1314
1208
|
),
|
|
1209
|
+
http_worker_ipc=recv_req.http_worker_ipc,
|
|
1315
1210
|
)
|
|
1316
1211
|
req.tokenizer = self.tokenizer
|
|
1317
1212
|
|
|
@@ -1410,26 +1305,29 @@ class Scheduler(
|
|
|
1410
1305
|
or req.sampling_params.ebnf is not None
|
|
1411
1306
|
or req.sampling_params.structural_tag is not None
|
|
1412
1307
|
):
|
|
1413
|
-
|
|
1414
|
-
|
|
1415
|
-
|
|
1416
|
-
elif req.sampling_params.regex is not None:
|
|
1417
|
-
key = ("regex", req.sampling_params.regex)
|
|
1418
|
-
elif req.sampling_params.ebnf is not None:
|
|
1419
|
-
key = ("ebnf", req.sampling_params.ebnf)
|
|
1420
|
-
elif req.sampling_params.structural_tag:
|
|
1421
|
-
key = ("structural_tag", req.sampling_params.structural_tag)
|
|
1422
|
-
|
|
1423
|
-
value, cache_hit = self.grammar_backend.get_cached_or_future_value(key)
|
|
1424
|
-
req.grammar = value
|
|
1425
|
-
|
|
1426
|
-
if not cache_hit:
|
|
1427
|
-
req.grammar_key = key
|
|
1428
|
-
add_to_grammar_queue = True
|
|
1308
|
+
if self.grammar_backend is None:
|
|
1309
|
+
error_msg = "Grammar-based generation (json_schema, regex, ebnf, structural_tag) is not supported when the server is launched with --grammar-backend none"
|
|
1310
|
+
req.set_finish_with_abort(error_msg)
|
|
1429
1311
|
else:
|
|
1430
|
-
if
|
|
1431
|
-
|
|
1432
|
-
|
|
1312
|
+
if req.sampling_params.json_schema is not None:
|
|
1313
|
+
key = ("json", req.sampling_params.json_schema)
|
|
1314
|
+
elif req.sampling_params.regex is not None:
|
|
1315
|
+
key = ("regex", req.sampling_params.regex)
|
|
1316
|
+
elif req.sampling_params.ebnf is not None:
|
|
1317
|
+
key = ("ebnf", req.sampling_params.ebnf)
|
|
1318
|
+
elif req.sampling_params.structural_tag:
|
|
1319
|
+
key = ("structural_tag", req.sampling_params.structural_tag)
|
|
1320
|
+
|
|
1321
|
+
value, cache_hit = self.grammar_backend.get_cached_or_future_value(key)
|
|
1322
|
+
req.grammar = value
|
|
1323
|
+
|
|
1324
|
+
if not cache_hit:
|
|
1325
|
+
req.grammar_key = key
|
|
1326
|
+
add_to_grammar_queue = True
|
|
1327
|
+
else:
|
|
1328
|
+
if value is INVALID_GRAMMAR_OBJ: # We hit a cached invalid grammar.
|
|
1329
|
+
error_msg = f"Invalid grammar request with cache hit: {key=}"
|
|
1330
|
+
req.set_finish_with_abort(error_msg)
|
|
1433
1331
|
|
|
1434
1332
|
if add_to_grammar_queue:
|
|
1435
1333
|
self.grammar_queue.append(req)
|
|
@@ -1456,8 +1354,18 @@ class Scheduler(
|
|
|
1456
1354
|
last_hash = req.last_host_node.get_last_hash_value()
|
|
1457
1355
|
matched_len = len(req.prefix_indices) + req.host_hit_length
|
|
1458
1356
|
new_input_tokens = req.fill_ids[matched_len:]
|
|
1357
|
+
|
|
1358
|
+
prefix_keys = (
|
|
1359
|
+
req.last_node.get_prefix_hash_values(req.last_node.parent)
|
|
1360
|
+
if self.tree_cache.hicache_storage_pass_prefix_keys
|
|
1361
|
+
else None
|
|
1362
|
+
)
|
|
1459
1363
|
self.tree_cache.prefetch_from_storage(
|
|
1460
|
-
req.rid,
|
|
1364
|
+
req.rid,
|
|
1365
|
+
req.last_host_node,
|
|
1366
|
+
new_input_tokens,
|
|
1367
|
+
last_hash,
|
|
1368
|
+
prefix_keys,
|
|
1461
1369
|
)
|
|
1462
1370
|
|
|
1463
1371
|
def _add_request_to_queue(self, req: Req, is_retracted: bool = False):
|
|
@@ -1489,7 +1397,11 @@ class Scheduler(
|
|
|
1489
1397
|
req.priority = sys.maxsize
|
|
1490
1398
|
else:
|
|
1491
1399
|
req.priority = -sys.maxsize - 1
|
|
1492
|
-
elif
|
|
1400
|
+
elif (
|
|
1401
|
+
not self.enable_priority_scheduling
|
|
1402
|
+
and req.priority is not None
|
|
1403
|
+
and self.abort_on_priority_when_disabled
|
|
1404
|
+
):
|
|
1493
1405
|
abort_req = AbortReq(
|
|
1494
1406
|
finished_reason={
|
|
1495
1407
|
"type": "abort",
|
|
@@ -1498,7 +1410,7 @@ class Scheduler(
|
|
|
1498
1410
|
},
|
|
1499
1411
|
rid=req.rid,
|
|
1500
1412
|
)
|
|
1501
|
-
self.send_to_tokenizer.
|
|
1413
|
+
self.send_to_tokenizer.send_output(abort_req, req)
|
|
1502
1414
|
|
|
1503
1415
|
def _abort_on_queued_limit(self, recv_req: Req) -> bool:
|
|
1504
1416
|
"""Abort an incoming or existing request if the waiting queue is full. Returns True if the incoming request is aborted."""
|
|
@@ -1530,7 +1442,7 @@ class Scheduler(
|
|
|
1530
1442
|
req_to_abort = candidate_req
|
|
1531
1443
|
message = "The request is aborted by a higher priority request."
|
|
1532
1444
|
|
|
1533
|
-
self.send_to_tokenizer.
|
|
1445
|
+
self.send_to_tokenizer.send_output(
|
|
1534
1446
|
AbortReq(
|
|
1535
1447
|
finished_reason={
|
|
1536
1448
|
"type": "abort",
|
|
@@ -1538,7 +1450,8 @@ class Scheduler(
|
|
|
1538
1450
|
"message": message,
|
|
1539
1451
|
},
|
|
1540
1452
|
rid=req_to_abort.rid,
|
|
1541
|
-
)
|
|
1453
|
+
),
|
|
1454
|
+
req_to_abort,
|
|
1542
1455
|
)
|
|
1543
1456
|
return req_to_abort.rid == recv_req.rid
|
|
1544
1457
|
|
|
@@ -1553,6 +1466,7 @@ class Scheduler(
|
|
|
1553
1466
|
recv_req.sampling_params,
|
|
1554
1467
|
token_type_ids=recv_req.token_type_ids,
|
|
1555
1468
|
priority=recv_req.priority,
|
|
1469
|
+
http_worker_ipc=recv_req.http_worker_ipc,
|
|
1556
1470
|
)
|
|
1557
1471
|
req.tokenizer = self.tokenizer
|
|
1558
1472
|
|
|
@@ -1602,109 +1516,6 @@ class Scheduler(
|
|
|
1602
1516
|
for tokenized_req in recv_req:
|
|
1603
1517
|
self.handle_embedding_request(tokenized_req)
|
|
1604
1518
|
|
|
1605
|
-
def self_check_during_idle(self):
|
|
1606
|
-
self.check_memory()
|
|
1607
|
-
self.check_tree_cache()
|
|
1608
|
-
self.new_token_ratio = self.init_new_token_ratio
|
|
1609
|
-
self.maybe_sleep_on_idle()
|
|
1610
|
-
|
|
1611
|
-
def check_memory(self):
|
|
1612
|
-
if self.is_hybrid:
|
|
1613
|
-
(
|
|
1614
|
-
full_num_used,
|
|
1615
|
-
swa_num_used,
|
|
1616
|
-
_,
|
|
1617
|
-
_,
|
|
1618
|
-
full_available_size,
|
|
1619
|
-
full_evictable_size,
|
|
1620
|
-
swa_available_size,
|
|
1621
|
-
swa_evictable_size,
|
|
1622
|
-
) = self._get_swa_token_info()
|
|
1623
|
-
memory_leak = full_num_used != 0 or swa_num_used != 0
|
|
1624
|
-
token_msg = (
|
|
1625
|
-
f"{self.full_tokens_per_layer=}, {full_available_size=}, {full_evictable_size=}, {self.tree_cache.full_protected_size()=}\n"
|
|
1626
|
-
f"{self.swa_tokens_per_layer=}, {swa_available_size=}, {swa_evictable_size=}, {self.tree_cache.swa_protected_size()=}\n"
|
|
1627
|
-
)
|
|
1628
|
-
else:
|
|
1629
|
-
_, _, available_size, evictable_size = self._get_token_info()
|
|
1630
|
-
protected_size = self.tree_cache.protected_size()
|
|
1631
|
-
memory_leak = (available_size + evictable_size) != (
|
|
1632
|
-
# self.max_total_num_tokens
|
|
1633
|
-
# if not self.enable_hierarchical_cache
|
|
1634
|
-
# else self.max_total_num_tokens - protected_size
|
|
1635
|
-
self.max_total_num_tokens
|
|
1636
|
-
- protected_size
|
|
1637
|
-
)
|
|
1638
|
-
token_msg = f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, {protected_size=}\n"
|
|
1639
|
-
|
|
1640
|
-
if memory_leak:
|
|
1641
|
-
msg = "token_to_kv_pool_allocator memory leak detected! " f"{token_msg}"
|
|
1642
|
-
raise ValueError(msg)
|
|
1643
|
-
|
|
1644
|
-
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
|
1645
|
-
req_total_size = (
|
|
1646
|
-
self.req_to_token_pool.size + self.req_to_token_pool.pre_alloc_size
|
|
1647
|
-
)
|
|
1648
|
-
else:
|
|
1649
|
-
req_total_size = self.req_to_token_pool.size
|
|
1650
|
-
|
|
1651
|
-
if len(self.req_to_token_pool.free_slots) != req_total_size:
|
|
1652
|
-
msg = (
|
|
1653
|
-
"req_to_token_pool memory leak detected!"
|
|
1654
|
-
f"available_size={len(self.req_to_token_pool.free_slots)}, "
|
|
1655
|
-
f"total_size={self.req_to_token_pool.size}\n"
|
|
1656
|
-
)
|
|
1657
|
-
raise ValueError(msg)
|
|
1658
|
-
|
|
1659
|
-
if (
|
|
1660
|
-
self.enable_metrics
|
|
1661
|
-
and self.current_scheduler_metrics_enabled()
|
|
1662
|
-
and time.perf_counter() > self.metrics_collector.last_log_time + 30
|
|
1663
|
-
):
|
|
1664
|
-
# During idle time, also collect metrics every 30 seconds.
|
|
1665
|
-
if self.is_hybrid:
|
|
1666
|
-
(
|
|
1667
|
-
full_num_used,
|
|
1668
|
-
swa_num_used,
|
|
1669
|
-
full_token_usage,
|
|
1670
|
-
swa_token_usage,
|
|
1671
|
-
_,
|
|
1672
|
-
_,
|
|
1673
|
-
_,
|
|
1674
|
-
_,
|
|
1675
|
-
) = self._get_swa_token_info()
|
|
1676
|
-
num_used = max(full_num_used, swa_num_used)
|
|
1677
|
-
token_usage = max(full_token_usage, swa_token_usage)
|
|
1678
|
-
else:
|
|
1679
|
-
num_used, token_usage, _, _ = self._get_token_info()
|
|
1680
|
-
num_running_reqs = len(self.running_batch.reqs)
|
|
1681
|
-
self.stats.num_running_reqs = num_running_reqs
|
|
1682
|
-
self.stats.num_used_tokens = num_used
|
|
1683
|
-
self.stats.token_usage = round(token_usage, 2)
|
|
1684
|
-
self.stats.gen_throughput = 0
|
|
1685
|
-
self.stats.num_queue_reqs = len(self.waiting_queue)
|
|
1686
|
-
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
|
|
1687
|
-
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
|
1688
|
-
self.stats.num_prefill_prealloc_queue_reqs = len(
|
|
1689
|
-
self.disagg_prefill_bootstrap_queue.queue
|
|
1690
|
-
)
|
|
1691
|
-
self.stats.num_prefill_inflight_queue_reqs = len(
|
|
1692
|
-
self.disagg_prefill_inflight_queue
|
|
1693
|
-
)
|
|
1694
|
-
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
|
1695
|
-
self.stats.num_decode_prealloc_queue_reqs = len(
|
|
1696
|
-
self.disagg_decode_prealloc_queue.queue
|
|
1697
|
-
)
|
|
1698
|
-
self.stats.num_decode_transfer_queue_reqs = len(
|
|
1699
|
-
self.disagg_decode_transfer_queue.queue
|
|
1700
|
-
)
|
|
1701
|
-
self.metrics_collector.log_stats(self.stats)
|
|
1702
|
-
self._publish_kv_events()
|
|
1703
|
-
|
|
1704
|
-
def check_tree_cache(self):
|
|
1705
|
-
if self.is_hybrid and isinstance(self.tree_cache, SWARadixCache):
|
|
1706
|
-
self.tree_cache.sanity_check()
|
|
1707
|
-
|
|
1708
1519
|
def _get_token_info(self):
|
|
1709
1520
|
available_size = self.token_to_kv_pool_allocator.available_size()
|
|
1710
1521
|
evictable_size = self.tree_cache.evictable_size()
|
|
@@ -1712,6 +1523,35 @@ class Scheduler(
|
|
|
1712
1523
|
token_usage = num_used / self.max_total_num_tokens
|
|
1713
1524
|
return num_used, token_usage, available_size, evictable_size
|
|
1714
1525
|
|
|
1526
|
+
def _get_mamba_token_info(self):
|
|
1527
|
+
is_radix_tree = isinstance(self.tree_cache, MambaRadixCache)
|
|
1528
|
+
full_available_size = self.token_to_kv_pool_allocator.available_size()
|
|
1529
|
+
full_evictable_size = (
|
|
1530
|
+
self.tree_cache.full_evictable_size() if is_radix_tree else 0
|
|
1531
|
+
)
|
|
1532
|
+
mamba_available_size = self.req_to_token_pool.mamba_pool.available_size()
|
|
1533
|
+
mamba_evictable_size = (
|
|
1534
|
+
self.tree_cache.mamba_evictable_size() if is_radix_tree else 0
|
|
1535
|
+
)
|
|
1536
|
+
full_num_used = self.token_to_kv_pool_allocator.size - (
|
|
1537
|
+
full_available_size + full_evictable_size
|
|
1538
|
+
)
|
|
1539
|
+
mamba_num_used = self.req_to_token_pool.mamba_pool.size - (
|
|
1540
|
+
mamba_available_size + mamba_evictable_size
|
|
1541
|
+
)
|
|
1542
|
+
full_token_usage = full_num_used / self.token_to_kv_pool_allocator.size
|
|
1543
|
+
mamba_usage = mamba_num_used / self.req_to_token_pool.mamba_pool.size
|
|
1544
|
+
return (
|
|
1545
|
+
full_num_used,
|
|
1546
|
+
mamba_num_used,
|
|
1547
|
+
full_token_usage,
|
|
1548
|
+
mamba_usage,
|
|
1549
|
+
full_available_size,
|
|
1550
|
+
full_evictable_size,
|
|
1551
|
+
mamba_available_size,
|
|
1552
|
+
mamba_evictable_size,
|
|
1553
|
+
)
|
|
1554
|
+
|
|
1715
1555
|
def _get_swa_token_info(self):
|
|
1716
1556
|
full_available_size = self.token_to_kv_pool_allocator.full_available_size()
|
|
1717
1557
|
full_evictable_size = self.tree_cache.full_evictable_size()
|
|
@@ -1745,7 +1585,7 @@ class Scheduler(
|
|
|
1745
1585
|
chunked_req_to_exclude.add(self.chunked_req)
|
|
1746
1586
|
self.tree_cache.cache_unfinished_req(self.chunked_req, chunked=True)
|
|
1747
1587
|
# chunked request keeps its rid but will get a new req_pool_idx
|
|
1748
|
-
if self.tp_worker.
|
|
1588
|
+
if self.tp_worker.model_runner.mambaish_config is not None:
|
|
1749
1589
|
self.req_to_token_pool.free(
|
|
1750
1590
|
self.chunked_req.req_pool_idx, free_mamba_cache=False
|
|
1751
1591
|
)
|
|
@@ -1802,7 +1642,7 @@ class Scheduler(
|
|
|
1802
1642
|
return ret
|
|
1803
1643
|
|
|
1804
1644
|
def get_num_allocatable_reqs(self, running_bs):
|
|
1805
|
-
res =
|
|
1645
|
+
res = get_global_server_args().pp_max_micro_batch_size - running_bs
|
|
1806
1646
|
if self.pp_size > 1:
|
|
1807
1647
|
res = min(res, self.req_to_token_pool.available_size())
|
|
1808
1648
|
return res
|
|
@@ -1999,7 +1839,7 @@ class Scheduler(
|
|
|
1999
1839
|
|
|
2000
1840
|
# Check if decode out of memory
|
|
2001
1841
|
if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or (
|
|
2002
|
-
TEST_RETRACT and
|
|
1842
|
+
TEST_RETRACT and self.forward_ct % TEST_RETRACT_INTERVAL == 0
|
|
2003
1843
|
):
|
|
2004
1844
|
old_ratio = self.new_token_ratio
|
|
2005
1845
|
retracted_reqs, new_token_ratio, reqs_to_abort = batch.retract_decode(
|
|
@@ -2008,8 +1848,8 @@ class Scheduler(
|
|
|
2008
1848
|
self.num_retracted_reqs = len(retracted_reqs)
|
|
2009
1849
|
self.new_token_ratio = new_token_ratio
|
|
2010
1850
|
for req in reqs_to_abort:
|
|
2011
|
-
self.send_to_tokenizer.
|
|
2012
|
-
AbortReq(abort_reason=req.to_abort_message, rid=req.rid)
|
|
1851
|
+
self.send_to_tokenizer.send_output(
|
|
1852
|
+
AbortReq(abort_reason=req.to_abort_message, rid=req.rid), req
|
|
2013
1853
|
)
|
|
2014
1854
|
|
|
2015
1855
|
logger.info(
|
|
@@ -2034,6 +1874,12 @@ class Scheduler(
|
|
|
2034
1874
|
batch.prepare_for_decode()
|
|
2035
1875
|
return batch
|
|
2036
1876
|
|
|
1877
|
+
# placeholder for override
|
|
1878
|
+
def update_cache_from_scheduler(
|
|
1879
|
+
self, schedule_batch: ScheduleBatch, batch_result: GenerationBatchResult
|
|
1880
|
+
):
|
|
1881
|
+
pass
|
|
1882
|
+
|
|
2037
1883
|
def run_batch(
|
|
2038
1884
|
self, batch: ScheduleBatch
|
|
2039
1885
|
) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
|
|
@@ -2051,22 +1897,72 @@ class Scheduler(
|
|
|
2051
1897
|
|
|
2052
1898
|
batch_or_worker_batch = batch
|
|
2053
1899
|
|
|
2054
|
-
if self.spec_algorithm.is_none():
|
|
1900
|
+
if self.enable_overlap or self.spec_algorithm.is_none():
|
|
2055
1901
|
# FIXME(lsyin): remove this if and finally unify the abstraction
|
|
2056
1902
|
batch_or_worker_batch = batch.get_model_worker_batch()
|
|
2057
1903
|
|
|
2058
|
-
|
|
2059
|
-
|
|
2060
|
-
|
|
1904
|
+
if self.enable_overlap:
|
|
1905
|
+
# FIXME: remove this assert
|
|
1906
|
+
assert isinstance(batch_or_worker_batch, ModelWorkerBatch)
|
|
1907
|
+
model_worker_batch = batch_or_worker_batch
|
|
1908
|
+
self.record_batch_in_overlap(model_worker_batch)
|
|
1909
|
+
|
|
1910
|
+
# Sampling info will be modified during forward
|
|
1911
|
+
model_worker_batch.sampling_info = (
|
|
1912
|
+
model_worker_batch.sampling_info.copy_for_forward()
|
|
1913
|
+
)
|
|
1914
|
+
|
|
1915
|
+
bs = len(model_worker_batch.seq_lens)
|
|
1916
|
+
future_indices = self.future_map.alloc_future_indices(bs)
|
|
2061
1917
|
|
|
2062
|
-
|
|
2063
|
-
|
|
2064
|
-
|
|
2065
|
-
|
|
1918
|
+
with self.forward_stream_ctx:
|
|
1919
|
+
self.forward_stream.wait_stream(self.default_stream)
|
|
1920
|
+
self.future_map.resolve_future(model_worker_batch)
|
|
1921
|
+
batch_result = self.model_worker.forward_batch_generation(
|
|
1922
|
+
model_worker_batch
|
|
1923
|
+
)
|
|
1924
|
+
# FIXME(lsyin): maybe move this to forward_batch_generation
|
|
1925
|
+
batch_result.copy_done = torch.get_device_module(
|
|
1926
|
+
self.device
|
|
1927
|
+
).Event()
|
|
1928
|
+
if batch_result.delay_sample_func is None:
|
|
1929
|
+
self.future_map.store_to_map(future_indices, batch_result)
|
|
1930
|
+
batch_result.copy_to_cpu()
|
|
1931
|
+
else:
|
|
1932
|
+
batch_result.future_indices = future_indices
|
|
1933
|
+
|
|
1934
|
+
# FIXME(lsyin): move this assignment elsewhere
|
|
1935
|
+
future_indices_or_next_token_ids = -future_indices.indices
|
|
1936
|
+
|
|
1937
|
+
if batch.is_v2_eagle:
|
|
1938
|
+
# FIXME(lsyin): tmp code for eagle v2
|
|
1939
|
+
# We only keep future indices for next draft input
|
|
1940
|
+
|
|
1941
|
+
batch.spec_info = batch_result.next_draft_input
|
|
1942
|
+
batch.spec_info.future_indices = future_indices
|
|
1943
|
+
|
|
1944
|
+
# batch.spec_info = EagleDraftInput(
|
|
1945
|
+
# future_indices=future_indices,
|
|
1946
|
+
# verify_done=batch_result.next_draft_input.verify_done,
|
|
1947
|
+
# # FIXME(lsyin): remove the allocate_lens in EagleDraftInput
|
|
1948
|
+
# allocate_lens=batch_result.next_draft_input.allocate_lens,
|
|
1949
|
+
# )
|
|
1950
|
+
|
|
1951
|
+
# The future value, usually for next batch preparation
|
|
1952
|
+
# Current implementation strictly synchronizes the seq_lens
|
|
1953
|
+
batch.seq_lens = batch_result.next_draft_input.new_seq_lens
|
|
1954
|
+
else:
|
|
1955
|
+
batch_result = self.model_worker.forward_batch_generation(
|
|
1956
|
+
batch_or_worker_batch
|
|
2066
1957
|
)
|
|
1958
|
+
future_indices_or_next_token_ids = batch_result.next_token_ids
|
|
1959
|
+
self.update_cache_from_scheduler(batch, batch_result)
|
|
2067
1960
|
|
|
2068
|
-
#
|
|
2069
|
-
|
|
1961
|
+
# NOTE: future_indices_or_next_token_ids is used in ScheduleBatch,
|
|
1962
|
+
# which can probably be replaced by future_indices later [TODO(lsyin)].
|
|
1963
|
+
# we shall still keep the original outputs, e.g. next_token_ids
|
|
1964
|
+
# in the GenerationBatchOutput for processing after copy_done.
|
|
1965
|
+
batch.output_ids = future_indices_or_next_token_ids
|
|
2070
1966
|
|
|
2071
1967
|
# These 2 values are needed for processing the output, but the values can be
|
|
2072
1968
|
# modified by overlap schedule. So we have to copy them here so that
|
|
@@ -2083,39 +1979,51 @@ class Scheduler(
|
|
|
2083
1979
|
else:
|
|
2084
1980
|
extend_logprob_start_len_per_req = None
|
|
2085
1981
|
|
|
2086
|
-
|
|
2087
|
-
|
|
2088
|
-
|
|
2089
|
-
extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
|
|
1982
|
+
batch_result.extend_input_len_per_req = extend_input_len_per_req
|
|
1983
|
+
batch_result.extend_logprob_start_len_per_req = (
|
|
1984
|
+
extend_logprob_start_len_per_req
|
|
2090
1985
|
)
|
|
1986
|
+
return batch_result
|
|
2091
1987
|
else: # embedding or reward model
|
|
2092
1988
|
model_worker_batch = batch.get_model_worker_batch()
|
|
2093
1989
|
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
|
|
2094
1990
|
ret = EmbeddingBatchResult(embeddings=embeddings)
|
|
2095
1991
|
return ret
|
|
2096
1992
|
|
|
1993
|
+
def launch_batch_sample_if_needed(
|
|
1994
|
+
self, batch_result: GenerationBatchResult
|
|
1995
|
+
) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
|
|
1996
|
+
# TODO(lsyin): make the delayed sample a default behavior after
|
|
1997
|
+
# unifying the forward_batch_generation interface (related to spec V2).
|
|
1998
|
+
if batch_result is None or batch_result.delay_sample_func is None:
|
|
1999
|
+
return
|
|
2000
|
+
|
|
2001
|
+
with self.forward_stream_ctx:
|
|
2002
|
+
self.forward_stream.wait_stream(self.default_stream)
|
|
2003
|
+
_batch_result = batch_result.delay_sample_func()
|
|
2004
|
+
assert _batch_result is batch_result
|
|
2005
|
+
self.future_map.store_to_map(batch_result.future_indices, batch_result)
|
|
2006
|
+
batch_result.copy_to_cpu()
|
|
2007
|
+
|
|
2097
2008
|
def process_batch_result(
|
|
2098
2009
|
self,
|
|
2099
2010
|
batch: ScheduleBatch,
|
|
2100
2011
|
result: Union[GenerationBatchResult, EmbeddingBatchResult],
|
|
2101
|
-
launch_done: Optional[threading.Event] = None,
|
|
2102
2012
|
):
|
|
2103
2013
|
if batch.forward_mode.is_decode():
|
|
2104
|
-
self.process_batch_result_decode(batch, result
|
|
2014
|
+
self.process_batch_result_decode(batch, result)
|
|
2105
2015
|
if self.enable_trace:
|
|
2106
2016
|
trace_slice_batch("decode loop", batch.reqs)
|
|
2107
2017
|
|
|
2108
2018
|
elif batch.forward_mode.is_extend():
|
|
2109
|
-
self.process_batch_result_prefill(batch, result
|
|
2019
|
+
self.process_batch_result_prefill(batch, result)
|
|
2110
2020
|
if self.enable_trace:
|
|
2111
2021
|
trace_slice_batch("prefill", batch.reqs)
|
|
2112
2022
|
|
|
2113
2023
|
elif batch.forward_mode.is_idle():
|
|
2114
2024
|
if self.enable_overlap:
|
|
2115
|
-
|
|
2116
|
-
|
|
2117
|
-
elif batch.forward_mode.is_dummy_first():
|
|
2118
|
-
self.set_next_batch_sampling_info_done(batch)
|
|
2025
|
+
if result.copy_done is not None:
|
|
2026
|
+
result.copy_done.synchronize()
|
|
2119
2027
|
|
|
2120
2028
|
self.maybe_send_health_check_signal()
|
|
2121
2029
|
|
|
@@ -2125,7 +2033,7 @@ class Scheduler(
|
|
|
2125
2033
|
# This is used to prevent the health check signal being blocked by long context prefill.
|
|
2126
2034
|
# However, one minor issue is that this code path does not check the status of detokenizer manager.
|
|
2127
2035
|
self.return_health_check_ct -= 1
|
|
2128
|
-
self.send_to_tokenizer.
|
|
2036
|
+
self.send_to_tokenizer.send_output(HealthCheckOutput())
|
|
2129
2037
|
|
|
2130
2038
|
def prepare_mlp_sync_batch(self, local_batch: ScheduleBatch):
|
|
2131
2039
|
return self.prepare_mlp_sync_batch_raw(
|
|
@@ -2139,6 +2047,7 @@ class Scheduler(
|
|
|
2139
2047
|
speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
|
|
2140
2048
|
require_mlp_tp_gather=require_mlp_tp_gather(self.server_args),
|
|
2141
2049
|
disable_overlap_schedule=self.server_args.disable_overlap_schedule,
|
|
2050
|
+
offload_tags=self.offload_tags,
|
|
2142
2051
|
)
|
|
2143
2052
|
|
|
2144
2053
|
@staticmethod
|
|
@@ -2153,6 +2062,7 @@ class Scheduler(
|
|
|
2153
2062
|
speculative_num_draft_tokens,
|
|
2154
2063
|
require_mlp_tp_gather: bool,
|
|
2155
2064
|
disable_overlap_schedule: bool,
|
|
2065
|
+
offload_tags: set[str],
|
|
2156
2066
|
):
|
|
2157
2067
|
# Check if other DP workers have running batches
|
|
2158
2068
|
if local_batch is None:
|
|
@@ -2183,7 +2093,7 @@ class Scheduler(
|
|
|
2183
2093
|
)
|
|
2184
2094
|
|
|
2185
2095
|
tbo_preparer = TboDPAttentionPreparer()
|
|
2186
|
-
if disable_overlap_schedule:
|
|
2096
|
+
if len(offload_tags) == 0 and disable_overlap_schedule:
|
|
2187
2097
|
group = tp_group.device_group
|
|
2188
2098
|
device = tp_group.device
|
|
2189
2099
|
else:
|
|
@@ -2325,13 +2235,6 @@ class Scheduler(
|
|
|
2325
2235
|
self._add_request_to_queue(req)
|
|
2326
2236
|
self.grammar_queue = self.grammar_queue[num_ready_reqs:]
|
|
2327
2237
|
|
|
2328
|
-
def set_next_batch_sampling_info_done(self, batch: ScheduleBatch):
|
|
2329
|
-
if batch.next_batch_sampling_info:
|
|
2330
|
-
if batch.next_batch_sampling_info.grammars is not None:
|
|
2331
|
-
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
|
2332
|
-
self.current_stream.synchronize()
|
|
2333
|
-
batch.next_batch_sampling_info.sampling_info_done.set()
|
|
2334
|
-
|
|
2335
2238
|
def watchdog_thread(self):
|
|
2336
2239
|
"""A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
|
|
2337
2240
|
self.watchdog_last_forward_ct = 0
|
|
@@ -2481,12 +2384,10 @@ class Scheduler(
|
|
|
2481
2384
|
)
|
|
2482
2385
|
|
|
2483
2386
|
def get_internal_state(self, recv_req: GetInternalStateReq):
|
|
2484
|
-
ret =
|
|
2387
|
+
ret = vars(get_global_server_args())
|
|
2485
2388
|
ret["last_gen_throughput"] = self.last_gen_throughput
|
|
2486
2389
|
ret["memory_usage"] = {
|
|
2487
|
-
"weight": round(
|
|
2488
|
-
self.tp_worker.worker.model_runner.weight_load_mem_usage, 2
|
|
2489
|
-
),
|
|
2390
|
+
"weight": round(self.tp_worker.model_runner.weight_load_mem_usage, 2),
|
|
2490
2391
|
"kvcache": round(
|
|
2491
2392
|
self.token_to_kv_pool_allocator.get_kvcache().mem_usage, 2
|
|
2492
2393
|
),
|
|
@@ -2494,7 +2395,7 @@ class Scheduler(
|
|
|
2494
2395
|
}
|
|
2495
2396
|
|
|
2496
2397
|
ret["memory_usage"]["graph"] = round(
|
|
2497
|
-
self.tp_worker.
|
|
2398
|
+
self.tp_worker.model_runner.graph_mem_usage, 2
|
|
2498
2399
|
)
|
|
2499
2400
|
|
|
2500
2401
|
if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
|
|
@@ -2510,7 +2411,7 @@ class Scheduler(
|
|
|
2510
2411
|
server_args_dict = recv_req.server_args
|
|
2511
2412
|
args_allow_update = set(
|
|
2512
2413
|
[
|
|
2513
|
-
"
|
|
2414
|
+
"pp_max_micro_batch_size",
|
|
2514
2415
|
"speculative_accept_threshold_single",
|
|
2515
2416
|
"speculative_accept_threshold_acc",
|
|
2516
2417
|
]
|
|
@@ -2521,7 +2422,7 @@ class Scheduler(
|
|
|
2521
2422
|
logging.warning(f"Updating {k} is not supported.")
|
|
2522
2423
|
if_success = False
|
|
2523
2424
|
break
|
|
2524
|
-
elif k == "
|
|
2425
|
+
elif k == "pp_max_micro_batch_size" and (
|
|
2525
2426
|
v > self.max_running_requests // self.pp_size or v < 1
|
|
2526
2427
|
):
|
|
2527
2428
|
logging.warning(
|
|
@@ -2537,11 +2438,11 @@ class Scheduler(
|
|
|
2537
2438
|
logger.info(f"{avg_spec_accept_length=}")
|
|
2538
2439
|
self.cum_spec_accept_length = self.cum_spec_accept_count = 0
|
|
2539
2440
|
for k, v in server_args_dict.items():
|
|
2540
|
-
|
|
2541
|
-
logger.info(f"Global server args updated! {
|
|
2441
|
+
setattr(get_global_server_args(), k, v)
|
|
2442
|
+
logger.info(f"Global server args updated! {get_global_server_args()=}")
|
|
2542
2443
|
return SetInternalStateReqOutput(
|
|
2543
2444
|
updated=True,
|
|
2544
|
-
server_args=
|
|
2445
|
+
server_args=vars(get_global_server_args()),
|
|
2545
2446
|
)
|
|
2546
2447
|
|
|
2547
2448
|
def handle_rpc_request(self, recv_req: RpcReqInput):
|
|
@@ -2579,7 +2480,7 @@ class Scheduler(
|
|
|
2579
2480
|
if self.enable_hicache_storage:
|
|
2580
2481
|
# to release prefetch events associated with the request
|
|
2581
2482
|
self.tree_cache.release_aborted_request(req.rid)
|
|
2582
|
-
self.send_to_tokenizer.
|
|
2483
|
+
self.send_to_tokenizer.send_output(AbortReq(rid=req.rid), req)
|
|
2583
2484
|
# For disaggregation decode mode, the request in the waiting queue has KV cache allocated.
|
|
2584
2485
|
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
|
2585
2486
|
self.tree_cache.cache_finished_req(req)
|
|
@@ -2663,10 +2564,6 @@ class Scheduler(
|
|
|
2663
2564
|
result = self.tp_worker.unload_lora_adapter(recv_req)
|
|
2664
2565
|
return result
|
|
2665
2566
|
|
|
2666
|
-
def register_multi_tokenizer(self, recv_req: MultiTokenizerRegisterReq):
|
|
2667
|
-
self.send_to_detokenizer.send_pyobj(recv_req)
|
|
2668
|
-
return recv_req
|
|
2669
|
-
|
|
2670
2567
|
def init_weights_send_group_for_remote_instance(
|
|
2671
2568
|
self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
|
|
2672
2569
|
):
|
|
@@ -2745,7 +2642,7 @@ class Scheduler(
|
|
|
2745
2642
|
def handle_freeze_gc(self, recv_req: FreezeGCReq):
|
|
2746
2643
|
"""Handle freeze_gc request: freeze scheduler's GC and forward to detokenizer."""
|
|
2747
2644
|
freeze_gc("Scheduler")
|
|
2748
|
-
self.send_to_detokenizer.
|
|
2645
|
+
self.send_to_detokenizer.send_output(recv_req, recv_req)
|
|
2749
2646
|
return None
|
|
2750
2647
|
|
|
2751
2648
|
|
|
@@ -2767,12 +2664,13 @@ class IdleSleeper:
|
|
|
2767
2664
|
for s in sockets:
|
|
2768
2665
|
self.poller.register(s, zmq.POLLIN)
|
|
2769
2666
|
|
|
2667
|
+
self.empty_cache_interval = envs.SGLANG_EMPTY_CACHE_INTERVAL.get()
|
|
2668
|
+
|
|
2770
2669
|
def maybe_sleep(self):
|
|
2771
2670
|
self.poller.poll(1000)
|
|
2772
2671
|
if (
|
|
2773
|
-
|
|
2774
|
-
and time.time() - self.last_empty_time
|
|
2775
|
-
> global_config.torch_empty_cache_interval
|
|
2672
|
+
self.empty_cache_interval > 0
|
|
2673
|
+
and time.time() - self.last_empty_time > self.empty_cache_interval
|
|
2776
2674
|
):
|
|
2777
2675
|
self.last_empty_time = time.time()
|
|
2778
2676
|
torch.cuda.empty_cache()
|
|
@@ -2831,7 +2729,9 @@ def run_scheduler_process(
|
|
|
2831
2729
|
|
|
2832
2730
|
# Set cpu affinity to this gpu process
|
|
2833
2731
|
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
|
|
2834
|
-
set_gpu_proc_affinity(
|
|
2732
|
+
set_gpu_proc_affinity(
|
|
2733
|
+
server_args.pp_size, server_args.tp_size, server_args.nnodes, gpu_id
|
|
2734
|
+
)
|
|
2835
2735
|
if (numa_node := server_args.numa_node) is not None:
|
|
2836
2736
|
numa_bind_to_node(numa_node[gpu_id])
|
|
2837
2737
|
|