sglang 0.4.7__py3-none-any.whl → 0.4.7.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +7 -0
- sglang/bench_serving.py +1 -1
- sglang/lang/interpreter.py +40 -1
- sglang/lang/ir.py +27 -0
- sglang/math_utils.py +8 -0
- sglang/srt/configs/model_config.py +6 -0
- sglang/srt/conversation.py +6 -0
- sglang/srt/disaggregation/base/__init__.py +1 -1
- sglang/srt/disaggregation/base/conn.py +25 -11
- sglang/srt/disaggregation/common/__init__.py +5 -1
- sglang/srt/disaggregation/common/utils.py +42 -0
- sglang/srt/disaggregation/decode.py +196 -51
- sglang/srt/disaggregation/fake/__init__.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +15 -9
- sglang/srt/disaggregation/mooncake/__init__.py +1 -1
- sglang/srt/disaggregation/mooncake/conn.py +18 -13
- sglang/srt/disaggregation/nixl/__init__.py +6 -1
- sglang/srt/disaggregation/nixl/conn.py +17 -12
- sglang/srt/disaggregation/prefill.py +128 -43
- sglang/srt/disaggregation/utils.py +127 -123
- sglang/srt/entrypoints/engine.py +15 -1
- sglang/srt/entrypoints/http_server.py +13 -2
- sglang/srt/eplb_simulator/__init__.py +1 -0
- sglang/srt/eplb_simulator/reader.py +51 -0
- sglang/srt/layers/activation.py +19 -0
- sglang/srt/layers/attention/aiter_backend.py +15 -2
- sglang/srt/layers/attention/cutlass_mla_backend.py +38 -15
- sglang/srt/layers/attention/flashattention_backend.py +53 -64
- sglang/srt/layers/attention/flashinfer_backend.py +1 -2
- sglang/srt/layers/attention/flashinfer_mla_backend.py +22 -24
- sglang/srt/layers/attention/flashmla_backend.py +2 -10
- sglang/srt/layers/attention/triton_backend.py +119 -119
- sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
- sglang/srt/layers/attention/vision.py +51 -24
- sglang/srt/layers/communicator.py +23 -5
- sglang/srt/layers/linear.py +0 -4
- sglang/srt/layers/logits_processor.py +0 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +6 -5
- sglang/srt/layers/moe/ep_moe/layer.py +42 -32
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -4
- sglang/srt/layers/moe/topk.py +16 -8
- sglang/srt/layers/pooler.py +56 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
- sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
- sglang/srt/layers/quantization/fp8_kernel.py +44 -15
- sglang/srt/layers/quantization/fp8_utils.py +87 -22
- sglang/srt/layers/radix_attention.py +2 -3
- sglang/srt/lora/lora_manager.py +79 -34
- sglang/srt/lora/mem_pool.py +4 -5
- sglang/srt/managers/cache_controller.py +2 -1
- sglang/srt/managers/io_struct.py +28 -4
- sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
- sglang/srt/managers/multimodal_processors/vila.py +85 -0
- sglang/srt/managers/schedule_batch.py +39 -6
- sglang/srt/managers/scheduler.py +73 -17
- sglang/srt/managers/tokenizer_manager.py +29 -2
- sglang/srt/mem_cache/chunk_cache.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +4 -2
- sglang/srt/mem_cache/memory_pool.py +111 -407
- sglang/srt/mem_cache/memory_pool_host.py +380 -0
- sglang/srt/mem_cache/radix_cache.py +36 -12
- sglang/srt/model_executor/cuda_graph_runner.py +122 -55
- sglang/srt/model_executor/forward_batch_info.py +14 -5
- sglang/srt/model_executor/model_runner.py +6 -6
- sglang/srt/model_loader/loader.py +8 -1
- sglang/srt/models/bert.py +113 -13
- sglang/srt/models/deepseek_v2.py +113 -155
- sglang/srt/models/internvl.py +46 -102
- sglang/srt/models/roberta.py +117 -9
- sglang/srt/models/vila.py +305 -0
- sglang/srt/openai_api/adapter.py +162 -4
- sglang/srt/openai_api/protocol.py +37 -1
- sglang/srt/sampling/sampling_batch_info.py +24 -0
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +318 -233
- sglang/srt/speculative/build_eagle_tree.py +1 -1
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -3
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +5 -2
- sglang/srt/speculative/eagle_utils.py +389 -109
- sglang/srt/speculative/eagle_worker.py +134 -43
- sglang/srt/two_batch_overlap.py +4 -2
- sglang/srt/utils.py +58 -0
- sglang/test/attention/test_prefix_chunk_info.py +2 -0
- sglang/test/runners.py +38 -3
- sglang/test/test_block_fp8.py +1 -0
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
- sglang/test/test_block_fp8_ep.py +1 -0
- sglang/test/test_utils.py +3 -1
- sglang/utils.py +9 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/METADATA +5 -5
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/RECORD +99 -88
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/top_level.txt +0 -0
@@ -24,10 +24,8 @@ from sglang.srt.disaggregation.common.conn import (
|
|
24
24
|
CommonKVManager,
|
25
25
|
CommonKVReceiver,
|
26
26
|
)
|
27
|
-
from sglang.srt.disaggregation.utils import
|
28
|
-
|
29
|
-
group_concurrent_contiguous,
|
30
|
-
)
|
27
|
+
from sglang.srt.disaggregation.common.utils import group_concurrent_contiguous
|
28
|
+
from sglang.srt.disaggregation.utils import DisaggregationMode
|
31
29
|
from sglang.srt.server_args import ServerArgs
|
32
30
|
from sglang.srt.utils import get_local_ip_by_remote
|
33
31
|
|
@@ -46,7 +44,7 @@ class TransferInfo:
|
|
46
44
|
agent_metadata: bytes
|
47
45
|
agent_name: str
|
48
46
|
dst_kv_ptrs: list[int]
|
49
|
-
dst_kv_indices: npt.NDArray[np.
|
47
|
+
dst_kv_indices: npt.NDArray[np.int32]
|
50
48
|
dst_aux_ptrs: list[int]
|
51
49
|
dst_aux_index: int
|
52
50
|
dst_gpu_id: int
|
@@ -64,7 +62,7 @@ class TransferInfo:
|
|
64
62
|
agent_metadata=msg[3],
|
65
63
|
agent_name=msg[4].decode("ascii"),
|
66
64
|
dst_kv_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
|
67
|
-
dst_kv_indices=np.frombuffer(msg[6], dtype=np.
|
65
|
+
dst_kv_indices=np.frombuffer(msg[6], dtype=np.int32),
|
68
66
|
dst_aux_ptrs=list(struct.unpack(f"{len(msg[7])//8}Q", msg[7])),
|
69
67
|
dst_aux_index=int(msg[8].decode("ascii")),
|
70
68
|
dst_gpu_id=int(msg[9].decode("ascii")),
|
@@ -164,9 +162,9 @@ class NixlKVManager(CommonKVManager):
|
|
164
162
|
def send_kvcache(
|
165
163
|
self,
|
166
164
|
peer_name: str,
|
167
|
-
prefill_kv_indices: npt.NDArray[np.
|
165
|
+
prefill_kv_indices: npt.NDArray[np.int32],
|
168
166
|
dst_kv_ptrs: list[int],
|
169
|
-
dst_kv_indices: npt.NDArray[np.
|
167
|
+
dst_kv_indices: npt.NDArray[np.int32],
|
170
168
|
dst_gpu_id: int,
|
171
169
|
notif: str,
|
172
170
|
):
|
@@ -248,7 +246,7 @@ class NixlKVManager(CommonKVManager):
|
|
248
246
|
def add_transfer_request(
|
249
247
|
self,
|
250
248
|
bootstrap_room: int,
|
251
|
-
kv_indices: npt.NDArray[np.
|
249
|
+
kv_indices: npt.NDArray[np.int32],
|
252
250
|
index_slice: slice,
|
253
251
|
is_last: bool,
|
254
252
|
chunk_id: int,
|
@@ -350,7 +348,14 @@ class NixlKVManager(CommonKVManager):
|
|
350
348
|
|
351
349
|
class NixlKVSender(BaseKVSender):
|
352
350
|
|
353
|
-
def __init__(
|
351
|
+
def __init__(
|
352
|
+
self,
|
353
|
+
mgr: NixlKVManager,
|
354
|
+
bootstrap_addr: str,
|
355
|
+
bootstrap_room: int,
|
356
|
+
dest_tp_ranks: List[int],
|
357
|
+
pp_rank: int,
|
358
|
+
):
|
354
359
|
self.kv_mgr = mgr
|
355
360
|
self.bootstrap_room = bootstrap_room
|
356
361
|
self.aux_index = None
|
@@ -368,7 +373,7 @@ class NixlKVSender(BaseKVSender):
|
|
368
373
|
|
369
374
|
def send(
|
370
375
|
self,
|
371
|
-
kv_indices: npt.NDArray[np.
|
376
|
+
kv_indices: npt.NDArray[np.int32],
|
372
377
|
):
|
373
378
|
index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices))
|
374
379
|
self.curr_idx += len(kv_indices)
|
@@ -412,7 +417,7 @@ class NixlKVReceiver(CommonKVReceiver):
|
|
412
417
|
self.started_transfer = False
|
413
418
|
super().__init__(mgr, bootstrap_addr, bootstrap_room, data_parallel_rank)
|
414
419
|
|
415
|
-
def init(self, kv_indices: npt.NDArray[np.
|
420
|
+
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
|
416
421
|
for bootstrap_info in self.bootstrap_infos:
|
417
422
|
self.prefill_server_url = (
|
418
423
|
f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
|
@@ -25,12 +25,13 @@ from collections import deque
|
|
25
25
|
from http import HTTPStatus
|
26
26
|
from typing import TYPE_CHECKING, List, Optional
|
27
27
|
|
28
|
+
import numpy as np
|
28
29
|
import torch
|
29
30
|
|
30
|
-
from sglang.srt.disaggregation.base import BaseKVManager,
|
31
|
+
from sglang.srt.disaggregation.base import BaseKVManager, KVPoll
|
31
32
|
from sglang.srt.disaggregation.utils import (
|
33
|
+
FAKE_BOOTSTRAP_HOST,
|
32
34
|
DisaggregationMode,
|
33
|
-
FakeBootstrapHost,
|
34
35
|
KVClassType,
|
35
36
|
MetadataBuffers,
|
36
37
|
ReqToMetadataIdxAllocator,
|
@@ -51,7 +52,6 @@ if TYPE_CHECKING:
|
|
51
52
|
from sglang.srt.managers.scheduler import GenerationBatchResult, Scheduler
|
52
53
|
from sglang.srt.mem_cache.memory_pool import KVCache
|
53
54
|
|
54
|
-
|
55
55
|
logger = logging.getLogger(__name__)
|
56
56
|
|
57
57
|
|
@@ -68,35 +68,45 @@ class PrefillBootstrapQueue:
|
|
68
68
|
metadata_buffers: MetadataBuffers,
|
69
69
|
tp_rank: int,
|
70
70
|
tp_size: int,
|
71
|
+
gpu_id: int,
|
71
72
|
bootstrap_port: int,
|
72
73
|
gloo_group: ProcessGroup,
|
73
|
-
|
74
|
+
max_total_num_tokens: int,
|
75
|
+
decode_tp_size: int,
|
76
|
+
decode_dp_size: int,
|
74
77
|
scheduler: Scheduler,
|
78
|
+
pp_rank: int,
|
79
|
+
pp_size: int,
|
80
|
+
transfer_backend: TransferBackend,
|
75
81
|
):
|
76
82
|
self.token_to_kv_pool = token_to_kv_pool
|
77
83
|
self.draft_token_to_kv_pool = draft_token_to_kv_pool
|
78
|
-
|
79
84
|
self.is_mla_backend = is_mla_backend(token_to_kv_pool)
|
80
|
-
|
81
85
|
self.metadata_buffers = metadata_buffers
|
82
86
|
self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
|
83
87
|
self.tp_rank = tp_rank
|
84
88
|
self.tp_size = tp_size
|
85
|
-
self.
|
86
|
-
self.
|
87
|
-
self.
|
89
|
+
self.decode_tp_size = decode_tp_size
|
90
|
+
self.decode_dp_size = decode_dp_size
|
91
|
+
self.pp_rank = pp_rank
|
92
|
+
self.pp_size = pp_size
|
93
|
+
self.gpu_id = gpu_id
|
94
|
+
self.bootstrap_port = bootstrap_port
|
88
95
|
self.queue: List[Req] = []
|
96
|
+
self.pp_rank = pp_rank
|
97
|
+
self.pp_size = pp_size
|
89
98
|
self.gloo_group = gloo_group
|
90
|
-
self.
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
output_id_buffer = self.metadata_buffers[0]
|
95
|
-
output_id_buffer[idx] = token_id
|
99
|
+
self.max_total_num_tokens = max_total_num_tokens
|
100
|
+
self.scheduler = scheduler
|
101
|
+
self.transfer_backend = transfer_backend
|
102
|
+
self.kv_manager = self._init_kv_manager()
|
96
103
|
|
97
104
|
def _init_kv_manager(self) -> BaseKVManager:
|
98
|
-
|
105
|
+
kv_args_class = get_kv_class(self.transfer_backend, KVClassType.KVARGS)
|
106
|
+
kv_args = kv_args_class()
|
99
107
|
kv_args.engine_rank = self.tp_rank
|
108
|
+
kv_args.decode_tp_size = self.decode_tp_size // self.decode_dp_size
|
109
|
+
kv_args.prefill_pp_size = self.pp_size
|
100
110
|
kv_data_ptrs, kv_data_lens, kv_item_lens = (
|
101
111
|
self.token_to_kv_pool.get_contiguous_buf_infos()
|
102
112
|
)
|
@@ -115,12 +125,12 @@ class PrefillBootstrapQueue:
|
|
115
125
|
kv_args.kv_data_lens = kv_data_lens
|
116
126
|
kv_args.kv_item_lens = kv_item_lens
|
117
127
|
|
118
|
-
# Define req -> input ids buffer
|
119
128
|
kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = (
|
120
129
|
self.metadata_buffers.get_buf_infos()
|
121
130
|
)
|
122
131
|
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
|
123
132
|
kv_args.gpu_id = self.scheduler.gpu_id
|
133
|
+
|
124
134
|
kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
|
125
135
|
kv_manager = kv_manager_class(
|
126
136
|
kv_args,
|
@@ -130,23 +140,39 @@ class PrefillBootstrapQueue:
|
|
130
140
|
)
|
131
141
|
return kv_manager
|
132
142
|
|
133
|
-
def add(self, req: Req) -> None:
|
134
|
-
if req
|
135
|
-
|
143
|
+
def add(self, req: Req, num_kv_heads: int) -> None:
|
144
|
+
if self._check_if_req_exceed_kv_capacity(req):
|
145
|
+
return
|
146
|
+
|
147
|
+
if req.bootstrap_host == FAKE_BOOTSTRAP_HOST:
|
136
148
|
kv_sender_class = get_kv_class(TransferBackend.FAKE, KVClassType.SENDER)
|
137
149
|
else:
|
138
150
|
kv_sender_class = get_kv_class(self.transfer_backend, KVClassType.SENDER)
|
151
|
+
|
152
|
+
dest_tp_ranks = [self.tp_rank]
|
153
|
+
|
139
154
|
req.disagg_kv_sender = kv_sender_class(
|
140
155
|
mgr=self.kv_manager,
|
141
156
|
bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}",
|
142
157
|
bootstrap_room=req.bootstrap_room,
|
158
|
+
dest_tp_ranks=dest_tp_ranks,
|
159
|
+
pp_rank=self.pp_rank,
|
143
160
|
)
|
144
161
|
self._process_req(req)
|
145
162
|
self.queue.append(req)
|
146
163
|
|
147
|
-
def extend(self, reqs: List[Req]) -> None:
|
164
|
+
def extend(self, reqs: List[Req], num_kv_heads: int) -> None:
|
148
165
|
for req in reqs:
|
149
|
-
self.add(req)
|
166
|
+
self.add(req, num_kv_heads)
|
167
|
+
|
168
|
+
def _check_if_req_exceed_kv_capacity(self, req: Req) -> bool:
|
169
|
+
if len(req.origin_input_ids) > self.max_total_num_tokens:
|
170
|
+
message = f"Request {req.rid} exceeds the maximum number of tokens: {len(req.origin_input_ids)} > {self.max_total_num_tokens}"
|
171
|
+
logger.error(message)
|
172
|
+
prepare_abort(req, message)
|
173
|
+
self.scheduler.stream_output([req], req.return_logprob)
|
174
|
+
return True
|
175
|
+
return False
|
150
176
|
|
151
177
|
def _process_req(self, req: Req) -> None:
|
152
178
|
"""
|
@@ -154,19 +180,40 @@ class PrefillBootstrapQueue:
|
|
154
180
|
"""
|
155
181
|
req.sampling_params.max_new_tokens = 1
|
156
182
|
|
157
|
-
def pop_bootstrapped(
|
158
|
-
|
183
|
+
def pop_bootstrapped(
|
184
|
+
self,
|
185
|
+
return_failed_reqs: bool = False,
|
186
|
+
rids_to_check: Optional[List[str]] = None,
|
187
|
+
) -> List[Req]:
|
188
|
+
"""
|
189
|
+
pop the reqs which has finished bootstrapping
|
190
|
+
|
191
|
+
return_failed_reqs: For PP, on rank 0, also return the failed reqs to notify the next rank
|
192
|
+
rids_to_check: For PP, on rank > 0, check the rids from the previous rank has consensus with the current rank.
|
193
|
+
"""
|
194
|
+
|
159
195
|
bootstrapped_reqs = []
|
196
|
+
failed_reqs = []
|
160
197
|
indices_to_remove = set()
|
161
198
|
|
162
199
|
if len(self.queue) == 0:
|
163
|
-
|
200
|
+
if return_failed_reqs is False:
|
201
|
+
return []
|
202
|
+
else:
|
203
|
+
return [], []
|
164
204
|
|
165
205
|
polls = poll_and_all_reduce(
|
166
206
|
[req.disagg_kv_sender for req in self.queue], self.gloo_group
|
167
207
|
)
|
168
|
-
|
169
208
|
for i, (req, poll) in enumerate(zip(self.queue, polls)):
|
209
|
+
|
210
|
+
if rids_to_check is not None:
|
211
|
+
# if req not in reqs_info_to_check, skip
|
212
|
+
if req.rid not in rids_to_check:
|
213
|
+
continue
|
214
|
+
# Either waiting for input or failed
|
215
|
+
assert poll == KVPoll.WaitingForInput or poll == KVPoll.Failed
|
216
|
+
|
170
217
|
if poll == KVPoll.Bootstrapping:
|
171
218
|
continue
|
172
219
|
elif poll == KVPoll.Failed:
|
@@ -181,9 +228,10 @@ class PrefillBootstrapQueue:
|
|
181
228
|
)
|
182
229
|
self.scheduler.stream_output([req], req.return_logprob)
|
183
230
|
indices_to_remove.add(i)
|
231
|
+
failed_reqs.append(req)
|
184
232
|
continue
|
185
233
|
|
186
|
-
# KV.WaitingForInput
|
234
|
+
# KV.WaitingForInput - init here
|
187
235
|
num_kv_indices = len(req.origin_input_ids)
|
188
236
|
if self.req_to_metadata_buffer_idx_allocator.available_size() == 0:
|
189
237
|
break
|
@@ -192,9 +240,9 @@ class PrefillBootstrapQueue:
|
|
192
240
|
self.req_to_metadata_buffer_idx_allocator.alloc()
|
193
241
|
)
|
194
242
|
assert req.metadata_buffer_index is not None
|
243
|
+
|
195
244
|
num_pages = kv_to_page_num(num_kv_indices, self.token_to_kv_pool.page_size)
|
196
245
|
req.disagg_kv_sender.init(num_pages, req.metadata_buffer_index)
|
197
|
-
|
198
246
|
bootstrapped_reqs.append(req)
|
199
247
|
indices_to_remove.add(i)
|
200
248
|
|
@@ -202,7 +250,10 @@ class PrefillBootstrapQueue:
|
|
202
250
|
entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
|
203
251
|
]
|
204
252
|
|
205
|
-
|
253
|
+
if return_failed_reqs is False:
|
254
|
+
return bootstrapped_reqs
|
255
|
+
else:
|
256
|
+
return bootstrapped_reqs, failed_reqs
|
206
257
|
|
207
258
|
|
208
259
|
class SchedulerDisaggregationPrefillMixin:
|
@@ -211,7 +262,7 @@ class SchedulerDisaggregationPrefillMixin:
|
|
211
262
|
"""
|
212
263
|
|
213
264
|
@torch.no_grad()
|
214
|
-
def event_loop_normal_disagg_prefill(self: Scheduler):
|
265
|
+
def event_loop_normal_disagg_prefill(self: Scheduler) -> None:
|
215
266
|
"""A normal scheduler loop for prefill worker in disaggregation mode."""
|
216
267
|
|
217
268
|
while True:
|
@@ -229,7 +280,6 @@ class SchedulerDisaggregationPrefillMixin:
|
|
229
280
|
or self.server_args.enable_sp_layernorm
|
230
281
|
):
|
231
282
|
batch, _ = self.prepare_dp_attn_batch(batch)
|
232
|
-
|
233
283
|
self.cur_batch = batch
|
234
284
|
|
235
285
|
if batch:
|
@@ -242,6 +292,7 @@ class SchedulerDisaggregationPrefillMixin:
|
|
242
292
|
if batch is None and len(self.disagg_prefill_inflight_queue) == 0:
|
243
293
|
self.check_memory()
|
244
294
|
self.new_token_ratio = self.init_new_token_ratio
|
295
|
+
self.maybe_sleep_on_idle()
|
245
296
|
|
246
297
|
self.last_batch = batch
|
247
298
|
# HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
|
@@ -249,7 +300,7 @@ class SchedulerDisaggregationPrefillMixin:
|
|
249
300
|
self.running_batch.batch_is_full = False
|
250
301
|
|
251
302
|
@torch.no_grad()
|
252
|
-
def event_loop_overlap_disagg_prefill(self: Scheduler):
|
303
|
+
def event_loop_overlap_disagg_prefill(self: Scheduler) -> None:
|
253
304
|
self.result_queue = deque()
|
254
305
|
|
255
306
|
while True:
|
@@ -267,9 +318,7 @@ class SchedulerDisaggregationPrefillMixin:
|
|
267
318
|
or self.server_args.enable_sp_layernorm
|
268
319
|
):
|
269
320
|
batch, _ = self.prepare_dp_attn_batch(batch)
|
270
|
-
|
271
321
|
self.cur_batch = batch
|
272
|
-
|
273
322
|
if batch:
|
274
323
|
result = self.run_batch(batch)
|
275
324
|
self.result_queue.append((batch.copy(), result))
|
@@ -286,6 +335,9 @@ class SchedulerDisaggregationPrefillMixin:
|
|
286
335
|
|
287
336
|
if self.last_batch:
|
288
337
|
tmp_batch, tmp_result = self.result_queue.popleft()
|
338
|
+
tmp_batch.next_batch_sampling_info = (
|
339
|
+
self.tp_worker.cur_sampling_info if batch else None
|
340
|
+
)
|
289
341
|
self.process_batch_result_disagg_prefill(tmp_batch, tmp_result)
|
290
342
|
|
291
343
|
if len(self.disagg_prefill_inflight_queue) > 0:
|
@@ -294,6 +346,7 @@ class SchedulerDisaggregationPrefillMixin:
|
|
294
346
|
if batch is None and len(self.disagg_prefill_inflight_queue) == 0:
|
295
347
|
self.check_memory()
|
296
348
|
self.new_token_ratio = self.init_new_token_ratio
|
349
|
+
self.maybe_sleep_on_idle()
|
297
350
|
|
298
351
|
self.last_batch = batch
|
299
352
|
# HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
|
@@ -307,7 +360,7 @@ class SchedulerDisaggregationPrefillMixin:
|
|
307
360
|
launch_done: Optional[threading.Event] = None,
|
308
361
|
) -> None:
|
309
362
|
"""
|
310
|
-
Transfer kv for prefill completed requests and add it into
|
363
|
+
Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
|
311
364
|
Adapted from process_batch_result_prefill
|
312
365
|
"""
|
313
366
|
(
|
@@ -323,7 +376,7 @@ class SchedulerDisaggregationPrefillMixin:
|
|
323
376
|
)
|
324
377
|
|
325
378
|
logprob_pt = 0
|
326
|
-
# Transfer kv for prefill completed requests and add it into
|
379
|
+
# Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
|
327
380
|
if self.enable_overlap:
|
328
381
|
# wait
|
329
382
|
logits_output, next_token_ids, _ = self.tp_worker.resolve_last_batch_result(
|
@@ -395,11 +448,15 @@ class SchedulerDisaggregationPrefillMixin:
|
|
395
448
|
# We need to remove the sync in the following function for overlap schedule.
|
396
449
|
self.set_next_batch_sampling_info_done(batch)
|
397
450
|
|
398
|
-
def process_disagg_prefill_inflight_queue(
|
451
|
+
def process_disagg_prefill_inflight_queue(
|
452
|
+
self: Scheduler, rids_to_check: Optional[List[str]] = None
|
453
|
+
) -> List[Req]:
|
399
454
|
"""
|
400
455
|
Poll the requests in the middle of transfer. If done, return the request.
|
456
|
+
rids_to_check: For PP, on rank > 0, check the rids from the previous rank has consensus with the current rank.
|
401
457
|
"""
|
402
|
-
|
458
|
+
if len(self.disagg_prefill_inflight_queue) == 0:
|
459
|
+
return []
|
403
460
|
|
404
461
|
done_reqs = []
|
405
462
|
|
@@ -411,6 +468,14 @@ class SchedulerDisaggregationPrefillMixin:
|
|
411
468
|
undone_reqs: List[Req] = []
|
412
469
|
# Check .poll() for the reqs in disagg_prefill_inflight_queue. If Success, respond to the client and remove it from the queue
|
413
470
|
for req, poll in zip(self.disagg_prefill_inflight_queue, polls):
|
471
|
+
|
472
|
+
if rids_to_check is not None:
|
473
|
+
if req.rid not in rids_to_check:
|
474
|
+
undone_reqs.append(req)
|
475
|
+
continue
|
476
|
+
|
477
|
+
assert poll == KVPoll.Success or poll == KVPoll.Failed
|
478
|
+
|
414
479
|
if poll in [KVPoll.WaitingForInput, KVPoll.Transferring]:
|
415
480
|
undone_reqs.append(req)
|
416
481
|
elif poll == KVPoll.Success: # transfer done
|
@@ -432,11 +497,8 @@ class SchedulerDisaggregationPrefillMixin:
|
|
432
497
|
req, error_message, status_code=HTTPStatus.INTERNAL_SERVER_ERROR
|
433
498
|
)
|
434
499
|
done_reqs.append(req)
|
435
|
-
|
436
|
-
|
437
|
-
self.disagg_prefill_bootstrap_queue.req_to_metadata_buffer_idx_allocator.free(
|
438
|
-
req.metadata_buffer_index
|
439
|
-
)
|
500
|
+
else:
|
501
|
+
assert False, f"Unexpected polling state {poll=}"
|
440
502
|
|
441
503
|
# Stream requests which have finished transfer
|
442
504
|
self.stream_output(
|
@@ -444,9 +506,32 @@ class SchedulerDisaggregationPrefillMixin:
|
|
444
506
|
any(req.return_logprob for req in done_reqs),
|
445
507
|
None,
|
446
508
|
)
|
509
|
+
for req in done_reqs:
|
510
|
+
req: Req
|
511
|
+
self.req_to_metadata_buffer_idx_allocator.free(req.metadata_buffer_index)
|
512
|
+
req.metadata_buffer_index = -1
|
447
513
|
|
448
514
|
self.disagg_prefill_inflight_queue = undone_reqs
|
449
515
|
|
516
|
+
return done_reqs
|
517
|
+
|
518
|
+
def get_transferred_rids(self: Scheduler) -> List[str]:
|
519
|
+
"""
|
520
|
+
Used by PP, get the transferred rids but **do not pop**
|
521
|
+
"""
|
522
|
+
polls = poll_and_all_reduce(
|
523
|
+
[req.disagg_kv_sender for req in self.disagg_prefill_inflight_queue],
|
524
|
+
self.tp_worker.get_tp_group().cpu_group,
|
525
|
+
)
|
526
|
+
|
527
|
+
transferred_rids: List[str] = []
|
528
|
+
|
529
|
+
for req, poll in zip(self.disagg_prefill_inflight_queue, polls):
|
530
|
+
if poll == KVPoll.Success or poll == KVPoll.Failed:
|
531
|
+
transferred_rids.append(req.rid)
|
532
|
+
|
533
|
+
return transferred_rids
|
534
|
+
|
450
535
|
def process_prefill_chunk(self: Scheduler) -> None:
|
451
536
|
if self.last_batch and self.last_batch.forward_mode.is_extend():
|
452
537
|
if self.chunked_req:
|