sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +10 -8
- sglang/bench_one_batch.py +7 -6
- sglang/bench_one_batch_server.py +157 -21
- sglang/bench_serving.py +137 -59
- sglang/compile_deep_gemm.py +5 -5
- sglang/eval/loogle_eval.py +157 -0
- sglang/lang/chat_template.py +78 -78
- sglang/lang/tracer.py +1 -1
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +2 -2
- sglang/srt/configs/model_config.py +40 -28
- sglang/srt/constrained/base_grammar_backend.py +55 -72
- sglang/srt/constrained/llguidance_backend.py +25 -21
- sglang/srt/constrained/outlines_backend.py +27 -26
- sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -43
- sglang/srt/conversation.py +49 -44
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +129 -135
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
- sglang/srt/disaggregation/fake/conn.py +3 -13
- sglang/srt/disaggregation/kv_events.py +357 -0
- sglang/srt/disaggregation/mini_lb.py +57 -24
- sglang/srt/disaggregation/mooncake/conn.py +238 -122
- sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
- sglang/srt/disaggregation/nixl/conn.py +10 -19
- sglang/srt/disaggregation/prefill.py +132 -47
- sglang/srt/disaggregation/utils.py +123 -6
- sglang/srt/distributed/utils.py +3 -3
- sglang/srt/entrypoints/EngineBase.py +5 -0
- sglang/srt/entrypoints/engine.py +44 -9
- sglang/srt/entrypoints/http_server.py +23 -6
- sglang/srt/entrypoints/http_server_engine.py +5 -2
- sglang/srt/function_call/base_format_detector.py +250 -0
- sglang/srt/function_call/core_types.py +34 -0
- sglang/srt/function_call/deepseekv3_detector.py +157 -0
- sglang/srt/function_call/ebnf_composer.py +234 -0
- sglang/srt/function_call/function_call_parser.py +175 -0
- sglang/srt/function_call/llama32_detector.py +74 -0
- sglang/srt/function_call/mistral_detector.py +84 -0
- sglang/srt/function_call/pythonic_detector.py +163 -0
- sglang/srt/function_call/qwen25_detector.py +67 -0
- sglang/srt/function_call/utils.py +35 -0
- sglang/srt/hf_transformers_utils.py +46 -7
- sglang/srt/layers/attention/aiter_backend.py +513 -0
- sglang/srt/layers/attention/flashattention_backend.py +64 -18
- sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
- sglang/srt/layers/attention/flashmla_backend.py +340 -78
- sglang/srt/layers/attention/triton_backend.py +3 -0
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
- sglang/srt/layers/attention/utils.py +6 -4
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/communicator.py +451 -0
- sglang/srt/layers/dp_attention.py +61 -21
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/logits_processor.py +46 -11
- sglang/srt/layers/moe/cutlass_moe.py +207 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +34 -12
- sglang/srt/layers/moe/ep_moe/layer.py +105 -51
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
- sglang/srt/layers/moe/topk.py +67 -10
- sglang/srt/layers/multimodal.py +70 -0
- sglang/srt/layers/quantization/__init__.py +8 -3
- sglang/srt/layers/quantization/blockwise_int8.py +2 -2
- sglang/srt/layers/quantization/deep_gemm.py +77 -74
- sglang/srt/layers/quantization/fp8.py +92 -2
- sglang/srt/layers/quantization/fp8_kernel.py +3 -3
- sglang/srt/layers/quantization/fp8_utils.py +6 -0
- sglang/srt/layers/quantization/gptq.py +298 -6
- sglang/srt/layers/quantization/int8_kernel.py +20 -7
- sglang/srt/layers/quantization/qoq.py +244 -0
- sglang/srt/layers/sampler.py +0 -4
- sglang/srt/layers/vocab_parallel_embedding.py +18 -7
- sglang/srt/lora/lora_manager.py +2 -4
- sglang/srt/lora/mem_pool.py +4 -4
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/data_parallel_controller.py +3 -3
- sglang/srt/managers/deepseek_eplb.py +278 -0
- sglang/srt/managers/detokenizer_manager.py +21 -8
- sglang/srt/managers/eplb_manager.py +55 -0
- sglang/srt/managers/expert_distribution.py +704 -56
- sglang/srt/managers/expert_location.py +394 -0
- sglang/srt/managers/expert_location_dispatch.py +91 -0
- sglang/srt/managers/io_struct.py +19 -4
- sglang/srt/managers/mm_utils.py +294 -140
- sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
- sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
- sglang/srt/managers/multimodal_processors/internvl.py +14 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
- sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
- sglang/srt/managers/multimodal_processors/llava.py +46 -0
- sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
- sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
- sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
- sglang/srt/managers/schedule_batch.py +122 -42
- sglang/srt/managers/schedule_policy.py +1 -5
- sglang/srt/managers/scheduler.py +205 -138
- sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +232 -58
- sglang/srt/managers/tp_worker.py +12 -9
- sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
- sglang/srt/mem_cache/base_prefix_cache.py +3 -0
- sglang/srt/mem_cache/chunk_cache.py +3 -1
- sglang/srt/mem_cache/hiradix_cache.py +4 -4
- sglang/srt/mem_cache/memory_pool.py +76 -52
- sglang/srt/mem_cache/multimodal_cache.py +45 -0
- sglang/srt/mem_cache/radix_cache.py +58 -5
- sglang/srt/metrics/collector.py +314 -39
- sglang/srt/mm_utils.py +10 -0
- sglang/srt/model_executor/cuda_graph_runner.py +29 -19
- sglang/srt/model_executor/expert_location_updater.py +422 -0
- sglang/srt/model_executor/forward_batch_info.py +5 -1
- sglang/srt/model_executor/model_runner.py +163 -68
- sglang/srt/model_loader/loader.py +10 -6
- sglang/srt/models/clip.py +5 -1
- sglang/srt/models/deepseek_janus_pro.py +2 -2
- sglang/srt/models/deepseek_v2.py +308 -351
- sglang/srt/models/exaone.py +8 -3
- sglang/srt/models/gemma3_mm.py +70 -33
- sglang/srt/models/llama.py +2 -0
- sglang/srt/models/llama4.py +15 -8
- sglang/srt/models/llava.py +258 -7
- sglang/srt/models/mimo_mtp.py +220 -0
- sglang/srt/models/minicpmo.py +5 -12
- sglang/srt/models/mistral.py +71 -1
- sglang/srt/models/mixtral.py +98 -34
- sglang/srt/models/mllama.py +3 -3
- sglang/srt/models/pixtral.py +467 -0
- sglang/srt/models/qwen2.py +95 -26
- sglang/srt/models/qwen2_5_vl.py +8 -0
- sglang/srt/models/qwen2_moe.py +330 -60
- sglang/srt/models/qwen2_vl.py +6 -0
- sglang/srt/models/qwen3.py +52 -10
- sglang/srt/models/qwen3_moe.py +411 -48
- sglang/srt/models/roberta.py +1 -1
- sglang/srt/models/siglip.py +294 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/openai_api/adapter.py +58 -20
- sglang/srt/openai_api/protocol.py +6 -8
- sglang/srt/operations.py +154 -0
- sglang/srt/operations_strategy.py +31 -0
- sglang/srt/reasoning_parser.py +3 -3
- sglang/srt/sampling/custom_logit_processor.py +18 -3
- sglang/srt/sampling/sampling_batch_info.py +4 -56
- sglang/srt/sampling/sampling_params.py +2 -2
- sglang/srt/server_args.py +162 -22
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/speculative/eagle_utils.py +138 -7
- sglang/srt/speculative/eagle_worker.py +69 -21
- sglang/srt/utils.py +74 -17
- sglang/test/few_shot_gsm8k.py +2 -2
- sglang/test/few_shot_gsm8k_engine.py +2 -2
- sglang/test/run_eval.py +2 -2
- sglang/test/runners.py +8 -1
- sglang/test/send_one.py +13 -3
- sglang/test/simple_eval_common.py +1 -1
- sglang/test/simple_eval_humaneval.py +1 -1
- sglang/test/test_cutlass_moe.py +278 -0
- sglang/test/test_programs.py +5 -5
- sglang/test/test_utils.py +55 -14
- sglang/utils.py +3 -3
- sglang/version.py +1 -1
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +23 -13
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +178 -149
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
- sglang/srt/function_call_parser.py +0 -858
- sglang/srt/platforms/interface.py +0 -371
- /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
- /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -24,6 +24,7 @@ import logging
|
|
24
24
|
import os
|
25
25
|
from collections import deque
|
26
26
|
from dataclasses import dataclass
|
27
|
+
from http import HTTPStatus
|
27
28
|
from typing import TYPE_CHECKING, List, Optional, Tuple
|
28
29
|
|
29
30
|
import numpy as np
|
@@ -35,24 +36,25 @@ from sglang.srt.disaggregation.utils import (
|
|
35
36
|
DisaggregationMode,
|
36
37
|
FakeBootstrapHost,
|
37
38
|
KVClassType,
|
39
|
+
MetadataBuffers,
|
38
40
|
ReqToMetadataIdxAllocator,
|
39
41
|
TransferBackend,
|
40
42
|
get_kv_class,
|
43
|
+
is_mla_backend,
|
41
44
|
kv_to_page_indices,
|
42
45
|
poll_and_all_reduce,
|
46
|
+
prepare_abort,
|
43
47
|
)
|
48
|
+
from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch
|
44
49
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
45
50
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
46
51
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
47
|
-
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
48
52
|
|
49
53
|
logger = logging.getLogger(__name__)
|
50
54
|
|
51
55
|
if TYPE_CHECKING:
|
52
|
-
from sglang.srt.
|
53
|
-
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
56
|
+
from sglang.srt.managers.schedule_batch import Req
|
54
57
|
from sglang.srt.managers.scheduler import Scheduler
|
55
|
-
from sglang.srt.server_args import ServerArgs
|
56
58
|
|
57
59
|
|
58
60
|
@dataclass
|
@@ -72,9 +74,9 @@ class DecodePreallocQueue:
|
|
72
74
|
self,
|
73
75
|
req_to_token_pool: ReqToTokenPool,
|
74
76
|
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
77
|
+
draft_token_to_kv_pool: Optional[KVCache],
|
75
78
|
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
|
76
|
-
metadata_buffers:
|
77
|
-
aux_dtype: torch.dtype,
|
79
|
+
metadata_buffers: MetadataBuffers,
|
78
80
|
scheduler: Scheduler,
|
79
81
|
transfer_queue: DecodeTransferQueue,
|
80
82
|
tree_cache: BasePrefixCache,
|
@@ -87,7 +89,8 @@ class DecodePreallocQueue:
|
|
87
89
|
self.req_to_token_pool = req_to_token_pool
|
88
90
|
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
89
91
|
self.token_to_kv_pool = token_to_kv_pool_allocator.get_kvcache()
|
90
|
-
self.
|
92
|
+
self.draft_token_to_kv_pool = draft_token_to_kv_pool
|
93
|
+
self.is_mla_backend = is_mla_backend(self.token_to_kv_pool)
|
91
94
|
self.metadata_buffers = metadata_buffers
|
92
95
|
self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
|
93
96
|
self.scheduler = scheduler
|
@@ -114,24 +117,29 @@ class DecodePreallocQueue:
|
|
114
117
|
self.token_to_kv_pool.get_contiguous_buf_infos()
|
115
118
|
)
|
116
119
|
|
120
|
+
if self.draft_token_to_kv_pool is not None:
|
121
|
+
draft_kv_data_ptrs, draft_kv_data_lens, draft_kv_item_lens = (
|
122
|
+
self.draft_token_to_kv_pool.get_contiguous_buf_infos()
|
123
|
+
)
|
124
|
+
kv_data_ptrs += draft_kv_data_ptrs
|
125
|
+
kv_data_lens += draft_kv_data_lens
|
126
|
+
kv_item_lens += draft_kv_item_lens
|
127
|
+
|
117
128
|
kv_args.kv_data_ptrs = kv_data_ptrs
|
118
129
|
kv_args.kv_data_lens = kv_data_lens
|
119
130
|
kv_args.kv_item_lens = kv_item_lens
|
120
131
|
|
121
|
-
kv_args.aux_data_ptrs =
|
122
|
-
|
123
|
-
|
124
|
-
kv_args.aux_data_lens = [
|
125
|
-
metadata_buffer.nbytes for metadata_buffer in self.metadata_buffers
|
126
|
-
]
|
127
|
-
kv_args.aux_item_lens = [
|
128
|
-
metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers
|
129
|
-
]
|
132
|
+
kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = (
|
133
|
+
self.metadata_buffers.get_buf_infos()
|
134
|
+
)
|
130
135
|
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
|
131
136
|
kv_args.gpu_id = self.scheduler.gpu_id
|
132
137
|
kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
|
133
138
|
kv_manager = kv_manager_class(
|
134
|
-
kv_args,
|
139
|
+
kv_args,
|
140
|
+
DisaggregationMode.DECODE,
|
141
|
+
self.scheduler.server_args,
|
142
|
+
self.is_mla_backend,
|
135
143
|
)
|
136
144
|
return kv_manager
|
137
145
|
|
@@ -173,7 +181,17 @@ class DecodePreallocQueue:
|
|
173
181
|
elif poll == KVPoll.WaitingForInput:
|
174
182
|
decode_req.waiting_for_input = True
|
175
183
|
elif poll == KVPoll.Failed:
|
176
|
-
|
184
|
+
error_message = f"Decode handshake failed for request rank={self.tp_rank} {decode_req.req.rid=} {decode_req.req.bootstrap_room=}"
|
185
|
+
try:
|
186
|
+
decode_req.kv_receiver.failure_exception()
|
187
|
+
except Exception as e:
|
188
|
+
error_message += f" with exception {e}"
|
189
|
+
logger.error(error_message)
|
190
|
+
prepare_abort(
|
191
|
+
decode_req.req,
|
192
|
+
error_message,
|
193
|
+
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
|
194
|
+
)
|
177
195
|
|
178
196
|
def pop_preallocated(self) -> List[DecodeRequest]:
|
179
197
|
"""Pop the preallocated requests from the pending queue (FIFO)."""
|
@@ -183,7 +201,18 @@ class DecodePreallocQueue:
|
|
183
201
|
indices_to_remove = set()
|
184
202
|
allocatable_tokens = self._allocatable_tokens()
|
185
203
|
|
204
|
+
# First, remove all failed requests from the queue
|
186
205
|
for i, decode_req in enumerate(self.queue):
|
206
|
+
if isinstance(decode_req.req.finished_reason, FINISH_ABORT):
|
207
|
+
self.scheduler.stream_output(
|
208
|
+
[decode_req.req], decode_req.req.return_logprob
|
209
|
+
)
|
210
|
+
indices_to_remove.add(i)
|
211
|
+
|
212
|
+
for i, decode_req in enumerate(self.queue):
|
213
|
+
if i in indices_to_remove:
|
214
|
+
continue
|
215
|
+
|
187
216
|
if not decode_req.waiting_for_input:
|
188
217
|
continue
|
189
218
|
|
@@ -303,18 +332,22 @@ class DecodeTransferQueue:
|
|
303
332
|
self,
|
304
333
|
gloo_group: ProcessGroup,
|
305
334
|
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
|
306
|
-
metadata_buffers:
|
335
|
+
metadata_buffers: MetadataBuffers,
|
336
|
+
scheduler: Scheduler,
|
337
|
+
tree_cache: BasePrefixCache,
|
307
338
|
):
|
308
339
|
self.queue: List[DecodeRequest] = []
|
309
340
|
self.gloo_group = gloo_group
|
310
341
|
self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
|
311
342
|
self.metadata_buffers = metadata_buffers
|
343
|
+
self.scheduler = scheduler
|
344
|
+
self.tree_cache = tree_cache
|
312
345
|
|
313
|
-
def add(self,
|
314
|
-
self.queue.append(
|
346
|
+
def add(self, decode_req: DecodeRequest) -> None:
|
347
|
+
self.queue.append(decode_req)
|
315
348
|
|
316
|
-
def extend(self,
|
317
|
-
self.queue.extend(
|
349
|
+
def extend(self, decode_reqs: List[DecodeRequest]) -> None:
|
350
|
+
self.queue.extend(decode_reqs)
|
318
351
|
|
319
352
|
def pop_transferred(self) -> List[DecodeRequest]:
|
320
353
|
if not self.queue:
|
@@ -328,18 +361,56 @@ class DecodeTransferQueue:
|
|
328
361
|
indices_to_remove = set()
|
329
362
|
for i, (decode_req, poll) in enumerate(zip(self.queue, polls)):
|
330
363
|
if poll == KVPoll.Failed:
|
331
|
-
|
364
|
+
error_message = f"Decode transfer failed for request {decode_req.req.rid=} {decode_req.req.bootstrap_room=}"
|
365
|
+
try:
|
366
|
+
decode_req.kv_receiver.failure_exception()
|
367
|
+
except Exception as e:
|
368
|
+
error_message += f" with exception {e}"
|
369
|
+
logger.error(error_message)
|
370
|
+
prepare_abort(
|
371
|
+
decode_req.req,
|
372
|
+
error_message,
|
373
|
+
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
|
374
|
+
)
|
375
|
+
self.scheduler.stream_output(
|
376
|
+
[decode_req.req], decode_req.req.return_logprob
|
377
|
+
)
|
378
|
+
# unlock the kv cache or it will have memory leak
|
379
|
+
self.tree_cache.cache_finished_req(decode_req.req)
|
380
|
+
indices_to_remove.add(i)
|
381
|
+
continue
|
332
382
|
elif poll == KVPoll.Success:
|
333
|
-
|
383
|
+
|
334
384
|
idx = decode_req.metadata_buffer_index
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
385
|
+
(
|
386
|
+
output_id,
|
387
|
+
output_token_logprobs_val,
|
388
|
+
output_token_logprobs_idx,
|
389
|
+
output_top_logprobs_val,
|
390
|
+
output_top_logprobs_idx,
|
391
|
+
) = self.metadata_buffers.get_buf(idx)
|
392
|
+
|
393
|
+
decode_req.req.output_ids.append(output_id[0].item())
|
394
|
+
|
395
|
+
if decode_req.req.return_logprob:
|
396
|
+
decode_req.req.output_token_logprobs_val.append(
|
397
|
+
output_token_logprobs_val[0].item()
|
398
|
+
)
|
399
|
+
decode_req.req.output_token_logprobs_idx.append(
|
400
|
+
output_token_logprobs_idx[0].item()
|
401
|
+
)
|
402
|
+
decode_req.req.output_top_logprobs_val.append(
|
403
|
+
output_top_logprobs_val[
|
404
|
+
: decode_req.req.top_logprobs_num
|
405
|
+
].tolist()
|
406
|
+
)
|
407
|
+
decode_req.req.output_top_logprobs_idx.append(
|
408
|
+
output_top_logprobs_idx[
|
409
|
+
: decode_req.req.top_logprobs_num
|
410
|
+
].tolist()
|
411
|
+
)
|
412
|
+
|
413
|
+
transferred_reqs.append(decode_req.req)
|
343
414
|
indices_to_remove.add(i)
|
344
415
|
elif poll in [
|
345
416
|
KVPoll.Bootstrapping,
|
@@ -362,95 +433,6 @@ class DecodeTransferQueue:
|
|
362
433
|
return transferred_reqs
|
363
434
|
|
364
435
|
|
365
|
-
class ScheduleBatchDisaggregationDecodeMixin:
|
366
|
-
|
367
|
-
def prepare_for_prebuilt_extend(self: ScheduleBatch):
|
368
|
-
"""
|
369
|
-
Prepare a prebuilt extend by populate metadata
|
370
|
-
Adapted from .prepare_for_extend().
|
371
|
-
"""
|
372
|
-
|
373
|
-
self.forward_mode = ForwardMode.EXTEND
|
374
|
-
reqs = self.reqs
|
375
|
-
input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
|
376
|
-
extend_num_tokens = sum(len(ids) for ids in input_ids)
|
377
|
-
seq_lens = []
|
378
|
-
pre_lens = []
|
379
|
-
req_pool_indices = []
|
380
|
-
|
381
|
-
# Pre-calculate total size
|
382
|
-
total_size = sum(req.extend_input_len for req in reqs)
|
383
|
-
out_cache_loc = torch.empty(total_size, dtype=torch.int64, device=self.device)
|
384
|
-
|
385
|
-
# Fill the tensor in one pass
|
386
|
-
offset = 0
|
387
|
-
for i, req in enumerate(reqs):
|
388
|
-
req_pool_indices.append(req.req_pool_idx)
|
389
|
-
|
390
|
-
chunk = self.req_to_token_pool.req_to_token[req.req_pool_idx][
|
391
|
-
: req.extend_input_len
|
392
|
-
]
|
393
|
-
assert (
|
394
|
-
offset + req.extend_input_len <= total_size
|
395
|
-
), f"Exceeds total size: offset={offset}, req.extend_input_len={req.extend_input_len}, total_size={total_size}"
|
396
|
-
out_cache_loc[offset : offset + req.extend_input_len] = chunk
|
397
|
-
offset += req.extend_input_len
|
398
|
-
|
399
|
-
pre_len = len(req.prefix_indices)
|
400
|
-
seq_len = len(req.origin_input_ids) + max(0, len(req.output_ids) - 1)
|
401
|
-
seq_lens.append(seq_len)
|
402
|
-
if len(req.output_ids) == 0:
|
403
|
-
assert (
|
404
|
-
seq_len - pre_len == req.extend_input_len
|
405
|
-
), f"seq_len={seq_len}, pre_len={pre_len}, req.extend_input_len={req.extend_input_len}"
|
406
|
-
|
407
|
-
req.cached_tokens += pre_len - req.already_computed
|
408
|
-
req.already_computed = seq_len
|
409
|
-
req.is_retracted = False
|
410
|
-
pre_lens.append(pre_len)
|
411
|
-
req.extend_logprob_start_len = 0
|
412
|
-
|
413
|
-
extend_input_logprob_token_ids = None
|
414
|
-
|
415
|
-
# Set fields
|
416
|
-
self.input_ids = torch.tensor(
|
417
|
-
sum(input_ids, []), dtype=torch.int32, device=self.device
|
418
|
-
)
|
419
|
-
self.req_pool_indices = torch.tensor(
|
420
|
-
req_pool_indices, dtype=torch.int64, device=self.device
|
421
|
-
)
|
422
|
-
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64, device=self.device)
|
423
|
-
self.out_cache_loc = out_cache_loc
|
424
|
-
self.seq_lens_sum = sum(seq_lens)
|
425
|
-
self.extend_num_tokens = extend_num_tokens
|
426
|
-
self.prefix_lens = [len(r.prefix_indices) for r in reqs]
|
427
|
-
self.extend_lens = [r.extend_input_len for r in reqs]
|
428
|
-
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
|
429
|
-
self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
|
430
|
-
|
431
|
-
# Build sampling info
|
432
|
-
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
|
433
|
-
self,
|
434
|
-
self.model_config.vocab_size,
|
435
|
-
)
|
436
|
-
|
437
|
-
def process_prebuilt_extend(
|
438
|
-
self: ScheduleBatch, server_args: ServerArgs, model_config: ModelConfig
|
439
|
-
):
|
440
|
-
"""Assign the buffered last input id to schedule batch"""
|
441
|
-
self.output_ids = []
|
442
|
-
for req in self.reqs:
|
443
|
-
if req.output_ids and len(req.output_ids) > 0:
|
444
|
-
# resumed retracted req
|
445
|
-
self.output_ids.append(req.output_ids[-1])
|
446
|
-
else:
|
447
|
-
assert req.transferred_output_id is not None
|
448
|
-
req.output_ids.append(req.transferred_output_id)
|
449
|
-
self.output_ids.append(req.transferred_output_id)
|
450
|
-
self.tree_cache.cache_unfinished_req(req)
|
451
|
-
self.output_ids = torch.tensor(self.output_ids, device=self.device)
|
452
|
-
|
453
|
-
|
454
436
|
class SchedulerDisaggregationDecodeMixin:
|
455
437
|
|
456
438
|
def _prepare_idle_batch_and_run(self, batch, delay_process=False):
|
@@ -483,7 +465,9 @@ class SchedulerDisaggregationDecodeMixin:
|
|
483
465
|
# Generate fake extend output.
|
484
466
|
if batch.forward_mode.is_extend():
|
485
467
|
# Note: Logprobs should be handled on the prefill engine.
|
486
|
-
self.stream_output(
|
468
|
+
self.stream_output(
|
469
|
+
batch.reqs, any(req.return_logprob for req in batch.reqs)
|
470
|
+
)
|
487
471
|
if prepare_dp_attn_flag:
|
488
472
|
self._prepare_idle_batch_and_run(None)
|
489
473
|
else:
|
@@ -509,7 +493,7 @@ class SchedulerDisaggregationDecodeMixin:
|
|
509
493
|
def event_loop_overlap_disagg_decode(self: Scheduler):
|
510
494
|
result_queue = deque()
|
511
495
|
self.last_batch: Optional[ScheduleBatch] = None
|
512
|
-
self.last_batch_in_queue = False # last batch is
|
496
|
+
self.last_batch_in_queue = False # last batch is modified in-place, so we need another variable to track if it's extend
|
513
497
|
|
514
498
|
while True:
|
515
499
|
recv_reqs = self.recv_requests()
|
@@ -529,7 +513,9 @@ class SchedulerDisaggregationDecodeMixin:
|
|
529
513
|
# Generate fake extend output.
|
530
514
|
if batch.forward_mode.is_extend():
|
531
515
|
# Note: Logprobs should be handled on the prefill engine.
|
532
|
-
self.stream_output(
|
516
|
+
self.stream_output(
|
517
|
+
batch.reqs, any(req.return_logprob for req in batch.reqs)
|
518
|
+
)
|
533
519
|
if prepare_dp_attn_flag:
|
534
520
|
batch_, result = self._prepare_idle_batch_and_run(
|
535
521
|
None, delay_process=True
|
@@ -542,7 +528,18 @@ class SchedulerDisaggregationDecodeMixin:
|
|
542
528
|
self.prepare_dp_attn_batch(batch)
|
543
529
|
result = self.run_batch(batch)
|
544
530
|
result_queue.append((batch.copy(), result))
|
531
|
+
|
532
|
+
if (self.last_batch is None) or (not self.last_batch_in_queue):
|
533
|
+
# Create a dummy first batch to start the pipeline for overlap schedule.
|
534
|
+
# It is now used for triggering the sampling_info_done event.
|
535
|
+
tmp_batch = ScheduleBatch(
|
536
|
+
reqs=None,
|
537
|
+
forward_mode=ForwardMode.DUMMY_FIRST,
|
538
|
+
next_batch_sampling_info=self.tp_worker.cur_sampling_info,
|
539
|
+
)
|
540
|
+
self.set_next_batch_sampling_info_done(tmp_batch)
|
545
541
|
last_batch_in_queue = True
|
542
|
+
|
546
543
|
elif prepare_dp_attn_flag:
|
547
544
|
batch, result = self._prepare_idle_batch_and_run(
|
548
545
|
None, delay_process=True
|
@@ -554,6 +551,9 @@ class SchedulerDisaggregationDecodeMixin:
|
|
554
551
|
# Process the results of the previous batch but skip if the last batch is extend
|
555
552
|
if self.last_batch and self.last_batch_in_queue:
|
556
553
|
tmp_batch, tmp_result = result_queue.popleft()
|
554
|
+
tmp_batch.next_batch_sampling_info = (
|
555
|
+
self.tp_worker.cur_sampling_info if batch else None
|
556
|
+
)
|
557
557
|
self.process_batch_result(tmp_batch, tmp_result)
|
558
558
|
|
559
559
|
if batch is None and (
|
@@ -602,6 +602,9 @@ class SchedulerDisaggregationDecodeMixin:
|
|
602
602
|
|
603
603
|
def get_new_prebuilt_batch(self: Scheduler) -> Optional[ScheduleBatch]:
|
604
604
|
"""Create a schedulebatch for fake completed prefill"""
|
605
|
+
if self.grammar_queue:
|
606
|
+
self.move_ready_grammar_requests()
|
607
|
+
|
605
608
|
if len(self.waiting_queue) == 0:
|
606
609
|
return None
|
607
610
|
|
@@ -627,8 +630,6 @@ class SchedulerDisaggregationDecodeMixin:
|
|
627
630
|
self.waiting_queue = waiting_queue
|
628
631
|
if len(can_run_list) == 0:
|
629
632
|
return None
|
630
|
-
# local import to avoid circular import
|
631
|
-
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
632
633
|
|
633
634
|
# construct a schedule batch with those requests and mark as decode
|
634
635
|
new_batch = ScheduleBatch.init_new(
|
@@ -650,15 +651,8 @@ class SchedulerDisaggregationDecodeMixin:
|
|
650
651
|
|
651
652
|
def process_decode_queue(self: Scheduler):
|
652
653
|
req_conns = self.disagg_decode_prealloc_queue.pop_preallocated()
|
653
|
-
|
654
|
-
def _num_pre_alloc(req):
|
655
|
-
return len(req.req.origin_input_ids) + max(len(req.req.output_ids) - 1, 0)
|
656
|
-
|
657
|
-
self.num_tokens_pre_allocated += sum(_num_pre_alloc(req) for req in req_conns)
|
658
654
|
self.disagg_decode_transfer_queue.extend(req_conns)
|
659
655
|
alloc_reqs = (
|
660
656
|
self.disagg_decode_transfer_queue.pop_transferred()
|
661
657
|
) # the requests which kv has arrived
|
662
|
-
self.
|
663
|
-
|
664
|
-
self.waiting_queue.extend([req.req for req in alloc_reqs])
|
658
|
+
self.waiting_queue.extend(alloc_reqs)
|
@@ -0,0 +1,142 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import logging
|
4
|
+
from typing import TYPE_CHECKING
|
5
|
+
|
6
|
+
import torch
|
7
|
+
|
8
|
+
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
9
|
+
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
10
|
+
|
11
|
+
logger = logging.getLogger(__name__)
|
12
|
+
|
13
|
+
if TYPE_CHECKING:
|
14
|
+
from sglang.srt.configs.model_config import ModelConfig
|
15
|
+
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
16
|
+
from sglang.srt.server_args import ServerArgs
|
17
|
+
|
18
|
+
|
19
|
+
class ScheduleBatchDisaggregationDecodeMixin:
|
20
|
+
|
21
|
+
def prepare_for_prebuilt_extend(self: ScheduleBatch):
|
22
|
+
"""
|
23
|
+
Prepare a prebuilt extend by populate metadata
|
24
|
+
Adapted from .prepare_for_extend().
|
25
|
+
"""
|
26
|
+
|
27
|
+
self.forward_mode = ForwardMode.EXTEND
|
28
|
+
reqs = self.reqs
|
29
|
+
input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
|
30
|
+
extend_num_tokens = sum(len(ids) for ids in input_ids)
|
31
|
+
seq_lens = []
|
32
|
+
pre_lens = []
|
33
|
+
req_pool_indices = []
|
34
|
+
|
35
|
+
# Pre-calculate total size
|
36
|
+
total_size = sum(req.extend_input_len for req in reqs)
|
37
|
+
out_cache_loc = torch.empty(total_size, dtype=torch.int64, device=self.device)
|
38
|
+
|
39
|
+
# Fill the tensor in one pass
|
40
|
+
offset = 0
|
41
|
+
for i, req in enumerate(reqs):
|
42
|
+
req_pool_indices.append(req.req_pool_idx)
|
43
|
+
|
44
|
+
chunk = self.req_to_token_pool.req_to_token[req.req_pool_idx][
|
45
|
+
: req.extend_input_len
|
46
|
+
]
|
47
|
+
assert (
|
48
|
+
offset + req.extend_input_len <= total_size
|
49
|
+
), f"Exceeds total size: offset={offset}, req.extend_input_len={req.extend_input_len}, total_size={total_size}"
|
50
|
+
out_cache_loc[offset : offset + req.extend_input_len] = chunk
|
51
|
+
offset += req.extend_input_len
|
52
|
+
|
53
|
+
pre_len = len(req.prefix_indices)
|
54
|
+
seq_len = len(req.origin_input_ids) + max(0, len(req.output_ids) - 1)
|
55
|
+
seq_lens.append(seq_len)
|
56
|
+
if len(req.output_ids) == 0:
|
57
|
+
assert (
|
58
|
+
seq_len - pre_len == req.extend_input_len
|
59
|
+
), f"seq_len={seq_len}, pre_len={pre_len}, req.extend_input_len={req.extend_input_len}"
|
60
|
+
|
61
|
+
req.cached_tokens += pre_len - req.already_computed
|
62
|
+
req.already_computed = seq_len
|
63
|
+
req.is_retracted = False
|
64
|
+
pre_lens.append(pre_len)
|
65
|
+
req.extend_logprob_start_len = 0
|
66
|
+
|
67
|
+
extend_input_logprob_token_ids = None
|
68
|
+
|
69
|
+
# Set fields
|
70
|
+
self.input_ids = torch.tensor(
|
71
|
+
sum(input_ids, []), dtype=torch.int32, device=self.device
|
72
|
+
)
|
73
|
+
self.req_pool_indices = torch.tensor(
|
74
|
+
req_pool_indices, dtype=torch.int64, device=self.device
|
75
|
+
)
|
76
|
+
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64, device=self.device)
|
77
|
+
self.out_cache_loc = out_cache_loc
|
78
|
+
self.seq_lens_sum = sum(seq_lens)
|
79
|
+
|
80
|
+
if self.return_logprob:
|
81
|
+
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
|
82
|
+
self.token_ids_logprobs = [r.token_ids_logprob for r in reqs]
|
83
|
+
|
84
|
+
self.extend_num_tokens = extend_num_tokens
|
85
|
+
self.prefix_lens = [len(r.prefix_indices) for r in reqs]
|
86
|
+
self.extend_lens = [r.extend_input_len for r in reqs]
|
87
|
+
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
|
88
|
+
self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
|
89
|
+
|
90
|
+
# Build sampling info
|
91
|
+
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
|
92
|
+
self,
|
93
|
+
self.model_config.vocab_size,
|
94
|
+
)
|
95
|
+
|
96
|
+
def process_prebuilt_extend(
|
97
|
+
self: ScheduleBatch, server_args: ServerArgs, model_config: ModelConfig
|
98
|
+
):
|
99
|
+
"""Assign the buffered last input id to schedule batch"""
|
100
|
+
self.output_ids = []
|
101
|
+
for req in self.reqs:
|
102
|
+
self.output_ids.append(req.output_ids[-1])
|
103
|
+
self.tree_cache.cache_unfinished_req(req)
|
104
|
+
if req.grammar is not None:
|
105
|
+
req.grammar.accept_token(req.output_ids[-1])
|
106
|
+
req.grammar.finished = req.finished()
|
107
|
+
self.output_ids = torch.tensor(self.output_ids, device=self.device)
|
108
|
+
|
109
|
+
# Simulate the eagle run. We add mock data to hidden states for the
|
110
|
+
# ease of implementation now meaning the first token will have acc rate
|
111
|
+
# of 0.
|
112
|
+
if not self.spec_algorithm.is_none():
|
113
|
+
|
114
|
+
b = len(self.reqs)
|
115
|
+
topk_p = torch.arange(
|
116
|
+
b * server_args.speculative_eagle_topk,
|
117
|
+
0,
|
118
|
+
-1,
|
119
|
+
device=self.device,
|
120
|
+
dtype=torch.float32,
|
121
|
+
)
|
122
|
+
topk_p = topk_p.reshape(b, server_args.speculative_eagle_topk)
|
123
|
+
topk_p /= b * server_args.speculative_eagle_topk
|
124
|
+
topk_index = torch.arange(
|
125
|
+
b * server_args.speculative_eagle_topk, device=self.device
|
126
|
+
)
|
127
|
+
topk_index = topk_index.reshape(b, server_args.speculative_eagle_topk)
|
128
|
+
|
129
|
+
# local import to avoid circular import
|
130
|
+
from sglang.srt.speculative.eagle_utils import EagleDraftInput
|
131
|
+
|
132
|
+
spec_info = EagleDraftInput(
|
133
|
+
topk_p=topk_p,
|
134
|
+
topk_index=topk_index,
|
135
|
+
hidden_states=torch.ones(
|
136
|
+
(b, model_config.hidden_size), device=self.device
|
137
|
+
),
|
138
|
+
verified_id=self.output_ids,
|
139
|
+
)
|
140
|
+
spec_info.prepare_for_extend(self)
|
141
|
+
spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
142
|
+
self.spec_info = spec_info
|
@@ -33,28 +33,18 @@ class FakeKVSender(BaseKVSender):
|
|
33
33
|
self,
|
34
34
|
kv_indices: list[int],
|
35
35
|
aux_index: Optional[int] = None,
|
36
|
-
dest_ranks: Optional[list[int]] = None,
|
37
36
|
):
|
38
37
|
logger.info(
|
39
|
-
f"FakeKVSender init with kv_indices: {kv_indices}, aux_index: {aux_index}
|
38
|
+
f"FakeKVSender init with kv_indices: {kv_indices}, aux_index: {aux_index}"
|
40
39
|
)
|
41
40
|
pass
|
42
41
|
|
43
42
|
def send(
|
44
43
|
self,
|
45
44
|
kv_indices: npt.NDArray[np.int64],
|
46
|
-
index_slice: slice,
|
47
|
-
is_last: bool,
|
48
45
|
):
|
49
|
-
|
50
|
-
|
51
|
-
)
|
52
|
-
if is_last:
|
53
|
-
self.has_sent = True
|
54
|
-
logger.info(f"FakeKVSender send success")
|
55
|
-
else:
|
56
|
-
self.has_sent = False
|
57
|
-
logger.info(f"FakeKVSender send fake transfering")
|
46
|
+
self.has_sent = True
|
47
|
+
logger.info(f"FakeKVSender send with kv_indices: {kv_indices}")
|
58
48
|
|
59
49
|
def failure_exception(self):
|
60
50
|
raise Exception("Fake KVSender Exception")
|