sglang 0.4.5__py3-none-any.whl → 0.4.5.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -4
- sglang/bench_one_batch.py +23 -2
- sglang/bench_serving.py +6 -4
- sglang/lang/backend/anthropic.py +0 -4
- sglang/lang/backend/base_backend.py +1 -1
- sglang/lang/backend/openai.py +1 -1
- sglang/lang/backend/vertexai.py +0 -1
- sglang/lang/compiler.py +1 -7
- sglang/lang/tracer.py +3 -7
- sglang/srt/_custom_ops.py +0 -2
- sglang/srt/configs/model_config.py +37 -5
- sglang/srt/constrained/base_grammar_backend.py +26 -5
- sglang/srt/constrained/llguidance_backend.py +1 -0
- sglang/srt/constrained/outlines_backend.py +1 -0
- sglang/srt/constrained/outlines_jump_forward.py +14 -1
- sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
- sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
- sglang/srt/constrained/xgrammar_backend.py +27 -4
- sglang/srt/custom_op.py +0 -62
- sglang/srt/disaggregation/base/__init__.py +8 -0
- sglang/srt/disaggregation/base/conn.py +113 -0
- sglang/srt/disaggregation/decode.py +80 -11
- sglang/srt/disaggregation/mini_lb.py +58 -123
- sglang/srt/disaggregation/mooncake/__init__.py +6 -0
- sglang/srt/disaggregation/mooncake/conn.py +585 -0
- sglang/srt/disaggregation/mooncake/transfer_engine.py +77 -0
- sglang/srt/disaggregation/prefill.py +82 -22
- sglang/srt/disaggregation/utils.py +46 -0
- sglang/srt/entrypoints/EngineBase.py +53 -0
- sglang/srt/entrypoints/engine.py +36 -8
- sglang/srt/entrypoints/http_server.py +37 -8
- sglang/srt/entrypoints/http_server_engine.py +142 -0
- sglang/srt/entrypoints/verl_engine.py +42 -13
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/activation.py +6 -8
- sglang/srt/layers/attention/flashattention_backend.py +430 -257
- sglang/srt/layers/attention/flashinfer_backend.py +18 -9
- sglang/srt/layers/attention/torch_native_backend.py +6 -1
- sglang/srt/layers/attention/triton_backend.py +6 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +13 -2
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/dp_attention.py +2 -4
- sglang/srt/layers/elementwise.py +15 -2
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +18 -3
- sglang/srt/layers/moe/ep_moe/layer.py +15 -29
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
- sglang/srt/layers/moe/fused_moe_native.py +4 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +46 -34
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
- sglang/srt/layers/moe/router.py +7 -1
- sglang/srt/layers/moe/topk.py +63 -45
- sglang/srt/layers/parameter.py +0 -2
- sglang/srt/layers/quantization/__init__.py +13 -5
- sglang/srt/layers/quantization/blockwise_int8.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +12 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -77
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
- sglang/srt/layers/quantization/fp8.py +131 -136
- sglang/srt/layers/quantization/fp8_kernel.py +328 -46
- sglang/srt/layers/quantization/fp8_utils.py +206 -253
- sglang/srt/layers/quantization/kv_cache.py +43 -52
- sglang/srt/layers/quantization/modelopt_quant.py +271 -4
- sglang/srt/layers/quantization/moe_wna16.py +2 -0
- sglang/srt/layers/quantization/utils.py +5 -11
- sglang/srt/layers/quantization/w8a8_fp8.py +156 -4
- sglang/srt/layers/quantization/w8a8_int8.py +8 -7
- sglang/srt/layers/radix_attention.py +28 -1
- sglang/srt/layers/rotary_embedding.py +15 -3
- sglang/srt/layers/sampler.py +5 -10
- sglang/srt/lora/backend/base_backend.py +18 -2
- sglang/srt/lora/backend/flashinfer_backend.py +1 -1
- sglang/srt/lora/backend/triton_backend.py +1 -1
- sglang/srt/lora/layers.py +1 -1
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/lora/lora_manager.py +1 -1
- sglang/srt/managers/detokenizer_manager.py +0 -1
- sglang/srt/managers/io_struct.py +255 -97
- sglang/srt/managers/mm_utils.py +7 -5
- sglang/srt/managers/multimodal_processor.py +0 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +117 -79
- sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
- sglang/srt/managers/multimodal_processors/mllama4.py +21 -36
- sglang/srt/managers/schedule_batch.py +64 -25
- sglang/srt/managers/scheduler.py +80 -82
- sglang/srt/managers/tokenizer_manager.py +18 -3
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +5 -1
- sglang/srt/mem_cache/memory_pool.py +21 -3
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +9 -6
- sglang/srt/model_executor/forward_batch_info.py +234 -15
- sglang/srt/model_executor/model_runner.py +67 -35
- sglang/srt/model_loader/loader.py +31 -4
- sglang/srt/model_loader/weight_utils.py +4 -2
- sglang/srt/models/baichuan.py +2 -0
- sglang/srt/models/bert.py +398 -0
- sglang/srt/models/chatglm.py +1 -0
- sglang/srt/models/commandr.py +1 -0
- sglang/srt/models/dbrx.py +1 -0
- sglang/srt/models/deepseek.py +2 -1
- sglang/srt/models/deepseek_nextn.py +74 -70
- sglang/srt/models/deepseek_v2.py +494 -366
- sglang/srt/models/exaone.py +1 -0
- sglang/srt/models/gemma.py +1 -0
- sglang/srt/models/gemma2.py +1 -0
- sglang/srt/models/gemma3_causal.py +1 -0
- sglang/srt/models/gpt2.py +1 -0
- sglang/srt/models/gpt_bigcode.py +1 -0
- sglang/srt/models/granite.py +1 -0
- sglang/srt/models/grok.py +1 -0
- sglang/srt/models/internlm2.py +1 -0
- sglang/srt/models/llama.py +6 -5
- sglang/srt/models/llama4.py +101 -34
- sglang/srt/models/minicpm.py +1 -0
- sglang/srt/models/minicpm3.py +30 -200
- sglang/srt/models/mixtral.py +1 -0
- sglang/srt/models/mixtral_quant.py +1 -0
- sglang/srt/models/mllama.py +51 -8
- sglang/srt/models/mllama4.py +102 -29
- sglang/srt/models/olmo.py +1 -0
- sglang/srt/models/olmo2.py +1 -0
- sglang/srt/models/olmoe.py +1 -0
- sglang/srt/models/phi3_small.py +1 -0
- sglang/srt/models/qwen.py +1 -0
- sglang/srt/models/qwen2.py +5 -1
- sglang/srt/models/qwen2_5_vl.py +35 -70
- sglang/srt/models/qwen2_moe.py +15 -13
- sglang/srt/models/qwen2_vl.py +27 -25
- sglang/srt/models/qwen3.py +335 -0
- sglang/srt/models/qwen3_moe.py +423 -0
- sglang/srt/models/stablelm.py +1 -0
- sglang/srt/models/xverse.py +1 -0
- sglang/srt/models/xverse_moe.py +1 -0
- sglang/srt/openai_api/adapter.py +4 -1
- sglang/srt/patch_torch.py +11 -0
- sglang/srt/reasoning_parser.py +0 -1
- sglang/srt/sampling/sampling_batch_info.py +2 -3
- sglang/srt/server_args.py +55 -19
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
- sglang/srt/speculative/eagle_utils.py +1 -11
- sglang/srt/speculative/eagle_worker.py +10 -9
- sglang/srt/utils.py +136 -10
- sglang/test/attention/test_flashattn_backend.py +259 -221
- sglang/test/attention/test_flashattn_mla_backend.py +285 -0
- sglang/test/attention/test_prefix_chunk_info.py +224 -0
- sglang/test/runners.py +5 -1
- sglang/test/test_block_fp8.py +224 -0
- sglang/test/test_custom_ops.py +1 -1
- sglang/test/test_utils.py +19 -8
- sglang/version.py +1 -1
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/METADATA +15 -5
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/RECORD +162 -147
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/WHEEL +1 -1
- sglang/lang/__init__.py +0 -0
- sglang/srt/disaggregation/conn.py +0 -81
- sglang/srt/lora/backend/__init__.py +0 -25
- sglang/srt/server.py +0 -18
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/top_level.txt +0 -0
sglang/srt/managers/scheduler.py
CHANGED
@@ -49,6 +49,7 @@ from sglang.srt.disaggregation.prefill import (
|
|
49
49
|
from sglang.srt.disaggregation.utils import (
|
50
50
|
DisaggregationMode,
|
51
51
|
ReqToMetadataIdxAllocator,
|
52
|
+
TransferBackend,
|
52
53
|
)
|
53
54
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
54
55
|
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
@@ -113,6 +114,7 @@ from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
|
113
114
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
114
115
|
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
115
116
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
117
|
+
from sglang.srt.reasoning_parser import ReasoningParser
|
116
118
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
117
119
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
118
120
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
@@ -232,6 +234,15 @@ class Scheduler(
|
|
232
234
|
# Init tokenizer
|
233
235
|
self.init_tokenizer()
|
234
236
|
|
237
|
+
# Set reasoning_parser and think_end_id if --reasoning_parser is enabled
|
238
|
+
if self.server_args.reasoning_parser and self.tokenizer:
|
239
|
+
reasoning_parser = ReasoningParser(
|
240
|
+
model_type=self.server_args.reasoning_parser, stream_reasoning=False
|
241
|
+
)
|
242
|
+
self.tokenizer.think_end_id = self.tokenizer.encode(
|
243
|
+
reasoning_parser.detector.think_end_token, add_special_tokens=False
|
244
|
+
)[0]
|
245
|
+
|
235
246
|
# Check whether overlap can be enabled
|
236
247
|
if not self.is_generation:
|
237
248
|
self.enable_overlap = False
|
@@ -380,6 +391,7 @@ class Scheduler(
|
|
380
391
|
self.torch_profiler = None
|
381
392
|
self.torch_profiler_output_dir: Optional[str] = None
|
382
393
|
self.profiler_activities: Optional[List[str]] = None
|
394
|
+
self.profiler_id: Optional[str] = None
|
383
395
|
self.profiler_target_forward_ct: Optional[int] = None
|
384
396
|
|
385
397
|
# Init metrics stats
|
@@ -427,6 +439,7 @@ class Scheduler(
|
|
427
439
|
context_length=server_args.context_length,
|
428
440
|
model_override_args=server_args.json_model_override_args,
|
429
441
|
is_embedding=server_args.is_embedding,
|
442
|
+
enable_multimodal=server_args.enable_multimodal,
|
430
443
|
dtype=server_args.dtype,
|
431
444
|
quantization=server_args.quantization,
|
432
445
|
)
|
@@ -441,6 +454,7 @@ class Scheduler(
|
|
441
454
|
tokenizer_mode=server_args.tokenizer_mode,
|
442
455
|
trust_remote_code=server_args.trust_remote_code,
|
443
456
|
revision=server_args.revision,
|
457
|
+
use_fast=not server_args.disable_fast_image_processor,
|
444
458
|
)
|
445
459
|
self.tokenizer = self.processor.tokenizer
|
446
460
|
else:
|
@@ -471,7 +485,7 @@ class Scheduler(
|
|
471
485
|
self.tree_cache = HiRadixCache(
|
472
486
|
req_to_token_pool=self.req_to_token_pool,
|
473
487
|
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
474
|
-
tp_cache_group=self.
|
488
|
+
tp_cache_group=self.tp_cpu_group,
|
475
489
|
page_size=self.page_size,
|
476
490
|
hicache_ratio=server_args.hicache_ratio,
|
477
491
|
)
|
@@ -518,6 +532,10 @@ class Scheduler(
|
|
518
532
|
)
|
519
533
|
|
520
534
|
def init_disaggregation(self):
|
535
|
+
self.transfer_backend = TransferBackend(
|
536
|
+
self.server_args.disaggregation_transfer_backend
|
537
|
+
)
|
538
|
+
|
521
539
|
if (
|
522
540
|
self.disaggregation_mode == DisaggregationMode.DECODE
|
523
541
|
): # *2 for the headroom.
|
@@ -536,7 +554,7 @@ class Scheduler(
|
|
536
554
|
|
537
555
|
# The decode requests polling kv cache
|
538
556
|
self.disagg_decode_transfer_queue = DecodeTransferQueue(
|
539
|
-
gloo_group=self.
|
557
|
+
gloo_group=self.attn_tp_cpu_group,
|
540
558
|
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
|
541
559
|
metadata_buffers=metadata_buffers,
|
542
560
|
)
|
@@ -551,10 +569,11 @@ class Scheduler(
|
|
551
569
|
scheduler=self,
|
552
570
|
transfer_queue=self.disagg_decode_transfer_queue,
|
553
571
|
tree_cache=self.tree_cache,
|
554
|
-
gloo_group=self.
|
572
|
+
gloo_group=self.attn_tp_cpu_group,
|
555
573
|
tp_rank=self.tp_rank,
|
556
574
|
tp_size=self.tp_size,
|
557
575
|
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
|
576
|
+
transfer_backend=self.transfer_backend,
|
558
577
|
)
|
559
578
|
elif self.disaggregation_mode == DisaggregationMode.PREFILL:
|
560
579
|
# *2 for the headroom.
|
@@ -579,10 +598,12 @@ class Scheduler(
|
|
579
598
|
tp_rank=self.tp_rank,
|
580
599
|
tp_size=self.tp_size,
|
581
600
|
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
|
582
|
-
gloo_group=self.
|
601
|
+
gloo_group=self.attn_tp_cpu_group,
|
602
|
+
transfer_backend=self.transfer_backend,
|
603
|
+
scheduler=self,
|
583
604
|
)
|
584
605
|
# The prefill requests that are in the middle of kv sending
|
585
|
-
self.
|
606
|
+
self.disagg_prefill_inflight_queue: List[Req] = []
|
586
607
|
|
587
608
|
@DynamicGradMode()
|
588
609
|
def event_loop_normal(self):
|
@@ -644,70 +665,6 @@ class Scheduler(
|
|
644
665
|
|
645
666
|
self.last_batch = batch
|
646
667
|
|
647
|
-
@torch.no_grad()
|
648
|
-
def event_loop_normal_disagg_prefill(self):
|
649
|
-
"""A normal scheduler loop for prefill worker in disaggregation mode."""
|
650
|
-
|
651
|
-
while True:
|
652
|
-
recv_reqs = self.recv_requests()
|
653
|
-
self.process_input_requests(recv_reqs)
|
654
|
-
self.waiting_queue.extend(
|
655
|
-
self.disagg_prefill_pending_queue.pop_bootstrapped()
|
656
|
-
)
|
657
|
-
self.process_prefill_chunk()
|
658
|
-
batch = self.get_new_batch_prefill()
|
659
|
-
self.cur_batch = batch
|
660
|
-
|
661
|
-
if batch:
|
662
|
-
result = self.run_batch(batch)
|
663
|
-
self.process_batch_result_disagg_prefill(batch, result)
|
664
|
-
|
665
|
-
if len(self.disagg_prefill_infight_queue) > 0:
|
666
|
-
self.process_disagg_prefill_infight_queue()
|
667
|
-
|
668
|
-
if batch is None and len(self.disagg_prefill_infight_queue) == 0:
|
669
|
-
self.check_memory()
|
670
|
-
self.new_token_ratio = self.init_new_token_ratio
|
671
|
-
|
672
|
-
self.last_batch = batch
|
673
|
-
# HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
|
674
|
-
# Otherwise, it hangs under high concurrency
|
675
|
-
self.running_batch.batch_is_full = False
|
676
|
-
|
677
|
-
@torch.no_grad()
|
678
|
-
def event_loop_normal_disagg_decode(self):
|
679
|
-
"""A normal scheduler loop for decode worker in disaggregation mode."""
|
680
|
-
|
681
|
-
while True:
|
682
|
-
recv_reqs = self.recv_requests()
|
683
|
-
self.process_input_requests(recv_reqs)
|
684
|
-
# polling and allocating kv cache
|
685
|
-
self.process_decode_queue()
|
686
|
-
batch = self.get_next_disagg_decode_batch_to_run()
|
687
|
-
self.cur_batch = batch
|
688
|
-
|
689
|
-
if batch:
|
690
|
-
# Generate fake extend output.
|
691
|
-
if batch.forward_mode.is_extend():
|
692
|
-
# Note: Logprobs should be handled on the prefill engine.
|
693
|
-
self.stream_output(
|
694
|
-
batch.reqs, [False for _ in range(len(batch.reqs))]
|
695
|
-
)
|
696
|
-
else:
|
697
|
-
result = self.run_batch(batch)
|
698
|
-
self.process_batch_result(batch, result)
|
699
|
-
|
700
|
-
if batch is None and (
|
701
|
-
len(self.disagg_decode_transfer_queue.queue)
|
702
|
-
+ len(self.disagg_decode_prealloc_queue.queue)
|
703
|
-
== 0
|
704
|
-
):
|
705
|
-
# When the server is idle, do self-check and re-init some states
|
706
|
-
self.check_memory()
|
707
|
-
self.new_token_ratio = self.init_new_token_ratio
|
708
|
-
|
709
|
-
self.last_batch = batch
|
710
|
-
|
711
668
|
def recv_requests(self) -> List[Req]:
|
712
669
|
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
|
713
670
|
if self.attn_tp_rank == 0:
|
@@ -826,6 +783,8 @@ class Scheduler(
|
|
826
783
|
custom_logit_processor=custom_logit_processor,
|
827
784
|
return_hidden_states=recv_req.return_hidden_states,
|
828
785
|
eos_token_ids=self.model_config.hf_eos_token_id,
|
786
|
+
bootstrap_host=recv_req.bootstrap_host,
|
787
|
+
bootstrap_room=recv_req.bootstrap_room,
|
829
788
|
)
|
830
789
|
req.tokenizer = self.tokenizer
|
831
790
|
|
@@ -937,12 +896,11 @@ class Scheduler(
|
|
937
896
|
self._add_request_to_queue(req)
|
938
897
|
|
939
898
|
def _add_request_to_queue(self, req: Req):
|
899
|
+
req.queue_time_start = time.time()
|
940
900
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
941
901
|
self.disagg_prefill_pending_queue.add(req)
|
942
|
-
|
943
902
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
944
903
|
self.disagg_decode_prealloc_queue.add(req)
|
945
|
-
|
946
904
|
else:
|
947
905
|
self.waiting_queue.append(req)
|
948
906
|
|
@@ -985,6 +943,7 @@ class Scheduler(
|
|
985
943
|
req.finished_reason = FINISH_ABORT(
|
986
944
|
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
|
987
945
|
)
|
946
|
+
req.queue_time_start = time.time()
|
988
947
|
self.waiting_queue.append(req)
|
989
948
|
return
|
990
949
|
|
@@ -1021,9 +980,10 @@ class Scheduler(
|
|
1021
980
|
self._largest_prefill_len, adder.log_input_tokens
|
1022
981
|
)
|
1023
982
|
|
983
|
+
num_new_seq = len(can_run_list)
|
1024
984
|
f = (
|
1025
985
|
f"Prefill batch. "
|
1026
|
-
f"#new-seq: {
|
986
|
+
f"#new-seq: {num_new_seq}, "
|
1027
987
|
f"#new-token: {adder.log_input_tokens}, "
|
1028
988
|
f"#cached-token: {adder.log_hit_tokens}, "
|
1029
989
|
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
@@ -1041,6 +1001,12 @@ class Scheduler(
|
|
1041
1001
|
self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
|
1042
1002
|
self.stats.num_queue_reqs = len(self.waiting_queue)
|
1043
1003
|
self.stats.cache_hit_rate = cache_hit_rate
|
1004
|
+
|
1005
|
+
total_queue_latency = 0
|
1006
|
+
for req in can_run_list:
|
1007
|
+
total_queue_latency += req.queue_time_end - req.queue_time_start
|
1008
|
+
self.stats.avg_request_queue_latency = total_queue_latency / num_new_seq
|
1009
|
+
|
1044
1010
|
self.metrics_collector.log_stats(self.stats)
|
1045
1011
|
|
1046
1012
|
def log_decode_stats(self):
|
@@ -1277,6 +1243,12 @@ class Scheduler(
|
|
1277
1243
|
can_run_list: List[Req] = adder.can_run_list
|
1278
1244
|
if len(can_run_list) == 0:
|
1279
1245
|
return None
|
1246
|
+
|
1247
|
+
if self.enable_metrics:
|
1248
|
+
# only record queue time when enable_metrics is True to avoid overhead
|
1249
|
+
for req in can_run_list:
|
1250
|
+
req.queue_time_end = time.time()
|
1251
|
+
|
1280
1252
|
self.waiting_queue = [
|
1281
1253
|
x for x in self.waiting_queue if x not in set(can_run_list)
|
1282
1254
|
]
|
@@ -1456,14 +1428,36 @@ class Scheduler(
|
|
1456
1428
|
self.send_to_tokenizer.send_pyobj(HealthCheckOutput())
|
1457
1429
|
|
1458
1430
|
def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
|
1431
|
+
return self.prepare_dp_attn_batch_raw(
|
1432
|
+
local_batch,
|
1433
|
+
dp_size=self.server_args.dp_size,
|
1434
|
+
attn_tp_size=self.attn_tp_size,
|
1435
|
+
tp_cpu_group=self.tp_cpu_group,
|
1436
|
+
get_idle_batch=self.get_idle_batch,
|
1437
|
+
disable_cuda_graph=self.server_args.disable_cuda_graph,
|
1438
|
+
spec_algorithm=self.spec_algorithm,
|
1439
|
+
speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
|
1440
|
+
)
|
1441
|
+
|
1442
|
+
@staticmethod
|
1443
|
+
def prepare_dp_attn_batch_raw(
|
1444
|
+
local_batch: ScheduleBatch,
|
1445
|
+
dp_size,
|
1446
|
+
attn_tp_size: int,
|
1447
|
+
tp_cpu_group,
|
1448
|
+
get_idle_batch,
|
1449
|
+
disable_cuda_graph: bool,
|
1450
|
+
spec_algorithm,
|
1451
|
+
speculative_num_draft_tokens,
|
1452
|
+
):
|
1459
1453
|
# Check if other DP workers have running batches
|
1460
1454
|
if local_batch is None:
|
1461
1455
|
num_tokens = 0
|
1462
1456
|
global_num_tokens_for_logprob = 0
|
1463
1457
|
elif local_batch.forward_mode.is_decode():
|
1464
1458
|
num_tokens = local_batch.batch_size()
|
1465
|
-
if not
|
1466
|
-
num_tokens = num_tokens *
|
1459
|
+
if not spec_algorithm.is_none() and spec_algorithm.is_eagle():
|
1460
|
+
num_tokens = num_tokens * speculative_num_draft_tokens
|
1467
1461
|
global_num_tokens_for_logprob = num_tokens
|
1468
1462
|
else:
|
1469
1463
|
num_tokens = local_batch.extend_num_tokens
|
@@ -1482,7 +1476,7 @@ class Scheduler(
|
|
1482
1476
|
else:
|
1483
1477
|
can_cuda_graph = 0
|
1484
1478
|
|
1485
|
-
if not
|
1479
|
+
if not spec_algorithm.is_none():
|
1486
1480
|
# TODO(sang): Support cuda graph when idle batch is there.
|
1487
1481
|
if local_batch is None or local_batch.forward_mode.is_idle():
|
1488
1482
|
can_cuda_graph = 0
|
@@ -1500,13 +1494,13 @@ class Scheduler(
|
|
1500
1494
|
dtype=torch.int64,
|
1501
1495
|
)
|
1502
1496
|
global_info = torch.empty(
|
1503
|
-
(
|
1497
|
+
(dp_size, attn_tp_size, 4),
|
1504
1498
|
dtype=torch.int64,
|
1505
1499
|
)
|
1506
1500
|
torch.distributed.all_gather_into_tensor(
|
1507
1501
|
global_info.flatten(),
|
1508
1502
|
local_info,
|
1509
|
-
group=
|
1503
|
+
group=tp_cpu_group,
|
1510
1504
|
)
|
1511
1505
|
global_num_tokens = global_info[:, 0, 0].tolist()
|
1512
1506
|
can_cuda_graph = min(global_info[:, 0, 1].tolist())
|
@@ -1514,14 +1508,14 @@ class Scheduler(
|
|
1514
1508
|
is_extend_in_batch = global_info[:, 0, 3].tolist()
|
1515
1509
|
|
1516
1510
|
if local_batch is None and max(global_num_tokens) > 0:
|
1517
|
-
local_batch =
|
1511
|
+
local_batch = get_idle_batch()
|
1518
1512
|
|
1519
1513
|
if local_batch is not None:
|
1520
1514
|
local_batch.global_num_tokens = global_num_tokens
|
1521
1515
|
local_batch.global_num_tokens_for_logprob = global_num_tokens_for_logprob
|
1522
1516
|
|
1523
1517
|
# Check forward mode for cuda graph
|
1524
|
-
if not
|
1518
|
+
if not disable_cuda_graph:
|
1525
1519
|
local_batch.can_run_dp_cuda_graph = can_cuda_graph
|
1526
1520
|
|
1527
1521
|
return local_batch, any(is_extend_in_batch)
|
@@ -1812,6 +1806,7 @@ class Scheduler(
|
|
1812
1806
|
recv_req.activities,
|
1813
1807
|
recv_req.with_stack,
|
1814
1808
|
recv_req.record_shapes,
|
1809
|
+
recv_req.profile_id,
|
1815
1810
|
)
|
1816
1811
|
else:
|
1817
1812
|
return self.stop_profile()
|
@@ -1823,6 +1818,7 @@ class Scheduler(
|
|
1823
1818
|
activities: Optional[List[str]],
|
1824
1819
|
with_stack: Optional[bool],
|
1825
1820
|
record_shapes: Optional[bool],
|
1821
|
+
profile_id: Optional[str],
|
1826
1822
|
) -> None:
|
1827
1823
|
if self.profiler_activities:
|
1828
1824
|
return ProfileReqOutput(
|
@@ -1837,9 +1833,11 @@ class Scheduler(
|
|
1837
1833
|
|
1838
1834
|
self.torch_profiler_output_dir = output_dir
|
1839
1835
|
self.profiler_activities = activities
|
1836
|
+
self.profiler_id = profile_id
|
1840
1837
|
logger.info(
|
1841
|
-
"Profiling starts. Traces will be saved to: %s",
|
1838
|
+
"Profiling starts. Traces will be saved to: %s (with id %s)",
|
1842
1839
|
self.torch_profiler_output_dir,
|
1840
|
+
self.profiler_id,
|
1843
1841
|
)
|
1844
1842
|
|
1845
1843
|
activity_map = {
|
@@ -1881,14 +1879,14 @@ class Scheduler(
|
|
1881
1879
|
self.torch_profiler.export_chrome_trace(
|
1882
1880
|
os.path.join(
|
1883
1881
|
self.torch_profiler_output_dir,
|
1884
|
-
|
1882
|
+
self.profiler_id + f"-TP-{self.tp_rank}" + ".trace.json.gz",
|
1885
1883
|
)
|
1886
1884
|
)
|
1887
1885
|
|
1888
1886
|
if "MEM" in self.profiler_activities:
|
1889
1887
|
memory_profile_path = os.path.join(
|
1890
1888
|
self.torch_profiler_output_dir,
|
1891
|
-
|
1889
|
+
self.profiler_id + f"-TP-{self.tp_rank}-memory" + ".pickle",
|
1892
1890
|
)
|
1893
1891
|
torch.cuda.memory._dump_snapshot(memory_profile_path)
|
1894
1892
|
torch.cuda.memory._record_memory_history(enabled=None)
|
@@ -48,8 +48,12 @@ from fastapi import BackgroundTasks
|
|
48
48
|
|
49
49
|
from sglang.srt.aio_rwlock import RWLock
|
50
50
|
from sglang.srt.configs.model_config import ModelConfig
|
51
|
-
from sglang.srt.disaggregation.
|
52
|
-
|
51
|
+
from sglang.srt.disaggregation.utils import (
|
52
|
+
DisaggregationMode,
|
53
|
+
KVClassType,
|
54
|
+
TransferBackend,
|
55
|
+
get_kv_class,
|
56
|
+
)
|
53
57
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
54
58
|
from sglang.srt.managers.io_struct import (
|
55
59
|
AbortReq,
|
@@ -163,6 +167,7 @@ class TokenizerManager:
|
|
163
167
|
context_length=server_args.context_length,
|
164
168
|
model_override_args=server_args.json_model_override_args,
|
165
169
|
is_embedding=server_args.is_embedding,
|
170
|
+
enable_multimodal=server_args.enable_multimodal,
|
166
171
|
dtype=server_args.dtype,
|
167
172
|
quantization=server_args.quantization,
|
168
173
|
)
|
@@ -179,6 +184,7 @@ class TokenizerManager:
|
|
179
184
|
tokenizer_mode=server_args.tokenizer_mode,
|
180
185
|
trust_remote_code=server_args.trust_remote_code,
|
181
186
|
revision=server_args.revision,
|
187
|
+
use_fast=not server_args.disable_fast_image_processor,
|
182
188
|
)
|
183
189
|
|
184
190
|
# We want to parallelize the image pre-processing so we create an executor for it
|
@@ -327,10 +333,16 @@ class TokenizerManager:
|
|
327
333
|
self.disaggregation_mode = DisaggregationMode(
|
328
334
|
self.server_args.disaggregation_mode
|
329
335
|
)
|
336
|
+
self.transfer_backend = TransferBackend(
|
337
|
+
self.server_args.disaggregation_transfer_backend
|
338
|
+
)
|
330
339
|
# for disaggregtion, start kv boostrap server on prefill
|
331
340
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
332
341
|
# only start bootstrap server on prefill tm
|
333
|
-
|
342
|
+
kv_bootstrap_server_class = get_kv_class(
|
343
|
+
self.transfer_backend, KVClassType.BOOTSTRAP_SERVER
|
344
|
+
)
|
345
|
+
self.bootstrap_server = kv_bootstrap_server_class(
|
334
346
|
self.server_args.disaggregation_bootstrap_port
|
335
347
|
)
|
336
348
|
|
@@ -452,6 +464,8 @@ class TokenizerManager:
|
|
452
464
|
top_logprobs_num,
|
453
465
|
token_ids_logprob,
|
454
466
|
obj.stream,
|
467
|
+
bootstrap_host=obj.bootstrap_host,
|
468
|
+
bootstrap_room=obj.bootstrap_room,
|
455
469
|
lora_path=obj.lora_path,
|
456
470
|
input_embeds=input_embeds,
|
457
471
|
session_params=session_params,
|
@@ -636,6 +650,7 @@ class TokenizerManager:
|
|
636
650
|
output_dir=output_dir,
|
637
651
|
num_steps=num_steps,
|
638
652
|
activities=activities,
|
653
|
+
profile_id=str(time.time()),
|
639
654
|
)
|
640
655
|
result = (await self.start_profile_communicator(req))[0]
|
641
656
|
if not result.success:
|
sglang/srt/managers/tp_worker.py
CHANGED
@@ -68,6 +68,7 @@ class TpModelWorker:
|
|
68
68
|
context_length=server_args.context_length,
|
69
69
|
model_override_args=server_args.json_model_override_args,
|
70
70
|
is_embedding=server_args.is_embedding,
|
71
|
+
enable_multimodal=server_args.enable_multimodal,
|
71
72
|
dtype=server_args.dtype,
|
72
73
|
quantization=server_args.quantization,
|
73
74
|
)
|
@@ -92,7 +92,7 @@ class HiRadixCache(RadixCache):
|
|
92
92
|
self.ongoing_write_through[node.id] = node
|
93
93
|
self.inc_lock_ref(node)
|
94
94
|
else:
|
95
|
-
return
|
95
|
+
return 0
|
96
96
|
|
97
97
|
return len(host_indices)
|
98
98
|
|
@@ -153,6 +153,7 @@ class HiRadixCache(RadixCache):
|
|
153
153
|
if x.host_value is None:
|
154
154
|
if self.cache_controller.write_policy == "write_back":
|
155
155
|
num_evicted += self.write_backup(x)
|
156
|
+
pending_nodes.append(x)
|
156
157
|
elif self.cache_controller.write_policy == "write_through_selective":
|
157
158
|
num_evicted += self._evict_write_through_selective(x)
|
158
159
|
else:
|
@@ -177,6 +178,9 @@ class HiRadixCache(RadixCache):
|
|
177
178
|
while len(self.ongoing_write_through) > 0:
|
178
179
|
self.writing_check()
|
179
180
|
time.sleep(0.1)
|
181
|
+
for node in pending_nodes:
|
182
|
+
assert node.host_value is not None
|
183
|
+
self._evict_write_through(node)
|
180
184
|
|
181
185
|
def _evict_write_through(self, node: TreeNode):
|
182
186
|
# evict a node already written to host
|
@@ -286,8 +286,12 @@ class MHATokenToKVPool(KVCache):
|
|
286
286
|
self.get_key_buffer(i).nbytes for i in range(self.layer_num)
|
287
287
|
] + [self.get_value_buffer(i).nbytes for i in range(self.layer_num)]
|
288
288
|
kv_item_lens = [
|
289
|
-
self.get_key_buffer(i)[0].nbytes
|
290
|
-
|
289
|
+
self.get_key_buffer(i)[0].nbytes * self.page_size
|
290
|
+
for i in range(self.layer_num)
|
291
|
+
] + [
|
292
|
+
self.get_value_buffer(i)[0].nbytes * self.page_size
|
293
|
+
for i in range(self.layer_num)
|
294
|
+
]
|
291
295
|
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
292
296
|
|
293
297
|
# Todo: different memory layout
|
@@ -414,6 +418,7 @@ class MLATokenToKVPool(KVCache):
|
|
414
418
|
enable_memory_saver: bool,
|
415
419
|
):
|
416
420
|
self.size = size
|
421
|
+
self.page_size = page_size
|
417
422
|
self.dtype = dtype
|
418
423
|
self.device = device
|
419
424
|
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
|
@@ -442,6 +447,14 @@ class MLATokenToKVPool(KVCache):
|
|
442
447
|
|
443
448
|
self.layer_transfer_counter = None
|
444
449
|
|
450
|
+
# for disagg
|
451
|
+
def get_contiguous_buf_infos(self):
|
452
|
+
# MLA has only one kv_buffer, so only the information of this buffer needs to be returned.
|
453
|
+
kv_data_ptrs = [self.kv_buffer[i].data_ptr() for i in range(self.layer_num)]
|
454
|
+
kv_data_lens = [self.kv_buffer[i].nbytes for i in range(self.layer_num)]
|
455
|
+
kv_item_lens = [self.kv_buffer[i][0].nbytes for i in range(self.layer_num)]
|
456
|
+
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
457
|
+
|
445
458
|
def get_key_buffer(self, layer_id: int):
|
446
459
|
if self.layer_transfer_counter is not None:
|
447
460
|
self.layer_transfer_counter.wait_until(layer_id)
|
@@ -866,7 +879,12 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|
866
879
|
self.qk_rope_head_dim = self.device_pool.qk_rope_head_dim
|
867
880
|
self.layer_num = self.device_pool.layer_num
|
868
881
|
|
869
|
-
return (
|
882
|
+
return (
|
883
|
+
(self.kv_lora_rank + self.qk_rope_head_dim)
|
884
|
+
* 1
|
885
|
+
* self.dtype.itemsize
|
886
|
+
* self.layer_num
|
887
|
+
)
|
870
888
|
|
871
889
|
def init_kv_buffer(self):
|
872
890
|
return torch.empty(
|
sglang/srt/metrics/collector.py
CHANGED
@@ -27,6 +27,7 @@ class SchedulerStats:
|
|
27
27
|
num_queue_reqs: int = 0
|
28
28
|
cache_hit_rate: float = 0.0
|
29
29
|
spec_accept_length: float = 0.0
|
30
|
+
avg_request_queue_latency: float = 0.0
|
30
31
|
|
31
32
|
|
32
33
|
class SchedulerMetricsCollector:
|
@@ -87,6 +88,13 @@ class SchedulerMetricsCollector:
|
|
87
88
|
multiprocess_mode="mostrecent",
|
88
89
|
)
|
89
90
|
|
91
|
+
self.avg_request_queue_latency = Gauge(
|
92
|
+
name="sglang:avg_request_queue_latency",
|
93
|
+
documentation="The average request queue latency for the last batch of requests in seconds.",
|
94
|
+
labelnames=labels.keys(),
|
95
|
+
multiprocess_mode="mostrecent",
|
96
|
+
)
|
97
|
+
|
90
98
|
def _log_gauge(self, gauge, data: Union[int, float]) -> None:
|
91
99
|
# Convenience function for logging to gauge.
|
92
100
|
gauge.labels(**self.labels).set(data)
|
@@ -99,6 +107,7 @@ class SchedulerMetricsCollector:
|
|
99
107
|
self._log_gauge(self.num_queue_reqs, stats.num_queue_reqs)
|
100
108
|
self._log_gauge(self.cache_hit_rate, stats.cache_hit_rate)
|
101
109
|
self._log_gauge(self.spec_accept_length, stats.spec_accept_length)
|
110
|
+
self._log_gauge(self.avg_request_queue_latency, stats.avg_request_queue_latency)
|
102
111
|
self.last_log_time = time.time()
|
103
112
|
|
104
113
|
|
@@ -34,13 +34,14 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|
34
34
|
ForwardBatch,
|
35
35
|
ForwardMode,
|
36
36
|
)
|
37
|
+
from sglang.srt.patch_torch import monkey_patch_torch_compile
|
37
38
|
from sglang.srt.utils import get_available_gpu_memory, is_hip
|
38
39
|
|
39
|
-
_is_hip = is_hip()
|
40
|
-
|
41
40
|
if TYPE_CHECKING:
|
42
41
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
43
42
|
|
43
|
+
_is_hip = is_hip()
|
44
|
+
|
44
45
|
|
45
46
|
def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
|
46
47
|
for sub in model._modules.values():
|
@@ -108,6 +109,8 @@ def set_torch_compile_config():
|
|
108
109
|
if hasattr(torch._dynamo.config, "cache_size_limit"):
|
109
110
|
torch._dynamo.config.cache_size_limit = 1024
|
110
111
|
|
112
|
+
monkey_patch_torch_compile()
|
113
|
+
|
111
114
|
|
112
115
|
def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
113
116
|
server_args = model_runner.server_args
|
@@ -116,7 +119,7 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
|
116
119
|
if capture_bs is None:
|
117
120
|
if server_args.speculative_algorithm is None:
|
118
121
|
if server_args.disable_cuda_graph_padding:
|
119
|
-
capture_bs = list(range(1, 33)) + range(40, 161, 16)
|
122
|
+
capture_bs = list(range(1, 33)) + list(range(40, 161, 16))
|
120
123
|
else:
|
121
124
|
capture_bs = [1, 2, 4, 8] + list(range(16, 161, 8))
|
122
125
|
else:
|
@@ -269,10 +272,10 @@ class CudaGraphRunner:
|
|
269
272
|
raise Exception(
|
270
273
|
f"Capture cuda graph failed: {e}\n"
|
271
274
|
"Possible solutions:\n"
|
272
|
-
"1.
|
273
|
-
"2. set --
|
275
|
+
"1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n"
|
276
|
+
"2. set --cuda-graph-max-bs to a smaller value (e.g., 32)\n"
|
274
277
|
"3. disable torch compile by not using --enable-torch-compile\n"
|
275
|
-
"4.
|
278
|
+
"4. disable cuda graph by --disable-cuda-graph\n"
|
276
279
|
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
|
277
280
|
)
|
278
281
|
|