sglang 0.4.5__py3-none-any.whl → 0.4.5.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. sglang/__init__.py +2 -4
  2. sglang/bench_one_batch.py +23 -2
  3. sglang/bench_serving.py +6 -4
  4. sglang/lang/backend/anthropic.py +0 -4
  5. sglang/lang/backend/base_backend.py +1 -1
  6. sglang/lang/backend/openai.py +1 -1
  7. sglang/lang/backend/vertexai.py +0 -1
  8. sglang/lang/compiler.py +1 -7
  9. sglang/lang/tracer.py +3 -7
  10. sglang/srt/_custom_ops.py +0 -2
  11. sglang/srt/configs/model_config.py +37 -5
  12. sglang/srt/constrained/base_grammar_backend.py +26 -5
  13. sglang/srt/constrained/llguidance_backend.py +1 -0
  14. sglang/srt/constrained/outlines_backend.py +1 -0
  15. sglang/srt/constrained/outlines_jump_forward.py +14 -1
  16. sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
  17. sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
  18. sglang/srt/constrained/xgrammar_backend.py +27 -4
  19. sglang/srt/custom_op.py +0 -62
  20. sglang/srt/disaggregation/base/__init__.py +8 -0
  21. sglang/srt/disaggregation/base/conn.py +113 -0
  22. sglang/srt/disaggregation/decode.py +80 -11
  23. sglang/srt/disaggregation/mini_lb.py +58 -123
  24. sglang/srt/disaggregation/mooncake/__init__.py +6 -0
  25. sglang/srt/disaggregation/mooncake/conn.py +585 -0
  26. sglang/srt/disaggregation/mooncake/transfer_engine.py +77 -0
  27. sglang/srt/disaggregation/prefill.py +82 -22
  28. sglang/srt/disaggregation/utils.py +46 -0
  29. sglang/srt/entrypoints/EngineBase.py +53 -0
  30. sglang/srt/entrypoints/engine.py +36 -8
  31. sglang/srt/entrypoints/http_server.py +37 -8
  32. sglang/srt/entrypoints/http_server_engine.py +142 -0
  33. sglang/srt/entrypoints/verl_engine.py +42 -13
  34. sglang/srt/hf_transformers_utils.py +4 -0
  35. sglang/srt/layers/activation.py +6 -8
  36. sglang/srt/layers/attention/flashattention_backend.py +430 -257
  37. sglang/srt/layers/attention/flashinfer_backend.py +18 -9
  38. sglang/srt/layers/attention/torch_native_backend.py +6 -1
  39. sglang/srt/layers/attention/triton_backend.py +6 -0
  40. sglang/srt/layers/attention/triton_ops/extend_attention.py +13 -2
  41. sglang/srt/layers/attention/vision.py +1 -1
  42. sglang/srt/layers/dp_attention.py +2 -4
  43. sglang/srt/layers/elementwise.py +15 -2
  44. sglang/srt/layers/layernorm.py +1 -1
  45. sglang/srt/layers/linear.py +18 -3
  46. sglang/srt/layers/moe/ep_moe/layer.py +15 -29
  47. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
  48. sglang/srt/layers/moe/fused_moe_native.py +4 -0
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +46 -34
  56. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  57. sglang/srt/layers/moe/router.py +7 -1
  58. sglang/srt/layers/moe/topk.py +63 -45
  59. sglang/srt/layers/parameter.py +0 -2
  60. sglang/srt/layers/quantization/__init__.py +13 -5
  61. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  62. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +12 -2
  63. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -77
  64. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
  65. sglang/srt/layers/quantization/fp8.py +131 -136
  66. sglang/srt/layers/quantization/fp8_kernel.py +328 -46
  67. sglang/srt/layers/quantization/fp8_utils.py +206 -253
  68. sglang/srt/layers/quantization/kv_cache.py +43 -52
  69. sglang/srt/layers/quantization/modelopt_quant.py +271 -4
  70. sglang/srt/layers/quantization/moe_wna16.py +2 -0
  71. sglang/srt/layers/quantization/utils.py +5 -11
  72. sglang/srt/layers/quantization/w8a8_fp8.py +156 -4
  73. sglang/srt/layers/quantization/w8a8_int8.py +8 -7
  74. sglang/srt/layers/radix_attention.py +28 -1
  75. sglang/srt/layers/rotary_embedding.py +15 -3
  76. sglang/srt/layers/sampler.py +5 -10
  77. sglang/srt/lora/backend/base_backend.py +18 -2
  78. sglang/srt/lora/backend/flashinfer_backend.py +1 -1
  79. sglang/srt/lora/backend/triton_backend.py +1 -1
  80. sglang/srt/lora/layers.py +1 -1
  81. sglang/srt/lora/lora.py +1 -1
  82. sglang/srt/lora/lora_manager.py +1 -1
  83. sglang/srt/managers/detokenizer_manager.py +0 -1
  84. sglang/srt/managers/io_struct.py +255 -97
  85. sglang/srt/managers/mm_utils.py +7 -5
  86. sglang/srt/managers/multimodal_processor.py +0 -2
  87. sglang/srt/managers/multimodal_processors/base_processor.py +117 -79
  88. sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
  89. sglang/srt/managers/multimodal_processors/mllama4.py +21 -36
  90. sglang/srt/managers/schedule_batch.py +64 -25
  91. sglang/srt/managers/scheduler.py +80 -82
  92. sglang/srt/managers/tokenizer_manager.py +18 -3
  93. sglang/srt/managers/tp_worker.py +1 -0
  94. sglang/srt/mem_cache/hiradix_cache.py +5 -1
  95. sglang/srt/mem_cache/memory_pool.py +21 -3
  96. sglang/srt/metrics/collector.py +9 -0
  97. sglang/srt/model_executor/cuda_graph_runner.py +9 -6
  98. sglang/srt/model_executor/forward_batch_info.py +234 -15
  99. sglang/srt/model_executor/model_runner.py +67 -35
  100. sglang/srt/model_loader/loader.py +31 -4
  101. sglang/srt/model_loader/weight_utils.py +4 -2
  102. sglang/srt/models/baichuan.py +2 -0
  103. sglang/srt/models/bert.py +398 -0
  104. sglang/srt/models/chatglm.py +1 -0
  105. sglang/srt/models/commandr.py +1 -0
  106. sglang/srt/models/dbrx.py +1 -0
  107. sglang/srt/models/deepseek.py +2 -1
  108. sglang/srt/models/deepseek_nextn.py +74 -70
  109. sglang/srt/models/deepseek_v2.py +494 -366
  110. sglang/srt/models/exaone.py +1 -0
  111. sglang/srt/models/gemma.py +1 -0
  112. sglang/srt/models/gemma2.py +1 -0
  113. sglang/srt/models/gemma3_causal.py +1 -0
  114. sglang/srt/models/gpt2.py +1 -0
  115. sglang/srt/models/gpt_bigcode.py +1 -0
  116. sglang/srt/models/granite.py +1 -0
  117. sglang/srt/models/grok.py +1 -0
  118. sglang/srt/models/internlm2.py +1 -0
  119. sglang/srt/models/llama.py +6 -5
  120. sglang/srt/models/llama4.py +101 -34
  121. sglang/srt/models/minicpm.py +1 -0
  122. sglang/srt/models/minicpm3.py +30 -200
  123. sglang/srt/models/mixtral.py +1 -0
  124. sglang/srt/models/mixtral_quant.py +1 -0
  125. sglang/srt/models/mllama.py +51 -8
  126. sglang/srt/models/mllama4.py +102 -29
  127. sglang/srt/models/olmo.py +1 -0
  128. sglang/srt/models/olmo2.py +1 -0
  129. sglang/srt/models/olmoe.py +1 -0
  130. sglang/srt/models/phi3_small.py +1 -0
  131. sglang/srt/models/qwen.py +1 -0
  132. sglang/srt/models/qwen2.py +5 -1
  133. sglang/srt/models/qwen2_5_vl.py +35 -70
  134. sglang/srt/models/qwen2_moe.py +15 -13
  135. sglang/srt/models/qwen2_vl.py +27 -25
  136. sglang/srt/models/qwen3.py +335 -0
  137. sglang/srt/models/qwen3_moe.py +423 -0
  138. sglang/srt/models/stablelm.py +1 -0
  139. sglang/srt/models/xverse.py +1 -0
  140. sglang/srt/models/xverse_moe.py +1 -0
  141. sglang/srt/openai_api/adapter.py +4 -1
  142. sglang/srt/patch_torch.py +11 -0
  143. sglang/srt/reasoning_parser.py +0 -1
  144. sglang/srt/sampling/sampling_batch_info.py +2 -3
  145. sglang/srt/server_args.py +55 -19
  146. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
  147. sglang/srt/speculative/eagle_utils.py +1 -11
  148. sglang/srt/speculative/eagle_worker.py +10 -9
  149. sglang/srt/utils.py +136 -10
  150. sglang/test/attention/test_flashattn_backend.py +259 -221
  151. sglang/test/attention/test_flashattn_mla_backend.py +285 -0
  152. sglang/test/attention/test_prefix_chunk_info.py +224 -0
  153. sglang/test/runners.py +5 -1
  154. sglang/test/test_block_fp8.py +224 -0
  155. sglang/test/test_custom_ops.py +1 -1
  156. sglang/test/test_utils.py +19 -8
  157. sglang/version.py +1 -1
  158. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/METADATA +15 -5
  159. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/RECORD +162 -147
  160. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/WHEEL +1 -1
  161. sglang/lang/__init__.py +0 -0
  162. sglang/srt/disaggregation/conn.py +0 -81
  163. sglang/srt/lora/backend/__init__.py +0 -25
  164. sglang/srt/server.py +0 -18
  165. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/licenses/LICENSE +0 -0
  166. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/top_level.txt +0 -0
