sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +321 -31
- sglang/bench_serving.py +10 -3
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +8 -0
- sglang/srt/configs/model_config.py +160 -105
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/constrained/base_grammar_backend.py +1 -0
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +6 -4
- sglang/srt/debug_utils/dumper.py +10 -3
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/common/conn.py +266 -98
- sglang/srt/disaggregation/decode.py +50 -9
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
- sglang/srt/disaggregation/mooncake/conn.py +51 -541
- sglang/srt/disaggregation/nixl/conn.py +148 -39
- sglang/srt/disaggregation/prefill.py +31 -14
- sglang/srt/disaggregation/utils.py +36 -5
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +135 -80
- sglang/srt/entrypoints/engine.py +23 -3
- sglang/srt/entrypoints/grpc_request_manager.py +330 -55
- sglang/srt/entrypoints/grpc_server.py +232 -102
- sglang/srt/entrypoints/http_server.py +49 -9
- sglang/srt/entrypoints/openai/protocol.py +110 -5
- sglang/srt/entrypoints/openai/serving_base.py +25 -6
- sglang/srt/entrypoints/openai/serving_chat.py +178 -49
- sglang/srt/entrypoints/openai/serving_completions.py +5 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/entrypoints/openai/serving_responses.py +42 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/expert_location.py +30 -5
- sglang/srt/function_call/function_call_parser.py +3 -2
- sglang/srt/function_call/glm4_moe_detector.py +3 -3
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
- sglang/srt/layers/activation.py +7 -6
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +108 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
- sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +112 -194
- sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
- sglang/srt/layers/attention/mamba/mamba.py +566 -1
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/triton_backend.py +42 -9
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +11 -1
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +2 -0
- sglang/srt/layers/linear.py +21 -4
- sglang/srt/layers/logits_processor.py +15 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +147 -74
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
- sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
- sglang/srt/layers/moe/utils.py +10 -0
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/fp8.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +1 -1
- sglang/srt/layers/quantization/modelopt_quant.py +44 -9
- sglang/srt/layers/quantization/mxfp4.py +12 -4
- sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
- sglang/srt/layers/quantization/w4afp8.py +0 -4
- sglang/srt/layers/quantization/w8a8_int8.py +15 -3
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +52 -4
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +3 -3
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +10 -4
- sglang/srt/lora/lora.py +7 -5
- sglang/srt/lora/lora_manager.py +17 -6
- sglang/srt/lora/mem_pool.py +1 -1
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +7 -5
- sglang/srt/managers/cache_controller.py +42 -142
- sglang/srt/managers/data_parallel_controller.py +11 -46
- sglang/srt/managers/detokenizer_manager.py +11 -11
- sglang/srt/managers/io_struct.py +162 -118
- sglang/srt/managers/mm_utils.py +43 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +167 -86
- sglang/srt/managers/schedule_policy.py +143 -16
- sglang/srt/managers/scheduler.py +359 -214
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
- sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
- sglang/srt/managers/tokenizer_manager.py +84 -136
- sglang/srt/managers/tp_worker.py +39 -29
- sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +40 -1
- sglang/srt/mem_cache/hiradix_cache.py +119 -32
- sglang/srt/mem_cache/memory_pool.py +188 -10
- sglang/srt/mem_cache/memory_pool_host.py +134 -182
- sglang/srt/mem_cache/radix_cache.py +222 -71
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
- sglang/srt/mem_cache/swa_radix_cache.py +25 -34
- sglang/srt/metrics/collector.py +82 -120
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +39 -32
- sglang/srt/model_executor/forward_batch_info.py +23 -38
- sglang/srt/model_executor/model_runner.py +131 -183
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/loader.py +14 -10
- sglang/srt/model_loader/weight_utils.py +156 -2
- sglang/srt/models/bailing_moe.py +27 -4
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +536 -153
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +3 -3
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +1 -1
- sglang/srt/models/glm4v_moe.py +1 -1
- sglang/srt/models/gpt_oss.py +7 -30
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/longcat_flash.py +1 -1
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +15 -4
- sglang/srt/models/qwen2.py +0 -7
- sglang/srt/models/qwen2_5_vl.py +2 -2
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +64 -1
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +31 -3
- sglang/srt/models/qwen3_next.py +36 -9
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +2 -3
- sglang/srt/multimodal/processors/internvl.py +20 -8
- sglang/srt/multimodal/processors/qwen_vl.py +8 -1
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +20 -2
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +753 -295
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
- sglang/srt/speculative/eagle_worker.py +57 -25
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +47 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +32 -6
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +399 -74
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +1 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +12 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +355 -4
- sglang/utils.py +10 -1
- sglang/version.py +1 -1
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -27,7 +27,8 @@ import torch
|
|
27
27
|
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
28
28
|
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
|
29
29
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
30
|
-
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
|
30
|
+
from sglang.srt.mem_cache.radix_cache import RadixCache, RadixKey, TreeNode
|
31
|
+
from sglang.srt.server_args import ServerArgs
|
31
32
|
|
32
33
|
if TYPE_CHECKING:
|
33
34
|
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
@@ -82,10 +83,14 @@ class SchedulePolicy:
|
|
82
83
|
policy: str,
|
83
84
|
tree_cache: BasePrefixCache,
|
84
85
|
enable_hierarchical_cache: bool,
|
86
|
+
enable_priority_scheduling: bool,
|
87
|
+
schedule_low_priority_values_first: bool,
|
85
88
|
):
|
86
89
|
self.policy = self._validate_and_adjust_policy(policy, tree_cache)
|
87
90
|
self.tree_cache = tree_cache
|
88
91
|
self.enable_hierarchical_cache = enable_hierarchical_cache
|
92
|
+
self.enable_priority_scheduling = enable_priority_scheduling
|
93
|
+
self.schedule_low_priority_values_first = schedule_low_priority_values_first
|
89
94
|
|
90
95
|
# It is used to find the matching prefix for in-batch prefix caching.
|
91
96
|
self.waiting_queue_radix_tree = RadixCache(
|
@@ -97,7 +102,10 @@ class SchedulePolicy:
|
|
97
102
|
|
98
103
|
def calc_priority(self, waiting_queue: List[Req]) -> bool:
|
99
104
|
if self.policy == CacheAgnosticPolicy.FCFS:
|
100
|
-
|
105
|
+
if self.enable_priority_scheduling:
|
106
|
+
SchedulePolicy._sort_by_priority_and_fcfs(
|
107
|
+
waiting_queue, self.schedule_low_priority_values_first
|
108
|
+
)
|
101
109
|
return False
|
102
110
|
|
103
111
|
policy = self._determine_active_policy(waiting_queue)
|
@@ -120,12 +128,15 @@ class SchedulePolicy:
|
|
120
128
|
if policy == CacheAgnosticPolicy.FCFS:
|
121
129
|
pass
|
122
130
|
elif policy == CacheAgnosticPolicy.LOF:
|
123
|
-
SchedulePolicy._sort_by_longest_output(
|
131
|
+
SchedulePolicy._sort_by_longest_output(
|
132
|
+
waiting_queue,
|
133
|
+
self.enable_priority_scheduling,
|
134
|
+
self.schedule_low_priority_values_first,
|
135
|
+
)
|
124
136
|
elif policy == CacheAgnosticPolicy.RANDOM:
|
125
137
|
SchedulePolicy._sort_randomly(waiting_queue)
|
126
138
|
else:
|
127
139
|
raise ValueError(f"Unknown CacheAgnostic Policy: {policy=}")
|
128
|
-
|
129
140
|
return prefix_computed
|
130
141
|
|
131
142
|
def _determine_active_policy(self, waiting_queue: List[Req]) -> Policy:
|
@@ -164,10 +175,13 @@ class SchedulePolicy:
|
|
164
175
|
|
165
176
|
for r in waiting_queue:
|
166
177
|
prefix_ids = r.adjust_max_prefix_ids()
|
178
|
+
extra_key = r.extra_key
|
167
179
|
|
168
180
|
# NOTE: the prefix_indices must always be aligned with last_node
|
169
181
|
r.prefix_indices, r.last_node, r.last_host_node, r.host_hit_length = (
|
170
|
-
self.tree_cache.match_prefix(
|
182
|
+
self.tree_cache.match_prefix(
|
183
|
+
rid=r.rid, key=RadixKey(token_ids=prefix_ids, extra_key=extra_key)
|
184
|
+
)
|
171
185
|
)
|
172
186
|
|
173
187
|
# NOTE(sang): This logic is for in-batch prefix caching;
|
@@ -180,7 +194,8 @@ class SchedulePolicy:
|
|
180
194
|
if len(r.prefix_indices) <= IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD:
|
181
195
|
in_batch_matching_prefixes, _, _, _ = (
|
182
196
|
self.waiting_queue_radix_tree.match_prefix(
|
183
|
-
rid=r.rid,
|
197
|
+
rid=r.rid,
|
198
|
+
key=RadixKey(token_ids=prefix_ids, extra_key=extra_key),
|
184
199
|
)
|
185
200
|
)
|
186
201
|
if (
|
@@ -191,7 +206,8 @@ class SchedulePolicy:
|
|
191
206
|
else:
|
192
207
|
# Insert with a dummy key
|
193
208
|
self.waiting_queue_radix_tree.insert(
|
194
|
-
|
209
|
+
RadixKey(token_ids=prefix_ids, extra_key=extra_key),
|
210
|
+
torch.empty(len(prefix_ids), dtype=torch.bool),
|
195
211
|
)
|
196
212
|
return temporary_deprioritized
|
197
213
|
|
@@ -231,15 +247,43 @@ class SchedulePolicy:
|
|
231
247
|
)
|
232
248
|
|
233
249
|
@staticmethod
|
234
|
-
def _sort_by_longest_output(
|
235
|
-
|
236
|
-
|
250
|
+
def _sort_by_longest_output(
|
251
|
+
waiting_queue: List[Req],
|
252
|
+
enable_priority_scheduling: bool,
|
253
|
+
schedule_low_priority_values_first: bool,
|
254
|
+
) -> None:
|
255
|
+
"""Sorts the waiting queue based on the longest output (max_new_tokens). If using priority scheduling, sort by priority first."""
|
256
|
+
if enable_priority_scheduling:
|
257
|
+
if schedule_low_priority_values_first:
|
258
|
+
waiting_queue.sort(
|
259
|
+
key=lambda x: (x.priority, -x.sampling_params.max_new_tokens)
|
260
|
+
)
|
261
|
+
else:
|
262
|
+
waiting_queue.sort(
|
263
|
+
key=lambda x: (-x.priority, -x.sampling_params.max_new_tokens)
|
264
|
+
)
|
265
|
+
else:
|
266
|
+
waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens)
|
237
267
|
|
238
268
|
@staticmethod
|
239
269
|
def _sort_randomly(waiting_queue: List[Req]) -> None:
|
240
270
|
"""Shuffles the waiting queue randomly."""
|
241
271
|
random.shuffle(waiting_queue)
|
242
272
|
|
273
|
+
@staticmethod
|
274
|
+
def _sort_by_priority_and_fcfs(
|
275
|
+
waiting_queue: List[Req], schedule_low_priority_values_first: bool
|
276
|
+
) -> None:
|
277
|
+
"""Sorts the waiting queue based on the request priority then received titmestamp."""
|
278
|
+
if schedule_low_priority_values_first:
|
279
|
+
waiting_queue.sort(
|
280
|
+
key=lambda x: (x.priority, x.time_stats.wait_queue_entry_time)
|
281
|
+
)
|
282
|
+
else:
|
283
|
+
waiting_queue.sort(
|
284
|
+
key=lambda x: (-x.priority, x.time_stats.wait_queue_entry_time)
|
285
|
+
)
|
286
|
+
|
243
287
|
@staticmethod
|
244
288
|
def _calc_weight(cur_node: TreeNode, node_to_weight: Dict[TreeNode, int]) -> None:
|
245
289
|
for child in cur_node.children.values():
|
@@ -279,6 +323,7 @@ class PrefillAdder:
|
|
279
323
|
rem_input_tokens: int,
|
280
324
|
rem_chunk_tokens: Optional[int],
|
281
325
|
mixed_with_decode_tokens: int = 0,
|
326
|
+
priority_scheduling_preemption_threshold: int = 0,
|
282
327
|
):
|
283
328
|
self.page_size = page_size
|
284
329
|
self.tree_cache = tree_cache
|
@@ -295,6 +340,7 @@ class PrefillAdder:
|
|
295
340
|
|
296
341
|
self.req_states = None
|
297
342
|
self.can_run_list = []
|
343
|
+
self.preempt_list = []
|
298
344
|
self.new_chunked_req = None
|
299
345
|
self.log_hit_tokens = 0
|
300
346
|
# TODO(lsyin): report the real input tokens excluding page alignment
|
@@ -303,11 +349,7 @@ class PrefillAdder:
|
|
303
349
|
if running_batch is not None:
|
304
350
|
self.rem_total_token_offset += sum(
|
305
351
|
[
|
306
|
-
|
307
|
-
(r.sampling_params.max_new_tokens - len(r.output_ids)),
|
308
|
-
CLIP_MAX_NEW_TOKENS,
|
309
|
-
)
|
310
|
-
* self.new_token_ratio
|
352
|
+
self._get_running_request_total_token_offset(r)
|
311
353
|
for r in running_batch.reqs
|
312
354
|
]
|
313
355
|
)
|
@@ -316,6 +358,19 @@ class PrefillAdder:
|
|
316
358
|
self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator
|
317
359
|
)
|
318
360
|
|
361
|
+
self.priority_scheduling_preemption_threshold = (
|
362
|
+
priority_scheduling_preemption_threshold
|
363
|
+
)
|
364
|
+
|
365
|
+
def _get_running_request_total_token_offset(self, req: Req) -> int:
|
366
|
+
return (
|
367
|
+
min(
|
368
|
+
(req.sampling_params.max_new_tokens - len(req.output_ids)),
|
369
|
+
CLIP_MAX_NEW_TOKENS,
|
370
|
+
)
|
371
|
+
* self.new_token_ratio
|
372
|
+
)
|
373
|
+
|
319
374
|
@property
|
320
375
|
def rem_total_tokens(self):
|
321
376
|
if self.is_hybrid:
|
@@ -495,7 +550,9 @@ class PrefillAdder:
|
|
495
550
|
|
496
551
|
return self.budget_state()
|
497
552
|
|
498
|
-
def add_one_req(
|
553
|
+
def add_one_req(
|
554
|
+
self, req: Req, has_chunked_req: bool, truncation_align_size: Optional[int]
|
555
|
+
):
|
499
556
|
if req.sampling_params.ignore_eos and getattr(self.tree_cache, "disable", True):
|
500
557
|
return self.add_one_req_ignore_eos(req, has_chunked_req)
|
501
558
|
|
@@ -526,6 +583,7 @@ class PrefillAdder:
|
|
526
583
|
req.prefix_indices = torch.cat([req.prefix_indices, new_indices])
|
527
584
|
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
|
528
585
|
prefix_len = len(req.prefix_indices)
|
586
|
+
req.last_matched_prefix_len = prefix_len
|
529
587
|
|
530
588
|
input_tokens = self.ceil_paged_tokens(req.extend_input_len)
|
531
589
|
|
@@ -554,6 +612,17 @@ class PrefillAdder:
|
|
554
612
|
if trunc_len <= 0:
|
555
613
|
return AddReqResult.OTHER
|
556
614
|
|
615
|
+
# When truncation align size is set, we want to assert that the prefill prefix length is multiple of truncation align size
|
616
|
+
# A typical use case is when deterministic inference is enabled with flashinfer attention backend,
|
617
|
+
# we need the prefill prefix length to be multiple of attention split size
|
618
|
+
if truncation_align_size is not None:
|
619
|
+
if trunc_len < truncation_align_size:
|
620
|
+
return AddReqResult.OTHER
|
621
|
+
else:
|
622
|
+
trunc_len = truncation_align_size * (
|
623
|
+
trunc_len // truncation_align_size
|
624
|
+
)
|
625
|
+
|
557
626
|
# Chunked prefill
|
558
627
|
req.extend_input_len = trunc_len
|
559
628
|
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]
|
@@ -568,3 +637,61 @@ class PrefillAdder:
|
|
568
637
|
self._update_prefill_budget(prefix_len, trunc_len, 0)
|
569
638
|
|
570
639
|
return self.budget_state()
|
640
|
+
|
641
|
+
def preempt_to_schedule(self, req: Req, server_args: ServerArgs) -> bool:
|
642
|
+
"""
|
643
|
+
Preempt running requests to serve the new request if the priority threshold is met and token count sum is verified.
|
644
|
+
Returns True if preemption was committed, and the new request can be scheduled.
|
645
|
+
"""
|
646
|
+
# Iterate running requests to find preemptible requests
|
647
|
+
if server_args.schedule_low_priority_values_first:
|
648
|
+
sorted_running_reqs = sorted(
|
649
|
+
self.running_batch.reqs,
|
650
|
+
key=lambda x: (-x.priority, -x.time_stats.wait_queue_entry_time),
|
651
|
+
)
|
652
|
+
else:
|
653
|
+
sorted_running_reqs = sorted(
|
654
|
+
self.running_batch.reqs,
|
655
|
+
key=lambda x: (x.priority, -x.time_stats.wait_queue_entry_time),
|
656
|
+
)
|
657
|
+
preemptible_reqs = []
|
658
|
+
min_tokens_to_remove = (
|
659
|
+
req.extend_input_len
|
660
|
+
+ min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS)
|
661
|
+
- self.rem_total_tokens
|
662
|
+
)
|
663
|
+
for running_req in sorted_running_reqs:
|
664
|
+
if running_req in self.preempt_list:
|
665
|
+
continue
|
666
|
+
# Priority difference needs to meet the threshold to be preemptible.
|
667
|
+
priority_diff = req.priority - running_req.priority
|
668
|
+
if server_args.schedule_low_priority_values_first:
|
669
|
+
priority_diff *= -1
|
670
|
+
if priority_diff > self.priority_scheduling_preemption_threshold:
|
671
|
+
preemptible_reqs.append(running_req)
|
672
|
+
min_tokens_to_remove -= self._get_running_request_total_token_offset(
|
673
|
+
running_req
|
674
|
+
)
|
675
|
+
|
676
|
+
# Check max token count limit can be met
|
677
|
+
if len(preemptible_reqs) == 0 or min_tokens_to_remove > 0:
|
678
|
+
return False
|
679
|
+
|
680
|
+
# Preempt running requests. Release allocated resources for immediate usage.
|
681
|
+
preemptible_reqs = set(preemptible_reqs)
|
682
|
+
keep_indices = []
|
683
|
+
release_counter = 0
|
684
|
+
for i, running_req in enumerate(self.running_batch.reqs):
|
685
|
+
if running_req in preemptible_reqs:
|
686
|
+
self.rem_total_token_offset -= (
|
687
|
+
self._get_running_request_total_token_offset(req)
|
688
|
+
)
|
689
|
+
release_counter += 1
|
690
|
+
self.running_batch.release_req(
|
691
|
+
i, len(self.running_batch.reqs) - release_counter, server_args
|
692
|
+
)
|
693
|
+
else:
|
694
|
+
keep_indices.append(i)
|
695
|
+
self.running_batch.filter_batch(keep_indices=keep_indices)
|
696
|
+
self.preempt_list.extend(preemptible_reqs)
|
697
|
+
return True
|