sglang 0.3.4.post1__py3-none-any.whl → 0.3.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (91) hide show
  1. sglang/api.py +1 -1
  2. sglang/bench_latency.py +3 -3
  3. sglang/bench_server_latency.py +2 -3
  4. sglang/bench_serving.py +92 -0
  5. sglang/global_config.py +9 -3
  6. sglang/lang/chat_template.py +50 -25
  7. sglang/lang/interpreter.py +9 -1
  8. sglang/lang/ir.py +11 -2
  9. sglang/launch_server.py +1 -1
  10. sglang/srt/configs/model_config.py +76 -15
  11. sglang/srt/constrained/__init__.py +18 -0
  12. sglang/srt/constrained/bnf_cache.py +61 -0
  13. sglang/srt/constrained/fsm_cache.py +10 -3
  14. sglang/srt/constrained/grammar.py +190 -0
  15. sglang/srt/hf_transformers_utils.py +20 -5
  16. sglang/srt/layers/attention/flashinfer_backend.py +5 -5
  17. sglang/srt/layers/attention/triton_ops/decode_attention.py +110 -30
  18. sglang/srt/layers/attention/triton_ops/prefill_attention.py +1 -1
  19. sglang/srt/layers/fused_moe/fused_moe.py +4 -3
  20. sglang/srt/layers/fused_moe/layer.py +28 -0
  21. sglang/srt/layers/logits_processor.py +5 -5
  22. sglang/srt/layers/quantization/base_config.py +16 -1
  23. sglang/srt/layers/rotary_embedding.py +15 -48
  24. sglang/srt/layers/sampler.py +51 -39
  25. sglang/srt/layers/vocab_parallel_embedding.py +486 -0
  26. sglang/srt/managers/data_parallel_controller.py +8 -7
  27. sglang/srt/managers/detokenizer_manager.py +11 -9
  28. sglang/srt/managers/image_processor.py +4 -3
  29. sglang/srt/managers/io_struct.py +80 -78
  30. sglang/srt/managers/schedule_batch.py +46 -52
  31. sglang/srt/managers/schedule_policy.py +24 -13
  32. sglang/srt/managers/scheduler.py +145 -82
  33. sglang/srt/managers/tokenizer_manager.py +236 -334
  34. sglang/srt/managers/tp_worker.py +5 -5
  35. sglang/srt/managers/tp_worker_overlap_thread.py +58 -21
  36. sglang/srt/mem_cache/flush_cache.py +1 -1
  37. sglang/srt/mem_cache/memory_pool.py +10 -3
  38. sglang/srt/model_executor/cuda_graph_runner.py +34 -23
  39. sglang/srt/model_executor/forward_batch_info.py +6 -9
  40. sglang/srt/model_executor/model_runner.py +10 -19
  41. sglang/srt/models/baichuan.py +4 -4
  42. sglang/srt/models/chatglm.py +4 -4
  43. sglang/srt/models/commandr.py +1 -1
  44. sglang/srt/models/dbrx.py +5 -5
  45. sglang/srt/models/deepseek.py +4 -4
  46. sglang/srt/models/deepseek_v2.py +4 -4
  47. sglang/srt/models/exaone.py +4 -4
  48. sglang/srt/models/gemma.py +1 -1
  49. sglang/srt/models/gemma2.py +1 -1
  50. sglang/srt/models/gpt2.py +287 -0
  51. sglang/srt/models/gpt_bigcode.py +1 -1
  52. sglang/srt/models/grok.py +4 -4
  53. sglang/srt/models/internlm2.py +4 -4
  54. sglang/srt/models/llama.py +15 -7
  55. sglang/srt/models/llama_embedding.py +2 -10
  56. sglang/srt/models/llama_reward.py +5 -0
  57. sglang/srt/models/minicpm.py +4 -4
  58. sglang/srt/models/minicpm3.py +4 -4
  59. sglang/srt/models/mixtral.py +7 -5
  60. sglang/srt/models/mixtral_quant.py +4 -4
  61. sglang/srt/models/mllama.py +5 -5
  62. sglang/srt/models/olmo.py +4 -4
  63. sglang/srt/models/olmoe.py +4 -4
  64. sglang/srt/models/qwen.py +4 -4
  65. sglang/srt/models/qwen2.py +4 -4
  66. sglang/srt/models/qwen2_moe.py +4 -4
  67. sglang/srt/models/qwen2_vl.py +4 -8
  68. sglang/srt/models/stablelm.py +4 -4
  69. sglang/srt/models/torch_native_llama.py +4 -4
  70. sglang/srt/models/xverse.py +4 -4
  71. sglang/srt/models/xverse_moe.py +4 -4
  72. sglang/srt/openai_api/adapter.py +52 -66
  73. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +6 -3
  74. sglang/srt/sampling/sampling_batch_info.py +7 -13
  75. sglang/srt/sampling/sampling_params.py +5 -7
  76. sglang/srt/server.py +41 -33
  77. sglang/srt/server_args.py +34 -5
  78. sglang/srt/utils.py +40 -56
  79. sglang/test/run_eval.py +2 -0
  80. sglang/test/runners.py +2 -1
  81. sglang/test/srt/sampling/penaltylib/utils.py +1 -0
  82. sglang/test/test_utils.py +151 -6
  83. sglang/utils.py +62 -1
  84. sglang/version.py +1 -1
  85. sglang-0.3.5.dist-info/METADATA +344 -0
  86. sglang-0.3.5.dist-info/RECORD +152 -0
  87. {sglang-0.3.4.post1.dist-info → sglang-0.3.5.dist-info}/WHEEL +1 -1
  88. sglang-0.3.4.post1.dist-info/METADATA +0 -900
  89. sglang-0.3.4.post1.dist-info/RECORD +0 -148
  90. {sglang-0.3.4.post1.dist-info → sglang-0.3.5.dist-info}/LICENSE +0 -0
  91. {sglang-0.3.4.post1.dist-info → sglang-0.3.5.dist-info}/top_level.txt +0 -0
