sglang 0.4.7__py3-none-any.whl → 0.4.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +7 -0
- sglang/bench_one_batch.py +8 -6
- sglang/bench_serving.py +1 -1
- sglang/lang/interpreter.py +40 -1
- sglang/lang/ir.py +27 -0
- sglang/math_utils.py +8 -0
- sglang/srt/_custom_ops.py +2 -2
- sglang/srt/code_completion_parser.py +2 -44
- sglang/srt/configs/model_config.py +6 -0
- sglang/srt/constants.py +3 -0
- sglang/srt/conversation.py +19 -3
- sglang/srt/custom_op.py +5 -1
- sglang/srt/disaggregation/base/__init__.py +1 -1
- sglang/srt/disaggregation/base/conn.py +25 -11
- sglang/srt/disaggregation/common/__init__.py +5 -1
- sglang/srt/disaggregation/common/utils.py +42 -0
- sglang/srt/disaggregation/decode.py +211 -72
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
- sglang/srt/disaggregation/fake/__init__.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +15 -9
- sglang/srt/disaggregation/mini_lb.py +34 -4
- sglang/srt/disaggregation/mooncake/__init__.py +1 -1
- sglang/srt/disaggregation/mooncake/conn.py +30 -29
- sglang/srt/disaggregation/nixl/__init__.py +6 -1
- sglang/srt/disaggregation/nixl/conn.py +17 -12
- sglang/srt/disaggregation/prefill.py +144 -55
- sglang/srt/disaggregation/utils.py +155 -123
- sglang/srt/distributed/parallel_state.py +12 -4
- sglang/srt/entrypoints/engine.py +37 -29
- sglang/srt/entrypoints/http_server.py +153 -72
- sglang/srt/entrypoints/http_server_engine.py +0 -3
- sglang/srt/entrypoints/openai/__init__.py +0 -0
- sglang/srt/{openai_api → entrypoints/openai}/protocol.py +84 -10
- sglang/srt/entrypoints/openai/serving_base.py +149 -0
- sglang/srt/entrypoints/openai/serving_chat.py +921 -0
- sglang/srt/entrypoints/openai/serving_completions.py +424 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
- sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
- sglang/srt/entrypoints/openai/serving_score.py +61 -0
- sglang/srt/entrypoints/openai/usage_processor.py +81 -0
- sglang/srt/entrypoints/openai/utils.py +72 -0
- sglang/srt/eplb_simulator/__init__.py +1 -0
- sglang/srt/eplb_simulator/reader.py +51 -0
- sglang/srt/function_call/base_format_detector.py +7 -4
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +64 -10
- sglang/srt/function_call/function_call_parser.py +6 -6
- sglang/srt/function_call/llama32_detector.py +1 -1
- sglang/srt/function_call/mistral_detector.py +1 -1
- sglang/srt/function_call/pythonic_detector.py +1 -1
- sglang/srt/function_call/qwen25_detector.py +1 -1
- sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
- sglang/srt/layers/activation.py +40 -3
- sglang/srt/layers/attention/aiter_backend.py +20 -4
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/cutlass_mla_backend.py +39 -15
- sglang/srt/layers/attention/flashattention_backend.py +71 -72
- sglang/srt/layers/attention/flashinfer_backend.py +10 -8
- sglang/srt/layers/attention/flashinfer_mla_backend.py +29 -28
- sglang/srt/layers/attention/flashmla_backend.py +7 -12
- sglang/srt/layers/attention/tbo_backend.py +3 -3
- sglang/srt/layers/attention/triton_backend.py +138 -130
- sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
- sglang/srt/layers/attention/vision.py +51 -24
- sglang/srt/layers/communicator.py +28 -10
- sglang/srt/layers/dp_attention.py +11 -2
- sglang/srt/layers/layernorm.py +29 -2
- sglang/srt/layers/linear.py +0 -4
- sglang/srt/layers/logits_processor.py +2 -14
- sglang/srt/layers/moe/ep_moe/kernels.py +165 -7
- sglang/srt/layers/moe/ep_moe/layer.py +249 -33
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -4
- sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
- sglang/srt/layers/moe/topk.py +107 -12
- sglang/srt/layers/pooler.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
- sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
- sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
- sglang/srt/layers/quantization/fp8.py +25 -17
- sglang/srt/layers/quantization/fp8_kernel.py +44 -15
- sglang/srt/layers/quantization/fp8_utils.py +87 -22
- sglang/srt/layers/quantization/modelopt_quant.py +62 -8
- sglang/srt/layers/quantization/utils.py +5 -2
- sglang/srt/layers/radix_attention.py +2 -3
- sglang/srt/layers/rotary_embedding.py +42 -2
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/lora_manager.py +249 -105
- sglang/srt/lora/mem_pool.py +53 -50
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +33 -14
- sglang/srt/managers/io_struct.py +31 -10
- sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
- sglang/srt/managers/multimodal_processors/vila.py +85 -0
- sglang/srt/managers/schedule_batch.py +79 -37
- sglang/srt/managers/schedule_policy.py +70 -56
- sglang/srt/managers/scheduler.py +220 -79
- sglang/srt/managers/template_manager.py +226 -0
- sglang/srt/managers/tokenizer_manager.py +40 -10
- sglang/srt/managers/tp_worker.py +12 -2
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
- sglang/srt/mem_cache/base_prefix_cache.py +52 -8
- sglang/srt/mem_cache/chunk_cache.py +11 -15
- sglang/srt/mem_cache/hiradix_cache.py +38 -25
- sglang/srt/mem_cache/memory_pool.py +213 -505
- sglang/srt/mem_cache/memory_pool_host.py +380 -0
- sglang/srt/mem_cache/radix_cache.py +56 -28
- sglang/srt/model_executor/cuda_graph_runner.py +198 -100
- sglang/srt/model_executor/forward_batch_info.py +32 -10
- sglang/srt/model_executor/model_runner.py +28 -12
- sglang/srt/model_loader/loader.py +16 -2
- sglang/srt/model_loader/weight_utils.py +11 -2
- sglang/srt/models/bert.py +113 -13
- sglang/srt/models/deepseek_nextn.py +29 -27
- sglang/srt/models/deepseek_v2.py +213 -173
- sglang/srt/models/glm4.py +312 -0
- sglang/srt/models/internvl.py +46 -102
- sglang/srt/models/mimo_mtp.py +2 -18
- sglang/srt/models/roberta.py +117 -9
- sglang/srt/models/vila.py +305 -0
- sglang/srt/reasoning_parser.py +21 -11
- sglang/srt/sampling/sampling_batch_info.py +24 -0
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +351 -238
- sglang/srt/speculative/build_eagle_tree.py +1 -1
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -9
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +130 -14
- sglang/srt/speculative/eagle_utils.py +468 -116
- sglang/srt/speculative/eagle_worker.py +258 -84
- sglang/srt/torch_memory_saver_adapter.py +19 -15
- sglang/srt/two_batch_overlap.py +4 -2
- sglang/srt/utils.py +235 -11
- sglang/test/attention/test_prefix_chunk_info.py +2 -0
- sglang/test/runners.py +38 -3
- sglang/test/test_block_fp8.py +1 -0
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
- sglang/test/test_block_fp8_ep.py +2 -0
- sglang/test/test_utils.py +4 -1
- sglang/utils.py +9 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/METADATA +8 -14
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/RECORD +150 -128
- sglang/srt/entrypoints/verl_engine.py +0 -179
- sglang/srt/openai_api/adapter.py +0 -1990
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -21,20 +21,19 @@ Life cycle of a request in the decode server
|
|
21
21
|
from __future__ import annotations
|
22
22
|
|
23
23
|
import logging
|
24
|
-
import os
|
25
24
|
from collections import deque
|
26
25
|
from dataclasses import dataclass
|
27
26
|
from http import HTTPStatus
|
28
27
|
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
29
28
|
|
30
|
-
import numpy as np
|
31
29
|
import torch
|
32
30
|
from torch.distributed import ProcessGroup
|
33
31
|
|
34
|
-
from sglang.srt.
|
32
|
+
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
|
33
|
+
from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVPoll
|
35
34
|
from sglang.srt.disaggregation.utils import (
|
35
|
+
FAKE_BOOTSTRAP_HOST,
|
36
36
|
DisaggregationMode,
|
37
|
-
FakeBootstrapHost,
|
38
37
|
KVClassType,
|
39
38
|
MetadataBuffers,
|
40
39
|
ReqToMetadataIdxAllocator,
|
@@ -46,10 +45,12 @@ from sglang.srt.disaggregation.utils import (
|
|
46
45
|
prepare_abort,
|
47
46
|
)
|
48
47
|
from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch
|
48
|
+
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
49
49
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
50
|
-
from sglang.srt.mem_cache.memory_pool import
|
50
|
+
from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
|
51
51
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
52
52
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
53
|
+
from sglang.srt.utils import require_mlp_sync
|
53
54
|
|
54
55
|
logger = logging.getLogger(__name__)
|
55
56
|
|
@@ -86,7 +87,7 @@ class DecodeReqToTokenPool:
|
|
86
87
|
self.max_context_len = max_context_len
|
87
88
|
self.device = device
|
88
89
|
self.pre_alloc_size = pre_alloc_size
|
89
|
-
with memory_saver_adapter.region():
|
90
|
+
with memory_saver_adapter.region(tag=GPU_MEMORY_TYPE_KV_CACHE):
|
90
91
|
self.req_to_token = torch.zeros(
|
91
92
|
(size + pre_alloc_size, max_context_len),
|
92
93
|
dtype=torch.int32,
|
@@ -135,7 +136,7 @@ class DecodePreallocQueue:
|
|
135
136
|
def __init__(
|
136
137
|
self,
|
137
138
|
req_to_token_pool: ReqToTokenPool,
|
138
|
-
token_to_kv_pool_allocator:
|
139
|
+
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
|
139
140
|
draft_token_to_kv_pool: Optional[KVCache],
|
140
141
|
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
|
141
142
|
metadata_buffers: MetadataBuffers,
|
@@ -145,7 +146,12 @@ class DecodePreallocQueue:
|
|
145
146
|
gloo_group: ProcessGroup,
|
146
147
|
tp_rank: int,
|
147
148
|
tp_size: int,
|
149
|
+
dp_size: int,
|
150
|
+
gpu_id: int,
|
148
151
|
bootstrap_port: int,
|
152
|
+
max_total_num_tokens: int,
|
153
|
+
prefill_pp_size: int,
|
154
|
+
num_reserved_decode_tokens: int,
|
149
155
|
transfer_backend: TransferBackend,
|
150
156
|
):
|
151
157
|
self.req_to_token_pool = req_to_token_pool
|
@@ -161,25 +167,33 @@ class DecodePreallocQueue:
|
|
161
167
|
self.gloo_group = gloo_group
|
162
168
|
self.tp_rank = tp_rank
|
163
169
|
self.tp_size = tp_size
|
170
|
+
self.dp_size = dp_size
|
171
|
+
self.gpu_id = gpu_id
|
164
172
|
self.bootstrap_port = bootstrap_port
|
165
|
-
|
166
|
-
self.
|
167
|
-
|
168
|
-
|
169
|
-
|
173
|
+
self.max_total_num_tokens = max_total_num_tokens
|
174
|
+
self.prefill_pp_size = prefill_pp_size
|
175
|
+
self.num_reserved_decode_tokens = num_reserved_decode_tokens
|
176
|
+
self.transfer_backend = transfer_backend
|
170
177
|
# Queue for requests pending pre-allocation
|
171
178
|
self.queue: List[DecodeRequest] = []
|
172
|
-
self.
|
179
|
+
self.retracted_queue: List[Req] = []
|
180
|
+
self.prefill_pp_size = prefill_pp_size
|
173
181
|
self.kv_manager = self._init_kv_manager()
|
174
182
|
|
175
183
|
def _init_kv_manager(self) -> BaseKVManager:
|
176
|
-
|
177
|
-
kv_args
|
184
|
+
kv_args_class = get_kv_class(self.transfer_backend, KVClassType.KVARGS)
|
185
|
+
kv_args = kv_args_class()
|
186
|
+
|
187
|
+
attn_tp_size = self.tp_size // self.dp_size
|
188
|
+
kv_args.engine_rank = self.tp_rank % (attn_tp_size)
|
189
|
+
kv_args.decode_tp_size = attn_tp_size
|
190
|
+
kv_args.prefill_pp_size = self.prefill_pp_size
|
178
191
|
kv_data_ptrs, kv_data_lens, kv_item_lens = (
|
179
192
|
self.token_to_kv_pool.get_contiguous_buf_infos()
|
180
193
|
)
|
181
|
-
|
182
194
|
if self.draft_token_to_kv_pool is not None:
|
195
|
+
# We should also transfer draft model kv cache. The indices are
|
196
|
+
# always shared with a target model.
|
183
197
|
draft_kv_data_ptrs, draft_kv_data_lens, draft_kv_item_lens = (
|
184
198
|
self.draft_token_to_kv_pool.get_contiguous_buf_infos()
|
185
199
|
)
|
@@ -194,6 +208,7 @@ class DecodePreallocQueue:
|
|
194
208
|
kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = (
|
195
209
|
self.metadata_buffers.get_buf_infos()
|
196
210
|
)
|
211
|
+
|
197
212
|
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
|
198
213
|
kv_args.gpu_id = self.scheduler.gpu_id
|
199
214
|
kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
|
@@ -205,27 +220,84 @@ class DecodePreallocQueue:
|
|
205
220
|
)
|
206
221
|
return kv_manager
|
207
222
|
|
208
|
-
def add(self, req: Req) -> None:
|
223
|
+
def add(self, req: Req, is_retracted: bool = False) -> None:
|
209
224
|
"""Add a request to the pending queue."""
|
210
|
-
if req
|
211
|
-
|
212
|
-
|
225
|
+
if self._check_if_req_exceed_kv_capacity(req):
|
226
|
+
return
|
227
|
+
|
228
|
+
if is_retracted:
|
229
|
+
self.retracted_queue.append(req)
|
213
230
|
else:
|
214
|
-
|
215
|
-
|
231
|
+
if req.bootstrap_host == FAKE_BOOTSTRAP_HOST:
|
232
|
+
kv_receiver_class = get_kv_class(
|
233
|
+
TransferBackend.FAKE, KVClassType.RECEIVER
|
234
|
+
)
|
235
|
+
else:
|
236
|
+
kv_receiver_class = get_kv_class(
|
237
|
+
self.transfer_backend, KVClassType.RECEIVER
|
238
|
+
)
|
239
|
+
|
240
|
+
kv_receiver = kv_receiver_class(
|
241
|
+
mgr=self.kv_manager,
|
242
|
+
bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
|
243
|
+
bootstrap_room=req.bootstrap_room,
|
244
|
+
data_parallel_rank=req.data_parallel_rank,
|
216
245
|
)
|
217
|
-
kv_receiver = kv_receiver_class(
|
218
|
-
mgr=self.kv_manager,
|
219
|
-
bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
|
220
|
-
bootstrap_room=req.bootstrap_room,
|
221
|
-
data_parallel_rank=req.data_parallel_rank,
|
222
|
-
)
|
223
|
-
self.queue.append(DecodeRequest(req=req, kv_receiver=kv_receiver))
|
224
246
|
|
225
|
-
|
247
|
+
self.queue.append(
|
248
|
+
DecodeRequest(req=req, kv_receiver=kv_receiver, waiting_for_input=False)
|
249
|
+
)
|
250
|
+
|
251
|
+
def _check_if_req_exceed_kv_capacity(self, req: Req) -> bool:
|
252
|
+
if len(req.origin_input_ids) > self.max_total_num_tokens:
|
253
|
+
message = f"Request {req.rid} exceeds the maximum number of tokens: {len(req.origin_input_ids)} > {self.max_total_num_tokens}"
|
254
|
+
logger.error(message)
|
255
|
+
prepare_abort(req, message)
|
256
|
+
self.scheduler.stream_output([req], req.return_logprob)
|
257
|
+
return True
|
258
|
+
return False
|
259
|
+
|
260
|
+
def extend(self, reqs: List[Req], is_retracted: bool = False) -> None:
|
226
261
|
"""Add a request to the pending queue."""
|
227
262
|
for req in reqs:
|
228
|
-
self.add(req)
|
263
|
+
self.add(req, is_retracted=is_retracted)
|
264
|
+
|
265
|
+
def resume_retracted_reqs(self) -> List[Req]:
|
266
|
+
# TODO refactor the scheduling part, reuse with the unified engine logic as much as possible
|
267
|
+
|
268
|
+
# allocate memory
|
269
|
+
resumed_reqs = []
|
270
|
+
indices_to_remove = set()
|
271
|
+
allocatable_tokens = self._allocatable_tokens(count_retracted=False)
|
272
|
+
|
273
|
+
for i, req in enumerate(self.retracted_queue):
|
274
|
+
if self.req_to_token_pool.available_size() <= 0:
|
275
|
+
break
|
276
|
+
|
277
|
+
required_tokens_for_request = (
|
278
|
+
len(req.origin_input_ids)
|
279
|
+
+ len(req.output_ids)
|
280
|
+
+ self.num_reserved_decode_tokens
|
281
|
+
)
|
282
|
+
if required_tokens_for_request > allocatable_tokens:
|
283
|
+
break
|
284
|
+
|
285
|
+
resumed_reqs.append(req)
|
286
|
+
indices_to_remove.add(i)
|
287
|
+
req.is_retracted = False
|
288
|
+
self._pre_alloc(req)
|
289
|
+
allocatable_tokens -= required_tokens_for_request
|
290
|
+
|
291
|
+
# load from cpu, release the cpu copy
|
292
|
+
req.load_kv_cache(self.req_to_token_pool, self.token_to_kv_pool_allocator)
|
293
|
+
|
294
|
+
self.retracted_queue = [
|
295
|
+
entry
|
296
|
+
for i, entry in enumerate(self.retracted_queue)
|
297
|
+
if i not in indices_to_remove
|
298
|
+
]
|
299
|
+
|
300
|
+
return resumed_reqs
|
229
301
|
|
230
302
|
def _update_handshake_waiters(self) -> None:
|
231
303
|
if not self.queue:
|
@@ -255,6 +327,8 @@ class DecodePreallocQueue:
|
|
255
327
|
error_message,
|
256
328
|
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
|
257
329
|
)
|
330
|
+
else:
|
331
|
+
raise ValueError(f"Unexpected poll case: {poll}")
|
258
332
|
|
259
333
|
def pop_preallocated(self) -> List[DecodeRequest]:
|
260
334
|
"""Pop the preallocated requests from the pending queue (FIFO)."""
|
@@ -262,8 +336,16 @@ class DecodePreallocQueue:
|
|
262
336
|
|
263
337
|
preallocated_reqs = []
|
264
338
|
indices_to_remove = set()
|
265
|
-
allocatable_tokens = self._allocatable_tokens()
|
266
339
|
|
340
|
+
# We need to make sure that the sum of inflight tokens and allocatable tokens is greater than maximum input+output length of each inflight request
|
341
|
+
# Otherwise it is possible for one request running decode out of memory, while all other requests are in the transfer queue that cannot be retracted.
|
342
|
+
retractable_tokens = sum(
|
343
|
+
len(r.origin_input_ids) + len(r.output_ids)
|
344
|
+
for r in self.scheduler.running_batch.reqs
|
345
|
+
)
|
346
|
+
allocatable_tokens = self._allocatable_tokens(
|
347
|
+
retractable_tokens=retractable_tokens, count_retracted=True
|
348
|
+
)
|
267
349
|
# First, remove all failed requests from the queue
|
268
350
|
for i, decode_req in enumerate(self.queue):
|
269
351
|
if isinstance(decode_req.req.finished_reason, FINISH_ABORT):
|
@@ -272,6 +354,7 @@ class DecodePreallocQueue:
|
|
272
354
|
)
|
273
355
|
indices_to_remove.add(i)
|
274
356
|
|
357
|
+
# Then, preallocate the remaining requests if possible
|
275
358
|
for i, decode_req in enumerate(self.queue):
|
276
359
|
if i in indices_to_remove:
|
277
360
|
continue
|
@@ -285,10 +368,23 @@ class DecodePreallocQueue:
|
|
285
368
|
if self.req_to_metadata_buffer_idx_allocator.available_size() <= 0:
|
286
369
|
break
|
287
370
|
|
371
|
+
# Memory estimation: don't add if the projected memory cannot be met
|
372
|
+
# TODO: add new_token ratio
|
373
|
+
origin_input_len = len(decode_req.req.origin_input_ids)
|
288
374
|
required_tokens_for_request = (
|
289
|
-
|
375
|
+
origin_input_len + self.num_reserved_decode_tokens
|
290
376
|
)
|
291
377
|
|
378
|
+
if (
|
379
|
+
max(
|
380
|
+
required_tokens_for_request,
|
381
|
+
origin_input_len
|
382
|
+
+ decode_req.req.sampling_params.max_new_tokens
|
383
|
+
- retractable_tokens,
|
384
|
+
)
|
385
|
+
> allocatable_tokens
|
386
|
+
):
|
387
|
+
break
|
292
388
|
if required_tokens_for_request > allocatable_tokens:
|
293
389
|
break
|
294
390
|
|
@@ -301,7 +397,6 @@ class DecodePreallocQueue:
|
|
301
397
|
]
|
302
398
|
.cpu()
|
303
399
|
.numpy()
|
304
|
-
.astype(np.int64)
|
305
400
|
)
|
306
401
|
|
307
402
|
decode_req.metadata_buffer_index = (
|
@@ -321,15 +416,35 @@ class DecodePreallocQueue:
|
|
321
416
|
|
322
417
|
return preallocated_reqs
|
323
418
|
|
324
|
-
def _allocatable_tokens(
|
325
|
-
|
326
|
-
|
327
|
-
|
419
|
+
def _allocatable_tokens(
|
420
|
+
self, retractable_tokens: Optional[int] = None, count_retracted: bool = True
|
421
|
+
) -> int:
|
422
|
+
need_space_for_single_req = (
|
423
|
+
max(
|
424
|
+
[
|
425
|
+
x.sampling_params.max_new_tokens
|
426
|
+
+ len(x.origin_input_ids)
|
427
|
+
- retractable_tokens
|
428
|
+
for x in self.scheduler.running_batch.reqs
|
429
|
+
]
|
430
|
+
)
|
431
|
+
if retractable_tokens is not None
|
432
|
+
and len(self.scheduler.running_batch.reqs) > 0
|
433
|
+
else 0
|
434
|
+
)
|
435
|
+
|
436
|
+
available_size = self.token_to_kv_pool_allocator.available_size()
|
437
|
+
|
438
|
+
allocatable_tokens = available_size - max(
|
439
|
+
# preserve some space for future decode
|
440
|
+
self.num_reserved_decode_tokens
|
328
441
|
* (
|
329
442
|
len(self.scheduler.running_batch.reqs)
|
330
443
|
+ len(self.transfer_queue.queue)
|
331
444
|
+ len(self.scheduler.waiting_queue)
|
332
|
-
)
|
445
|
+
),
|
446
|
+
# make sure each request can finish if reach max_tokens with all other requests retracted
|
447
|
+
need_space_for_single_req,
|
333
448
|
)
|
334
449
|
|
335
450
|
# Note: if the last fake extend just finishes, and we enter `pop_preallocated` immediately in the next iteration
|
@@ -342,15 +457,27 @@ class DecodePreallocQueue:
|
|
342
457
|
self.scheduler.last_batch.reqs
|
343
458
|
)
|
344
459
|
|
460
|
+
if count_retracted:
|
461
|
+
allocatable_tokens -= sum(
|
462
|
+
[
|
463
|
+
len(req.origin_input_ids)
|
464
|
+
+ len(req.output_ids)
|
465
|
+
+ self.num_reserved_decode_tokens
|
466
|
+
for req in self.retracted_queue
|
467
|
+
]
|
468
|
+
)
|
345
469
|
return allocatable_tokens
|
346
470
|
|
347
471
|
def _pre_alloc(self, req: Req) -> torch.Tensor:
|
348
472
|
"""Pre-allocate the memory for req_to_token and token_kv_pool"""
|
349
473
|
req_pool_indices = self.req_to_token_pool.alloc(1)
|
350
474
|
|
351
|
-
assert
|
475
|
+
assert (
|
476
|
+
req_pool_indices is not None
|
477
|
+
), "req_pool_indices is full! There is a bug in memory estimation."
|
352
478
|
|
353
479
|
req.req_pool_idx = req_pool_indices[0]
|
480
|
+
|
354
481
|
if self.token_to_kv_pool_allocator.page_size == 1:
|
355
482
|
kv_loc = self.token_to_kv_pool_allocator.alloc(
|
356
483
|
len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
|
@@ -375,7 +502,10 @@ class DecodePreallocQueue:
|
|
375
502
|
),
|
376
503
|
extend_num_tokens=num_tokens,
|
377
504
|
)
|
378
|
-
|
505
|
+
|
506
|
+
assert (
|
507
|
+
kv_loc is not None
|
508
|
+
), "KV cache is full! There is a bug in memory estimation."
|
379
509
|
|
380
510
|
self.req_to_token_pool.write((req.req_pool_idx, slice(0, len(kv_loc))), kv_loc)
|
381
511
|
|
@@ -395,6 +525,7 @@ class DecodeTransferQueue:
|
|
395
525
|
self,
|
396
526
|
gloo_group: ProcessGroup,
|
397
527
|
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
|
528
|
+
tp_rank: int,
|
398
529
|
metadata_buffers: MetadataBuffers,
|
399
530
|
scheduler: Scheduler,
|
400
531
|
tree_cache: BasePrefixCache,
|
@@ -402,9 +533,11 @@ class DecodeTransferQueue:
|
|
402
533
|
self.queue: List[DecodeRequest] = []
|
403
534
|
self.gloo_group = gloo_group
|
404
535
|
self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
|
536
|
+
self.tp_rank = tp_rank
|
405
537
|
self.metadata_buffers = metadata_buffers
|
406
538
|
self.scheduler = scheduler
|
407
539
|
self.tree_cache = tree_cache
|
540
|
+
self.spec_algorithm = scheduler.spec_algorithm
|
408
541
|
|
409
542
|
def add(self, decode_req: DecodeRequest) -> None:
|
410
543
|
self.queue.append(decode_req)
|
@@ -412,10 +545,9 @@ class DecodeTransferQueue:
|
|
412
545
|
def extend(self, decode_reqs: List[DecodeRequest]) -> None:
|
413
546
|
self.queue.extend(decode_reqs)
|
414
547
|
|
415
|
-
def pop_transferred(self) -> List[
|
548
|
+
def pop_transferred(self) -> List[Req]:
|
416
549
|
if not self.queue:
|
417
550
|
return []
|
418
|
-
|
419
551
|
polls = poll_and_all_reduce(
|
420
552
|
[decode_req.kv_receiver for decode_req in self.queue], self.gloo_group
|
421
553
|
)
|
@@ -424,7 +556,7 @@ class DecodeTransferQueue:
|
|
424
556
|
indices_to_remove = set()
|
425
557
|
for i, (decode_req, poll) in enumerate(zip(self.queue, polls)):
|
426
558
|
if poll == KVPoll.Failed:
|
427
|
-
error_message = f"Decode transfer failed for request rank={self.
|
559
|
+
error_message = f"Decode transfer failed for request rank={self.tp_rank} {decode_req.req.rid=} {decode_req.req.bootstrap_room=}"
|
428
560
|
try:
|
429
561
|
decode_req.kv_receiver.failure_exception()
|
430
562
|
except Exception as e:
|
@@ -447,6 +579,7 @@ class DecodeTransferQueue:
|
|
447
579
|
idx = decode_req.metadata_buffer_index
|
448
580
|
(
|
449
581
|
output_id,
|
582
|
+
output_hidden_states,
|
450
583
|
output_token_logprobs_val,
|
451
584
|
output_token_logprobs_idx,
|
452
585
|
output_top_logprobs_val,
|
@@ -454,7 +587,8 @@ class DecodeTransferQueue:
|
|
454
587
|
) = self.metadata_buffers.get_buf(idx)
|
455
588
|
|
456
589
|
decode_req.req.output_ids.append(output_id[0].item())
|
457
|
-
|
590
|
+
if not self.spec_algorithm.is_none():
|
591
|
+
decode_req.req.hidden_states_tensor = output_hidden_states
|
458
592
|
if decode_req.req.return_logprob:
|
459
593
|
decode_req.req.output_token_logprobs_val.append(
|
460
594
|
output_token_logprobs_val[0].item()
|
@@ -499,15 +633,6 @@ class DecodeTransferQueue:
|
|
499
633
|
|
500
634
|
class SchedulerDisaggregationDecodeMixin:
|
501
635
|
|
502
|
-
def _prepare_idle_batch_and_run(self, batch, delay_process=False):
|
503
|
-
batch, _ = self.prepare_dp_attn_batch(batch)
|
504
|
-
result = None
|
505
|
-
if batch:
|
506
|
-
result = self.run_batch(batch)
|
507
|
-
if not delay_process:
|
508
|
-
self.process_batch_result(batch, result)
|
509
|
-
return batch, result
|
510
|
-
|
511
636
|
@torch.no_grad()
|
512
637
|
def event_loop_normal_disagg_decode(self: Scheduler):
|
513
638
|
"""A normal scheduler loop for decode worker in disaggregation mode."""
|
@@ -520,10 +645,7 @@ class SchedulerDisaggregationDecodeMixin:
|
|
520
645
|
batch = self.get_next_disagg_decode_batch_to_run()
|
521
646
|
self.cur_batch = batch
|
522
647
|
|
523
|
-
|
524
|
-
self.server_args.enable_dp_attention
|
525
|
-
or self.server_args.enable_sp_layernorm
|
526
|
-
)
|
648
|
+
prepare_mlp_sync_flag = require_mlp_sync(self.server_args)
|
527
649
|
|
528
650
|
if batch:
|
529
651
|
# Generate fake extend output.
|
@@ -532,24 +654,26 @@ class SchedulerDisaggregationDecodeMixin:
|
|
532
654
|
self.stream_output(
|
533
655
|
batch.reqs, any(req.return_logprob for req in batch.reqs)
|
534
656
|
)
|
535
|
-
if
|
657
|
+
if prepare_mlp_sync_flag:
|
536
658
|
self._prepare_idle_batch_and_run(None)
|
537
659
|
else:
|
538
|
-
if
|
539
|
-
self.
|
660
|
+
if prepare_mlp_sync_flag:
|
661
|
+
self.prepare_mlp_sync_batch(batch)
|
540
662
|
result = self.run_batch(batch)
|
541
663
|
self.process_batch_result(batch, result)
|
542
|
-
elif
|
664
|
+
elif prepare_mlp_sync_flag:
|
543
665
|
batch, _ = self._prepare_idle_batch_and_run(None)
|
544
666
|
|
545
667
|
if batch is None and (
|
546
|
-
len(self.
|
668
|
+
len(self.waiting_queue)
|
669
|
+
+ len(self.disagg_decode_transfer_queue.queue)
|
547
670
|
+ len(self.disagg_decode_prealloc_queue.queue)
|
548
671
|
== 0
|
549
672
|
):
|
550
673
|
# When the server is idle, do self-check and re-init some states
|
551
674
|
self.check_memory()
|
552
675
|
self.new_token_ratio = self.init_new_token_ratio
|
676
|
+
self.maybe_sleep_on_idle()
|
553
677
|
|
554
678
|
self.last_batch = batch
|
555
679
|
|
@@ -568,10 +692,7 @@ class SchedulerDisaggregationDecodeMixin:
|
|
568
692
|
self.cur_batch = batch
|
569
693
|
last_batch_in_queue = False
|
570
694
|
|
571
|
-
|
572
|
-
self.server_args.enable_dp_attention
|
573
|
-
or self.server_args.enable_sp_layernorm
|
574
|
-
)
|
695
|
+
prepare_mlp_sync_flag = require_mlp_sync(self.server_args)
|
575
696
|
|
576
697
|
if batch:
|
577
698
|
# Generate fake extend output.
|
@@ -580,7 +701,7 @@ class SchedulerDisaggregationDecodeMixin:
|
|
580
701
|
self.stream_output(
|
581
702
|
batch.reqs, any(req.return_logprob for req in batch.reqs)
|
582
703
|
)
|
583
|
-
if
|
704
|
+
if prepare_mlp_sync_flag:
|
584
705
|
batch_, result = self._prepare_idle_batch_and_run(
|
585
706
|
None, delay_process=True
|
586
707
|
)
|
@@ -588,8 +709,8 @@ class SchedulerDisaggregationDecodeMixin:
|
|
588
709
|
result_queue.append((batch_.copy(), result))
|
589
710
|
last_batch_in_queue = True
|
590
711
|
else:
|
591
|
-
if
|
592
|
-
self.
|
712
|
+
if prepare_mlp_sync_flag:
|
713
|
+
self.prepare_mlp_sync_batch(batch)
|
593
714
|
result = self.run_batch(batch)
|
594
715
|
result_queue.append((batch.copy(), result))
|
595
716
|
|
@@ -604,7 +725,7 @@ class SchedulerDisaggregationDecodeMixin:
|
|
604
725
|
self.set_next_batch_sampling_info_done(tmp_batch)
|
605
726
|
last_batch_in_queue = True
|
606
727
|
|
607
|
-
elif
|
728
|
+
elif prepare_mlp_sync_flag:
|
608
729
|
batch, result = self._prepare_idle_batch_and_run(
|
609
730
|
None, delay_process=True
|
610
731
|
)
|
@@ -621,17 +742,28 @@ class SchedulerDisaggregationDecodeMixin:
|
|
621
742
|
self.process_batch_result(tmp_batch, tmp_result)
|
622
743
|
|
623
744
|
if batch is None and (
|
624
|
-
len(self.
|
745
|
+
len(self.waiting_queue)
|
746
|
+
+ len(self.disagg_decode_transfer_queue.queue)
|
625
747
|
+ len(self.disagg_decode_prealloc_queue.queue)
|
626
748
|
== 0
|
627
749
|
):
|
628
750
|
# When the server is idle, do self-check and re-init some states
|
629
751
|
self.check_memory()
|
630
752
|
self.new_token_ratio = self.init_new_token_ratio
|
753
|
+
self.maybe_sleep_on_idle()
|
631
754
|
|
632
755
|
self.last_batch = batch
|
633
756
|
self.last_batch_in_queue = last_batch_in_queue
|
634
757
|
|
758
|
+
def _prepare_idle_batch_and_run(self: Scheduler, batch, delay_process=False):
|
759
|
+
batch, _ = self.prepare_mlp_sync_batch(batch)
|
760
|
+
result = None
|
761
|
+
if batch:
|
762
|
+
result = self.run_batch(batch)
|
763
|
+
if not delay_process:
|
764
|
+
self.process_batch_result(batch, result)
|
765
|
+
return batch, result
|
766
|
+
|
635
767
|
def get_next_disagg_decode_batch_to_run(
|
636
768
|
self: Scheduler,
|
637
769
|
) -> Optional[Tuple[ScheduleBatch, bool]]:
|
@@ -714,6 +846,13 @@ class SchedulerDisaggregationDecodeMixin:
|
|
714
846
|
return new_batch
|
715
847
|
|
716
848
|
def process_decode_queue(self: Scheduler):
|
849
|
+
# try to resume retracted requests if there are enough space for another `num_reserved_decode_tokens` decode steps
|
850
|
+
resumed_reqs = self.disagg_decode_prealloc_queue.resume_retracted_reqs()
|
851
|
+
self.waiting_queue.extend(resumed_reqs)
|
852
|
+
if len(self.disagg_decode_prealloc_queue.retracted_queue) > 0:
|
853
|
+
# if there are still retracted requests, we do not allocate new requests
|
854
|
+
return
|
855
|
+
|
717
856
|
req_conns = self.disagg_decode_prealloc_queue.pop_preallocated()
|
718
857
|
self.disagg_decode_transfer_queue.extend(req_conns)
|
719
858
|
alloc_reqs = (
|
@@ -126,15 +126,16 @@ class ScheduleBatchDisaggregationDecodeMixin:
|
|
126
126
|
)
|
127
127
|
topk_index = topk_index.reshape(b, server_args.speculative_eagle_topk)
|
128
128
|
|
129
|
+
hidden_states_list = [req.hidden_states_tensor for req in self.reqs]
|
130
|
+
hidden_states = torch.stack(hidden_states_list, dim=0).to(self.device)
|
131
|
+
|
129
132
|
# local import to avoid circular import
|
130
133
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput
|
131
134
|
|
132
135
|
spec_info = EagleDraftInput(
|
133
136
|
topk_p=topk_p,
|
134
137
|
topk_index=topk_index,
|
135
|
-
hidden_states=
|
136
|
-
(b, model_config.hidden_size), device=self.device
|
137
|
-
),
|
138
|
+
hidden_states=hidden_states,
|
138
139
|
verified_id=self.output_ids,
|
139
140
|
)
|
140
141
|
spec_info.prepare_for_extend(self)
|
@@ -1 +1 @@
|
|
1
|
-
from .conn import FakeKVReceiver, FakeKVSender
|
1
|
+
from sglang.srt.disaggregation.fake.conn import FakeKVReceiver, FakeKVSender
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import logging
|
2
|
-
from typing import
|
2
|
+
from typing import List, Optional
|
3
3
|
|
4
4
|
import numpy as np
|
5
5
|
import numpy.typing as npt
|
@@ -8,7 +8,6 @@ from sglang.srt.disaggregation.base.conn import (
|
|
8
8
|
BaseKVManager,
|
9
9
|
BaseKVReceiver,
|
10
10
|
BaseKVSender,
|
11
|
-
KVArgs,
|
12
11
|
KVPoll,
|
13
12
|
)
|
14
13
|
|
@@ -17,7 +16,14 @@ logger = logging.getLogger(__name__)
|
|
17
16
|
|
18
17
|
# For warmup reqs, we don't kv transfer, we use the fake sender and receiver
|
19
18
|
class FakeKVSender(BaseKVSender):
|
20
|
-
def __init__(
|
19
|
+
def __init__(
|
20
|
+
self,
|
21
|
+
mgr: BaseKVManager,
|
22
|
+
bootstrap_addr: str,
|
23
|
+
bootstrap_room: int,
|
24
|
+
dest_tp_ranks: List[int],
|
25
|
+
pp_rank: int,
|
26
|
+
):
|
21
27
|
self.has_sent = False
|
22
28
|
|
23
29
|
def poll(self) -> KVPoll:
|
@@ -26,7 +32,7 @@ class FakeKVSender(BaseKVSender):
|
|
26
32
|
return KVPoll.WaitingForInput
|
27
33
|
else:
|
28
34
|
# Assume transfer completed instantly
|
29
|
-
logger.
|
35
|
+
logger.debug("FakeKVSender poll success")
|
30
36
|
return KVPoll.Success
|
31
37
|
|
32
38
|
def init(
|
@@ -34,17 +40,17 @@ class FakeKVSender(BaseKVSender):
|
|
34
40
|
kv_indices: list[int],
|
35
41
|
aux_index: Optional[int] = None,
|
36
42
|
):
|
37
|
-
logger.
|
43
|
+
logger.debug(
|
38
44
|
f"FakeKVSender init with kv_indices: {kv_indices}, aux_index: {aux_index}"
|
39
45
|
)
|
40
46
|
pass
|
41
47
|
|
42
48
|
def send(
|
43
49
|
self,
|
44
|
-
kv_indices: npt.NDArray[np.
|
50
|
+
kv_indices: npt.NDArray[np.int32],
|
45
51
|
):
|
46
52
|
self.has_sent = True
|
47
|
-
logger.
|
53
|
+
logger.debug(f"FakeKVSender send with kv_indices: {kv_indices}")
|
48
54
|
|
49
55
|
def failure_exception(self):
|
50
56
|
raise Exception("Fake KVSender Exception")
|
@@ -66,12 +72,12 @@ class FakeKVReceiver(BaseKVReceiver):
|
|
66
72
|
return KVPoll.WaitingForInput
|
67
73
|
else:
|
68
74
|
# Assume transfer completed instantly
|
69
|
-
logger.
|
75
|
+
logger.debug("FakeKVReceiver poll success")
|
70
76
|
return KVPoll.Success
|
71
77
|
|
72
78
|
def init(self, kv_indices: list[int], aux_index: Optional[int] = None):
|
73
79
|
self.has_init = True
|
74
|
-
logger.
|
80
|
+
logger.debug(
|
75
81
|
f"FakeKVReceiver init with kv_indices: {kv_indices}, aux_index: {aux_index}"
|
76
82
|
)
|
77
83
|
|