sglang 0.4.7__py3-none-any.whl → 0.4.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (152) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +7 -0
  3. sglang/bench_one_batch.py +8 -6
  4. sglang/bench_serving.py +1 -1
  5. sglang/lang/interpreter.py +40 -1
  6. sglang/lang/ir.py +27 -0
  7. sglang/math_utils.py +8 -0
  8. sglang/srt/_custom_ops.py +2 -2
  9. sglang/srt/code_completion_parser.py +2 -44
  10. sglang/srt/configs/model_config.py +6 -0
  11. sglang/srt/constants.py +3 -0
  12. sglang/srt/conversation.py +19 -3
  13. sglang/srt/custom_op.py +5 -1
  14. sglang/srt/disaggregation/base/__init__.py +1 -1
  15. sglang/srt/disaggregation/base/conn.py +25 -11
  16. sglang/srt/disaggregation/common/__init__.py +5 -1
  17. sglang/srt/disaggregation/common/utils.py +42 -0
  18. sglang/srt/disaggregation/decode.py +211 -72
  19. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
  20. sglang/srt/disaggregation/fake/__init__.py +1 -1
  21. sglang/srt/disaggregation/fake/conn.py +15 -9
  22. sglang/srt/disaggregation/mini_lb.py +34 -4
  23. sglang/srt/disaggregation/mooncake/__init__.py +1 -1
  24. sglang/srt/disaggregation/mooncake/conn.py +30 -29
  25. sglang/srt/disaggregation/nixl/__init__.py +6 -1
  26. sglang/srt/disaggregation/nixl/conn.py +17 -12
  27. sglang/srt/disaggregation/prefill.py +144 -55
  28. sglang/srt/disaggregation/utils.py +155 -123
  29. sglang/srt/distributed/parallel_state.py +12 -4
  30. sglang/srt/entrypoints/engine.py +37 -29
  31. sglang/srt/entrypoints/http_server.py +153 -72
  32. sglang/srt/entrypoints/http_server_engine.py +0 -3
  33. sglang/srt/entrypoints/openai/__init__.py +0 -0
  34. sglang/srt/{openai_api → entrypoints/openai}/protocol.py +84 -10
  35. sglang/srt/entrypoints/openai/serving_base.py +149 -0
  36. sglang/srt/entrypoints/openai/serving_chat.py +921 -0
  37. sglang/srt/entrypoints/openai/serving_completions.py +424 -0
  38. sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
  39. sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
  40. sglang/srt/entrypoints/openai/serving_score.py +61 -0
  41. sglang/srt/entrypoints/openai/usage_processor.py +81 -0
  42. sglang/srt/entrypoints/openai/utils.py +72 -0
  43. sglang/srt/eplb_simulator/__init__.py +1 -0
  44. sglang/srt/eplb_simulator/reader.py +51 -0
  45. sglang/srt/function_call/base_format_detector.py +7 -4
  46. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  47. sglang/srt/function_call/ebnf_composer.py +64 -10
  48. sglang/srt/function_call/function_call_parser.py +6 -6
  49. sglang/srt/function_call/llama32_detector.py +1 -1
  50. sglang/srt/function_call/mistral_detector.py +1 -1
  51. sglang/srt/function_call/pythonic_detector.py +1 -1
  52. sglang/srt/function_call/qwen25_detector.py +1 -1
  53. sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
  54. sglang/srt/layers/activation.py +40 -3
  55. sglang/srt/layers/attention/aiter_backend.py +20 -4
  56. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  57. sglang/srt/layers/attention/cutlass_mla_backend.py +39 -15
  58. sglang/srt/layers/attention/flashattention_backend.py +71 -72
  59. sglang/srt/layers/attention/flashinfer_backend.py +10 -8
  60. sglang/srt/layers/attention/flashinfer_mla_backend.py +29 -28
  61. sglang/srt/layers/attention/flashmla_backend.py +7 -12
  62. sglang/srt/layers/attention/tbo_backend.py +3 -3
  63. sglang/srt/layers/attention/triton_backend.py +138 -130
  64. sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
  65. sglang/srt/layers/attention/vision.py +51 -24
  66. sglang/srt/layers/communicator.py +28 -10
  67. sglang/srt/layers/dp_attention.py +11 -2
  68. sglang/srt/layers/layernorm.py +29 -2
  69. sglang/srt/layers/linear.py +0 -4
  70. sglang/srt/layers/logits_processor.py +2 -14
  71. sglang/srt/layers/moe/ep_moe/kernels.py +165 -7
  72. sglang/srt/layers/moe/ep_moe/layer.py +249 -33
  73. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
  74. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  75. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -4
  76. sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
  77. sglang/srt/layers/moe/topk.py +107 -12
  78. sglang/srt/layers/pooler.py +56 -0
  79. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
  80. sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
  81. sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
  82. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
  83. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
  84. sglang/srt/layers/quantization/fp8.py +25 -17
  85. sglang/srt/layers/quantization/fp8_kernel.py +44 -15
  86. sglang/srt/layers/quantization/fp8_utils.py +87 -22
  87. sglang/srt/layers/quantization/modelopt_quant.py +62 -8
  88. sglang/srt/layers/quantization/utils.py +5 -2
  89. sglang/srt/layers/radix_attention.py +2 -3
  90. sglang/srt/layers/rotary_embedding.py +42 -2
  91. sglang/srt/layers/sampler.py +1 -1
  92. sglang/srt/lora/lora_manager.py +249 -105
  93. sglang/srt/lora/mem_pool.py +53 -50
  94. sglang/srt/lora/utils.py +1 -1
  95. sglang/srt/managers/cache_controller.py +33 -14
  96. sglang/srt/managers/io_struct.py +31 -10
  97. sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
  98. sglang/srt/managers/multimodal_processors/vila.py +85 -0
  99. sglang/srt/managers/schedule_batch.py +79 -37
  100. sglang/srt/managers/schedule_policy.py +70 -56
  101. sglang/srt/managers/scheduler.py +220 -79
  102. sglang/srt/managers/template_manager.py +226 -0
  103. sglang/srt/managers/tokenizer_manager.py +40 -10
  104. sglang/srt/managers/tp_worker.py +12 -2
  105. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  106. sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
  107. sglang/srt/mem_cache/base_prefix_cache.py +52 -8
  108. sglang/srt/mem_cache/chunk_cache.py +11 -15
  109. sglang/srt/mem_cache/hiradix_cache.py +38 -25
  110. sglang/srt/mem_cache/memory_pool.py +213 -505
  111. sglang/srt/mem_cache/memory_pool_host.py +380 -0
  112. sglang/srt/mem_cache/radix_cache.py +56 -28
  113. sglang/srt/model_executor/cuda_graph_runner.py +198 -100
  114. sglang/srt/model_executor/forward_batch_info.py +32 -10
  115. sglang/srt/model_executor/model_runner.py +28 -12
  116. sglang/srt/model_loader/loader.py +16 -2
  117. sglang/srt/model_loader/weight_utils.py +11 -2
  118. sglang/srt/models/bert.py +113 -13
  119. sglang/srt/models/deepseek_nextn.py +29 -27
  120. sglang/srt/models/deepseek_v2.py +213 -173
  121. sglang/srt/models/glm4.py +312 -0
  122. sglang/srt/models/internvl.py +46 -102
  123. sglang/srt/models/mimo_mtp.py +2 -18
  124. sglang/srt/models/roberta.py +117 -9
  125. sglang/srt/models/vila.py +305 -0
  126. sglang/srt/reasoning_parser.py +21 -11
  127. sglang/srt/sampling/sampling_batch_info.py +24 -0
  128. sglang/srt/sampling/sampling_params.py +2 -0
  129. sglang/srt/server_args.py +351 -238
  130. sglang/srt/speculative/build_eagle_tree.py +1 -1
  131. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -9
  132. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +130 -14
  133. sglang/srt/speculative/eagle_utils.py +468 -116
  134. sglang/srt/speculative/eagle_worker.py +258 -84
  135. sglang/srt/torch_memory_saver_adapter.py +19 -15
  136. sglang/srt/two_batch_overlap.py +4 -2
  137. sglang/srt/utils.py +235 -11
  138. sglang/test/attention/test_prefix_chunk_info.py +2 -0
  139. sglang/test/runners.py +38 -3
  140. sglang/test/test_block_fp8.py +1 -0
  141. sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
  142. sglang/test/test_block_fp8_ep.py +2 -0
  143. sglang/test/test_utils.py +4 -1
  144. sglang/utils.py +9 -0
  145. sglang/version.py +1 -1
  146. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/METADATA +8 -14
  147. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/RECORD +150 -128
  148. sglang/srt/entrypoints/verl_engine.py +0 -179
  149. sglang/srt/openai_api/adapter.py +0 -1990
  150. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
  151. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
  152. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -27,10 +27,10 @@ from typing import TYPE_CHECKING, List, Optional