@@ -56,49 +56,47 @@ class GenerateReqInput:
56
56
  # LoRA related
57
57
  lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
58
58
 
59
- # Whether it is a single request or a batch request
60
- is_single: bool = True
61
-
62
- def post_init(self):
59
+ def normalize_batch_and_arguments(self):
63
60
  if (self.text is None and self.input_ids is None) or (
64
61
  self.text is not None and self.input_ids is not None
65
62
  ):
66
63
  raise ValueError("Either text or input_ids should be provided.")
67
64
 
68
- self.is_single = False
65
+ # Derive the batch size
69
66
  if self.text is not None:
70
67
  if isinstance(self.text, str):
71
68
  self.is_single = True
72
69
  self.batch_size = 1
73
70
  else:
71
+ self.is_single = False
74
72
  self.batch_size = len(self.text)
75
73
  else:
76
74
  if isinstance(self.input_ids[0], int):
77
75
  self.is_single = True
78
76
  self.batch_size = 1
79
77
  else:
78
+ self.is_single = False
80
79
  self.batch_size = len(self.input_ids)
81
80
 
81
+ # Handle parallel sampling
82
+ # When parallel sampling is used, we always treat the input as a batch.
82
83
  if self.sampling_params is None:
83
84
  self.parallel_sample_num = 1
84
85
  elif isinstance(self.sampling_params, dict):
85
86
  self.parallel_sample_num = self.sampling_params.get("n", 1)
86
87
  else: # isinstance(self.sampling_params, list):
87
88
  self.parallel_sample_num = self.sampling_params[0].get("n", 1)
88
- for sp in self.sampling_params:
89
- # TODO cope with the case that the parallel_sample_num is different for different samples
90
- assert self.parallel_sample_num == sp.get(
91
- "n", 1
92
- ), "The parallel_sample_num should be the same for all samples in sample params."
93
-
94
- if self.parallel_sample_num > 1:
95
- if self.is_single:
96
- self.is_single = False
97
- if self.text is not None:
98
- self.text = [self.text]
99
- if self.input_ids is not None:
100
- self.input_ids = [self.input_ids]
89
+ assert all(self.parallel_sample_num == sampling_params.get("n", 1) for sampling_params in self.sampling_params), (
90
+ "The parallel_sample_num should be the same for all samples in sample params.")
101
91
 
