sglang 0.4.4__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (176) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +6 -0
  3. sglang/bench_one_batch.py +1 -1
  4. sglang/bench_one_batch_server.py +1 -1
  5. sglang/bench_serving.py +3 -1
  6. sglang/check_env.py +3 -4
  7. sglang/lang/backend/openai.py +18 -5
  8. sglang/lang/chat_template.py +28 -7
  9. sglang/lang/interpreter.py +7 -3
  10. sglang/lang/ir.py +10 -0
  11. sglang/srt/_custom_ops.py +1 -1
  12. sglang/srt/code_completion_parser.py +174 -0
  13. sglang/srt/configs/__init__.py +2 -6
  14. sglang/srt/configs/deepseekvl2.py +667 -0
  15. sglang/srt/configs/janus_pro.py +3 -4
  16. sglang/srt/configs/load_config.py +1 -0
  17. sglang/srt/configs/model_config.py +63 -11
  18. sglang/srt/configs/utils.py +25 -0
  19. sglang/srt/connector/__init__.py +51 -0
  20. sglang/srt/connector/base_connector.py +112 -0
  21. sglang/srt/connector/redis.py +85 -0
  22. sglang/srt/connector/s3.py +122 -0
  23. sglang/srt/connector/serde/__init__.py +31 -0
  24. sglang/srt/connector/serde/safe_serde.py +29 -0
  25. sglang/srt/connector/serde/serde.py +43 -0
  26. sglang/srt/connector/utils.py +35 -0
  27. sglang/srt/conversation.py +88 -0
  28. sglang/srt/disaggregation/conn.py +81 -0
  29. sglang/srt/disaggregation/decode.py +495 -0
  30. sglang/srt/disaggregation/mini_lb.py +285 -0
  31. sglang/srt/disaggregation/prefill.py +249 -0
  32. sglang/srt/disaggregation/utils.py +44 -0
  33. sglang/srt/distributed/parallel_state.py +10 -3
  34. sglang/srt/entrypoints/engine.py +55 -5
  35. sglang/srt/entrypoints/http_server.py +71 -12
  36. sglang/srt/function_call_parser.py +164 -54
  37. sglang/srt/hf_transformers_utils.py +28 -3
  38. sglang/srt/layers/activation.py +4 -2
  39. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +295 -0
  41. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  42. sglang/srt/layers/attention/flashmla_backend.py +284 -0
  43. sglang/srt/layers/attention/triton_backend.py +171 -38
  44. sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
  45. sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
  46. sglang/srt/layers/attention/utils.py +53 -0
  47. sglang/srt/layers/attention/vision.py +9 -28
  48. sglang/srt/layers/dp_attention.py +62 -23
  49. sglang/srt/layers/elementwise.py +411 -0
  50. sglang/srt/layers/layernorm.py +24 -2
  51. sglang/srt/layers/linear.py +17 -5
  52. sglang/srt/layers/logits_processor.py +26 -7
  53. sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
  54. sglang/srt/layers/moe/ep_moe/layer.py +273 -1
  55. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
  56. sglang/srt/layers/moe/fused_moe_native.py +2 -1
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
  63. sglang/srt/layers/moe/router.py +342 -0
  64. sglang/srt/layers/moe/topk.py +31 -18
  65. sglang/srt/layers/parameter.py +1 -1
  66. sglang/srt/layers/quantization/__init__.py +184 -126
  67. sglang/srt/layers/quantization/base_config.py +5 -0
  68. sglang/srt/layers/quantization/blockwise_int8.py +1 -1
  69. sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
  70. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
  71. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
  72. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
  73. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
  74. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
  75. sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
  76. sglang/srt/layers/quantization/fp8.py +76 -34
  77. sglang/srt/layers/quantization/fp8_kernel.py +24 -8
  78. sglang/srt/layers/quantization/fp8_utils.py +284 -28
  79. sglang/srt/layers/quantization/gptq.py +36 -9
  80. sglang/srt/layers/quantization/kv_cache.py +98 -0
  81. sglang/srt/layers/quantization/modelopt_quant.py +9 -7
  82. sglang/srt/layers/quantization/utils.py +153 -0
  83. sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
  84. sglang/srt/layers/rotary_embedding.py +66 -87
  85. sglang/srt/layers/sampler.py +1 -1
  86. sglang/srt/lora/layers.py +68 -0
  87. sglang/srt/lora/lora.py +2 -22
  88. sglang/srt/lora/lora_manager.py +47 -23
  89. sglang/srt/lora/mem_pool.py +110 -51
  90. sglang/srt/lora/utils.py +12 -1
  91. sglang/srt/managers/cache_controller.py +4 -5
  92. sglang/srt/managers/data_parallel_controller.py +31 -9
  93. sglang/srt/managers/expert_distribution.py +81 -0
  94. sglang/srt/managers/io_struct.py +39 -3
  95. sglang/srt/managers/mm_utils.py +373 -0
  96. sglang/srt/managers/multimodal_processor.py +68 -0
  97. sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
  98. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
  99. sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
  100. sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
  101. sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
  102. sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
  103. sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
  104. sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
  105. sglang/srt/managers/schedule_batch.py +134 -31
  106. sglang/srt/managers/scheduler.py +325 -38
  107. sglang/srt/managers/scheduler_output_processor_mixin.py +4 -1
  108. sglang/srt/managers/session_controller.py +1 -1
  109. sglang/srt/managers/tokenizer_manager.py +59 -23
  110. sglang/srt/managers/tp_worker.py +1 -1
  111. sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
  112. sglang/srt/managers/utils.py +6 -1
  113. sglang/srt/mem_cache/hiradix_cache.py +27 -8
  114. sglang/srt/mem_cache/memory_pool.py +258 -98
  115. sglang/srt/mem_cache/paged_allocator.py +2 -2
  116. sglang/srt/mem_cache/radix_cache.py +4 -4
  117. sglang/srt/model_executor/cuda_graph_runner.py +85 -28
  118. sglang/srt/model_executor/forward_batch_info.py +81 -15
  119. sglang/srt/model_executor/model_runner.py +70 -6
  120. sglang/srt/model_loader/loader.py +160 -2
  121. sglang/srt/model_loader/weight_utils.py +45 -0
  122. sglang/srt/models/deepseek_janus_pro.py +29 -86
  123. sglang/srt/models/deepseek_nextn.py +22 -10
  124. sglang/srt/models/deepseek_v2.py +326 -192
  125. sglang/srt/models/deepseek_vl2.py +358 -0
  126. sglang/srt/models/gemma3_causal.py +684 -0
  127. sglang/srt/models/gemma3_mm.py +462 -0
  128. sglang/srt/models/grok.py +374 -119
  129. sglang/srt/models/llama.py +47 -7
  130. sglang/srt/models/llama_eagle.py +1 -0
  131. sglang/srt/models/llama_eagle3.py +196 -0
  132. sglang/srt/models/llava.py +3 -3
  133. sglang/srt/models/llavavid.py +3 -3
  134. sglang/srt/models/minicpmo.py +1995 -0
  135. sglang/srt/models/minicpmv.py +62 -137
  136. sglang/srt/models/mllama.py +4 -4
  137. sglang/srt/models/phi3_small.py +1 -1
  138. sglang/srt/models/qwen2.py +3 -0
  139. sglang/srt/models/qwen2_5_vl.py +68 -146
  140. sglang/srt/models/qwen2_classification.py +75 -0
  141. sglang/srt/models/qwen2_moe.py +9 -1
  142. sglang/srt/models/qwen2_vl.py +25 -63
  143. sglang/srt/openai_api/adapter.py +145 -47
  144. sglang/srt/openai_api/protocol.py +23 -2
  145. sglang/srt/sampling/sampling_batch_info.py +1 -1
  146. sglang/srt/sampling/sampling_params.py +6 -6
  147. sglang/srt/server_args.py +104 -14
  148. sglang/srt/speculative/build_eagle_tree.py +7 -347
  149. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
  150. sglang/srt/speculative/eagle_utils.py +208 -252
  151. sglang/srt/speculative/eagle_worker.py +139 -53
  152. sglang/srt/speculative/spec_info.py +6 -1
  153. sglang/srt/torch_memory_saver_adapter.py +22 -0
  154. sglang/srt/utils.py +182 -21
  155. sglang/test/__init__.py +0 -0
  156. sglang/test/attention/__init__.py +0 -0
  157. sglang/test/attention/test_flashattn_backend.py +312 -0
  158. sglang/test/runners.py +2 -0
  159. sglang/test/test_activation.py +2 -1
  160. sglang/test/test_block_fp8.py +5 -4
  161. sglang/test/test_block_fp8_ep.py +2 -1
  162. sglang/test/test_dynamic_grad_mode.py +58 -0
  163. sglang/test/test_layernorm.py +3 -2
  164. sglang/test/test_utils.py +55 -4
  165. sglang/utils.py +31 -0
  166. sglang/version.py +1 -1
  167. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
  168. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +171 -125
  169. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
  170. sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
  171. sglang/srt/managers/image_processor.py +0 -55
  172. sglang/srt/managers/image_processors/base_image_processor.py +0 -219
  173. sglang/srt/managers/image_processors/minicpmv.py +0 -86
  174. sglang/srt/managers/multi_modality_padding.py +0 -134
  175. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
  176. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