27
27
 
28
28
  import torch
29
29
 
30
- from sglang.srt.disaggregation.base import BaseKVManager, KVArgs, KVPoll
30
+ from sglang.srt.disaggregation.base import BaseKVManager, KVPoll
31
31
  from sglang.srt.disaggregation.utils import (
32
+ FAKE_BOOTSTRAP_HOST,
32
33
  DisaggregationMode,
33
- FakeBootstrapHost,
34
34
  KVClassType,
35
35
  MetadataBuffers,
36
36
  ReqToMetadataIdxAllocator,
@@ -44,6 +44,7 @@ from sglang.srt.disaggregation.utils import (
44
44
  )
45
45
  from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
46
46
  from sglang.srt.model_executor.forward_batch_info import ForwardMode
47
+ from sglang.srt.utils import require_mlp_sync
47
48
 
48
49
  if TYPE_CHECKING:
49
50
  from torch.distributed import ProcessGroup
@@ -51,7 +52,6 @@ if TYPE_CHECKING:
51
52
  from sglang.srt.managers.scheduler import GenerationBatchResult, Scheduler
52
53
  from sglang.srt.mem_cache.memory_pool import KVCache
53
54
 
54
-
55
55
  logger = logging.getLogger(__name__)
56
56
 
57
57
 
@@ -68,35 +68,45 @@ class PrefillBootstrapQueue:
68
68
  metadata_buffers: MetadataBuffers,
69
69
  tp_rank: int,
70
70
  tp_size: int,
71
+ gpu_id: int,
71
72
  bootstrap_port: int,
72
73
  gloo_group: ProcessGroup,
73
- transfer_backend: TransferBackend,
74
+ max_total_num_tokens: int,
75
+ decode_tp_size: int,
76
+ decode_dp_size: int,
74
77
  scheduler: Scheduler,
78
+ pp_rank: int,
79
+ pp_size: int,
80
+ transfer_backend: TransferBackend,
75
81
  ):