92
+ if self.parallel_sample_num > 1 and self.is_single:
93
+ self.is_single = False
94
+ if self.text is not None:
95
+ self.text = [self.text]
96
+ if self.input_ids is not None:
97
+ self.input_ids = [self.input_ids]
98
+
99
+ # Fill in default arguments
102
100
  if self.is_single:
103
101
  if self.sampling_params is None:
104
102
  self.sampling_params = {}
@@ -114,9 +112,8 @@ class GenerateReqInput:
114
112
  if self.parallel_sample_num == 1:
115
113
  num = self.batch_size
116
114
  else:
117
- # FIXME support cascade inference
118
- # first bs samples are used for caching the prefix for parallel sampling
119
- num = self.batch_size + self.parallel_sample_num * self.batch_size
115
+ # Expand parallel_sample_num
116
+ num = self.batch_size * self.parallel_sample_num
120
117
 
121
118
  if self.image_data is None:
122
119
  self.image_data = [None] * num
@@ -129,14 +126,11 @@ class GenerateReqInput:
129
126
  self.sampling_params = [{}] * num
130
127
  elif not isinstance(self.sampling_params, list):
131
128
  self.sampling_params = [self.sampling_params] * num
132
- else:
133
- assert self.parallel_sample_num == 1
134
129
 
135
130
  if self.rid is None:
136
131
  self.rid = [uuid.uuid4().hex for _ in range(num)]
137
132
  else:
138
133
  assert isinstance(self.rid, list), "The rid should be a list."
139
- assert self.parallel_sample_num == 1
140
134
 
141
135
  if self.return_logprob is None:
142
136
  self.return_logprob = [False] * num
@@ -159,6 +153,26 @@ class GenerateReqInput:
159
153
  else:
160
154
  assert self.parallel_sample_num == 1
161
155
 
156
+ def regenerate_rid(self):
157
+ self.rid = uuid.uuid4().hex
158
+ return self.rid
159
+
160
+ def __getitem__(self, i):
161
+ return GenerateReqInput(
162
+ text=self.text[i] if self.text is not None else None,
163
+ input_ids=self.input_ids[i] if self.input_ids is not None else None,
164
+ image_data=self.image_data[i],
165
+ sampling_params=self.sampling_params[i],
166
+ rid=self.rid[i],
167
+ return_logprob=self.return_logprob[i],
168
+ logprob_start_len=self.logprob_start_len[i],
169
+ top_logprobs_num=self.top_logprobs_num[i],
170
+ return_text_in_logprobs=self.return_text_in_logprobs,
171
+ stream=self.stream,
172
+ modalities=self.modalities[i] if self.modalities else None,
173
+ lora_path=self.lora_path[i] if self.lora_path is not None else None,
174
+ )
175
+
162
176
 
163
177
  @dataclass
164
178
  class TokenizedGenerateReqInput:
@@ -196,85 +210,61 @@ class EmbeddingReqInput:
196
210
  # Dummy sampling params for compatibility
197
211
  sampling_params: Union[List[Dict], Dict] = None
198
212
 
199
- def post_init(self):
213
+ def normalize_batch_and_arguments(self):
200
214
  if (self.text is None and self.input_ids is None) or (
201
215
  self.text is not None and self.input_ids is not None
202
216
  ):
203
217
  raise ValueError("Either text or input_ids should be provided.")
204
218
 
219
+ # Derive the batch size
205
220
  if self.text is not None:
206
- self.is_single = isinstance(self.text, str)
221
+ if isinstance(self.text, str):
222
+ self.is_single = True
223
+ self.batch_size = 1
224
+ else:
225
+ self.is_single = False
226
+ self.batch_size = len(self.text)
207
227
  else:
208
- self.is_single = isinstance(self.input_ids[0], int)
228
+ if isinstance(self.input_ids[0], int):
229
+ self.is_single = True
230
+ self.batch_size = 1
231
+ else:
232
+ self.is_single = False
233
+ self.batch_size = len(self.input_ids)
209
234
 
