sglang 0.4.7__py3-none-any.whl → 0.4.7.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +7 -0
  3. sglang/bench_serving.py +1 -1
  4. sglang/lang/interpreter.py +40 -1
  5. sglang/lang/ir.py +27 -0
  6. sglang/math_utils.py +8 -0
  7. sglang/srt/configs/model_config.py +6 -0
  8. sglang/srt/conversation.py +6 -0
  9. sglang/srt/disaggregation/base/__init__.py +1 -1
  10. sglang/srt/disaggregation/base/conn.py +25 -11
  11. sglang/srt/disaggregation/common/__init__.py +5 -1
  12. sglang/srt/disaggregation/common/utils.py +42 -0
  13. sglang/srt/disaggregation/decode.py +196 -51
  14. sglang/srt/disaggregation/fake/__init__.py +1 -1
  15. sglang/srt/disaggregation/fake/conn.py +15 -9
  16. sglang/srt/disaggregation/mooncake/__init__.py +1 -1
  17. sglang/srt/disaggregation/mooncake/conn.py +18 -13
  18. sglang/srt/disaggregation/nixl/__init__.py +6 -1
  19. sglang/srt/disaggregation/nixl/conn.py +17 -12
  20. sglang/srt/disaggregation/prefill.py +128 -43
  21. sglang/srt/disaggregation/utils.py +127 -123
  22. sglang/srt/entrypoints/engine.py +15 -1
  23. sglang/srt/entrypoints/http_server.py +13 -2
  24. sglang/srt/eplb_simulator/__init__.py +1 -0
  25. sglang/srt/eplb_simulator/reader.py +51 -0
  26. sglang/srt/layers/activation.py +19 -0
  27. sglang/srt/layers/attention/aiter_backend.py +15 -2
  28. sglang/srt/layers/attention/cutlass_mla_backend.py +38 -15
  29. sglang/srt/layers/attention/flashattention_backend.py +53 -64
  30. sglang/srt/layers/attention/flashinfer_backend.py +1 -2
  31. sglang/srt/layers/attention/flashinfer_mla_backend.py +22 -24
  32. sglang/srt/layers/attention/flashmla_backend.py +2 -10
  33. sglang/srt/layers/attention/triton_backend.py +119 -119
  34. sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
  35. sglang/srt/layers/attention/vision.py +51 -24
  36. sglang/srt/layers/communicator.py +23 -5
  37. sglang/srt/layers/linear.py +0 -4
  38. sglang/srt/layers/logits_processor.py +0 -12
  39. sglang/srt/layers/moe/ep_moe/kernels.py +6 -5
  40. sglang/srt/layers/moe/ep_moe/layer.py +42 -32
  41. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
  42. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -4
  43. sglang/srt/layers/moe/topk.py +16 -8
  44. sglang/srt/layers/pooler.py +56 -0
  45. sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
  46. sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
  47. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
  48. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
  49. sglang/srt/layers/quantization/fp8_kernel.py +44 -15
  50. sglang/srt/layers/quantization/fp8_utils.py +87 -22
  51. sglang/srt/layers/radix_attention.py +2 -3
  52. sglang/srt/lora/lora_manager.py +79 -34
  53. sglang/srt/lora/mem_pool.py +4 -5
  54. sglang/srt/managers/cache_controller.py +2 -1
  55. sglang/srt/managers/io_struct.py +28 -4
  56. sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
  57. sglang/srt/managers/multimodal_processors/vila.py +85 -0
  58. sglang/srt/managers/schedule_batch.py +39 -6
  59. sglang/srt/managers/scheduler.py +73 -17
  60. sglang/srt/managers/tokenizer_manager.py +29 -2
  61. sglang/srt/mem_cache/chunk_cache.py +1 -0
  62. sglang/srt/mem_cache/hiradix_cache.py +4 -2
  63. sglang/srt/mem_cache/memory_pool.py +111 -407
  64. sglang/srt/mem_cache/memory_pool_host.py +380 -0
  65. sglang/srt/mem_cache/radix_cache.py +36 -12
  66. sglang/srt/model_executor/cuda_graph_runner.py +122 -55
  67. sglang/srt/model_executor/forward_batch_info.py +14 -5
  68. sglang/srt/model_executor/model_runner.py +6 -6
  69. sglang/srt/model_loader/loader.py +8 -1
  70. sglang/srt/models/bert.py +113 -13
  71. sglang/srt/models/deepseek_v2.py +113 -155
  72. sglang/srt/models/internvl.py +46 -102
  73. sglang/srt/models/roberta.py +117 -9
  74. sglang/srt/models/vila.py +305 -0
  75. sglang/srt/openai_api/adapter.py +162 -4
  76. sglang/srt/openai_api/protocol.py +37 -1
  77. sglang/srt/sampling/sampling_batch_info.py +24 -0
  78. sglang/srt/sampling/sampling_params.py +2 -0
  79. sglang/srt/server_args.py +318 -233
  80. sglang/srt/speculative/build_eagle_tree.py +1 -1
  81. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -3
  82. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +5 -2
  83. sglang/srt/speculative/eagle_utils.py +389 -109
  84. sglang/srt/speculative/eagle_worker.py +134 -43
  85. sglang/srt/two_batch_overlap.py +4 -2
  86. sglang/srt/utils.py +58 -0
  87. sglang/test/attention/test_prefix_chunk_info.py +2 -0
  88. sglang/test/runners.py +38 -3
  89. sglang/test/test_block_fp8.py +1 -0
  90. sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
  91. sglang/test/test_block_fp8_ep.py +1 -0
  92. sglang/test/test_utils.py +3 -1
  93. sglang/utils.py +9 -0
  94. sglang/version.py +1 -1
  95. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/METADATA +5 -5
  96. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/RECORD +99 -88
  97. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/WHEEL +0 -0
  98. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/licenses/LICENSE +0 -0
  99. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/top_level.txt +0 -0