76
82
  self.token_to_kv_pool = token_to_kv_pool
77
83
  self.draft_token_to_kv_pool = draft_token_to_kv_pool
78
-
79
84
  self.is_mla_backend = is_mla_backend(token_to_kv_pool)
80
-
81
85
  self.metadata_buffers = metadata_buffers
82
86
  self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
83
87
  self.tp_rank = tp_rank
84
88
  self.tp_size = tp_size
85
- self.transfer_backend = transfer_backend
86
- self.scheduler = scheduler
87
- self.kv_manager = self._init_kv_manager()
89
+ self.decode_tp_size = decode_tp_size
90
+ self.decode_dp_size = decode_dp_size
91
+ self.pp_rank = pp_rank
92
+ self.pp_size = pp_size
93
+ self.gpu_id = gpu_id
94
+ self.bootstrap_port = bootstrap_port
88
95
  self.queue: List[Req] = []
96
+ self.pp_rank = pp_rank
97
+ self.pp_size = pp_size
89
98
  self.gloo_group = gloo_group
90
- self.bootstrap_port = bootstrap_port
91
-
92
- def store_prefill_results(self, idx: int, token_id: int):
93
- assert token_id >= 0, f"token_id: {token_id} is negative"
94
- output_id_buffer = self.metadata_buffers[0]
95
- output_id_buffer[idx] = token_id
99
+ self.max_total_num_tokens = max_total_num_tokens
100
+ self.scheduler = scheduler
101
+ self.transfer_backend = transfer_backend
102
+ self.kv_manager = self._init_kv_manager()
96
103
 
