sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +321 -31
- sglang/bench_serving.py +10 -3
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +8 -0
- sglang/srt/configs/model_config.py +160 -105
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/constrained/base_grammar_backend.py +1 -0
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +6 -4
- sglang/srt/debug_utils/dumper.py +10 -3
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/common/conn.py +266 -98
- sglang/srt/disaggregation/decode.py +50 -9
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
- sglang/srt/disaggregation/mooncake/conn.py +51 -541
- sglang/srt/disaggregation/nixl/conn.py +148 -39
- sglang/srt/disaggregation/prefill.py +31 -14
- sglang/srt/disaggregation/utils.py +36 -5
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +135 -80
- sglang/srt/entrypoints/engine.py +23 -3
- sglang/srt/entrypoints/grpc_request_manager.py +330 -55
- sglang/srt/entrypoints/grpc_server.py +232 -102
- sglang/srt/entrypoints/http_server.py +49 -9
- sglang/srt/entrypoints/openai/protocol.py +110 -5
- sglang/srt/entrypoints/openai/serving_base.py +25 -6
- sglang/srt/entrypoints/openai/serving_chat.py +178 -49
- sglang/srt/entrypoints/openai/serving_completions.py +5 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/entrypoints/openai/serving_responses.py +42 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/expert_location.py +30 -5
- sglang/srt/function_call/function_call_parser.py +3 -2
- sglang/srt/function_call/glm4_moe_detector.py +3 -3
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
- sglang/srt/layers/activation.py +7 -6
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +108 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
- sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +112 -194
- sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
- sglang/srt/layers/attention/mamba/mamba.py +566 -1
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/triton_backend.py +42 -9
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +11 -1
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +2 -0
- sglang/srt/layers/linear.py +21 -4
- sglang/srt/layers/logits_processor.py +15 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +147 -74
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
- sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
- sglang/srt/layers/moe/utils.py +10 -0
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/fp8.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +1 -1
- sglang/srt/layers/quantization/modelopt_quant.py +44 -9
- sglang/srt/layers/quantization/mxfp4.py +12 -4
- sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
- sglang/srt/layers/quantization/w4afp8.py +0 -4
- sglang/srt/layers/quantization/w8a8_int8.py +15 -3
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +52 -4
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +3 -3
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +10 -4
- sglang/srt/lora/lora.py +7 -5
- sglang/srt/lora/lora_manager.py +17 -6
- sglang/srt/lora/mem_pool.py +1 -1
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +7 -5
- sglang/srt/managers/cache_controller.py +42 -142
- sglang/srt/managers/data_parallel_controller.py +11 -46
- sglang/srt/managers/detokenizer_manager.py +11 -11
- sglang/srt/managers/io_struct.py +162 -118
- sglang/srt/managers/mm_utils.py +43 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +167 -86
- sglang/srt/managers/schedule_policy.py +143 -16
- sglang/srt/managers/scheduler.py +359 -214
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
- sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
- sglang/srt/managers/tokenizer_manager.py +84 -136
- sglang/srt/managers/tp_worker.py +39 -29
- sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +40 -1
- sglang/srt/mem_cache/hiradix_cache.py +119 -32
- sglang/srt/mem_cache/memory_pool.py +188 -10
- sglang/srt/mem_cache/memory_pool_host.py +134 -182
- sglang/srt/mem_cache/radix_cache.py +222 -71
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
- sglang/srt/mem_cache/swa_radix_cache.py +25 -34
- sglang/srt/metrics/collector.py +82 -120
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +39 -32
- sglang/srt/model_executor/forward_batch_info.py +23 -38
- sglang/srt/model_executor/model_runner.py +131 -183
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/loader.py +14 -10
- sglang/srt/model_loader/weight_utils.py +156 -2
- sglang/srt/models/bailing_moe.py +27 -4
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +536 -153
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +3 -3
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +1 -1
- sglang/srt/models/glm4v_moe.py +1 -1
- sglang/srt/models/gpt_oss.py +7 -30
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/longcat_flash.py +1 -1
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +15 -4
- sglang/srt/models/qwen2.py +0 -7
- sglang/srt/models/qwen2_5_vl.py +2 -2
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +64 -1
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +31 -3
- sglang/srt/models/qwen3_next.py +36 -9
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +2 -3
- sglang/srt/multimodal/processors/internvl.py +20 -8
- sglang/srt/multimodal/processors/qwen_vl.py +8 -1
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +20 -2
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +753 -295
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
- sglang/srt/speculative/eagle_worker.py +57 -25
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +47 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +32 -6
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +399 -74
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +1 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +12 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +355 -4
- sglang/utils.py +10 -1
- sglang/version.py +1 -1
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
sglang/srt/managers/scheduler.py
CHANGED
@@ -44,6 +44,9 @@ from sglang.srt.disaggregation.decode import (
|
|
44
44
|
DecodeTransferQueue,
|
45
45
|
SchedulerDisaggregationDecodeMixin,
|
46
46
|
)
|
47
|
+
from sglang.srt.disaggregation.decode_kvcache_offload_manager import (
|
48
|
+
DecodeKVCacheOffloadManager,
|
49
|
+
)
|
47
50
|
from sglang.srt.disaggregation.prefill import (
|
48
51
|
PrefillBootstrapQueue,
|
49
52
|
SchedulerDisaggregationPrefillMixin,
|
@@ -57,11 +60,6 @@ from sglang.srt.disaggregation.utils import (
|
|
57
60
|
)
|
58
61
|
from sglang.srt.distributed import get_pp_group, get_world_group
|
59
62
|
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
60
|
-
from sglang.srt.hf_transformers_utils import (
|
61
|
-
get_processor,
|
62
|
-
get_tokenizer,
|
63
|
-
get_tokenizer_from_processor,
|
64
|
-
)
|
65
63
|
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
66
64
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
67
65
|
from sglang.srt.layers.moe import initialize_moe_config
|
@@ -72,8 +70,10 @@ from sglang.srt.managers.io_struct import (
|
|
72
70
|
ClearHiCacheReqInput,
|
73
71
|
ClearHiCacheReqOutput,
|
74
72
|
CloseSessionReqInput,
|
73
|
+
DestroyWeightsUpdateGroupReqInput,
|
75
74
|
ExpertDistributionReq,
|
76
75
|
ExpertDistributionReqOutput,
|
76
|
+
ExpertDistributionReqType,
|
77
77
|
FlushCacheReqInput,
|
78
78
|
FlushCacheReqOutput,
|
79
79
|
FreezeGCReq,
|
@@ -116,6 +116,7 @@ from sglang.srt.managers.schedule_batch import (
|
|
116
116
|
FINISH_ABORT,
|
117
117
|
MultimodalInputs,
|
118
118
|
Req,
|
119
|
+
RequestStage,
|
119
120
|
ScheduleBatch,
|
120
121
|
global_server_args_dict,
|
121
122
|
)
|
@@ -140,23 +141,25 @@ from sglang.srt.managers.scheduler_update_weights_mixin import (
|
|
140
141
|
from sglang.srt.managers.session_controller import Session
|
141
142
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
142
143
|
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
|
143
|
-
from sglang.srt.managers.utils import
|
144
|
+
from sglang.srt.managers.utils import validate_input_length
|
144
145
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
145
146
|
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
146
|
-
from sglang.srt.mem_cache.lora_radix_cache import LoRARadixCache
|
147
147
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
148
148
|
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
149
|
-
from sglang.srt.model_executor.forward_batch_info import
|
149
|
+
from sglang.srt.model_executor.forward_batch_info import (
|
150
|
+
ForwardBatchOutput,
|
151
|
+
ForwardMode,
|
152
|
+
PPProxyTensors,
|
153
|
+
)
|
150
154
|
from sglang.srt.parser.reasoning_parser import ReasoningParser
|
151
155
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
152
156
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
153
157
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
154
158
|
from sglang.srt.tracing.trace import (
|
155
159
|
process_tracing_init,
|
156
|
-
trace_event,
|
157
160
|
trace_set_proc_propagate_context,
|
158
161
|
trace_set_thread_info,
|
159
|
-
|
162
|
+
trace_slice_batch,
|
160
163
|
trace_slice_end,
|
161
164
|
trace_slice_start,
|
162
165
|
)
|
@@ -170,8 +173,8 @@ from sglang.srt.utils import (
|
|
170
173
|
freeze_gc,
|
171
174
|
get_available_gpu_memory,
|
172
175
|
get_bool_env_var,
|
176
|
+
get_int_env_var,
|
173
177
|
get_zmq_socket,
|
174
|
-
is_cpu,
|
175
178
|
kill_itself_when_parent_died,
|
176
179
|
numa_bind_to_node,
|
177
180
|
point_to_point_pyobj,
|
@@ -182,6 +185,11 @@ from sglang.srt.utils import (
|
|
182
185
|
set_random_seed,
|
183
186
|
suppress_other_loggers,
|
184
187
|
)
|
188
|
+
from sglang.srt.utils.hf_transformers_utils import (
|
189
|
+
get_processor,
|
190
|
+
get_tokenizer,
|
191
|
+
get_tokenizer_from_processor,
|
192
|
+
)
|
185
193
|
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
|
186
194
|
|
187
195
|
logger = logging.getLogger(__name__)
|
@@ -190,24 +198,59 @@ logger = logging.getLogger(__name__)
|
|
190
198
|
TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
|
191
199
|
GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
|
192
200
|
|
193
|
-
_is_cpu = is_cpu()
|
194
|
-
|
195
201
|
|
196
202
|
@dataclass
|
197
203
|
class GenerationBatchResult:
|
198
204
|
logits_output: Optional[LogitsProcessorOutput]
|
199
|
-
pp_hidden_states_proxy_tensors: Optional[
|
205
|
+
pp_hidden_states_proxy_tensors: Optional[PPProxyTensors]
|
200
206
|
next_token_ids: Optional[List[int]]
|
207
|
+
can_run_cuda_graph: bool
|
208
|
+
|
209
|
+
# For output processing
|
201
210
|
extend_input_len_per_req: List[int]
|
202
211
|
extend_logprob_start_len_per_req: List[int]
|
203
|
-
|
204
|
-
|
212
|
+
|
213
|
+
@classmethod
|
214
|
+
def from_forward_batch_output(
|
215
|
+
cls,
|
216
|
+
forward_batch_output: ForwardBatchOutput,
|
217
|
+
extend_input_len_per_req: List[int],
|
218
|
+
extend_logprob_start_len_per_req: List[int],
|
219
|
+
):
|
220
|
+
# TODO(lsyin): remove this workaround logic and try to unify output classes
|
221
|
+
|
222
|
+
return cls(
|
223
|
+
logits_output=forward_batch_output.logits_output,
|
224
|
+
pp_hidden_states_proxy_tensors=forward_batch_output.pp_proxy_tensors,
|
225
|
+
next_token_ids=forward_batch_output.next_token_ids,
|
226
|
+
extend_input_len_per_req=extend_input_len_per_req,
|
227
|
+
extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
|
228
|
+
can_run_cuda_graph=forward_batch_output.can_run_cuda_graph,
|
229
|
+
)
|
230
|
+
|
231
|
+
@classmethod
|
232
|
+
def from_pp_proxy(
|
233
|
+
cls, logits_output, next_pp_outputs: PPProxyTensors, can_run_cuda_graph
|
234
|
+
):
|
235
|
+
# TODO(lsyin): also simplify this logic
|
236
|
+
# Current PP implementation in scheduler is not compatible with ForwardBatchOutput
|
237
|
+
# Maybe introduce a ProxyBatchOutput for PP and the original ForwardBatchOutput for TP
|
238
|
+
proxy_dict = next_pp_outputs.tensors
|
239
|
+
return cls(
|
240
|
+
logits_output=logits_output,
|
241
|
+
pp_hidden_states_proxy_tensors=None,
|
242
|
+
next_token_ids=next_pp_outputs["next_token_ids"],
|
243
|
+
extend_input_len_per_req=proxy_dict.get("extend_input_len_per_req", None),
|
244
|
+
extend_logprob_start_len_per_req=proxy_dict.get(
|
245
|
+
"extend_logprob_start_len_per_req", None
|
246
|
+
),
|
247
|
+
can_run_cuda_graph=can_run_cuda_graph,
|
248
|
+
)
|
205
249
|
|
206
250
|
|
207
251
|
@dataclass
|
208
252
|
class EmbeddingBatchResult:
|
209
253
|
embeddings: torch.Tensor
|
210
|
-
bid: int
|
211
254
|
|
212
255
|
|
213
256
|
class Scheduler(
|
@@ -229,7 +272,6 @@ class Scheduler(
|
|
229
272
|
moe_ep_rank: int,
|
230
273
|
pp_rank: int,
|
231
274
|
dp_rank: Optional[int],
|
232
|
-
dp_balance_meta: Optional[DPBalanceMeta] = None,
|
233
275
|
):
|
234
276
|
# Parse args
|
235
277
|
self.server_args = server_args
|
@@ -242,6 +284,13 @@ class Scheduler(
|
|
242
284
|
self.pp_size = server_args.pp_size
|
243
285
|
self.dp_size = server_args.dp_size
|
244
286
|
self.schedule_policy = server_args.schedule_policy
|
287
|
+
self.enable_priority_scheduling = server_args.enable_priority_scheduling
|
288
|
+
self.schedule_low_priority_values_first = (
|
289
|
+
server_args.schedule_low_priority_values_first
|
290
|
+
)
|
291
|
+
self.priority_scheduling_preemption_threshold = (
|
292
|
+
server_args.priority_scheduling_preemption_threshold
|
293
|
+
)
|
245
294
|
self.enable_lora = server_args.enable_lora
|
246
295
|
self.max_loras_per_batch = server_args.max_loras_per_batch
|
247
296
|
self.enable_overlap = not server_args.disable_overlap_schedule
|
@@ -250,7 +299,10 @@ class Scheduler(
|
|
250
299
|
self.enable_metrics_for_all_schedulers = (
|
251
300
|
server_args.enable_metrics_for_all_schedulers
|
252
301
|
)
|
253
|
-
self.enable_kv_cache_events =
|
302
|
+
self.enable_kv_cache_events = bool(
|
303
|
+
server_args.kv_events_config and tp_rank == 0
|
304
|
+
)
|
305
|
+
self.enable_trace = server_args.enable_trace
|
254
306
|
self.stream_interval = server_args.stream_interval
|
255
307
|
self.spec_algorithm = SpeculativeAlgorithm.from_string(
|
256
308
|
server_args.speculative_algorithm
|
@@ -376,9 +428,27 @@ class Scheduler(
|
|
376
428
|
target_worker=self.tp_worker,
|
377
429
|
dp_rank=dp_rank,
|
378
430
|
)
|
431
|
+
elif self.spec_algorithm.is_ngram():
|
432
|
+
from sglang.srt.speculative.ngram_worker import NGRAMWorker
|
433
|
+
|
434
|
+
self.draft_worker = NGRAMWorker(
|
435
|
+
gpu_id=gpu_id,
|
436
|
+
tp_rank=tp_rank,
|
437
|
+
moe_ep_rank=moe_ep_rank,
|
438
|
+
server_args=server_args,
|
439
|
+
nccl_port=port_args.nccl_port,
|
440
|
+
target_worker=self.tp_worker,
|
441
|
+
dp_rank=dp_rank,
|
442
|
+
)
|
379
443
|
else:
|
380
444
|
self.draft_worker = None
|
381
445
|
|
446
|
+
# Dispatch the model worker
|
447
|
+
if self.spec_algorithm.is_none():
|
448
|
+
self.model_worker = self.tp_worker
|
449
|
+
else:
|
450
|
+
self.model_worker = self.draft_worker
|
451
|
+
|
382
452
|
# Get token and memory info from the model worker
|
383
453
|
(
|
384
454
|
self.max_total_num_tokens,
|
@@ -486,7 +556,12 @@ class Scheduler(
|
|
486
556
|
self.schedule_policy,
|
487
557
|
self.tree_cache,
|
488
558
|
self.enable_hierarchical_cache,
|
559
|
+
self.enable_priority_scheduling,
|
560
|
+
self.schedule_low_priority_values_first,
|
489
561
|
)
|
562
|
+
# Enable preemption for priority scheduling.
|
563
|
+
self.try_preemption = self.enable_priority_scheduling
|
564
|
+
|
490
565
|
assert (
|
491
566
|
server_args.schedule_conservativeness >= 0
|
492
567
|
), "Invalid schedule_conservativeness"
|
@@ -527,8 +602,9 @@ class Scheduler(
|
|
527
602
|
|
528
603
|
# Init metrics stats
|
529
604
|
self.init_metrics(tp_rank, pp_rank, dp_rank)
|
530
|
-
|
531
|
-
self.
|
605
|
+
|
606
|
+
if self.enable_kv_cache_events:
|
607
|
+
self.init_kv_events(server_args.kv_events_config)
|
532
608
|
|
533
609
|
# Init disaggregation
|
534
610
|
self.disaggregation_mode = DisaggregationMode(
|
@@ -539,6 +615,9 @@ class Scheduler(
|
|
539
615
|
if get_bool_env_var("SGLANG_GC_LOG"):
|
540
616
|
configure_gc_logger()
|
541
617
|
|
618
|
+
# Init prefill kv split size when deterministic inference is enabled with various attention backends
|
619
|
+
self.init_deterministic_inference_config()
|
620
|
+
|
542
621
|
# Init request dispatcher
|
543
622
|
self._request_dispatcher = TypeBasedDispatcher(
|
544
623
|
[
|
@@ -553,6 +632,7 @@ class Scheduler(
|
|
553
632
|
(CloseSessionReqInput, self.close_session),
|
554
633
|
(UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
|
555
634
|
(InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
|
635
|
+
(DestroyWeightsUpdateGroupReqInput, self.destroy_weights_update_group),
|
556
636
|
(
|
557
637
|
InitWeightsSendGroupForRemoteInstanceReqInput,
|
558
638
|
self.init_weights_send_group_for_remote_instance,
|
@@ -583,6 +663,23 @@ class Scheduler(
|
|
583
663
|
]
|
584
664
|
)
|
585
665
|
|
666
|
+
def init_deterministic_inference_config(self):
|
667
|
+
"""Initialize deterministic inference configuration for different attention backends."""
|
668
|
+
if not self.server_args.enable_deterministic_inference:
|
669
|
+
self.truncation_align_size = None
|
670
|
+
return
|
671
|
+
|
672
|
+
backend_sizes = {
|
673
|
+
"flashinfer": ("SGLANG_FLASHINFER_PREFILL_SPLIT_TILE_SIZE", 4096),
|
674
|
+
"triton": ("SGLANG_TRITON_PREFILL_TRUNCATION_ALIGN_SIZE", 4096),
|
675
|
+
}
|
676
|
+
env_var, default_size = backend_sizes.get(
|
677
|
+
self.server_args.attention_backend, (None, None)
|
678
|
+
)
|
679
|
+
self.truncation_align_size = (
|
680
|
+
get_int_env_var(env_var, default_size) if env_var else None
|
681
|
+
)
|
682
|
+
|
586
683
|
def init_tokenizer(self):
|
587
684
|
server_args = self.server_args
|
588
685
|
self.is_generation = self.model_config.is_generation
|
@@ -654,6 +751,7 @@ class Scheduler(
|
|
654
751
|
else self.tp_cpu_group
|
655
752
|
),
|
656
753
|
page_size=self.page_size,
|
754
|
+
eviction_policy=server_args.radix_eviction_policy,
|
657
755
|
hicache_ratio=server_args.hicache_ratio,
|
658
756
|
hicache_size=server_args.hicache_size,
|
659
757
|
hicache_write_policy=server_args.hicache_write_policy,
|
@@ -664,6 +762,7 @@ class Scheduler(
|
|
664
762
|
hicache_storage_prefetch_policy=server_args.hicache_storage_prefetch_policy,
|
665
763
|
model_name=server_args.served_model_name,
|
666
764
|
storage_backend_extra_config=server_args.hicache_storage_backend_extra_config,
|
765
|
+
is_eagle=self.spec_algorithm.is_eagle(),
|
667
766
|
)
|
668
767
|
self.tp_worker.register_hicache_layer_transfer_counter(
|
669
768
|
self.tree_cache.cache_controller.layer_done_counter
|
@@ -679,19 +778,6 @@ class Scheduler(
|
|
679
778
|
page_size=self.page_size,
|
680
779
|
disable=server_args.disable_radix_cache,
|
681
780
|
)
|
682
|
-
elif self.enable_lora:
|
683
|
-
assert (
|
684
|
-
not self.enable_hierarchical_cache
|
685
|
-
), "LoRA radix cache doesn't support hierarchical cache"
|
686
|
-
assert (
|
687
|
-
self.schedule_policy == "fcfs"
|
688
|
-
), "LoRA radix cache only supports FCFS policy"
|
689
|
-
self.tree_cache = LoRARadixCache(
|
690
|
-
req_to_token_pool=self.req_to_token_pool,
|
691
|
-
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
692
|
-
page_size=self.page_size,
|
693
|
-
disable=server_args.disable_radix_cache,
|
694
|
-
)
|
695
781
|
elif server_args.enable_lmcache:
|
696
782
|
from sglang.srt.mem_cache.storage.lmcache.lmc_radix_cache import (
|
697
783
|
LMCRadixCache,
|
@@ -706,6 +792,7 @@ class Scheduler(
|
|
706
792
|
tp_size=self.tp_size,
|
707
793
|
rank=self.tp_rank,
|
708
794
|
tp_group=self.tp_group,
|
795
|
+
eviction_policy=server_args.radix_eviction_policy,
|
709
796
|
)
|
710
797
|
else:
|
711
798
|
self.tree_cache = RadixCache(
|
@@ -714,16 +801,36 @@ class Scheduler(
|
|
714
801
|
page_size=self.page_size,
|
715
802
|
disable=server_args.disable_radix_cache,
|
716
803
|
enable_kv_cache_events=self.enable_kv_cache_events,
|
804
|
+
eviction_policy=server_args.radix_eviction_policy,
|
805
|
+
is_eagle=self.spec_algorithm.is_eagle(),
|
717
806
|
)
|
718
807
|
|
808
|
+
if (
|
809
|
+
server_args.disaggregation_mode == "decode"
|
810
|
+
and server_args.disaggregation_decode_enable_offload_kvcache
|
811
|
+
):
|
812
|
+
self.decode_offload_manager = DecodeKVCacheOffloadManager(
|
813
|
+
req_to_token_pool=self.req_to_token_pool,
|
814
|
+
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
815
|
+
tp_group=(
|
816
|
+
self.attn_tp_cpu_group
|
817
|
+
if self.server_args.enable_dp_attention
|
818
|
+
else self.tp_cpu_group
|
819
|
+
),
|
820
|
+
tree_cache=self.tree_cache,
|
821
|
+
server_args=self.server_args,
|
822
|
+
)
|
823
|
+
else:
|
824
|
+
self.decode_offload_manager = None
|
825
|
+
|
719
826
|
self.decode_mem_cache_buf_multiplier = (
|
720
827
|
1
|
721
828
|
if self.spec_algorithm.is_none()
|
722
829
|
else (
|
723
830
|
server_args.speculative_num_draft_tokens
|
724
831
|
+ (
|
725
|
-
server_args.speculative_eagle_topk
|
726
|
-
* server_args.speculative_num_steps
|
832
|
+
(server_args.speculative_eagle_topk or 1)
|
833
|
+
* (server_args.speculative_num_steps or 1)
|
727
834
|
)
|
728
835
|
)
|
729
836
|
)
|
@@ -746,7 +853,7 @@ class Scheduler(
|
|
746
853
|
self.disagg_metadata_buffers = MetadataBuffers(
|
747
854
|
buffer_size,
|
748
855
|
hidden_size=self.model_config.hf_text_config.hidden_size,
|
749
|
-
|
856
|
+
hidden_states_dtype=self.model_config.dtype,
|
750
857
|
custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
|
751
858
|
)
|
752
859
|
|
@@ -766,7 +873,7 @@ class Scheduler(
|
|
766
873
|
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
767
874
|
draft_token_to_kv_pool=(
|
768
875
|
None
|
769
|
-
if self.draft_worker is None
|
876
|
+
if self.draft_worker is None or self.spec_algorithm.is_ngram()
|
770
877
|
else self.draft_worker.model_runner.token_to_kv_pool
|
771
878
|
),
|
772
879
|
req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
|
@@ -795,7 +902,7 @@ class Scheduler(
|
|
795
902
|
self.disagg_metadata_buffers = MetadataBuffers(
|
796
903
|
buffer_size,
|
797
904
|
hidden_size=self.model_config.hf_text_config.hidden_size,
|
798
|
-
|
905
|
+
hidden_states_dtype=self.model_config.dtype,
|
799
906
|
custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
|
800
907
|
)
|
801
908
|
|
@@ -803,7 +910,7 @@ class Scheduler(
|
|
803
910
|
token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
|
804
911
|
draft_token_to_kv_pool=(
|
805
912
|
None
|
806
|
-
if self.draft_worker is None
|
913
|
+
if self.draft_worker is None or self.spec_algorithm.is_ngram()
|
807
914
|
else self.draft_worker.model_runner.token_to_kv_pool
|
808
915
|
),
|
809
916
|
req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
|
@@ -838,10 +945,6 @@ class Scheduler(
|
|
838
945
|
batch = self.get_next_batch_to_run()
|
839
946
|
self.cur_batch = batch
|
840
947
|
|
841
|
-
if batch:
|
842
|
-
for req in batch.reqs:
|
843
|
-
trace_event("schedule", req.rid)
|
844
|
-
|
845
948
|
if batch:
|
846
949
|
result = self.run_batch(batch)
|
847
950
|
self.process_batch_result(batch, result)
|
@@ -863,10 +966,6 @@ class Scheduler(
|
|
863
966
|
batch = self.get_next_batch_to_run()
|
864
967
|
self.cur_batch = batch
|
865
968
|
|
866
|
-
if batch:
|
867
|
-
for req in batch.reqs:
|
868
|
-
trace_event("schedule", req.rid)
|
869
|
-
|
870
969
|
if batch:
|
871
970
|
batch.launch_done = threading.Event()
|
872
971
|
result = self.run_batch(batch)
|
@@ -906,7 +1005,6 @@ class Scheduler(
|
|
906
1005
|
self.running_mbs = [
|
907
1006
|
ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
|
908
1007
|
]
|
909
|
-
bids = [None] * self.pp_size
|
910
1008
|
pp_outputs: Optional[PPProxyTensors] = None
|
911
1009
|
while True:
|
912
1010
|
server_is_idle = True
|
@@ -927,10 +1025,7 @@ class Scheduler(
|
|
927
1025
|
# (last rank) send the outputs to the next step
|
928
1026
|
if self.pp_group.is_last_rank:
|
929
1027
|
if self.cur_batch:
|
930
|
-
next_token_ids
|
931
|
-
result.next_token_ids,
|
932
|
-
result.bid,
|
933
|
-
)
|
1028
|
+
next_token_ids = result.next_token_ids
|
934
1029
|
if self.cur_batch.return_logprob:
|
935
1030
|
pp_outputs = PPProxyTensors(
|
936
1031
|
{
|
@@ -978,17 +1073,10 @@ class Scheduler(
|
|
978
1073
|
logits_output = LogitsProcessorOutput(**logits_output_args)
|
979
1074
|
else:
|
980
1075
|
logits_output = None
|
981
|
-
|
1076
|
+
|
1077
|
+
output_result = GenerationBatchResult.from_pp_proxy(
|
982
1078
|
logits_output=logits_output,
|
983
|
-
|
984
|
-
next_token_ids=next_pp_outputs["next_token_ids"],
|
985
|
-
extend_input_len_per_req=next_pp_outputs.tensors.get(
|
986
|
-
"extend_input_len_per_req", None
|
987
|
-
),
|
988
|
-
extend_logprob_start_len_per_req=next_pp_outputs.tensors.get(
|
989
|
-
"extend_logprob_start_len_per_req", None
|
990
|
-
),
|
991
|
-
bid=bids[next_mb_id],
|
1079
|
+
next_pp_outputs=next_pp_outputs,
|
992
1080
|
can_run_cuda_graph=result.can_run_cuda_graph,
|
993
1081
|
)
|
994
1082
|
self.process_batch_result(mbs[next_mb_id], output_result)
|
@@ -996,8 +1084,6 @@ class Scheduler(
|
|
996
1084
|
|
997
1085
|
# (not last rank)
|
998
1086
|
if not self.pp_group.is_last_rank:
|
999
|
-
if self.cur_batch:
|
1000
|
-
bids[mb_id] = result.bid
|
1001
1087
|
# carry the outputs to the next stage
|
1002
1088
|
# send the outputs from the last round to let the next stage worker run post processing
|
1003
1089
|
if pp_outputs:
|
@@ -1019,8 +1105,10 @@ class Scheduler(
|
|
1019
1105
|
|
1020
1106
|
# send out proxy tensors to the next stage
|
1021
1107
|
if self.cur_batch:
|
1108
|
+
# FIXME(lsyin): remove this assert
|
1109
|
+
assert result.pp_hidden_states_proxy_tensors.tensors is not None
|
1022
1110
|
self.pp_group.send_tensor_dict(
|
1023
|
-
result.pp_hidden_states_proxy_tensors,
|
1111
|
+
result.pp_hidden_states_proxy_tensors.tensors,
|
1024
1112
|
all_gather_group=self.attn_tp_group,
|
1025
1113
|
)
|
1026
1114
|
|
@@ -1131,10 +1219,13 @@ class Scheduler(
|
|
1131
1219
|
src=self.tp_group.ranks[0],
|
1132
1220
|
)
|
1133
1221
|
|
1134
|
-
|
1135
|
-
|
1136
|
-
|
1137
|
-
|
1222
|
+
if self.enable_trace:
|
1223
|
+
for req in recv_reqs:
|
1224
|
+
if isinstance(
|
1225
|
+
req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
|
1226
|
+
):
|
1227
|
+
trace_set_proc_propagate_context(req.rid, req.trace_context)
|
1228
|
+
trace_slice_start("", req.rid, anonymous=True)
|
1138
1229
|
|
1139
1230
|
return recv_reqs
|
1140
1231
|
|
@@ -1149,20 +1240,6 @@ class Scheduler(
|
|
1149
1240
|
self.return_health_check_ct += 1
|
1150
1241
|
continue
|
1151
1242
|
|
1152
|
-
# If it is a work request, accept or reject the request based on the request queue size.
|
1153
|
-
if is_work_request(recv_req):
|
1154
|
-
if len(self.waiting_queue) + 1 > self.max_queued_requests:
|
1155
|
-
abort_req = AbortReq(
|
1156
|
-
recv_req.rid,
|
1157
|
-
finished_reason={
|
1158
|
-
"type": "abort",
|
1159
|
-
"status_code": HTTPStatus.SERVICE_UNAVAILABLE,
|
1160
|
-
"message": "The request queue is full.",
|
1161
|
-
},
|
1162
|
-
)
|
1163
|
-
self.send_to_tokenizer.send_pyobj(abort_req)
|
1164
|
-
continue
|
1165
|
-
|
1166
1243
|
# If it is a MultiTokenizerWrapper, unwrap it and handle the inner request.
|
1167
1244
|
if isinstance(recv_req, MultiTokenizerWrapper):
|
1168
1245
|
worker_id = recv_req.worker_id
|
@@ -1195,8 +1272,6 @@ class Scheduler(
|
|
1195
1272
|
self,
|
1196
1273
|
recv_req: TokenizedGenerateReqInput,
|
1197
1274
|
):
|
1198
|
-
self.maybe_update_dp_balance_data(recv_req)
|
1199
|
-
|
1200
1275
|
# Create a new request
|
1201
1276
|
if (
|
1202
1277
|
recv_req.session_params is None
|
@@ -1230,8 +1305,13 @@ class Scheduler(
|
|
1230
1305
|
bootstrap_host=recv_req.bootstrap_host,
|
1231
1306
|
bootstrap_port=recv_req.bootstrap_port,
|
1232
1307
|
bootstrap_room=recv_req.bootstrap_room,
|
1308
|
+
disagg_mode=self.disaggregation_mode,
|
1233
1309
|
data_parallel_rank=recv_req.data_parallel_rank,
|
1234
1310
|
vocab_size=self.model_config.vocab_size,
|
1311
|
+
priority=recv_req.priority,
|
1312
|
+
metrics_collector=(
|
1313
|
+
self.metrics_collector if self.enable_metrics else None
|
1314
|
+
),
|
1235
1315
|
)
|
1236
1316
|
req.tokenizer = self.tokenizer
|
1237
1317
|
|
@@ -1352,7 +1432,6 @@ class Scheduler(
|
|
1352
1432
|
req.set_finish_with_abort(error_msg)
|
1353
1433
|
|
1354
1434
|
if add_to_grammar_queue:
|
1355
|
-
req.queue_time_start = time.perf_counter()
|
1356
1435
|
self.grammar_queue.append(req)
|
1357
1436
|
else:
|
1358
1437
|
self._add_request_to_queue(req)
|
@@ -1368,20 +1447,6 @@ class Scheduler(
|
|
1368
1447
|
for tokenized_req in recv_req:
|
1369
1448
|
self.handle_generate_request(tokenized_req)
|
1370
1449
|
|
1371
|
-
def _add_request_to_queue(self, req: Req):
|
1372
|
-
req.queue_time_start = time.perf_counter()
|
1373
|
-
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
1374
|
-
self._prefetch_kvcache(req)
|
1375
|
-
self.disagg_prefill_bootstrap_queue.add(
|
1376
|
-
req, self.model_config.num_key_value_heads
|
1377
|
-
)
|
1378
|
-
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
1379
|
-
self.disagg_decode_prealloc_queue.add(req)
|
1380
|
-
else:
|
1381
|
-
self._prefetch_kvcache(req)
|
1382
|
-
self.waiting_queue.append(req)
|
1383
|
-
trace_slice_end("process req", req.rid, auto_next_anon=True)
|
1384
|
-
|
1385
1450
|
def _prefetch_kvcache(self, req: Req):
|
1386
1451
|
if self.enable_hicache_storage:
|
1387
1452
|
req.init_next_round_input(self.tree_cache)
|
@@ -1395,16 +1460,87 @@ class Scheduler(
|
|
1395
1460
|
req.rid, req.last_host_node, new_input_tokens, last_hash
|
1396
1461
|
)
|
1397
1462
|
|
1398
|
-
def
|
1399
|
-
if self.disaggregation_mode == DisaggregationMode.
|
1400
|
-
self.
|
1401
|
-
|
1463
|
+
def _add_request_to_queue(self, req: Req, is_retracted: bool = False):
|
1464
|
+
if self.disaggregation_mode == DisaggregationMode.NULL:
|
1465
|
+
self._set_or_validate_priority(req)
|
1466
|
+
if self._abort_on_queued_limit(req):
|
1467
|
+
return
|
1468
|
+
self._prefetch_kvcache(req)
|
1469
|
+
self.waiting_queue.append(req)
|
1470
|
+
req.time_stats.wait_queue_entry_time = time.perf_counter()
|
1471
|
+
trace_slice_end("process req", req.rid, auto_next_anon=True)
|
1472
|
+
elif self.disaggregation_mode == DisaggregationMode.PREFILL:
|
1473
|
+
self._prefetch_kvcache(req)
|
1474
|
+
self.disagg_prefill_bootstrap_queue.add(
|
1475
|
+
req, self.model_config.num_key_value_heads
|
1402
1476
|
)
|
1477
|
+
req.time_stats.prefill_bootstrap_queue_entry_time = time.perf_counter()
|
1403
1478
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
1404
|
-
|
1405
|
-
|
1479
|
+
self.disagg_decode_prealloc_queue.add(req, is_retracted=is_retracted)
|
1480
|
+
if not is_retracted:
|
1481
|
+
req.time_stats.decode_prealloc_queue_entry_time = time.perf_counter()
|
1406
1482
|
else:
|
1407
|
-
self.
|
1483
|
+
raise ValueError(f"Invalid {self.disaggregation_mode=}")
|
1484
|
+
|
1485
|
+
def _set_or_validate_priority(self, req: Req):
|
1486
|
+
"""Set the default priority value, or abort the request based on the priority scheduling mode."""
|
1487
|
+
if self.enable_priority_scheduling and req.priority is None:
|
1488
|
+
if self.schedule_low_priority_values_first:
|
1489
|
+
req.priority = sys.maxsize
|
1490
|
+
else:
|
1491
|
+
req.priority = -sys.maxsize - 1
|
1492
|
+
elif not self.enable_priority_scheduling and req.priority is not None:
|
1493
|
+
abort_req = AbortReq(
|
1494
|
+
finished_reason={
|
1495
|
+
"type": "abort",
|
1496
|
+
"status_code": HTTPStatus.SERVICE_UNAVAILABLE,
|
1497
|
+
"message": "Using priority is disabled for this server. Please send a new request without a priority.",
|
1498
|
+
},
|
1499
|
+
rid=req.rid,
|
1500
|
+
)
|
1501
|
+
self.send_to_tokenizer.send_pyobj(abort_req)
|
1502
|
+
|
1503
|
+
def _abort_on_queued_limit(self, recv_req: Req) -> bool:
|
1504
|
+
"""Abort an incoming or existing request if the waiting queue is full. Returns True if the incoming request is aborted."""
|
1505
|
+
if (
|
1506
|
+
self.max_queued_requests is None
|
1507
|
+
or len(self.waiting_queue) + 1 <= self.max_queued_requests
|
1508
|
+
):
|
1509
|
+
return False
|
1510
|
+
|
1511
|
+
# Reject the incoming request by default.
|
1512
|
+
req_to_abort = recv_req
|
1513
|
+
message = "The request queue is full."
|
1514
|
+
if self.enable_priority_scheduling:
|
1515
|
+
# With priority scheduling, consider aboritng an existing request based on the priority.
|
1516
|
+
# direction = 1 => smaller number = higher priority; -1 => larger number = higher priority.
|
1517
|
+
# max(...) + (direction * priority, queue_time_start) picks the least-preferred request.
|
1518
|
+
# Tie: later queue_time_start (newer) is evicted first. Preempt only if strictly better.
|
1519
|
+
direction = 1 if self.schedule_low_priority_values_first else -1
|
1520
|
+
key_fn = lambda item: (
|
1521
|
+
direction * item[1].priority,
|
1522
|
+
item[1].time_stats.wait_queue_entry_time,
|
1523
|
+
)
|
1524
|
+
idx, candidate_req = max(enumerate(self.waiting_queue), key=key_fn)
|
1525
|
+
abort_existing_req = (
|
1526
|
+
direction * recv_req.priority < direction * candidate_req.priority
|
1527
|
+
)
|
1528
|
+
if abort_existing_req:
|
1529
|
+
self.waiting_queue.pop(idx)
|
1530
|
+
req_to_abort = candidate_req
|
1531
|
+
message = "The request is aborted by a higher priority request."
|
1532
|
+
|
1533
|
+
self.send_to_tokenizer.send_pyobj(
|
1534
|
+
AbortReq(
|
1535
|
+
finished_reason={
|
1536
|
+
"type": "abort",
|
1537
|
+
"status_code": HTTPStatus.SERVICE_UNAVAILABLE,
|
1538
|
+
"message": message,
|
1539
|
+
},
|
1540
|
+
rid=req_to_abort.rid,
|
1541
|
+
)
|
1542
|
+
)
|
1543
|
+
return req_to_abort.rid == recv_req.rid
|
1408
1544
|
|
1409
1545
|
def handle_embedding_request(
|
1410
1546
|
self,
|
@@ -1416,6 +1552,7 @@ class Scheduler(
|
|
1416
1552
|
recv_req.input_ids,
|
1417
1553
|
recv_req.sampling_params,
|
1418
1554
|
token_type_ids=recv_req.token_type_ids,
|
1555
|
+
priority=recv_req.priority,
|
1419
1556
|
)
|
1420
1557
|
req.tokenizer = self.tokenizer
|
1421
1558
|
|
@@ -1660,7 +1797,6 @@ class Scheduler(
|
|
1660
1797
|
|
1661
1798
|
# Handle DP attention
|
1662
1799
|
if need_dp_attn_preparation:
|
1663
|
-
self.maybe_handle_dp_balance_data()
|
1664
1800
|
ret = self.prepare_mlp_sync_batch(ret)
|
1665
1801
|
|
1666
1802
|
return ret
|
@@ -1676,6 +1812,10 @@ class Scheduler(
|
|
1676
1812
|
if self.grammar_queue:
|
1677
1813
|
self.move_ready_grammar_requests()
|
1678
1814
|
|
1815
|
+
if self.try_preemption:
|
1816
|
+
# Reset batch_is_full to try preemption with a prefill adder.
|
1817
|
+
self.running_batch.batch_is_full = False
|
1818
|
+
|
1679
1819
|
# Handle the cases where prefill is not allowed
|
1680
1820
|
if (
|
1681
1821
|
self.running_batch.batch_is_full or len(self.waiting_queue) == 0
|
@@ -1688,7 +1828,11 @@ class Scheduler(
|
|
1688
1828
|
# as the space for the chunked request has just been released.
|
1689
1829
|
# In PP case, a chunked req can start in one microbatch and end in another microbatch, so the max_running_requests per microbatch should not be strict.
|
1690
1830
|
# Instead, we should always allow chunked request to be added, otherwise, there will be a memory leak.
|
1691
|
-
if
|
1831
|
+
if (
|
1832
|
+
self.get_num_allocatable_reqs(running_bs) <= 0
|
1833
|
+
and not self.chunked_req
|
1834
|
+
and not self.try_preemption
|
1835
|
+
):
|
1692
1836
|
self.running_batch.batch_is_full = True
|
1693
1837
|
return None
|
1694
1838
|
|
@@ -1708,6 +1852,7 @@ class Scheduler(
|
|
1708
1852
|
self.max_prefill_tokens,
|
1709
1853
|
self.chunked_prefill_size,
|
1710
1854
|
running_bs if self.is_mixed_chunk else 0,
|
1855
|
+
self.priority_scheduling_preemption_threshold,
|
1711
1856
|
)
|
1712
1857
|
|
1713
1858
|
if self.chunked_req is not None:
|
@@ -1728,15 +1873,19 @@ class Scheduler(
|
|
1728
1873
|
self.running_batch.batch_is_full = True
|
1729
1874
|
break
|
1730
1875
|
|
1876
|
+
running_bs = len(self.running_batch.reqs) - len(adder.preempt_list)
|
1731
1877
|
if len(adder.can_run_list) >= self.get_num_allocatable_reqs(running_bs):
|
1732
1878
|
self.running_batch.batch_is_full = True
|
1733
|
-
break
|
1734
|
-
|
1735
1879
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
1736
1880
|
# In prefill mode, prealloc queue and transfer queue can also take memory,
|
1737
1881
|
# so we need to check if the available size for the actual available size.
|
1738
1882
|
if len(adder.can_run_list) >= self.req_to_token_pool.available_size():
|
1739
1883
|
self.running_batch.batch_is_full = True
|
1884
|
+
|
1885
|
+
if self.running_batch.batch_is_full:
|
1886
|
+
if not self.try_preemption:
|
1887
|
+
break
|
1888
|
+
if not adder.preempt_to_schedule(req, self.server_args):
|
1740
1889
|
break
|
1741
1890
|
|
1742
1891
|
if self.enable_hicache_storage:
|
@@ -1746,7 +1895,11 @@ class Scheduler(
|
|
1746
1895
|
continue
|
1747
1896
|
|
1748
1897
|
req.init_next_round_input(self.tree_cache)
|
1749
|
-
res = adder.add_one_req(
|
1898
|
+
res = adder.add_one_req(
|
1899
|
+
req,
|
1900
|
+
has_chunked_req=(self.chunked_req is not None),
|
1901
|
+
truncation_align_size=self.truncation_align_size,
|
1902
|
+
)
|
1750
1903
|
|
1751
1904
|
if res != AddReqResult.CONTINUE:
|
1752
1905
|
if res == AddReqResult.NO_TOKEN:
|
@@ -1767,11 +1920,14 @@ class Scheduler(
|
|
1767
1920
|
if self.enable_metrics:
|
1768
1921
|
# only record queue time when enable_metrics is True to avoid overhead
|
1769
1922
|
for req in can_run_list:
|
1770
|
-
req.
|
1923
|
+
req.add_latency(RequestStage.PREFILL_WAITING)
|
1771
1924
|
|
1772
1925
|
self.waiting_queue = [
|
1773
1926
|
x for x in self.waiting_queue if x not in set(can_run_list)
|
1774
1927
|
]
|
1928
|
+
if adder.preempt_list:
|
1929
|
+
for req in adder.preempt_list:
|
1930
|
+
self._add_request_to_queue(req)
|
1775
1931
|
|
1776
1932
|
if adder.new_chunked_req is not None:
|
1777
1933
|
assert self.chunked_req is None
|
@@ -1782,7 +1938,16 @@ class Scheduler(
|
|
1782
1938
|
|
1783
1939
|
# Print stats
|
1784
1940
|
if self.current_scheduler_metrics_enabled():
|
1785
|
-
self.log_prefill_stats(adder, can_run_list, running_bs)
|
1941
|
+
self.log_prefill_stats(adder, can_run_list, running_bs, 0)
|
1942
|
+
|
1943
|
+
for req in can_run_list:
|
1944
|
+
if req.time_stats.forward_entry_time == 0:
|
1945
|
+
# Avoid update chunked request many times
|
1946
|
+
req.time_stats.forward_entry_time = time.perf_counter()
|
1947
|
+
if self.enable_metrics:
|
1948
|
+
self.metrics_collector.observe_queue_time(
|
1949
|
+
req.time_stats.get_queueing_time(),
|
1950
|
+
)
|
1786
1951
|
|
1787
1952
|
# Create a new batch
|
1788
1953
|
new_batch = ScheduleBatch.init_new(
|
@@ -1837,19 +2002,25 @@ class Scheduler(
|
|
1837
2002
|
TEST_RETRACT and batch.batch_size() > 10
|
1838
2003
|
):
|
1839
2004
|
old_ratio = self.new_token_ratio
|
1840
|
-
|
1841
|
-
|
1842
|
-
|
2005
|
+
retracted_reqs, new_token_ratio, reqs_to_abort = batch.retract_decode(
|
2006
|
+
self.server_args
|
2007
|
+
)
|
2008
|
+
self.num_retracted_reqs = len(retracted_reqs)
|
1843
2009
|
self.new_token_ratio = new_token_ratio
|
2010
|
+
for req in reqs_to_abort:
|
2011
|
+
self.send_to_tokenizer.send_pyobj(
|
2012
|
+
AbortReq(abort_reason=req.to_abort_message, rid=req.rid)
|
2013
|
+
)
|
1844
2014
|
|
1845
2015
|
logger.info(
|
1846
2016
|
"KV cache pool is full. Retract requests. "
|
1847
|
-
f"#retracted_reqs: {
|
1848
|
-
f"#
|
2017
|
+
f"#retracted_reqs: {len(retracted_reqs)}, "
|
2018
|
+
f"#aborted_retracted_reqs: {len(reqs_to_abort)}, "
|
2019
|
+
f"#new_token_ratio: {old_ratio:.4f} -> {new_token_ratio:.4f}"
|
1849
2020
|
)
|
1850
2021
|
|
1851
|
-
|
1852
|
-
|
2022
|
+
for req in retracted_reqs:
|
2023
|
+
self._add_request_to_queue(req, is_retracted=True)
|
1853
2024
|
else:
|
1854
2025
|
self.new_token_ratio = max(
|
1855
2026
|
self.new_token_ratio - self.new_token_ratio_decay,
|
@@ -1877,33 +2048,25 @@ class Scheduler(
|
|
1877
2048
|
|
1878
2049
|
# Run forward
|
1879
2050
|
if self.is_generation:
|
2051
|
+
|
2052
|
+
batch_or_worker_batch = batch
|
2053
|
+
|
1880
2054
|
if self.spec_algorithm.is_none():
|
1881
|
-
|
2055
|
+
# FIXME(lsyin): remove this if and finally unify the abstraction
|
2056
|
+
batch_or_worker_batch = batch.get_model_worker_batch()
|
1882
2057
|
|
1883
|
-
|
1884
|
-
|
1885
|
-
|
1886
|
-
|
1887
|
-
|
1888
|
-
|
1889
|
-
|
1890
|
-
)
|
1891
|
-
|
1892
|
-
|
1893
|
-
|
1894
|
-
|
1895
|
-
next_token_ids,
|
1896
|
-
bid,
|
1897
|
-
num_accepted_tokens,
|
1898
|
-
can_run_cuda_graph,
|
1899
|
-
) = self.draft_worker.forward_batch_speculative_generation(batch)
|
1900
|
-
bs = batch.batch_size()
|
1901
|
-
self.spec_num_total_accepted_tokens += num_accepted_tokens + bs
|
1902
|
-
self.spec_num_total_forward_ct += bs
|
1903
|
-
self.num_generated_tokens += num_accepted_tokens
|
1904
|
-
|
1905
|
-
if self.pp_group.is_last_rank:
|
1906
|
-
batch.output_ids = next_token_ids
|
2058
|
+
forward_batch_output = self.model_worker.forward_batch_generation(
|
2059
|
+
batch_or_worker_batch
|
2060
|
+
)
|
2061
|
+
|
2062
|
+
if not self.spec_algorithm.is_none():
|
2063
|
+
# TODO(lsyin): unify this metric-updating logic with non-spec, and move it to decode processing
|
2064
|
+
self.udpate_spec_metrics(
|
2065
|
+
batch.batch_size(), forward_batch_output.num_accepted_tokens
|
2066
|
+
)
|
2067
|
+
|
2068
|
+
# update batch's output ids
|
2069
|
+
batch.output_ids = forward_batch_output.next_token_ids
|
1907
2070
|
|
1908
2071
|
# These 2 values are needed for processing the output, but the values can be
|
1909
2072
|
# modified by overlap schedule. So we have to copy them here so that
|
@@ -1912,6 +2075,7 @@ class Scheduler(
|
|
1912
2075
|
extend_input_len_per_req = [req.extend_input_len for req in batch.reqs]
|
1913
2076
|
else:
|
1914
2077
|
extend_input_len_per_req = None
|
2078
|
+
|
1915
2079
|
if batch.return_logprob:
|
1916
2080
|
extend_logprob_start_len_per_req = [
|
1917
2081
|
req.extend_logprob_start_len for req in batch.reqs
|
@@ -1919,25 +2083,15 @@ class Scheduler(
|
|
1919
2083
|
else:
|
1920
2084
|
extend_logprob_start_len_per_req = None
|
1921
2085
|
|
1922
|
-
|
1923
|
-
|
1924
|
-
pp_hidden_states_proxy_tensors=(
|
1925
|
-
pp_hidden_states_proxy_tensors
|
1926
|
-
if not self.pp_group.is_last_rank
|
1927
|
-
else None
|
1928
|
-
),
|
1929
|
-
next_token_ids=next_token_ids if self.pp_group.is_last_rank else None,
|
2086
|
+
return GenerationBatchResult.from_forward_batch_output(
|
2087
|
+
forward_batch_output=forward_batch_output,
|
1930
2088
|
extend_input_len_per_req=extend_input_len_per_req,
|
1931
2089
|
extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
|
1932
|
-
bid=bid,
|
1933
|
-
can_run_cuda_graph=can_run_cuda_graph,
|
1934
2090
|
)
|
1935
2091
|
else: # embedding or reward model
|
1936
2092
|
model_worker_batch = batch.get_model_worker_batch()
|
1937
2093
|
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
|
1938
|
-
ret = EmbeddingBatchResult(
|
1939
|
-
embeddings=embeddings, bid=model_worker_batch.bid
|
1940
|
-
)
|
2094
|
+
ret = EmbeddingBatchResult(embeddings=embeddings)
|
1941
2095
|
return ret
|
1942
2096
|
|
1943
2097
|
def process_batch_result(
|
@@ -1948,23 +2102,14 @@ class Scheduler(
|
|
1948
2102
|
):
|
1949
2103
|
if batch.forward_mode.is_decode():
|
1950
2104
|
self.process_batch_result_decode(batch, result, launch_done)
|
1951
|
-
|
1952
|
-
|
1953
|
-
"decode loop",
|
1954
|
-
req.rid,
|
1955
|
-
auto_next_anon=not req.finished(),
|
1956
|
-
thread_finish_flag=req.finished(),
|
1957
|
-
)
|
2105
|
+
if self.enable_trace:
|
2106
|
+
trace_slice_batch("decode loop", batch.reqs)
|
1958
2107
|
|
1959
2108
|
elif batch.forward_mode.is_extend():
|
1960
2109
|
self.process_batch_result_prefill(batch, result, launch_done)
|
1961
|
-
|
1962
|
-
|
1963
|
-
|
1964
|
-
req.rid,
|
1965
|
-
auto_next_anon=not req.finished(),
|
1966
|
-
thread_finish_flag=req.finished(),
|
1967
|
-
)
|
2110
|
+
if self.enable_trace:
|
2111
|
+
trace_slice_batch("prefill", batch.reqs)
|
2112
|
+
|
1968
2113
|
elif batch.forward_mode.is_idle():
|
1969
2114
|
if self.enable_overlap:
|
1970
2115
|
self.tp_worker.resolve_last_batch_result(launch_done)
|
@@ -2123,12 +2268,13 @@ class Scheduler(
|
|
2123
2268
|
if req.finished(): # It is aborted by AbortReq
|
2124
2269
|
num_ready_reqs += 1
|
2125
2270
|
continue
|
2271
|
+
|
2126
2272
|
req.grammar = req.grammar.result(timeout=0.03)
|
2127
2273
|
self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
|
2128
2274
|
if req.grammar is INVALID_GRAMMAR_OBJ:
|
2129
|
-
req.
|
2130
|
-
|
2131
|
-
|
2275
|
+
error_msg = f"Invalid grammar request: {req.grammar_key=}"
|
2276
|
+
req.set_finish_with_abort(error_msg)
|
2277
|
+
|
2132
2278
|
num_ready_reqs += 1
|
2133
2279
|
except futures._base.TimeoutError:
|
2134
2280
|
req.grammar_wait_ct += 1
|
@@ -2160,9 +2306,8 @@ class Scheduler(
|
|
2160
2306
|
req.grammar = req.grammar.result()
|
2161
2307
|
self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
|
2162
2308
|
if req.grammar is INVALID_GRAMMAR_OBJ:
|
2163
|
-
req.
|
2164
|
-
|
2165
|
-
)
|
2309
|
+
error_msg = f"Invalid grammar request: {req.grammar_key=}"
|
2310
|
+
req.set_finish_with_abort(error_msg)
|
2166
2311
|
else:
|
2167
2312
|
num_ready_reqs_max = num_ready_reqs
|
2168
2313
|
num_timeout_reqs_max = num_timeout_reqs
|
@@ -2170,12 +2315,14 @@ class Scheduler(
|
|
2170
2315
|
for i in range(num_ready_reqs, num_ready_reqs + num_timeout_reqs_max):
|
2171
2316
|
req = self.grammar_queue[i]
|
2172
2317
|
req.grammar.cancel()
|
2318
|
+
self.grammar_backend.set_cache(req.grammar_key, INVALID_GRAMMAR_OBJ)
|
2173
2319
|
error_msg = f"Grammar preprocessing timed out for {req.grammar_key=}"
|
2174
2320
|
req.set_finish_with_abort(error_msg)
|
2175
|
-
|
2321
|
+
|
2176
2322
|
num_ready_reqs = num_ready_reqs_max + num_timeout_reqs_max
|
2177
2323
|
|
2178
|
-
self.
|
2324
|
+
for req in self.grammar_queue[:num_ready_reqs]:
|
2325
|
+
self._add_request_to_queue(req)
|
2179
2326
|
self.grammar_queue = self.grammar_queue[num_ready_reqs:]
|
2180
2327
|
|
2181
2328
|
def set_next_batch_sampling_info_done(self, batch: ScheduleBatch):
|
@@ -2267,9 +2414,8 @@ class Scheduler(
|
|
2267
2414
|
self.req_to_token_pool.clear()
|
2268
2415
|
self.token_to_kv_pool_allocator.clear()
|
2269
2416
|
|
2270
|
-
if
|
2271
|
-
self.draft_worker.
|
2272
|
-
self.draft_worker.model_runner.token_to_kv_pool_allocator.clear()
|
2417
|
+
if self.draft_worker:
|
2418
|
+
self.draft_worker.clear_cache_pool()
|
2273
2419
|
|
2274
2420
|
self.num_generated_tokens = 0
|
2275
2421
|
self.forward_ct_decode = 0
|
@@ -2433,7 +2579,7 @@ class Scheduler(
|
|
2433
2579
|
if self.enable_hicache_storage:
|
2434
2580
|
# to release prefetch events associated with the request
|
2435
2581
|
self.tree_cache.release_aborted_request(req.rid)
|
2436
|
-
self.send_to_tokenizer.send_pyobj(AbortReq(req.rid))
|
2582
|
+
self.send_to_tokenizer.send_pyobj(AbortReq(rid=req.rid))
|
2437
2583
|
# For disaggregation decode mode, the request in the waiting queue has KV cache allocated.
|
2438
2584
|
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
2439
2585
|
self.tree_cache.cache_finished_req(req)
|
@@ -2454,31 +2600,31 @@ class Scheduler(
|
|
2454
2600
|
# Delete requests not in the waiting queue when PD disaggregation is enabled
|
2455
2601
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
2456
2602
|
# Abort requests that have not yet been bootstrapped
|
2457
|
-
for
|
2458
|
-
logger.debug(f"Abort bootstrap queue request. {req.rid=}")
|
2603
|
+
for req in self.disagg_prefill_bootstrap_queue.queue:
|
2459
2604
|
if recv_req.abort_all or req.rid.startswith(recv_req.rid):
|
2605
|
+
logger.debug(f"Abort bootstrap queue request. {req.rid=}")
|
2460
2606
|
if hasattr(req.disagg_kv_sender, "abort"):
|
2461
2607
|
req.disagg_kv_sender.abort()
|
2462
2608
|
|
2463
2609
|
# Abort in-flight requests
|
2464
|
-
for
|
2465
|
-
logger.debug(f"Abort inflight queue request. {req.rid=}")
|
2610
|
+
for req in self.disagg_prefill_inflight_queue:
|
2466
2611
|
if recv_req.abort_all or req.rid.startswith(recv_req.rid):
|
2612
|
+
logger.debug(f"Abort inflight queue request. {req.rid=}")
|
2467
2613
|
if hasattr(req.disagg_kv_sender, "abort"):
|
2468
2614
|
req.disagg_kv_sender.abort()
|
2469
2615
|
|
2470
2616
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
2471
2617
|
# Abort requests that have not yet finished preallocation
|
2472
|
-
for
|
2473
|
-
logger.debug(f"Abort prealloc queue request. {decode_req.req.rid=}")
|
2618
|
+
for decode_req in self.disagg_decode_prealloc_queue.queue:
|
2474
2619
|
if recv_req.abort_all or decode_req.req.rid.startswith(recv_req.rid):
|
2620
|
+
logger.debug(f"Abort prealloc queue request. {decode_req.req.rid=}")
|
2475
2621
|
if hasattr(decode_req.kv_receiver, "abort"):
|
2476
2622
|
decode_req.kv_receiver.abort()
|
2477
2623
|
|
2478
2624
|
# Abort requests waiting for kvcache to release tree cache
|
2479
|
-
for
|
2480
|
-
logger.debug(f"Abort transfer queue request. {decode_req.req.rid=}")
|
2625
|
+
for decode_req in self.disagg_decode_transfer_queue.queue:
|
2481
2626
|
if recv_req.abort_all or decode_req.req.rid.startswith(recv_req.rid):
|
2627
|
+
logger.debug(f"Abort transfer queue request. {decode_req.req.rid=}")
|
2482
2628
|
if hasattr(decode_req.kv_receiver, "abort"):
|
2483
2629
|
decode_req.kv_receiver.abort()
|
2484
2630
|
|
@@ -2545,11 +2691,12 @@ class Scheduler(
|
|
2545
2691
|
return SlowDownReqOutput()
|
2546
2692
|
|
2547
2693
|
def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
|
2548
|
-
|
2694
|
+
action = recv_req.action
|
2695
|
+
if action == ExpertDistributionReqType.START_RECORD:
|
2549
2696
|
get_global_expert_distribution_recorder().start_record()
|
2550
|
-
elif
|
2697
|
+
elif action == ExpertDistributionReqType.STOP_RECORD:
|
2551
2698
|
get_global_expert_distribution_recorder().stop_record()
|
2552
|
-
elif
|
2699
|
+
elif action == ExpertDistributionReqType.DUMP_RECORD:
|
2553
2700
|
get_global_expert_distribution_recorder().dump_record()
|
2554
2701
|
else:
|
2555
2702
|
raise ValueError(f"Unrecognized ExpertDistributionReq value: {recv_req=}")
|
@@ -2632,7 +2779,8 @@ class IdleSleeper:
|
|
2632
2779
|
|
2633
2780
|
|
2634
2781
|
def is_health_check_generate_req(recv_req):
|
2635
|
-
|
2782
|
+
rid = getattr(recv_req, "rid", None)
|
2783
|
+
return rid is not None and rid.startswith("HEALTH_CHECK")
|
2636
2784
|
|
2637
2785
|
|
2638
2786
|
def is_work_request(recv_req):
|
@@ -2656,19 +2804,12 @@ def run_scheduler_process(
|
|
2656
2804
|
pp_rank: int,
|
2657
2805
|
dp_rank: Optional[int],
|
2658
2806
|
pipe_writer,
|
2659
|
-
balance_meta: Optional[DPBalanceMeta] = None,
|
2660
2807
|
):
|
2661
|
-
|
2662
|
-
process_tracing_init(server_args.oltp_traces_endpoint, "sglang")
|
2663
|
-
if server_args.disaggregation_mode == "null":
|
2664
|
-
thread_label = "Scheduler"
|
2665
|
-
trace_set_thread_info(thread_label, tp_rank, dp_rank)
|
2666
|
-
|
2667
|
-
if (numa_node := server_args.numa_node) is not None:
|
2668
|
-
numa_bind_to_node(numa_node[gpu_id])
|
2669
|
-
|
2670
|
-
# Generate the prefix
|
2808
|
+
# Generate the logger prefix
|
2671
2809
|
prefix = ""
|
2810
|
+
if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
|
2811
|
+
# [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
|
2812
|
+
dp_rank = int(os.environ["SGLANG_DP_RANK"])
|
2672
2813
|
if dp_rank is not None:
|
2673
2814
|
prefix += f" DP{dp_rank}"
|
2674
2815
|
if server_args.tp_size > 1:
|
@@ -2684,10 +2825,6 @@ def run_scheduler_process(
|
|
2684
2825
|
kill_itself_when_parent_died()
|
2685
2826
|
parent_process = psutil.Process().parent()
|
2686
2827
|
|
2687
|
-
# [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
|
2688
|
-
if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
|
2689
|
-
dp_rank = int(os.environ["SGLANG_DP_RANK"])
|
2690
|
-
|
2691
2828
|
# Configure the logger
|
2692
2829
|
configure_logger(server_args, prefix=prefix)
|
2693
2830
|
suppress_other_loggers()
|
@@ -2695,6 +2832,15 @@ def run_scheduler_process(
|
|
2695
2832
|
# Set cpu affinity to this gpu process
|
2696
2833
|
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
|
2697
2834
|
set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
|
2835
|
+
if (numa_node := server_args.numa_node) is not None:
|
2836
|
+
numa_bind_to_node(numa_node[gpu_id])
|
2837
|
+
|
2838
|
+
# Set up tracing
|
2839
|
+
if server_args.enable_trace:
|
2840
|
+
process_tracing_init(server_args.oltp_traces_endpoint, "sglang")
|
2841
|
+
if server_args.disaggregation_mode == "null":
|
2842
|
+
thread_label = "Scheduler"
|
2843
|
+
trace_set_thread_info(thread_label, tp_rank, dp_rank)
|
2698
2844
|
|
2699
2845
|
# Create a scheduler and run the event loop
|
2700
2846
|
try:
|
@@ -2706,7 +2852,6 @@ def run_scheduler_process(
|
|
2706
2852
|
moe_ep_rank,
|
2707
2853
|
pp_rank,
|
2708
2854
|
dp_rank,
|
2709
|
-
dp_balance_meta=balance_meta,
|
2710
2855
|
)
|
2711
2856
|
pipe_writer.send(
|
2712
2857
|
{
|