@@ -49,6 +49,7 @@ from sglang.srt.disaggregation.prefill import (
49
49
  from sglang.srt.disaggregation.utils import (
50
50
  DisaggregationMode,
51
51
  ReqToMetadataIdxAllocator,
52
+ TransferBackend,
52
53
  )
53
54
  from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
54
55
  from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
@@ -113,6 +114,7 @@ from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
113
114
  from sglang.srt.mem_cache.radix_cache import RadixCache
114
115
  from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
115
116
  from sglang.srt.model_executor.forward_batch_info import ForwardMode
117
+ from sglang.srt.reasoning_parser import ReasoningParser
116
118
  from sglang.srt.server_args import PortArgs, ServerArgs
117
119
  from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
118
120
  from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
@@ -232,6 +234,15 @@ class Scheduler(
232
234
  # Init tokenizer
233
235
  self.init_tokenizer()
234
236
 
237
+ # Set reasoning_parser and think_end_id if --reasoning_parser is enabled
238
+ if self.server_args.reasoning_parser and self.tokenizer:
239
+ reasoning_parser = ReasoningParser(
240
+ model_type=self.server_args.reasoning_parser, stream_reasoning=False
241
+ )
242
+ self.tokenizer.think_end_id = self.tokenizer.encode(
243
+ reasoning_parser.detector.think_end_token, add_special_tokens=False
244
+ )[0]
245
+
235
246
  # Check whether overlap can be enabled
236
247
  if not self.is_generation:
237
248
  self.enable_overlap = False
@@ -380,6 +391,7 @@ class Scheduler(
380
391
  self.torch_profiler = None
381
392
  self.torch_profiler_output_dir: Optional[str] = None
382
393
  self.profiler_activities: Optional[List[str]] = None
394
+ self.profiler_id: Optional[str] = None
383
395
  self.profiler_target_forward_ct: Optional[int] = None
384
396
 
385
397
  # Init metrics stats
@@ -427,6 +439,7 @@ class Scheduler(
427
439
  context_length=server_args.context_length,
428
440
  model_override_args=server_args.json_model_override_args,
429
441
  is_embedding=server_args.is_embedding,
442
+ enable_multimodal=server_args.enable_multimodal,
430
443
  dtype=server_args.dtype,
431
444
  quantization=server_args.quantization,
432
445
  )
@@ -441,6 +454,7 @@ class Scheduler(
441
454
  tokenizer_mode=server_args.tokenizer_mode,
442
455
  trust_remote_code=server_args.trust_remote_code,
443
456
  revision=server_args.revision,
457
+ use_fast=not server_args.disable_fast_image_processor,
444
458
  )
445
459
  self.tokenizer = self.processor.tokenizer
446
460
  else:
@@ -471,7 +485,7 @@ class Scheduler(
471
485
  self.tree_cache = HiRadixCache(
472
486
  req_to_token_pool=self.req_to_token_pool,
473
487
  token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
474
- tp_cache_group=self.tp_worker.get_tp_cpu_group(),
488
+ tp_cache_group=self.tp_cpu_group,
475
489
  page_size=self.page_size,
476
490
  hicache_ratio=server_args.hicache_ratio,
477
491
  )
@@ -518,6 +532,10 @@ class Scheduler(
518
532
  )
519
533
 
520
534
  def init_disaggregation(self):
535
+ self.transfer_backend = TransferBackend(
536
+ self.server_args.disaggregation_transfer_backend
537
+ )
538
+
521
539
  if (
522
540
  self.disaggregation_mode == DisaggregationMode.DECODE
523
541
  ): # *2 for the headroom.
@@ -536,7 +554,7 @@ class Scheduler(
536
554
 
537
555
  # The decode requests polling kv cache
538
556
  self.disagg_decode_transfer_queue = DecodeTransferQueue(
539
- gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
557
+ gloo_group=self.attn_tp_cpu_group,
540
558
  req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
541
559
  metadata_buffers=metadata_buffers,
542
560
  )
@@ -551,10 +569,11 @@ class Scheduler(
551
569
  scheduler=self,
552
570
  transfer_queue=self.disagg_decode_transfer_queue,
553
571
  tree_cache=self.tree_cache,
554
- gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
572
+ gloo_group=self.attn_tp_cpu_group,
555
573
  tp_rank=self.tp_rank,
556
574
  tp_size=self.tp_size,
557
575
  bootstrap_port=self.server_args.disaggregation_bootstrap_port,
576
+ transfer_backend=self.transfer_backend,
558
577
  )
559
578
  elif self.disaggregation_mode == DisaggregationMode.PREFILL:
560
579
  # *2 for the headroom.
@@ -579,10 +598,12 @@ class Scheduler(
579
598
  tp_rank=self.tp_rank,
580
599
  tp_size=self.tp_size,
581
600
  bootstrap_port=self.server_args.disaggregation_bootstrap_port,
582
- gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
601
+ gloo_group=self.attn_tp_cpu_group,
602
+ transfer_backend=self.transfer_backend,
603
+ scheduler=self,
583
604
  )
584
605
  # The prefill requests that are in the middle of kv sending
585
- self.disagg_prefill_infight_queue: List[Req] = []
606
+ self.disagg_prefill_inflight_queue: List[Req] = []
586
607
 
587
608
  @DynamicGradMode()
588
609
  def event_loop_normal(self):
@@ -644,70 +665,6 @@ class Scheduler(
644
665
 
645
666
  self.last_batch = batch
646
667
 
647
- @torch.no_grad()
648
- def event_loop_normal_disagg_prefill(self):
649
- """A normal scheduler loop for prefill worker in disaggregation mode."""
650
-
651
- while True:
652
- recv_reqs = self.recv_requests()
653
- self.process_input_requests(recv_reqs)
654
- self.waiting_queue.extend(
655
- self.disagg_prefill_pending_queue.pop_bootstrapped()
656
- )
657
- self.process_prefill_chunk()
658
- batch = self.get_new_batch_prefill()
659
- self.cur_batch = batch
660
-
661
- if batch:
662
- result = self.run_batch(batch)
663
- self.process_batch_result_disagg_prefill(batch, result)
664
-
665
- if len(self.disagg_prefill_infight_queue) > 0:
666
- self.process_disagg_prefill_infight_queue()
667
-
668
- if batch is None and len(self.disagg_prefill_infight_queue) == 0:
669
- self.check_memory()
670
- self.new_token_ratio = self.init_new_token_ratio
671
-
672
- self.last_batch = batch
673
- # HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
674
- # Otherwise, it hangs under high concurrency
675
- self.running_batch.batch_is_full = False
676
-
677
- @torch.no_grad()
678
- def event_loop_normal_disagg_decode(self):
679
- """A normal scheduler loop for decode worker in disaggregation mode."""
680
-
681
- while True:
682
- recv_reqs = self.recv_requests()
683
- self.process_input_requests(recv_reqs)
684
- # polling and allocating kv cache
685
- self.process_decode_queue()
686
- batch = self.get_next_disagg_decode_batch_to_run()
687
- self.cur_batch = batch
688
-
689
- if batch:
690
- # Generate fake extend output.
691
- if batch.forward_mode.is_extend():
692
- # Note: Logprobs should be handled on the prefill engine.
693
- self.stream_output(
694
- batch.reqs, [False for _ in range(len(batch.reqs))]
695
- )
696
- else:
697
- result = self.run_batch(batch)
698
- self.process_batch_result(batch, result)
699
-
700
- if batch is None and (
701
- len(self.disagg_decode_transfer_queue.queue)
702
- + len(self.disagg_decode_prealloc_queue.queue)
703
- == 0
704
- ):
705
- # When the server is idle, do self-check and re-init some states
706
- self.check_memory()
707
- self.new_token_ratio = self.init_new_token_ratio
708
-
709
- self.last_batch = batch
710
-
711
668
  def recv_requests(self) -> List[Req]:
712
669
  """Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
713
670
  if self.attn_tp_rank == 0:
@@ -826,6 +783,8 @@ class Scheduler(
826
783
  custom_logit_processor=custom_logit_processor,
827
784
  return_hidden_states=recv_req.return_hidden_states,
828
785
  eos_token_ids=self.model_config.hf_eos_token_id,
786
+ bootstrap_host=recv_req.bootstrap_host,
787
+ bootstrap_room=recv_req.bootstrap_room,
829
788
  )
830
789
  req.tokenizer = self.tokenizer
831
790
 
@@ -937,12 +896,11 @@ class Scheduler(
937
896
  self._add_request_to_queue(req)
938
897
 
939
898
  def _add_request_to_queue(self, req: Req):
899
+ req.queue_time_start = time.time()
940
900
  if self.disaggregation_mode == DisaggregationMode.PREFILL:
941
901
  self.disagg_prefill_pending_queue.add(req)
942
-
943
902
  elif self.disaggregation_mode == DisaggregationMode.DECODE:
944
903
  self.disagg_decode_prealloc_queue.add(req)
945
-
946
904
  else:
947
905
  self.waiting_queue.append(req)
948
906
 
@@ -985,6 +943,7 @@ class Scheduler(
985
943
  req.finished_reason = FINISH_ABORT(
986
944
  error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
987
945
  )
946
+ req.queue_time_start = time.time()
988
947
  self.waiting_queue.append(req)
989
948
  return
990
949
 
@@ -1021,9 +980,10 @@ class Scheduler(
1021
980
  self._largest_prefill_len, adder.log_input_tokens
1022
981
  )
1023
982
 
983
+ num_new_seq = len(can_run_list)
1024
984
  f = (
1025
985
  f"Prefill batch. "
1026
- f"#new-seq: {len(can_run_list)}, "
986
+ f"#new-seq: {num_new_seq}, "
1027
987
  f"#new-token: {adder.log_input_tokens}, "
1028
988
  f"#cached-token: {adder.log_hit_tokens}, "
1029
989
  f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
@@ -1041,6 +1001,12 @@ class Scheduler(
1041
1001
  self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
1042
1002
  self.stats.num_queue_reqs = len(self.waiting_queue)
1043
1003
  self.stats.cache_hit_rate = cache_hit_rate
1004
+
1005
+ total_queue_latency = 0
1006
+ for req in can_run_list:
1007
+ total_queue_latency += req.queue_time_end - req.queue_time_start
1008
+ self.stats.avg_request_queue_latency = total_queue_latency / num_new_seq
1009
+
1044
1010
  self.metrics_collector.log_stats(self.stats)
1045
1011
 
1046
1012
  def log_decode_stats(self):
@@ -1277,6 +1243,12 @@ class Scheduler(
1277
1243
  can_run_list: List[Req] = adder.can_run_list
1278
1244
  if len(can_run_list) == 0:
1279
1245
  return None
1246
+
1247
+ if self.enable_metrics:
1248
+ # only record queue time when enable_metrics is True to avoid overhead
1249
+ for req in can_run_list:
1250
+ req.queue_time_end = time.time()
1251
+
1280
1252
  self.waiting_queue = [
1281
1253
  x for x in self.waiting_queue if x not in set(can_run_list)
1282
1254
  ]
@@ -1456,14 +1428,36 @@ class Scheduler(
1456
1428
  self.send_to_tokenizer.send_pyobj(HealthCheckOutput())
1457
1429
 
1458
1430
  def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
1431
+ return self.prepare_dp_attn_batch_raw(
1432
+ local_batch,
1433
+ dp_size=self.server_args.dp_size,
1434
+ attn_tp_size=self.attn_tp_size,
1435
+ tp_cpu_group=self.tp_cpu_group,
1436
+ get_idle_batch=self.get_idle_batch,
1437
+ disable_cuda_graph=self.server_args.disable_cuda_graph,
1438
+ spec_algorithm=self.spec_algorithm,
1439
+ speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
1440
+ )
1441
+
1442
+ @staticmethod
1443
+ def prepare_dp_attn_batch_raw(
1444
+ local_batch: ScheduleBatch,
1445
+ dp_size,
1446
+ attn_tp_size: int,
1447
+ tp_cpu_group,
1448
+ get_idle_batch,
1449
+ disable_cuda_graph: bool,
1450
+ spec_algorithm,
1451
+ speculative_num_draft_tokens,
1452
+ ):
1459
1453
  # Check if other DP workers have running batches
1460
1454
  if local_batch is None:
1461
1455
  num_tokens = 0
1462
1456
  global_num_tokens_for_logprob = 0
1463
1457
  elif local_batch.forward_mode.is_decode():
1464
1458
  num_tokens = local_batch.batch_size()
1465
- if not self.spec_algorithm.is_none() and self.spec_algorithm.is_eagle():
1466
- num_tokens = num_tokens * self.server_args.speculative_num_draft_tokens
1459
+ if not spec_algorithm.is_none() and spec_algorithm.is_eagle():
1460
+ num_tokens = num_tokens * speculative_num_draft_tokens
1467
1461
  global_num_tokens_for_logprob = num_tokens
1468
1462
  else:
1469
1463
  num_tokens = local_batch.extend_num_tokens
@@ -1482,7 +1476,7 @@ class Scheduler(
1482
1476
  else:
1483
1477
  can_cuda_graph = 0
1484
1478
 
1485
- if not self.spec_algorithm.is_none():
1479
+ if not spec_algorithm.is_none():
1486
1480
  # TODO(sang): Support cuda graph when idle batch is there.
1487
1481
  if local_batch is None or local_batch.forward_mode.is_idle():
1488
1482
  can_cuda_graph = 0
@@ -1500,13 +1494,13 @@ class Scheduler(
1500
1494
  dtype=torch.int64,
1501
1495
  )
1502
1496
  global_info = torch.empty(
1503
- (self.server_args.dp_size, self.attn_tp_size, 4),
1497
+ (dp_size, attn_tp_size, 4),
1504
1498
  dtype=torch.int64,
1505
1499
  )
1506
1500
  torch.distributed.all_gather_into_tensor(
1507
1501
  global_info.flatten(),
1508
1502
  local_info,
1509
- group=self.tp_cpu_group,
1503
+ group=tp_cpu_group,
1510
1504
  )
1511
1505
  global_num_tokens = global_info[:, 0, 0].tolist()
1512
1506
  can_cuda_graph = min(global_info[:, 0, 1].tolist())
@@ -1514,14 +1508,14 @@ class Scheduler(
1514
1508
  is_extend_in_batch = global_info[:, 0, 3].tolist()
1515
1509
 
1516
1510
  if local_batch is None and max(global_num_tokens) > 0:
1517
- local_batch = self.get_idle_batch()
1511
+ local_batch = get_idle_batch()
1518
1512
 
1519
1513
  if local_batch is not None:
1520
1514
  local_batch.global_num_tokens = global_num_tokens
1521
1515
  local_batch.global_num_tokens_for_logprob = global_num_tokens_for_logprob
1522
1516
 
1523
1517
  # Check forward mode for cuda graph
1524
- if not self.server_args.disable_cuda_graph:
1518
+ if not disable_cuda_graph:
1525
1519
  local_batch.can_run_dp_cuda_graph = can_cuda_graph
1526
1520
 
1527
1521
  return local_batch, any(is_extend_in_batch)
@@ -1812,6 +1806,7 @@ class Scheduler(
1812
1806
  recv_req.activities,
1813
1807
  recv_req.with_stack,
1814
1808
  recv_req.record_shapes,
1809
+ recv_req.profile_id,
1815
1810
  )
1816
1811
  else:
1817
1812
  return self.stop_profile()
@@ -1823,6 +1818,7 @@ class Scheduler(
1823
1818
  activities: Optional[List[str]],
1824
1819
  with_stack: Optional[bool],
1825
1820
  record_shapes: Optional[bool],
1821
+ profile_id: Optional[str],
1826
1822
  ) -> None:
1827
1823
  if self.profiler_activities:
1828
1824
  return ProfileReqOutput(
@@ -1837,9 +1833,11 @@ class Scheduler(
1837
1833
 
1838
1834
  self.torch_profiler_output_dir = output_dir
1839
1835
  self.profiler_activities = activities
1836
+ self.profiler_id = profile_id
1840
1837
  logger.info(
1841
- "Profiling starts. Traces will be saved to: %s",
1838
+ "Profiling starts. Traces will be saved to: %s (with id %s)",
1842
1839
  self.torch_profiler_output_dir,
1840
+ self.profiler_id,
1843
1841
  )
1844
1842
 
1845
1843
  activity_map = {
@@ -1881,14 +1879,14 @@ class Scheduler(
1881
1879
  self.torch_profiler.export_chrome_trace(
1882
1880
  os.path.join(
1883
1881
  self.torch_profiler_output_dir,
1884
- str(time.time()) + f"-TP-{self.tp_rank}" + ".trace.json.gz",
1882
+ self.profiler_id + f"-TP-{self.tp_rank}" + ".trace.json.gz",
1885
1883
  )
1886
1884
  )
1887
1885
 
1888
1886
  if "MEM" in self.profiler_activities:
1889
1887
  memory_profile_path = os.path.join(
1890
1888
  self.torch_profiler_output_dir,
1891
- str(time.time()) + f"-TP-{self.tp_rank}-memory" + ".pickle",
1889
+ self.profiler_id + f"-TP-{self.tp_rank}-memory" + ".pickle",
1892
1890
  )
1893
1891
  torch.cuda.memory._dump_snapshot(memory_profile_path)
1894
1892
  torch.cuda.memory._record_memory_history(enabled=None)
@@ -48,8 +48,12 @@ from fastapi import BackgroundTasks
48
48
 
49
49
  from sglang.srt.aio_rwlock import RWLock
50
50
  from sglang.srt.configs.model_config import ModelConfig
51
- from sglang.srt.disaggregation.conn import KVBootstrapServer
52
- from sglang.srt.disaggregation.utils import DisaggregationMode
51
+ from sglang.srt.disaggregation.utils import (
52
+ DisaggregationMode,
53
+ KVClassType,
54
+ TransferBackend,
55
+ get_kv_class,
56
+ )
53
57
  from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
54
58
  from sglang.srt.managers.io_struct import (
55
59
  AbortReq,
@@ -163,6 +167,7 @@ class TokenizerManager:
163
167
  context_length=server_args.context_length,
164
168
  model_override_args=server_args.json_model_override_args,
165
169
  is_embedding=server_args.is_embedding,
170
+ enable_multimodal=server_args.enable_multimodal,
166
171
  dtype=server_args.dtype,
167
172
  quantization=server_args.quantization,
168
173
  )
@@ -179,6 +184,7 @@ class TokenizerManager:
179
184
  tokenizer_mode=server_args.tokenizer_mode,
180
185
  trust_remote_code=server_args.trust_remote_code,
181
186
  revision=server_args.revision,
187
+ use_fast=not server_args.disable_fast_image_processor,
182
188
  )
183
189
 
184
190
  # We want to parallelize the image pre-processing so we create an executor for it
@@ -327,10 +333,16 @@ class TokenizerManager:
327
333
  self.disaggregation_mode = DisaggregationMode(
328
334
  self.server_args.disaggregation_mode
329
335
  )
336
+ self.transfer_backend = TransferBackend(
337
+ self.server_args.disaggregation_transfer_backend
338
+ )
330
339
  # for disaggregtion, start kv boostrap server on prefill
331
340
  if self.disaggregation_mode == DisaggregationMode.PREFILL:
332
341
  # only start bootstrap server on prefill tm
333
- self.bootstrap_server = KVBootstrapServer(
342
+ kv_bootstrap_server_class = get_kv_class(
343
+ self.transfer_backend, KVClassType.BOOTSTRAP_SERVER
344
+ )
345
+ self.bootstrap_server = kv_bootstrap_server_class(
334
346
  self.server_args.disaggregation_bootstrap_port
335
347
  )
336
348
 
@@ -452,6 +464,8 @@ class TokenizerManager:
452
464
  top_logprobs_num,
453
465
  token_ids_logprob,
454
466
  obj.stream,
467
+ bootstrap_host=obj.bootstrap_host,
468
+ bootstrap_room=obj.bootstrap_room,
455
469
  lora_path=obj.lora_path,
456
470
  input_embeds=input_embeds,
457
471
  session_params=session_params,
@@ -636,6 +650,7 @@ class TokenizerManager:
636
650
  output_dir=output_dir,
637
651
  num_steps=num_steps,
638
652
  activities=activities,
653
+ profile_id=str(time.time()),
639
654
  )
640
655
  result = (await self.start_profile_communicator(req))[0]
641
656
  if not result.success:
@@ -68,6 +68,7 @@ class TpModelWorker:
68
68
  context_length=server_args.context_length,
69
69
  model_override_args=server_args.json_model_override_args,
70
70
  is_embedding=server_args.is_embedding,
71
+ enable_multimodal=server_args.enable_multimodal,
71
72
  dtype=server_args.dtype,
72
73
  quantization=server_args.quantization,
73
74
  )
@@ -92,7 +92,7 @@ class HiRadixCache(RadixCache):
92
92
  self.ongoing_write_through[node.id] = node
93
93
  self.inc_lock_ref(node)
94
94
  else:
95
- return None
95
+ return 0
96
96
 
97
97
  return len(host_indices)
98
98
 
@@ -153,6 +153,7 @@ class HiRadixCache(RadixCache):
153
153
  if x.host_value is None:
154
154
  if self.cache_controller.write_policy == "write_back":
155
155
  num_evicted += self.write_backup(x)
156
+ pending_nodes.append(x)
156
157
  elif self.cache_controller.write_policy == "write_through_selective":
157
158
  num_evicted += self._evict_write_through_selective(x)
158
159
  else:
@@ -177,6 +178,9 @@ class HiRadixCache(RadixCache):
177
178
  while len(self.ongoing_write_through) > 0:
178
179
  self.writing_check()
179
180
  time.sleep(0.1)
181
+ for node in pending_nodes:
182
+ assert node.host_value is not None
183
+ self._evict_write_through(node)
180
184
 
181
185
  def _evict_write_through(self, node: TreeNode):
182
186
  # evict a node already written to host
@@ -286,8 +286,12 @@ class MHATokenToKVPool(KVCache):
286
286
  self.get_key_buffer(i).nbytes for i in range(self.layer_num)
287
287
  ] + [self.get_value_buffer(i).nbytes for i in range(self.layer_num)]
288
288
  kv_item_lens = [
289
- self.get_key_buffer(i)[0].nbytes for i in range(self.layer_num)
290
- ] + [self.get_value_buffer(i)[0].nbytes for i in range(self.layer_num)]
289
+ self.get_key_buffer(i)[0].nbytes * self.page_size
290
+ for i in range(self.layer_num)
291
+ ] + [
292
+ self.get_value_buffer(i)[0].nbytes * self.page_size
293
+ for i in range(self.layer_num)
294
+ ]
291
295
  return kv_data_ptrs, kv_data_lens, kv_item_lens
292
296
 
293
297
  # Todo: different memory layout
@@ -414,6 +418,7 @@ class MLATokenToKVPool(KVCache):
414
418
  enable_memory_saver: bool,
415
419
  ):
416
420
  self.size = size
421
+ self.page_size = page_size
417
422
  self.dtype = dtype
418
423
  self.device = device
419
424
  if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
@@ -442,6 +447,14 @@ class MLATokenToKVPool(KVCache):
442
447
 
443
448
  self.layer_transfer_counter = None
444
449
 
450
+ # for disagg
451
+ def get_contiguous_buf_infos(self):
452
+ # MLA has only one kv_buffer, so only the information of this buffer needs to be returned.
453
+ kv_data_ptrs = [self.kv_buffer[i].data_ptr() for i in range(self.layer_num)]
454
+ kv_data_lens = [self.kv_buffer[i].nbytes for i in range(self.layer_num)]
455
+ kv_item_lens = [self.kv_buffer[i][0].nbytes for i in range(self.layer_num)]
456
+ return kv_data_ptrs, kv_data_lens, kv_item_lens
457
+
445
458
  def get_key_buffer(self, layer_id: int):
446
459
  if self.layer_transfer_counter is not None:
447
460
  self.layer_transfer_counter.wait_until(layer_id)
@@ -866,7 +879,12 @@ class MLATokenToKVPoolHost(HostKVCache):
866
879
  self.qk_rope_head_dim = self.device_pool.qk_rope_head_dim
867
880
  self.layer_num = self.device_pool.layer_num
868
881
 
869
- return (self.kv_lora_rank + self.qk_rope_head_dim) * 1 * self.dtype.itemsize
882
+ return (
883
+ (self.kv_lora_rank + self.qk_rope_head_dim)
884
+ * 1
885
+ * self.dtype.itemsize
886
+ * self.layer_num
887
+ )
870
888
 
871
889
  def init_kv_buffer(self):
872
890
  return torch.empty(
@@ -27,6 +27,7 @@ class SchedulerStats:
27
27
  num_queue_reqs: int = 0
28
28
  cache_hit_rate: float = 0.0
29
29
  spec_accept_length: float = 0.0
30
+ avg_request_queue_latency: float = 0.0
30
31
 
31
32
 
32
33
  class SchedulerMetricsCollector:
@@ -87,6 +88,13 @@ class SchedulerMetricsCollector:
87
88
  multiprocess_mode="mostrecent",
88
89
  )
89
90
 
91
+ self.avg_request_queue_latency = Gauge(
92
+ name="sglang:avg_request_queue_latency",
93
+ documentation="The average request queue latency for the last batch of requests in seconds.",
94
+ labelnames=labels.keys(),
95
+ multiprocess_mode="mostrecent",
96
+ )
97
+
90
98
  def _log_gauge(self, gauge, data: Union[int, float]) -> None:
91
99
  # Convenience function for logging to gauge.
92
100
  gauge.labels(**self.labels).set(data)
@@ -99,6 +107,7 @@ class SchedulerMetricsCollector:
99
107
  self._log_gauge(self.num_queue_reqs, stats.num_queue_reqs)
100
108
  self._log_gauge(self.cache_hit_rate, stats.cache_hit_rate)
101
109
  self._log_gauge(self.spec_accept_length, stats.spec_accept_length)
110
+ self._log_gauge(self.avg_request_queue_latency, stats.avg_request_queue_latency)
102
111
  self.last_log_time = time.time()
103
112
 
104
113
 
@@ -34,13 +34,14 @@ from sglang.srt.model_executor.forward_batch_info import (
34
34
  ForwardBatch,
35
35
  ForwardMode,
36
36
  )
37
+ from sglang.srt.patch_torch import monkey_patch_torch_compile
37
38
  from sglang.srt.utils import get_available_gpu_memory, is_hip
38
39
 
39
- _is_hip = is_hip()
40
-
41
40
  if TYPE_CHECKING:
42
41
  from sglang.srt.model_executor.model_runner import ModelRunner
43
42
 
43
+ _is_hip = is_hip()
44
+
44
45
 
45
46
  def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
46
47
  for sub in model._modules.values():
@@ -108,6 +109,8 @@ def set_torch_compile_config():
108
109
  if hasattr(torch._dynamo.config, "cache_size_limit"):
109
110
  torch._dynamo.config.cache_size_limit = 1024
110
111
 
112
+ monkey_patch_torch_compile()
113
+
111
114
 
112
115
  def get_batch_sizes_to_capture(model_runner: ModelRunner):
113
116
  server_args = model_runner.server_args
@@ -116,7 +119,7 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
116
119
  if capture_bs is None:
117
120
  if server_args.speculative_algorithm is None:
118
121
  if server_args.disable_cuda_graph_padding:
119
- capture_bs = list(range(1, 33)) + range(40, 161, 16)
122
+ capture_bs = list(range(1, 33)) + list(range(40, 161, 16))
120
123
  else:
121
124
  capture_bs = [1, 2, 4, 8] + list(range(16, 161, 8))
122
125
  else:
@@ -269,10 +272,10 @@ class CudaGraphRunner:
269
272
  raise Exception(
270
273
  f"Capture cuda graph failed: {e}\n"
271
274
  "Possible solutions:\n"
272
- "1. disable cuda graph by --disable-cuda-graph\n"
273
- "2. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n"
275
+ "1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n"
276
+ "2. set --cuda-graph-max-bs to a smaller value (e.g., 32)\n"
274
277
  "3. disable torch compile by not using --enable-torch-compile\n"
275
- "4. set --cuda-graph-max-bs to a smaller value (e.g., 32)\n"
278
+ "4. disable cuda graph by --disable-cuda-graph\n"
276
279
  "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
277
280
  )
278
281