97
104
  def _init_kv_manager(self) -> BaseKVManager:
98
- kv_args = KVArgs()
105
+ kv_args_class = get_kv_class(self.transfer_backend, KVClassType.KVARGS)
106
+ kv_args = kv_args_class()
99
107
  kv_args.engine_rank = self.tp_rank
108
+ kv_args.decode_tp_size = self.decode_tp_size // self.decode_dp_size
109
+ kv_args.prefill_pp_size = self.pp_size
100
110
  kv_data_ptrs, kv_data_lens, kv_item_lens = (
101
111
  self.token_to_kv_pool.get_contiguous_buf_infos()
102
112
  )
@@ -115,12 +125,12 @@ class PrefillBootstrapQueue:
115
125
  kv_args.kv_data_lens = kv_data_lens
116
126
  kv_args.kv_item_lens = kv_item_lens
117
127
 
118
- # Define req -> input ids buffer
119
128
  kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = (
120
129
  self.metadata_buffers.get_buf_infos()
121
130
  )
122
131
  kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
123
132
  kv_args.gpu_id = self.scheduler.gpu_id
133
+
124
134
  kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
125
135
  kv_manager = kv_manager_class(
126
136
  kv_args,
@@ -130,23 +140,39 @@ class PrefillBootstrapQueue:
130
140
  )
131
141
  return kv_manager
132
142
 
133
- def add(self, req: Req) -> None:
134
- if req.bootstrap_host == FakeBootstrapHost:
135
- # Fake transfer for warmup reqs
143
+ def add(self, req: Req, num_kv_heads: int) -> None:
144
+ if self._check_if_req_exceed_kv_capacity(req):
145
+ return
146
+
147
+ if req.bootstrap_host == FAKE_BOOTSTRAP_HOST:
136
148
  kv_sender_class = get_kv_class(TransferBackend.FAKE, KVClassType.SENDER)
137
149
  else:
138
150
  kv_sender_class = get_kv_class(self.transfer_backend, KVClassType.SENDER)
151
+
152
+ dest_tp_ranks = [self.tp_rank]
153
+
139
154
  req.disagg_kv_sender = kv_sender_class(
140
155
  mgr=self.kv_manager,
141
156
  bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}",
142
157
  bootstrap_room=req.bootstrap_room,
158
+ dest_tp_ranks=dest_tp_ranks,
159
+ pp_rank=self.pp_rank,
143
160
  )
144
161
  self._process_req(req)
145
162
  self.queue.append(req)
146
163
 
147
- def extend(self, reqs: List[Req]) -> None:
164
+ def extend(self, reqs: List[Req], num_kv_heads: int) -> None:
148
165
  for req in reqs:
149
- self.add(req)
166
+ self.add(req, num_kv_heads)
167
+
168
+ def _check_if_req_exceed_kv_capacity(self, req: Req) -> bool:
169
+ if len(req.origin_input_ids) > self.max_total_num_tokens:
170
+ message = f"Request {req.rid} exceeds the maximum number of tokens: {len(req.origin_input_ids)} > {self.max_total_num_tokens}"
171
+ logger.error(message)
172
+ prepare_abort(req, message)
173
+ self.scheduler.stream_output([req], req.return_logprob)
174
+ return True
175
+ return False
150
176
 
151
177
  def _process_req(self, req: Req) -> None:
152
178
  """
@@ -154,19 +180,40 @@ class PrefillBootstrapQueue:
154
180
  """
155
181
  req.sampling_params.max_new_tokens = 1
156
182
 
157
- def pop_bootstrapped(self) -> List[Req]:
158
- """pop the reqs which has finished bootstrapping"""
183
+ def pop_bootstrapped(
184
+ self,
185
+ return_failed_reqs: bool = False,
186
+ rids_to_check: Optional[List[str]] = None,
187
+ ) -> List[Req]:
188
+ """
189
+ pop the reqs which has finished bootstrapping
190
+
191
+ return_failed_reqs: For PP, on rank 0, also return the failed reqs to notify the next rank
192
+ rids_to_check: For PP, on rank > 0, check the rids from the previous rank has consensus with the current rank.
193
+ """
194
+
159
195
  bootstrapped_reqs = []