235
+ # Fill in default arguments
210
236
  if self.is_single:
211
237
  if self.rid is None:
212
238
  self.rid = uuid.uuid4().hex
213
239
  if self.sampling_params is None:
214
240
  self.sampling_params = {}
215
- self.sampling_params["max_new_tokens"] = 1
241
+ self.sampling_params["max_new_tokens"] = 0
216
242
  else:
217
- # support select operation
218
- self.batch_size = (
219
- len(self.text) if self.text is not None else len(self.input_ids)
220
- )
221
243
  if self.rid is None:
222
244
  self.rid = [uuid.uuid4().hex for _ in range(self.batch_size)]
223
245
  else:
224
- if not isinstance(self.rid, list):
225
- raise ValueError("The rid should be a list.")
246
+ assert isinstance(self.rid, list), "The rid should be a list."
247
+
226
248
  if self.sampling_params is None:
227
249
  self.sampling_params = [{}] * self.batch_size
228
250
  for i in range(self.batch_size):
229
- self.sampling_params[i]["max_new_tokens"] = 1
230
-
231
-
232
- @dataclass
233
- class TokenizedEmbeddingReqInput:
234
- # The request id
235
- rid: str
236
- # The input text
237
- input_text: str
238
- # The input token ids
239
- input_ids: List[int]
240
- # Dummy sampling params for compatibility
241
- sampling_params: SamplingParams
251
+ self.sampling_params[i]["max_new_tokens"] = 0
242
252
 
253
+ def regenerate_rid(self):
254
+ self.rid = uuid.uuid4().hex
255
+ return self.rid
243
256
 
244
- @dataclass
245
- class RewardReqInput:
246
- # The input prompt in the chat format. It can be a single prompt or a batch of prompts.
247
- conv: Union[List[List[Dict]], List[Dict]]
248
- # The request id.
249
- rid: Optional[Union[List[str], str]] = None
250
- # Dummy sampling params for compatibility
251
- sampling_params: Union[List[Dict], Dict] = None
252
-
253
- def post_init(self):
254
- self.is_single = isinstance(self.conv[0], dict)
255
-
256
- if self.is_single:
257
- if self.rid is None:
258
- self.rid = uuid.uuid4().hex
259
- if self.sampling_params is None:
260
- self.sampling_params = {}
261
- self.sampling_params["max_new_tokens"] = 1
262
- else:
263
- # support select operation
264
- self.batch_size = len(self.conv)
265
- if self.rid is None:
266
- self.rid = [uuid.uuid4().hex for _ in range(self.batch_size)]
267
- else:
268
- if not isinstance(self.rid, list):
269
- raise ValueError("The rid should be a list.")
270
- if self.sampling_params is None:
271
- self.sampling_params = [{}] * self.batch_size
272
- for i in range(self.batch_size):
273
- self.sampling_params[i]["max_new_tokens"] = 1
257
+ def __getitem__(self, i):
258
+ return EmbeddingReqInput(
259
+ text=self.text[i] if self.text is not None else None,
260
+ input_ids=self.input_ids[i] if self.input_ids is not None else None,
261
+ sampling_params=self.sampling_params[i],
262
+ rid=self.rid[i],
263
+ )
274
264
 
275
265
 
276
266
  @dataclass
277
- class TokenizedRewardReqInput:
267
+ class TokenizedEmbeddingReqInput:
278
268
  # The request id
279
269
  rid: str
280
270
  # The input text
@@ -294,6 +284,8 @@ class BatchTokenIDOut:
294
284
  decoded_texts: List[str]
295
285
  decode_ids: List[int]
296
286
  read_offsets: List[int]
287
+ # Only used when `--skip-tokenizer-init`
288
+ output_ids: Optional[List[int]]
297
289
  skip_special_tokens: List[bool]
298
290
  spaces_between_special_tokens: List[bool]
299
291
  meta_info: List[Dict]
@@ -353,3 +345,13 @@ class AbortReq:
353
345
  class ProfileReq(Enum):
354
346
  START_PROFILE = 1
355
347
  STOP_PROFILE = 2
348
+
349
+
350
+ @dataclass
351
+ class GetMemPoolSizeReq:
352
+ pass
353
+
354
+
355
+ @dataclass
356
+ class GetMemPoolSizeReqOutput:
357
+ size: int
@@ -37,8 +37,7 @@ import torch
37
37
 
38
38
  from sglang.global_config import global_config
39
39
  from sglang.srt.configs.model_config import ModelConfig
40
- from sglang.srt.constrained import RegexGuide
41
- from sglang.srt.constrained.jump_forward import JumpForwardMap
40
+ from sglang.srt.constrained.grammar import Grammar
42
41
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
43
42
  from sglang.srt.mem_cache.chunk_cache import ChunkCache
44
43
  from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
@@ -212,9 +211,6 @@ class Req:
212
211
  # this does not include the jump forward tokens.
213
212
  self.completion_tokens_wo_jump_forward = 0
214
213
 
215
- # The number of cached tokens, that were already cached in the KV store
216
- self.cached_tokens = 0
217
-
218
214
  # For vision inputs
219
215
  self.image_inputs: Optional[ImageInputs] = None
220
216
 
@@ -222,7 +218,10 @@ class Req:
222
218
  self.prefix_indices = []
223
219
  self.extend_input_len = 0
224
220
  self.last_node = None
225
- self.is_inflight_req = 0
221
+ self.is_being_chunked = 0
222
+
223
+ # For retraction
224
+ self.is_retracted = False
226
225
 
227
226
  # Logprobs (arguments)
228
227
  self.return_logprob = False
@@ -243,13 +242,14 @@ class Req:
243
242
  # The relative logprob_start_len in an extend batch
244
243
  self.extend_logprob_start_len = 0
245
244
 
246
- # Embedding
245
+ # Embedding (return values)
247
246
  self.embedding = None
248
247
 
249
248
  # Constrained decoding
250
- self.regex_fsm: RegexGuide = None
251
- self.regex_fsm_state: int = 0
252
- self.jump_forward_map: JumpForwardMap = None
249
+ self.grammar: Optional[Grammar] = None
250
+
251
+ # The number of cached tokens, that were already cached in the KV cache
252
+ self.cached_tokens = 0
253
253
 
254
254
  # For Qwen2-VL
255
255
  self.mrope_position_delta = [] # use mutable object
@@ -334,15 +334,20 @@ class Req:
334
334
 
335
335
  last_token_id = self.output_ids[-1]
336
336
 
337
- matched_eos = last_token_id in self.sampling_params.stop_token_ids
337
+ matched_eos = False
338
338
 
339
+ # Check stop token ids
340
+ if self.sampling_params.stop_token_ids:
341
+ matched_eos = last_token_id in self.sampling_params.stop_token_ids
339
342
  if self.tokenizer is not None:
340
343
  matched_eos |= last_token_id == self.tokenizer.eos_token_id
341
-
344
+ if self.tokenizer.additional_stop_token_ids:
345
+ matched_eos |= last_token_id in self.tokenizer.additional_stop_token_ids
342
346
  if matched_eos and not self.sampling_params.ignore_eos:
343
347
  self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
344
348
  return
345
349
 
350
+ # Check stop strings
346
351
  if len(self.sampling_params.stop_strs) > 0:
347
352
  tail_str = self.tokenizer.decode(
348
353
  self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
@@ -354,6 +359,8 @@ class Req:
354
359
  return
355
360
 
356
361
  def jump_forward_and_retokenize(self, jump_forward_str, next_state):
362
+ assert self.grammar is not None and self.tokenizer is not None
363
+
357
364
  if self.origin_input_text is None:
358
365
  # Recovering text can only use unpadded ids
359
366
  self.origin_input_text = self.tokenizer.decode(
@@ -393,7 +400,8 @@ class Req:
393
400
  self.surr_offset = self.read_offset - i
394
401
  break
395
402
 
396
- self.regex_fsm_state = next_state
403
+ # update the inner state of the grammar
404
+ self.grammar.jump_and_retokenize(old_output_ids, self.output_ids, next_state)
397
405
 
398
406
  if self.return_logprob:
399
407
  # For fast-forward part's logprobs
@@ -463,8 +471,8 @@ class ScheduleBatch:
463
471
  # Stream
464
472
  has_stream: bool = False
465
473
 
466
- # Has regex
467
- has_regex: bool = False
474
+ # Has grammar
475
+ has_grammar: bool = False
468
476
 
469
477
  # device
470
478
  device: str = "cuda"
@@ -472,7 +480,7 @@ class ScheduleBatch:
472
480
  @classmethod
473
481
  def init_new(
474
482
  cls,
475
- reqs,
483
+ reqs: List[Req],
476
484
  req_to_token_pool,
477
485
  token_to_kv_pool,
478
486
  tree_cache,
@@ -486,7 +494,7 @@ class ScheduleBatch:
486
494
  model_config=model_config,
487
495
  return_logprob=any(req.return_logprob for req in reqs),
488
496
  has_stream=any(req.stream for req in reqs),
489
- has_regex=any(req.regex_fsm for req in reqs),
497
+ has_grammar=any(req.grammar for req in reqs),
490
498
  device=req_to_token_pool.device,
491
499
  )
492
500
 
@@ -514,7 +522,12 @@ class ScheduleBatch:
514
522
  out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)
515
523
 
516
524
  if out_cache_loc is None:
517
- logger.error("Prefill out of memory. Try to lower your batch size.")
525
+ phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
526
+ logger.error(
527
+ f"{phase_str} out of memory. Try to lower your batch size.\n"
528
+ f"Try to allocate {num_tokens} tokens.\n"
529
+ f"Avaliable tokens: {self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()}\n"
530
+ )
518
531
  if self.tree_cache is not None:
519
532
  self.tree_cache.pretty_print()
520
533
  exit(1)
@@ -551,7 +564,7 @@ class ScheduleBatch:
551
564
  seq_lens[i] -= encoder_len
552
565
 
553
566
  if len(req.prefix_indices) < encoder_len:
554
- # NOTE: the encoder part should considered as a whole
567
+ # NOTE: the encoder part should be considered as a whole
555
568
  assert len(req.prefix_indices) == 0
556
569
  input_ids[i] = input_ids[i][encoder_len:]
557
570
  encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len])
@@ -638,6 +651,7 @@ class ScheduleBatch:
638
651
 
639
652
  req.extend_logprob_start_len = extend_logprob_start_len
640
653
  pt += req.extend_input_len
654
+ req.is_retracted = False
641
655
 
642
656
  # Set fields
643
657
  self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
@@ -770,6 +784,7 @@ class ScheduleBatch:
770
784
  req.prefix_indices = []
771
785
  req.last_node = None
772
786
  req.extend_input_len = 0
787
+ req.is_retracted = True
773
788
 
774
789
  # For incremental logprobs
775
790
  req.last_update_decode_tokens = 0
@@ -793,26 +808,10 @@ class ScheduleBatch:
793
808
  keep_indices = set(i for i in range(len(self.reqs)))
794
809
 
795
810
  for i, req in enumerate(self.reqs):
796
- if req.jump_forward_map is not None:
797
- jump_forward_bytes = req.jump_forward_map.jump_forward_byte(
798
- req.regex_fsm_state
799
- )
800
- if jump_forward_bytes is not None and len(jump_forward_bytes) > 1:
801
- suffix_bytes = []
802
- continuation_range = range(0x80, 0xC0)
803
- cur_state = req.regex_fsm_state
804
- while (
805
- len(jump_forward_bytes)
806
- and jump_forward_bytes[0][0] in continuation_range
807
- ):
808
- # continuation bytes
809
- byte_edge = jump_forward_bytes.pop(0)
810
- suffix_bytes.append(byte_edge[0])
811
- cur_state = byte_edge[1]
812
-
813
- suffix_tokens = [f"<0x{hex(b)[2:].upper()}>" for b in suffix_bytes]
814
- suffix_ids = req.tokenizer.convert_tokens_to_ids(suffix_tokens)
815
-
811
+ if req.grammar is not None:
812
+ jump_helper = req.grammar.try_jump(req.tokenizer)
813
+ if jump_helper.can_jump():
814
+ suffix_ids = jump_helper.suffix_ids
816
815
  # Current ids, for cache and revert
