sglang 0.4.4__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (176) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +6 -0
  3. sglang/bench_one_batch.py +1 -1
  4. sglang/bench_one_batch_server.py +1 -1
  5. sglang/bench_serving.py +3 -1
  6. sglang/check_env.py +3 -4
  7. sglang/lang/backend/openai.py +18 -5
  8. sglang/lang/chat_template.py +28 -7
  9. sglang/lang/interpreter.py +7 -3
  10. sglang/lang/ir.py +10 -0
  11. sglang/srt/_custom_ops.py +1 -1
  12. sglang/srt/code_completion_parser.py +174 -0
  13. sglang/srt/configs/__init__.py +2 -6
  14. sglang/srt/configs/deepseekvl2.py +667 -0
  15. sglang/srt/configs/janus_pro.py +3 -4
  16. sglang/srt/configs/load_config.py +1 -0
  17. sglang/srt/configs/model_config.py +63 -11
  18. sglang/srt/configs/utils.py +25 -0
  19. sglang/srt/connector/__init__.py +51 -0
  20. sglang/srt/connector/base_connector.py +112 -0
  21. sglang/srt/connector/redis.py +85 -0
  22. sglang/srt/connector/s3.py +122 -0
  23. sglang/srt/connector/serde/__init__.py +31 -0
  24. sglang/srt/connector/serde/safe_serde.py +29 -0
  25. sglang/srt/connector/serde/serde.py +43 -0
  26. sglang/srt/connector/utils.py +35 -0
  27. sglang/srt/conversation.py +88 -0
  28. sglang/srt/disaggregation/conn.py +81 -0
  29. sglang/srt/disaggregation/decode.py +495 -0
  30. sglang/srt/disaggregation/mini_lb.py +285 -0
  31. sglang/srt/disaggregation/prefill.py +249 -0
  32. sglang/srt/disaggregation/utils.py +44 -0
  33. sglang/srt/distributed/parallel_state.py +10 -3
  34. sglang/srt/entrypoints/engine.py +55 -5
  35. sglang/srt/entrypoints/http_server.py +71 -12
  36. sglang/srt/function_call_parser.py +164 -54
  37. sglang/srt/hf_transformers_utils.py +28 -3
  38. sglang/srt/layers/activation.py +4 -2
  39. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +295 -0
  41. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  42. sglang/srt/layers/attention/flashmla_backend.py +284 -0
  43. sglang/srt/layers/attention/triton_backend.py +171 -38
  44. sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
  45. sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
  46. sglang/srt/layers/attention/utils.py +53 -0
  47. sglang/srt/layers/attention/vision.py +9 -28
  48. sglang/srt/layers/dp_attention.py +62 -23
  49. sglang/srt/layers/elementwise.py +411 -0
  50. sglang/srt/layers/layernorm.py +24 -2
  51. sglang/srt/layers/linear.py +17 -5
  52. sglang/srt/layers/logits_processor.py +26 -7
  53. sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
  54. sglang/srt/layers/moe/ep_moe/layer.py +273 -1
  55. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
  56. sglang/srt/layers/moe/fused_moe_native.py +2 -1
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
  63. sglang/srt/layers/moe/router.py +342 -0
  64. sglang/srt/layers/moe/topk.py +31 -18
  65. sglang/srt/layers/parameter.py +1 -1
  66. sglang/srt/layers/quantization/__init__.py +184 -126
  67. sglang/srt/layers/quantization/base_config.py +5 -0
  68. sglang/srt/layers/quantization/blockwise_int8.py +1 -1
  69. sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
  70. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
  71. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
  72. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
  73. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
  74. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
  75. sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
  76. sglang/srt/layers/quantization/fp8.py +76 -34
  77. sglang/srt/layers/quantization/fp8_kernel.py +24 -8
  78. sglang/srt/layers/quantization/fp8_utils.py +284 -28
  79. sglang/srt/layers/quantization/gptq.py +36 -9
  80. sglang/srt/layers/quantization/kv_cache.py +98 -0
  81. sglang/srt/layers/quantization/modelopt_quant.py +9 -7
  82. sglang/srt/layers/quantization/utils.py +153 -0
  83. sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
  84. sglang/srt/layers/rotary_embedding.py +66 -87
  85. sglang/srt/layers/sampler.py +1 -1
  86. sglang/srt/lora/layers.py +68 -0
  87. sglang/srt/lora/lora.py +2 -22
  88. sglang/srt/lora/lora_manager.py +47 -23
  89. sglang/srt/lora/mem_pool.py +110 -51
  90. sglang/srt/lora/utils.py +12 -1
  91. sglang/srt/managers/cache_controller.py +4 -5
  92. sglang/srt/managers/data_parallel_controller.py +31 -9
  93. sglang/srt/managers/expert_distribution.py +81 -0
  94. sglang/srt/managers/io_struct.py +39 -3
  95. sglang/srt/managers/mm_utils.py +373 -0
  96. sglang/srt/managers/multimodal_processor.py +68 -0
  97. sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
  98. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
  99. sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
  100. sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
  101. sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
  102. sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
  103. sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
  104. sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
  105. sglang/srt/managers/schedule_batch.py +134 -31
  106. sglang/srt/managers/scheduler.py +325 -38
  107. sglang/srt/managers/scheduler_output_processor_mixin.py +4 -1
  108. sglang/srt/managers/session_controller.py +1 -1
  109. sglang/srt/managers/tokenizer_manager.py +59 -23
  110. sglang/srt/managers/tp_worker.py +1 -1
  111. sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
  112. sglang/srt/managers/utils.py +6 -1
  113. sglang/srt/mem_cache/hiradix_cache.py +27 -8
  114. sglang/srt/mem_cache/memory_pool.py +258 -98
  115. sglang/srt/mem_cache/paged_allocator.py +2 -2
  116. sglang/srt/mem_cache/radix_cache.py +4 -4
  117. sglang/srt/model_executor/cuda_graph_runner.py +85 -28
  118. sglang/srt/model_executor/forward_batch_info.py +81 -15
  119. sglang/srt/model_executor/model_runner.py +70 -6
  120. sglang/srt/model_loader/loader.py +160 -2
  121. sglang/srt/model_loader/weight_utils.py +45 -0
  122. sglang/srt/models/deepseek_janus_pro.py +29 -86
  123. sglang/srt/models/deepseek_nextn.py +22 -10
  124. sglang/srt/models/deepseek_v2.py +326 -192
  125. sglang/srt/models/deepseek_vl2.py +358 -0
  126. sglang/srt/models/gemma3_causal.py +684 -0
  127. sglang/srt/models/gemma3_mm.py +462 -0
  128. sglang/srt/models/grok.py +374 -119
  129. sglang/srt/models/llama.py +47 -7
  130. sglang/srt/models/llama_eagle.py +1 -0
  131. sglang/srt/models/llama_eagle3.py +196 -0
  132. sglang/srt/models/llava.py +3 -3
  133. sglang/srt/models/llavavid.py +3 -3
  134. sglang/srt/models/minicpmo.py +1995 -0
  135. sglang/srt/models/minicpmv.py +62 -137
  136. sglang/srt/models/mllama.py +4 -4
  137. sglang/srt/models/phi3_small.py +1 -1
  138. sglang/srt/models/qwen2.py +3 -0
  139. sglang/srt/models/qwen2_5_vl.py +68 -146
  140. sglang/srt/models/qwen2_classification.py +75 -0
  141. sglang/srt/models/qwen2_moe.py +9 -1
  142. sglang/srt/models/qwen2_vl.py +25 -63
  143. sglang/srt/openai_api/adapter.py +145 -47
  144. sglang/srt/openai_api/protocol.py +23 -2
  145. sglang/srt/sampling/sampling_batch_info.py +1 -1
  146. sglang/srt/sampling/sampling_params.py +6 -6
  147. sglang/srt/server_args.py +104 -14
  148. sglang/srt/speculative/build_eagle_tree.py +7 -347
  149. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
  150. sglang/srt/speculative/eagle_utils.py +208 -252
  151. sglang/srt/speculative/eagle_worker.py +139 -53
  152. sglang/srt/speculative/spec_info.py +6 -1
  153. sglang/srt/torch_memory_saver_adapter.py +22 -0
  154. sglang/srt/utils.py +182 -21
  155. sglang/test/__init__.py +0 -0
  156. sglang/test/attention/__init__.py +0 -0
  157. sglang/test/attention/test_flashattn_backend.py +312 -0
  158. sglang/test/runners.py +2 -0
  159. sglang/test/test_activation.py +2 -1
  160. sglang/test/test_block_fp8.py +5 -4
  161. sglang/test/test_block_fp8_ep.py +2 -1
  162. sglang/test/test_dynamic_grad_mode.py +58 -0
  163. sglang/test/test_layernorm.py +3 -2
  164. sglang/test/test_utils.py +55 -4
  165. sglang/utils.py +31 -0
  166. sglang/version.py +1 -1
  167. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
  168. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +171 -125
  169. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
  170. sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
  171. sglang/srt/managers/image_processor.py +0 -55
  172. sglang/srt/managers/image_processors/base_image_processor.py +0 -219
  173. sglang/srt/managers/image_processors/minicpmv.py +0 -86
  174. sglang/srt/managers/multi_modality_padding.py +0 -134
  175. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
  176. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,495 @@
1
+ """
2
+ Life cycle of a request in the decode server
3
+
4
+ 1. PreallocQueue:
5
+ a. Initialize a receiver for each request
6
+ b. The request handshakes first, and pre-allocate kv once there is available kv.
7
+ c. Move the request to TransferQueue.
8
+
9
+ 2. TransferQueue:
10
+ a. Poll the receiver to check the transfer state
11
+ b. If the transfer has finished, move the request to waiting queue
12
+
13
+ 3. WaitingQueue:
14
+ a. Use the requests in the queue to construct a PrebuiltExtendBatch
15
+ b. Skip the prefill forward but only populate metadata
16
+
17
+ 4. RunningBatch:
18
+ a. Merge the resolved PrebuiltExtendBatch into running batch to run decoding
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ import logging
24
+ from dataclasses import dataclass
25
+ from typing import TYPE_CHECKING, List, Optional, Tuple
26
+
27
+ import torch
28
+ from torch.distributed import ProcessGroup
29
+
30
+ from sglang.srt.disaggregation.conn import KVArgs, KVManager, KVPoll, KVReceiver
31
+ from sglang.srt.disaggregation.utils import (
32
+ ReqToMetadataIdxAllocator,
33
+ poll_and_all_reduce,
34
+ )
35
+ from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
36
+ from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
37
+ from sglang.srt.model_executor.forward_batch_info import ForwardMode
38
+ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
39
+
40
+ logger = logging.getLogger(__name__)
41
+
42
+ if TYPE_CHECKING:
43
+ from sglang.srt.configs.model_config import ModelConfig
44
+ from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
45
+ from sglang.srt.managers.scheduler import Scheduler
46
+ from sglang.srt.server_args import ServerArgs
47
+
48
+
49
+ @dataclass
50
+ class DecodeRequest:
51
+ req: Req
52
+ kv_receiver: KVReceiver
53
+ waiting_for_input: bool = False
54
+ metadata_buffer_index: int = -1
55
+
56
+
57
+ class DecodePreallocQueue:
58
+ """
59
+ Store the requests that are preallocating.
60
+ """
61
+
62
+ def __init__(
63
+ self,
64
+ req_to_token_pool: ReqToTokenPool,
65
+ token_to_kv_pool_allocator: TokenToKVPoolAllocator,
66
+ req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
67
+ metadata_buffers: List[torch.Tensor],
68
+ aux_dtype: torch.dtype,
69
+ scheduler: Scheduler,
70
+ transfer_queue: DecodeTransferQueue,
71
+ tree_cache: BasePrefixCache,
72
+ gloo_group: ProcessGroup,
73
+ tp_rank: int,
74
+ tp_size: int,
75
+ bootstrap_port: int,
76
+ ):
77
+ self.req_to_token_pool = req_to_token_pool
78
+ self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
79
+ self.token_to_kv_pool = token_to_kv_pool_allocator.get_kvcache()
80
+ self.aux_dtype = aux_dtype
81
+ self.metadata_buffers = metadata_buffers
82
+ self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
83
+ self.scheduler = scheduler
84
+ self.transfer_queue = transfer_queue
85
+ self.tree_cache = tree_cache # this is always a chunk cache
86
+ self.gloo_group = gloo_group
87
+ self.tp_rank = tp_rank
88
+ self.tp_size = tp_size
89
+ self.bootstrap_port = bootstrap_port
90
+
91
+ self.num_reserved_decode_tokens = 512
92
+
93
+ # Queue for requests pending pre-allocation
94
+ self.queue: List[DecodeRequest] = []
95
+ self.kv_manager = self._init_kv_manager()
96
+
97
+ def _init_kv_manager(self) -> KVManager:
98
+ kv_args = KVArgs()
99
+ kv_args.engine_rank = self.tp_rank
100
+ kv_data_ptrs, kv_data_lens, kv_item_lens = (
101
+ self.token_to_kv_pool.get_contiguous_buf_infos()
102
+ )
103
+
104
+ kv_args.kv_data_ptrs = kv_data_ptrs
105
+ kv_args.kv_data_lens = kv_data_lens
106
+ kv_args.kv_item_lens = kv_item_lens
107
+
108
+ kv_args.aux_data_ptrs = [
109
+ output_id_tensor.data_ptr() for output_id_tensor in self.metadata_buffers
110
+ ]
111
+ kv_args.aux_data_lens = [
112
+ metadata_buffer.nbytes for metadata_buffer in self.metadata_buffers
113
+ ]
114
+ kv_args.aux_item_lens = [
115
+ metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers
116
+ ]
117
+ kv_args.ib_device = "mock-ib-device"
118
+ kv_manager = KVManager(kv_args)
119
+ return kv_manager
120
+
121
+ def add(self, req: Req) -> None:
122
+ """Add a request to the pending queue."""
123
+
124
+ kv_receiver = KVReceiver(
125
+ mgr=self.kv_manager,
126
+ bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}",
127
+ bootstrap_room=req.bootstrap_room,
128
+ )
129
+ self.queue.append(DecodeRequest(req=req, kv_receiver=kv_receiver))
130
+
131
+ def extend(self, reqs: List[Req]) -> None:
132
+ """Add a request to the pending queue."""
133
+ for req in reqs:
134
+ self.add(req)
135
+
136
+ def _update_handshake_waiters(self) -> None:
137
+ if not self.queue:
138
+ return
139
+
140
+ if all(decode_req.waiting_for_input for decode_req in self.queue):
141
+ return
142
+
143
+ polls = poll_and_all_reduce(
144
+ [decode_req.kv_receiver for decode_req in self.queue], self.gloo_group
145
+ )
146
+
147
+ for i, (decode_req, poll) in enumerate(zip(self.queue, polls)):
148
+ if poll == KVPoll.Bootstrapping:
149
+ pass
150
+ elif poll == KVPoll.WaitingForInput:
151
+ decode_req.waiting_for_input = True
152
+ elif poll == KVPoll.Failed:
153
+ raise Exception("Handshake failed")
154
+
155
+ def pop_preallocated(self) -> List[DecodeRequest]:
156
+ """Pop the preallocated requests from the pending queue (FIFO)."""
157
+ self._update_handshake_waiters()
158
+
159
+ preallocated_reqs = []
160
+ indices_to_remove = set()
161
+ allocatable_tokens = self._allocatable_tokens()
162
+
163
+ for i, decode_req in enumerate(self.queue):
164
+ if not decode_req.waiting_for_input:
165
+ continue
166
+
167
+ if self.req_to_token_pool.available_size() <= 0:
168
+ break
169
+
170
+ if self.req_to_metadata_buffer_idx_allocator.available_size() <= 0:
171
+ break
172
+
173
+ required_tokens_for_request = (
174
+ len(decode_req.req.origin_input_ids) + self.num_reserved_decode_tokens
175
+ )
176
+
177
+ if required_tokens_for_request > allocatable_tokens:
178
+ break
179
+
180
+ allocatable_tokens -= required_tokens_for_request
181
+ self._pre_alloc(decode_req.req)
182
+
183
+ kv_indices = (
184
+ self.req_to_token_pool.req_to_token[decode_req.req.req_pool_idx][
185
+ : len(decode_req.req.origin_input_ids)
186
+ ]
187
+ .cpu()
188
+ .numpy()
189
+ )
190
+
191
+ decode_req.metadata_buffer_index = (
192
+ self.req_to_metadata_buffer_idx_allocator.alloc()
193
+ )
194
+ assert decode_req.metadata_buffer_index is not None
195
+ decode_req.kv_receiver.init(kv_indices, decode_req.metadata_buffer_index)
196
+ preallocated_reqs.append(decode_req)
197
+ indices_to_remove.add(i)
198
+
199
+ self.queue = [
200
+ entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
201
+ ]
202
+
203
+ return preallocated_reqs
204
+
205
+ def _allocatable_tokens(self) -> int:
206
+ allocatable_tokens = (
207
+ self.token_to_kv_pool_allocator.available_size()
208
+ - self.num_reserved_decode_tokens
209
+ * (
210
+ len(self.scheduler.running_batch.reqs)
211
+ + len(self.transfer_queue.queue)
212
+ + len(self.scheduler.waiting_queue)
213
+ )
214
+ )
215
+
216
+ # Note: if the last fake extend just finishes, and we enter `pop_preallocated` immediately in the next iteration
217
+ # the extend batch is not in any queue, so we need to explicitly add the tokens slots here
218
+ if (
219
+ self.scheduler.last_batch
220
+ and self.scheduler.last_batch.forward_mode.is_extend()
221
+ ):
222
+ allocatable_tokens -= self.num_reserved_decode_tokens * len(
223
+ self.scheduler.last_batch.reqs
224
+ )
225
+
226
+ return allocatable_tokens
227
+
228
+ def _pre_alloc(self, req: Req) -> torch.Tensor:
229
+ """Pre-allocate the memory for req_to_token and token_kv_pool"""
230
+ req_pool_indices = self.req_to_token_pool.alloc(1)
231
+
232
+ assert req_pool_indices is not None
233
+
234
+ req.req_pool_idx = req_pool_indices[0]
235
+ kv_loc = self.token_to_kv_pool_allocator.alloc(
236
+ len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
237
+ )
238
+
239
+ assert kv_loc is not None
240
+
241
+ self.req_to_token_pool.write((req.req_pool_idx, slice(0, len(kv_loc))), kv_loc)
242
+
243
+ # populate metadata
244
+ req.fill_ids = req.origin_input_ids + req.output_ids
245
+ req.extend_input_len = len(req.origin_input_ids)
246
+
247
+ return kv_loc
248
+
249
+
250
+ class DecodeTransferQueue:
251
+ """
252
+ Store the requests that is polling kv
253
+ """
254
+
255
+ def __init__(
256
+ self,
257
+ gloo_group: ProcessGroup,
258
+ req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
259
+ metadata_buffers: torch.Tensor,
260
+ ):
261
+ self.queue: List[DecodeRequest] = []
262
+ self.gloo_group = gloo_group
263
+ self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
264
+ self.metadata_buffers = metadata_buffers
265
+
266
+ def add(self, req_conn: DecodeRequest) -> None:
267
+ self.queue.append(req_conn)
268
+
269
+ def extend(self, req_conns) -> None:
270
+ self.queue.extend(req_conns)
271
+
272
+ def pop_transferred(self) -> List[Req]:
273
+ if not self.queue:
274
+ return []
275
+
276
+ polls = poll_and_all_reduce(
277
+ [decode_req.kv_receiver for decode_req in self.queue], self.gloo_group
278
+ )
279
+
280
+ transferred_reqs = []
281
+ indices_to_remove = set()
282
+ for i, (decode_req, poll) in enumerate(zip(self.queue, polls)):
283
+ if poll == KVPoll.Failed:
284
+ raise Exception("Transfer failed")
285
+ elif poll == KVPoll.Success:
286
+ # pop and push it to waiting queue
287
+ idx = decode_req.metadata_buffer_index
288
+ assert len(decode_req.req.output_ids) == 0
289
+ output_id_buffer = self.metadata_buffers[0]
290
+ # the last dimension is padded by the same values.
291
+ output_id = output_id_buffer[idx][0].item()
292
+ assert len(decode_req.req.output_ids) == 0
293
+ assert decode_req.req.transferred_output_id is None
294
+ decode_req.req.transferred_output_id = output_id
295
+ transferred_reqs.append(decode_req.req)
296
+ indices_to_remove.add(i)
297
+ elif poll in [
298
+ KVPoll.Bootstrapping,
299
+ KVPoll.WaitingForInput,
300
+ KVPoll.Transferring,
301
+ ]:
302
+ pass
303
+ else:
304
+ raise ValueError(f"Unexpected poll case: {poll}")
305
+
306
+ for i in indices_to_remove:
307
+ idx = self.queue[i].metadata_buffer_index
308
+ assert idx != -1
309
+ self.req_to_metadata_buffer_idx_allocator.free(idx)
310
+
311
+ self.queue = [
312
+ entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
313
+ ]
314
+
315
+ return transferred_reqs
316
+
317
+
318
+ class ScheduleBatchDisaggregationDecodeMixin:
319
+
320
+ def prepare_for_prebuilt_extend(self: ScheduleBatch):
321
+ """
322
+ Prepare a prebuilt extend by populate metadata
323
+ Adapted from .prepare_for_extend().
324
+ """
325
+
326
+ self.forward_mode = ForwardMode.EXTEND
327
+ reqs = self.reqs
328
+ input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
329
+ extend_num_tokens = sum(len(ids) for ids in input_ids)
330
+ seq_lens = []
331
+ pre_lens = []
332
+ req_pool_indices = []
333
+
334
+ # Pre-calculate total size
335
+ total_size = sum(req.extend_input_len for req in reqs)
336
+ out_cache_loc = torch.empty(total_size, dtype=torch.int64, device=self.device)
337
+
338
+ # Fill the tensor in one pass
339
+ offset = 0
340
+ for i, req in enumerate(reqs):
341
+ req_pool_indices.append(req.req_pool_idx)
342
+
343
+ chunk = self.req_to_token_pool.req_to_token[req.req_pool_idx][
344
+ : req.extend_input_len
345
+ ]
346
+ assert (
347
+ offset + req.extend_input_len <= total_size
348
+ ), f"Exceeds total size: offset={offset}, req.extend_input_len={req.extend_input_len}, total_size={total_size}"
349
+ out_cache_loc[offset : offset + req.extend_input_len] = chunk
350
+ offset += req.extend_input_len
351
+
352
+ pre_len = len(req.prefix_indices)
353
+ seq_len = len(req.origin_input_ids) + max(0, len(req.output_ids) - 1)
354
+ seq_lens.append(seq_len)
355
+ if len(req.output_ids) == 0:
356
+ assert (
357
+ seq_len - pre_len == req.extend_input_len
358
+ ), f"seq_len={seq_len}, pre_len={pre_len}, req.extend_input_len={req.extend_input_len}"
359
+
360
+ req.cached_tokens += pre_len - req.already_computed
361
+ req.already_computed = seq_len
362
+ req.is_retracted = False
363
+ pre_lens.append(pre_len)
364
+ req.extend_logprob_start_len = 0
365
+
366
+ extend_input_logprob_token_ids = None
367
+
368
+ # Set fields
369
+ self.input_ids = torch.tensor(
370
+ sum(input_ids, []), dtype=torch.int32, device=self.device
371
+ )
372
+ self.req_pool_indices = torch.tensor(
373
+ req_pool_indices, dtype=torch.int64, device=self.device
374
+ )
375
+ self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64, device=self.device)
376
+ self.out_cache_loc = out_cache_loc
377
+ self.seq_lens_sum = sum(seq_lens)
378
+ self.extend_num_tokens = extend_num_tokens
379
+ self.prefix_lens = [len(r.prefix_indices) for r in reqs]
380
+ self.extend_lens = [r.extend_input_len for r in reqs]
381
+ self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
382
+ self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
383
+
384
+ # Build sampling info
385
+ self.sampling_info = SamplingBatchInfo.from_schedule_batch(
386
+ self,
387
+ self.model_config.vocab_size,
388
+ )
389
+
390
+ def process_prebuilt_extend(
391
+ self: ScheduleBatch, server_args: ServerArgs, model_config: ModelConfig
392
+ ):
393
+ """Assign the buffered last input id to schedule batch"""
394
+ self.output_ids = []
395
+ for req in self.reqs:
396
+ if req.output_ids and len(req.output_ids) > 0:
397
+ # resumed retracted req
398
+ self.output_ids.append(req.output_ids[-1])
399
+ else:
400
+ assert req.transferred_output_id is not None
401
+ req.output_ids.append(req.transferred_output_id)
402
+ self.output_ids.append(req.transferred_output_id)
403
+ self.tree_cache.cache_unfinished_req(req)
404
+ self.output_ids = torch.tensor(self.output_ids, device=self.device)
405
+
406
+
407
+ class SchedulerDisaggregationDecodeMixin:
408
+
409
+ def get_next_disagg_decode_batch_to_run(
410
+ self: Scheduler,
411
+ ) -> Optional[Tuple[ScheduleBatch, bool]]:
412
+ """Create fake completed prefill if possible and merge with running batch"""
413
+ # Merge the prefill batch into the running batch
414
+ last_batch = self.last_batch
415
+ if last_batch and last_batch.forward_mode.is_extend():
416
+ # chunked prefill doesn't happen in decode instance.
417
+ assert self.chunked_req is None
418
+ # Filter finished batches.
419
+ last_batch.filter_batch()
420
+ if not last_batch.is_empty():
421
+ if self.running_batch.is_empty():
422
+ self.running_batch = last_batch
423
+ else:
424
+ # merge running_batch with prefill batch
425
+ self.running_batch.merge_batch(last_batch)
426
+
427
+ new_prebuilt_batch = self.get_new_prebuilt_batch()
428
+
429
+ ret: Optional[ScheduleBatch] = None
430
+ if new_prebuilt_batch:
431
+ ret = new_prebuilt_batch
432
+ else:
433
+ if self.running_batch.is_empty():
434
+ ret = None
435
+ else:
436
+ self.running_batch = self.update_running_batch(self.running_batch)
437
+ ret = self.running_batch if not self.running_batch.is_empty() else None
438
+
439
+ return ret
440
+
441
+ def get_new_prebuilt_batch(self: Scheduler) -> Optional[ScheduleBatch]:
442
+ """Create a schedulebatch for fake completed prefill"""
443
+ if len(self.waiting_queue) == 0:
444
+ return None
445
+
446
+ curr_batch_size = self.running_batch.batch_size()
447
+
448
+ batch_size = min(self.req_to_token_pool.size, self.max_running_requests)
449
+
450
+ num_not_used_batch = batch_size - curr_batch_size
451
+
452
+ # pop req from waiting queue
453
+ can_run_list: List[Req] = []
454
+ waiting_queue: List[Req] = []
455
+
456
+ for i in range(len(self.waiting_queue)):
457
+ req = self.waiting_queue[i]
458
+ # we can only add at least `num_not_used_batch` new batch to the running queue
459
+ if i < num_not_used_batch:
460
+ can_run_list.append(req)
461
+ req.init_next_round_input(self.tree_cache)
462
+ else:
463
+ waiting_queue.append(req)
464
+
465
+ self.waiting_queue = waiting_queue
466
+ if len(can_run_list) == 0:
467
+ return None
468
+ # local import to avoid circular import
469
+ from sglang.srt.managers.schedule_batch import ScheduleBatch
470
+
471
+ # construct a schedule batch with those requests and mark as decode
472
+ new_batch = ScheduleBatch.init_new(
473
+ can_run_list,
474
+ self.req_to_token_pool,
475
+ self.token_to_kv_pool_allocator,
476
+ self.tree_cache,
477
+ self.model_config,
478
+ self.enable_overlap,
479
+ self.spec_algorithm,
480
+ self.server_args.enable_custom_logit_processor,
481
+ )
482
+
483
+ # construct fake completed prefill
484
+ new_batch.prepare_for_prebuilt_extend()
485
+ new_batch.process_prebuilt_extend(self.server_args, self.model_config)
486
+
487
+ return new_batch
488
+
489
+ def process_decode_queue(self: Scheduler):
490
+ req_conns = self.disagg_decode_prealloc_queue.pop_preallocated()
491
+ self.disagg_decode_transfer_queue.extend(req_conns)
492
+ alloc_reqs = (
493
+ self.disagg_decode_transfer_queue.pop_transferred()
494
+ ) # the requests which kv has arrived
495
+ self.waiting_queue.extend(alloc_reqs)