196
+ failed_reqs = []
160
197
  indices_to_remove = set()
161
198
 
162
199
  if len(self.queue) == 0:
163
- return []
200
+ if return_failed_reqs is False:
201
+ return []
202
+ else:
203
+ return [], []
164
204
 
165
205
  polls = poll_and_all_reduce(
166
206
  [req.disagg_kv_sender for req in self.queue], self.gloo_group
167
207
  )
168
-
169
208
  for i, (req, poll) in enumerate(zip(self.queue, polls)):
209
+
210
+ if rids_to_check is not None:
211
+ # if req not in reqs_info_to_check, skip
212
+ if req.rid not in rids_to_check:
213
+ continue
214
+ # Either waiting for input or failed
215
+ assert poll == KVPoll.WaitingForInput or poll == KVPoll.Failed
216
+
170
217
  if poll == KVPoll.Bootstrapping:
171
218
  continue
172
219
  elif poll == KVPoll.Failed:
@@ -181,9 +228,10 @@ class PrefillBootstrapQueue:
181
228
  )
182
229
  self.scheduler.stream_output([req], req.return_logprob)
183
230
  indices_to_remove.add(i)
231
+ failed_reqs.append(req)
184
232
  continue
185
233
 
186
- # KV.WaitingForInput
234
+ # KV.WaitingForInput - init here
187
235
  num_kv_indices = len(req.origin_input_ids)
188
236
  if self.req_to_metadata_buffer_idx_allocator.available_size() == 0:
189
237
  break
@@ -192,9 +240,9 @@ class PrefillBootstrapQueue:
192
240
  self.req_to_metadata_buffer_idx_allocator.alloc()
193
241
  )
194
242
  assert req.metadata_buffer_index is not None
243
+
195
244
  num_pages = kv_to_page_num(num_kv_indices, self.token_to_kv_pool.page_size)
196
245
  req.disagg_kv_sender.init(num_pages, req.metadata_buffer_index)
197
-
198
246
  bootstrapped_reqs.append(req)
199
247
  indices_to_remove.add(i)
200
248
 
@@ -202,7 +250,10 @@ class PrefillBootstrapQueue:
202
250
  entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
203
251
  ]
204
252
 
205
- return bootstrapped_reqs
253
+ if return_failed_reqs is False:
254
+ return bootstrapped_reqs
255
+ else:
256
+ return bootstrapped_reqs, failed_reqs
206
257
 
207
258
 
208
259
  class SchedulerDisaggregationPrefillMixin:
