sglang 0.3.4.post1__py3-none-any.whl → 0.3.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +1 -1
- sglang/bench_latency.py +3 -3
- sglang/bench_server_latency.py +2 -3
- sglang/bench_serving.py +92 -0
- sglang/global_config.py +9 -3
- sglang/lang/chat_template.py +50 -25
- sglang/lang/interpreter.py +9 -1
- sglang/lang/ir.py +11 -2
- sglang/launch_server.py +1 -1
- sglang/srt/configs/model_config.py +76 -15
- sglang/srt/constrained/__init__.py +18 -0
- sglang/srt/constrained/bnf_cache.py +61 -0
- sglang/srt/constrained/fsm_cache.py +10 -3
- sglang/srt/constrained/grammar.py +190 -0
- sglang/srt/hf_transformers_utils.py +20 -5
- sglang/srt/layers/attention/flashinfer_backend.py +5 -5
- sglang/srt/layers/attention/triton_ops/decode_attention.py +110 -30
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +1 -1
- sglang/srt/layers/fused_moe/fused_moe.py +4 -3
- sglang/srt/layers/fused_moe/layer.py +28 -0
- sglang/srt/layers/logits_processor.py +5 -5
- sglang/srt/layers/quantization/base_config.py +16 -1
- sglang/srt/layers/rotary_embedding.py +15 -48
- sglang/srt/layers/sampler.py +51 -39
- sglang/srt/layers/vocab_parallel_embedding.py +486 -0
- sglang/srt/managers/data_parallel_controller.py +8 -7
- sglang/srt/managers/detokenizer_manager.py +11 -9
- sglang/srt/managers/image_processor.py +4 -3
- sglang/srt/managers/io_struct.py +80 -78
- sglang/srt/managers/schedule_batch.py +46 -52
- sglang/srt/managers/schedule_policy.py +24 -13
- sglang/srt/managers/scheduler.py +145 -82
- sglang/srt/managers/tokenizer_manager.py +236 -334
- sglang/srt/managers/tp_worker.py +5 -5
- sglang/srt/managers/tp_worker_overlap_thread.py +58 -21
- sglang/srt/mem_cache/flush_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +10 -3
- sglang/srt/model_executor/cuda_graph_runner.py +34 -23
- sglang/srt/model_executor/forward_batch_info.py +6 -9
- sglang/srt/model_executor/model_runner.py +10 -19
- sglang/srt/models/baichuan.py +4 -4
- sglang/srt/models/chatglm.py +4 -4
- sglang/srt/models/commandr.py +1 -1
- sglang/srt/models/dbrx.py +5 -5
- sglang/srt/models/deepseek.py +4 -4
- sglang/srt/models/deepseek_v2.py +4 -4
- sglang/srt/models/exaone.py +4 -4
- sglang/srt/models/gemma.py +1 -1
- sglang/srt/models/gemma2.py +1 -1
- sglang/srt/models/gpt2.py +287 -0
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/grok.py +4 -4
- sglang/srt/models/internlm2.py +4 -4
- sglang/srt/models/llama.py +15 -7
- sglang/srt/models/llama_embedding.py +2 -10
- sglang/srt/models/llama_reward.py +5 -0
- sglang/srt/models/minicpm.py +4 -4
- sglang/srt/models/minicpm3.py +4 -4
- sglang/srt/models/mixtral.py +7 -5
- sglang/srt/models/mixtral_quant.py +4 -4
- sglang/srt/models/mllama.py +5 -5
- sglang/srt/models/olmo.py +4 -4
- sglang/srt/models/olmoe.py +4 -4
- sglang/srt/models/qwen.py +4 -4
- sglang/srt/models/qwen2.py +4 -4
- sglang/srt/models/qwen2_moe.py +4 -4
- sglang/srt/models/qwen2_vl.py +4 -8
- sglang/srt/models/stablelm.py +4 -4
- sglang/srt/models/torch_native_llama.py +4 -4
- sglang/srt/models/xverse.py +4 -4
- sglang/srt/models/xverse_moe.py +4 -4
- sglang/srt/openai_api/adapter.py +52 -66
- sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +6 -3
- sglang/srt/sampling/sampling_batch_info.py +7 -13
- sglang/srt/sampling/sampling_params.py +5 -7
- sglang/srt/server.py +41 -33
- sglang/srt/server_args.py +34 -5
- sglang/srt/utils.py +40 -56
- sglang/test/run_eval.py +2 -0
- sglang/test/runners.py +2 -1
- sglang/test/srt/sampling/penaltylib/utils.py +1 -0
- sglang/test/test_utils.py +151 -6
- sglang/utils.py +62 -1
- sglang/version.py +1 -1
- sglang-0.3.5.dist-info/METADATA +344 -0
- sglang-0.3.5.dist-info/RECORD +152 -0
- {sglang-0.3.4.post1.dist-info → sglang-0.3.5.dist-info}/WHEEL +1 -1
- sglang-0.3.4.post1.dist-info/METADATA +0 -900
- sglang-0.3.4.post1.dist-info/RECORD +0 -148
- {sglang-0.3.4.post1.dist-info → sglang-0.3.5.dist-info}/LICENSE +0 -0
- {sglang-0.3.4.post1.dist-info → sglang-0.3.5.dist-info}/top_level.txt +0 -0
sglang/srt/managers/io_struct.py
CHANGED
@@ -56,49 +56,47 @@ class GenerateReqInput:
|
|
56
56
|
# LoRA related
|
57
57
|
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
58
58
|
|
59
|
-
|
60
|
-
is_single: bool = True
|
61
|
-
|
62
|
-
def post_init(self):
|
59
|
+
def normalize_batch_and_arguments(self):
|
63
60
|
if (self.text is None and self.input_ids is None) or (
|
64
61
|
self.text is not None and self.input_ids is not None
|
65
62
|
):
|
66
63
|
raise ValueError("Either text or input_ids should be provided.")
|
67
64
|
|
68
|
-
|
65
|
+
# Derive the batch size
|
69
66
|
if self.text is not None:
|
70
67
|
if isinstance(self.text, str):
|
71
68
|
self.is_single = True
|
72
69
|
self.batch_size = 1
|
73
70
|
else:
|
71
|
+
self.is_single = False
|
74
72
|
self.batch_size = len(self.text)
|
75
73
|
else:
|
76
74
|
if isinstance(self.input_ids[0], int):
|
77
75
|
self.is_single = True
|
78
76
|
self.batch_size = 1
|
79
77
|
else:
|
78
|
+
self.is_single = False
|
80
79
|
self.batch_size = len(self.input_ids)
|
81
80
|
|
81
|
+
# Handle parallel sampling
|
82
|
+
# When parallel sampling is used, we always treat the input as a batch.
|
82
83
|
if self.sampling_params is None:
|
83
84
|
self.parallel_sample_num = 1
|
84
85
|
elif isinstance(self.sampling_params, dict):
|
85
86
|
self.parallel_sample_num = self.sampling_params.get("n", 1)
|
86
87
|
else: # isinstance(self.sampling_params, list):
|
87
88
|
self.parallel_sample_num = self.sampling_params[0].get("n", 1)
|
88
|
-
for
|
89
|
-
|
90
|
-
assert self.parallel_sample_num == sp.get(
|
91
|
-
"n", 1
|
92
|
-
), "The parallel_sample_num should be the same for all samples in sample params."
|
93
|
-
|
94
|
-
if self.parallel_sample_num > 1:
|
95
|
-
if self.is_single:
|
96
|
-
self.is_single = False
|
97
|
-
if self.text is not None:
|
98
|
-
self.text = [self.text]
|
99
|
-
if self.input_ids is not None:
|
100
|
-
self.input_ids = [self.input_ids]
|
89
|
+
assert all(self.parallel_sample_num == sampling_params.get("n", 1) for sampling_params in self.sampling_params), (
|
90
|
+
"The parallel_sample_num should be the same for all samples in sample params.")
|
101
91
|
|
92
|
+
if self.parallel_sample_num > 1 and self.is_single:
|
93
|
+
self.is_single = False
|
94
|
+
if self.text is not None:
|
95
|
+
self.text = [self.text]
|
96
|
+
if self.input_ids is not None:
|
97
|
+
self.input_ids = [self.input_ids]
|
98
|
+
|
99
|
+
# Fill in default arguments
|
102
100
|
if self.is_single:
|
103
101
|
if self.sampling_params is None:
|
104
102
|
self.sampling_params = {}
|
@@ -114,9 +112,8 @@ class GenerateReqInput:
|
|
114
112
|
if self.parallel_sample_num == 1:
|
115
113
|
num = self.batch_size
|
116
114
|
else:
|
117
|
-
#
|
118
|
-
|
119
|
-
num = self.batch_size + self.parallel_sample_num * self.batch_size
|
115
|
+
# Expand parallel_sample_num
|
116
|
+
num = self.batch_size * self.parallel_sample_num
|
120
117
|
|
121
118
|
if self.image_data is None:
|
122
119
|
self.image_data = [None] * num
|
@@ -129,14 +126,11 @@ class GenerateReqInput:
|
|
129
126
|
self.sampling_params = [{}] * num
|
130
127
|
elif not isinstance(self.sampling_params, list):
|
131
128
|
self.sampling_params = [self.sampling_params] * num
|
132
|
-
else:
|
133
|
-
assert self.parallel_sample_num == 1
|
134
129
|
|
135
130
|
if self.rid is None:
|
136
131
|
self.rid = [uuid.uuid4().hex for _ in range(num)]
|
137
132
|
else:
|
138
133
|
assert isinstance(self.rid, list), "The rid should be a list."
|
139
|
-
assert self.parallel_sample_num == 1
|
140
134
|
|
141
135
|
if self.return_logprob is None:
|
142
136
|
self.return_logprob = [False] * num
|
@@ -159,6 +153,26 @@ class GenerateReqInput:
|
|
159
153
|
else:
|
160
154
|
assert self.parallel_sample_num == 1
|
161
155
|
|
156
|
+
def regenerate_rid(self):
|
157
|
+
self.rid = uuid.uuid4().hex
|
158
|
+
return self.rid
|
159
|
+
|
160
|
+
def __getitem__(self, i):
|
161
|
+
return GenerateReqInput(
|
162
|
+
text=self.text[i] if self.text is not None else None,
|
163
|
+
input_ids=self.input_ids[i] if self.input_ids is not None else None,
|
164
|
+
image_data=self.image_data[i],
|
165
|
+
sampling_params=self.sampling_params[i],
|
166
|
+
rid=self.rid[i],
|
167
|
+
return_logprob=self.return_logprob[i],
|
168
|
+
logprob_start_len=self.logprob_start_len[i],
|
169
|
+
top_logprobs_num=self.top_logprobs_num[i],
|
170
|
+
return_text_in_logprobs=self.return_text_in_logprobs,
|
171
|
+
stream=self.stream,
|
172
|
+
modalities=self.modalities[i] if self.modalities else None,
|
173
|
+
lora_path=self.lora_path[i] if self.lora_path is not None else None,
|
174
|
+
)
|
175
|
+
|
162
176
|
|
163
177
|
@dataclass
|
164
178
|
class TokenizedGenerateReqInput:
|
@@ -196,85 +210,61 @@ class EmbeddingReqInput:
|
|
196
210
|
# Dummy sampling params for compatibility
|
197
211
|
sampling_params: Union[List[Dict], Dict] = None
|
198
212
|
|
199
|
-
def
|
213
|
+
def normalize_batch_and_arguments(self):
|
200
214
|
if (self.text is None and self.input_ids is None) or (
|
201
215
|
self.text is not None and self.input_ids is not None
|
202
216
|
):
|
203
217
|
raise ValueError("Either text or input_ids should be provided.")
|
204
218
|
|
219
|
+
# Derive the batch size
|
205
220
|
if self.text is not None:
|
206
|
-
|
221
|
+
if isinstance(self.text, str):
|
222
|
+
self.is_single = True
|
223
|
+
self.batch_size = 1
|
224
|
+
else:
|
225
|
+
self.is_single = False
|
226
|
+
self.batch_size = len(self.text)
|
207
227
|
else:
|
208
|
-
|
228
|
+
if isinstance(self.input_ids[0], int):
|
229
|
+
self.is_single = True
|
230
|
+
self.batch_size = 1
|
231
|
+
else:
|
232
|
+
self.is_single = False
|
233
|
+
self.batch_size = len(self.input_ids)
|
209
234
|
|
235
|
+
# Fill in default arguments
|
210
236
|
if self.is_single:
|
211
237
|
if self.rid is None:
|
212
238
|
self.rid = uuid.uuid4().hex
|
213
239
|
if self.sampling_params is None:
|
214
240
|
self.sampling_params = {}
|
215
|
-
self.sampling_params["max_new_tokens"] =
|
241
|
+
self.sampling_params["max_new_tokens"] = 0
|
216
242
|
else:
|
217
|
-
# support select operation
|
218
|
-
self.batch_size = (
|
219
|
-
len(self.text) if self.text is not None else len(self.input_ids)
|
220
|
-
)
|
221
243
|
if self.rid is None:
|
222
244
|
self.rid = [uuid.uuid4().hex for _ in range(self.batch_size)]
|
223
245
|
else:
|
224
|
-
|
225
|
-
|
246
|
+
assert isinstance(self.rid, list), "The rid should be a list."
|
247
|
+
|
226
248
|
if self.sampling_params is None:
|
227
249
|
self.sampling_params = [{}] * self.batch_size
|
228
250
|
for i in range(self.batch_size):
|
229
|
-
self.sampling_params[i]["max_new_tokens"] =
|
230
|
-
|
231
|
-
|
232
|
-
@dataclass
|
233
|
-
class TokenizedEmbeddingReqInput:
|
234
|
-
# The request id
|
235
|
-
rid: str
|
236
|
-
# The input text
|
237
|
-
input_text: str
|
238
|
-
# The input token ids
|
239
|
-
input_ids: List[int]
|
240
|
-
# Dummy sampling params for compatibility
|
241
|
-
sampling_params: SamplingParams
|
251
|
+
self.sampling_params[i]["max_new_tokens"] = 0
|
242
252
|
|
253
|
+
def regenerate_rid(self):
|
254
|
+
self.rid = uuid.uuid4().hex
|
255
|
+
return self.rid
|
243
256
|
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
sampling_params: Union[List[Dict], Dict] = None
|
252
|
-
|
253
|
-
def post_init(self):
|
254
|
-
self.is_single = isinstance(self.conv[0], dict)
|
255
|
-
|
256
|
-
if self.is_single:
|
257
|
-
if self.rid is None:
|
258
|
-
self.rid = uuid.uuid4().hex
|
259
|
-
if self.sampling_params is None:
|
260
|
-
self.sampling_params = {}
|
261
|
-
self.sampling_params["max_new_tokens"] = 1
|
262
|
-
else:
|
263
|
-
# support select operation
|
264
|
-
self.batch_size = len(self.conv)
|
265
|
-
if self.rid is None:
|
266
|
-
self.rid = [uuid.uuid4().hex for _ in range(self.batch_size)]
|
267
|
-
else:
|
268
|
-
if not isinstance(self.rid, list):
|
269
|
-
raise ValueError("The rid should be a list.")
|
270
|
-
if self.sampling_params is None:
|
271
|
-
self.sampling_params = [{}] * self.batch_size
|
272
|
-
for i in range(self.batch_size):
|
273
|
-
self.sampling_params[i]["max_new_tokens"] = 1
|
257
|
+
def __getitem__(self, i):
|
258
|
+
return EmbeddingReqInput(
|
259
|
+
text=self.text[i] if self.text is not None else None,
|
260
|
+
input_ids=self.input_ids[i] if self.input_ids is not None else None,
|
261
|
+
sampling_params=self.sampling_params[i],
|
262
|
+
rid=self.rid[i],
|
263
|
+
)
|
274
264
|
|
275
265
|
|
276
266
|
@dataclass
|
277
|
-
class
|
267
|
+
class TokenizedEmbeddingReqInput:
|
278
268
|
# The request id
|
279
269
|
rid: str
|
280
270
|
# The input text
|
@@ -294,6 +284,8 @@ class BatchTokenIDOut:
|
|
294
284
|
decoded_texts: List[str]
|
295
285
|
decode_ids: List[int]
|
296
286
|
read_offsets: List[int]
|
287
|
+
# Only used when `--skip-tokenizer-init`
|
288
|
+
output_ids: Optional[List[int]]
|
297
289
|
skip_special_tokens: List[bool]
|
298
290
|
spaces_between_special_tokens: List[bool]
|
299
291
|
meta_info: List[Dict]
|
@@ -353,3 +345,13 @@ class AbortReq:
|
|
353
345
|
class ProfileReq(Enum):
|
354
346
|
START_PROFILE = 1
|
355
347
|
STOP_PROFILE = 2
|
348
|
+
|
349
|
+
|
350
|
+
@dataclass
|
351
|
+
class GetMemPoolSizeReq:
|
352
|
+
pass
|
353
|
+
|
354
|
+
|
355
|
+
@dataclass
|
356
|
+
class GetMemPoolSizeReqOutput:
|
357
|
+
size: int
|
@@ -37,8 +37,7 @@ import torch
|
|
37
37
|
|
38
38
|
from sglang.global_config import global_config
|
39
39
|
from sglang.srt.configs.model_config import ModelConfig
|
40
|
-
from sglang.srt.constrained import
|
41
|
-
from sglang.srt.constrained.jump_forward import JumpForwardMap
|
40
|
+
from sglang.srt.constrained.grammar import Grammar
|
42
41
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
43
42
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
44
43
|
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
@@ -212,9 +211,6 @@ class Req:
|
|
212
211
|
# this does not include the jump forward tokens.
|
213
212
|
self.completion_tokens_wo_jump_forward = 0
|
214
213
|
|
215
|
-
# The number of cached tokens, that were already cached in the KV store
|
216
|
-
self.cached_tokens = 0
|
217
|
-
|
218
214
|
# For vision inputs
|
219
215
|
self.image_inputs: Optional[ImageInputs] = None
|
220
216
|
|
@@ -222,7 +218,10 @@ class Req:
|
|
222
218
|
self.prefix_indices = []
|
223
219
|
self.extend_input_len = 0
|
224
220
|
self.last_node = None
|
225
|
-
self.
|
221
|
+
self.is_being_chunked = 0
|
222
|
+
|
223
|
+
# For retraction
|
224
|
+
self.is_retracted = False
|
226
225
|
|
227
226
|
# Logprobs (arguments)
|
228
227
|
self.return_logprob = False
|
@@ -243,13 +242,14 @@ class Req:
|
|
243
242
|
# The relative logprob_start_len in an extend batch
|
244
243
|
self.extend_logprob_start_len = 0
|
245
244
|
|
246
|
-
# Embedding
|
245
|
+
# Embedding (return values)
|
247
246
|
self.embedding = None
|
248
247
|
|
249
248
|
# Constrained decoding
|
250
|
-
self.
|
251
|
-
|
252
|
-
|
249
|
+
self.grammar: Optional[Grammar] = None
|
250
|
+
|
251
|
+
# The number of cached tokens, that were already cached in the KV cache
|
252
|
+
self.cached_tokens = 0
|
253
253
|
|
254
254
|
# For Qwen2-VL
|
255
255
|
self.mrope_position_delta = [] # use mutable object
|
@@ -334,15 +334,20 @@ class Req:
|
|
334
334
|
|
335
335
|
last_token_id = self.output_ids[-1]
|
336
336
|
|
337
|
-
matched_eos =
|
337
|
+
matched_eos = False
|
338
338
|
|
339
|
+
# Check stop token ids
|
340
|
+
if self.sampling_params.stop_token_ids:
|
341
|
+
matched_eos = last_token_id in self.sampling_params.stop_token_ids
|
339
342
|
if self.tokenizer is not None:
|
340
343
|
matched_eos |= last_token_id == self.tokenizer.eos_token_id
|
341
|
-
|
344
|
+
if self.tokenizer.additional_stop_token_ids:
|
345
|
+
matched_eos |= last_token_id in self.tokenizer.additional_stop_token_ids
|
342
346
|
if matched_eos and not self.sampling_params.ignore_eos:
|
343
347
|
self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
|
344
348
|
return
|
345
349
|
|
350
|
+
# Check stop strings
|
346
351
|
if len(self.sampling_params.stop_strs) > 0:
|
347
352
|
tail_str = self.tokenizer.decode(
|
348
353
|
self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
|
@@ -354,6 +359,8 @@ class Req:
|
|
354
359
|
return
|
355
360
|
|
356
361
|
def jump_forward_and_retokenize(self, jump_forward_str, next_state):
|
362
|
+
assert self.grammar is not None and self.tokenizer is not None
|
363
|
+
|
357
364
|
if self.origin_input_text is None:
|
358
365
|
# Recovering text can only use unpadded ids
|
359
366
|
self.origin_input_text = self.tokenizer.decode(
|
@@ -393,7 +400,8 @@ class Req:
|
|
393
400
|
self.surr_offset = self.read_offset - i
|
394
401
|
break
|
395
402
|
|
396
|
-
|
403
|
+
# update the inner state of the grammar
|
404
|
+
self.grammar.jump_and_retokenize(old_output_ids, self.output_ids, next_state)
|
397
405
|
|
398
406
|
if self.return_logprob:
|
399
407
|
# For fast-forward part's logprobs
|
@@ -463,8 +471,8 @@ class ScheduleBatch:
|
|
463
471
|
# Stream
|
464
472
|
has_stream: bool = False
|
465
473
|
|
466
|
-
# Has
|
467
|
-
|
474
|
+
# Has grammar
|
475
|
+
has_grammar: bool = False
|
468
476
|
|
469
477
|
# device
|
470
478
|
device: str = "cuda"
|
@@ -472,7 +480,7 @@ class ScheduleBatch:
|
|
472
480
|
@classmethod
|
473
481
|
def init_new(
|
474
482
|
cls,
|
475
|
-
reqs,
|
483
|
+
reqs: List[Req],
|
476
484
|
req_to_token_pool,
|
477
485
|
token_to_kv_pool,
|
478
486
|
tree_cache,
|
@@ -486,7 +494,7 @@ class ScheduleBatch:
|
|
486
494
|
model_config=model_config,
|
487
495
|
return_logprob=any(req.return_logprob for req in reqs),
|
488
496
|
has_stream=any(req.stream for req in reqs),
|
489
|
-
|
497
|
+
has_grammar=any(req.grammar for req in reqs),
|
490
498
|
device=req_to_token_pool.device,
|
491
499
|
)
|
492
500
|
|
@@ -514,7 +522,12 @@ class ScheduleBatch:
|
|
514
522
|
out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)
|
515
523
|
|
516
524
|
if out_cache_loc is None:
|
517
|
-
|
525
|
+
phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
|
526
|
+
logger.error(
|
527
|
+
f"{phase_str} out of memory. Try to lower your batch size.\n"
|
528
|
+
f"Try to allocate {num_tokens} tokens.\n"
|
529
|
+
f"Avaliable tokens: {self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()}\n"
|
530
|
+
)
|
518
531
|
if self.tree_cache is not None:
|
519
532
|
self.tree_cache.pretty_print()
|
520
533
|
exit(1)
|
@@ -551,7 +564,7 @@ class ScheduleBatch:
|
|
551
564
|
seq_lens[i] -= encoder_len
|
552
565
|
|
553
566
|
if len(req.prefix_indices) < encoder_len:
|
554
|
-
# NOTE: the encoder part should considered as a whole
|
567
|
+
# NOTE: the encoder part should be considered as a whole
|
555
568
|
assert len(req.prefix_indices) == 0
|
556
569
|
input_ids[i] = input_ids[i][encoder_len:]
|
557
570
|
encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len])
|
@@ -638,6 +651,7 @@ class ScheduleBatch:
|
|
638
651
|
|
639
652
|
req.extend_logprob_start_len = extend_logprob_start_len
|
640
653
|
pt += req.extend_input_len
|
654
|
+
req.is_retracted = False
|
641
655
|
|
642
656
|
# Set fields
|
643
657
|
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
|
@@ -770,6 +784,7 @@ class ScheduleBatch:
|
|
770
784
|
req.prefix_indices = []
|
771
785
|
req.last_node = None
|
772
786
|
req.extend_input_len = 0
|
787
|
+
req.is_retracted = True
|
773
788
|
|
774
789
|
# For incremental logprobs
|
775
790
|
req.last_update_decode_tokens = 0
|
@@ -793,26 +808,10 @@ class ScheduleBatch:
|
|
793
808
|
keep_indices = set(i for i in range(len(self.reqs)))
|
794
809
|
|
795
810
|
for i, req in enumerate(self.reqs):
|
796
|
-
if req.
|
797
|
-
|
798
|
-
|
799
|
-
|
800
|
-
if jump_forward_bytes is not None and len(jump_forward_bytes) > 1:
|
801
|
-
suffix_bytes = []
|
802
|
-
continuation_range = range(0x80, 0xC0)
|
803
|
-
cur_state = req.regex_fsm_state
|
804
|
-
while (
|
805
|
-
len(jump_forward_bytes)
|
806
|
-
and jump_forward_bytes[0][0] in continuation_range
|
807
|
-
):
|
808
|
-
# continuation bytes
|
809
|
-
byte_edge = jump_forward_bytes.pop(0)
|
810
|
-
suffix_bytes.append(byte_edge[0])
|
811
|
-
cur_state = byte_edge[1]
|
812
|
-
|
813
|
-
suffix_tokens = [f"<0x{hex(b)[2:].upper()}>" for b in suffix_bytes]
|
814
|
-
suffix_ids = req.tokenizer.convert_tokens_to_ids(suffix_tokens)
|
815
|
-
|
811
|
+
if req.grammar is not None:
|
812
|
+
jump_helper = req.grammar.try_jump(req.tokenizer)
|
813
|
+
if jump_helper.can_jump():
|
814
|
+
suffix_ids = jump_helper.suffix_ids
|
816
815
|
# Current ids, for cache and revert
|
817
816
|
cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1]
|
818
817
|
cur_output_ids = req.output_ids
|
@@ -826,10 +825,8 @@ class ScheduleBatch:
|
|
826
825
|
(
|
827
826
|
jump_forward_str,
|
828
827
|
next_state,
|
829
|
-
) = req.
|
828
|
+
) = req.grammar.jump_forward_str_state(jump_helper)
|
830
829
|
|
831
|
-
# Make the incrementally decoded text part of jump_forward_str
|
832
|
-
# so that the UTF-8 will not corrupt
|
833
830
|
jump_forward_str = new_text + jump_forward_str
|
834
831
|
if not req.jump_forward_and_retokenize(
|
835
832
|
jump_forward_str, next_state
|
@@ -896,7 +893,7 @@ class ScheduleBatch:
|
|
896
893
|
|
897
894
|
def filter_batch(
|
898
895
|
self,
|
899
|
-
|
896
|
+
being_chunked_req: Optional[Req] = None,
|
900
897
|
keep_indices: Optional[List[int]] = None,
|
901
898
|
):
|
902
899
|
if keep_indices is None:
|
@@ -904,7 +901,7 @@ class ScheduleBatch:
|
|
904
901
|
i
|
905
902
|
for i in range(len(self.reqs))
|
906
903
|
if not self.reqs[i].finished()
|
907
|
-
and self.reqs[i] is not
|
904
|
+
and self.reqs[i] is not being_chunked_req
|
908
905
|
]
|
909
906
|
|
910
907
|
if keep_indices is None or len(keep_indices) == 0:
|
@@ -936,7 +933,7 @@ class ScheduleBatch:
|
|
936
933
|
self.top_logprobs_nums = None
|
937
934
|
|
938
935
|
self.has_stream = any(req.stream for req in self.reqs)
|
939
|
-
self.
|
936
|
+
self.has_grammar = any(req.grammar for req in self.reqs)
|
940
937
|
|
941
938
|
self.sampling_info.filter_batch(keep_indices, new_indices)
|
942
939
|
|
@@ -969,7 +966,7 @@ class ScheduleBatch:
|
|
969
966
|
|
970
967
|
self.return_logprob = self.return_logprob or other.return_logprob
|
971
968
|
self.has_stream = self.has_stream or other.has_stream
|
972
|
-
self.
|
969
|
+
self.has_grammar = self.has_grammar or other.has_grammar
|
973
970
|
|
974
971
|
def get_model_worker_batch(self):
|
975
972
|
if self.forward_mode.is_decode():
|
@@ -979,13 +976,10 @@ class ScheduleBatch:
|
|
979
976
|
extend_prefix_lens = self.prefix_lens
|
980
977
|
extend_logprob_start_lens = self.extend_logprob_start_lens
|
981
978
|
|
982
|
-
if self.
|
983
|
-
self.sampling_info.
|
984
|
-
self.sampling_info.regex_fsm_states = [
|
985
|
-
req.regex_fsm_state for req in self.reqs
|
986
|
-
]
|
979
|
+
if self.has_grammar:
|
980
|
+
self.sampling_info.grammars = [req.grammar for req in self.reqs]
|
987
981
|
else:
|
988
|
-
self.sampling_info.
|
982
|
+
self.sampling_info.grammars = None
|
989
983
|
|
990
984
|
global bid
|
991
985
|
bid += 1
|
@@ -30,7 +30,9 @@ from sglang.srt.mem_cache.radix_cache import TreeNode
|
|
30
30
|
# This can prevent the server from being too conservative.
|
31
31
|
# Note that this only clips the estimation in the scheduler but does not change the stop
|
32
32
|
# condition. The request can still generate tokens until it hits the unclipped max_new_tokens.
|
33
|
-
|
33
|
+
CLIP_MAX_NEW_TOKENS_ESTIMATION = int(
|
34
|
+
os.environ.get("SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION", "4096")
|
35
|
+
)
|
34
36
|
|
35
37
|
|
36
38
|
class SchedulePolicy:
|
@@ -43,9 +45,15 @@ class SchedulePolicy:
|
|
43
45
|
self.tree_cache = tree_cache
|
44
46
|
|
45
47
|
def calc_priority(self, waiting_queue: List[Req]):
|
48
|
+
if len(waiting_queue) > 128 and self.policy == "lpm":
|
49
|
+
# Turn off the expensive prefix matching and sorting when the #queue is large.
|
50
|
+
policy = "fcfs"
|
51
|
+
else:
|
52
|
+
policy = self.policy
|
53
|
+
|
46
54
|
# Compute matched prefix length
|
47
55
|
prefix_computed = False
|
48
|
-
if
|
56
|
+
if policy == "lpm" or policy == "dfs-weight":
|
49
57
|
for r in waiting_queue:
|
50
58
|
# NOTE: the prefix_indices must always be aligned with last_node
|
51
59
|
r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
|
@@ -54,18 +62,18 @@ class SchedulePolicy:
|
|
54
62
|
|
55
63
|
prefix_computed = True
|
56
64
|
|
57
|
-
if
|
65
|
+
if policy == "lpm":
|
58
66
|
# Longest Prefix Match
|
59
67
|
waiting_queue.sort(key=lambda x: -len(x.prefix_indices))
|
60
|
-
elif
|
68
|
+
elif policy == "fcfs":
|
61
69
|
# first come first serve
|
62
70
|
pass
|
63
|
-
elif
|
71
|
+
elif policy == "lof":
|
64
72
|
# longest output first
|
65
73
|
waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens)
|
66
|
-
elif
|
74
|
+
elif policy == "random":
|
67
75
|
random.shuffle(waiting_queue)
|
68
|
-
elif
|
76
|
+
elif policy == "dfs-weight":
|
69
77
|
last_node_to_reqs = defaultdict(list)
|
70
78
|
for req in waiting_queue:
|
71
79
|
last_node_to_reqs[req.last_node].append(req)
|
@@ -83,7 +91,7 @@ class SchedulePolicy:
|
|
83
91
|
waiting_queue,
|
84
92
|
)
|
85
93
|
else:
|
86
|
-
raise ValueError(f"Unknown schedule_policy: {
|
94
|
+
raise ValueError(f"Unknown schedule_policy: {policy=}")
|
87
95
|
|
88
96
|
return prefix_computed
|
89
97
|
|
@@ -146,7 +154,7 @@ class PrefillAdder:
|
|
146
154
|
[
|
147
155
|
min(
|
148
156
|
(r.sampling_params.max_new_tokens - len(r.output_ids)),
|
149
|
-
|
157
|
+
CLIP_MAX_NEW_TOKENS_ESTIMATION,
|
150
158
|
)
|
151
159
|
* self.new_token_ratio
|
152
160
|
for r in running_batch.reqs
|
@@ -186,7 +194,7 @@ class PrefillAdder:
|
|
186
194
|
len(req.prefix_indices),
|
187
195
|
req.extend_input_len,
|
188
196
|
(
|
189
|
-
min(req.sampling_params.max_new_tokens,
|
197
|
+
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION)
|
190
198
|
if not truncated
|
191
199
|
else 0
|
192
200
|
),
|
@@ -258,7 +266,7 @@ class PrefillAdder:
|
|
258
266
|
self._prefill_one_req(
|
259
267
|
0,
|
260
268
|
req.extend_input_len,
|
261
|
-
min(req.sampling_params.max_new_tokens,
|
269
|
+
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION),
|
262
270
|
)
|
263
271
|
else:
|
264
272
|
# Chunked prefill
|
@@ -276,7 +284,7 @@ class PrefillAdder:
|
|
276
284
|
return self.add_one_req_ignore_eos(req)
|
277
285
|
|
278
286
|
total_tokens = req.extend_input_len + min(
|
279
|
-
req.sampling_params.max_new_tokens,
|
287
|
+
req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION
|
280
288
|
)
|
281
289
|
input_tokens = req.extend_input_len
|
282
290
|
prefix_len = len(req.prefix_indices)
|
@@ -302,7 +310,10 @@ class PrefillAdder:
|
|
302
310
|
self._prefill_one_req(
|
303
311
|
prefix_len,
|
304
312
|
input_tokens,
|
305
|
-
min(
|
313
|
+
min(
|
314
|
+
req.sampling_params.max_new_tokens,
|
315
|
+
CLIP_MAX_NEW_TOKENS_ESTIMATION,
|
316
|
+
),
|
306
317
|
)
|
307
318
|
else:
|
308
319
|
# Chunked prefill
|