@@ -24,10 +24,8 @@ from sglang.srt.disaggregation.common.conn import (
24
24
  CommonKVManager,
25
25
  CommonKVReceiver,
26
26
  )
27
- from sglang.srt.disaggregation.utils import (
28
- DisaggregationMode,
29
- group_concurrent_contiguous,
30
- )
27
+ from sglang.srt.disaggregation.common.utils import group_concurrent_contiguous
28
+ from sglang.srt.disaggregation.utils import DisaggregationMode
31
29
  from sglang.srt.server_args import ServerArgs
32
30
  from sglang.srt.utils import get_local_ip_by_remote
33
31
 
@@ -46,7 +44,7 @@ class TransferInfo:
46
44
  agent_metadata: bytes
47
45
  agent_name: str
48
46
  dst_kv_ptrs: list[int]
49
- dst_kv_indices: npt.NDArray[np.int64]
47
+ dst_kv_indices: npt.NDArray[np.int32]
50
48
  dst_aux_ptrs: list[int]
51
49
  dst_aux_index: int
52
50
  dst_gpu_id: int
@@ -64,7 +62,7 @@ class TransferInfo:
64
62
  agent_metadata=msg[3],
65
63
  agent_name=msg[4].decode("ascii"),
66
64
  dst_kv_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
67
- dst_kv_indices=np.frombuffer(msg[6], dtype=np.int64),
65
+ dst_kv_indices=np.frombuffer(msg[6], dtype=np.int32),
68
66
  dst_aux_ptrs=list(struct.unpack(f"{len(msg[7])//8}Q", msg[7])),
69
67
  dst_aux_index=int(msg[8].decode("ascii")),
70
68
  dst_gpu_id=int(msg[9].decode("ascii")),