@@ -211,7 +262,7 @@ class SchedulerDisaggregationPrefillMixin:
211
262
  """
212
263
 
213
264
  @torch.no_grad()
214
- def event_loop_normal_disagg_prefill(self: Scheduler):
265
+ def event_loop_normal_disagg_prefill(self: Scheduler) -> None:
215
266
  """A normal scheduler loop for prefill worker in disaggregation mode."""
216
267
 
217
268
  while True:
@@ -223,13 +274,8 @@ class SchedulerDisaggregationPrefillMixin:
223
274
  self.process_prefill_chunk()
224
275
  batch = self.get_new_batch_prefill()
225
276
 
226
- # Handle DP attention
227
- if (
228
- self.server_args.enable_dp_attention
229
- or self.server_args.enable_sp_layernorm
230
- ):
231
- batch, _ = self.prepare_dp_attn_batch(batch)
232
-
277
+ if require_mlp_sync(self.server_args):
278
+ batch, _ = self.prepare_mlp_sync_batch(batch)
233
279
  self.cur_batch = batch
234
280
 
235
281
  if batch:
@@ -242,6 +288,7 @@ class SchedulerDisaggregationPrefillMixin:
242
288
  if batch is None and len(self.disagg_prefill_inflight_queue) == 0:
243
289
  self.check_memory()
244
290
  self.new_token_ratio = self.init_new_token_ratio
291
+ self.maybe_sleep_on_idle()
245
292
 
246
293
  self.last_batch = batch
247
294
  # HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
@@ -249,7 +296,7 @@ class SchedulerDisaggregationPrefillMixin:
249
296
  self.running_batch.batch_is_full = False
250
297
 
251
298
  @torch.no_grad()
252
- def event_loop_overlap_disagg_prefill(self: Scheduler):
299
+ def event_loop_overlap_disagg_prefill(self: Scheduler) -> None:
253
300
  self.result_queue = deque()
254
301
 
255
302
  while True:
@@ -261,15 +308,9 @@ class SchedulerDisaggregationPrefillMixin:
261
308
  self.process_prefill_chunk()
262
309
  batch = self.get_new_batch_prefill()
263
310
 
264
- # Handle DP attention
265
- if (
266
- self.server_args.enable_dp_attention
267
- or self.server_args.enable_sp_layernorm
268
- ):
269
- batch, _ = self.prepare_dp_attn_batch(batch)
270
-
311
+ if require_mlp_sync(self.server_args):
312
+ batch, _ = self.prepare_mlp_sync_batch(batch)
271
313
  self.cur_batch = batch
272
-
273
314
  if batch:
274
315
  result = self.run_batch(batch)
275
316
  self.result_queue.append((batch.copy(), result))
@@ -286,6 +327,9 @@ class SchedulerDisaggregationPrefillMixin:
286
327
 
287
328
  if self.last_batch:
288
329
  tmp_batch, tmp_result = self.result_queue.popleft()
330
+ tmp_batch.next_batch_sampling_info = (
331
+ self.tp_worker.cur_sampling_info if batch else None
332
+ )
289
333
  self.process_batch_result_disagg_prefill(tmp_batch, tmp_result)
290
334
 
291
335
  if len(self.disagg_prefill_inflight_queue) > 0:
@@ -294,6 +338,7 @@ class SchedulerDisaggregationPrefillMixin:
294
338
  if batch is None and len(self.disagg_prefill_inflight_queue) == 0:
295
339
  self.check_memory()
296
340
  self.new_token_ratio = self.init_new_token_ratio
341
+ self.maybe_sleep_on_idle()
297
342
 
298
343
  self.last_batch = batch
299
344
  # HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
@@ -307,7 +352,7 @@ class SchedulerDisaggregationPrefillMixin:
307
352
  launch_done: Optional[threading.Event] = None,
308
353
  ) -> None:
309
354
  """
310
- Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
355
+ Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
311
356
  Adapted from process_batch_result_prefill
312
357
  """
313
358
  (
@@ -323,7 +368,7 @@ class SchedulerDisaggregationPrefillMixin:
323
368
  )
324
369
 
325
370
  logprob_pt = 0
326
- # Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
371
+ # Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
327
372
  if self.enable_overlap:
328
373
  # wait
329
374
  logits_output, next_token_ids, _ = self.tp_worker.resolve_last_batch_result(
@@ -340,6 +385,8 @@ class SchedulerDisaggregationPrefillMixin:
340
385
  logits_output.input_token_logprobs = tuple(
341
386
  logits_output.input_token_logprobs.tolist()
342
387
  )
388
+
389
+ hidden_state_offset = 0
343
390
  for i, (req, next_token_id) in enumerate(
344
391
  zip(batch.reqs, next_token_ids, strict=True)
345
392
  ):
@@ -349,6 +396,16 @@ class SchedulerDisaggregationPrefillMixin:
349
396
  req.output_ids.append(next_token_id)
350
397
  self.tree_cache.cache_unfinished_req(req) # update the tree and lock
351
398
  self.disagg_prefill_inflight_queue.append(req)
399
+ if logits_output.hidden_states is not None:
400
+ last_hidden_index = (
401
+ hidden_state_offset + extend_input_len_per_req[i] - 1
402
+ )
403
+ req.hidden_states_tensor = (
404
+ logits_output.hidden_states[last_hidden_index].cpu().clone()
405
+ )
406
+ hidden_state_offset += extend_input_len_per_req[i]
407
+ else:
408
+ req.hidden_states_tensor = None
352
409
  if req.return_logprob:
353
410
  assert extend_logprob_start_len_per_req is not None
354
411
  assert extend_input_len_per_req is not None
@@ -395,11 +452,15 @@ class SchedulerDisaggregationPrefillMixin:
395
452
  # We need to remove the sync in the following function for overlap schedule.
396
453
  self.set_next_batch_sampling_info_done(batch)
397
454
 
398
- def process_disagg_prefill_inflight_queue(self: Scheduler) -> None:
455
+ def process_disagg_prefill_inflight_queue(
456
+ self: Scheduler, rids_to_check: Optional[List[str]] = None
457
+ ) -> List[Req]:
399
458
  """
