sglang 0.3.4.post2__py3-none-any.whl → 0.3.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +1 -1
- sglang/bench_latency.py +3 -3
- sglang/bench_server_latency.py +2 -3
- sglang/bench_serving.py +92 -0
- sglang/global_config.py +9 -3
- sglang/lang/chat_template.py +50 -25
- sglang/lang/interpreter.py +9 -1
- sglang/lang/ir.py +11 -2
- sglang/launch_server.py +1 -1
- sglang/srt/configs/model_config.py +51 -13
- sglang/srt/constrained/__init__.py +18 -0
- sglang/srt/constrained/bnf_cache.py +61 -0
- sglang/srt/constrained/grammar.py +190 -0
- sglang/srt/hf_transformers_utils.py +6 -5
- sglang/srt/layers/attention/triton_ops/decode_attention.py +110 -30
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +1 -1
- sglang/srt/layers/fused_moe/fused_moe.py +4 -3
- sglang/srt/layers/fused_moe/layer.py +28 -0
- sglang/srt/layers/quantization/base_config.py +16 -1
- sglang/srt/layers/vocab_parallel_embedding.py +486 -0
- sglang/srt/managers/data_parallel_controller.py +7 -6
- sglang/srt/managers/detokenizer_manager.py +9 -11
- sglang/srt/managers/image_processor.py +4 -3
- sglang/srt/managers/io_struct.py +70 -78
- sglang/srt/managers/schedule_batch.py +33 -49
- sglang/srt/managers/schedule_policy.py +24 -13
- sglang/srt/managers/scheduler.py +137 -80
- sglang/srt/managers/tokenizer_manager.py +224 -336
- sglang/srt/managers/tp_worker.py +5 -5
- sglang/srt/mem_cache/flush_cache.py +1 -1
- sglang/srt/model_executor/cuda_graph_runner.py +7 -4
- sglang/srt/model_executor/model_runner.py +8 -17
- sglang/srt/models/baichuan.py +4 -4
- sglang/srt/models/chatglm.py +4 -4
- sglang/srt/models/commandr.py +1 -1
- sglang/srt/models/dbrx.py +5 -5
- sglang/srt/models/deepseek.py +4 -4
- sglang/srt/models/deepseek_v2.py +4 -4
- sglang/srt/models/exaone.py +4 -4
- sglang/srt/models/gemma.py +1 -1
- sglang/srt/models/gemma2.py +1 -1
- sglang/srt/models/gpt2.py +287 -0
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/grok.py +4 -4
- sglang/srt/models/internlm2.py +4 -4
- sglang/srt/models/llama.py +15 -7
- sglang/srt/models/llama_embedding.py +2 -10
- sglang/srt/models/llama_reward.py +5 -0
- sglang/srt/models/minicpm.py +4 -4
- sglang/srt/models/minicpm3.py +4 -4
- sglang/srt/models/mixtral.py +7 -5
- sglang/srt/models/mixtral_quant.py +4 -4
- sglang/srt/models/mllama.py +5 -5
- sglang/srt/models/olmo.py +4 -4
- sglang/srt/models/olmoe.py +4 -4
- sglang/srt/models/qwen.py +4 -4
- sglang/srt/models/qwen2.py +4 -4
- sglang/srt/models/qwen2_moe.py +4 -4
- sglang/srt/models/qwen2_vl.py +4 -8
- sglang/srt/models/stablelm.py +4 -4
- sglang/srt/models/torch_native_llama.py +4 -4
- sglang/srt/models/xverse.py +4 -4
- sglang/srt/models/xverse_moe.py +4 -4
- sglang/srt/openai_api/adapter.py +52 -66
- sglang/srt/sampling/sampling_batch_info.py +7 -13
- sglang/srt/server.py +31 -35
- sglang/srt/server_args.py +34 -5
- sglang/srt/utils.py +40 -56
- sglang/test/runners.py +2 -1
- sglang/test/test_utils.py +73 -25
- sglang/utils.py +62 -1
- sglang/version.py +1 -1
- sglang-0.3.5.dist-info/METADATA +344 -0
- {sglang-0.3.4.post2.dist-info → sglang-0.3.5.dist-info}/RECORD +77 -73
- {sglang-0.3.4.post2.dist-info → sglang-0.3.5.dist-info}/WHEEL +1 -1
- sglang-0.3.4.post2.dist-info/METADATA +0 -899
- {sglang-0.3.4.post2.dist-info → sglang-0.3.5.dist-info}/LICENSE +0 -0
- {sglang-0.3.4.post2.dist-info → sglang-0.3.5.dist-info}/top_level.txt +0 -0
sglang/srt/managers/io_struct.py
CHANGED
@@ -56,49 +56,47 @@ class GenerateReqInput:
|
|
56
56
|
# LoRA related
|
57
57
|
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
58
58
|
|
59
|
-
|
60
|
-
is_single: bool = True
|
61
|
-
|
62
|
-
def post_init(self):
|
59
|
+
def normalize_batch_and_arguments(self):
|
63
60
|
if (self.text is None and self.input_ids is None) or (
|
64
61
|
self.text is not None and self.input_ids is not None
|
65
62
|
):
|
66
63
|
raise ValueError("Either text or input_ids should be provided.")
|
67
64
|
|
68
|
-
|
65
|
+
# Derive the batch size
|
69
66
|
if self.text is not None:
|
70
67
|
if isinstance(self.text, str):
|
71
68
|
self.is_single = True
|
72
69
|
self.batch_size = 1
|
73
70
|
else:
|
71
|
+
self.is_single = False
|
74
72
|
self.batch_size = len(self.text)
|
75
73
|
else:
|
76
74
|
if isinstance(self.input_ids[0], int):
|
77
75
|
self.is_single = True
|
78
76
|
self.batch_size = 1
|
79
77
|
else:
|
78
|
+
self.is_single = False
|
80
79
|
self.batch_size = len(self.input_ids)
|
81
80
|
|
81
|
+
# Handle parallel sampling
|
82
|
+
# When parallel sampling is used, we always treat the input as a batch.
|
82
83
|
if self.sampling_params is None:
|
83
84
|
self.parallel_sample_num = 1
|
84
85
|
elif isinstance(self.sampling_params, dict):
|
85
86
|
self.parallel_sample_num = self.sampling_params.get("n", 1)
|
86
87
|
else: # isinstance(self.sampling_params, list):
|
87
88
|
self.parallel_sample_num = self.sampling_params[0].get("n", 1)
|
88
|
-
for
|
89
|
-
|
90
|
-
assert self.parallel_sample_num == sp.get(
|
91
|
-
"n", 1
|
92
|
-
), "The parallel_sample_num should be the same for all samples in sample params."
|
93
|
-
|
94
|
-
if self.parallel_sample_num > 1:
|
95
|
-
if self.is_single:
|
96
|
-
self.is_single = False
|
97
|
-
if self.text is not None:
|
98
|
-
self.text = [self.text]
|
99
|
-
if self.input_ids is not None:
|
100
|
-
self.input_ids = [self.input_ids]
|
89
|
+
assert all(self.parallel_sample_num == sampling_params.get("n", 1) for sampling_params in self.sampling_params), (
|
90
|
+
"The parallel_sample_num should be the same for all samples in sample params.")
|
101
91
|
|
92
|
+
if self.parallel_sample_num > 1 and self.is_single:
|
93
|
+
self.is_single = False
|
94
|
+
if self.text is not None:
|
95
|
+
self.text = [self.text]
|
96
|
+
if self.input_ids is not None:
|
97
|
+
self.input_ids = [self.input_ids]
|
98
|
+
|
99
|
+
# Fill in default arguments
|
102
100
|
if self.is_single:
|
103
101
|
if self.sampling_params is None:
|
104
102
|
self.sampling_params = {}
|
@@ -114,9 +112,8 @@ class GenerateReqInput:
|
|
114
112
|
if self.parallel_sample_num == 1:
|
115
113
|
num = self.batch_size
|
116
114
|
else:
|
117
|
-
#
|
118
|
-
|
119
|
-
num = self.batch_size + self.parallel_sample_num * self.batch_size
|
115
|
+
# Expand parallel_sample_num
|
116
|
+
num = self.batch_size * self.parallel_sample_num
|
120
117
|
|
121
118
|
if self.image_data is None:
|
122
119
|
self.image_data = [None] * num
|
@@ -129,14 +126,11 @@ class GenerateReqInput:
|
|
129
126
|
self.sampling_params = [{}] * num
|
130
127
|
elif not isinstance(self.sampling_params, list):
|
131
128
|
self.sampling_params = [self.sampling_params] * num
|
132
|
-
else:
|
133
|
-
assert self.parallel_sample_num == 1
|
134
129
|
|
135
130
|
if self.rid is None:
|
136
131
|
self.rid = [uuid.uuid4().hex for _ in range(num)]
|
137
132
|
else:
|
138
133
|
assert isinstance(self.rid, list), "The rid should be a list."
|
139
|
-
assert self.parallel_sample_num == 1
|
140
134
|
|
141
135
|
if self.return_logprob is None:
|
142
136
|
self.return_logprob = [False] * num
|
@@ -159,6 +153,26 @@ class GenerateReqInput:
|
|
159
153
|
else:
|
160
154
|
assert self.parallel_sample_num == 1
|
161
155
|
|
156
|
+
def regenerate_rid(self):
|
157
|
+
self.rid = uuid.uuid4().hex
|
158
|
+
return self.rid
|
159
|
+
|
160
|
+
def __getitem__(self, i):
|
161
|
+
return GenerateReqInput(
|
162
|
+
text=self.text[i] if self.text is not None else None,
|
163
|
+
input_ids=self.input_ids[i] if self.input_ids is not None else None,
|
164
|
+
image_data=self.image_data[i],
|
165
|
+
sampling_params=self.sampling_params[i],
|
166
|
+
rid=self.rid[i],
|
167
|
+
return_logprob=self.return_logprob[i],
|
168
|
+
logprob_start_len=self.logprob_start_len[i],
|
169
|
+
top_logprobs_num=self.top_logprobs_num[i],
|
170
|
+
return_text_in_logprobs=self.return_text_in_logprobs,
|
171
|
+
stream=self.stream,
|
172
|
+
modalities=self.modalities[i] if self.modalities else None,
|
173
|
+
lora_path=self.lora_path[i] if self.lora_path is not None else None,
|
174
|
+
)
|
175
|
+
|
162
176
|
|
163
177
|
@dataclass
|
164
178
|
class TokenizedGenerateReqInput:
|
@@ -196,85 +210,61 @@ class EmbeddingReqInput:
|
|
196
210
|
# Dummy sampling params for compatibility
|
197
211
|
sampling_params: Union[List[Dict], Dict] = None
|
198
212
|
|
199
|
-
def
|
213
|
+
def normalize_batch_and_arguments(self):
|
200
214
|
if (self.text is None and self.input_ids is None) or (
|
201
215
|
self.text is not None and self.input_ids is not None
|
202
216
|
):
|
203
217
|
raise ValueError("Either text or input_ids should be provided.")
|
204
218
|
|
219
|
+
# Derive the batch size
|
205
220
|
if self.text is not None:
|
206
|
-
|
221
|
+
if isinstance(self.text, str):
|
222
|
+
self.is_single = True
|
223
|
+
self.batch_size = 1
|
224
|
+
else:
|
225
|
+
self.is_single = False
|
226
|
+
self.batch_size = len(self.text)
|
207
227
|
else:
|
208
|
-
|
228
|
+
if isinstance(self.input_ids[0], int):
|
229
|
+
self.is_single = True
|
230
|
+
self.batch_size = 1
|
231
|
+
else:
|
232
|
+
self.is_single = False
|
233
|
+
self.batch_size = len(self.input_ids)
|
209
234
|
|
235
|
+
# Fill in default arguments
|
210
236
|
if self.is_single:
|
211
237
|
if self.rid is None:
|
212
238
|
self.rid = uuid.uuid4().hex
|
213
239
|
if self.sampling_params is None:
|
214
240
|
self.sampling_params = {}
|
215
|
-
self.sampling_params["max_new_tokens"] =
|
241
|
+
self.sampling_params["max_new_tokens"] = 0
|
216
242
|
else:
|
217
|
-
# support select operation
|
218
|
-
self.batch_size = (
|
219
|
-
len(self.text) if self.text is not None else len(self.input_ids)
|
220
|
-
)
|
221
243
|
if self.rid is None:
|
222
244
|
self.rid = [uuid.uuid4().hex for _ in range(self.batch_size)]
|
223
245
|
else:
|
224
|
-
|
225
|
-
|
246
|
+
assert isinstance(self.rid, list), "The rid should be a list."
|
247
|
+
|
226
248
|
if self.sampling_params is None:
|
227
249
|
self.sampling_params = [{}] * self.batch_size
|
228
250
|
for i in range(self.batch_size):
|
229
|
-
self.sampling_params[i]["max_new_tokens"] =
|
251
|
+
self.sampling_params[i]["max_new_tokens"] = 0
|
230
252
|
|
253
|
+
def regenerate_rid(self):
|
254
|
+
self.rid = uuid.uuid4().hex
|
255
|
+
return self.rid
|
231
256
|
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
input_ids: List[int]
|
240
|
-
# Dummy sampling params for compatibility
|
241
|
-
sampling_params: SamplingParams
|
257
|
+
def __getitem__(self, i):
|
258
|
+
return EmbeddingReqInput(
|
259
|
+
text=self.text[i] if self.text is not None else None,
|
260
|
+
input_ids=self.input_ids[i] if self.input_ids is not None else None,
|
261
|
+
sampling_params=self.sampling_params[i],
|
262
|
+
rid=self.rid[i],
|
263
|
+
)
|
242
264
|
|
243
265
|
|
244
266
|
@dataclass
|
245
|
-
class
|
246
|
-
# The input prompt in the chat format. It can be a single prompt or a batch of prompts.
|
247
|
-
conv: Union[List[List[Dict]], List[Dict]]
|
248
|
-
# The request id.
|
249
|
-
rid: Optional[Union[List[str], str]] = None
|
250
|
-
# Dummy sampling params for compatibility
|
251
|
-
sampling_params: Union[List[Dict], Dict] = None
|
252
|
-
|
253
|
-
def post_init(self):
|
254
|
-
self.is_single = isinstance(self.conv[0], dict)
|
255
|
-
|
256
|
-
if self.is_single:
|
257
|
-
if self.rid is None:
|
258
|
-
self.rid = uuid.uuid4().hex
|
259
|
-
if self.sampling_params is None:
|
260
|
-
self.sampling_params = {}
|
261
|
-
self.sampling_params["max_new_tokens"] = 1
|
262
|
-
else:
|
263
|
-
# support select operation
|
264
|
-
self.batch_size = len(self.conv)
|
265
|
-
if self.rid is None:
|
266
|
-
self.rid = [uuid.uuid4().hex for _ in range(self.batch_size)]
|
267
|
-
else:
|
268
|
-
if not isinstance(self.rid, list):
|
269
|
-
raise ValueError("The rid should be a list.")
|
270
|
-
if self.sampling_params is None:
|
271
|
-
self.sampling_params = [{}] * self.batch_size
|
272
|
-
for i in range(self.batch_size):
|
273
|
-
self.sampling_params[i]["max_new_tokens"] = 1
|
274
|
-
|
275
|
-
|
276
|
-
@dataclass
|
277
|
-
class TokenizedRewardReqInput:
|
267
|
+
class TokenizedEmbeddingReqInput:
|
278
268
|
# The request id
|
279
269
|
rid: str
|
280
270
|
# The input text
|
@@ -294,6 +284,8 @@ class BatchTokenIDOut:
|
|
294
284
|
decoded_texts: List[str]
|
295
285
|
decode_ids: List[int]
|
296
286
|
read_offsets: List[int]
|
287
|
+
# Only used when `--skip-tokenizer-init`
|
288
|
+
output_ids: Optional[List[int]]
|
297
289
|
skip_special_tokens: List[bool]
|
298
290
|
spaces_between_special_tokens: List[bool]
|
299
291
|
meta_info: List[Dict]
|
@@ -37,8 +37,7 @@ import torch
|
|
37
37
|
|
38
38
|
from sglang.global_config import global_config
|
39
39
|
from sglang.srt.configs.model_config import ModelConfig
|
40
|
-
from sglang.srt.constrained import
|
41
|
-
from sglang.srt.constrained.jump_forward import JumpForwardMap
|
40
|
+
from sglang.srt.constrained.grammar import Grammar
|
42
41
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
43
42
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
44
43
|
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
@@ -212,9 +211,6 @@ class Req:
|
|
212
211
|
# this does not include the jump forward tokens.
|
213
212
|
self.completion_tokens_wo_jump_forward = 0
|
214
213
|
|
215
|
-
# The number of cached tokens, that were already cached in the KV store
|
216
|
-
self.cached_tokens = 0
|
217
|
-
|
218
214
|
# For vision inputs
|
219
215
|
self.image_inputs: Optional[ImageInputs] = None
|
220
216
|
|
@@ -222,7 +218,10 @@ class Req:
|
|
222
218
|
self.prefix_indices = []
|
223
219
|
self.extend_input_len = 0
|
224
220
|
self.last_node = None
|
225
|
-
self.
|
221
|
+
self.is_being_chunked = 0
|
222
|
+
|
223
|
+
# For retraction
|
224
|
+
self.is_retracted = False
|
226
225
|
|
227
226
|
# Logprobs (arguments)
|
228
227
|
self.return_logprob = False
|
@@ -243,13 +242,14 @@ class Req:
|
|
243
242
|
# The relative logprob_start_len in an extend batch
|
244
243
|
self.extend_logprob_start_len = 0
|
245
244
|
|
246
|
-
# Embedding
|
245
|
+
# Embedding (return values)
|
247
246
|
self.embedding = None
|
248
247
|
|
249
248
|
# Constrained decoding
|
250
|
-
self.
|
251
|
-
|
252
|
-
|
249
|
+
self.grammar: Optional[Grammar] = None
|
250
|
+
|
251
|
+
# The number of cached tokens, that were already cached in the KV cache
|
252
|
+
self.cached_tokens = 0
|
253
253
|
|
254
254
|
# For Qwen2-VL
|
255
255
|
self.mrope_position_delta = [] # use mutable object
|
@@ -359,6 +359,8 @@ class Req:
|
|
359
359
|
return
|
360
360
|
|
361
361
|
def jump_forward_and_retokenize(self, jump_forward_str, next_state):
|
362
|
+
assert self.grammar is not None and self.tokenizer is not None
|
363
|
+
|
362
364
|
if self.origin_input_text is None:
|
363
365
|
# Recovering text can only use unpadded ids
|
364
366
|
self.origin_input_text = self.tokenizer.decode(
|
@@ -398,7 +400,8 @@ class Req:
|
|
398
400
|
self.surr_offset = self.read_offset - i
|
399
401
|
break
|
400
402
|
|
401
|
-
|
403
|
+
# update the inner state of the grammar
|
404
|
+
self.grammar.jump_and_retokenize(old_output_ids, self.output_ids, next_state)
|
402
405
|
|
403
406
|
if self.return_logprob:
|
404
407
|
# For fast-forward part's logprobs
|
@@ -468,8 +471,8 @@ class ScheduleBatch:
|
|
468
471
|
# Stream
|
469
472
|
has_stream: bool = False
|
470
473
|
|
471
|
-
# Has
|
472
|
-
|
474
|
+
# Has grammar
|
475
|
+
has_grammar: bool = False
|
473
476
|
|
474
477
|
# device
|
475
478
|
device: str = "cuda"
|
@@ -477,7 +480,7 @@ class ScheduleBatch:
|
|
477
480
|
@classmethod
|
478
481
|
def init_new(
|
479
482
|
cls,
|
480
|
-
reqs,
|
483
|
+
reqs: List[Req],
|
481
484
|
req_to_token_pool,
|
482
485
|
token_to_kv_pool,
|
483
486
|
tree_cache,
|
@@ -491,7 +494,7 @@ class ScheduleBatch:
|
|
491
494
|
model_config=model_config,
|
492
495
|
return_logprob=any(req.return_logprob for req in reqs),
|
493
496
|
has_stream=any(req.stream for req in reqs),
|
494
|
-
|
497
|
+
has_grammar=any(req.grammar for req in reqs),
|
495
498
|
device=req_to_token_pool.device,
|
496
499
|
)
|
497
500
|
|
@@ -561,7 +564,7 @@ class ScheduleBatch:
|
|
561
564
|
seq_lens[i] -= encoder_len
|
562
565
|
|
563
566
|
if len(req.prefix_indices) < encoder_len:
|
564
|
-
# NOTE: the encoder part should considered as a whole
|
567
|
+
# NOTE: the encoder part should be considered as a whole
|
565
568
|
assert len(req.prefix_indices) == 0
|
566
569
|
input_ids[i] = input_ids[i][encoder_len:]
|
567
570
|
encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len])
|
@@ -648,6 +651,7 @@ class ScheduleBatch:
|
|
648
651
|
|
649
652
|
req.extend_logprob_start_len = extend_logprob_start_len
|
650
653
|
pt += req.extend_input_len
|
654
|
+
req.is_retracted = False
|
651
655
|
|
652
656
|
# Set fields
|
653
657
|
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
|
@@ -780,6 +784,7 @@ class ScheduleBatch:
|
|
780
784
|
req.prefix_indices = []
|
781
785
|
req.last_node = None
|
782
786
|
req.extend_input_len = 0
|
787
|
+
req.is_retracted = True
|
783
788
|
|
784
789
|
# For incremental logprobs
|
785
790
|
req.last_update_decode_tokens = 0
|
@@ -803,26 +808,10 @@ class ScheduleBatch:
|
|
803
808
|
keep_indices = set(i for i in range(len(self.reqs)))
|
804
809
|
|
805
810
|
for i, req in enumerate(self.reqs):
|
806
|
-
if req.
|
807
|
-
|
808
|
-
|
809
|
-
|
810
|
-
if jump_forward_bytes is not None and len(jump_forward_bytes) > 1:
|
811
|
-
suffix_bytes = []
|
812
|
-
continuation_range = range(0x80, 0xC0)
|
813
|
-
cur_state = req.regex_fsm_state
|
814
|
-
while (
|
815
|
-
len(jump_forward_bytes)
|
816
|
-
and jump_forward_bytes[0][0] in continuation_range
|
817
|
-
):
|
818
|
-
# continuation bytes
|
819
|
-
byte_edge = jump_forward_bytes.pop(0)
|
820
|
-
suffix_bytes.append(byte_edge[0])
|
821
|
-
cur_state = byte_edge[1]
|
822
|
-
|
823
|
-
suffix_tokens = [f"<0x{hex(b)[2:].upper()}>" for b in suffix_bytes]
|
824
|
-
suffix_ids = req.tokenizer.convert_tokens_to_ids(suffix_tokens)
|
825
|
-
|
811
|
+
if req.grammar is not None:
|
812
|
+
jump_helper = req.grammar.try_jump(req.tokenizer)
|
813
|
+
if jump_helper.can_jump():
|
814
|
+
suffix_ids = jump_helper.suffix_ids
|
826
815
|
# Current ids, for cache and revert
|
827
816
|
cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1]
|
828
817
|
cur_output_ids = req.output_ids
|
@@ -836,10 +825,8 @@ class ScheduleBatch:
|
|
836
825
|
(
|
837
826
|
jump_forward_str,
|
838
827
|
next_state,
|
839
|
-
) = req.
|
828
|
+
) = req.grammar.jump_forward_str_state(jump_helper)
|
840
829
|
|
841
|
-
# Make the incrementally decoded text part of jump_forward_str
|
842
|
-
# so that the UTF-8 will not corrupt
|
843
830
|
jump_forward_str = new_text + jump_forward_str
|
844
831
|
if not req.jump_forward_and_retokenize(
|
845
832
|
jump_forward_str, next_state
|
@@ -906,7 +893,7 @@ class ScheduleBatch:
|
|
906
893
|
|
907
894
|
def filter_batch(
|
908
895
|
self,
|
909
|
-
|
896
|
+
being_chunked_req: Optional[Req] = None,
|
910
897
|
keep_indices: Optional[List[int]] = None,
|
911
898
|
):
|
912
899
|
if keep_indices is None:
|
@@ -914,7 +901,7 @@ class ScheduleBatch:
|
|
914
901
|
i
|
915
902
|
for i in range(len(self.reqs))
|
916
903
|
if not self.reqs[i].finished()
|
917
|
-
and self.reqs[i] is not
|
904
|
+
and self.reqs[i] is not being_chunked_req
|
918
905
|
]
|
919
906
|
|
920
907
|
if keep_indices is None or len(keep_indices) == 0:
|
@@ -946,7 +933,7 @@ class ScheduleBatch:
|
|
946
933
|
self.top_logprobs_nums = None
|
947
934
|
|
948
935
|
self.has_stream = any(req.stream for req in self.reqs)
|
949
|
-
self.
|
936
|
+
self.has_grammar = any(req.grammar for req in self.reqs)
|
950
937
|
|
951
938
|
self.sampling_info.filter_batch(keep_indices, new_indices)
|
952
939
|
|
@@ -979,7 +966,7 @@ class ScheduleBatch:
|
|
979
966
|
|
980
967
|
self.return_logprob = self.return_logprob or other.return_logprob
|
981
968
|
self.has_stream = self.has_stream or other.has_stream
|
982
|
-
self.
|
969
|
+
self.has_grammar = self.has_grammar or other.has_grammar
|
983
970
|
|
984
971
|
def get_model_worker_batch(self):
|
985
972
|
if self.forward_mode.is_decode():
|
@@ -989,13 +976,10 @@ class ScheduleBatch:
|
|
989
976
|
extend_prefix_lens = self.prefix_lens
|
990
977
|
extend_logprob_start_lens = self.extend_logprob_start_lens
|
991
978
|
|
992
|
-
if self.
|
993
|
-
self.sampling_info.
|
994
|
-
self.sampling_info.regex_fsm_states = [
|
995
|
-
req.regex_fsm_state for req in self.reqs
|
996
|
-
]
|
979
|
+
if self.has_grammar:
|
980
|
+
self.sampling_info.grammars = [req.grammar for req in self.reqs]
|
997
981
|
else:
|
998
|
-
self.sampling_info.
|
982
|
+
self.sampling_info.grammars = None
|
999
983
|
|
1000
984
|
global bid
|
1001
985
|
bid += 1
|
@@ -30,7 +30,9 @@ from sglang.srt.mem_cache.radix_cache import TreeNode
|
|
30
30
|
# This can prevent the server from being too conservative.
|
31
31
|
# Note that this only clips the estimation in the scheduler but does not change the stop
|
32
32
|
# condition. The request can still generate tokens until it hits the unclipped max_new_tokens.
|
33
|
-
|
33
|
+
CLIP_MAX_NEW_TOKENS_ESTIMATION = int(
|
34
|
+
os.environ.get("SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION", "4096")
|
35
|
+
)
|
34
36
|
|
35
37
|
|
36
38
|
class SchedulePolicy:
|
@@ -43,9 +45,15 @@ class SchedulePolicy:
|
|
43
45
|
self.tree_cache = tree_cache
|
44
46
|
|
45
47
|
def calc_priority(self, waiting_queue: List[Req]):
|
48
|
+
if len(waiting_queue) > 128 and self.policy == "lpm":
|
49
|
+
# Turn off the expensive prefix matching and sorting when the #queue is large.
|
50
|
+
policy = "fcfs"
|
51
|
+
else:
|
52
|
+
policy = self.policy
|
53
|
+
|
46
54
|
# Compute matched prefix length
|
47
55
|
prefix_computed = False
|
48
|
-
if
|
56
|
+
if policy == "lpm" or policy == "dfs-weight":
|
49
57
|
for r in waiting_queue:
|
50
58
|
# NOTE: the prefix_indices must always be aligned with last_node
|
51
59
|
r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
|
@@ -54,18 +62,18 @@ class SchedulePolicy:
|
|
54
62
|
|
55
63
|
prefix_computed = True
|
56
64
|
|
57
|
-
if
|
65
|
+
if policy == "lpm":
|
58
66
|
# Longest Prefix Match
|
59
67
|
waiting_queue.sort(key=lambda x: -len(x.prefix_indices))
|
60
|
-
elif
|
68
|
+
elif policy == "fcfs":
|
61
69
|
# first come first serve
|
62
70
|
pass
|
63
|
-
elif
|
71
|
+
elif policy == "lof":
|
64
72
|
# longest output first
|
65
73
|
waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens)
|
66
|
-
elif
|
74
|
+
elif policy == "random":
|
67
75
|
random.shuffle(waiting_queue)
|
68
|
-
elif
|
76
|
+
elif policy == "dfs-weight":
|
69
77
|
last_node_to_reqs = defaultdict(list)
|
70
78
|
for req in waiting_queue:
|
71
79
|
last_node_to_reqs[req.last_node].append(req)
|
@@ -83,7 +91,7 @@ class SchedulePolicy:
|
|
83
91
|
waiting_queue,
|
84
92
|
)
|
85
93
|
else:
|
86
|
-
raise ValueError(f"Unknown schedule_policy: {
|
94
|
+
raise ValueError(f"Unknown schedule_policy: {policy=}")
|
87
95
|
|
88
96
|
return prefix_computed
|
89
97
|
|
@@ -146,7 +154,7 @@ class PrefillAdder:
|
|
146
154
|
[
|
147
155
|
min(
|
148
156
|
(r.sampling_params.max_new_tokens - len(r.output_ids)),
|
149
|
-
|
157
|
+
CLIP_MAX_NEW_TOKENS_ESTIMATION,
|
150
158
|
)
|
151
159
|
* self.new_token_ratio
|
152
160
|
for r in running_batch.reqs
|
@@ -186,7 +194,7 @@ class PrefillAdder:
|
|
186
194
|
len(req.prefix_indices),
|
187
195
|
req.extend_input_len,
|
188
196
|
(
|
189
|
-
min(req.sampling_params.max_new_tokens,
|
197
|
+
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION)
|
190
198
|
if not truncated
|
191
199
|
else 0
|
192
200
|
),
|
@@ -258,7 +266,7 @@ class PrefillAdder:
|
|
258
266
|
self._prefill_one_req(
|
259
267
|
0,
|
260
268
|
req.extend_input_len,
|
261
|
-
min(req.sampling_params.max_new_tokens,
|
269
|
+
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION),
|
262
270
|
)
|
263
271
|
else:
|
264
272
|
# Chunked prefill
|
@@ -276,7 +284,7 @@ class PrefillAdder:
|
|
276
284
|
return self.add_one_req_ignore_eos(req)
|
277
285
|
|
278
286
|
total_tokens = req.extend_input_len + min(
|
279
|
-
req.sampling_params.max_new_tokens,
|
287
|
+
req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION
|
280
288
|
)
|
281
289
|
input_tokens = req.extend_input_len
|
282
290
|
prefix_len = len(req.prefix_indices)
|
@@ -302,7 +310,10 @@ class PrefillAdder:
|
|
302
310
|
self._prefill_one_req(
|
303
311
|
prefix_len,
|
304
312
|
input_tokens,
|
305
|
-
min(
|
313
|
+
min(
|
314
|
+
req.sampling_params.max_new_tokens,
|
315
|
+
CLIP_MAX_NEW_TOKENS_ESTIMATION,
|
316
|
+
),
|
306
317
|
)
|
307
318
|
else:
|
308
319
|
# Chunked prefill
|