@@ -164,9 +162,9 @@ class NixlKVManager(CommonKVManager):
164
162
  def send_kvcache(
165
163
  self,
166
164
  peer_name: str,
167
- prefill_kv_indices: npt.NDArray[np.int64],
165
+ prefill_kv_indices: npt.NDArray[np.int32],
168
166
  dst_kv_ptrs: list[int],
169
- dst_kv_indices: npt.NDArray[np.int64],
167
+ dst_kv_indices: npt.NDArray[np.int32],
170
168
  dst_gpu_id: int,
171
169
  notif: str,
172
170
  ):
@@ -248,7 +246,7 @@ class NixlKVManager(CommonKVManager):
248
246
  def add_transfer_request(
249
247
  self,
250
248
  bootstrap_room: int,
251
- kv_indices: npt.NDArray[np.int64],
249
+ kv_indices: npt.NDArray[np.int32],
252
250
  index_slice: slice,
253
251
  is_last: bool,
254
252
  chunk_id: int,
@@ -350,7 +348,14 @@ class NixlKVManager(CommonKVManager):
350
348
 
351
349
  class NixlKVSender(BaseKVSender):
352
350
 
353
- def __init__(self, mgr: NixlKVManager, bootstrap_addr: str, bootstrap_room: int):
351
+ def __init__(
352
+ self,
353
+ mgr: NixlKVManager,
354
+ bootstrap_addr: str,
355
+ bootstrap_room: int,
356
+ dest_tp_ranks: List[int],
357
+ pp_rank: int,
358
+ ):
354
359
  self.kv_mgr = mgr
355
360
  self.bootstrap_room = bootstrap_room
356
361
  self.aux_index = None
@@ -368,7 +373,7 @@ class NixlKVSender(BaseKVSender):
368
373
 
369
374
  def send(
370
375
  self,
371
- kv_indices: npt.NDArray[np.int64],
376
+ kv_indices: npt.NDArray[np.int32],
372
377
  ):
373
378
  index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices))
374
379
  self.curr_idx += len(kv_indices)
@@ -412,7 +417,7 @@ class NixlKVReceiver(CommonKVReceiver):
412
417
  self.started_transfer = False
413
418
  super().__init__(mgr, bootstrap_addr, bootstrap_room, data_parallel_rank)
414
419
 
415
- def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
420
+ def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
416
421
  for bootstrap_info in self.bootstrap_infos:
417
422
  self.prefill_server_url = (
418
423
  f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
@@ -25,12 +25,13 @@ from collections import deque
25
25
  from http import HTTPStatus
26
26
  from typing import TYPE_CHECKING, List, Optional
27
27
 
28
+ import numpy as np
28
29
  import torch
29
30
 
30
- from sglang.srt.disaggregation.base import BaseKVManager, KVArgs, KVPoll
31
+ from sglang.srt.disaggregation.base import BaseKVManager, KVPoll
31
32
  from sglang.srt.disaggregation.utils import (
33
+ FAKE_BOOTSTRAP_HOST,
32
34
  DisaggregationMode,
33
- FakeBootstrapHost,
34
35
  KVClassType,
35
36
  MetadataBuffers,
36
37
  ReqToMetadataIdxAllocator,
@@ -51,7 +52,6 @@ if TYPE_CHECKING:
51
52
  from sglang.srt.managers.scheduler import GenerationBatchResult, Scheduler
52
53
  from sglang.srt.mem_cache.memory_pool import KVCache
53
54
 
54
-
55
55
  logger = logging.getLogger(__name__)
56
56
 
57
57
 
@@ -68,35 +68,45 @@ class PrefillBootstrapQueue:
68
68
  metadata_buffers: MetadataBuffers,
69
69
  tp_rank: int,
70
70
  tp_size: int,
71
+ gpu_id: int,
71
72
  bootstrap_port: int,
72
73
  gloo_group: ProcessGroup,
73
- transfer_backend: TransferBackend,
74
+ max_total_num_tokens: int,
75
+ decode_tp_size: int,
76
+ decode_dp_size: int,
74
77
  scheduler: Scheduler,
78
+ pp_rank: int,
79
+ pp_size: int,
80
+ transfer_backend: TransferBackend,
75
81
  ):
76
82
  self.token_to_kv_pool = token_to_kv_pool
77
83
  self.draft_token_to_kv_pool = draft_token_to_kv_pool
78
-
79
84
  self.is_mla_backend = is_mla_backend(token_to_kv_pool)
80
-
81
85
  self.metadata_buffers = metadata_buffers
82
86
  self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
83
87
  self.tp_rank = tp_rank
84
88
  self.tp_size = tp_size
85
- self.transfer_backend = transfer_backend
86
- self.scheduler = scheduler
87
- self.kv_manager = self._init_kv_manager()
89
+ self.decode_tp_size = decode_tp_size
90
+ self.decode_dp_size = decode_dp_size
91
+ self.pp_rank = pp_rank
92
+ self.pp_size = pp_size
93
+ self.gpu_id = gpu_id
94
+ self.bootstrap_port = bootstrap_port
88
95
  self.queue: List[Req] = []
96
+ self.pp_rank = pp_rank
97
+ self.pp_size = pp_size
89
98
  self.gloo_group = gloo_group
90
- self.bootstrap_port = bootstrap_port
91
-
92
- def store_prefill_results(self, idx: int, token_id: int):
93
- assert token_id >= 0, f"token_id: {token_id} is negative"
94
- output_id_buffer = self.metadata_buffers[0]
95
- output_id_buffer[idx] = token_id
99
+ self.max_total_num_tokens = max_total_num_tokens
100
+ self.scheduler = scheduler
101
+ self.transfer_backend = transfer_backend
102
+ self.kv_manager = self._init_kv_manager()
96
103
 
97
104
  def _init_kv_manager(self) -> BaseKVManager:
98
- kv_args = KVArgs()
105
+ kv_args_class = get_kv_class(self.transfer_backend, KVClassType.KVARGS)
106
+ kv_args = kv_args_class()
99
107
  kv_args.engine_rank = self.tp_rank
108
+ kv_args.decode_tp_size = self.decode_tp_size // self.decode_dp_size
109
+ kv_args.prefill_pp_size = self.pp_size
100
110
  kv_data_ptrs, kv_data_lens, kv_item_lens = (
101
111
  self.token_to_kv_pool.get_contiguous_buf_infos()
102
112
  )
@@ -115,12 +125,12 @@ class PrefillBootstrapQueue:
115
125
  kv_args.kv_data_lens = kv_data_lens
116
126
  kv_args.kv_item_lens = kv_item_lens
117
127
 
118
- # Define req -> input ids buffer
119
128
  kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = (
120
129
  self.metadata_buffers.get_buf_infos()
121
130
  )
122
131
  kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
123
132
  kv_args.gpu_id = self.scheduler.gpu_id
133
+
124
134
  kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
125
135
  kv_manager = kv_manager_class(
126
136
  kv_args,
@@ -130,23 +140,39 @@ class PrefillBootstrapQueue:
130
140
  )
131
141
  return kv_manager
132
142
 
133
- def add(self, req: Req) -> None:
134
- if req.bootstrap_host == FakeBootstrapHost:
135
- # Fake transfer for warmup reqs
143
+ def add(self, req: Req, num_kv_heads: int) -> None:
144
+ if self._check_if_req_exceed_kv_capacity(req):
145
+ return
146
+
147
+ if req.bootstrap_host == FAKE_BOOTSTRAP_HOST:
136
148
  kv_sender_class = get_kv_class(TransferBackend.FAKE, KVClassType.SENDER)
137
149
  else:
138
150
  kv_sender_class = get_kv_class(self.transfer_backend, KVClassType.SENDER)
151
+
152
+ dest_tp_ranks = [self.tp_rank]
153
+
139
154
  req.disagg_kv_sender = kv_sender_class(
140
155
  mgr=self.kv_manager,
141
156
  bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}",
142
157
  bootstrap_room=req.bootstrap_room,
158
+ dest_tp_ranks=dest_tp_ranks,
159
+ pp_rank=self.pp_rank,
143
160
  )