400
459
  Poll the requests in the middle of transfer. If done, return the request.
460
+ rids_to_check: For PP, on rank > 0, check the rids from the previous rank has consensus with the current rank.
401
461
  """
402
- assert len(self.disagg_prefill_inflight_queue) > 0
462
+ if len(self.disagg_prefill_inflight_queue) == 0:
463
+ return []
403
464
 
404
465
  done_reqs = []
405
466
 
@@ -411,6 +472,14 @@ class SchedulerDisaggregationPrefillMixin:
411
472
  undone_reqs: List[Req] = []
412
473
  # Check .poll() for the reqs in disagg_prefill_inflight_queue. If Success, respond to the client and remove it from the queue
413
474
  for req, poll in zip(self.disagg_prefill_inflight_queue, polls):
475
+
476
+ if rids_to_check is not None:
477
+ if req.rid not in rids_to_check:
478
+ undone_reqs.append(req)
479
+ continue
480
+
481
+ assert poll == KVPoll.Success or poll == KVPoll.Failed
482
+
414
483
  if poll in [KVPoll.WaitingForInput, KVPoll.Transferring]:
415
484
  undone_reqs.append(req)
416
485
  elif poll == KVPoll.Success: # transfer done
@@ -432,11 +501,8 @@ class SchedulerDisaggregationPrefillMixin:
432
501
  req, error_message, status_code=HTTPStatus.INTERNAL_SERVER_ERROR
433
502
  )
434
503
  done_reqs.append(req)
435
-
436
- for req in done_reqs:
437
- self.disagg_prefill_bootstrap_queue.req_to_metadata_buffer_idx_allocator.free(
438
- req.metadata_buffer_index
439
- )
504
+ else:
505
+ assert False, f"Unexpected polling state {poll=}"
440
506
 
441
507
  # Stream requests which have finished transfer
442
508
  self.stream_output(
@@ -444,9 +510,32 @@ class SchedulerDisaggregationPrefillMixin:
444
510
  any(req.return_logprob for req in done_reqs),
445
511
  None,
446
512
  )
513
+ for req in done_reqs:
514
+ req: Req
515
+ self.req_to_metadata_buffer_idx_allocator.free(req.metadata_buffer_index)
516
+ req.metadata_buffer_index = -1
447
517
 
448
518
  self.disagg_prefill_inflight_queue = undone_reqs
449
519
 
520
+ return done_reqs
521
+
522
+ def get_transferred_rids(self: Scheduler) -> List[str]:
523
+ """
524
+ Used by PP, get the transferred rids but **do not pop**
525
+ """
526
+ polls = poll_and_all_reduce(
527
+ [req.disagg_kv_sender for req in self.disagg_prefill_inflight_queue],
528
+ self.tp_worker.get_tp_group().cpu_group,
529
+ )
530
+
531
+ transferred_rids: List[str] = []
532
+
533
+ for req, poll in zip(self.disagg_prefill_inflight_queue, polls):
534
+ if poll == KVPoll.Success or poll == KVPoll.Failed:
535
+ transferred_rids.append(req.rid)
536
+
537
+ return transferred_rids
538
+
450
539
  def process_prefill_chunk(self: Scheduler) -> None:
451
540
  if self.last_batch and self.last_batch.forward_mode.is_extend():
452
541
  if self.chunked_req: