sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +6 -0
  3. sglang/bench_one_batch.py +1 -1
  4. sglang/bench_one_batch_server.py +1 -1
  5. sglang/bench_serving.py +26 -4
  6. sglang/check_env.py +3 -4
  7. sglang/lang/backend/openai.py +18 -5
  8. sglang/lang/chat_template.py +28 -7
  9. sglang/lang/interpreter.py +7 -3
  10. sglang/lang/ir.py +10 -0
  11. sglang/srt/_custom_ops.py +1 -1
  12. sglang/srt/code_completion_parser.py +174 -0
  13. sglang/srt/configs/__init__.py +2 -6
  14. sglang/srt/configs/deepseekvl2.py +676 -0
  15. sglang/srt/configs/janus_pro.py +3 -4
  16. sglang/srt/configs/load_config.py +1 -0
  17. sglang/srt/configs/model_config.py +49 -8
  18. sglang/srt/configs/utils.py +25 -0
  19. sglang/srt/connector/__init__.py +51 -0
  20. sglang/srt/connector/base_connector.py +112 -0
  21. sglang/srt/connector/redis.py +85 -0
  22. sglang/srt/connector/s3.py +122 -0
  23. sglang/srt/connector/serde/__init__.py +31 -0
  24. sglang/srt/connector/serde/safe_serde.py +29 -0
  25. sglang/srt/connector/serde/serde.py +43 -0
  26. sglang/srt/connector/utils.py +35 -0
  27. sglang/srt/conversation.py +88 -0
  28. sglang/srt/disaggregation/conn.py +81 -0
  29. sglang/srt/disaggregation/decode.py +495 -0
  30. sglang/srt/disaggregation/mini_lb.py +285 -0
  31. sglang/srt/disaggregation/prefill.py +249 -0
  32. sglang/srt/disaggregation/utils.py +44 -0
  33. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  34. sglang/srt/distributed/parallel_state.py +42 -8
  35. sglang/srt/entrypoints/engine.py +55 -5
  36. sglang/srt/entrypoints/http_server.py +78 -13
  37. sglang/srt/entrypoints/verl_engine.py +2 -0
  38. sglang/srt/function_call_parser.py +133 -55
  39. sglang/srt/hf_transformers_utils.py +28 -3
  40. sglang/srt/layers/activation.py +4 -2
  41. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  42. sglang/srt/layers/attention/flashattention_backend.py +434 -0
  43. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  44. sglang/srt/layers/attention/flashmla_backend.py +284 -0
  45. sglang/srt/layers/attention/triton_backend.py +171 -38
  46. sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
  47. sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
  48. sglang/srt/layers/attention/utils.py +53 -0
  49. sglang/srt/layers/attention/vision.py +9 -28
  50. sglang/srt/layers/dp_attention.py +41 -19
  51. sglang/srt/layers/layernorm.py +24 -2
  52. sglang/srt/layers/linear.py +17 -5
  53. sglang/srt/layers/logits_processor.py +25 -7
  54. sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
  55. sglang/srt/layers/moe/ep_moe/layer.py +273 -1
  56. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
  57. sglang/srt/layers/moe/fused_moe_native.py +2 -1
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
  63. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
  64. sglang/srt/layers/moe/topk.py +60 -20
  65. sglang/srt/layers/parameter.py +1 -1
  66. sglang/srt/layers/quantization/__init__.py +80 -53
  67. sglang/srt/layers/quantization/awq.py +200 -0
  68. sglang/srt/layers/quantization/base_config.py +5 -0
  69. sglang/srt/layers/quantization/blockwise_int8.py +1 -1
  70. sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
  71. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
  72. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
  73. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
  74. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
  75. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
  76. sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
  77. sglang/srt/layers/quantization/fp8.py +76 -34
  78. sglang/srt/layers/quantization/fp8_kernel.py +25 -8
  79. sglang/srt/layers/quantization/fp8_utils.py +284 -28
  80. sglang/srt/layers/quantization/gptq.py +36 -19
  81. sglang/srt/layers/quantization/kv_cache.py +98 -0
  82. sglang/srt/layers/quantization/modelopt_quant.py +9 -7
  83. sglang/srt/layers/quantization/utils.py +153 -0
  84. sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
  85. sglang/srt/layers/rotary_embedding.py +78 -87
  86. sglang/srt/layers/sampler.py +1 -1
  87. sglang/srt/lora/backend/base_backend.py +4 -4
  88. sglang/srt/lora/backend/flashinfer_backend.py +12 -9
  89. sglang/srt/lora/backend/triton_backend.py +5 -8
  90. sglang/srt/lora/layers.py +87 -33
  91. sglang/srt/lora/lora.py +2 -22
  92. sglang/srt/lora/lora_manager.py +67 -30
  93. sglang/srt/lora/mem_pool.py +117 -52
  94. sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
  95. sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
  96. sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
  97. sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
  98. sglang/srt/lora/utils.py +18 -1
  99. sglang/srt/managers/cache_controller.py +2 -5
  100. sglang/srt/managers/data_parallel_controller.py +30 -8
  101. sglang/srt/managers/expert_distribution.py +81 -0
  102. sglang/srt/managers/io_struct.py +43 -5
  103. sglang/srt/managers/mm_utils.py +373 -0
  104. sglang/srt/managers/multimodal_processor.py +68 -0
  105. sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
  106. sglang/srt/managers/multimodal_processors/clip.py +63 -0
  107. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
  108. sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
  109. sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
  110. sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
  111. sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
  112. sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
  113. sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
  114. sglang/srt/managers/schedule_batch.py +134 -30
  115. sglang/srt/managers/scheduler.py +290 -31
  116. sglang/srt/managers/session_controller.py +1 -1
  117. sglang/srt/managers/tokenizer_manager.py +59 -24
  118. sglang/srt/managers/tp_worker.py +4 -1
  119. sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
  120. sglang/srt/managers/utils.py +6 -1
  121. sglang/srt/mem_cache/hiradix_cache.py +18 -7
  122. sglang/srt/mem_cache/memory_pool.py +255 -98
  123. sglang/srt/mem_cache/paged_allocator.py +2 -2
  124. sglang/srt/mem_cache/radix_cache.py +4 -4
  125. sglang/srt/model_executor/cuda_graph_runner.py +36 -21
  126. sglang/srt/model_executor/forward_batch_info.py +68 -11
  127. sglang/srt/model_executor/model_runner.py +75 -8
  128. sglang/srt/model_loader/loader.py +171 -3
  129. sglang/srt/model_loader/weight_utils.py +51 -3
  130. sglang/srt/models/clip.py +563 -0
  131. sglang/srt/models/deepseek_janus_pro.py +31 -88
  132. sglang/srt/models/deepseek_nextn.py +22 -10
  133. sglang/srt/models/deepseek_v2.py +329 -73
  134. sglang/srt/models/deepseek_vl2.py +358 -0
  135. sglang/srt/models/gemma3_causal.py +694 -0
  136. sglang/srt/models/gemma3_mm.py +468 -0
  137. sglang/srt/models/llama.py +47 -7
  138. sglang/srt/models/llama_eagle.py +1 -0
  139. sglang/srt/models/llama_eagle3.py +196 -0
  140. sglang/srt/models/llava.py +3 -3
  141. sglang/srt/models/llavavid.py +3 -3
  142. sglang/srt/models/minicpmo.py +1995 -0
  143. sglang/srt/models/minicpmv.py +62 -137
  144. sglang/srt/models/mllama.py +4 -4
  145. sglang/srt/models/phi3_small.py +1 -1
  146. sglang/srt/models/qwen2.py +3 -0
  147. sglang/srt/models/qwen2_5_vl.py +68 -146
  148. sglang/srt/models/qwen2_classification.py +75 -0
  149. sglang/srt/models/qwen2_moe.py +9 -1
  150. sglang/srt/models/qwen2_vl.py +25 -63
  151. sglang/srt/openai_api/adapter.py +201 -104
  152. sglang/srt/openai_api/protocol.py +33 -7
  153. sglang/srt/patch_torch.py +71 -0
  154. sglang/srt/sampling/sampling_batch_info.py +1 -1
  155. sglang/srt/sampling/sampling_params.py +6 -6
  156. sglang/srt/server_args.py +114 -14
  157. sglang/srt/speculative/build_eagle_tree.py +7 -347
  158. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
  159. sglang/srt/speculative/eagle_utils.py +208 -252
  160. sglang/srt/speculative/eagle_worker.py +140 -54
  161. sglang/srt/speculative/spec_info.py +6 -1
  162. sglang/srt/torch_memory_saver_adapter.py +22 -0
  163. sglang/srt/utils.py +215 -21
  164. sglang/test/__init__.py +0 -0
  165. sglang/test/attention/__init__.py +0 -0
  166. sglang/test/attention/test_flashattn_backend.py +312 -0
  167. sglang/test/runners.py +29 -2
  168. sglang/test/test_activation.py +2 -1
  169. sglang/test/test_block_fp8.py +5 -4
  170. sglang/test/test_block_fp8_ep.py +2 -1
  171. sglang/test/test_dynamic_grad_mode.py +58 -0
  172. sglang/test/test_layernorm.py +3 -2
  173. sglang/test/test_utils.py +56 -5
  174. sglang/utils.py +31 -0
  175. sglang/version.py +1 -1
  176. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/METADATA +16 -8
  177. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/RECORD +180 -132
  178. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/WHEEL +1 -1
  179. sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
  180. sglang/srt/managers/image_processor.py +0 -55
  181. sglang/srt/managers/image_processors/base_image_processor.py +0 -219
  182. sglang/srt/managers/image_processors/minicpmv.py +0 -86
  183. sglang/srt/managers/multi_modality_padding.py +0 -134
  184. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info/licenses}/LICENSE +0 -0
  185. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/top_level.txt +0 -0