@@ -32,16 +32,33 @@ import psutil
32
32
  import setproctitle
33
33
  import torch
34
34
  import zmq
35
+ from torch.distributed import barrier
35
36
 
36
37
  from sglang.global_config import global_config
37
38
  from sglang.srt.configs.model_config import ModelConfig
38
39
  from sglang.srt.constrained.base_grammar_backend import create_grammar_backend
40
+ from sglang.srt.disaggregation.decode import (
41
+ DecodePreallocQueue,
42
+ DecodeTransferQueue,
43
+ SchedulerDisaggregationDecodeMixin,
44
+ )
45
+ from sglang.srt.disaggregation.prefill import (
46
+ PrefillBootstrapQueue,
47
+ SchedulerDisaggregationPrefillMixin,
48
+ )
49
+ from sglang.srt.disaggregation.utils import (
50
+ DisaggregationMode,
51
+ ReqToMetadataIdxAllocator,
52
+ )
39
53
  from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
40
54
  from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
41
55
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
56
+ from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
42
57
  from sglang.srt.managers.io_struct import (
43
58
  AbortReq,
44
59
  CloseSessionReqInput,
60
+ ExpertDistributionReq,
61
+ ExpertDistributionReqOutput,
45
62
  FlushCacheReq,
46
63
  GetInternalStateReq,
47
64
  GetInternalStateReqOutput,
@@ -59,6 +76,8 @@ from sglang.srt.managers.io_struct import (
59
76
  ReleaseMemoryOccupationReqOutput,
60
77
  ResumeMemoryOccupationReqInput,
61
78
  ResumeMemoryOccupationReqOutput,
79
+ RpcReqInput,
80
+ RpcReqOutput,
62
81
  SetInternalStateReq,
63
82
  SetInternalStateReqOutput,
64
83
  TokenizedEmbeddingReqInput,
@@ -72,7 +91,7 @@ from sglang.srt.managers.io_struct import (
72
91
  )
73
92
  from sglang.srt.managers.schedule_batch import (
74
93
  FINISH_ABORT,
75
- ImageInputs,
94
+ MultimodalInputs,
76
95
  Req,
77
96
  ScheduleBatch,
78
97
  global_server_args_dict,
@@ -98,6 +117,7 @@ from sglang.srt.server_args import PortArgs, ServerArgs
98
117
  from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
99
118
  from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
100
119
  from sglang.srt.utils import (
120
+ DynamicGradMode,
101
121
  broadcast_pyobj,
102
122
  configure_logger,
103
123
  crash_on_warnings,
@@ -111,6 +131,8 @@ from sglang.srt.utils import (
111
131
  )
112
132
  from sglang.utils import TypeBasedDispatcher, get_exception_traceback
113
133
 
134
+ expert_distribution_recorder = ExpertDistributionRecorder()
135
+
114
136
  logger = logging.getLogger(__name__)
115
137
 
116
138
  # Test retract decode for debugging purposes
@@ -133,7 +155,11 @@ class EmbeddingBatchResult:
133
155
  bid: int
134
156
 
135
157
 
136
- class Scheduler(SchedulerOutputProcessorMixin):
158
+ class Scheduler(
159
+ SchedulerOutputProcessorMixin,
160
+ SchedulerDisaggregationDecodeMixin,
161
+ SchedulerDisaggregationPrefillMixin,
162
+ ):
137
163
  """A scheduler that manages a tensor parallel GPU worker."""
138
164
 
139
165
  def __init__(
@@ -193,8 +219,13 @@ class Scheduler(SchedulerOutputProcessorMixin):
193
219
  self.send_to_detokenizer = get_zmq_socket(
194
220
  context, zmq.PUSH, port_args.detokenizer_ipc_name, False
195
221
  )
222
+
223
+ self.recv_from_rpc = get_zmq_socket(
224
+ context, zmq.DEALER, port_args.rpc_ipc_name, False
225
+ )
196
226
  else:
197
227
  self.recv_from_tokenizer = None
228
+ self.recv_from_rpc = None
198
229
  self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
199
230
  self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
200
231
 
@@ -376,9 +407,16 @@ class Scheduler(SchedulerOutputProcessorMixin):
376
407
  (ProfileReq, self.profile),
377
408
  (GetInternalStateReq, self.get_internal_state),
378
409
  (SetInternalStateReq, self.set_internal_state),
410
+ (RpcReqInput, self.handle_rpc_request),
411
+ (ExpertDistributionReq, self.expert_distribution_handle),
379
412
  ]
380
413
  )
381
414
 
415
+ self.disaggregation_mode = DisaggregationMode(
416
+ self.server_args.disaggregation_mode
417
+ )
418
+ self.init_disaggregation()
419
+
382
420
  def init_tokenizer(self):
383
421
  server_args = self.server_args
384
422
 
@@ -434,6 +472,8 @@ class Scheduler(SchedulerOutputProcessorMixin):
434
472
  req_to_token_pool=self.req_to_token_pool,
435
473
  token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
436
474
  tp_cache_group=self.tp_worker.get_tp_cpu_group(),
475
+ page_size=self.page_size,
476
+ hicache_ratio=server_args.hicache_ratio,
437
477
  )
438
478
  else:
439
479
  self.tree_cache = RadixCache(
@@ -477,7 +517,74 @@ class Scheduler(SchedulerOutputProcessorMixin):
477
517
  },
478
518
  )
479
519
 
480
- @torch.no_grad()
520
+ def init_disaggregation(self):
521
+ if (
522
+ self.disaggregation_mode == DisaggregationMode.DECODE
523
+ ): # *2 for the headroom.
524
+ buffer_size = (self.req_to_token_pool.size) * 2
525
+ req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
526
+ buffer_size
527
+ )
528
+ aux_dtype = torch.int32
529
+ # A list of metadata buffers. The shape is (b, metadata_size) where
530
+ # b corresponds to a max running requests. The last shape * dtype.itemsize
531
+ # should be larger than 64 bytes to work with RDMA, so we pad it.
532
+ output_id_buffer = torch.zeros(
533
+ (buffer_size, 16), dtype=aux_dtype, device="cpu"
534
+ )
535
+ metadata_buffers = [output_id_buffer]
536
+
537
+ # The decode requests polling kv cache
538
+ self.disagg_decode_transfer_queue = DecodeTransferQueue(
539
+ gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
540
+ req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
541
+ metadata_buffers=metadata_buffers,
542
+ )
543
+
544
+ # The decode requests pending for pre-allocation
545
+ self.disagg_decode_prealloc_queue = DecodePreallocQueue(
546
+ req_to_token_pool=self.req_to_token_pool,
547
+ token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
548
+ req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
549
+ metadata_buffers=metadata_buffers,
550
+ aux_dtype=aux_dtype,
551
+ scheduler=self,
552
+ transfer_queue=self.disagg_decode_transfer_queue,
553
+ tree_cache=self.tree_cache,
554
+ gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
555
+ tp_rank=self.tp_rank,
556
+ tp_size=self.tp_size,
557
+ bootstrap_port=self.server_args.disaggregation_bootstrap_port,
558
+ )
559
+ elif self.disaggregation_mode == DisaggregationMode.PREFILL:
560
+ # *2 for the headroom.
561
+ buffer_size = self.max_running_requests * 2
562
+ req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
563
+ buffer_size
564
+ )
565
+ aux_dtype = torch.int32
566
+ # A list of metadata buffers. The shape is (b, metadata_size) where
567
+ # b corresponds to a max running requests. The last shape * dtype.itemsize
568
+ # should be larger than 64 bytes to work with RDMA, so we pad it.
569
+ output_id_buffer = torch.zeros(
570
+ (buffer_size, 16), dtype=aux_dtype, device="cpu"
571
+ )
572
+ metadata_buffers = [output_id_buffer]
573
+
574
+ self.disagg_prefill_pending_queue = PrefillBootstrapQueue(
575
+ token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
576
+ req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
577
+ metadata_buffers=metadata_buffers,
578
+ aux_dtype=aux_dtype,
579
+ tp_rank=self.tp_rank,
580
+ tp_size=self.tp_size,
581
+ bootstrap_port=self.server_args.disaggregation_bootstrap_port,
582
+ gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
583
+ )
584
+ # The prefill requests that are in the middle of kv sending
585
+ self.disagg_prefill_infight_queue: List[Req] = []
586
+
587
+ @DynamicGradMode()
481
588
  def event_loop_normal(self):
482
589
  """A normal scheduler loop."""
483
590
  while True:
@@ -497,7 +604,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
497
604
 
498
605
  self.last_batch = batch
499
606
 
500
- @torch.no_grad()
607
+ @DynamicGradMode()
501
608
  def event_loop_overlap(self):
502
609
  """A scheduler loop that overlaps the CPU processing and GPU computation."""
503
610
  self.result_queue = deque()
@@ -537,6 +644,70 @@ class Scheduler(SchedulerOutputProcessorMixin):
537
644
 
538
645
  self.last_batch = batch
539
646
 
647
+ @torch.no_grad()
648
+ def event_loop_normal_disagg_prefill(self):
649
+ """A normal scheduler loop for prefill worker in disaggregation mode."""
650
+
651
+ while True:
652
+ recv_reqs = self.recv_requests()
653
+ self.process_input_requests(recv_reqs)
654
+ self.waiting_queue.extend(
655
+ self.disagg_prefill_pending_queue.pop_bootstrapped()
656
+ )
657
+ self.process_prefill_chunk()
658
+ batch = self.get_new_batch_prefill()
659
+ self.cur_batch = batch
660
+
661
+ if batch:
662
+ result = self.run_batch(batch)
663
+ self.process_batch_result_disagg_prefill(batch, result)
664
+
665
+ if len(self.disagg_prefill_infight_queue) > 0:
666
+ self.process_disagg_prefill_infight_queue()
667
+
668
+ if batch is None and len(self.disagg_prefill_infight_queue) == 0:
669
+ self.check_memory()
670
+ self.new_token_ratio = self.init_new_token_ratio
671
+
672
+ self.last_batch = batch
673
+ # HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
674
+ # Otherwise, it hangs under high concurrency
675
+ self.running_batch.batch_is_full = False
676
+
677
+ @torch.no_grad()
678
+ def event_loop_normal_disagg_decode(self):
679
+ """A normal scheduler loop for decode worker in disaggregation mode."""
680
+
681
+ while True:
682
+ recv_reqs = self.recv_requests()
683
+ self.process_input_requests(recv_reqs)
684
+ # polling and allocating kv cache
685
+ self.process_decode_queue()
686
+ batch = self.get_next_disagg_decode_batch_to_run()
687
+ self.cur_batch = batch
688
+
689
+ if batch:
690
+ # Generate fake extend output.
691
+ if batch.forward_mode.is_extend():
692
+ # Note: Logprobs should be handled on the prefill engine.
693
+ self.stream_output(
694
+ batch.reqs, [False for _ in range(len(batch.reqs))]
695
+ )
696
+ else:
697
+ result = self.run_batch(batch)
698
+ self.process_batch_result(batch, result)
699
+
700
+ if batch is None and (
701
+ len(self.disagg_decode_transfer_queue.queue)
702
+ + len(self.disagg_decode_prealloc_queue.queue)
703
+ == 0
704
+ ):
705
+ # When the server is idle, do self-check and re-init some states
706
+ self.check_memory()
707
+ self.new_token_ratio = self.init_new_token_ratio
708
+
709
+ self.last_batch = batch
710
+
540
711
  def recv_requests(self) -> List[Req]:
541
712
  """Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
542
713
  if self.attn_tp_rank == 0:
@@ -548,6 +719,13 @@ class Scheduler(SchedulerOutputProcessorMixin):
548
719
  except zmq.ZMQError:
549
720
  break
550
721
  recv_reqs.append(recv_req)
722
+
723
+ while True:
724
+ try:
725
+ recv_rpc = self.recv_from_rpc.recv_pyobj(zmq.NOBLOCK)
726
+ except zmq.ZMQError:
727
+ break
728
+ recv_reqs.append(recv_rpc)
551
729
  else:
552
730
  recv_reqs = None
553
731
 
@@ -599,7 +777,11 @@ class Scheduler(SchedulerOutputProcessorMixin):
599
777
 
600
778
  output = self._request_dispatcher(recv_req)
601
779
  if output is not None:
602
- self.send_to_tokenizer.send_pyobj(output)
780
+ if isinstance(output, RpcReqOutput):
781
+ if self.recv_from_rpc is not None:
782
+ self.recv_from_rpc.send_pyobj(output)
783
+ else:
784
+ self.send_to_tokenizer.send_pyobj(output)
603
785
 
604
786
  def handle_generate_request(
605
787
  self,
@@ -665,8 +847,8 @@ class Scheduler(SchedulerOutputProcessorMixin):
665
847
  return
666
848
 
667
849
  # Handle multimodal inputs
668
- if recv_req.image_inputs is not None:
669
- image_inputs = ImageInputs.from_dict(recv_req.image_inputs)
850
+ if recv_req.mm_inputs is not None:
851
+ image_inputs = MultimodalInputs.from_dict(recv_req.mm_inputs)
670
852
  # Expand a single image token into multiple dummy tokens for receiving image embeddings
671
853
  req.origin_input_ids = self.pad_input_ids_func(
672
854
  req.origin_input_ids, image_inputs
@@ -680,7 +862,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
680
862
  )
681
863
  logger.error(error_msg)
682
864
  req.origin_input_ids = [0]
683
- req.image_inputs = None
865
+ req.multimodal_inputs = None
684
866
  req.sampling_params.max_new_tokens = 0
685
867
  req.finished_reason = FINISH_ABORT(
686
868
  error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
@@ -755,10 +937,20 @@ class Scheduler(SchedulerOutputProcessorMixin):
755
937
  self._add_request_to_queue(req)
756
938
 
757
939
  def _add_request_to_queue(self, req: Req):
758
- self.waiting_queue.append(req)
940
+ if self.disaggregation_mode == DisaggregationMode.PREFILL:
941
+ self.disagg_prefill_pending_queue.add(req)
942
+
943
+ elif self.disaggregation_mode == DisaggregationMode.DECODE:
944
+ self.disagg_decode_prealloc_queue.add(req)
759
945
 
760
- def _extend_requests_to_queue(self, reqs: List[Req]):
761
- self.waiting_queue.extend(reqs)
946
+ else:
947
+ self.waiting_queue.append(req)
948
+
949
+ def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False):
950
+ if self.disaggregation_mode == DisaggregationMode.DECODE:
951
+ self.disagg_decode_prealloc_queue.extend(reqs)
952
+ else:
953
+ self.waiting_queue.extend(reqs)
762
954
 
763
955
  def handle_embedding_request(
764
956
  self,
@@ -774,7 +966,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
774
966
 
775
967
  # Handle multimodal inputs
776
968
  if recv_req.image_inputs is not None:
777
- image_inputs = ImageInputs.from_dict(recv_req.image_inputs)
969
+ image_inputs = MultimodalInputs.from_dict(recv_req.image_inputs)
778
970
  # Expand a single image token into multiple dummy tokens for receiving image embeddings
779
971
  req.origin_input_ids = self.pad_input_ids_func(
780
972
  req.origin_input_ids, image_inputs
@@ -788,7 +980,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
788
980
  )
789
981
  logger.error(error_msg)
790
982
  req.origin_input_ids = [0]
791
- req.image_inputs = None
983
+ req.multimodal_inputs = None
792
984
  req.sampling_params.max_new_tokens = 0
793
985
  req.finished_reason = FINISH_ABORT(
794
986
  error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
@@ -874,7 +1066,6 @@ class Scheduler(SchedulerOutputProcessorMixin):
874
1066
  f"#token: {num_used}, "
875
1067
  f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
876
1068
  f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
877
- f"largest-len: {self._largest_prefill_decode_len}, "
878
1069
  f"#queue-req: {len(self.waiting_queue)}, "
879
1070
  )
880
1071
  spec_accept_length = 0
@@ -892,7 +1083,6 @@ class Scheduler(SchedulerOutputProcessorMixin):
892
1083
  f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
893
1084
  f"accept len: {spec_accept_length:.2f}, "
894
1085
  f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
895
- f"largest-len: {self._largest_prefill_decode_len}, "
896
1086
  f"#queue-req: {len(self.waiting_queue)}, "
897
1087
  )
898
1088
 
@@ -997,7 +1187,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
997
1187
 
998
1188
  # Handle DP attention
999
1189
  if self.server_args.enable_dp_attention:
1000
- ret = self.prepare_dp_attn_batch(ret)
1190
+ ret, _ = self.prepare_dp_attn_batch(ret)
1001
1191
 
1002
1192
  return ret
1003
1193
 
@@ -1269,39 +1459,72 @@ class Scheduler(SchedulerOutputProcessorMixin):
1269
1459
  # Check if other DP workers have running batches
1270
1460
  if local_batch is None:
1271
1461
  num_tokens = 0
1462
+ global_num_tokens_for_logprob = 0
1272
1463
  elif local_batch.forward_mode.is_decode():
1273
1464
  num_tokens = local_batch.batch_size()
1465
+ if not self.spec_algorithm.is_none() and self.spec_algorithm.is_eagle():
1466
+ num_tokens = num_tokens * self.server_args.speculative_num_draft_tokens
1467
+ global_num_tokens_for_logprob = num_tokens
1274
1468
  else:
1275
1469
  num_tokens = local_batch.extend_num_tokens
1470
+ global_num_tokens_for_logprob = sum(
1471
+ [
1472
+ # We should have at least 1 token for sample in every case.
1473
+ max(extend_len - logprob_start_len, 1)
1474
+ for logprob_start_len, extend_len in zip(
1475
+ local_batch.extend_logprob_start_lens, local_batch.extend_lens
1476
+ )
1477
+ ]
1478
+ )
1479
+
1480
+ if local_batch is None or local_batch.forward_mode.is_decode_or_idle():
1481
+ can_cuda_graph = 1
1482
+ else:
1483
+ can_cuda_graph = 0
1484
+
1485
+ if not self.spec_algorithm.is_none():
1486
+ # TODO(sang): Support cuda graph when idle batch is there.
1487
+ if local_batch is None or local_batch.forward_mode.is_idle():
1488
+ can_cuda_graph = 0
1276
1489
 
1277
- local_num_tokens = torch.tensor([num_tokens], dtype=torch.int64)
1278
- global_num_tokens = torch.empty(self.tp_size, dtype=torch.int64)
1490
+ is_extend_in_batch = (
1491
+ local_batch.forward_mode.is_extend() if local_batch else False
1492
+ )
1493
+ local_info = torch.tensor(
1494
+ [
1495
+ num_tokens,
1496
+ can_cuda_graph,
1497
+ global_num_tokens_for_logprob,
1498
+ is_extend_in_batch,
1499
+ ],
1500
+ dtype=torch.int64,
1501
+ )
1502
+ global_info = torch.empty(
1503
+ (self.server_args.dp_size, self.attn_tp_size, 4),
1504
+ dtype=torch.int64,
1505
+ )
1279
1506
  torch.distributed.all_gather_into_tensor(
1280
- global_num_tokens,
1281
- local_num_tokens,
1507
+ global_info.flatten(),
1508
+ local_info,
1282
1509
  group=self.tp_cpu_group,
1283
1510
  )
1511
+ global_num_tokens = global_info[:, 0, 0].tolist()
1512
+ can_cuda_graph = min(global_info[:, 0, 1].tolist())
1513
+ global_num_tokens_for_logprob = global_info[:, 0, 2].tolist()
1514
+ is_extend_in_batch = global_info[:, 0, 3].tolist()
1284
1515
 
1285
- if local_batch is None and global_num_tokens.max().item() > 0:
1516
+ if local_batch is None and max(global_num_tokens) > 0:
1286
1517
  local_batch = self.get_idle_batch()
1287
1518
 
1288
1519
  if local_batch is not None:
1289
- local_batch.global_num_tokens = global_num_tokens.tolist()
1520
+ local_batch.global_num_tokens = global_num_tokens
1521
+ local_batch.global_num_tokens_for_logprob = global_num_tokens_for_logprob
1290
1522
 
1291
1523
  # Check forward mode for cuda graph
1292
1524
  if not self.server_args.disable_cuda_graph:
1293
- forward_mode_state = torch.tensor(
1294
- (1 if local_batch.forward_mode.is_decode_or_idle() else 0),
1295
- dtype=torch.int32,
1296
- )
1297
- torch.distributed.all_reduce(
1298
- forward_mode_state,
1299
- op=torch.distributed.ReduceOp.MIN,
1300
- group=self.tp_cpu_group,
1301
- )
1302
- local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1
1525
+ local_batch.can_run_dp_cuda_graph = can_cuda_graph
1303
1526
 
1304
- return local_batch
1527
+ return local_batch, any(is_extend_in_batch)
1305
1528
 
1306
1529
  def get_idle_batch(self):
1307
1530
  idle_batch = ScheduleBatch.init_new(
@@ -1458,6 +1681,47 @@ class Scheduler(SchedulerOutputProcessorMixin):
1458
1681
  server_args=global_server_args_dict,
1459
1682
  )
1460
1683
 
1684
+ def handle_rpc_request(self, recv_req: RpcReqInput):
1685
+ # Handle RPC requests
1686
+ logger.info(
1687
+ f"handle_rpc_request: {recv_req.method}, param: {recv_req.parameters}"
1688
+ )
1689
+
1690
+ success = True
1691
+ exec = None
1692
+ try:
1693
+ func = getattr(self, recv_req.method)
1694
+ func(recv_req.parameters)
1695
+ except Exception as e:
1696
+ success = False
1697
+ exec = e
1698
+ logger.error(f"Failed to call rpc {recv_req.method}: {str(e)}")
1699
+
1700
+ barrier()
1701
+ return RpcReqOutput(success, "" if not exec else str(exec))
1702
+
1703
+ def save_remote_model(self, params):
1704
+ url = params["url"]
1705
+
1706
+ if isinstance(self.tp_worker, TpModelWorkerClient):
1707
+ worker = self.tp_worker.worker
1708
+ else:
1709
+ worker = self.tp_worker
1710
+
1711
+ worker.model_runner.save_remote_model(url)
1712
+
1713
+ def save_sharded_model(self, params):
1714
+ if isinstance(self.tp_worker, TpModelWorkerClient):
1715
+ worker = self.tp_worker.worker
1716
+ else:
1717
+ worker = self.tp_worker
1718
+
1719
+ worker.model_runner.save_sharded_model(
1720
+ path=params["path"],
1721
+ pattern=params["pattern"],
1722
+ max_size=params["max_size"],
1723
+ )
1724
+
1461
1725
  def abort_request(self, recv_req: AbortReq):
1462
1726
  # Delete requests in the waiting queue
1463
1727
  to_del = []
@@ -1527,6 +1791,9 @@ class Scheduler(SchedulerOutputProcessorMixin):
1527
1791
  return GetWeightsByNameReqOutput(parameter)
1528
1792
 
1529
1793
  def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
1794
+ self.memory_saver_adapter.check_validity(
1795
+ caller_name="release_memory_occupation"
1796
+ )
1530
1797
  self.stashed_model_static_state = _export_static_state(
1531
1798
  self.tp_worker.worker.model_runner.model
1532
1799
  )
@@ -1535,6 +1802,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
1535
1802
  return ReleaseMemoryOccupationReqOutput()
1536
1803
 
1537
1804
  def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
1805
+ self.memory_saver_adapter.check_validity(caller_name="resume_memory_occupation")
1538
1806
  self.memory_saver_adapter.resume()
1539
1807
  _import_static_state(
1540
1808
  self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
@@ -1634,6 +1902,17 @@ class Scheduler(SchedulerOutputProcessorMixin):
1634
1902
  ProfileReqOutput(success=True, message="Succeeded.")
1635
1903
  )
1636
1904
 
1905
+ def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
1906
+ if recv_req == ExpertDistributionReq.START_RECORD:
1907
+ expert_distribution_recorder.start_record()
1908
+ elif recv_req == ExpertDistributionReq.STOP_RECORD:
1909
+ expert_distribution_recorder.stop_record()
1910
+ elif recv_req == ExpertDistributionReq.DUMP_RECORD:
1911
+ expert_distribution_recorder.dump_record()
1912
+ else:
1913
+ raise ValueError("Unrecognized ExpertDistributionReq value")
1914
+ return ExpertDistributionReqOutput()
1915
+
1637
1916
  def open_session(self, recv_req: OpenSessionReqInput):
1638
1917
  # handle error
1639
1918
  session_id = recv_req.session_id
@@ -1692,7 +1971,7 @@ def run_scheduler_process(
1692
1971
  prefix = f" DP{dp_rank} TP{tp_rank}"
1693
1972
 
1694
1973
  # Config the process
1695
- # kill_itself_when_parent_died() # This is disabled because it does not work for `--dp 2`
1974
+ kill_itself_when_parent_died()
1696
1975
  setproctitle.setproctitle(f"sglang::scheduler{prefix.replace(' ', '_')}")
1697
1976
  faulthandler.enable()
1698
1977
  parent_process = psutil.Process().parent()
@@ -1719,10 +1998,18 @@ def run_scheduler_process(
1719
1998
  "max_req_input_len": scheduler.max_req_input_len,
1720
1999
  }
1721
2000
  )
1722
- if scheduler.enable_overlap:
1723
- scheduler.event_loop_overlap()
1724
- else:
1725
- scheduler.event_loop_normal()
2001
+ disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode
2002
+
2003
+ if disaggregation_mode == DisaggregationMode.NULL:
2004
+ if scheduler.enable_overlap:
2005
+ scheduler.event_loop_overlap()
2006
+ else:
2007
+ scheduler.event_loop_normal()
2008
+ elif disaggregation_mode == DisaggregationMode.PREFILL:
2009
+ scheduler.event_loop_normal_disagg_prefill()
2010
+ elif disaggregation_mode == DisaggregationMode.DECODE:
2011
+ scheduler.event_loop_normal_disagg_decode()
2012
+
1726
2013
  except Exception:
1727
2014
  traceback = get_exception_traceback()
1728
2015
  logger.error(f"Scheduler hit an exception: {traceback}")
@@ -111,6 +111,7 @@ class SchedulerOutputProcessorMixin:
111
111
  ]
112
112
  .cpu()
113
113
  .clone()
114
+ .tolist()
114
115
  )
115
116
 
116
117
  if req.grammar is not None:
@@ -245,7 +246,9 @@ class SchedulerOutputProcessorMixin:
245
246
  )
246
247
 
247
248
  if req.return_hidden_states and logits_output.hidden_states is not None:
248
- req.hidden_states.append(logits_output.hidden_states[i].cpu().clone())
249
+ req.hidden_states.append(
250
+ logits_output.hidden_states[i].cpu().clone().tolist()
251
+ )
249
252
 
250
253
  if req.grammar is not None and batch.spec_algorithm.is_none():
251
254
  req.grammar.accept_token(next_token_id)
@@ -138,7 +138,7 @@ class Session:
138
138
  token_ids_logprob=req.token_ids_logprob,
139
139
  )
140
140
  if last_req is not None:
141
- new_req.image_inputs = last_req.image_inputs
141
+ new_req.multimodal_inputs = last_req.mm_inputs
142
142
  new_req.tokenizer = tokenizer
143
143
  if abort:
144
144
  new_req.to_abort = True