sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -1
- sglang/eval/loogle_eval.py +7 -0
- sglang/srt/_custom_ops.py +29 -1
- sglang/srt/configs/deepseekvl2.py +11 -2
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +10 -8
- sglang/srt/configs/update_config.py +3 -1
- sglang/srt/conversation.py +2 -1
- sglang/srt/custom_op.py +5 -2
- sglang/srt/disaggregation/common/conn.py +34 -6
- sglang/srt/disaggregation/decode.py +9 -1
- sglang/srt/disaggregation/mini_lb.py +3 -2
- sglang/srt/disaggregation/mooncake/conn.py +93 -76
- sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
- sglang/srt/disaggregation/nixl/conn.py +17 -13
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
- sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
- sglang/srt/distributed/parallel_state.py +103 -15
- sglang/srt/entrypoints/engine.py +31 -33
- sglang/srt/entrypoints/http_server.py +20 -32
- sglang/srt/entrypoints/openai/protocol.py +3 -3
- sglang/srt/entrypoints/openai/serving_chat.py +48 -6
- sglang/srt/eplb/expert_location_dispatch.py +1 -1
- sglang/srt/function_call/base_format_detector.py +74 -12
- sglang/srt/function_call/deepseekv3_detector.py +26 -11
- sglang/srt/function_call/ebnf_composer.py +95 -63
- sglang/srt/function_call/function_call_parser.py +4 -2
- sglang/srt/function_call/kimik2_detector.py +41 -16
- sglang/srt/function_call/llama32_detector.py +6 -3
- sglang/srt/function_call/mistral_detector.py +11 -3
- sglang/srt/function_call/pythonic_detector.py +16 -14
- sglang/srt/function_call/qwen25_detector.py +12 -3
- sglang/srt/function_call/qwen3_coder_detector.py +151 -0
- sglang/srt/hf_transformers_utils.py +0 -1
- sglang/srt/layers/activation.py +24 -3
- sglang/srt/layers/attention/base_attn_backend.py +3 -1
- sglang/srt/layers/attention/flashattention_backend.py +3 -3
- sglang/srt/layers/attention/flashinfer_backend.py +40 -1
- sglang/srt/layers/communicator.py +12 -12
- sglang/srt/layers/dp_attention.py +72 -24
- sglang/srt/layers/linear.py +13 -102
- sglang/srt/layers/logits_processor.py +34 -24
- sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
- sglang/srt/layers/moe/ep_moe/layer.py +23 -402
- sglang/srt/layers/moe/fused_moe_native.py +7 -47
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +54 -263
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
- sglang/srt/layers/moe/topk.py +190 -23
- sglang/srt/layers/quantization/__init__.py +20 -134
- sglang/srt/layers/quantization/awq.py +578 -11
- sglang/srt/layers/quantization/awq_triton.py +339 -0
- sglang/srt/layers/quantization/base_config.py +85 -10
- sglang/srt/layers/quantization/blockwise_int8.py +17 -55
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +23 -79
- sglang/srt/layers/quantization/fp8.py +273 -62
- sglang/srt/layers/quantization/fp8_kernel.py +210 -46
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gptq.py +501 -143
- sglang/srt/layers/quantization/marlin_utils.py +790 -0
- sglang/srt/layers/quantization/modelopt_quant.py +34 -112
- sglang/srt/layers/quantization/moe_wna16.py +45 -49
- sglang/srt/layers/quantization/petit.py +252 -0
- sglang/srt/layers/quantization/petit_utils.py +104 -0
- sglang/srt/layers/quantization/qoq.py +7 -6
- sglang/srt/layers/quantization/scalar_type.py +352 -0
- sglang/srt/layers/quantization/unquant.py +422 -0
- sglang/srt/layers/quantization/utils.py +340 -9
- sglang/srt/layers/quantization/w4afp8.py +8 -4
- sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
- sglang/srt/layers/quantization/w8a8_int8.py +51 -115
- sglang/srt/layers/radix_attention.py +5 -3
- sglang/srt/layers/vocab_parallel_embedding.py +1 -41
- sglang/srt/lora/lora.py +0 -4
- sglang/srt/lora/lora_manager.py +162 -164
- sglang/srt/lora/lora_registry.py +124 -0
- sglang/srt/lora/mem_pool.py +83 -35
- sglang/srt/lora/utils.py +12 -5
- sglang/srt/managers/cache_controller.py +288 -0
- sglang/srt/managers/io_struct.py +60 -30
- sglang/srt/managers/mm_utils.py +7 -8
- sglang/srt/managers/schedule_batch.py +163 -113
- sglang/srt/managers/schedule_policy.py +68 -27
- sglang/srt/managers/scheduler.py +256 -86
- sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
- sglang/srt/managers/tokenizer_manager.py +38 -27
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/allocator.py +74 -23
- sglang/srt/mem_cache/base_prefix_cache.py +14 -2
- sglang/srt/mem_cache/chunk_cache.py +5 -2
- sglang/srt/mem_cache/hicache_storage.py +168 -0
- sglang/srt/mem_cache/hiradix_cache.py +194 -5
- sglang/srt/mem_cache/memory_pool.py +16 -1
- sglang/srt/mem_cache/memory_pool_host.py +44 -2
- sglang/srt/mem_cache/radix_cache.py +26 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +66 -31
- sglang/srt/model_executor/forward_batch_info.py +210 -25
- sglang/srt/model_executor/model_runner.py +147 -42
- sglang/srt/model_loader/loader.py +7 -1
- sglang/srt/model_loader/utils.py +4 -4
- sglang/srt/models/clip.py +1 -1
- sglang/srt/models/deepseek.py +9 -6
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_v2.py +192 -173
- sglang/srt/models/deepseek_vl2.py +5 -5
- sglang/srt/models/gemma.py +48 -0
- sglang/srt/models/gemma2.py +52 -0
- sglang/srt/models/gemma3_causal.py +63 -0
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -4
- sglang/srt/models/granitemoe.py +385 -0
- sglang/srt/models/grok.py +9 -3
- sglang/srt/models/hunyuan.py +63 -16
- sglang/srt/models/internvl.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -1
- sglang/srt/models/llama.py +41 -0
- sglang/srt/models/llama4.py +11 -11
- sglang/srt/models/llava.py +2 -2
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpm.py +0 -2
- sglang/srt/models/minicpmo.py +3 -7
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mixtral.py +9 -2
- sglang/srt/models/mllama.py +3 -5
- sglang/srt/models/mllama4.py +13 -6
- sglang/srt/models/olmoe.py +8 -5
- sglang/srt/models/persimmon.py +330 -0
- sglang/srt/models/phi.py +321 -0
- sglang/srt/models/phi4mm.py +44 -4
- sglang/srt/models/phi4mm_audio.py +1260 -0
- sglang/srt/models/phi4mm_utils.py +1917 -0
- sglang/srt/models/phimoe.py +9 -3
- sglang/srt/models/qwen.py +37 -0
- sglang/srt/models/qwen2.py +41 -0
- sglang/srt/models/qwen2_5_vl.py +4 -4
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +53 -9
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/qwen3.py +65 -1
- sglang/srt/models/qwen3_moe.py +57 -24
- sglang/srt/models/vila.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +91 -97
- sglang/srt/multimodal/processors/clip.py +21 -19
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
- sglang/srt/multimodal/processors/gemma3.py +13 -17
- sglang/srt/multimodal/processors/gemma3n.py +19 -23
- sglang/srt/multimodal/processors/internvl.py +9 -10
- sglang/srt/multimodal/processors/janus_pro.py +12 -27
- sglang/srt/multimodal/processors/kimi_vl.py +12 -14
- sglang/srt/multimodal/processors/llava.py +4 -2
- sglang/srt/multimodal/processors/minicpm.py +35 -44
- sglang/srt/multimodal/processors/mlama.py +21 -18
- sglang/srt/multimodal/processors/mllama4.py +4 -5
- sglang/srt/multimodal/processors/phi4mm.py +63 -39
- sglang/srt/multimodal/processors/pixtral.py +14 -35
- sglang/srt/multimodal/processors/qwen_audio.py +65 -0
- sglang/srt/multimodal/processors/qwen_vl.py +16 -21
- sglang/srt/multimodal/processors/vila.py +14 -14
- sglang/srt/reasoning_parser.py +46 -4
- sglang/srt/sampling/sampling_batch_info.py +6 -5
- sglang/srt/sampling/sampling_params.py +8 -1
- sglang/srt/server_args.py +454 -270
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +46 -37
- sglang/srt/speculative/eagle_utils.py +51 -23
- sglang/srt/speculative/eagle_worker.py +59 -44
- sglang/srt/two_batch_overlap.py +10 -5
- sglang/srt/utils.py +44 -69
- sglang/test/runners.py +14 -3
- sglang/test/test_activation.py +50 -1
- sglang/test/test_block_fp8.py +8 -3
- sglang/test/test_block_fp8_ep.py +1 -1
- sglang/test/test_custom_ops.py +12 -7
- sglang/test/test_cutlass_w4a8_moe.py +1 -3
- sglang/test/test_fp4_moe.py +1 -3
- sglang/test/test_marlin_moe.py +286 -0
- sglang/test/test_marlin_utils.py +171 -0
- sglang/test/test_utils.py +35 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +10 -10
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +198 -175
- sglang/srt/layers/quantization/quant_utils.py +0 -166
- sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
sglang/srt/managers/scheduler.py
CHANGED
@@ -129,10 +129,10 @@ from sglang.srt.managers.session_controller import Session
|
|
129
129
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
130
130
|
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
|
131
131
|
from sglang.srt.managers.utils import validate_input_length
|
132
|
-
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
|
133
132
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
134
133
|
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
135
134
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
135
|
+
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
136
136
|
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
137
137
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
|
138
138
|
from sglang.srt.reasoning_parser import ReasoningParser
|
@@ -247,11 +247,14 @@ class Scheduler(
|
|
247
247
|
self.pp_size = server_args.pp_size
|
248
248
|
self.dp_size = server_args.dp_size
|
249
249
|
self.schedule_policy = server_args.schedule_policy
|
250
|
-
self.
|
250
|
+
self.enable_lora = server_args.enable_lora
|
251
251
|
self.max_loras_per_batch = server_args.max_loras_per_batch
|
252
252
|
self.enable_overlap = not server_args.disable_overlap_schedule
|
253
253
|
self.skip_tokenizer_init = server_args.skip_tokenizer_init
|
254
254
|
self.enable_metrics = server_args.enable_metrics
|
255
|
+
self.enable_metrics_for_all_schedulers = (
|
256
|
+
server_args.enable_metrics_for_all_schedulers
|
257
|
+
)
|
255
258
|
self.enable_kv_cache_events = server_args.kv_events_config is not None
|
256
259
|
self.stream_interval = server_args.stream_interval
|
257
260
|
self.spec_algorithm = SpeculativeAlgorithm.from_string(
|
@@ -259,6 +262,7 @@ class Scheduler(
|
|
259
262
|
)
|
260
263
|
self.gpu_id = gpu_id
|
261
264
|
self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
|
265
|
+
self.enable_hicache_storage = server_args.hicache_storage_backend is not None
|
262
266
|
self.page_size = server_args.page_size
|
263
267
|
self.dp_size = server_args.dp_size
|
264
268
|
self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = (
|
@@ -281,9 +285,6 @@ class Scheduler(
|
|
281
285
|
self.send_to_tokenizer = get_zmq_socket(
|
282
286
|
context, zmq.PUSH, port_args.tokenizer_ipc_name, False
|
283
287
|
)
|
284
|
-
self.send_metrics_from_scheduler = get_zmq_socket(
|
285
|
-
context, zmq.PUSH, port_args.metrics_ipc_name, False
|
286
|
-
)
|
287
288
|
|
288
289
|
if server_args.skip_tokenizer_init:
|
289
290
|
# Directly send to the TokenizerManager
|
@@ -309,10 +310,14 @@ class Scheduler(
|
|
309
310
|
else:
|
310
311
|
self.recv_from_tokenizer = None
|
311
312
|
self.recv_from_rpc = None
|
312
|
-
self.send_metrics_from_scheduler = None
|
313
313
|
self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
|
314
314
|
self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
|
315
315
|
|
316
|
+
if self.current_scheduler_metrics_enabled():
|
317
|
+
self.send_metrics_from_scheduler = get_zmq_socket(
|
318
|
+
context, zmq.PUSH, port_args.metrics_ipc_name, False
|
319
|
+
)
|
320
|
+
|
316
321
|
# Init tokenizer
|
317
322
|
self.init_tokenizer()
|
318
323
|
|
@@ -390,6 +395,14 @@ class Scheduler(
|
|
390
395
|
global_server_args_dict.update(worker_global_server_args_dict)
|
391
396
|
set_random_seed(self.random_seed)
|
392
397
|
|
398
|
+
# Hybrid
|
399
|
+
self.is_hybrid = self.tp_worker.is_hybrid
|
400
|
+
if self.is_hybrid:
|
401
|
+
self.sliding_window_size = self.tp_worker.sliding_window_size
|
402
|
+
self.full_tokens_per_layer, self.swa_tokens_per_layer = (
|
403
|
+
self.tp_worker.get_tokens_per_layer_info()
|
404
|
+
)
|
405
|
+
|
393
406
|
# Print debug info
|
394
407
|
if tp_rank == 0:
|
395
408
|
avail_mem = get_available_gpu_memory(
|
@@ -487,7 +500,7 @@ class Scheduler(
|
|
487
500
|
self.init_profier()
|
488
501
|
|
489
502
|
# Init metrics stats
|
490
|
-
self.init_metrics()
|
503
|
+
self.init_metrics(tp_rank, pp_rank, dp_rank)
|
491
504
|
self.init_kv_events(server_args.kv_events_config)
|
492
505
|
|
493
506
|
# Init request dispatcher
|
@@ -529,6 +542,9 @@ class Scheduler(
|
|
529
542
|
if get_bool_env_var("SGLANG_GC_LOG"):
|
530
543
|
configure_gc_logger()
|
531
544
|
|
545
|
+
def current_scheduler_metrics_enabled(self):
|
546
|
+
return self.attn_tp_rank == 0 or self.enable_metrics_for_all_schedulers
|
547
|
+
|
532
548
|
def maybe_sleep_on_idle(self):
|
533
549
|
if self.idle_sleeper is not None:
|
534
550
|
self.idle_sleeper.maybe_sleep()
|
@@ -570,7 +586,7 @@ class Scheduler(
|
|
570
586
|
server_args.chunked_prefill_size is not None
|
571
587
|
and server_args.disable_radix_cache
|
572
588
|
):
|
573
|
-
if self.
|
589
|
+
if self.is_hybrid:
|
574
590
|
ChunkCacheClass = SWAChunkCache
|
575
591
|
else:
|
576
592
|
ChunkCacheClass = ChunkCache
|
@@ -599,10 +615,22 @@ class Scheduler(
|
|
599
615
|
== "fa3" # hot fix for incompatibility
|
600
616
|
else server_args.hicache_io_backend
|
601
617
|
),
|
618
|
+
hicache_storage_backend=server_args.hicache_storage_backend,
|
602
619
|
)
|
603
620
|
self.tp_worker.register_hicache_layer_transfer_counter(
|
604
621
|
self.tree_cache.cache_controller.layer_done_counter
|
605
622
|
)
|
623
|
+
elif self.is_hybrid:
|
624
|
+
assert (
|
625
|
+
self.server_args.disaggregation_mode == "null"
|
626
|
+
), "Hybrid mode does not support disaggregation yet"
|
627
|
+
self.tree_cache = SWARadixCache(
|
628
|
+
req_to_token_pool=self.req_to_token_pool,
|
629
|
+
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
630
|
+
sliding_window_size=self.sliding_window_size,
|
631
|
+
page_size=self.page_size,
|
632
|
+
disable=server_args.disable_radix_cache,
|
633
|
+
)
|
606
634
|
|
607
635
|
else:
|
608
636
|
self.tree_cache = RadixCache(
|
@@ -625,6 +653,9 @@ class Scheduler(
|
|
625
653
|
)
|
626
654
|
)
|
627
655
|
|
656
|
+
embedding_cache_size = int(os.environ.get("SGLANG_VLM_CACHE_SIZE_MB", "100"))
|
657
|
+
init_embedding_cache(embedding_cache_size * 1024 * 1024)
|
658
|
+
|
628
659
|
def init_profier(self):
|
629
660
|
self.torch_profiler = None
|
630
661
|
self.torch_profiler_output_dir: Optional[str] = None
|
@@ -641,7 +672,7 @@ class Scheduler(
|
|
641
672
|
self.profile_in_progress: bool = False
|
642
673
|
self.rpd_profiler = None
|
643
674
|
|
644
|
-
def init_metrics(self):
|
675
|
+
def init_metrics(self, tp_rank: int, pp_rank: int, dp_rank: Optional[int]):
|
645
676
|
self.last_gen_throughput: float = 0.0
|
646
677
|
self.last_input_throughput: float = 0.0
|
647
678
|
self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
|
@@ -649,15 +680,19 @@ class Scheduler(
|
|
649
680
|
self.spec_num_total_forward_ct = 0
|
650
681
|
self.cum_spec_accept_length = 0
|
651
682
|
self.cum_spec_accept_count = 0
|
683
|
+
self.total_retracted_reqs = 0
|
652
684
|
self.stats = SchedulerStats()
|
653
685
|
if self.enable_metrics:
|
654
686
|
engine_type = "unified"
|
655
|
-
|
656
|
-
|
657
|
-
|
658
|
-
|
659
|
-
|
660
|
-
|
687
|
+
labels = {
|
688
|
+
"model_name": self.server_args.served_model_name,
|
689
|
+
"engine_type": engine_type,
|
690
|
+
"tp_rank": tp_rank,
|
691
|
+
"pp_rank": pp_rank,
|
692
|
+
}
|
693
|
+
if dp_rank is not None:
|
694
|
+
labels["dp_rank"] = dp_rank
|
695
|
+
self.metrics_collector = SchedulerMetricsCollector(labels=labels)
|
661
696
|
|
662
697
|
def init_kv_events(self, kv_events_config: Optional[str]):
|
663
698
|
if self.enable_kv_cache_events:
|
@@ -774,6 +809,7 @@ class Scheduler(
|
|
774
809
|
else:
|
775
810
|
# When the server is idle, do self-check and re-init some states
|
776
811
|
self.check_memory()
|
812
|
+
self.check_tree_cache()
|
777
813
|
self.new_token_ratio = self.init_new_token_ratio
|
778
814
|
self.maybe_sleep_on_idle()
|
779
815
|
|
@@ -819,6 +855,7 @@ class Scheduler(
|
|
819
855
|
elif batch is None:
|
820
856
|
# When the server is idle, do self-check and re-init some states
|
821
857
|
self.check_memory()
|
858
|
+
self.check_tree_cache()
|
822
859
|
self.new_token_ratio = self.init_new_token_ratio
|
823
860
|
self.maybe_sleep_on_idle()
|
824
861
|
|
@@ -955,6 +992,7 @@ class Scheduler(
|
|
955
992
|
# When the server is idle, self-check and re-init some states
|
956
993
|
if server_is_idle:
|
957
994
|
self.check_memory()
|
995
|
+
self.check_tree_cache()
|
958
996
|
self.new_token_ratio = self.init_new_token_ratio
|
959
997
|
self.maybe_sleep_on_idle()
|
960
998
|
|
@@ -1091,6 +1129,7 @@ class Scheduler(
|
|
1091
1129
|
bootstrap_port=recv_req.bootstrap_port,
|
1092
1130
|
bootstrap_room=recv_req.bootstrap_room,
|
1093
1131
|
data_parallel_rank=recv_req.data_parallel_rank,
|
1132
|
+
vocab_size=self.model_config.vocab_size,
|
1094
1133
|
)
|
1095
1134
|
req.tokenizer = self.tokenizer
|
1096
1135
|
|
@@ -1220,6 +1259,15 @@ class Scheduler(
|
|
1220
1259
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
1221
1260
|
self.disagg_decode_prealloc_queue.add(req)
|
1222
1261
|
else:
|
1262
|
+
if self.enable_hicache_storage:
|
1263
|
+
req.init_next_round_input(self.tree_cache)
|
1264
|
+
last_hash = req.last_host_node.get_last_hash_value()
|
1265
|
+
matched_len = len(req.prefix_indices) + req.host_hit_length
|
1266
|
+
if (matched_len > 0 and last_hash is not None) or matched_len == 0:
|
1267
|
+
new_input_tokens = req.fill_ids[matched_len:]
|
1268
|
+
self.tree_cache.prefetch_from_storage(
|
1269
|
+
req.rid, req.last_host_node, new_input_tokens, last_hash
|
1270
|
+
)
|
1223
1271
|
self.waiting_queue.append(req)
|
1224
1272
|
|
1225
1273
|
def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False):
|
@@ -1306,9 +1354,26 @@ class Scheduler(
|
|
1306
1354
|
self.last_input_throughput = self.last_prefill_tokens / gap_latency
|
1307
1355
|
self.last_prefill_tokens = adder.log_input_tokens
|
1308
1356
|
|
1309
|
-
|
1310
|
-
|
1311
|
-
|
1357
|
+
if self.is_hybrid:
|
1358
|
+
(
|
1359
|
+
full_num_used,
|
1360
|
+
swa_num_used,
|
1361
|
+
full_token_usage,
|
1362
|
+
swa_token_usage,
|
1363
|
+
_,
|
1364
|
+
_,
|
1365
|
+
_,
|
1366
|
+
_,
|
1367
|
+
) = self._get_swa_token_info()
|
1368
|
+
num_used = max(full_num_used, swa_num_used)
|
1369
|
+
token_usage = max(full_token_usage, swa_token_usage)
|
1370
|
+
token_msg = (
|
1371
|
+
f"full token usage: {full_token_usage:.2f}, "
|
1372
|
+
f"swa token usage: {swa_token_usage:.2f}, "
|
1373
|
+
)
|
1374
|
+
else:
|
1375
|
+
num_used, token_usage, _, _ = self._get_token_info()
|
1376
|
+
token_msg = f"token usage: {token_usage:.2f}, "
|
1312
1377
|
|
1313
1378
|
num_new_seq = len(can_run_list)
|
1314
1379
|
f = (
|
@@ -1316,7 +1381,7 @@ class Scheduler(
|
|
1316
1381
|
f"#new-seq: {num_new_seq}, "
|
1317
1382
|
f"#new-token: {adder.log_input_tokens}, "
|
1318
1383
|
f"#cached-token: {adder.log_hit_tokens}, "
|
1319
|
-
f"{
|
1384
|
+
f"{token_msg}"
|
1320
1385
|
)
|
1321
1386
|
|
1322
1387
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
@@ -1328,17 +1393,17 @@ class Scheduler(
|
|
1328
1393
|
f += f"#running-req: {running_bs}, "
|
1329
1394
|
f += f"#queue-req: {len(self.waiting_queue)}, "
|
1330
1395
|
|
1331
|
-
f += f"timestamp: {datetime.datetime.now().isoformat()}"
|
1332
|
-
|
1333
1396
|
logger.info(f)
|
1334
1397
|
|
1335
1398
|
if self.enable_metrics:
|
1336
|
-
|
1337
|
-
|
1399
|
+
total_tokens = adder.log_input_tokens + adder.log_hit_tokens
|
1400
|
+
|
1401
|
+
cache_hit_rate = (
|
1402
|
+
adder.log_hit_tokens / total_tokens if total_tokens > 0 else 0.0
|
1338
1403
|
)
|
1339
1404
|
self.stats.num_running_reqs = running_bs
|
1340
1405
|
self.stats.num_used_tokens = num_used
|
1341
|
-
self.stats.token_usage = round(
|
1406
|
+
self.stats.token_usage = round(token_usage, 2)
|
1342
1407
|
self.stats.num_queue_reqs = len(self.waiting_queue)
|
1343
1408
|
self.stats.cache_hit_rate = cache_hit_rate
|
1344
1409
|
|
@@ -1361,16 +1426,35 @@ class Scheduler(
|
|
1361
1426
|
self.last_gen_throughput = self.num_generated_tokens / gap_latency
|
1362
1427
|
self.num_generated_tokens = 0
|
1363
1428
|
num_running_reqs = len(batch.reqs)
|
1364
|
-
|
1365
|
-
|
1366
|
-
|
1429
|
+
if self.is_hybrid:
|
1430
|
+
(
|
1431
|
+
full_num_used,
|
1432
|
+
swa_num_used,
|
1433
|
+
full_token_usage,
|
1434
|
+
swa_token_usage,
|
1435
|
+
_,
|
1436
|
+
_,
|
1437
|
+
_,
|
1438
|
+
_,
|
1439
|
+
) = self._get_swa_token_info()
|
1440
|
+
num_used = max(full_num_used, swa_num_used)
|
1441
|
+
token_usage = max(full_token_usage, swa_token_usage)
|
1442
|
+
token_msg = (
|
1443
|
+
f"#full token: {full_num_used}, "
|
1444
|
+
f"full token usage: {full_token_usage:.2f}, "
|
1445
|
+
f"#swa token: {swa_num_used}, "
|
1446
|
+
f"swa token usage: {swa_token_usage:.2f}, "
|
1447
|
+
)
|
1448
|
+
else:
|
1449
|
+
num_used, token_usage, _, _ = self._get_token_info()
|
1450
|
+
token_msg = f"#token: {num_used}, " f"token usage: {token_usage:.2f}, "
|
1367
1451
|
|
1368
1452
|
if RECORD_STEP_TIME:
|
1369
1453
|
self.step_time_dict[num_running_reqs].append(
|
1370
1454
|
gap_latency / self.server_args.decode_log_interval
|
1371
1455
|
)
|
1372
1456
|
|
1373
|
-
msg = f"Decode batch.
|
1457
|
+
msg = f"Decode batch. #running-req: {num_running_reqs}, {token_msg}"
|
1374
1458
|
|
1375
1459
|
if self.spec_algorithm.is_none():
|
1376
1460
|
spec_accept_length = 0
|
@@ -1391,42 +1475,52 @@ class Scheduler(
|
|
1391
1475
|
f"cuda graph: {can_run_cuda_graph}, "
|
1392
1476
|
f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
|
1393
1477
|
f"#queue-req: {len(self.waiting_queue)}, "
|
1394
|
-
f"timestamp: {datetime.datetime.now().isoformat()}"
|
1395
1478
|
)
|
1396
1479
|
|
1397
1480
|
logger.info(msg)
|
1398
1481
|
if self.enable_metrics:
|
1399
1482
|
self.stats.num_running_reqs = num_running_reqs
|
1400
1483
|
self.stats.num_used_tokens = num_used
|
1401
|
-
self.stats.token_usage =
|
1484
|
+
self.stats.token_usage = round(token_usage, 2)
|
1402
1485
|
self.stats.cache_hit_rate = 0.0
|
1403
1486
|
self.stats.gen_throughput = self.last_gen_throughput
|
1404
1487
|
self.stats.num_queue_reqs = len(self.waiting_queue)
|
1405
1488
|
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
|
1406
1489
|
self.stats.spec_accept_length = spec_accept_length
|
1490
|
+
self.stats.total_retracted_reqs = self.total_retracted_reqs
|
1407
1491
|
self.metrics_collector.log_stats(self.stats)
|
1408
1492
|
self._emit_kv_metrics()
|
1409
1493
|
self._publish_kv_events()
|
1410
1494
|
|
1411
1495
|
def check_memory(self):
|
1412
|
-
if
|
1413
|
-
|
1496
|
+
if self.is_hybrid:
|
1497
|
+
(
|
1498
|
+
full_num_used,
|
1499
|
+
swa_num_used,
|
1500
|
+
_,
|
1501
|
+
_,
|
1502
|
+
full_available_size,
|
1503
|
+
full_evictable_size,
|
1504
|
+
swa_available_size,
|
1505
|
+
swa_evictable_size,
|
1506
|
+
) = self._get_swa_token_info()
|
1507
|
+
memory_leak = full_num_used != 0 or swa_num_used != 0
|
1508
|
+
token_msg = (
|
1509
|
+
f"{self.full_tokens_per_layer=}, {full_available_size=}, {full_evictable_size=}, {self.tree_cache.full_protected_size()=}\n"
|
1510
|
+
f"{self.swa_tokens_per_layer=}, {swa_available_size=}, {swa_evictable_size=}, {self.tree_cache.swa_protected_size()=}\n"
|
1511
|
+
)
|
1414
1512
|
else:
|
1415
|
-
|
1416
|
-
|
1417
|
-
|
1418
|
-
|
1419
|
-
|
1420
|
-
|
1421
|
-
else self.max_total_num_tokens - protected_size
|
1422
|
-
)
|
1423
|
-
if memory_leak:
|
1424
|
-
msg = (
|
1425
|
-
"token_to_kv_pool_allocator memory leak detected! "
|
1426
|
-
f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
|
1427
|
-
f"{available_token_size=}\n"
|
1428
|
-
f"{self.tree_cache.evictable_size()=}\n"
|
1513
|
+
_, _, available_size, evictable_size = self._get_token_info()
|
1514
|
+
protected_size = self.tree_cache.protected_size()
|
1515
|
+
memory_leak = (available_size + evictable_size) != (
|
1516
|
+
self.max_total_num_tokens
|
1517
|
+
if not self.enable_hierarchical_cache
|
1518
|
+
else self.max_total_num_tokens - protected_size
|
1429
1519
|
)
|
1520
|
+
token_msg = f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, {protected_size=}\n"
|
1521
|
+
|
1522
|
+
if memory_leak:
|
1523
|
+
msg = "token_to_kv_pool_allocator memory leak detected! " f"{token_msg}"
|
1430
1524
|
raise ValueError(msg)
|
1431
1525
|
|
1432
1526
|
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
@@ -1446,24 +1540,70 @@ class Scheduler(
|
|
1446
1540
|
|
1447
1541
|
if (
|
1448
1542
|
self.enable_metrics
|
1449
|
-
and self.
|
1543
|
+
and self.current_scheduler_metrics_enabled()
|
1450
1544
|
and time.perf_counter() > self.metrics_collector.last_log_time + 30
|
1451
1545
|
):
|
1452
1546
|
# During idle time, also collect metrics every 30 seconds.
|
1453
|
-
|
1454
|
-
|
1455
|
-
|
1456
|
-
|
1547
|
+
if self.is_hybrid:
|
1548
|
+
(
|
1549
|
+
full_num_used,
|
1550
|
+
swa_num_used,
|
1551
|
+
full_token_usage,
|
1552
|
+
swa_token_usage,
|
1553
|
+
_,
|
1554
|
+
_,
|
1555
|
+
_,
|
1556
|
+
_,
|
1557
|
+
) = self._get_swa_token_info()
|
1558
|
+
num_used = max(full_num_used, swa_num_used)
|
1559
|
+
token_usage = max(full_token_usage, swa_token_usage)
|
1560
|
+
else:
|
1561
|
+
num_used, token_usage, _, _ = self._get_token_info()
|
1457
1562
|
num_running_reqs = len(self.running_batch.reqs)
|
1458
1563
|
self.stats.num_running_reqs = num_running_reqs
|
1459
1564
|
self.stats.num_used_tokens = num_used
|
1460
|
-
self.stats.token_usage =
|
1565
|
+
self.stats.token_usage = round(token_usage, 2)
|
1461
1566
|
self.stats.gen_throughput = 0
|
1462
1567
|
self.stats.num_queue_reqs = len(self.waiting_queue)
|
1463
1568
|
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
|
1464
1569
|
self.metrics_collector.log_stats(self.stats)
|
1465
1570
|
self._publish_kv_events()
|
1466
1571
|
|
1572
|
+
def check_tree_cache(self):
|
1573
|
+
if self.is_hybrid and isinstance(self.tree_cache, SWARadixCache):
|
1574
|
+
self.tree_cache.sanity_check()
|
1575
|
+
|
1576
|
+
def _get_token_info(self):
|
1577
|
+
available_size = self.token_to_kv_pool_allocator.available_size()
|
1578
|
+
evictable_size = self.tree_cache.evictable_size()
|
1579
|
+
num_used = self.max_total_num_tokens - (available_size + evictable_size)
|
1580
|
+
token_usage = num_used / self.max_total_num_tokens
|
1581
|
+
return num_used, token_usage, available_size, evictable_size
|
1582
|
+
|
1583
|
+
def _get_swa_token_info(self):
|
1584
|
+
full_available_size = self.token_to_kv_pool_allocator.full_available_size()
|
1585
|
+
full_evictable_size = self.tree_cache.full_evictable_size()
|
1586
|
+
swa_available_size = self.token_to_kv_pool_allocator.swa_available_size()
|
1587
|
+
swa_evictable_size = self.tree_cache.swa_evictable_size()
|
1588
|
+
full_num_used = self.full_tokens_per_layer - (
|
1589
|
+
full_available_size + full_evictable_size
|
1590
|
+
)
|
1591
|
+
swa_num_used = self.swa_tokens_per_layer - (
|
1592
|
+
swa_available_size + swa_evictable_size
|
1593
|
+
)
|
1594
|
+
full_token_usage = full_num_used / self.full_tokens_per_layer
|
1595
|
+
swa_token_usage = swa_num_used / self.swa_tokens_per_layer
|
1596
|
+
return (
|
1597
|
+
full_num_used,
|
1598
|
+
swa_num_used,
|
1599
|
+
full_token_usage,
|
1600
|
+
swa_token_usage,
|
1601
|
+
full_available_size,
|
1602
|
+
full_evictable_size,
|
1603
|
+
swa_available_size,
|
1604
|
+
swa_evictable_size,
|
1605
|
+
)
|
1606
|
+
|
1467
1607
|
def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
|
1468
1608
|
# Merge the prefill batch into the running batch
|
1469
1609
|
chunked_req_to_exclude = set()
|
@@ -1572,13 +1712,13 @@ class Scheduler(
|
|
1572
1712
|
self.chunked_req.init_next_round_input()
|
1573
1713
|
self.chunked_req = adder.add_chunked_req(self.chunked_req)
|
1574
1714
|
|
1575
|
-
if self.
|
1715
|
+
if self.enable_lora:
|
1576
1716
|
lora_set = set([req.lora_path for req in self.running_batch.reqs])
|
1577
1717
|
|
1578
1718
|
# Get requests from the waiting queue to a new prefill batch
|
1579
1719
|
for req in self.waiting_queue:
|
1580
1720
|
if (
|
1581
|
-
self.
|
1721
|
+
self.enable_lora
|
1582
1722
|
and len(
|
1583
1723
|
lora_set
|
1584
1724
|
| set([req.lora_path for req in adder.can_run_list])
|
@@ -1600,6 +1740,9 @@ class Scheduler(
|
|
1600
1740
|
self.running_batch.batch_is_full = True
|
1601
1741
|
break
|
1602
1742
|
|
1743
|
+
if self.enable_hicache_storage:
|
1744
|
+
self.tree_cache.check_prefetch_progress(req.rid)
|
1745
|
+
|
1603
1746
|
req.init_next_round_input(self.tree_cache)
|
1604
1747
|
res = adder.add_one_req(req, has_chunked_req=(self.chunked_req is not None))
|
1605
1748
|
|
@@ -1636,7 +1779,7 @@ class Scheduler(
|
|
1636
1779
|
self.chunked_req.is_chunked += 1
|
1637
1780
|
|
1638
1781
|
# Print stats
|
1639
|
-
if self.
|
1782
|
+
if self.current_scheduler_metrics_enabled():
|
1640
1783
|
self.log_prefill_stats(adder, can_run_list, running_bs)
|
1641
1784
|
|
1642
1785
|
# Create a new batch
|
@@ -1695,14 +1838,17 @@ class Scheduler(
|
|
1695
1838
|
old_ratio = self.new_token_ratio
|
1696
1839
|
|
1697
1840
|
retracted_reqs, new_token_ratio = batch.retract_decode(self.server_args)
|
1841
|
+
num_retracted_reqs = len(retracted_reqs)
|
1698
1842
|
self.new_token_ratio = new_token_ratio
|
1699
1843
|
|
1700
1844
|
logger.info(
|
1701
1845
|
"KV cache pool is full. Retract requests. "
|
1702
|
-
f"#retracted_reqs: {
|
1846
|
+
f"#retracted_reqs: {num_retracted_reqs}, "
|
1703
1847
|
f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
|
1704
1848
|
)
|
1849
|
+
|
1705
1850
|
self._extend_requests_to_queue(retracted_reqs, is_retracted=True)
|
1851
|
+
self.total_retracted_reqs += num_retracted_reqs
|
1706
1852
|
else:
|
1707
1853
|
self.new_token_ratio = max(
|
1708
1854
|
self.new_token_ratio - self.new_token_ratio_decay,
|
@@ -1826,7 +1972,7 @@ class Scheduler(
|
|
1826
1972
|
local_batch,
|
1827
1973
|
dp_size=self.server_args.dp_size,
|
1828
1974
|
attn_tp_size=self.attn_tp_size,
|
1829
|
-
|
1975
|
+
tp_group=self.tp_group,
|
1830
1976
|
get_idle_batch=self.get_idle_batch,
|
1831
1977
|
disable_cuda_graph=self.server_args.disable_cuda_graph,
|
1832
1978
|
spec_algorithm=self.spec_algorithm,
|
@@ -1835,6 +1981,7 @@ class Scheduler(
|
|
1835
1981
|
enable_deepep_moe=self.server_args.enable_deepep_moe,
|
1836
1982
|
deepep_mode=DeepEPMode[self.server_args.deepep_mode],
|
1837
1983
|
require_mlp_tp_gather=require_mlp_tp_gather(self.server_args),
|
1984
|
+
disable_overlap_schedule=self.server_args.disable_overlap_schedule,
|
1838
1985
|
)
|
1839
1986
|
|
1840
1987
|
@staticmethod
|
@@ -1842,7 +1989,7 @@ class Scheduler(
|
|
1842
1989
|
local_batch: ScheduleBatch,
|
1843
1990
|
dp_size,
|
1844
1991
|
attn_tp_size: int,
|
1845
|
-
|
1992
|
+
tp_group,
|
1846
1993
|
get_idle_batch,
|
1847
1994
|
disable_cuda_graph: bool,
|
1848
1995
|
spec_algorithm,
|
@@ -1851,6 +1998,7 @@ class Scheduler(
|
|
1851
1998
|
enable_deepep_moe: bool,
|
1852
1999
|
deepep_mode: DeepEPMode,
|
1853
2000
|
require_mlp_tp_gather: bool,
|
2001
|
+
disable_overlap_schedule: bool,
|
1854
2002
|
):
|
1855
2003
|
# Check if other DP workers have running batches
|
1856
2004
|
if local_batch is None:
|
@@ -1881,6 +2029,12 @@ class Scheduler(
|
|
1881
2029
|
)
|
1882
2030
|
|
1883
2031
|
tbo_preparer = TboDPAttentionPreparer()
|
2032
|
+
if disable_overlap_schedule:
|
2033
|
+
group = tp_group.device_group
|
2034
|
+
device = tp_group.device
|
2035
|
+
else:
|
2036
|
+
group = tp_group.cpu_group
|
2037
|
+
device = "cpu"
|
1884
2038
|
|
1885
2039
|
local_info = torch.tensor(
|
1886
2040
|
[
|
@@ -1896,15 +2050,17 @@ class Scheduler(
|
|
1896
2050
|
),
|
1897
2051
|
],
|
1898
2052
|
dtype=torch.int64,
|
2053
|
+
device=device,
|
1899
2054
|
)
|
1900
2055
|
global_info = torch.empty(
|
1901
2056
|
(dp_size, attn_tp_size, 6),
|
1902
2057
|
dtype=torch.int64,
|
2058
|
+
device=device,
|
1903
2059
|
)
|
1904
2060
|
torch.distributed.all_gather_into_tensor(
|
1905
2061
|
global_info.flatten(),
|
1906
2062
|
local_info,
|
1907
|
-
group=
|
2063
|
+
group=group,
|
1908
2064
|
)
|
1909
2065
|
global_num_tokens = global_info[:, 0, 0].tolist()
|
1910
2066
|
can_cuda_graph = min(global_info[:, 0, 1].tolist())
|
@@ -2042,11 +2198,30 @@ class Scheduler(
|
|
2042
2198
|
|
2043
2199
|
if not disable_request_logging():
|
2044
2200
|
# Print batch size and memory pool info to check whether there are de-sync issues.
|
2201
|
+
if self.is_hybrid:
|
2202
|
+
(
|
2203
|
+
_,
|
2204
|
+
_,
|
2205
|
+
_,
|
2206
|
+
_,
|
2207
|
+
full_available_size,
|
2208
|
+
full_evictable_size,
|
2209
|
+
swa_available_size,
|
2210
|
+
swa_evictable_size,
|
2211
|
+
) = self._get_swa_token_info()
|
2212
|
+
info_msg = (
|
2213
|
+
f"{full_available_size=}, "
|
2214
|
+
f"{full_evictable_size=}, "
|
2215
|
+
f"{swa_available_size=}, "
|
2216
|
+
f"{swa_evictable_size=}, "
|
2217
|
+
)
|
2218
|
+
else:
|
2219
|
+
_, _, available_size, evictable_size = self._get_token_info()
|
2220
|
+
info_msg = f"{available_size=}, " f"{evictable_size=}, "
|
2045
2221
|
logger.error(
|
2046
2222
|
f"{self.cur_batch.batch_size()=}, "
|
2047
2223
|
f"{self.cur_batch.reqs=}, "
|
2048
|
-
f"{
|
2049
|
-
f"{self.tree_cache.evictable_size()=}, "
|
2224
|
+
f"{info_msg}"
|
2050
2225
|
)
|
2051
2226
|
|
2052
2227
|
pyspy_dump_schedulers()
|
@@ -2101,11 +2276,24 @@ class Scheduler(
|
|
2101
2276
|
|
2102
2277
|
def get_load(self):
|
2103
2278
|
# TODO(lsyin): use dynamically maintained num_waiting_tokens
|
2104
|
-
|
2105
|
-
|
2106
|
-
|
2107
|
-
|
2108
|
-
|
2279
|
+
if self.is_hybrid:
|
2280
|
+
load_full = (
|
2281
|
+
self.full_tokens_per_layer
|
2282
|
+
- self.token_to_kv_pool_allocator.full_available_size()
|
2283
|
+
- self.tree_cache.full_evictable_size()
|
2284
|
+
)
|
2285
|
+
load_swa = (
|
2286
|
+
self.swa_tokens_per_layer
|
2287
|
+
- self.token_to_kv_pool_allocator.swa_available_size()
|
2288
|
+
- self.tree_cache.swa_evictable_size()
|
2289
|
+
)
|
2290
|
+
load = max(load_full, load_swa)
|
2291
|
+
else:
|
2292
|
+
load = (
|
2293
|
+
self.max_total_num_tokens
|
2294
|
+
- self.token_to_kv_pool_allocator.available_size()
|
2295
|
+
- self.tree_cache.evictable_size()
|
2296
|
+
)
|
2109
2297
|
load += sum(len(req.origin_input_ids) for req in self.waiting_queue)
|
2110
2298
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
2111
2299
|
load += sum(
|
@@ -2284,12 +2472,6 @@ class Scheduler(
|
|
2284
2472
|
"""In-place loading a new lora adapter from disk or huggingface."""
|
2285
2473
|
|
2286
2474
|
result = self.tp_worker.load_lora_adapter(recv_req)
|
2287
|
-
|
2288
|
-
if result.success:
|
2289
|
-
flush_cache_success = self.flush_cache()
|
2290
|
-
assert flush_cache_success, "Cache flush failed after loading lora adapter."
|
2291
|
-
else:
|
2292
|
-
logger.error(result.error_message)
|
2293
2475
|
return result
|
2294
2476
|
|
2295
2477
|
def unload_lora_adapter(
|
@@ -2298,14 +2480,6 @@ class Scheduler(
|
|
2298
2480
|
"""Unload the lora adapter."""
|
2299
2481
|
|
2300
2482
|
result = self.tp_worker.unload_lora_adapter(recv_req)
|
2301
|
-
|
2302
|
-
if result.success:
|
2303
|
-
flush_cache_success = self.flush_cache()
|
2304
|
-
assert (
|
2305
|
-
flush_cache_success
|
2306
|
-
), "Cache flush failed after unloading LoRA weights"
|
2307
|
-
else:
|
2308
|
-
logger.error(result.error_message)
|
2309
2483
|
return result
|
2310
2484
|
|
2311
2485
|
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
|
@@ -2727,9 +2901,9 @@ def run_scheduler_process(
|
|
2727
2901
|
prefix += f" PP{pp_rank}"
|
2728
2902
|
|
2729
2903
|
# Config the process
|
2730
|
-
kill_itself_when_parent_died()
|
2731
2904
|
setproctitle.setproctitle(f"sglang::scheduler{prefix.replace(' ', '_')}")
|
2732
2905
|
faulthandler.enable()
|
2906
|
+
kill_itself_when_parent_died()
|
2733
2907
|
parent_process = psutil.Process().parent()
|
2734
2908
|
|
2735
2909
|
# [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
|
@@ -2744,10 +2918,6 @@ def run_scheduler_process(
|
|
2744
2918
|
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
|
2745
2919
|
set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
|
2746
2920
|
|
2747
|
-
embedding_cache_size = 100
|
2748
|
-
if "SGLANG_VLM_CACHE_SIZE_MB" in os.environ:
|
2749
|
-
embedding_cache_size = int(os.environ["SGLANG_VLM_CACHE_SIZE_MB"])
|
2750
|
-
init_embedding_cache(embedding_cache_size * 1024 * 1024)
|
2751
2921
|
# Create a scheduler and run the event loop
|
2752
2922
|
try:
|
2753
2923
|
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, pp_rank, dp_rank)
|
@@ -2758,8 +2928,8 @@ def run_scheduler_process(
|
|
2758
2928
|
"max_req_input_len": scheduler.max_req_input_len,
|
2759
2929
|
}
|
2760
2930
|
)
|
2761
|
-
disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode
|
2762
2931
|
|
2932
|
+
disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode
|
2763
2933
|
if disaggregation_mode == DisaggregationMode.NULL:
|
2764
2934
|
if server_args.pp_size > 1:
|
2765
2935
|
scheduler.event_loop_pp()
|