144
161
  self._process_req(req)
145
162
  self.queue.append(req)
146
163
 
147
- def extend(self, reqs: List[Req]) -> None:
164
+ def extend(self, reqs: List[Req], num_kv_heads: int) -> None:
148
165
  for req in reqs:
149
- self.add(req)
166
+ self.add(req, num_kv_heads)
167
+
168
+ def _check_if_req_exceed_kv_capacity(self, req: Req) -> bool:
169
+ if len(req.origin_input_ids) > self.max_total_num_tokens:
170
+ message = f"Request {req.rid} exceeds the maximum number of tokens: {len(req.origin_input_ids)} > {self.max_total_num_tokens}"
171
+ logger.error(message)
172
+ prepare_abort(req, message)
173
+ self.scheduler.stream_output([req], req.return_logprob)
174
+ return True
175
+ return False
150
176
 
151
177
  def _process_req(self, req: Req) -> None:
152
178
  """
@@ -154,19 +180,40 @@ class PrefillBootstrapQueue:
154
180
  """
155
181
  req.sampling_params.max_new_tokens = 1
156
182
 
157
- def pop_bootstrapped(self) -> List[Req]:
158
- """pop the reqs which has finished bootstrapping"""
183
+ def pop_bootstrapped(
184
+ self,
185
+ return_failed_reqs: bool = False,
186
+ rids_to_check: Optional[List[str]] = None,
187
+ ) -> List[Req]:
188
+ """
189
+ pop the reqs which has finished bootstrapping
190
+
191
+ return_failed_reqs: For PP, on rank 0, also return the failed reqs to notify the next rank
192
+ rids_to_check: For PP, on rank > 0, check the rids from the previous rank has consensus with the current rank.
193
+ """
194
+
159
195
  bootstrapped_reqs = []
196
+ failed_reqs = []
160
197
  indices_to_remove = set()
161
198
 
162
199
  if len(self.queue) == 0:
163
- return []
200
+ if return_failed_reqs is False:
201
+ return []
202
+ else:
203
+ return [], []
164
204
 
165
205
  polls = poll_and_all_reduce(
166
206
  [req.disagg_kv_sender for req in self.queue], self.gloo_group
167
207
  )
168
-
169
208
  for i, (req, poll) in enumerate(zip(self.queue, polls)):
209
+
210
+ if rids_to_check is not None:
211
+ # if req not in reqs_info_to_check, skip
212
+ if req.rid not in rids_to_check:
213
+ continue
214
+ # Either waiting for input or failed
215
+ assert poll == KVPoll.WaitingForInput or poll == KVPoll.Failed
216
+
170
217
  if poll == KVPoll.Bootstrapping:
171
218
  continue
172
219
  elif poll == KVPoll.Failed:
@@ -181,9 +228,10 @@ class PrefillBootstrapQueue:
181
228
  )
182
229
  self.scheduler.stream_output([req], req.return_logprob)
183
230
  indices_to_remove.add(i)
231
+ failed_reqs.append(req)
184
232
  continue
185
233
 
186
- # KV.WaitingForInput
234
+ # KV.WaitingForInput - init here
187
235
  num_kv_indices = len(req.origin_input_ids)
188
236
  if self.req_to_metadata_buffer_idx_allocator.available_size() == 0:
189
237
  break
@@ -192,9 +240,9 @@ class PrefillBootstrapQueue:
192
240
  self.req_to_metadata_buffer_idx_allocator.alloc()
193
241
  )
194
242
  assert req.metadata_buffer_index is not None
243
+
195
244
  num_pages = kv_to_page_num(num_kv_indices, self.token_to_kv_pool.page_size)
196
245
  req.disagg_kv_sender.init(num_pages, req.metadata_buffer_index)
197
-
198
246
  bootstrapped_reqs.append(req)
199
247
  indices_to_remove.add(i)
200
248
 