@@ -32,16 +32,33 @@ import psutil
32
32
  import setproctitle
33
33
  import torch
34
34
  import zmq
35
+ from torch.distributed import barrier
35
36
 
36
37
  from sglang.global_config import global_config
37
38
  from sglang.srt.configs.model_config import ModelConfig
38
39
  from sglang.srt.constrained.base_grammar_backend import create_grammar_backend
40
+ from sglang.srt.disaggregation.decode import (
41
+ DecodePreallocQueue,
42
+ DecodeTransferQueue,
43
+ SchedulerDisaggregationDecodeMixin,
44
+ )
45
+ from sglang.srt.disaggregation.prefill import (
46
+ PrefillBootstrapQueue,
47
+ SchedulerDisaggregationPrefillMixin,
48
+ )
49
+ from sglang.srt.disaggregation.utils import (
50
+ DisaggregationMode,
51
+ ReqToMetadataIdxAllocator,
52
+ )
39
53
  from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
40
54
  from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
41
55
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
56
+ from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
42
57
  from sglang.srt.managers.io_struct import (
43
58
  AbortReq,
44
59
  CloseSessionReqInput,
60
+ ExpertDistributionReq,
61
+ ExpertDistributionReqOutput,
45
62
  FlushCacheReq,
46
63
  GetInternalStateReq,
47
64
  GetInternalStateReqOutput,
@@ -59,6 +76,8 @@ from sglang.srt.managers.io_struct import (
59
76
  ReleaseMemoryOccupationReqOutput,
60
77
  ResumeMemoryOccupationReqInput,
61
78
  ResumeMemoryOccupationReqOutput,
79
+ RpcReqInput,
80
+ RpcReqOutput,
62
81
  SetInternalStateReq,
63
82
  SetInternalStateReqOutput,
64
83
  TokenizedEmbeddingReqInput,
@@ -72,7 +91,7 @@ from sglang.srt.managers.io_struct import (
72
91
  )
73
92
  from sglang.srt.managers.schedule_batch import (
74
93
  FINISH_ABORT,
75
- ImageInputs,
94
+ MultimodalInputs,
76
95
  Req,
77
96
  ScheduleBatch,
78
97
  global_server_args_dict,
@@ -98,6 +117,7 @@ from sglang.srt.server_args import PortArgs, ServerArgs
98
117
  from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
99
118
  from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
100
119
  from sglang.srt.utils import (
120
+ DynamicGradMode,
101
121
  broadcast_pyobj,
102
122
  configure_logger,
103
123
  crash_on_warnings,
@@ -111,6 +131,8 @@ from sglang.srt.utils import (
111
131
  )
112
132
  from sglang.utils import TypeBasedDispatcher, get_exception_traceback
113
133
 
134
+ expert_distribution_recorder = ExpertDistributionRecorder()
135
+
114
136
  logger = logging.getLogger(__name__)
115
137
 
116
138
  # Test retract decode for debugging purposes
@@ -133,7 +155,11 @@ class EmbeddingBatchResult:
133
155
  bid: int
134
156
 
135
157
 
136
- class Scheduler(SchedulerOutputProcessorMixin):
158
+ class Scheduler(
159
+ SchedulerOutputProcessorMixin,
160
+ SchedulerDisaggregationDecodeMixin,
161
+ SchedulerDisaggregationPrefillMixin,
162
+ ):
137
163
  """A scheduler that manages a tensor parallel GPU worker."""
138
164
 
139
165
  def __init__(
@@ -193,8 +219,13 @@ class Scheduler(SchedulerOutputProcessorMixin):
193
219
  self.send_to_detokenizer = get_zmq_socket(
194
220
  context, zmq.PUSH, port_args.detokenizer_ipc_name, False
195
221
  )
222
+
223
+ self.recv_from_rpc = get_zmq_socket(
224
+ context, zmq.DEALER, port_args.rpc_ipc_name, False
225
+ )
196
226
  else:
197
227
  self.recv_from_tokenizer = None
228
+ self.recv_from_rpc = None
198
229
  self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
199
230
  self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
200
231
 
@@ -348,7 +379,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
348
379
  # Init profiler
349
380
  self.torch_profiler = None
350
381
  self.torch_profiler_output_dir: Optional[str] = None
351
- self.torch_profiler_activities: Optional[List[str]] = None
382
+ self.profiler_activities: Optional[List[str]] = None
352
383
  self.profiler_target_forward_ct: Optional[int] = None
353
384
 
354
385
  # Init metrics stats
@@ -376,9 +407,16 @@ class Scheduler(SchedulerOutputProcessorMixin):
376
407
  (ProfileReq, self.profile),
377
408
  (GetInternalStateReq, self.get_internal_state),
378
409
  (SetInternalStateReq, self.set_internal_state),
410
+ (RpcReqInput, self.handle_rpc_request),
411
+ (ExpertDistributionReq, self.expert_distribution_handle),
379
412
  ]
380
413
  )
381
414
 
415
+ self.disaggregation_mode = DisaggregationMode(
416
+ self.server_args.disaggregation_mode
417
+ )
418
+ self.init_disaggregation()
419
+
382
420
  def init_tokenizer(self):
383
421
  server_args = self.server_args
384
422
 
@@ -435,6 +473,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
435
473
  token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
436
474
  tp_cache_group=self.tp_worker.get_tp_cpu_group(),
437
475
  page_size=self.page_size,
476
+ hicache_ratio=server_args.hicache_ratio,
438
477
  )
439
478
  else:
440
479
  self.tree_cache = RadixCache(
@@ -478,7 +517,74 @@ class Scheduler(SchedulerOutputProcessorMixin):
478
517
  },
479
518
  )
480
519
 
481
- @torch.no_grad()
520
+ def init_disaggregation(self):
521
+ if (
522
+ self.disaggregation_mode == DisaggregationMode.DECODE
523
+ ): # *2 for the headroom.
524
+ buffer_size = (self.req_to_token_pool.size) * 2
525
+ req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
526
+ buffer_size
527
+ )
528
+ aux_dtype = torch.int32
529
+ # A list of metadata buffers. The shape is (b, metadata_size) where
530
+ # b corresponds to a max running requests. The last shape * dtype.itemsize
531
+ # should be larger than 64 bytes to work with RDMA, so we pad it.
532
+ output_id_buffer = torch.zeros(
533
+ (buffer_size, 16), dtype=aux_dtype, device="cpu"
534
+ )
535
+ metadata_buffers = [output_id_buffer]
536
+
537
+ # The decode requests polling kv cache
538
+ self.disagg_decode_transfer_queue = DecodeTransferQueue(
539
+ gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
540
+ req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
541
+ metadata_buffers=metadata_buffers,
542
+ )
543
+
544
+ # The decode requests pending for pre-allocation
545
+ self.disagg_decode_prealloc_queue = DecodePreallocQueue(
546
+ req_to_token_pool=self.req_to_token_pool,
547
+ token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
548
+ req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
549
+ metadata_buffers=metadata_buffers,
550
+ aux_dtype=aux_dtype,
551
+ scheduler=self,
552
+ transfer_queue=self.disagg_decode_transfer_queue,
553
+ tree_cache=self.tree_cache,
554
+ gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
555
+ tp_rank=self.tp_rank,
556
+ tp_size=self.tp_size,
557
+ bootstrap_port=self.server_args.disaggregation_bootstrap_port,
558
+ )
559
+ elif self.disaggregation_mode == DisaggregationMode.PREFILL:
560
+ # *2 for the headroom.
561
+ buffer_size = self.max_running_requests * 2
562
+ req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
563
+ buffer_size
564
+ )
565
+ aux_dtype = torch.int32
566
+ # A list of metadata buffers. The shape is (b, metadata_size) where
567
+ # b corresponds to a max running requests. The last shape * dtype.itemsize
568
+ # should be larger than 64 bytes to work with RDMA, so we pad it.
569
+ output_id_buffer = torch.zeros(
570
+ (buffer_size, 16), dtype=aux_dtype, device="cpu"
571
+ )
572
+ metadata_buffers = [output_id_buffer]
573
+
574
+ self.disagg_prefill_pending_queue = PrefillBootstrapQueue(
575
+ token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
576
+ req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
577
+ metadata_buffers=metadata_buffers,
578
+ aux_dtype=aux_dtype,
579
+ tp_rank=self.tp_rank,
580
+ tp_size=self.tp_size,
581
+ bootstrap_port=self.server_args.disaggregation_bootstrap_port,
582
+ gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
583
+ )
584
+ # The prefill requests that are in the middle of kv sending
585
+ self.disagg_prefill_infight_queue: List[Req] = []
586
+
587
+ @DynamicGradMode()
482
588
  def event_loop_normal(self):