817
816
  cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1]
818
817
  cur_output_ids = req.output_ids
@@ -826,10 +825,8 @@ class ScheduleBatch:
826
825
  (
827
826
  jump_forward_str,
828
827
  next_state,
829
- ) = req.jump_forward_map.jump_forward_symbol(cur_state)
828
+ ) = req.grammar.jump_forward_str_state(jump_helper)
830
829
 
831
- # Make the incrementally decoded text part of jump_forward_str
832
- # so that the UTF-8 will not corrupt
833
830
  jump_forward_str = new_text + jump_forward_str
834
831
  if not req.jump_forward_and_retokenize(
835
832
  jump_forward_str, next_state
@@ -896,7 +893,7 @@ class ScheduleBatch:
896
893
 
897
894
  def filter_batch(
898
895
  self,
899
- current_inflight_req: Optional[Req] = None,
896
+ being_chunked_req: Optional[Req] = None,
900
897
  keep_indices: Optional[List[int]] = None,
901
898
  ):
902
899
  if keep_indices is None:
@@ -904,7 +901,7 @@ class ScheduleBatch:
904
901
  i
905
902
  for i in range(len(self.reqs))
906
903
  if not self.reqs[i].finished()
907
- and self.reqs[i] is not current_inflight_req
904
+ and self.reqs[i] is not being_chunked_req
908
905
  ]
909
906
 
910
907
  if keep_indices is None or len(keep_indices) == 0:
@@ -936,7 +933,7 @@ class ScheduleBatch:
936
933
  self.top_logprobs_nums = None
937
934
 
938
935
  self.has_stream = any(req.stream for req in self.reqs)
939
- self.has_regex = any(req.regex_fsm for req in self.reqs)
936
+ self.has_grammar = any(req.grammar for req in self.reqs)
940
937
 
941
938
  self.sampling_info.filter_batch(keep_indices, new_indices)
942
939
 
@@ -969,7 +966,7 @@ class ScheduleBatch:
969
966
 
970
967
  self.return_logprob = self.return_logprob or other.return_logprob
971
968
  self.has_stream = self.has_stream or other.has_stream
972
- self.has_regex = self.has_regex or other.has_regex
969
+ self.has_grammar = self.has_grammar or other.has_grammar
973
970
 
974
971
  def get_model_worker_batch(self):
975
972
  if self.forward_mode.is_decode():
@@ -979,13 +976,10 @@ class ScheduleBatch:
979
976
  extend_prefix_lens = self.prefix_lens
980
977
  extend_logprob_start_lens = self.extend_logprob_start_lens
981
978
 
982
- if self.has_regex:
983
- self.sampling_info.regex_fsms = [req.regex_fsm for req in self.reqs]
984
- self.sampling_info.regex_fsm_states = [
985
- req.regex_fsm_state for req in self.reqs
986
- ]
979
+ if self.has_grammar:
980
+ self.sampling_info.grammars = [req.grammar for req in self.reqs]
987
981
  else:
988
- self.sampling_info.regex_fsms = None
982
+ self.sampling_info.grammars = None
989
983
 
990
984
  global bid
991
985
  bid += 1
@@ -30,7 +30,9 @@ from sglang.srt.mem_cache.radix_cache import TreeNode
30
30
  # This can prevent the server from being too conservative.
31
31
  # Note that this only clips the estimation in the scheduler but does not change the stop
32
32
  # condition. The request can still generate tokens until it hits the unclipped max_new_tokens.
33
- CLIP_MAX_NEW_TOKENS = int(os.environ.get("SGLANG_CLIP_MAX_NEW_TOKENS", "4096"))
33
+ CLIP_MAX_NEW_TOKENS_ESTIMATION = int(
34
+ os.environ.get("SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION", "4096")
35
+ )
34
36
 
