sglang 0.4.7__py3-none-any.whl → 0.4.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +7 -0
- sglang/bench_one_batch.py +8 -6
- sglang/bench_serving.py +1 -1
- sglang/lang/interpreter.py +40 -1
- sglang/lang/ir.py +27 -0
- sglang/math_utils.py +8 -0
- sglang/srt/_custom_ops.py +2 -2
- sglang/srt/code_completion_parser.py +2 -44
- sglang/srt/configs/model_config.py +6 -0
- sglang/srt/constants.py +3 -0
- sglang/srt/conversation.py +19 -3
- sglang/srt/custom_op.py +5 -1
- sglang/srt/disaggregation/base/__init__.py +1 -1
- sglang/srt/disaggregation/base/conn.py +25 -11
- sglang/srt/disaggregation/common/__init__.py +5 -1
- sglang/srt/disaggregation/common/utils.py +42 -0
- sglang/srt/disaggregation/decode.py +211 -72
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
- sglang/srt/disaggregation/fake/__init__.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +15 -9
- sglang/srt/disaggregation/mini_lb.py +34 -4
- sglang/srt/disaggregation/mooncake/__init__.py +1 -1
- sglang/srt/disaggregation/mooncake/conn.py +30 -29
- sglang/srt/disaggregation/nixl/__init__.py +6 -1
- sglang/srt/disaggregation/nixl/conn.py +17 -12
- sglang/srt/disaggregation/prefill.py +144 -55
- sglang/srt/disaggregation/utils.py +155 -123
- sglang/srt/distributed/parallel_state.py +12 -4
- sglang/srt/entrypoints/engine.py +37 -29
- sglang/srt/entrypoints/http_server.py +153 -72
- sglang/srt/entrypoints/http_server_engine.py +0 -3
- sglang/srt/entrypoints/openai/__init__.py +0 -0
- sglang/srt/{openai_api → entrypoints/openai}/protocol.py +84 -10
- sglang/srt/entrypoints/openai/serving_base.py +149 -0
- sglang/srt/entrypoints/openai/serving_chat.py +921 -0
- sglang/srt/entrypoints/openai/serving_completions.py +424 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
- sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
- sglang/srt/entrypoints/openai/serving_score.py +61 -0
- sglang/srt/entrypoints/openai/usage_processor.py +81 -0
- sglang/srt/entrypoints/openai/utils.py +72 -0
- sglang/srt/eplb_simulator/__init__.py +1 -0
- sglang/srt/eplb_simulator/reader.py +51 -0
- sglang/srt/function_call/base_format_detector.py +7 -4
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +64 -10
- sglang/srt/function_call/function_call_parser.py +6 -6
- sglang/srt/function_call/llama32_detector.py +1 -1
- sglang/srt/function_call/mistral_detector.py +1 -1
- sglang/srt/function_call/pythonic_detector.py +1 -1
- sglang/srt/function_call/qwen25_detector.py +1 -1
- sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
- sglang/srt/layers/activation.py +40 -3
- sglang/srt/layers/attention/aiter_backend.py +20 -4
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/cutlass_mla_backend.py +39 -15
- sglang/srt/layers/attention/flashattention_backend.py +71 -72
- sglang/srt/layers/attention/flashinfer_backend.py +10 -8
- sglang/srt/layers/attention/flashinfer_mla_backend.py +29 -28
- sglang/srt/layers/attention/flashmla_backend.py +7 -12
- sglang/srt/layers/attention/tbo_backend.py +3 -3
- sglang/srt/layers/attention/triton_backend.py +138 -130
- sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
- sglang/srt/layers/attention/vision.py +51 -24
- sglang/srt/layers/communicator.py +28 -10
- sglang/srt/layers/dp_attention.py +11 -2
- sglang/srt/layers/layernorm.py +29 -2
- sglang/srt/layers/linear.py +0 -4
- sglang/srt/layers/logits_processor.py +2 -14
- sglang/srt/layers/moe/ep_moe/kernels.py +165 -7
- sglang/srt/layers/moe/ep_moe/layer.py +249 -33
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -4
- sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
- sglang/srt/layers/moe/topk.py +107 -12
- sglang/srt/layers/pooler.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
- sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
- sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
- sglang/srt/layers/quantization/fp8.py +25 -17
- sglang/srt/layers/quantization/fp8_kernel.py +44 -15
- sglang/srt/layers/quantization/fp8_utils.py +87 -22
- sglang/srt/layers/quantization/modelopt_quant.py +62 -8
- sglang/srt/layers/quantization/utils.py +5 -2
- sglang/srt/layers/radix_attention.py +2 -3
- sglang/srt/layers/rotary_embedding.py +42 -2
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/lora_manager.py +249 -105
- sglang/srt/lora/mem_pool.py +53 -50
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +33 -14
- sglang/srt/managers/io_struct.py +31 -10
- sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
- sglang/srt/managers/multimodal_processors/vila.py +85 -0
- sglang/srt/managers/schedule_batch.py +79 -37
- sglang/srt/managers/schedule_policy.py +70 -56
- sglang/srt/managers/scheduler.py +220 -79
- sglang/srt/managers/template_manager.py +226 -0
- sglang/srt/managers/tokenizer_manager.py +40 -10
- sglang/srt/managers/tp_worker.py +12 -2
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
- sglang/srt/mem_cache/base_prefix_cache.py +52 -8
- sglang/srt/mem_cache/chunk_cache.py +11 -15
- sglang/srt/mem_cache/hiradix_cache.py +38 -25
- sglang/srt/mem_cache/memory_pool.py +213 -505
- sglang/srt/mem_cache/memory_pool_host.py +380 -0
- sglang/srt/mem_cache/radix_cache.py +56 -28
- sglang/srt/model_executor/cuda_graph_runner.py +198 -100
- sglang/srt/model_executor/forward_batch_info.py +32 -10
- sglang/srt/model_executor/model_runner.py +28 -12
- sglang/srt/model_loader/loader.py +16 -2
- sglang/srt/model_loader/weight_utils.py +11 -2
- sglang/srt/models/bert.py +113 -13
- sglang/srt/models/deepseek_nextn.py +29 -27
- sglang/srt/models/deepseek_v2.py +213 -173
- sglang/srt/models/glm4.py +312 -0
- sglang/srt/models/internvl.py +46 -102
- sglang/srt/models/mimo_mtp.py +2 -18
- sglang/srt/models/roberta.py +117 -9
- sglang/srt/models/vila.py +305 -0
- sglang/srt/reasoning_parser.py +21 -11
- sglang/srt/sampling/sampling_batch_info.py +24 -0
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +351 -238
- sglang/srt/speculative/build_eagle_tree.py +1 -1
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -9
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +130 -14
- sglang/srt/speculative/eagle_utils.py +468 -116
- sglang/srt/speculative/eagle_worker.py +258 -84
- sglang/srt/torch_memory_saver_adapter.py +19 -15
- sglang/srt/two_batch_overlap.py +4 -2
- sglang/srt/utils.py +235 -11
- sglang/test/attention/test_prefix_chunk_info.py +2 -0
- sglang/test/runners.py +38 -3
- sglang/test/test_block_fp8.py +1 -0
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
- sglang/test/test_block_fp8_ep.py +2 -0
- sglang/test/test_utils.py +4 -1
- sglang/utils.py +9 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/METADATA +8 -14
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/RECORD +150 -128
- sglang/srt/entrypoints/verl_engine.py +0 -179
- sglang/srt/openai_api/adapter.py +0 -1990
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -27,10 +27,10 @@ from typing import TYPE_CHECKING, List, Optional
|
|
27
27
|
|
28
28
|
import torch
|
29
29
|
|
30
|
-
from sglang.srt.disaggregation.base import BaseKVManager,
|
30
|
+
from sglang.srt.disaggregation.base import BaseKVManager, KVPoll
|
31
31
|
from sglang.srt.disaggregation.utils import (
|
32
|
+
FAKE_BOOTSTRAP_HOST,
|
32
33
|
DisaggregationMode,
|
33
|
-
FakeBootstrapHost,
|
34
34
|
KVClassType,
|
35
35
|
MetadataBuffers,
|
36
36
|
ReqToMetadataIdxAllocator,
|
@@ -44,6 +44,7 @@ from sglang.srt.disaggregation.utils import (
|
|
44
44
|
)
|
45
45
|
from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
|
46
46
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
47
|
+
from sglang.srt.utils import require_mlp_sync
|
47
48
|
|
48
49
|
if TYPE_CHECKING:
|
49
50
|
from torch.distributed import ProcessGroup
|
@@ -51,7 +52,6 @@ if TYPE_CHECKING:
|
|
51
52
|
from sglang.srt.managers.scheduler import GenerationBatchResult, Scheduler
|
52
53
|
from sglang.srt.mem_cache.memory_pool import KVCache
|
53
54
|
|
54
|
-
|
55
55
|
logger = logging.getLogger(__name__)
|
56
56
|
|
57
57
|
|
@@ -68,35 +68,45 @@ class PrefillBootstrapQueue:
|
|
68
68
|
metadata_buffers: MetadataBuffers,
|
69
69
|
tp_rank: int,
|
70
70
|
tp_size: int,
|
71
|
+
gpu_id: int,
|
71
72
|
bootstrap_port: int,
|
72
73
|
gloo_group: ProcessGroup,
|
73
|
-
|
74
|
+
max_total_num_tokens: int,
|
75
|
+
decode_tp_size: int,
|
76
|
+
decode_dp_size: int,
|
74
77
|
scheduler: Scheduler,
|
78
|
+
pp_rank: int,
|
79
|
+
pp_size: int,
|
80
|
+
transfer_backend: TransferBackend,
|
75
81
|
):
|
76
82
|
self.token_to_kv_pool = token_to_kv_pool
|
77
83
|
self.draft_token_to_kv_pool = draft_token_to_kv_pool
|
78
|
-
|
79
84
|
self.is_mla_backend = is_mla_backend(token_to_kv_pool)
|
80
|
-
|
81
85
|
self.metadata_buffers = metadata_buffers
|
82
86
|
self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
|
83
87
|
self.tp_rank = tp_rank
|
84
88
|
self.tp_size = tp_size
|
85
|
-
self.
|
86
|
-
self.
|
87
|
-
self.
|
89
|
+
self.decode_tp_size = decode_tp_size
|
90
|
+
self.decode_dp_size = decode_dp_size
|
91
|
+
self.pp_rank = pp_rank
|
92
|
+
self.pp_size = pp_size
|
93
|
+
self.gpu_id = gpu_id
|
94
|
+
self.bootstrap_port = bootstrap_port
|
88
95
|
self.queue: List[Req] = []
|
96
|
+
self.pp_rank = pp_rank
|
97
|
+
self.pp_size = pp_size
|
89
98
|
self.gloo_group = gloo_group
|
90
|
-
self.
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
output_id_buffer = self.metadata_buffers[0]
|
95
|
-
output_id_buffer[idx] = token_id
|
99
|
+
self.max_total_num_tokens = max_total_num_tokens
|
100
|
+
self.scheduler = scheduler
|
101
|
+
self.transfer_backend = transfer_backend
|
102
|
+
self.kv_manager = self._init_kv_manager()
|
96
103
|
|
97
104
|
def _init_kv_manager(self) -> BaseKVManager:
|
98
|
-
|
105
|
+
kv_args_class = get_kv_class(self.transfer_backend, KVClassType.KVARGS)
|
106
|
+
kv_args = kv_args_class()
|
99
107
|
kv_args.engine_rank = self.tp_rank
|
108
|
+
kv_args.decode_tp_size = self.decode_tp_size // self.decode_dp_size
|
109
|
+
kv_args.prefill_pp_size = self.pp_size
|
100
110
|
kv_data_ptrs, kv_data_lens, kv_item_lens = (
|
101
111
|
self.token_to_kv_pool.get_contiguous_buf_infos()
|
102
112
|
)
|
@@ -115,12 +125,12 @@ class PrefillBootstrapQueue:
|
|
115
125
|
kv_args.kv_data_lens = kv_data_lens
|
116
126
|
kv_args.kv_item_lens = kv_item_lens
|
117
127
|
|
118
|
-
# Define req -> input ids buffer
|
119
128
|
kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = (
|
120
129
|
self.metadata_buffers.get_buf_infos()
|
121
130
|
)
|
122
131
|
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
|
123
132
|
kv_args.gpu_id = self.scheduler.gpu_id
|
133
|
+
|
124
134
|
kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
|
125
135
|
kv_manager = kv_manager_class(
|
126
136
|
kv_args,
|
@@ -130,23 +140,39 @@ class PrefillBootstrapQueue:
|
|
130
140
|
)
|
131
141
|
return kv_manager
|
132
142
|
|
133
|
-
def add(self, req: Req) -> None:
|
134
|
-
if req
|
135
|
-
|
143
|
+
def add(self, req: Req, num_kv_heads: int) -> None:
|
144
|
+
if self._check_if_req_exceed_kv_capacity(req):
|
145
|
+
return
|
146
|
+
|
147
|
+
if req.bootstrap_host == FAKE_BOOTSTRAP_HOST:
|
136
148
|
kv_sender_class = get_kv_class(TransferBackend.FAKE, KVClassType.SENDER)
|
137
149
|
else:
|
138
150
|
kv_sender_class = get_kv_class(self.transfer_backend, KVClassType.SENDER)
|
151
|
+
|
152
|
+
dest_tp_ranks = [self.tp_rank]
|
153
|
+
|
139
154
|
req.disagg_kv_sender = kv_sender_class(
|
140
155
|
mgr=self.kv_manager,
|
141
156
|
bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}",
|
142
157
|
bootstrap_room=req.bootstrap_room,
|
158
|
+
dest_tp_ranks=dest_tp_ranks,
|
159
|
+
pp_rank=self.pp_rank,
|
143
160
|
)
|
144
161
|
self._process_req(req)
|
145
162
|
self.queue.append(req)
|
146
163
|
|
147
|
-
def extend(self, reqs: List[Req]) -> None:
|
164
|
+
def extend(self, reqs: List[Req], num_kv_heads: int) -> None:
|
148
165
|
for req in reqs:
|
149
|
-
self.add(req)
|
166
|
+
self.add(req, num_kv_heads)
|
167
|
+
|
168
|
+
def _check_if_req_exceed_kv_capacity(self, req: Req) -> bool:
|
169
|
+
if len(req.origin_input_ids) > self.max_total_num_tokens:
|
170
|
+
message = f"Request {req.rid} exceeds the maximum number of tokens: {len(req.origin_input_ids)} > {self.max_total_num_tokens}"
|
171
|
+
logger.error(message)
|
172
|
+
prepare_abort(req, message)
|
173
|
+
self.scheduler.stream_output([req], req.return_logprob)
|
174
|
+
return True
|
175
|
+
return False
|
150
176
|
|
151
177
|
def _process_req(self, req: Req) -> None:
|
152
178
|
"""
|
@@ -154,19 +180,40 @@ class PrefillBootstrapQueue:
|
|
154
180
|
"""
|
155
181
|
req.sampling_params.max_new_tokens = 1
|
156
182
|
|
157
|
-
def pop_bootstrapped(
|
158
|
-
|
183
|
+
def pop_bootstrapped(
|
184
|
+
self,
|
185
|
+
return_failed_reqs: bool = False,
|
186
|
+
rids_to_check: Optional[List[str]] = None,
|
187
|
+
) -> List[Req]:
|
188
|
+
"""
|
189
|
+
pop the reqs which has finished bootstrapping
|
190
|
+
|
191
|
+
return_failed_reqs: For PP, on rank 0, also return the failed reqs to notify the next rank
|
192
|
+
rids_to_check: For PP, on rank > 0, check the rids from the previous rank has consensus with the current rank.
|
193
|
+
"""
|
194
|
+
|
159
195
|
bootstrapped_reqs = []
|
196
|
+
failed_reqs = []
|
160
197
|
indices_to_remove = set()
|
161
198
|
|
162
199
|
if len(self.queue) == 0:
|
163
|
-
|
200
|
+
if return_failed_reqs is False:
|
201
|
+
return []
|
202
|
+
else:
|
203
|
+
return [], []
|
164
204
|
|
165
205
|
polls = poll_and_all_reduce(
|
166
206
|
[req.disagg_kv_sender for req in self.queue], self.gloo_group
|
167
207
|
)
|
168
|
-
|
169
208
|
for i, (req, poll) in enumerate(zip(self.queue, polls)):
|
209
|
+
|
210
|
+
if rids_to_check is not None:
|
211
|
+
# if req not in reqs_info_to_check, skip
|
212
|
+
if req.rid not in rids_to_check:
|
213
|
+
continue
|
214
|
+
# Either waiting for input or failed
|
215
|
+
assert poll == KVPoll.WaitingForInput or poll == KVPoll.Failed
|
216
|
+
|
170
217
|
if poll == KVPoll.Bootstrapping:
|
171
218
|
continue
|
172
219
|
elif poll == KVPoll.Failed:
|
@@ -181,9 +228,10 @@ class PrefillBootstrapQueue:
|
|
181
228
|
)
|
182
229
|
self.scheduler.stream_output([req], req.return_logprob)
|
183
230
|
indices_to_remove.add(i)
|
231
|
+
failed_reqs.append(req)
|
184
232
|
continue
|
185
233
|
|
186
|
-
# KV.WaitingForInput
|
234
|
+
# KV.WaitingForInput - init here
|
187
235
|
num_kv_indices = len(req.origin_input_ids)
|
188
236
|
if self.req_to_metadata_buffer_idx_allocator.available_size() == 0:
|
189
237
|
break
|
@@ -192,9 +240,9 @@ class PrefillBootstrapQueue:
|
|
192
240
|
self.req_to_metadata_buffer_idx_allocator.alloc()
|
193
241
|
)
|
194
242
|
assert req.metadata_buffer_index is not None
|
243
|
+
|
195
244
|
num_pages = kv_to_page_num(num_kv_indices, self.token_to_kv_pool.page_size)
|
196
245
|
req.disagg_kv_sender.init(num_pages, req.metadata_buffer_index)
|
197
|
-
|
198
246
|
bootstrapped_reqs.append(req)
|
199
247
|
indices_to_remove.add(i)
|
200
248
|
|
@@ -202,7 +250,10 @@ class PrefillBootstrapQueue:
|
|
202
250
|
entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
|
203
251
|
]
|
204
252
|
|
205
|
-
|
253
|
+
if return_failed_reqs is False:
|
254
|
+
return bootstrapped_reqs
|
255
|
+
else:
|
256
|
+
return bootstrapped_reqs, failed_reqs
|
206
257
|
|
207
258
|
|
208
259
|
class SchedulerDisaggregationPrefillMixin:
|
@@ -211,7 +262,7 @@ class SchedulerDisaggregationPrefillMixin:
|
|
211
262
|
"""
|
212
263
|
|
213
264
|
@torch.no_grad()
|
214
|
-
def event_loop_normal_disagg_prefill(self: Scheduler):
|
265
|
+
def event_loop_normal_disagg_prefill(self: Scheduler) -> None:
|
215
266
|
"""A normal scheduler loop for prefill worker in disaggregation mode."""
|
216
267
|
|
217
268
|
while True:
|
@@ -223,13 +274,8 @@ class SchedulerDisaggregationPrefillMixin:
|
|
223
274
|
self.process_prefill_chunk()
|
224
275
|
batch = self.get_new_batch_prefill()
|
225
276
|
|
226
|
-
|
227
|
-
|
228
|
-
self.server_args.enable_dp_attention
|
229
|
-
or self.server_args.enable_sp_layernorm
|
230
|
-
):
|
231
|
-
batch, _ = self.prepare_dp_attn_batch(batch)
|
232
|
-
|
277
|
+
if require_mlp_sync(self.server_args):
|
278
|
+
batch, _ = self.prepare_mlp_sync_batch(batch)
|
233
279
|
self.cur_batch = batch
|
234
280
|
|
235
281
|
if batch:
|
@@ -242,6 +288,7 @@ class SchedulerDisaggregationPrefillMixin:
|
|
242
288
|
if batch is None and len(self.disagg_prefill_inflight_queue) == 0:
|
243
289
|
self.check_memory()
|
244
290
|
self.new_token_ratio = self.init_new_token_ratio
|
291
|
+
self.maybe_sleep_on_idle()
|
245
292
|
|
246
293
|
self.last_batch = batch
|
247
294
|
# HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
|
@@ -249,7 +296,7 @@ class SchedulerDisaggregationPrefillMixin:
|
|
249
296
|
self.running_batch.batch_is_full = False
|
250
297
|
|
251
298
|
@torch.no_grad()
|
252
|
-
def event_loop_overlap_disagg_prefill(self: Scheduler):
|
299
|
+
def event_loop_overlap_disagg_prefill(self: Scheduler) -> None:
|
253
300
|
self.result_queue = deque()
|
254
301
|
|
255
302
|
while True:
|
@@ -261,15 +308,9 @@ class SchedulerDisaggregationPrefillMixin:
|
|
261
308
|
self.process_prefill_chunk()
|
262
309
|
batch = self.get_new_batch_prefill()
|
263
310
|
|
264
|
-
|
265
|
-
|
266
|
-
self.server_args.enable_dp_attention
|
267
|
-
or self.server_args.enable_sp_layernorm
|
268
|
-
):
|
269
|
-
batch, _ = self.prepare_dp_attn_batch(batch)
|
270
|
-
|
311
|
+
if require_mlp_sync(self.server_args):
|
312
|
+
batch, _ = self.prepare_mlp_sync_batch(batch)
|
271
313
|
self.cur_batch = batch
|
272
|
-
|
273
314
|
if batch:
|
274
315
|
result = self.run_batch(batch)
|
275
316
|
self.result_queue.append((batch.copy(), result))
|
@@ -286,6 +327,9 @@ class SchedulerDisaggregationPrefillMixin:
|
|
286
327
|
|
287
328
|
if self.last_batch:
|
288
329
|
tmp_batch, tmp_result = self.result_queue.popleft()
|
330
|
+
tmp_batch.next_batch_sampling_info = (
|
331
|
+
self.tp_worker.cur_sampling_info if batch else None
|
332
|
+
)
|
289
333
|
self.process_batch_result_disagg_prefill(tmp_batch, tmp_result)
|
290
334
|
|
291
335
|
if len(self.disagg_prefill_inflight_queue) > 0:
|
@@ -294,6 +338,7 @@ class SchedulerDisaggregationPrefillMixin:
|
|
294
338
|
if batch is None and len(self.disagg_prefill_inflight_queue) == 0:
|
295
339
|
self.check_memory()
|
296
340
|
self.new_token_ratio = self.init_new_token_ratio
|
341
|
+
self.maybe_sleep_on_idle()
|
297
342
|
|
298
343
|
self.last_batch = batch
|
299
344
|
# HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
|
@@ -307,7 +352,7 @@ class SchedulerDisaggregationPrefillMixin:
|
|
307
352
|
launch_done: Optional[threading.Event] = None,
|
308
353
|
) -> None:
|
309
354
|
"""
|
310
|
-
Transfer kv for prefill completed requests and add it into
|
355
|
+
Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
|
311
356
|
Adapted from process_batch_result_prefill
|
312
357
|
"""
|
313
358
|
(
|
@@ -323,7 +368,7 @@ class SchedulerDisaggregationPrefillMixin:
|
|
323
368
|
)
|
324
369
|
|
325
370
|
logprob_pt = 0
|
326
|
-
# Transfer kv for prefill completed requests and add it into
|
371
|
+
# Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
|
327
372
|
if self.enable_overlap:
|
328
373
|
# wait
|
329
374
|
logits_output, next_token_ids, _ = self.tp_worker.resolve_last_batch_result(
|
@@ -340,6 +385,8 @@ class SchedulerDisaggregationPrefillMixin:
|
|
340
385
|
logits_output.input_token_logprobs = tuple(
|
341
386
|
logits_output.input_token_logprobs.tolist()
|
342
387
|
)
|
388
|
+
|
389
|
+
hidden_state_offset = 0
|
343
390
|
for i, (req, next_token_id) in enumerate(
|
344
391
|
zip(batch.reqs, next_token_ids, strict=True)
|
345
392
|
):
|
@@ -349,6 +396,16 @@ class SchedulerDisaggregationPrefillMixin:
|
|
349
396
|
req.output_ids.append(next_token_id)
|
350
397
|
self.tree_cache.cache_unfinished_req(req) # update the tree and lock
|
351
398
|
self.disagg_prefill_inflight_queue.append(req)
|
399
|
+
if logits_output.hidden_states is not None:
|
400
|
+
last_hidden_index = (
|
401
|
+
hidden_state_offset + extend_input_len_per_req[i] - 1
|
402
|
+
)
|
403
|
+
req.hidden_states_tensor = (
|
404
|
+
logits_output.hidden_states[last_hidden_index].cpu().clone()
|
405
|
+
)
|
406
|
+
hidden_state_offset += extend_input_len_per_req[i]
|
407
|
+
else:
|
408
|
+
req.hidden_states_tensor = None
|
352
409
|
if req.return_logprob:
|
353
410
|
assert extend_logprob_start_len_per_req is not None
|
354
411
|
assert extend_input_len_per_req is not None
|
@@ -395,11 +452,15 @@ class SchedulerDisaggregationPrefillMixin:
|
|
395
452
|
# We need to remove the sync in the following function for overlap schedule.
|
396
453
|
self.set_next_batch_sampling_info_done(batch)
|
397
454
|
|
398
|
-
def process_disagg_prefill_inflight_queue(
|
455
|
+
def process_disagg_prefill_inflight_queue(
|
456
|
+
self: Scheduler, rids_to_check: Optional[List[str]] = None
|
457
|
+
) -> List[Req]:
|
399
458
|
"""
|
400
459
|
Poll the requests in the middle of transfer. If done, return the request.
|
460
|
+
rids_to_check: For PP, on rank > 0, check the rids from the previous rank has consensus with the current rank.
|
401
461
|
"""
|
402
|
-
|
462
|
+
if len(self.disagg_prefill_inflight_queue) == 0:
|
463
|
+
return []
|
403
464
|
|
404
465
|
done_reqs = []
|
405
466
|
|
@@ -411,6 +472,14 @@ class SchedulerDisaggregationPrefillMixin:
|
|
411
472
|
undone_reqs: List[Req] = []
|
412
473
|
# Check .poll() for the reqs in disagg_prefill_inflight_queue. If Success, respond to the client and remove it from the queue
|
413
474
|
for req, poll in zip(self.disagg_prefill_inflight_queue, polls):
|
475
|
+
|
476
|
+
if rids_to_check is not None:
|
477
|
+
if req.rid not in rids_to_check:
|
478
|
+
undone_reqs.append(req)
|
479
|
+
continue
|
480
|
+
|
481
|
+
assert poll == KVPoll.Success or poll == KVPoll.Failed
|
482
|
+
|
414
483
|
if poll in [KVPoll.WaitingForInput, KVPoll.Transferring]:
|
415
484
|
undone_reqs.append(req)
|
416
485
|
elif poll == KVPoll.Success: # transfer done
|
@@ -432,11 +501,8 @@ class SchedulerDisaggregationPrefillMixin:
|
|
432
501
|
req, error_message, status_code=HTTPStatus.INTERNAL_SERVER_ERROR
|
433
502
|
)
|
434
503
|
done_reqs.append(req)
|
435
|
-
|
436
|
-
|
437
|
-
self.disagg_prefill_bootstrap_queue.req_to_metadata_buffer_idx_allocator.free(
|
438
|
-
req.metadata_buffer_index
|
439
|
-
)
|
504
|
+
else:
|
505
|
+
assert False, f"Unexpected polling state {poll=}"
|
440
506
|
|
441
507
|
# Stream requests which have finished transfer
|
442
508
|
self.stream_output(
|
@@ -444,9 +510,32 @@ class SchedulerDisaggregationPrefillMixin:
|
|
444
510
|
any(req.return_logprob for req in done_reqs),
|
445
511
|
None,
|
446
512
|
)
|
513
|
+
for req in done_reqs:
|
514
|
+
req: Req
|
515
|
+
self.req_to_metadata_buffer_idx_allocator.free(req.metadata_buffer_index)
|
516
|
+
req.metadata_buffer_index = -1
|
447
517
|
|
448
518
|
self.disagg_prefill_inflight_queue = undone_reqs
|
449
519
|
|
520
|
+
return done_reqs
|
521
|
+
|
522
|
+
def get_transferred_rids(self: Scheduler) -> List[str]:
|
523
|
+
"""
|
524
|
+
Used by PP, get the transferred rids but **do not pop**
|
525
|
+
"""
|
526
|
+
polls = poll_and_all_reduce(
|
527
|
+
[req.disagg_kv_sender for req in self.disagg_prefill_inflight_queue],
|
528
|
+
self.tp_worker.get_tp_group().cpu_group,
|
529
|
+
)
|
530
|
+
|
531
|
+
transferred_rids: List[str] = []
|
532
|
+
|
533
|
+
for req, poll in zip(self.disagg_prefill_inflight_queue, polls):
|
534
|
+
if poll == KVPoll.Success or poll == KVPoll.Failed:
|
535
|
+
transferred_rids.append(req.rid)
|
536
|
+
|
537
|
+
return transferred_rids
|
538
|
+
|
450
539
|
def process_prefill_chunk(self: Scheduler) -> None:
|
451
540
|
if self.last_batch and self.last_batch.forward_mode.is_extend():
|
452
541
|
if self.chunked_req:
|