483
589
  """A normal scheduler loop."""
484
590
  while True:
@@ -498,7 +604,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
498
604
 
499
605
  self.last_batch = batch
500
606
 
501
- @torch.no_grad()
607
+ @DynamicGradMode()
502
608
  def event_loop_overlap(self):
503
609
  """A scheduler loop that overlaps the CPU processing and GPU computation."""
504
610
  self.result_queue = deque()
@@ -538,6 +644,70 @@ class Scheduler(SchedulerOutputProcessorMixin):
538
644
 
539
645
  self.last_batch = batch
540
646
 
647
+ @torch.no_grad()
648
+ def event_loop_normal_disagg_prefill(self):
649
+ """A normal scheduler loop for prefill worker in disaggregation mode."""
650
+
651
+ while True:
652
+ recv_reqs = self.recv_requests()
653
+ self.process_input_requests(recv_reqs)
654
+ self.waiting_queue.extend(
655
+ self.disagg_prefill_pending_queue.pop_bootstrapped()
656
+ )
657
+ self.process_prefill_chunk()
658
+ batch = self.get_new_batch_prefill()
659
+ self.cur_batch = batch
660
+
661
+ if batch:
662
+ result = self.run_batch(batch)
663
+ self.process_batch_result_disagg_prefill(batch, result)
664
+
665
+ if len(self.disagg_prefill_infight_queue) > 0:
666
+ self.process_disagg_prefill_infight_queue()
667
+
668
+ if batch is None and len(self.disagg_prefill_infight_queue) == 0:
669
+ self.check_memory()
670
+ self.new_token_ratio = self.init_new_token_ratio
671
+
672
+ self.last_batch = batch
673
+ # HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
674
+ # Otherwise, it hangs under high concurrency
675
+ self.running_batch.batch_is_full = False
676
+
677
+ @torch.no_grad()
678
+ def event_loop_normal_disagg_decode(self):
679
+ """A normal scheduler loop for decode worker in disaggregation mode."""
680
+
681
+ while True:
682
+ recv_reqs = self.recv_requests()
683
+ self.process_input_requests(recv_reqs)
684
+ # polling and allocating kv cache
685
+ self.process_decode_queue()
686
+ batch = self.get_next_disagg_decode_batch_to_run()
687
+ self.cur_batch = batch
688
+
689
+ if batch:
690
+ # Generate fake extend output.
691
+ if batch.forward_mode.is_extend():
692
+ # Note: Logprobs should be handled on the prefill engine.
693
+ self.stream_output(
694
+ batch.reqs, [False for _ in range(len(batch.reqs))]
695
+ )
696
+ else:
697
+ result = self.run_batch(batch)
698
+ self.process_batch_result(batch, result)
699
+
700
+ if batch is None and (
701
+ len(self.disagg_decode_transfer_queue.queue)
702
+ + len(self.disagg_decode_prealloc_queue.queue)
703
+ == 0
704
+ ):
705
+ # When the server is idle, do self-check and re-init some states
706
+ self.check_memory()
707
+ self.new_token_ratio = self.init_new_token_ratio
708
+
709
+ self.last_batch = batch
710
+
541
711
  def recv_requests(self) -> List[Req]:
542
712
  """Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
543
713
  if self.attn_tp_rank == 0:
@@ -549,6 +719,13 @@ class Scheduler(SchedulerOutputProcessorMixin):
549
719
  except zmq.ZMQError:
550
720
  break
551
721
  recv_reqs.append(recv_req)
722
+
723
+ while True:
724
+ try:
725
+ recv_rpc = self.recv_from_rpc.recv_pyobj(zmq.NOBLOCK)
726
+ except zmq.ZMQError:
727
+ break
728
+ recv_reqs.append(recv_rpc)
552
729
  else:
553
730
  recv_reqs = None
554
731
 
@@ -600,7 +777,11 @@ class Scheduler(SchedulerOutputProcessorMixin):
600
777
 
601
778
  output = self._request_dispatcher(recv_req)
602
779
  if output is not None:
603
- self.send_to_tokenizer.send_pyobj(output)
780
+ if isinstance(output, RpcReqOutput):
781
+ if self.recv_from_rpc is not None:
782
+ self.recv_from_rpc.send_pyobj(output)
783
+ else:
784
+ self.send_to_tokenizer.send_pyobj(output)
604
785
 
605
786
  def handle_generate_request(
606
787
  self,
@@ -666,8 +847,8 @@ class Scheduler(SchedulerOutputProcessorMixin):
666
847
  return
667
848
 
668
849
  # Handle multimodal inputs
669
- if recv_req.image_inputs is not None:
670
- image_inputs = ImageInputs.from_dict(recv_req.image_inputs)
850
+ if recv_req.mm_inputs is not None:
851
+ image_inputs = MultimodalInputs.from_dict(recv_req.mm_inputs)
671
852
  # Expand a single image token into multiple dummy tokens for receiving image embeddings
672
853
  req.origin_input_ids = self.pad_input_ids_func(
673
854
  req.origin_input_ids, image_inputs
@@ -681,7 +862,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
681
862
  )
682
863
  logger.error(error_msg)
683
864
  req.origin_input_ids = [0]
684
- req.image_inputs = None
865
+ req.multimodal_inputs = None
685
866
  req.sampling_params.max_new_tokens = 0
686
867
  req.finished_reason = FINISH_ABORT(
687
868
  error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
@@ -756,10 +937,20 @@ class Scheduler(SchedulerOutputProcessorMixin):
756
937
  self._add_request_to_queue(req)
757
938
 
758
939
  def _add_request_to_queue(self, req: Req):
759
- self.waiting_queue.append(req)
940
+ if self.disaggregation_mode == DisaggregationMode.PREFILL:
941
+ self.disagg_prefill_pending_queue.add(req)
942
+
943
+ elif self.disaggregation_mode == DisaggregationMode.DECODE:
944
+ self.disagg_decode_prealloc_queue.add(req)
945
+
946
+ else:
947
+ self.waiting_queue.append(req)
760
948
 
761
- def _extend_requests_to_queue(self, reqs: List[Req]):
762
- self.waiting_queue.extend(reqs)
949
+ def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False):
950
+ if self.disaggregation_mode == DisaggregationMode.DECODE:
951
+ self.disagg_decode_prealloc_queue.extend(reqs)
952
+ else:
953
+ self.waiting_queue.extend(reqs)
763
954
 
764
955
  def handle_embedding_request(
765
956
  self,
@@ -775,7 +966,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
775
966
 
776
967
  # Handle multimodal inputs
777
968
  if recv_req.image_inputs is not None:
778
- image_inputs = ImageInputs.from_dict(recv_req.image_inputs)
969
+ image_inputs = MultimodalInputs.from_dict(recv_req.image_inputs)
779
970
  # Expand a single image token into multiple dummy tokens for receiving image embeddings
780
971
  req.origin_input_ids = self.pad_input_ids_func(
781
972
  req.origin_input_ids, image_inputs
@@ -789,7 +980,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
789
980
  )
790
981
  logger.error(error_msg)
791
982
  req.origin_input_ids = [0]
792
- req.image_inputs = None
983
+ req.multimodal_inputs = None
793
984
  req.sampling_params.max_new_tokens = 0
794
985
  req.finished_reason = FINISH_ABORT(
795
986
  error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
@@ -875,7 +1066,6 @@ class Scheduler(SchedulerOutputProcessorMixin):
875
1066
  f"#token: {num_used}, "
876
1067
  f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
877
1068
  f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
878
- f"largest-len: {self._largest_prefill_decode_len}, "
879
1069
  f"#queue-req: {len(self.waiting_queue)}, "
880
1070
  )
881
1071
  spec_accept_length = 0
@@ -893,7 +1083,6 @@ class Scheduler(SchedulerOutputProcessorMixin):
893
1083
  f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
894
1084
  f"accept len: {spec_accept_length:.2f}, "
895
1085
  f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
896
- f"largest-len: {self._largest_prefill_decode_len}, "
897
1086
  f"#queue-req: {len(self.waiting_queue)}, "
898
1087
  )
899
1088
 
@@ -997,7 +1186,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
997
1186
  ret = None
998
1187
 
999
1188
  # Handle DP attention
1000
- if self.server_args.enable_dp_attention:
1189
+ if self.server_args.enable_dp_attention or self.server_args.enable_sp_layernorm:
1001
1190
  ret, _ = self.prepare_dp_attn_batch(ret)
1002
1191
 
1003
1192
  return ret
@@ -1492,6 +1681,41 @@ class Scheduler(SchedulerOutputProcessorMixin):
1492
1681
  server_args=global_server_args_dict,
1493
1682
  )
1494
1683
 
1684
+ def handle_rpc_request(self, recv_req: RpcReqInput):
1685
+ # Handle RPC requests
1686
+ logger.info(
1687
+ f"handle_rpc_request: {recv_req.method}, param: {recv_req.parameters}"
1688
+ )
1689
+
1690
+ success = True
1691
+ exec = None
1692
+ try:
1693
+ func = getattr(self, recv_req.method)
1694
+ func(recv_req.parameters)
1695
+ except Exception as e:
1696
+ success = False
1697
+ exec = e
1698
+ logger.error(f"Failed to call rpc {recv_req.method}: {str(e)}")
1699
+
1700
+ barrier()
1701
+ return RpcReqOutput(success, "" if not exec else str(exec))
1702
+
1703
+ def save_remote_model(self, params):
1704
+ url = params["url"]
1705
+
1706
+ worker = self.tp_worker.worker
1707
+
1708
+ worker.model_runner.save_remote_model(url)
1709
+
1710
+ def save_sharded_model(self, params):
1711
+ worker = self.tp_worker.worker
1712
+
1713
+ worker.model_runner.save_sharded_model(
1714
+ path=params["path"],
1715
+ pattern=params["pattern"],
1716
+ max_size=params["max_size"],
1717
+ )
1718
+
1495
1719
  def abort_request(self, recv_req: AbortReq):
1496
1720
  # Delete requests in the waiting queue
1497
1721
  to_del = []
@@ -1561,6 +1785,9 @@ class Scheduler(SchedulerOutputProcessorMixin):
1561
1785
  return GetWeightsByNameReqOutput(parameter)
1562
1786
 
1563
1787
  def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
1788
+ self.memory_saver_adapter.check_validity(
1789
+ caller_name="release_memory_occupation"
1790
+ )
1564
1791
  self.stashed_model_static_state = _export_static_state(
1565
1792
  self.tp_worker.worker.model_runner.model
1566
1793
  )
@@ -1569,6 +1796,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
1569
1796
  return ReleaseMemoryOccupationReqOutput()
1570
1797
 
1571
1798
  def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
1799
+ self.memory_saver_adapter.check_validity(caller_name="resume_memory_occupation")
1572
1800
  self.memory_saver_adapter.resume()
1573
1801
  _import_static_state(
1574
1802
  self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
@@ -1579,7 +1807,11 @@ class Scheduler(SchedulerOutputProcessorMixin):
1579
1807
  def profile(self, recv_req: ProfileReq):
1580
1808
  if recv_req.type == ProfileReqType.START_PROFILE:
1581
1809
  return self.start_profile(
1582
- recv_req.output_dir, recv_req.num_steps, recv_req.activities
1810
+ recv_req.output_dir,
1811
+ recv_req.num_steps,
1812
+ recv_req.activities,
1813
+ recv_req.with_stack,
1814
+ recv_req.record_shapes,
1583
1815
  )
1584
1816
  else:
1585
1817
  return self.stop_profile()
@@ -1589,8 +1821,10 @@ class Scheduler(SchedulerOutputProcessorMixin):
1589
1821
  output_dir: Optional[str],
1590
1822
  num_steps: Optional[int],
1591
1823
  activities: Optional[List[str]],
1824
+ with_stack: Optional[bool],
1825
+ record_shapes: Optional[bool],
1592
1826
  ) -> None:
1593
- if self.torch_profiler_activities:
1827
+ if self.profiler_activities:
1594
1828
  return ProfileReqOutput(
1595
1829
  success=False,
1596
1830
  message="Profiling is already in progress. Call /stop_profile first.",
@@ -1602,7 +1836,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
1602
1836
  activities = ["CPU", "GPU"]
1603
1837
 
1604
1838
  self.torch_profiler_output_dir = output_dir
1605
- self.torch_profiler_activities = activities
1839
+ self.profiler_activities = activities
1606
1840
  logger.info(
1607
1841
  "Profiling starts. Traces will be saved to: %s",
1608
1842
  self.torch_profiler_output_dir,
@@ -1619,13 +1853,17 @@ class Scheduler(SchedulerOutputProcessorMixin):
1619
1853
  if torchprof_activities:
1620
1854
  self.torch_profiler = torch.profiler.profile(
1621
1855
  activities=torchprof_activities,
1622
- with_stack=True,
1856
+ with_stack=with_stack if with_stack is not None else True,
1857
+ record_shapes=record_shapes if record_shapes is not None else False,
1623
1858
  )
1624
1859
  self.torch_profiler.start()
1625
1860
 
1626
1861
  if "MEM" in activities:
1627
1862
  torch.cuda.memory._record_memory_history(max_entries=100000)
1628
1863
 
1864
+ if "CUDA_PROFILER" in activities:
1865
+ torch.cuda.cudart().cudaProfilerStart()
1866
+
1629
1867
  if num_steps:
1630
1868
  self.profiler_target_forward_ct = self.forward_ct + num_steps
1631
1869
  # The caller will be notified when reaching profiler_target_forward_ct
@@ -1634,7 +1872,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
1634
1872
  return ProfileReqOutput(success=True, message="Succeeded")
1635
1873
 
1636
1874
  def stop_profile(self) -> None:
1637
- if self.torch_profiler_activities is None:
1875
+ if self.profiler_activities is None:
1638
1876
  return
1639
1877
 
1640
1878
  logger.info("Stop profiling...")
@@ -1647,27 +1885,41 @@ class Scheduler(SchedulerOutputProcessorMixin):
1647
1885
  )
1648
1886
  )
1649
1887
 
1650
- if "MEM" in self.torch_profiler_activities:
1888
+ if "MEM" in self.profiler_activities:
1651
1889
  memory_profile_path = os.path.join(
1652
- self.torch_profiler_trace_dir,
1890
+ self.torch_profiler_output_dir,
1653
1891
  str(time.time()) + f"-TP-{self.tp_rank}-memory" + ".pickle",
1654
1892
  )
1655
1893
  torch.cuda.memory._dump_snapshot(memory_profile_path)
1656
1894
  torch.cuda.memory._record_memory_history(enabled=None)
1657
1895
 
1896
+ if "CUDA_PROFILER" in self.profiler_activities:
1897
+ torch.cuda.cudart().cudaProfilerStop()
1898
+
1658
1899
  logger.info(
1659
1900
  "Profiling done. Traces are saved to: %s",
1660
1901
  self.torch_profiler_output_dir,
1661
1902
  )
1662
1903
  self.torch_profiler = None
1663
1904
  self.torch_profiler_output_dir = None
1664
- self.torch_profiler_activities = None
1905
+ self.profiler_activities = None
1665
1906
 
1666
1907
  if self.profiler_target_forward_ct:
1667
1908
  self.send_to_tokenizer.send_pyobj(
1668
1909
  ProfileReqOutput(success=True, message="Succeeded.")
1669
1910
  )
1670
1911
 
1912
+ def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
1913
+ if recv_req == ExpertDistributionReq.START_RECORD:
1914
+ expert_distribution_recorder.start_record()
1915
+ elif recv_req == ExpertDistributionReq.STOP_RECORD:
1916
+ expert_distribution_recorder.stop_record()
1917
+ elif recv_req == ExpertDistributionReq.DUMP_RECORD:
1918
+ expert_distribution_recorder.dump_record()
1919
+ else:
1920
+ raise ValueError("Unrecognized ExpertDistributionReq value")
1921
+ return ExpertDistributionReqOutput()
1922
+
1671
1923
  def open_session(self, recv_req: OpenSessionReqInput):
1672
1924
  # handle error
1673
1925
  session_id = recv_req.session_id
@@ -1718,7 +1970,6 @@ def run_scheduler_process(
1718
1970
  dp_rank: Optional[int],
1719
1971
  pipe_writer,
1720
1972
  ):
1721
-
1722
1973
  # Generate the prefix
1723
1974
  if dp_rank is None:
1724
1975
  prefix = f" TP{tp_rank}"
@@ -1726,7 +1977,7 @@ def run_scheduler_process(
1726
1977
  prefix = f" DP{dp_rank} TP{tp_rank}"
1727
1978
 
1728
1979
  # Config the process
1729
- # kill_itself_when_parent_died() # This is disabled because it does not work for `--dp 2`
1980
+ kill_itself_when_parent_died()
1730
1981
  setproctitle.setproctitle(f"sglang::scheduler{prefix.replace(' ', '_')}")
1731
1982
  faulthandler.enable()
1732
1983
  parent_process = psutil.Process().parent()
@@ -1753,10 +2004,18 @@ def run_scheduler_process(
1753
2004
  "max_req_input_len": scheduler.max_req_input_len,
1754
2005
  }
1755
2006
  )
1756
- if scheduler.enable_overlap:
1757
- scheduler.event_loop_overlap()
1758
- else:
1759
- scheduler.event_loop_normal()
2007
+ disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode
2008
+
2009
+ if disaggregation_mode == DisaggregationMode.NULL:
2010
+ if scheduler.enable_overlap:
2011
+ scheduler.event_loop_overlap()
2012
+ else:
2013
+ scheduler.event_loop_normal()
2014
+ elif disaggregation_mode == DisaggregationMode.PREFILL:
2015
+ scheduler.event_loop_normal_disagg_prefill()
2016
+ elif disaggregation_mode == DisaggregationMode.DECODE:
2017
+ scheduler.event_loop_normal_disagg_decode()
2018
+
1760
2019
  except Exception:
1761
2020
  traceback = get_exception_traceback()
1762
2021
  logger.error(f"Scheduler hit an exception: {traceback}")
@@ -138,7 +138,7 @@ class Session:
138
138
  token_ids_logprob=req.token_ids_logprob,
139
139
  )
140
140
  if last_req is not None:
141
- new_req.image_inputs = last_req.image_inputs
141
+ new_req.multimodal_inputs = last_req.mm_inputs
142
142
  new_req.tokenizer = tokenizer
143
143
  if abort:
144
144
  new_req.to_abort = True