sglang 0.4.4__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +6 -0
- sglang/bench_one_batch.py +1 -1
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +3 -1
- sglang/check_env.py +3 -4
- sglang/lang/backend/openai.py +18 -5
- sglang/lang/chat_template.py +28 -7
- sglang/lang/interpreter.py +7 -3
- sglang/lang/ir.py +10 -0
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/code_completion_parser.py +174 -0
- sglang/srt/configs/__init__.py +2 -6
- sglang/srt/configs/deepseekvl2.py +667 -0
- sglang/srt/configs/janus_pro.py +3 -4
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +63 -11
- sglang/srt/configs/utils.py +25 -0
- sglang/srt/connector/__init__.py +51 -0
- sglang/srt/connector/base_connector.py +112 -0
- sglang/srt/connector/redis.py +85 -0
- sglang/srt/connector/s3.py +122 -0
- sglang/srt/connector/serde/__init__.py +31 -0
- sglang/srt/connector/serde/safe_serde.py +29 -0
- sglang/srt/connector/serde/serde.py +43 -0
- sglang/srt/connector/utils.py +35 -0
- sglang/srt/conversation.py +88 -0
- sglang/srt/disaggregation/conn.py +81 -0
- sglang/srt/disaggregation/decode.py +495 -0
- sglang/srt/disaggregation/mini_lb.py +285 -0
- sglang/srt/disaggregation/prefill.py +249 -0
- sglang/srt/disaggregation/utils.py +44 -0
- sglang/srt/distributed/parallel_state.py +10 -3
- sglang/srt/entrypoints/engine.py +55 -5
- sglang/srt/entrypoints/http_server.py +71 -12
- sglang/srt/function_call_parser.py +164 -54
- sglang/srt/hf_transformers_utils.py +28 -3
- sglang/srt/layers/activation.py +4 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +295 -0
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashmla_backend.py +284 -0
- sglang/srt/layers/attention/triton_backend.py +171 -38
- sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
- sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
- sglang/srt/layers/attention/utils.py +53 -0
- sglang/srt/layers/attention/vision.py +9 -28
- sglang/srt/layers/dp_attention.py +62 -23
- sglang/srt/layers/elementwise.py +411 -0
- sglang/srt/layers/layernorm.py +24 -2
- sglang/srt/layers/linear.py +17 -5
- sglang/srt/layers/logits_processor.py +26 -7
- sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
- sglang/srt/layers/moe/ep_moe/layer.py +273 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
- sglang/srt/layers/moe/fused_moe_native.py +2 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
- sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
- sglang/srt/layers/moe/router.py +342 -0
- sglang/srt/layers/moe/topk.py +31 -18
- sglang/srt/layers/parameter.py +1 -1
- sglang/srt/layers/quantization/__init__.py +184 -126
- sglang/srt/layers/quantization/base_config.py +5 -0
- sglang/srt/layers/quantization/blockwise_int8.py +1 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
- sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
- sglang/srt/layers/quantization/fp8.py +76 -34
- sglang/srt/layers/quantization/fp8_kernel.py +24 -8
- sglang/srt/layers/quantization/fp8_utils.py +284 -28
- sglang/srt/layers/quantization/gptq.py +36 -9
- sglang/srt/layers/quantization/kv_cache.py +98 -0
- sglang/srt/layers/quantization/modelopt_quant.py +9 -7
- sglang/srt/layers/quantization/utils.py +153 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
- sglang/srt/layers/rotary_embedding.py +66 -87
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/layers.py +68 -0
- sglang/srt/lora/lora.py +2 -22
- sglang/srt/lora/lora_manager.py +47 -23
- sglang/srt/lora/mem_pool.py +110 -51
- sglang/srt/lora/utils.py +12 -1
- sglang/srt/managers/cache_controller.py +4 -5
- sglang/srt/managers/data_parallel_controller.py +31 -9
- sglang/srt/managers/expert_distribution.py +81 -0
- sglang/srt/managers/io_struct.py +39 -3
- sglang/srt/managers/mm_utils.py +373 -0
- sglang/srt/managers/multimodal_processor.py +68 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
- sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
- sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
- sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
- sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
- sglang/srt/managers/schedule_batch.py +134 -31
- sglang/srt/managers/scheduler.py +325 -38
- sglang/srt/managers/scheduler_output_processor_mixin.py +4 -1
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +59 -23
- sglang/srt/managers/tp_worker.py +1 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
- sglang/srt/managers/utils.py +6 -1
- sglang/srt/mem_cache/hiradix_cache.py +27 -8
- sglang/srt/mem_cache/memory_pool.py +258 -98
- sglang/srt/mem_cache/paged_allocator.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +4 -4
- sglang/srt/model_executor/cuda_graph_runner.py +85 -28
- sglang/srt/model_executor/forward_batch_info.py +81 -15
- sglang/srt/model_executor/model_runner.py +70 -6
- sglang/srt/model_loader/loader.py +160 -2
- sglang/srt/model_loader/weight_utils.py +45 -0
- sglang/srt/models/deepseek_janus_pro.py +29 -86
- sglang/srt/models/deepseek_nextn.py +22 -10
- sglang/srt/models/deepseek_v2.py +326 -192
- sglang/srt/models/deepseek_vl2.py +358 -0
- sglang/srt/models/gemma3_causal.py +684 -0
- sglang/srt/models/gemma3_mm.py +462 -0
- sglang/srt/models/grok.py +374 -119
- sglang/srt/models/llama.py +47 -7
- sglang/srt/models/llama_eagle.py +1 -0
- sglang/srt/models/llama_eagle3.py +196 -0
- sglang/srt/models/llava.py +3 -3
- sglang/srt/models/llavavid.py +3 -3
- sglang/srt/models/minicpmo.py +1995 -0
- sglang/srt/models/minicpmv.py +62 -137
- sglang/srt/models/mllama.py +4 -4
- sglang/srt/models/phi3_small.py +1 -1
- sglang/srt/models/qwen2.py +3 -0
- sglang/srt/models/qwen2_5_vl.py +68 -146
- sglang/srt/models/qwen2_classification.py +75 -0
- sglang/srt/models/qwen2_moe.py +9 -1
- sglang/srt/models/qwen2_vl.py +25 -63
- sglang/srt/openai_api/adapter.py +145 -47
- sglang/srt/openai_api/protocol.py +23 -2
- sglang/srt/sampling/sampling_batch_info.py +1 -1
- sglang/srt/sampling/sampling_params.py +6 -6
- sglang/srt/server_args.py +104 -14
- sglang/srt/speculative/build_eagle_tree.py +7 -347
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
- sglang/srt/speculative/eagle_utils.py +208 -252
- sglang/srt/speculative/eagle_worker.py +139 -53
- sglang/srt/speculative/spec_info.py +6 -1
- sglang/srt/torch_memory_saver_adapter.py +22 -0
- sglang/srt/utils.py +182 -21
- sglang/test/__init__.py +0 -0
- sglang/test/attention/__init__.py +0 -0
- sglang/test/attention/test_flashattn_backend.py +312 -0
- sglang/test/runners.py +2 -0
- sglang/test/test_activation.py +2 -1
- sglang/test/test_block_fp8.py +5 -4
- sglang/test/test_block_fp8_ep.py +2 -1
- sglang/test/test_dynamic_grad_mode.py +58 -0
- sglang/test/test_layernorm.py +3 -2
- sglang/test/test_utils.py +55 -4
- sglang/utils.py +31 -0
- sglang/version.py +1 -1
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +171 -125
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
- sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
- sglang/srt/managers/image_processor.py +0 -55
- sglang/srt/managers/image_processors/base_image_processor.py +0 -219
- sglang/srt/managers/image_processors/minicpmv.py +0 -86
- sglang/srt/managers/multi_modality_padding.py +0 -134
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,495 @@
|
|
1
|
+
"""
|
2
|
+
Life cycle of a request in the decode server
|
3
|
+
|
4
|
+
1. PreallocQueue:
|
5
|
+
a. Initialize a receiver for each request
|
6
|
+
b. The request handshakes first, and pre-allocate kv once there is available kv.
|
7
|
+
c. Move the request to TransferQueue.
|
8
|
+
|
9
|
+
2. TransferQueue:
|
10
|
+
a. Poll the receiver to check the transfer state
|
11
|
+
b. If the transfer has finished, move the request to waiting queue
|
12
|
+
|
13
|
+
3. WaitingQueue:
|
14
|
+
a. Use the requests in the queue to construct a PrebuiltExtendBatch
|
15
|
+
b. Skip the prefill forward but only populate metadata
|
16
|
+
|
17
|
+
4. RunningBatch:
|
18
|
+
a. Merge the resolved PrebuiltExtendBatch into running batch to run decoding
|
19
|
+
"""
|
20
|
+
|
21
|
+
from __future__ import annotations
|
22
|
+
|
23
|
+
import logging
|
24
|
+
from dataclasses import dataclass
|
25
|
+
from typing import TYPE_CHECKING, List, Optional, Tuple
|
26
|
+
|
27
|
+
import torch
|
28
|
+
from torch.distributed import ProcessGroup
|
29
|
+
|
30
|
+
from sglang.srt.disaggregation.conn import KVArgs, KVManager, KVPoll, KVReceiver
|
31
|
+
from sglang.srt.disaggregation.utils import (
|
32
|
+
ReqToMetadataIdxAllocator,
|
33
|
+
poll_and_all_reduce,
|
34
|
+
)
|
35
|
+
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
36
|
+
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
37
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
38
|
+
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
39
|
+
|
40
|
+
logger = logging.getLogger(__name__)
|
41
|
+
|
42
|
+
if TYPE_CHECKING:
|
43
|
+
from sglang.srt.configs.model_config import ModelConfig
|
44
|
+
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
45
|
+
from sglang.srt.managers.scheduler import Scheduler
|
46
|
+
from sglang.srt.server_args import ServerArgs
|
47
|
+
|
48
|
+
|
49
|
+
@dataclass
|
50
|
+
class DecodeRequest:
|
51
|
+
req: Req
|
52
|
+
kv_receiver: KVReceiver
|
53
|
+
waiting_for_input: bool = False
|
54
|
+
metadata_buffer_index: int = -1
|
55
|
+
|
56
|
+
|
57
|
+
class DecodePreallocQueue:
|
58
|
+
"""
|
59
|
+
Store the requests that are preallocating.
|
60
|
+
"""
|
61
|
+
|
62
|
+
def __init__(
|
63
|
+
self,
|
64
|
+
req_to_token_pool: ReqToTokenPool,
|
65
|
+
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
66
|
+
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
|
67
|
+
metadata_buffers: List[torch.Tensor],
|
68
|
+
aux_dtype: torch.dtype,
|
69
|
+
scheduler: Scheduler,
|
70
|
+
transfer_queue: DecodeTransferQueue,
|
71
|
+
tree_cache: BasePrefixCache,
|
72
|
+
gloo_group: ProcessGroup,
|
73
|
+
tp_rank: int,
|
74
|
+
tp_size: int,
|
75
|
+
bootstrap_port: int,
|
76
|
+
):
|
77
|
+
self.req_to_token_pool = req_to_token_pool
|
78
|
+
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
79
|
+
self.token_to_kv_pool = token_to_kv_pool_allocator.get_kvcache()
|
80
|
+
self.aux_dtype = aux_dtype
|
81
|
+
self.metadata_buffers = metadata_buffers
|
82
|
+
self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
|
83
|
+
self.scheduler = scheduler
|
84
|
+
self.transfer_queue = transfer_queue
|
85
|
+
self.tree_cache = tree_cache # this is always a chunk cache
|
86
|
+
self.gloo_group = gloo_group
|
87
|
+
self.tp_rank = tp_rank
|
88
|
+
self.tp_size = tp_size
|
89
|
+
self.bootstrap_port = bootstrap_port
|
90
|
+
|
91
|
+
self.num_reserved_decode_tokens = 512
|
92
|
+
|
93
|
+
# Queue for requests pending pre-allocation
|
94
|
+
self.queue: List[DecodeRequest] = []
|
95
|
+
self.kv_manager = self._init_kv_manager()
|
96
|
+
|
97
|
+
def _init_kv_manager(self) -> KVManager:
|
98
|
+
kv_args = KVArgs()
|
99
|
+
kv_args.engine_rank = self.tp_rank
|
100
|
+
kv_data_ptrs, kv_data_lens, kv_item_lens = (
|
101
|
+
self.token_to_kv_pool.get_contiguous_buf_infos()
|
102
|
+
)
|
103
|
+
|
104
|
+
kv_args.kv_data_ptrs = kv_data_ptrs
|
105
|
+
kv_args.kv_data_lens = kv_data_lens
|
106
|
+
kv_args.kv_item_lens = kv_item_lens
|
107
|
+
|
108
|
+
kv_args.aux_data_ptrs = [
|
109
|
+
output_id_tensor.data_ptr() for output_id_tensor in self.metadata_buffers
|
110
|
+
]
|
111
|
+
kv_args.aux_data_lens = [
|
112
|
+
metadata_buffer.nbytes for metadata_buffer in self.metadata_buffers
|
113
|
+
]
|
114
|
+
kv_args.aux_item_lens = [
|
115
|
+
metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers
|
116
|
+
]
|
117
|
+
kv_args.ib_device = "mock-ib-device"
|
118
|
+
kv_manager = KVManager(kv_args)
|
119
|
+
return kv_manager
|
120
|
+
|
121
|
+
def add(self, req: Req) -> None:
|
122
|
+
"""Add a request to the pending queue."""
|
123
|
+
|
124
|
+
kv_receiver = KVReceiver(
|
125
|
+
mgr=self.kv_manager,
|
126
|
+
bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}",
|
127
|
+
bootstrap_room=req.bootstrap_room,
|
128
|
+
)
|
129
|
+
self.queue.append(DecodeRequest(req=req, kv_receiver=kv_receiver))
|
130
|
+
|
131
|
+
def extend(self, reqs: List[Req]) -> None:
|
132
|
+
"""Add a request to the pending queue."""
|
133
|
+
for req in reqs:
|
134
|
+
self.add(req)
|
135
|
+
|
136
|
+
def _update_handshake_waiters(self) -> None:
|
137
|
+
if not self.queue:
|
138
|
+
return
|
139
|
+
|
140
|
+
if all(decode_req.waiting_for_input for decode_req in self.queue):
|
141
|
+
return
|
142
|
+
|
143
|
+
polls = poll_and_all_reduce(
|
144
|
+
[decode_req.kv_receiver for decode_req in self.queue], self.gloo_group
|
145
|
+
)
|
146
|
+
|
147
|
+
for i, (decode_req, poll) in enumerate(zip(self.queue, polls)):
|
148
|
+
if poll == KVPoll.Bootstrapping:
|
149
|
+
pass
|
150
|
+
elif poll == KVPoll.WaitingForInput:
|
151
|
+
decode_req.waiting_for_input = True
|
152
|
+
elif poll == KVPoll.Failed:
|
153
|
+
raise Exception("Handshake failed")
|
154
|
+
|
155
|
+
def pop_preallocated(self) -> List[DecodeRequest]:
|
156
|
+
"""Pop the preallocated requests from the pending queue (FIFO)."""
|
157
|
+
self._update_handshake_waiters()
|
158
|
+
|
159
|
+
preallocated_reqs = []
|
160
|
+
indices_to_remove = set()
|
161
|
+
allocatable_tokens = self._allocatable_tokens()
|
162
|
+
|
163
|
+
for i, decode_req in enumerate(self.queue):
|
164
|
+
if not decode_req.waiting_for_input:
|
165
|
+
continue
|
166
|
+
|
167
|
+
if self.req_to_token_pool.available_size() <= 0:
|
168
|
+
break
|
169
|
+
|
170
|
+
if self.req_to_metadata_buffer_idx_allocator.available_size() <= 0:
|
171
|
+
break
|
172
|
+
|
173
|
+
required_tokens_for_request = (
|
174
|
+
len(decode_req.req.origin_input_ids) + self.num_reserved_decode_tokens
|
175
|
+
)
|
176
|
+
|
177
|
+
if required_tokens_for_request > allocatable_tokens:
|
178
|
+
break
|
179
|
+
|
180
|
+
allocatable_tokens -= required_tokens_for_request
|
181
|
+
self._pre_alloc(decode_req.req)
|
182
|
+
|
183
|
+
kv_indices = (
|
184
|
+
self.req_to_token_pool.req_to_token[decode_req.req.req_pool_idx][
|
185
|
+
: len(decode_req.req.origin_input_ids)
|
186
|
+
]
|
187
|
+
.cpu()
|
188
|
+
.numpy()
|
189
|
+
)
|
190
|
+
|
191
|
+
decode_req.metadata_buffer_index = (
|
192
|
+
self.req_to_metadata_buffer_idx_allocator.alloc()
|
193
|
+
)
|
194
|
+
assert decode_req.metadata_buffer_index is not None
|
195
|
+
decode_req.kv_receiver.init(kv_indices, decode_req.metadata_buffer_index)
|
196
|
+
preallocated_reqs.append(decode_req)
|
197
|
+
indices_to_remove.add(i)
|
198
|
+
|
199
|
+
self.queue = [
|
200
|
+
entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
|
201
|
+
]
|
202
|
+
|
203
|
+
return preallocated_reqs
|
204
|
+
|
205
|
+
def _allocatable_tokens(self) -> int:
|
206
|
+
allocatable_tokens = (
|
207
|
+
self.token_to_kv_pool_allocator.available_size()
|
208
|
+
- self.num_reserved_decode_tokens
|
209
|
+
* (
|
210
|
+
len(self.scheduler.running_batch.reqs)
|
211
|
+
+ len(self.transfer_queue.queue)
|
212
|
+
+ len(self.scheduler.waiting_queue)
|
213
|
+
)
|
214
|
+
)
|
215
|
+
|
216
|
+
# Note: if the last fake extend just finishes, and we enter `pop_preallocated` immediately in the next iteration
|
217
|
+
# the extend batch is not in any queue, so we need to explicitly add the tokens slots here
|
218
|
+
if (
|
219
|
+
self.scheduler.last_batch
|
220
|
+
and self.scheduler.last_batch.forward_mode.is_extend()
|
221
|
+
):
|
222
|
+
allocatable_tokens -= self.num_reserved_decode_tokens * len(
|
223
|
+
self.scheduler.last_batch.reqs
|
224
|
+
)
|
225
|
+
|
226
|
+
return allocatable_tokens
|
227
|
+
|
228
|
+
def _pre_alloc(self, req: Req) -> torch.Tensor:
|
229
|
+
"""Pre-allocate the memory for req_to_token and token_kv_pool"""
|
230
|
+
req_pool_indices = self.req_to_token_pool.alloc(1)
|
231
|
+
|
232
|
+
assert req_pool_indices is not None
|
233
|
+
|
234
|
+
req.req_pool_idx = req_pool_indices[0]
|
235
|
+
kv_loc = self.token_to_kv_pool_allocator.alloc(
|
236
|
+
len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
|
237
|
+
)
|
238
|
+
|
239
|
+
assert kv_loc is not None
|
240
|
+
|
241
|
+
self.req_to_token_pool.write((req.req_pool_idx, slice(0, len(kv_loc))), kv_loc)
|
242
|
+
|
243
|
+
# populate metadata
|
244
|
+
req.fill_ids = req.origin_input_ids + req.output_ids
|
245
|
+
req.extend_input_len = len(req.origin_input_ids)
|
246
|
+
|
247
|
+
return kv_loc
|
248
|
+
|
249
|
+
|
250
|
+
class DecodeTransferQueue:
|
251
|
+
"""
|
252
|
+
Store the requests that is polling kv
|
253
|
+
"""
|
254
|
+
|
255
|
+
def __init__(
|
256
|
+
self,
|
257
|
+
gloo_group: ProcessGroup,
|
258
|
+
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
|
259
|
+
metadata_buffers: torch.Tensor,
|
260
|
+
):
|
261
|
+
self.queue: List[DecodeRequest] = []
|
262
|
+
self.gloo_group = gloo_group
|
263
|
+
self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
|
264
|
+
self.metadata_buffers = metadata_buffers
|
265
|
+
|
266
|
+
def add(self, req_conn: DecodeRequest) -> None:
|
267
|
+
self.queue.append(req_conn)
|
268
|
+
|
269
|
+
def extend(self, req_conns) -> None:
|
270
|
+
self.queue.extend(req_conns)
|
271
|
+
|
272
|
+
def pop_transferred(self) -> List[Req]:
|
273
|
+
if not self.queue:
|
274
|
+
return []
|
275
|
+
|
276
|
+
polls = poll_and_all_reduce(
|
277
|
+
[decode_req.kv_receiver for decode_req in self.queue], self.gloo_group
|
278
|
+
)
|
279
|
+
|
280
|
+
transferred_reqs = []
|
281
|
+
indices_to_remove = set()
|
282
|
+
for i, (decode_req, poll) in enumerate(zip(self.queue, polls)):
|
283
|
+
if poll == KVPoll.Failed:
|
284
|
+
raise Exception("Transfer failed")
|
285
|
+
elif poll == KVPoll.Success:
|
286
|
+
# pop and push it to waiting queue
|
287
|
+
idx = decode_req.metadata_buffer_index
|
288
|
+
assert len(decode_req.req.output_ids) == 0
|
289
|
+
output_id_buffer = self.metadata_buffers[0]
|
290
|
+
# the last dimension is padded by the same values.
|
291
|
+
output_id = output_id_buffer[idx][0].item()
|
292
|
+
assert len(decode_req.req.output_ids) == 0
|
293
|
+
assert decode_req.req.transferred_output_id is None
|
294
|
+
decode_req.req.transferred_output_id = output_id
|
295
|
+
transferred_reqs.append(decode_req.req)
|
296
|
+
indices_to_remove.add(i)
|
297
|
+
elif poll in [
|
298
|
+
KVPoll.Bootstrapping,
|
299
|
+
KVPoll.WaitingForInput,
|
300
|
+
KVPoll.Transferring,
|
301
|
+
]:
|
302
|
+
pass
|
303
|
+
else:
|
304
|
+
raise ValueError(f"Unexpected poll case: {poll}")
|
305
|
+
|
306
|
+
for i in indices_to_remove:
|
307
|
+
idx = self.queue[i].metadata_buffer_index
|
308
|
+
assert idx != -1
|
309
|
+
self.req_to_metadata_buffer_idx_allocator.free(idx)
|
310
|
+
|
311
|
+
self.queue = [
|
312
|
+
entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
|
313
|
+
]
|
314
|
+
|
315
|
+
return transferred_reqs
|
316
|
+
|
317
|
+
|
318
|
+
class ScheduleBatchDisaggregationDecodeMixin:
|
319
|
+
|
320
|
+
def prepare_for_prebuilt_extend(self: ScheduleBatch):
|
321
|
+
"""
|
322
|
+
Prepare a prebuilt extend by populate metadata
|
323
|
+
Adapted from .prepare_for_extend().
|
324
|
+
"""
|
325
|
+
|
326
|
+
self.forward_mode = ForwardMode.EXTEND
|
327
|
+
reqs = self.reqs
|
328
|
+
input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
|
329
|
+
extend_num_tokens = sum(len(ids) for ids in input_ids)
|
330
|
+
seq_lens = []
|
331
|
+
pre_lens = []
|
332
|
+
req_pool_indices = []
|
333
|
+
|
334
|
+
# Pre-calculate total size
|
335
|
+
total_size = sum(req.extend_input_len for req in reqs)
|
336
|
+
out_cache_loc = torch.empty(total_size, dtype=torch.int64, device=self.device)
|
337
|
+
|
338
|
+
# Fill the tensor in one pass
|
339
|
+
offset = 0
|
340
|
+
for i, req in enumerate(reqs):
|
341
|
+
req_pool_indices.append(req.req_pool_idx)
|
342
|
+
|
343
|
+
chunk = self.req_to_token_pool.req_to_token[req.req_pool_idx][
|
344
|
+
: req.extend_input_len
|
345
|
+
]
|
346
|
+
assert (
|
347
|
+
offset + req.extend_input_len <= total_size
|
348
|
+
), f"Exceeds total size: offset={offset}, req.extend_input_len={req.extend_input_len}, total_size={total_size}"
|
349
|
+
out_cache_loc[offset : offset + req.extend_input_len] = chunk
|
350
|
+
offset += req.extend_input_len
|
351
|
+
|
352
|
+
pre_len = len(req.prefix_indices)
|
353
|
+
seq_len = len(req.origin_input_ids) + max(0, len(req.output_ids) - 1)
|
354
|
+
seq_lens.append(seq_len)
|
355
|
+
if len(req.output_ids) == 0:
|
356
|
+
assert (
|
357
|
+
seq_len - pre_len == req.extend_input_len
|
358
|
+
), f"seq_len={seq_len}, pre_len={pre_len}, req.extend_input_len={req.extend_input_len}"
|
359
|
+
|
360
|
+
req.cached_tokens += pre_len - req.already_computed
|
361
|
+
req.already_computed = seq_len
|
362
|
+
req.is_retracted = False
|
363
|
+
pre_lens.append(pre_len)
|
364
|
+
req.extend_logprob_start_len = 0
|
365
|
+
|
366
|
+
extend_input_logprob_token_ids = None
|
367
|
+
|
368
|
+
# Set fields
|
369
|
+
self.input_ids = torch.tensor(
|
370
|
+
sum(input_ids, []), dtype=torch.int32, device=self.device
|
371
|
+
)
|
372
|
+
self.req_pool_indices = torch.tensor(
|
373
|
+
req_pool_indices, dtype=torch.int64, device=self.device
|
374
|
+
)
|
375
|
+
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64, device=self.device)
|
376
|
+
self.out_cache_loc = out_cache_loc
|
377
|
+
self.seq_lens_sum = sum(seq_lens)
|
378
|
+
self.extend_num_tokens = extend_num_tokens
|
379
|
+
self.prefix_lens = [len(r.prefix_indices) for r in reqs]
|
380
|
+
self.extend_lens = [r.extend_input_len for r in reqs]
|
381
|
+
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
|
382
|
+
self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
|
383
|
+
|
384
|
+
# Build sampling info
|
385
|
+
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
|
386
|
+
self,
|
387
|
+
self.model_config.vocab_size,
|
388
|
+
)
|
389
|
+
|
390
|
+
def process_prebuilt_extend(
|
391
|
+
self: ScheduleBatch, server_args: ServerArgs, model_config: ModelConfig
|
392
|
+
):
|
393
|
+
"""Assign the buffered last input id to schedule batch"""
|
394
|
+
self.output_ids = []
|
395
|
+
for req in self.reqs:
|
396
|
+
if req.output_ids and len(req.output_ids) > 0:
|
397
|
+
# resumed retracted req
|
398
|
+
self.output_ids.append(req.output_ids[-1])
|
399
|
+
else:
|
400
|
+
assert req.transferred_output_id is not None
|
401
|
+
req.output_ids.append(req.transferred_output_id)
|
402
|
+
self.output_ids.append(req.transferred_output_id)
|
403
|
+
self.tree_cache.cache_unfinished_req(req)
|
404
|
+
self.output_ids = torch.tensor(self.output_ids, device=self.device)
|
405
|
+
|
406
|
+
|
407
|
+
class SchedulerDisaggregationDecodeMixin:
|
408
|
+
|
409
|
+
def get_next_disagg_decode_batch_to_run(
|
410
|
+
self: Scheduler,
|
411
|
+
) -> Optional[Tuple[ScheduleBatch, bool]]:
|
412
|
+
"""Create fake completed prefill if possible and merge with running batch"""
|
413
|
+
# Merge the prefill batch into the running batch
|
414
|
+
last_batch = self.last_batch
|
415
|
+
if last_batch and last_batch.forward_mode.is_extend():
|
416
|
+
# chunked prefill doesn't happen in decode instance.
|
417
|
+
assert self.chunked_req is None
|
418
|
+
# Filter finished batches.
|
419
|
+
last_batch.filter_batch()
|
420
|
+
if not last_batch.is_empty():
|
421
|
+
if self.running_batch.is_empty():
|
422
|
+
self.running_batch = last_batch
|
423
|
+
else:
|
424
|
+
# merge running_batch with prefill batch
|
425
|
+
self.running_batch.merge_batch(last_batch)
|
426
|
+
|
427
|
+
new_prebuilt_batch = self.get_new_prebuilt_batch()
|
428
|
+
|
429
|
+
ret: Optional[ScheduleBatch] = None
|
430
|
+
if new_prebuilt_batch:
|
431
|
+
ret = new_prebuilt_batch
|
432
|
+
else:
|
433
|
+
if self.running_batch.is_empty():
|
434
|
+
ret = None
|
435
|
+
else:
|
436
|
+
self.running_batch = self.update_running_batch(self.running_batch)
|
437
|
+
ret = self.running_batch if not self.running_batch.is_empty() else None
|
438
|
+
|
439
|
+
return ret
|
440
|
+
|
441
|
+
def get_new_prebuilt_batch(self: Scheduler) -> Optional[ScheduleBatch]:
|
442
|
+
"""Create a schedulebatch for fake completed prefill"""
|
443
|
+
if len(self.waiting_queue) == 0:
|
444
|
+
return None
|
445
|
+
|
446
|
+
curr_batch_size = self.running_batch.batch_size()
|
447
|
+
|
448
|
+
batch_size = min(self.req_to_token_pool.size, self.max_running_requests)
|
449
|
+
|
450
|
+
num_not_used_batch = batch_size - curr_batch_size
|
451
|
+
|
452
|
+
# pop req from waiting queue
|
453
|
+
can_run_list: List[Req] = []
|
454
|
+
waiting_queue: List[Req] = []
|
455
|
+
|
456
|
+
for i in range(len(self.waiting_queue)):
|
457
|
+
req = self.waiting_queue[i]
|
458
|
+
# we can only add at least `num_not_used_batch` new batch to the running queue
|
459
|
+
if i < num_not_used_batch:
|
460
|
+
can_run_list.append(req)
|
461
|
+
req.init_next_round_input(self.tree_cache)
|
462
|
+
else:
|
463
|
+
waiting_queue.append(req)
|
464
|
+
|
465
|
+
self.waiting_queue = waiting_queue
|
466
|
+
if len(can_run_list) == 0:
|
467
|
+
return None
|
468
|
+
# local import to avoid circular import
|
469
|
+
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
470
|
+
|
471
|
+
# construct a schedule batch with those requests and mark as decode
|
472
|
+
new_batch = ScheduleBatch.init_new(
|
473
|
+
can_run_list,
|
474
|
+
self.req_to_token_pool,
|
475
|
+
self.token_to_kv_pool_allocator,
|
476
|
+
self.tree_cache,
|
477
|
+
self.model_config,
|
478
|
+
self.enable_overlap,
|
479
|
+
self.spec_algorithm,
|
480
|
+
self.server_args.enable_custom_logit_processor,
|
481
|
+
)
|
482
|
+
|
483
|
+
# construct fake completed prefill
|
484
|
+
new_batch.prepare_for_prebuilt_extend()
|
485
|
+
new_batch.process_prebuilt_extend(self.server_args, self.model_config)
|
486
|
+
|
487
|
+
return new_batch
|
488
|
+
|
489
|
+
def process_decode_queue(self: Scheduler):
|
490
|
+
req_conns = self.disagg_decode_prealloc_queue.pop_preallocated()
|
491
|
+
self.disagg_decode_transfer_queue.extend(req_conns)
|
492
|
+
alloc_reqs = (
|
493
|
+
self.disagg_decode_transfer_queue.pop_transferred()
|
494
|
+
) # the requests which kv has arrived
|
495
|
+
self.waiting_queue.extend(alloc_reqs)
|