35
37
 
36
38
  class SchedulePolicy:
@@ -43,9 +45,15 @@ class SchedulePolicy:
43
45
  self.tree_cache = tree_cache
44
46
 
45
47
  def calc_priority(self, waiting_queue: List[Req]):
48
+ if len(waiting_queue) > 128 and self.policy == "lpm":
49
+ # Turn off the expensive prefix matching and sorting when the #queue is large.
50
+ policy = "fcfs"
51
+ else:
52
+ policy = self.policy
53
+
46
54
  # Compute matched prefix length
47
55
  prefix_computed = False
48
- if self.policy == "lpm" or self.policy == "dfs-weight":
56
+ if policy == "lpm" or policy == "dfs-weight":
49
57
  for r in waiting_queue:
50
58
  # NOTE: the prefix_indices must always be aligned with last_node
51
59
  r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
@@ -54,18 +62,18 @@ class SchedulePolicy:
54
62
 
55
63
  prefix_computed = True
56
64
 
57
- if self.policy == "lpm":
65
+ if policy == "lpm":
58
66
  # Longest Prefix Match
59
67
  waiting_queue.sort(key=lambda x: -len(x.prefix_indices))
60
- elif self.policy == "fcfs":
68
+ elif policy == "fcfs":
61
69
  # first come first serve
62
70
  pass
63
- elif self.policy == "lof":
71
+ elif policy == "lof":
64
72
  # longest output first
65
73
  waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens)
66
- elif self.policy == "random":
74
+ elif policy == "random":
67
75
  random.shuffle(waiting_queue)
68
- elif self.policy == "dfs-weight":
76
+ elif policy == "dfs-weight":
69
77
  last_node_to_reqs = defaultdict(list)
70
78
  for req in waiting_queue:
71
79
  last_node_to_reqs[req.last_node].append(req)
@@ -83,7 +91,7 @@ class SchedulePolicy:
83
91
  waiting_queue,
84
92
  )
85
93
  else:
86
- raise ValueError(f"Unknown schedule_policy: {self.policy}")
94
+ raise ValueError(f"Unknown schedule_policy: {policy=}")
87
95
 
88
96
  return prefix_computed
89
97
 
@@ -146,7 +154,7 @@ class PrefillAdder:
146
154
  [
147
155
  min(
148
156
  (r.sampling_params.max_new_tokens - len(r.output_ids)),
149
- CLIP_MAX_NEW_TOKENS,
157
+ CLIP_MAX_NEW_TOKENS_ESTIMATION,
150
158
  )
151
159
  * self.new_token_ratio
152
160
  for r in running_batch.reqs
@@ -186,7 +194,7 @@ class PrefillAdder:
186
194
  len(req.prefix_indices),
187
195
  req.extend_input_len,
188
196
  (
189
- min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS)
197
+ min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION)
190
198
  if not truncated
191
199
  else 0
192
200
  ),
@@ -258,7 +266,7 @@ class PrefillAdder:
258
266
  self._prefill_one_req(
259
267
  0,
260
268
  req.extend_input_len,
261
- min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS),
269
+ min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION),
262
270
  )
263
271
  else:
264
272
  # Chunked prefill
@@ -276,7 +284,7 @@ class PrefillAdder:
276
284
  return self.add_one_req_ignore_eos(req)
277
285
 
278
286
  total_tokens = req.extend_input_len + min(
279
- req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS
287
+ req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION
280
288
  )
281
289
  input_tokens = req.extend_input_len
282
290
  prefix_len = len(req.prefix_indices)
@@ -302,7 +310,10 @@ class PrefillAdder:
302
310
  self._prefill_one_req(
303
311
  prefix_len,
304
312
  input_tokens,
305
- min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS),
313
+ min(
314
+ req.sampling_params.max_new_tokens,
315
+ CLIP_MAX_NEW_TOKENS_ESTIMATION,
316
+ ),
306
317
  )
307
318
  else:
308
319
  # Chunked prefill