@@ -202,7 +250,10 @@ class PrefillBootstrapQueue:
202
250
  entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
203
251
  ]
204
252
 
205
- return bootstrapped_reqs
253
+ if return_failed_reqs is False:
254
+ return bootstrapped_reqs
255
+ else:
256
+ return bootstrapped_reqs, failed_reqs
206
257
 
207
258
 
208
259
  class SchedulerDisaggregationPrefillMixin:
@@ -211,7 +262,7 @@ class SchedulerDisaggregationPrefillMixin:
211
262
  """
212
263
 
213
264
  @torch.no_grad()
214
- def event_loop_normal_disagg_prefill(self: Scheduler):
265
+ def event_loop_normal_disagg_prefill(self: Scheduler) -> None:
215
266
  """A normal scheduler loop for prefill worker in disaggregation mode."""
216
267
 
217
268
  while True:
@@ -229,7 +280,6 @@ class SchedulerDisaggregationPrefillMixin:
229
280
  or self.server_args.enable_sp_layernorm
230
281
  ):
231
282
  batch, _ = self.prepare_dp_attn_batch(batch)
232
-
233
283
  self.cur_batch = batch
234
284
 
235
285
  if batch:
@@ -242,6 +292,7 @@ class SchedulerDisaggregationPrefillMixin:
242
292
  if batch is None and len(self.disagg_prefill_inflight_queue) == 0:
243
293
  self.check_memory()
244
294
  self.new_token_ratio = self.init_new_token_ratio
295
+ self.maybe_sleep_on_idle()
245
296
 
246
297
  self.last_batch = batch
247
298
  # HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
@@ -249,7 +300,7 @@ class SchedulerDisaggregationPrefillMixin:
249
300
  self.running_batch.batch_is_full = False
250
301
 
251
302
  @torch.no_grad()
252
- def event_loop_overlap_disagg_prefill(self: Scheduler):
303
+ def event_loop_overlap_disagg_prefill(self: Scheduler) -> None:
253
304
  self.result_queue = deque()
254
305
 
255
306
  while True:
@@ -267,9 +318,7 @@ class SchedulerDisaggregationPrefillMixin:
267
318
  or self.server_args.enable_sp_layernorm
268
319
  ):
269
320
  batch, _ = self.prepare_dp_attn_batch(batch)
270
-
271
321
  self.cur_batch = batch
272
-
273
322
  if batch:
274
323
  result = self.run_batch(batch)
275
324
  self.result_queue.append((batch.copy(), result))
@@ -286,6 +335,9 @@ class SchedulerDisaggregationPrefillMixin:
286
335
 
287
336
  if self.last_batch:
288
337
  tmp_batch, tmp_result = self.result_queue.popleft()
338
+ tmp_batch.next_batch_sampling_info = (
339
+ self.tp_worker.cur_sampling_info if batch else None
340
+ )
289
341
  self.process_batch_result_disagg_prefill(tmp_batch, tmp_result)
290
342
 
291
343
  if len(self.disagg_prefill_inflight_queue) > 0:
@@ -294,6 +346,7 @@ class SchedulerDisaggregationPrefillMixin:
294
346
  if batch is None and len(self.disagg_prefill_inflight_queue) == 0:
295
347
  self.check_memory()
296
348
  self.new_token_ratio = self.init_new_token_ratio
349
+ self.maybe_sleep_on_idle()
297
350
 
298
351
  self.last_batch = batch
299
352
  # HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
@@ -307,7 +360,7 @@ class SchedulerDisaggregationPrefillMixin:
307
360
  launch_done: Optional[threading.Event] = None,
308
361
  ) -> None:
309
362
  """
310
- Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
363
+ Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
311
364
  Adapted from process_batch_result_prefill
312
365
  """
313
366
  (
@@ -323,7 +376,7 @@ class SchedulerDisaggregationPrefillMixin:
323
376
  )
324
377
 
325
378
  logprob_pt = 0
326
- # Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
379
+ # Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
327
380
  if self.enable_overlap:
328
381
  # wait
329
382
  logits_output, next_token_ids, _ = self.tp_worker.resolve_last_batch_result(
@@ -395,11 +448,15 @@ class SchedulerDisaggregationPrefillMixin:
395
448
  # We need to remove the sync in the following function for overlap schedule.
396
449
  self.set_next_batch_sampling_info_done(batch)
397
450
 
398
- def process_disagg_prefill_inflight_queue(self: Scheduler) -> None:
451
+ def process_disagg_prefill_inflight_queue(
452
+ self: Scheduler, rids_to_check: Optional[List[str]] = None
453
+ ) -> List[Req]:
399
454
  """
400
455
  Poll the requests in the middle of transfer. If done, return the request.
456
+ rids_to_check: For PP, on rank > 0, check the rids from the previous rank has consensus with the current rank.
401
457
  """
402
- assert len(self.disagg_prefill_inflight_queue) > 0
458
+ if len(self.disagg_prefill_inflight_queue) == 0:
459
+ return []
403
460
 
404
461
  done_reqs = []
405
462
 
@@ -411,6 +468,14 @@ class SchedulerDisaggregationPrefillMixin:
411
468
  undone_reqs: List[Req] = []
412
469
  # Check .poll() for the reqs in disagg_prefill_inflight_queue. If Success, respond to the client and remove it from the queue
413
470
  for req, poll in zip(self.disagg_prefill_inflight_queue, polls):
471
+
472
+ if rids_to_check is not None:
473
+ if req.rid not in rids_to_check:
474
+ undone_reqs.append(req)
475
+ continue
476
+
477
+ assert poll == KVPoll.Success or poll == KVPoll.Failed
478
+
414
479
  if poll in [KVPoll.WaitingForInput, KVPoll.Transferring]:
415
480
  undone_reqs.append(req)
416
481
  elif poll == KVPoll.Success: # transfer done
@@ -432,11 +497,8 @@ class SchedulerDisaggregationPrefillMixin:
432
497
  req, error_message, status_code=HTTPStatus.INTERNAL_SERVER_ERROR
433
498
  )
434
499
  done_reqs.append(req)
435
-
436
- for req in done_reqs:
437
- self.disagg_prefill_bootstrap_queue.req_to_metadata_buffer_idx_allocator.free(
438
- req.metadata_buffer_index
439
- )
500
+ else:
501
+ assert False, f"Unexpected polling state {poll=}"
440
502
 
441
503
  # Stream requests which have finished transfer
442
504
  self.stream_output(
@@ -444,9 +506,32 @@ class SchedulerDisaggregationPrefillMixin:
444
506
  any(req.return_logprob for req in done_reqs),
445
507
  None,
446
508
  )
509
+ for req in done_reqs:
510
+ req: Req
511
+ self.req_to_metadata_buffer_idx_allocator.free(req.metadata_buffer_index)
512
+ req.metadata_buffer_index = -1
447
513
 
448
514
  self.disagg_prefill_inflight_queue = undone_reqs
449
515
 
516
+ return done_reqs
517
+
518
+ def get_transferred_rids(self: Scheduler) -> List[str]:
519
+ """
520
+ Used by PP, get the transferred rids but **do not pop**
521
+ """
522
+ polls = poll_and_all_reduce(
523
+ [req.disagg_kv_sender for req in self.disagg_prefill_inflight_queue],
524
+ self.tp_worker.get_tp_group().cpu_group,
525
+ )
526
+
527
+ transferred_rids: List[str] = []
528
+
529
+ for req, poll in zip(self.disagg_prefill_inflight_queue, polls):
530
+ if poll == KVPoll.Success or poll == KVPoll.Failed:
531
+ transferred_rids.append(req.rid)
532
+
533
+ return transferred_rids
534
+
450
535
  def process_prefill_chunk(self: Scheduler) -> None:
451
536
  if self.last_batch and self.last_batch.forward_mode.is_extend():
452
537
  if self.chunked_req: