sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (172) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +6 -0
  3. sglang/bench_one_batch.py +1 -1
  4. sglang/bench_one_batch_server.py +1 -1
  5. sglang/bench_serving.py +3 -1
  6. sglang/check_env.py +3 -4
  7. sglang/lang/backend/openai.py +18 -5
  8. sglang/lang/chat_template.py +28 -7
  9. sglang/lang/interpreter.py +7 -3
  10. sglang/lang/ir.py +10 -0
  11. sglang/srt/_custom_ops.py +1 -1
  12. sglang/srt/code_completion_parser.py +174 -0
  13. sglang/srt/configs/__init__.py +2 -6
  14. sglang/srt/configs/deepseekvl2.py +667 -0
  15. sglang/srt/configs/janus_pro.py +3 -4
  16. sglang/srt/configs/load_config.py +1 -0
  17. sglang/srt/configs/model_config.py +63 -11
  18. sglang/srt/configs/utils.py +25 -0
  19. sglang/srt/connector/__init__.py +51 -0
  20. sglang/srt/connector/base_connector.py +112 -0
  21. sglang/srt/connector/redis.py +85 -0
  22. sglang/srt/connector/s3.py +122 -0
  23. sglang/srt/connector/serde/__init__.py +31 -0
  24. sglang/srt/connector/serde/safe_serde.py +29 -0
  25. sglang/srt/connector/serde/serde.py +43 -0
  26. sglang/srt/connector/utils.py +35 -0
  27. sglang/srt/conversation.py +88 -0
  28. sglang/srt/disaggregation/conn.py +81 -0
  29. sglang/srt/disaggregation/decode.py +495 -0
  30. sglang/srt/disaggregation/mini_lb.py +285 -0
  31. sglang/srt/disaggregation/prefill.py +249 -0
  32. sglang/srt/disaggregation/utils.py +44 -0
  33. sglang/srt/distributed/parallel_state.py +10 -3
  34. sglang/srt/entrypoints/engine.py +55 -5
  35. sglang/srt/entrypoints/http_server.py +71 -12
  36. sglang/srt/function_call_parser.py +133 -54
  37. sglang/srt/hf_transformers_utils.py +28 -3
  38. sglang/srt/layers/activation.py +4 -2
  39. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +295 -0
  41. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  42. sglang/srt/layers/attention/flashmla_backend.py +284 -0
  43. sglang/srt/layers/attention/triton_backend.py +171 -38
  44. sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
  45. sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
  46. sglang/srt/layers/attention/utils.py +53 -0
  47. sglang/srt/layers/attention/vision.py +9 -28
  48. sglang/srt/layers/dp_attention.py +32 -21
  49. sglang/srt/layers/layernorm.py +24 -2
  50. sglang/srt/layers/linear.py +17 -5
  51. sglang/srt/layers/logits_processor.py +25 -7
  52. sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
  53. sglang/srt/layers/moe/ep_moe/layer.py +273 -1
  54. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
  55. sglang/srt/layers/moe/fused_moe_native.py +2 -1
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
  61. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
  62. sglang/srt/layers/moe/topk.py +31 -18
  63. sglang/srt/layers/parameter.py +1 -1
  64. sglang/srt/layers/quantization/__init__.py +184 -126
  65. sglang/srt/layers/quantization/base_config.py +5 -0
  66. sglang/srt/layers/quantization/blockwise_int8.py +1 -1
  67. sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
  68. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
  69. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
  70. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
  71. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
  72. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
  73. sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
  74. sglang/srt/layers/quantization/fp8.py +76 -34
  75. sglang/srt/layers/quantization/fp8_kernel.py +24 -8
  76. sglang/srt/layers/quantization/fp8_utils.py +284 -28
  77. sglang/srt/layers/quantization/gptq.py +36 -9
  78. sglang/srt/layers/quantization/kv_cache.py +98 -0
  79. sglang/srt/layers/quantization/modelopt_quant.py +9 -7
  80. sglang/srt/layers/quantization/utils.py +153 -0
  81. sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
  82. sglang/srt/layers/rotary_embedding.py +66 -87
  83. sglang/srt/layers/sampler.py +1 -1
  84. sglang/srt/lora/layers.py +68 -0
  85. sglang/srt/lora/lora.py +2 -22
  86. sglang/srt/lora/lora_manager.py +47 -23
  87. sglang/srt/lora/mem_pool.py +110 -51
  88. sglang/srt/lora/utils.py +12 -1
  89. sglang/srt/managers/cache_controller.py +2 -5
  90. sglang/srt/managers/data_parallel_controller.py +30 -8
  91. sglang/srt/managers/expert_distribution.py +81 -0
  92. sglang/srt/managers/io_struct.py +39 -3
  93. sglang/srt/managers/mm_utils.py +373 -0
  94. sglang/srt/managers/multimodal_processor.py +68 -0
  95. sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
  96. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
  97. sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
  98. sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
  99. sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
  100. sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
  101. sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
  102. sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
  103. sglang/srt/managers/schedule_batch.py +133 -30
  104. sglang/srt/managers/scheduler.py +273 -20
  105. sglang/srt/managers/session_controller.py +1 -1
  106. sglang/srt/managers/tokenizer_manager.py +59 -23
  107. sglang/srt/managers/tp_worker.py +1 -1
  108. sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
  109. sglang/srt/managers/utils.py +6 -1
  110. sglang/srt/mem_cache/hiradix_cache.py +18 -7
  111. sglang/srt/mem_cache/memory_pool.py +255 -98
  112. sglang/srt/mem_cache/paged_allocator.py +2 -2
  113. sglang/srt/mem_cache/radix_cache.py +4 -4
  114. sglang/srt/model_executor/cuda_graph_runner.py +27 -13
  115. sglang/srt/model_executor/forward_batch_info.py +68 -11
  116. sglang/srt/model_executor/model_runner.py +70 -6
  117. sglang/srt/model_loader/loader.py +160 -2
  118. sglang/srt/model_loader/weight_utils.py +45 -0
  119. sglang/srt/models/deepseek_janus_pro.py +29 -86
  120. sglang/srt/models/deepseek_nextn.py +22 -10
  121. sglang/srt/models/deepseek_v2.py +208 -77
  122. sglang/srt/models/deepseek_vl2.py +358 -0
  123. sglang/srt/models/gemma3_causal.py +684 -0
  124. sglang/srt/models/gemma3_mm.py +462 -0
  125. sglang/srt/models/llama.py +47 -7
  126. sglang/srt/models/llama_eagle.py +1 -0
  127. sglang/srt/models/llama_eagle3.py +196 -0
  128. sglang/srt/models/llava.py +3 -3
  129. sglang/srt/models/llavavid.py +3 -3
  130. sglang/srt/models/minicpmo.py +1995 -0
  131. sglang/srt/models/minicpmv.py +62 -137
  132. sglang/srt/models/mllama.py +4 -4
  133. sglang/srt/models/phi3_small.py +1 -1
  134. sglang/srt/models/qwen2.py +3 -0
  135. sglang/srt/models/qwen2_5_vl.py +68 -146
  136. sglang/srt/models/qwen2_classification.py +75 -0
  137. sglang/srt/models/qwen2_moe.py +9 -1
  138. sglang/srt/models/qwen2_vl.py +25 -63
  139. sglang/srt/openai_api/adapter.py +124 -28
  140. sglang/srt/openai_api/protocol.py +23 -2
  141. sglang/srt/sampling/sampling_batch_info.py +1 -1
  142. sglang/srt/sampling/sampling_params.py +6 -6
  143. sglang/srt/server_args.py +99 -9
  144. sglang/srt/speculative/build_eagle_tree.py +7 -347
  145. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
  146. sglang/srt/speculative/eagle_utils.py +208 -252
  147. sglang/srt/speculative/eagle_worker.py +139 -53
  148. sglang/srt/speculative/spec_info.py +6 -1
  149. sglang/srt/torch_memory_saver_adapter.py +22 -0
  150. sglang/srt/utils.py +182 -21
  151. sglang/test/__init__.py +0 -0
  152. sglang/test/attention/__init__.py +0 -0
  153. sglang/test/attention/test_flashattn_backend.py +312 -0
  154. sglang/test/runners.py +2 -0
  155. sglang/test/test_activation.py +2 -1
  156. sglang/test/test_block_fp8.py +5 -4
  157. sglang/test/test_block_fp8_ep.py +2 -1
  158. sglang/test/test_dynamic_grad_mode.py +58 -0
  159. sglang/test/test_layernorm.py +3 -2
  160. sglang/test/test_utils.py +55 -4
  161. sglang/utils.py +31 -0
  162. sglang/version.py +1 -1
  163. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
  164. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +167 -123
  165. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
  166. sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
  167. sglang/srt/managers/image_processor.py +0 -55
  168. sglang/srt/managers/image_processors/base_image_processor.py +0 -219
  169. sglang/srt/managers/image_processors/minicpmv.py +0 -86
  170. sglang/srt/managers/multi_modality_padding.py +0 -134
  171. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
  172. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
@@ -32,16 +32,33 @@ import psutil
32
32
  import setproctitle
33
33
  import torch
34
34
  import zmq
35
+ from torch.distributed import barrier
35
36
 
36
37
  from sglang.global_config import global_config
37
38
  from sglang.srt.configs.model_config import ModelConfig
38
39
  from sglang.srt.constrained.base_grammar_backend import create_grammar_backend
40
+ from sglang.srt.disaggregation.decode import (
41
+ DecodePreallocQueue,
42
+ DecodeTransferQueue,
43
+ SchedulerDisaggregationDecodeMixin,
44
+ )
45
+ from sglang.srt.disaggregation.prefill import (
46
+ PrefillBootstrapQueue,
47
+ SchedulerDisaggregationPrefillMixin,
48
+ )
49
+ from sglang.srt.disaggregation.utils import (
50
+ DisaggregationMode,
51
+ ReqToMetadataIdxAllocator,
52
+ )
39
53
  from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
40
54
  from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
41
55
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
56
+ from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
42
57
  from sglang.srt.managers.io_struct import (
43
58
  AbortReq,
44
59
  CloseSessionReqInput,
60
+ ExpertDistributionReq,
61
+ ExpertDistributionReqOutput,
45
62
  FlushCacheReq,
46
63
  GetInternalStateReq,
47
64
  GetInternalStateReqOutput,
@@ -59,6 +76,8 @@ from sglang.srt.managers.io_struct import (
59
76
  ReleaseMemoryOccupationReqOutput,
60
77
  ResumeMemoryOccupationReqInput,
61
78
  ResumeMemoryOccupationReqOutput,
79
+ RpcReqInput,
80
+ RpcReqOutput,
62
81
  SetInternalStateReq,
63
82
  SetInternalStateReqOutput,
64
83
  TokenizedEmbeddingReqInput,
@@ -72,7 +91,7 @@ from sglang.srt.managers.io_struct import (
72
91
  )
73
92
  from sglang.srt.managers.schedule_batch import (
74
93
  FINISH_ABORT,
75
- ImageInputs,
94
+ MultimodalInputs,
76
95
  Req,
77
96
  ScheduleBatch,
78
97
  global_server_args_dict,
@@ -98,6 +117,7 @@ from sglang.srt.server_args import PortArgs, ServerArgs
98
117
  from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
99
118
  from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
100
119
  from sglang.srt.utils import (
120
+ DynamicGradMode,
101
121
  broadcast_pyobj,
102
122
  configure_logger,
103
123
  crash_on_warnings,
@@ -111,6 +131,8 @@ from sglang.srt.utils import (
111
131
  )
112
132
  from sglang.utils import TypeBasedDispatcher, get_exception_traceback
113
133
 
134
+ expert_distribution_recorder = ExpertDistributionRecorder()
135
+
114
136
  logger = logging.getLogger(__name__)
115
137
 
116
138
  # Test retract decode for debugging purposes
@@ -133,7 +155,11 @@ class EmbeddingBatchResult:
133
155
  bid: int
134
156
 
135
157
 
136
- class Scheduler(SchedulerOutputProcessorMixin):
158
+ class Scheduler(
159
+ SchedulerOutputProcessorMixin,
160
+ SchedulerDisaggregationDecodeMixin,
161
+ SchedulerDisaggregationPrefillMixin,
162
+ ):
137
163
  """A scheduler that manages a tensor parallel GPU worker."""
138
164
 
139
165
  def __init__(
@@ -193,8 +219,13 @@ class Scheduler(SchedulerOutputProcessorMixin):
193
219
  self.send_to_detokenizer = get_zmq_socket(
194
220
  context, zmq.PUSH, port_args.detokenizer_ipc_name, False
195
221
  )
222
+
223
+ self.recv_from_rpc = get_zmq_socket(
224
+ context, zmq.DEALER, port_args.rpc_ipc_name, False
225
+ )
196
226
  else:
197
227
  self.recv_from_tokenizer = None
228
+ self.recv_from_rpc = None
198
229
  self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
199
230
  self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
200
231
 
@@ -376,9 +407,16 @@ class Scheduler(SchedulerOutputProcessorMixin):
376
407
  (ProfileReq, self.profile),
377
408
  (GetInternalStateReq, self.get_internal_state),
378
409
  (SetInternalStateReq, self.set_internal_state),
410
+ (RpcReqInput, self.handle_rpc_request),
411
+ (ExpertDistributionReq, self.expert_distribution_handle),
379
412
  ]
380
413
  )
381
414
 
415
+ self.disaggregation_mode = DisaggregationMode(
416
+ self.server_args.disaggregation_mode
417
+ )
418
+ self.init_disaggregation()
419
+
382
420
  def init_tokenizer(self):
383
421
  server_args = self.server_args
384
422
 
@@ -435,6 +473,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
435
473
  token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
436
474
  tp_cache_group=self.tp_worker.get_tp_cpu_group(),
437
475
  page_size=self.page_size,
476
+ hicache_ratio=server_args.hicache_ratio,
438
477
  )
439
478
  else:
440
479
  self.tree_cache = RadixCache(
@@ -478,7 +517,74 @@ class Scheduler(SchedulerOutputProcessorMixin):
478
517
  },
479
518
  )
480
519
 
481
- @torch.no_grad()
520
+ def init_disaggregation(self):
521
+ if (
522
+ self.disaggregation_mode == DisaggregationMode.DECODE
523
+ ): # *2 for the headroom.
524
+ buffer_size = (self.req_to_token_pool.size) * 2
525
+ req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
526
+ buffer_size
527
+ )
528
+ aux_dtype = torch.int32
529
+ # A list of metadata buffers. The shape is (b, metadata_size) where
530
+ # b corresponds to a max running requests. The last shape * dtype.itemsize
531
+ # should be larger than 64 bytes to work with RDMA, so we pad it.
532
+ output_id_buffer = torch.zeros(
533
+ (buffer_size, 16), dtype=aux_dtype, device="cpu"
534
+ )
535
+ metadata_buffers = [output_id_buffer]
536
+
537
+ # The decode requests polling kv cache
538
+ self.disagg_decode_transfer_queue = DecodeTransferQueue(
539
+ gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
540
+ req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
541
+ metadata_buffers=metadata_buffers,
542
+ )
543
+
544
+ # The decode requests pending for pre-allocation
545
+ self.disagg_decode_prealloc_queue = DecodePreallocQueue(
546
+ req_to_token_pool=self.req_to_token_pool,
547
+ token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
548
+ req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
549
+ metadata_buffers=metadata_buffers,
550
+ aux_dtype=aux_dtype,
551
+ scheduler=self,
552
+ transfer_queue=self.disagg_decode_transfer_queue,
553
+ tree_cache=self.tree_cache,
554
+ gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
555
+ tp_rank=self.tp_rank,
556
+ tp_size=self.tp_size,
557
+ bootstrap_port=self.server_args.disaggregation_bootstrap_port,
558
+ )
559
+ elif self.disaggregation_mode == DisaggregationMode.PREFILL:
560
+ # *2 for the headroom.
561
+ buffer_size = self.max_running_requests * 2
562
+ req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
563
+ buffer_size
564
+ )
565
+ aux_dtype = torch.int32
566
+ # A list of metadata buffers. The shape is (b, metadata_size) where
567
+ # b corresponds to a max running requests. The last shape * dtype.itemsize
568
+ # should be larger than 64 bytes to work with RDMA, so we pad it.
569
+ output_id_buffer = torch.zeros(
570
+ (buffer_size, 16), dtype=aux_dtype, device="cpu"
571
+ )
572
+ metadata_buffers = [output_id_buffer]
573
+
574
+ self.disagg_prefill_pending_queue = PrefillBootstrapQueue(
575
+ token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
576
+ req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
577
+ metadata_buffers=metadata_buffers,
578
+ aux_dtype=aux_dtype,
579
+ tp_rank=self.tp_rank,
580
+ tp_size=self.tp_size,
581
+ bootstrap_port=self.server_args.disaggregation_bootstrap_port,
582
+ gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
583
+ )
584
+ # The prefill requests that are in the middle of kv sending
585
+ self.disagg_prefill_infight_queue: List[Req] = []
586
+
587
+ @DynamicGradMode()
482
588
  def event_loop_normal(self):
483
589
  """A normal scheduler loop."""
484
590
  while True:
@@ -498,7 +604,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
498
604
 
499
605
  self.last_batch = batch
500
606
 
501
- @torch.no_grad()
607
+ @DynamicGradMode()
502
608
  def event_loop_overlap(self):
503
609
  """A scheduler loop that overlaps the CPU processing and GPU computation."""
504
610
  self.result_queue = deque()
@@ -538,6 +644,70 @@ class Scheduler(SchedulerOutputProcessorMixin):
538
644
 
539
645
  self.last_batch = batch
540
646
 
647
+ @torch.no_grad()
648
+ def event_loop_normal_disagg_prefill(self):
649
+ """A normal scheduler loop for prefill worker in disaggregation mode."""
650
+
651
+ while True:
652
+ recv_reqs = self.recv_requests()
653
+ self.process_input_requests(recv_reqs)
654
+ self.waiting_queue.extend(
655
+ self.disagg_prefill_pending_queue.pop_bootstrapped()
656
+ )
657
+ self.process_prefill_chunk()
658
+ batch = self.get_new_batch_prefill()
659
+ self.cur_batch = batch
660
+
661
+ if batch:
662
+ result = self.run_batch(batch)
663
+ self.process_batch_result_disagg_prefill(batch, result)
664
+
665
+ if len(self.disagg_prefill_infight_queue) > 0:
666
+ self.process_disagg_prefill_infight_queue()
667
+
668
+ if batch is None and len(self.disagg_prefill_infight_queue) == 0:
669
+ self.check_memory()
670
+ self.new_token_ratio = self.init_new_token_ratio
671
+
672
+ self.last_batch = batch
673
+ # HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
674
+ # Otherwise, it hangs under high concurrency
675
+ self.running_batch.batch_is_full = False
676
+
677
+ @torch.no_grad()
678
+ def event_loop_normal_disagg_decode(self):
679
+ """A normal scheduler loop for decode worker in disaggregation mode."""
680
+
681
+ while True:
682
+ recv_reqs = self.recv_requests()
683
+ self.process_input_requests(recv_reqs)
684
+ # polling and allocating kv cache
685
+ self.process_decode_queue()
686
+ batch = self.get_next_disagg_decode_batch_to_run()
687
+ self.cur_batch = batch
688
+
689
+ if batch:
690
+ # Generate fake extend output.
691
+ if batch.forward_mode.is_extend():
692
+ # Note: Logprobs should be handled on the prefill engine.
693
+ self.stream_output(
694
+ batch.reqs, [False for _ in range(len(batch.reqs))]
695
+ )
696
+ else:
697
+ result = self.run_batch(batch)
698
+ self.process_batch_result(batch, result)
699
+
700
+ if batch is None and (
701
+ len(self.disagg_decode_transfer_queue.queue)
702
+ + len(self.disagg_decode_prealloc_queue.queue)
703
+ == 0
704
+ ):
705
+ # When the server is idle, do self-check and re-init some states
706
+ self.check_memory()
707
+ self.new_token_ratio = self.init_new_token_ratio
708
+
709
+ self.last_batch = batch
710
+
541
711
  def recv_requests(self) -> List[Req]:
