sglang 0.4.6.post1__py3-none-any.whl → 0.4.6.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -0
- sglang/check_env.py +3 -3
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/kimi_vl.py +38 -0
- sglang/srt/configs/kimi_vl_moonvit.py +32 -0
- sglang/srt/configs/model_config.py +15 -0
- sglang/srt/conversation.py +122 -1
- sglang/srt/entrypoints/engine.py +44 -22
- sglang/srt/function_call_parser.py +97 -0
- sglang/srt/hf_transformers_utils.py +2 -0
- sglang/srt/layers/attention/cutlass_mla_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_backend.py +107 -82
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -16
- sglang/srt/layers/attention/flashmla_backend.py +3 -0
- sglang/srt/layers/dp_attention.py +5 -2
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +1 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +8 -6
- sglang/srt/layers/quantization/__init__.py +2 -2
- sglang/srt/layers/quantization/deep_gemm.py +1 -1
- sglang/srt/layers/utils.py +35 -0
- sglang/srt/lora/layers.py +35 -9
- sglang/srt/lora/lora_manager.py +84 -35
- sglang/srt/managers/data_parallel_controller.py +52 -34
- sglang/srt/managers/multimodal_processors/kimi_vl.py +73 -0
- sglang/srt/managers/schedule_batch.py +25 -15
- sglang/srt/managers/scheduler.py +263 -59
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -1
- sglang/srt/managers/tp_worker.py +51 -16
- sglang/srt/managers/tp_worker_overlap_thread.py +9 -3
- sglang/srt/mem_cache/memory_pool.py +70 -36
- sglang/srt/model_executor/cuda_graph_runner.py +82 -19
- sglang/srt/model_executor/forward_batch_info.py +31 -1
- sglang/srt/model_executor/model_runner.py +115 -57
- sglang/srt/models/deepseek_nextn.py +1 -257
- sglang/srt/models/deepseek_v2.py +78 -18
- sglang/srt/models/kimi_vl.py +308 -0
- sglang/srt/models/kimi_vl_moonvit.py +639 -0
- sglang/srt/models/llama.py +92 -30
- sglang/srt/models/llama4.py +2 -1
- sglang/srt/models/llama_eagle.py +4 -1
- sglang/srt/models/llama_eagle3.py +4 -1
- sglang/srt/models/qwen2_moe.py +8 -3
- sglang/srt/models/qwen2_vl.py +0 -12
- sglang/srt/models/qwen3_moe.py +8 -3
- sglang/srt/openai_api/adapter.py +34 -22
- sglang/srt/openai_api/protocol.py +11 -1
- sglang/srt/server_args.py +67 -22
- sglang/srt/speculative/eagle_worker.py +3 -2
- sglang/srt/utils.py +88 -9
- sglang/test/runners.py +4 -0
- sglang/test/test_utils.py +29 -0
- sglang/version.py +1 -1
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post2.dist-info}/METADATA +5 -4
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post2.dist-info}/RECORD +61 -51
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post2.dist-info}/WHEEL +1 -1
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post2.dist-info}/top_level.txt +0 -0
sglang/srt/lora/lora_manager.py
CHANGED
@@ -72,6 +72,23 @@ class LoRAManager:
|
|
72
72
|
self.init_loras()
|
73
73
|
self.init_lora_memory_pool()
|
74
74
|
|
75
|
+
def init_cuda_graph_batch_info(self, max_bs_in_cuda_graph: int):
|
76
|
+
self.max_bs_in_cuda_graph = max_bs_in_cuda_graph
|
77
|
+
with torch.device("cuda"):
|
78
|
+
self.cuda_graph_batch_info = LoRABatchInfo(
|
79
|
+
bs=self.max_bs_in_cuda_graph,
|
80
|
+
seg_lens=torch.zeros(self.max_bs_in_cuda_graph, dtype=torch.int32),
|
81
|
+
seg_indptr=torch.zeros(
|
82
|
+
self.max_bs_in_cuda_graph + 1, dtype=torch.int32
|
83
|
+
),
|
84
|
+
max_len=0,
|
85
|
+
weight_indices=torch.zeros(
|
86
|
+
self.max_bs_in_cuda_graph, dtype=torch.int32
|
87
|
+
),
|
88
|
+
lora_ranks=torch.zeros(self.max_loras_per_batch, dtype=torch.int32),
|
89
|
+
scalings=torch.zeros(self.max_loras_per_batch, dtype=torch.float),
|
90
|
+
)
|
91
|
+
|
75
92
|
def init_loras(self):
|
76
93
|
# Config of each LoRA adapter
|
77
94
|
self.configs: Dict[str, LoRAConfig] = {}
|
@@ -136,43 +153,75 @@ class LoRAManager:
|
|
136
153
|
assert len(cur_uids) <= self.max_loras_per_batch
|
137
154
|
self.memory_pool.prepare_lora_batch(cur_uids, self.loras)
|
138
155
|
|
139
|
-
#
|
140
|
-
if cur_uids == set([None]):
|
141
|
-
return
|
142
|
-
|
143
|
-
# set up batch info shared by all lora moruldes
|
156
|
+
# set up batch info shared by all lora modules
|
144
157
|
bs = forward_batch.batch_size
|
145
|
-
seg_lens = (
|
146
|
-
forward_batch.extend_seq_lens
|
147
|
-
if forward_batch.forward_mode.is_extend()
|
148
|
-
else torch.ones(bs, device=self.device)
|
149
|
-
)
|
150
|
-
seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device)
|
151
|
-
seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
|
152
|
-
max_len = int(torch.max(seg_lens))
|
153
|
-
weight_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
|
154
158
|
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
159
|
+
if hasattr(self, "max_bs_in_cuda_graph") and bs <= self.max_bs_in_cuda_graph:
|
160
|
+
# Do in-place updates when CUDA graph is enabled. Note that
|
161
|
+
# if CUDA graph is enabled, the batch whose bs <= max_bs_in_cuda_graph
|
162
|
+
# will also use these preallocated buffers, no matter whether
|
163
|
+
# the batch can use CUDA graph or not.
|
164
|
+
self.cuda_graph_batch_info.bs = bs
|
165
|
+
if forward_batch.forward_mode.is_extend():
|
166
|
+
self.cuda_graph_batch_info.seg_lens[:bs].copy_(
|
167
|
+
forward_batch.extend_seq_lens
|
168
|
+
)
|
169
|
+
else:
|
170
|
+
self.cuda_graph_batch_info.seg_lens[:bs].fill_(1)
|
171
|
+
torch.cumsum(
|
172
|
+
self.cuda_graph_batch_info.seg_lens[:bs],
|
173
|
+
dim=0,
|
174
|
+
out=self.cuda_graph_batch_info.seg_indptr[1 : bs + 1],
|
175
|
+
)
|
176
|
+
self.cuda_graph_batch_info.max_len = int(
|
177
|
+
torch.max(self.cuda_graph_batch_info.seg_lens[:bs])
|
178
|
+
)
|
179
|
+
|
180
|
+
for i, lora_path in enumerate(forward_batch.lora_paths):
|
181
|
+
self.cuda_graph_batch_info.weight_indices[i] = (
|
182
|
+
self.memory_pool.get_buffer_id(lora_path)
|
183
|
+
)
|
184
|
+
if lora_path is not None:
|
185
|
+
lora = self.loras[lora_path]
|
186
|
+
self.cuda_graph_batch_info.lora_ranks[
|
187
|
+
self.cuda_graph_batch_info.weight_indices[i]
|
188
|
+
] = lora.config.hf_config["r"]
|
189
|
+
self.cuda_graph_batch_info.scalings[
|
190
|
+
self.cuda_graph_batch_info.weight_indices[i]
|
191
|
+
] = lora.scaling
|
192
|
+
batch_info = self.cuda_graph_batch_info
|
193
|
+
else:
|
194
|
+
seg_lens = (
|
195
|
+
forward_batch.extend_seq_lens
|
196
|
+
if forward_batch.forward_mode.is_extend()
|
197
|
+
else torch.ones(bs, device=self.device)
|
198
|
+
)
|
199
|
+
seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device)
|
200
|
+
seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
|
201
|
+
max_len = int(torch.max(seg_lens))
|
202
|
+
weight_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
|
203
|
+
|
204
|
+
lora_ranks = torch.empty(
|
205
|
+
(self.max_loras_per_batch,), dtype=torch.int64, device="cuda"
|
206
|
+
)
|
207
|
+
scalings = torch.empty(
|
208
|
+
(self.max_loras_per_batch,), dtype=torch.float, device="cuda"
|
209
|
+
)
|
210
|
+
for i, lora_path in enumerate(forward_batch.lora_paths):
|
211
|
+
weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
|
212
|
+
if lora_path is not None:
|
213
|
+
lora = self.loras[lora_path]
|
214
|
+
lora_ranks[weight_indices[i]] = lora.config.hf_config["r"]
|
215
|
+
scalings[weight_indices[i]] = lora.scaling
|
216
|
+
batch_info = LoRABatchInfo(
|
217
|
+
bs=bs,
|
218
|
+
seg_lens=seg_lens,
|
219
|
+
seg_indptr=seg_indptr,
|
220
|
+
max_len=max_len,
|
221
|
+
weight_indices=weight_indices,
|
222
|
+
lora_ranks=lora_ranks,
|
223
|
+
scalings=scalings,
|
224
|
+
)
|
176
225
|
self.lora_backend.set_batch_info(batch_info)
|
177
226
|
|
178
227
|
# call set_lora_info for each lora modules
|
@@ -181,44 +181,62 @@ class DataParallelController:
|
|
181
181
|
enable=server_args.enable_memory_saver
|
182
182
|
)
|
183
183
|
|
184
|
-
# Launch tensor parallel scheduler processes
|
185
184
|
scheduler_pipe_readers = []
|
186
|
-
|
185
|
+
|
186
|
+
nnodes_per_tp_group = max(server_args.nnodes // server_args.pp_size, 1)
|
187
|
+
tp_size_per_node = server_args.tp_size // nnodes_per_tp_group
|
187
188
|
tp_rank_range = range(
|
188
|
-
tp_size_per_node * server_args.node_rank,
|
189
|
-
tp_size_per_node * (server_args.node_rank + 1),
|
189
|
+
tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group),
|
190
|
+
tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group + 1),
|
191
|
+
)
|
192
|
+
|
193
|
+
pp_size_per_node = max(server_args.pp_size // server_args.nnodes, 1)
|
194
|
+
pp_rank_range = range(
|
195
|
+
pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group),
|
196
|
+
pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group + 1),
|
190
197
|
)
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
198
|
+
|
199
|
+
for pp_rank in pp_rank_range:
|
200
|
+
for tp_rank in tp_rank_range:
|
201
|
+
rank_port_args = port_args
|
202
|
+
|
203
|
+
if server_args.enable_dp_attention:
|
204
|
+
# dp attention has different sharding logic
|
205
|
+
_, _, dp_rank = compute_dp_attention_world_info(
|
206
|
+
server_args.enable_dp_attention,
|
207
|
+
tp_rank,
|
208
|
+
server_args.tp_size,
|
209
|
+
server_args.dp_size,
|
210
|
+
)
|
211
|
+
# compute zmq ports for this dp rank
|
212
|
+
rank_port_args = PortArgs.init_new(server_args, dp_rank)
|
213
|
+
# Data parallelism resues the tensor parallelism group,
|
214
|
+
# so all dp ranks should use the same nccl port.
|
215
|
+
rank_port_args.nccl_port = port_args.nccl_port
|
216
|
+
|
217
|
+
reader, writer = mp.Pipe(duplex=False)
|
218
|
+
gpu_id = (
|
219
|
+
server_args.base_gpu_id
|
220
|
+
+ base_gpu_id
|
221
|
+
+ ((pp_rank % pp_size_per_node) * tp_size_per_node)
|
222
|
+
+ (tp_rank % tp_size_per_node) * server_args.gpu_id_step
|
201
223
|
)
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
with memory_saver_adapter.configure_subprocess():
|
219
|
-
proc.start()
|
220
|
-
self.scheduler_procs.append(proc)
|
221
|
-
scheduler_pipe_readers.append(reader)
|
224
|
+
proc = mp.Process(
|
225
|
+
target=run_scheduler_process,
|
226
|
+
args=(
|
227
|
+
server_args,
|
228
|
+
rank_port_args,
|
229
|
+
gpu_id,
|
230
|
+
tp_rank,
|
231
|
+
pp_rank,
|
232
|
+
dp_rank,
|
233
|
+
writer,
|
234
|
+
),
|
235
|
+
)
|
236
|
+
with memory_saver_adapter.configure_subprocess():
|
237
|
+
proc.start()
|
238
|
+
self.scheduler_procs.append(proc)
|
239
|
+
scheduler_pipe_readers.append(reader)
|
222
240
|
|
223
241
|
# Wait for model to finish loading
|
224
242
|
scheduler_info = []
|
@@ -0,0 +1,73 @@
|
|
1
|
+
import asyncio
|
2
|
+
import math
|
3
|
+
from typing import List, Union
|
4
|
+
|
5
|
+
import torch
|
6
|
+
from PIL import Image
|
7
|
+
|
8
|
+
from sglang.srt.managers.multimodal_processors.base_processor import (
|
9
|
+
BaseMultimodalProcessor as SGLangBaseProcessor,
|
10
|
+
)
|
11
|
+
from sglang.srt.managers.multimodal_processors.base_processor import (
|
12
|
+
MultimodalSpecialTokens,
|
13
|
+
)
|
14
|
+
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
15
|
+
from sglang.srt.models.kimi_vl import KimiVLForConditionalGeneration
|
16
|
+
|
17
|
+
|
18
|
+
# Compatible with KimiVLForConditionalGeneration
|
19
|
+
class KimiVLImageProcessor(SGLangBaseProcessor):
|
20
|
+
models = [KimiVLForConditionalGeneration]
|
21
|
+
|
22
|
+
def __init__(self, hf_config, server_args, _processor):
|
23
|
+
super().__init__(hf_config, server_args, _processor)
|
24
|
+
self.IMAGE_TOKEN = "<|media_pad|>"
|
25
|
+
self.im_token_id = _processor.tokenizer.convert_tokens_to_ids(self.IMAGE_TOKEN)
|
26
|
+
|
27
|
+
self.im_start = "<|media_start|>"
|
28
|
+
self.im_start_id = _processor.tokenizer.convert_tokens_to_ids(self.im_start)
|
29
|
+
|
30
|
+
self.im_end = "<|media_end|>"
|
31
|
+
self.im_end_id = _processor.tokenizer.convert_tokens_to_ids(self.im_end)
|
32
|
+
|
33
|
+
self.im_content = "<|media_content|>"
|
34
|
+
self.im_content_id = _processor.tokenizer.convert_tokens_to_ids(self.im_content)
|
35
|
+
|
36
|
+
async def process_mm_data_async(
|
37
|
+
self,
|
38
|
+
image_data: List[Union[str, bytes]],
|
39
|
+
input_text,
|
40
|
+
request_obj,
|
41
|
+
max_req_input_len,
|
42
|
+
*args,
|
43
|
+
**kwargs,
|
44
|
+
):
|
45
|
+
if not image_data:
|
46
|
+
return None
|
47
|
+
if isinstance(image_data, str):
|
48
|
+
image_data = [image_data]
|
49
|
+
|
50
|
+
base_output = self.load_mm_data(
|
51
|
+
prompt=input_text,
|
52
|
+
image_data=image_data,
|
53
|
+
multimodal_tokens=MultimodalSpecialTokens(image_token=self.IMAGE_TOKEN),
|
54
|
+
max_req_input_len=max_req_input_len,
|
55
|
+
)
|
56
|
+
ret = self.process_mm_data(
|
57
|
+
input_text=base_output.input_text,
|
58
|
+
images=base_output.images,
|
59
|
+
)
|
60
|
+
return {
|
61
|
+
"input_ids": ret["input_ids"].flatten().tolist(),
|
62
|
+
"mm_items": [
|
63
|
+
MultimodalDataItem(
|
64
|
+
pixel_values=ret["pixel_values"],
|
65
|
+
image_grid_thws=ret["image_grid_hws"],
|
66
|
+
modality=Modality.IMAGE,
|
67
|
+
)
|
68
|
+
],
|
69
|
+
"im_token_id": self.im_token_id,
|
70
|
+
"im_start_id": self.im_start_id,
|
71
|
+
"im_end_id": self.im_end_id,
|
72
|
+
"im_content_id": self.im_content_id,
|
73
|
+
}
|
@@ -66,23 +66,24 @@ INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
|
66
66
|
# Put some global args for easy access
|
67
67
|
global_server_args_dict = {
|
68
68
|
"attention_backend": ServerArgs.attention_backend,
|
69
|
-
"
|
70
|
-
"triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
|
71
|
-
"torchao_config": ServerArgs.torchao_config,
|
72
|
-
"enable_nan_detection": ServerArgs.enable_nan_detection,
|
73
|
-
"enable_dp_attention": ServerArgs.enable_dp_attention,
|
74
|
-
"enable_ep_moe": ServerArgs.enable_ep_moe,
|
75
|
-
"enable_deepep_moe": ServerArgs.enable_deepep_moe,
|
69
|
+
"chunked_prefill_size": ServerArgs.chunked_prefill_size,
|
76
70
|
"deepep_mode": ServerArgs.deepep_mode,
|
77
71
|
"device": ServerArgs.device,
|
78
|
-
"
|
79
|
-
"speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
|
72
|
+
"disable_chunked_prefix_cache": ServerArgs.disable_chunked_prefix_cache,
|
80
73
|
"disable_radix_cache": ServerArgs.disable_radix_cache,
|
74
|
+
"enable_deepep_moe": ServerArgs.enable_deepep_moe,
|
75
|
+
"enable_dp_attention": ServerArgs.enable_dp_attention,
|
76
|
+
"enable_ep_moe": ServerArgs.enable_ep_moe,
|
77
|
+
"enable_nan_detection": ServerArgs.enable_nan_detection,
|
81
78
|
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
|
79
|
+
"max_micro_batch_size": ServerArgs.max_micro_batch_size,
|
82
80
|
"moe_dense_tp_size": ServerArgs.moe_dense_tp_size,
|
83
|
-
"chunked_prefill_size": ServerArgs.chunked_prefill_size,
|
84
81
|
"n_share_experts_fusion": ServerArgs.n_share_experts_fusion,
|
85
|
-
"
|
82
|
+
"sampling_backend": ServerArgs.sampling_backend,
|
83
|
+
"speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
|
84
|
+
"speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
|
85
|
+
"torchao_config": ServerArgs.torchao_config,
|
86
|
+
"triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
|
86
87
|
}
|
87
88
|
|
88
89
|
logger = logging.getLogger(__name__)
|
@@ -728,6 +729,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
728
729
|
# Events
|
729
730
|
launch_done: Optional[threading.Event] = None
|
730
731
|
|
732
|
+
# For chunked prefill in PP
|
733
|
+
chunked_req: Optional[Req] = None
|
734
|
+
|
731
735
|
# Sampling info
|
732
736
|
sampling_info: SamplingBatchInfo = None
|
733
737
|
next_batch_sampling_info: SamplingBatchInfo = None
|
@@ -761,7 +765,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
761
765
|
# For extend and mixed chunekd prefill
|
762
766
|
prefix_lens: List[int] = None
|
763
767
|
extend_lens: List[int] = None
|
764
|
-
extend_num_tokens: int = None
|
768
|
+
extend_num_tokens: Optional[int] = None
|
765
769
|
decoding_reqs: List[Req] = None
|
766
770
|
extend_logprob_start_lens: List[int] = None
|
767
771
|
# It comes empty list if logprob is not required.
|
@@ -803,6 +807,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
803
807
|
enable_overlap: bool,
|
804
808
|
spec_algorithm: SpeculativeAlgorithm,
|
805
809
|
enable_custom_logit_processor: bool,
|
810
|
+
chunked_req: Optional[Req] = None,
|
806
811
|
):
|
807
812
|
return_logprob = any(req.return_logprob for req in reqs)
|
808
813
|
|
@@ -820,6 +825,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
820
825
|
spec_algorithm=spec_algorithm,
|
821
826
|
enable_custom_logit_processor=enable_custom_logit_processor,
|
822
827
|
return_hidden_states=any(req.return_hidden_states for req in reqs),
|
828
|
+
chunked_req=chunked_req,
|
823
829
|
)
|
824
830
|
|
825
831
|
def batch_size(self):
|
@@ -1236,7 +1242,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1236
1242
|
|
1237
1243
|
def retract_decode(self, server_args: ServerArgs):
|
1238
1244
|
"""Retract the decoding requests when there is not enough memory."""
|
1239
|
-
sorted_indices =
|
1245
|
+
sorted_indices = list(range(len(self.reqs)))
|
1240
1246
|
|
1241
1247
|
# TODO(lsyin): improve retraction policy for radix cache
|
1242
1248
|
# For spec decoding, filter_batch API can only filter
|
@@ -1413,15 +1419,19 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1413
1419
|
|
1414
1420
|
def filter_batch(
|
1415
1421
|
self,
|
1416
|
-
chunked_req_to_exclude: Optional[Req] = None,
|
1422
|
+
chunked_req_to_exclude: Optional[Union[Req, List[Req]]] = None,
|
1417
1423
|
keep_indices: Optional[List[int]] = None,
|
1418
1424
|
):
|
1419
1425
|
if keep_indices is None:
|
1426
|
+
if isinstance(chunked_req_to_exclude, Req):
|
1427
|
+
chunked_req_to_exclude = [chunked_req_to_exclude]
|
1428
|
+
elif chunked_req_to_exclude is None:
|
1429
|
+
chunked_req_to_exclude = []
|
1420
1430
|
keep_indices = [
|
1421
1431
|
i
|
1422
1432
|
for i in range(len(self.reqs))
|
1423
1433
|
if not self.reqs[i].finished()
|
1424
|
-
and self.reqs[i]
|
1434
|
+
and not self.reqs[i] in chunked_req_to_exclude
|
1425
1435
|
]
|
1426
1436
|
|
1427
1437
|
if keep_indices is None or len(keep_indices) == 0:
|