542
712
  """Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
543
713
  if self.attn_tp_rank == 0:
@@ -549,6 +719,13 @@ class Scheduler(SchedulerOutputProcessorMixin):
549
719
  except zmq.ZMQError:
550
720
  break
551
721
  recv_reqs.append(recv_req)
722
+
723
+ while True:
724
+ try:
725
+ recv_rpc = self.recv_from_rpc.recv_pyobj(zmq.NOBLOCK)
726
+ except zmq.ZMQError:
727
+ break
728
+ recv_reqs.append(recv_rpc)
552
729
  else:
553
730
  recv_reqs = None
554
731
 
@@ -600,7 +777,11 @@ class Scheduler(SchedulerOutputProcessorMixin):
600
777
 
601
778
  output = self._request_dispatcher(recv_req)
602
779
  if output is not None:
603
- self.send_to_tokenizer.send_pyobj(output)
780
+ if isinstance(output, RpcReqOutput):
781
+ if self.recv_from_rpc is not None:
782
+ self.recv_from_rpc.send_pyobj(output)
783
+ else:
784
+ self.send_to_tokenizer.send_pyobj(output)
604
785
 
605
786
  def handle_generate_request(
606
787
  self,
@@ -666,8 +847,8 @@ class Scheduler(SchedulerOutputProcessorMixin):
666
847
  return
667
848
 
668
849
  # Handle multimodal inputs
669
- if recv_req.image_inputs is not None:
670
- image_inputs = ImageInputs.from_dict(recv_req.image_inputs)
850
+ if recv_req.mm_inputs is not None:
851
+ image_inputs = MultimodalInputs.from_dict(recv_req.mm_inputs)
671
852
  # Expand a single image token into multiple dummy tokens for receiving image embeddings
672
853
  req.origin_input_ids = self.pad_input_ids_func(
673
854
  req.origin_input_ids, image_inputs
@@ -681,7 +862,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
681
862
  )
682
863
  logger.error(error_msg)
683
864
  req.origin_input_ids = [0]
684
- req.image_inputs = None
865
+ req.multimodal_inputs = None
685
866
  req.sampling_params.max_new_tokens = 0
686
867
  req.finished_reason = FINISH_ABORT(
687
868
  error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
@@ -756,10 +937,20 @@ class Scheduler(SchedulerOutputProcessorMixin):
756
937
  self._add_request_to_queue(req)
757
938
 
758
939
  def _add_request_to_queue(self, req: Req):
759
- self.waiting_queue.append(req)
940
+ if self.disaggregation_mode == DisaggregationMode.PREFILL:
941
+ self.disagg_prefill_pending_queue.add(req)
760
942
 
761
- def _extend_requests_to_queue(self, reqs: List[Req]):
762
- self.waiting_queue.extend(reqs)
943
+ elif self.disaggregation_mode == DisaggregationMode.DECODE:
944
+ self.disagg_decode_prealloc_queue.add(req)
945
+
946
+ else:
947
+ self.waiting_queue.append(req)
948
+
949
+ def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False):
950
+ if self.disaggregation_mode == DisaggregationMode.DECODE:
951
+ self.disagg_decode_prealloc_queue.extend(reqs)
952
+ else:
953
+ self.waiting_queue.extend(reqs)
763
954
 
764
955
  def handle_embedding_request(
765
956
  self,
@@ -775,7 +966,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
775
966
 
776
967
  # Handle multimodal inputs
777
968
  if recv_req.image_inputs is not None:
778
- image_inputs = ImageInputs.from_dict(recv_req.image_inputs)
969
+ image_inputs = MultimodalInputs.from_dict(recv_req.image_inputs)
779
970
  # Expand a single image token into multiple dummy tokens for receiving image embeddings
780
971
  req.origin_input_ids = self.pad_input_ids_func(
781
972
  req.origin_input_ids, image_inputs
@@ -789,7 +980,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
789
980
  )
790
981
  logger.error(error_msg)
791
982
  req.origin_input_ids = [0]
792
- req.image_inputs = None
983
+ req.multimodal_inputs = None
793
984
  req.sampling_params.max_new_tokens = 0
794
985
  req.finished_reason = FINISH_ABORT(
795
986
  error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
@@ -875,7 +1066,6 @@ class Scheduler(SchedulerOutputProcessorMixin):
875
1066
  f"#token: {num_used}, "
876
1067
  f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
877
1068
  f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
878
- f"largest-len: {self._largest_prefill_decode_len}, "
879
1069
  f"#queue-req: {len(self.waiting_queue)}, "
880
1070
  )
881
1071
  spec_accept_length = 0
@@ -893,7 +1083,6 @@ class Scheduler(SchedulerOutputProcessorMixin):
893
1083
  f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
894
1084
  f"accept len: {spec_accept_length:.2f}, "
895
1085
  f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
896
- f"largest-len: {self._largest_prefill_decode_len}, "
897
1086
  f"#queue-req: {len(self.waiting_queue)}, "
898
1087
  )
899
1088
 
@@ -1492,6 +1681,47 @@ class Scheduler(SchedulerOutputProcessorMixin):
1492
1681
  server_args=global_server_args_dict,
1493
1682
  )
1494
1683
 
1684
+ def handle_rpc_request(self, recv_req: RpcReqInput):
1685
+ # Handle RPC requests
1686
+ logger.info(
1687
+ f"handle_rpc_request: {recv_req.method}, param: {recv_req.parameters}"
1688
+ )
1689
+
1690
+ success = True
1691
+ exec = None
1692
+ try:
1693
+ func = getattr(self, recv_req.method)
1694
+ func(recv_req.parameters)
1695
+ except Exception as e:
1696
+ success = False
1697
+ exec = e
1698
+ logger.error(f"Failed to call rpc {recv_req.method}: {str(e)}")
1699
+
1700
+ barrier()
1701
+ return RpcReqOutput(success, "" if not exec else str(exec))
1702
+
1703
+ def save_remote_model(self, params):
1704
+ url = params["url"]
1705
+
1706
+ if isinstance(self.tp_worker, TpModelWorkerClient):
1707
+ worker = self.tp_worker.worker
1708
+ else:
1709
+ worker = self.tp_worker
1710
+
1711
+ worker.model_runner.save_remote_model(url)
1712
+
1713
+ def save_sharded_model(self, params):
1714
+ if isinstance(self.tp_worker, TpModelWorkerClient):
1715
+ worker = self.tp_worker.worker
1716
+ else:
1717
+ worker = self.tp_worker
1718
+
1719
+ worker.model_runner.save_sharded_model(
1720
+ path=params["path"],
1721
+ pattern=params["pattern"],
1722
+ max_size=params["max_size"],
1723
+ )
1724
+
1495
1725
  def abort_request(self, recv_req: AbortReq):
1496
1726
  # Delete requests in the waiting queue
1497
1727
  to_del = []
@@ -1561,6 +1791,9 @@ class Scheduler(SchedulerOutputProcessorMixin):
1561
1791
  return GetWeightsByNameReqOutput(parameter)
1562
1792
 
1563
1793
  def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
1794
+ self.memory_saver_adapter.check_validity(
1795
+ caller_name="release_memory_occupation"
1796
+ )
1564
1797
  self.stashed_model_static_state = _export_static_state(
1565
1798
  self.tp_worker.worker.model_runner.model
1566
1799
  )
@@ -1569,6 +1802,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
1569
1802
  return ReleaseMemoryOccupationReqOutput()
1570
1803
 
1571
1804
  def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
1805
+ self.memory_saver_adapter.check_validity(caller_name="resume_memory_occupation")
1572
1806
  self.memory_saver_adapter.resume()
1573
1807
  _import_static_state(
1574
1808
  self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
@@ -1668,6 +1902,17 @@ class Scheduler(SchedulerOutputProcessorMixin):
1668
1902
  ProfileReqOutput(success=True, message="Succeeded.")
1669
1903
  )
1670
1904
 
1905
+ def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
1906
+ if recv_req == ExpertDistributionReq.START_RECORD:
1907
+ expert_distribution_recorder.start_record()
1908
+ elif recv_req == ExpertDistributionReq.STOP_RECORD:
1909
+ expert_distribution_recorder.stop_record()
1910
+ elif recv_req == ExpertDistributionReq.DUMP_RECORD:
1911
+ expert_distribution_recorder.dump_record()
1912
+ else:
1913
+ raise ValueError("Unrecognized ExpertDistributionReq value")
1914
+ return ExpertDistributionReqOutput()
1915
+
1671
1916
  def open_session(self, recv_req: OpenSessionReqInput):
1672
1917
  # handle error
1673
1918
  session_id = recv_req.session_id
@@ -1726,7 +1971,7 @@ def run_scheduler_process(
1726
1971
  prefix = f" DP{dp_rank} TP{tp_rank}"
1727
1972
 
1728
1973
  # Config the process
1729
- # kill_itself_when_parent_died() # This is disabled because it does not work for `--dp 2`
1974
+ kill_itself_when_parent_died()
1730
1975
  setproctitle.setproctitle(f"sglang::scheduler{prefix.replace(' ', '_')}")
1731
1976
  faulthandler.enable()
1732
1977
  parent_process = psutil.Process().parent()
@@ -1753,10 +1998,18 @@ def run_scheduler_process(
1753
1998
  "max_req_input_len": scheduler.max_req_input_len,
1754
1999
  }
1755
2000
  )
1756
- if scheduler.enable_overlap:
1757
- scheduler.event_loop_overlap()
1758
- else:
1759
- scheduler.event_loop_normal()
2001
+ disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode
2002
+
2003
+ if disaggregation_mode == DisaggregationMode.NULL:
2004
+ if scheduler.enable_overlap:
2005
+ scheduler.event_loop_overlap()
2006
+ else:
2007
+ scheduler.event_loop_normal()
2008
+ elif disaggregation_mode == DisaggregationMode.PREFILL:
2009
+ scheduler.event_loop_normal_disagg_prefill()
2010
+ elif disaggregation_mode == DisaggregationMode.DECODE:
2011
+ scheduler.event_loop_normal_disagg_decode()
2012
+
1760
2013
  except Exception:
1761
2014
  traceback = get_exception_traceback()
1762
2015
  logger.error(f"Scheduler hit an exception: {traceback}")
@@ -138,7 +138,7 @@ class Session:
138
138
  token_ids_logprob=req.token_ids_logprob,
139
139
  )
140
140
  if last_req is not None:
141
- new_req.image_inputs = last_req.image_inputs
141
+ new_req.multimodal_inputs = last_req.mm_inputs
142
142
  new_req.tokenizer = tokenizer
143
143
  if abort:
144
144
  new_req.to_abort = True
@@ -16,7 +16,6 @@
16
16
  import asyncio
17
17
  import copy
18
18
  import dataclasses
19
- import json
20
19
  import logging
21
20
  import os
22
21
  import pickle
@@ -49,11 +48,9 @@ from fastapi import BackgroundTasks
49
48
 
50
49
  from sglang.srt.aio_rwlock import RWLock
51
50
  from sglang.srt.configs.model_config import ModelConfig
51
+ from sglang.srt.disaggregation.conn import KVBootstrapServer
52
+ from sglang.srt.disaggregation.utils import DisaggregationMode
52
53
  from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
53
- from sglang.srt.managers.image_processor import (
54
- get_dummy_image_processor,
55
- get_image_processor,
56
- )
57
54
  from sglang.srt.managers.io_struct import (
58
55
  AbortReq,
59
56
  BatchEmbeddingOut,
@@ -63,6 +60,8 @@ from sglang.srt.managers.io_struct import (
63
60
  CloseSessionReqInput,
64
61
  ConfigureLoggingReq,
65
62
  EmbeddingReqInput,
63
+ ExpertDistributionReq,
64
+ ExpertDistributionReqOutput,
66
65
  FlushCacheReq,
67
66
  GenerateReqInput,
68
67
  GetInternalStateReq,
@@ -91,6 +90,11 @@ from sglang.srt.managers.io_struct import (
91
90
  UpdateWeightsFromTensorReqInput,
92
91
  UpdateWeightsFromTensorReqOutput,
93
92
  )
93
+ from sglang.srt.managers.multimodal_processor import (
94
+ get_dummy_processor,
95
+ get_mm_processor,
96
+ import_processors,
97
+ )
94
98
  from sglang.srt.metrics.collector import TokenizerMetricsCollector
95
99
  from sglang.srt.sampling.sampling_params import SamplingParams
96
100
  from sglang.srt.server_args import PortArgs, ServerArgs
@@ -168,27 +172,33 @@ class TokenizerManager:
168
172
  self.context_len = self.model_config.context_len
169
173
  self.image_token_id = self.model_config.image_token_id
170
174
 
171
- # Create image processor placeholder
172
- self.image_processor = get_dummy_image_processor()
175
+ if self.model_config.is_multimodal:
176
+ import_processors()
177
+ _processor = get_processor(
178
+ server_args.tokenizer_path,
179
+ tokenizer_mode=server_args.tokenizer_mode,
180
+ trust_remote_code=server_args.trust_remote_code,
181
+ revision=server_args.revision,
182
+ )
173
183
 
174
- # Create tokenizer
175
- if server_args.skip_tokenizer_init:
176
- self.tokenizer = self.processor = None
177
- else:
178
- if self.model_config.is_multimodal:
179
- self.processor = get_processor(
180
- server_args.tokenizer_path,
181
- tokenizer_mode=server_args.tokenizer_mode,
182
- trust_remote_code=server_args.trust_remote_code,
183
- revision=server_args.revision,
184
- )
184
+ # We want to parallelize the image pre-processing so we create an executor for it
185
+ # We create mm_processor for any skip_tokenizer_init to make sure we still encode
186
+ # images even with skip_tokenizer_init=False.
187
+ self.mm_processor = get_mm_processor(
188
+ self.model_config.hf_config, server_args, _processor
189
+ )
190
+
191
+ if server_args.skip_tokenizer_init:
192
+ self.tokenizer = self.processor = None
193
+ else:
194
+ self.processor = _processor
185
195
  self.tokenizer = self.processor.tokenizer
186
196
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
197
+ else:
198
+ self.mm_processor = get_dummy_processor()
187
199
 
188
- # We want to parallelize the image pre-processing so we create an executor for it
189
- self.image_processor = get_image_processor(
190
- self.model_config.hf_config, server_args, self.processor
191
- )
200
+ if server_args.skip_tokenizer_init:
201
+ self.tokenizer = self.processor = None
192
202
  else:
193
203
  self.tokenizer = get_tokenizer(
194
204
  server_args.tokenizer_path,
@@ -255,6 +265,9 @@ class TokenizerManager:
255
265
  self.get_internal_state_communicator = _Communicator(
256
266
  self.send_to_scheduler, server_args.dp_size
257
267
  )
268
+ self.expert_distribution_communicator = _Communicator(
269
+ self.send_to_scheduler, server_args.dp_size
270
+ )
258
271
 
259
272
  self._result_dispatcher = TypeBasedDispatcher(
260
273
  [
@@ -304,10 +317,24 @@ class TokenizerManager:
304
317
  GetInternalStateReqOutput,
305
318
  self.get_internal_state_communicator.handle_recv,
306
319
  ),
320
+ (
321
+ ExpertDistributionReqOutput,
322
+ self.expert_distribution_communicator.handle_recv,
323
+ ),
307
324
  (HealthCheckOutput, lambda x: None),
308
325
  ]
309
326
  )
310
327
 
328
+ self.disaggregation_mode = DisaggregationMode(
329
+ self.server_args.disaggregation_mode
330
+ )
331
+ # for disaggregtion, start kv boostrap server on prefill
332
+ if self.disaggregation_mode == DisaggregationMode.PREFILL:
333
+ # only start bootstrap server on prefill tm
334
+ self.bootstrap_server = KVBootstrapServer(
335
+ self.server_args.disaggregation_bootstrap_port
336
+ )
337
+
311
338
  async def generate_request(
312
339
  self,
313
340
  obj: Union[GenerateReqInput, EmbeddingReqInput],
@@ -372,7 +399,7 @@ class TokenizerManager:
372
399
  )
373
400
  input_ids = self.tokenizer.encode(input_text)
374
401
 
375
- image_inputs: Dict = await self.image_processor.process_images_async(
402
+ image_inputs: Dict = await self.mm_processor.process_mm_data_async(
376
403
  obj.image_data, input_text or input_ids, obj, self.max_req_input_len
377
404
  )
378
405
  if image_inputs and "input_ids" in image_inputs:
@@ -620,6 +647,15 @@ class TokenizerManager:
620
647
  req = ProfileReq(type=ProfileReqType.STOP_PROFILE)
621
648
  self.send_to_scheduler.send_pyobj(req)
622
649
 
650
+ async def start_expert_distribution_record(self):
651
+ await self.expert_distribution_communicator(ExpertDistributionReq.START_RECORD)
652
+
653
+ async def stop_expert_distribution_record(self):
654
+ await self.expert_distribution_communicator(ExpertDistributionReq.STOP_RECORD)
655
+
656
+ async def dump_expert_distribution_record(self):
657
+ await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
658
+
623
659
  async def update_weights_from_disk(
624
660
  self,
625
661
  obj: UpdateWeightFromDiskReqInput,
@@ -214,7 +214,7 @@ class TpModelWorker:
214
214
  def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
215
215
  success, message = self.model_runner.update_weights_from_tensor(
216
216
  named_tensors=MultiprocessingSerializer.deserialize(
217
- recv_req.serialized_named_tensors
217
+ recv_req.serialized_named_tensors[self.tp_rank]
218
218
  ),
219
219
  load_format=recv_req.load_format,
220
220
  )
@@ -33,7 +33,7 @@ from sglang.srt.managers.io_struct import (
33
33
  from sglang.srt.managers.schedule_batch import ModelWorkerBatch
34
34
  from sglang.srt.managers.tp_worker import TpModelWorker
35
35
  from sglang.srt.server_args import ServerArgs
36
- from sglang.srt.utils import get_compiler_backend
36
+ from sglang.srt.utils import DynamicGradMode, get_compiler_backend
37
37
  from sglang.utils import get_exception_traceback
38
38
 
39
39
  logger = logging.getLogger(__name__)
@@ -69,7 +69,7 @@ class TpModelWorkerClient:
69
69
  self.future_token_ids_ct = 0
70
70
  self.future_token_ids_limit = self.max_running_requests * 3
71
71
  self.future_token_ids_map = torch.empty(
72
- (self.max_running_requests * 5,), dtype=torch.int32, device=self.device
72
+ (self.max_running_requests * 5,), dtype=torch.int64, device=self.device
73
73
  )
74
74
 
75
75
  # Launch threads
@@ -115,7 +115,7 @@ class TpModelWorkerClient:
115
115
  logger.error(f"TpModelWorkerClient hit an exception: {traceback}")
116
116
  self.parent_process.send_signal(signal.SIGQUIT)
117
117
 
118
- @torch.no_grad()
118
+ @DynamicGradMode()
119
119
  def forward_thread_func_(self):
120
120
  batch_pt = 0
121
121
  batch_lists = [None] * 2