sglang 0.3.4.post2__py3-none-any.whl → 0.3.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. sglang/api.py +1 -1
  2. sglang/bench_latency.py +3 -3
  3. sglang/bench_server_latency.py +2 -3
  4. sglang/bench_serving.py +92 -0
  5. sglang/global_config.py +9 -3
  6. sglang/lang/chat_template.py +50 -25
  7. sglang/lang/interpreter.py +9 -1
  8. sglang/lang/ir.py +11 -2
  9. sglang/launch_server.py +1 -1
  10. sglang/srt/configs/model_config.py +51 -13
  11. sglang/srt/constrained/__init__.py +18 -0
  12. sglang/srt/constrained/bnf_cache.py +61 -0
  13. sglang/srt/constrained/grammar.py +190 -0
  14. sglang/srt/hf_transformers_utils.py +6 -5
  15. sglang/srt/layers/attention/triton_ops/decode_attention.py +110 -30
  16. sglang/srt/layers/attention/triton_ops/prefill_attention.py +1 -1
  17. sglang/srt/layers/fused_moe/fused_moe.py +4 -3
  18. sglang/srt/layers/fused_moe/layer.py +28 -0
  19. sglang/srt/layers/quantization/base_config.py +16 -1
  20. sglang/srt/layers/vocab_parallel_embedding.py +486 -0
  21. sglang/srt/managers/data_parallel_controller.py +7 -6
  22. sglang/srt/managers/detokenizer_manager.py +9 -11
  23. sglang/srt/managers/image_processor.py +4 -3
  24. sglang/srt/managers/io_struct.py +70 -78
  25. sglang/srt/managers/schedule_batch.py +33 -49
  26. sglang/srt/managers/schedule_policy.py +24 -13
  27. sglang/srt/managers/scheduler.py +137 -80
  28. sglang/srt/managers/tokenizer_manager.py +224 -336
  29. sglang/srt/managers/tp_worker.py +5 -5
  30. sglang/srt/mem_cache/flush_cache.py +1 -1
  31. sglang/srt/model_executor/cuda_graph_runner.py +7 -4
  32. sglang/srt/model_executor/model_runner.py +8 -17
  33. sglang/srt/models/baichuan.py +4 -4
  34. sglang/srt/models/chatglm.py +4 -4
  35. sglang/srt/models/commandr.py +1 -1
  36. sglang/srt/models/dbrx.py +5 -5
  37. sglang/srt/models/deepseek.py +4 -4
  38. sglang/srt/models/deepseek_v2.py +4 -4
  39. sglang/srt/models/exaone.py +4 -4
  40. sglang/srt/models/gemma.py +1 -1
  41. sglang/srt/models/gemma2.py +1 -1
  42. sglang/srt/models/gpt2.py +287 -0
  43. sglang/srt/models/gpt_bigcode.py +1 -1
  44. sglang/srt/models/grok.py +4 -4
  45. sglang/srt/models/internlm2.py +4 -4
  46. sglang/srt/models/llama.py +15 -7
  47. sglang/srt/models/llama_embedding.py +2 -10
  48. sglang/srt/models/llama_reward.py +5 -0
  49. sglang/srt/models/minicpm.py +4 -4
  50. sglang/srt/models/minicpm3.py +4 -4
  51. sglang/srt/models/mixtral.py +7 -5
  52. sglang/srt/models/mixtral_quant.py +4 -4
  53. sglang/srt/models/mllama.py +5 -5
  54. sglang/srt/models/olmo.py +4 -4
  55. sglang/srt/models/olmoe.py +4 -4
  56. sglang/srt/models/qwen.py +4 -4
  57. sglang/srt/models/qwen2.py +4 -4
  58. sglang/srt/models/qwen2_moe.py +4 -4
  59. sglang/srt/models/qwen2_vl.py +4 -8
  60. sglang/srt/models/stablelm.py +4 -4
  61. sglang/srt/models/torch_native_llama.py +4 -4
  62. sglang/srt/models/xverse.py +4 -4
  63. sglang/srt/models/xverse_moe.py +4 -4
  64. sglang/srt/openai_api/adapter.py +52 -66
  65. sglang/srt/sampling/sampling_batch_info.py +7 -13
  66. sglang/srt/server.py +31 -35
  67. sglang/srt/server_args.py +34 -5
  68. sglang/srt/utils.py +40 -56
  69. sglang/test/runners.py +2 -1
  70. sglang/test/test_utils.py +73 -25
  71. sglang/utils.py +62 -1
  72. sglang/version.py +1 -1
  73. sglang-0.3.5.dist-info/METADATA +344 -0
  74. {sglang-0.3.4.post2.dist-info → sglang-0.3.5.dist-info}/RECORD +77 -73
  75. {sglang-0.3.4.post2.dist-info → sglang-0.3.5.dist-info}/WHEEL +1 -1
  76. sglang-0.3.4.post2.dist-info/METADATA +0 -899
  77. {sglang-0.3.4.post2.dist-info → sglang-0.3.5.dist-info}/LICENSE +0 -0
  78. {sglang-0.3.4.post2.dist-info → sglang-0.3.5.dist-info}/top_level.txt +0 -0
@@ -56,49 +56,47 @@ class GenerateReqInput:
56
56
  # LoRA related
57
57
  lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
58
58
 
59
- # Whether it is a single request or a batch request
60
- is_single: bool = True
61
-
62
- def post_init(self):
59
+ def normalize_batch_and_arguments(self):
63
60
  if (self.text is None and self.input_ids is None) or (
64
61
  self.text is not None and self.input_ids is not None
65
62
  ):
66
63
  raise ValueError("Either text or input_ids should be provided.")
67
64
 
68
- self.is_single = False
65
+ # Derive the batch size
69
66
  if self.text is not None:
70
67
  if isinstance(self.text, str):
71
68
  self.is_single = True
72
69
  self.batch_size = 1
73
70
  else:
71
+ self.is_single = False
74
72
  self.batch_size = len(self.text)
75
73
  else:
76
74
  if isinstance(self.input_ids[0], int):
77
75
  self.is_single = True
78
76
  self.batch_size = 1
79
77
  else:
78
+ self.is_single = False
80
79
  self.batch_size = len(self.input_ids)
81
80
 
81
+ # Handle parallel sampling
82
+ # When parallel sampling is used, we always treat the input as a batch.
82
83
  if self.sampling_params is None:
83
84
  self.parallel_sample_num = 1
84
85
  elif isinstance(self.sampling_params, dict):
85
86
  self.parallel_sample_num = self.sampling_params.get("n", 1)
86
87
  else: # isinstance(self.sampling_params, list):
87
88
  self.parallel_sample_num = self.sampling_params[0].get("n", 1)
88
- for sp in self.sampling_params:
89
- # TODO cope with the case that the parallel_sample_num is different for different samples
90
- assert self.parallel_sample_num == sp.get(
91
- "n", 1
92
- ), "The parallel_sample_num should be the same for all samples in sample params."
93
-
94
- if self.parallel_sample_num > 1:
95
- if self.is_single:
96
- self.is_single = False
97
- if self.text is not None:
98
- self.text = [self.text]
99
- if self.input_ids is not None:
100
- self.input_ids = [self.input_ids]
89
+ assert all(self.parallel_sample_num == sampling_params.get("n", 1) for sampling_params in self.sampling_params), (
90
+ "The parallel_sample_num should be the same for all samples in sample params.")
101
91
 
92
+ if self.parallel_sample_num > 1 and self.is_single:
93
+ self.is_single = False
94
+ if self.text is not None:
95
+ self.text = [self.text]
96
+ if self.input_ids is not None:
97
+ self.input_ids = [self.input_ids]
98
+
99
+ # Fill in default arguments
102
100
  if self.is_single:
103
101
  if self.sampling_params is None:
104
102
  self.sampling_params = {}
@@ -114,9 +112,8 @@ class GenerateReqInput:
114
112
  if self.parallel_sample_num == 1:
115
113
  num = self.batch_size
116
114
  else:
117
- # FIXME support cascade inference
118
- # first bs samples are used for caching the prefix for parallel sampling
119
- num = self.batch_size + self.parallel_sample_num * self.batch_size
115
+ # Expand parallel_sample_num
116
+ num = self.batch_size * self.parallel_sample_num
120
117
 
121
118
  if self.image_data is None:
122
119
  self.image_data = [None] * num
@@ -129,14 +126,11 @@ class GenerateReqInput:
129
126
  self.sampling_params = [{}] * num
130
127
  elif not isinstance(self.sampling_params, list):
131
128
  self.sampling_params = [self.sampling_params] * num
132
- else:
133
- assert self.parallel_sample_num == 1
134
129
 
135
130
  if self.rid is None:
136
131
  self.rid = [uuid.uuid4().hex for _ in range(num)]
137
132
  else:
138
133
  assert isinstance(self.rid, list), "The rid should be a list."
139
- assert self.parallel_sample_num == 1
140
134
 
141
135
  if self.return_logprob is None:
142
136
  self.return_logprob = [False] * num
@@ -159,6 +153,26 @@ class GenerateReqInput:
159
153
  else:
160
154
  assert self.parallel_sample_num == 1
161
155
 
156
+ def regenerate_rid(self):
157
+ self.rid = uuid.uuid4().hex
158
+ return self.rid
159
+
160
+ def __getitem__(self, i):
161
+ return GenerateReqInput(
162
+ text=self.text[i] if self.text is not None else None,
163
+ input_ids=self.input_ids[i] if self.input_ids is not None else None,
164
+ image_data=self.image_data[i],
165
+ sampling_params=self.sampling_params[i],
166
+ rid=self.rid[i],
167
+ return_logprob=self.return_logprob[i],
168
+ logprob_start_len=self.logprob_start_len[i],
169
+ top_logprobs_num=self.top_logprobs_num[i],
170
+ return_text_in_logprobs=self.return_text_in_logprobs,
171
+ stream=self.stream,
172
+ modalities=self.modalities[i] if self.modalities else None,
173
+ lora_path=self.lora_path[i] if self.lora_path is not None else None,
174
+ )
175
+
162
176
 
163
177
  @dataclass
164
178
  class TokenizedGenerateReqInput:
@@ -196,85 +210,61 @@ class EmbeddingReqInput:
196
210
  # Dummy sampling params for compatibility
197
211
  sampling_params: Union[List[Dict], Dict] = None
198
212
 
199
- def post_init(self):
213
+ def normalize_batch_and_arguments(self):
200
214
  if (self.text is None and self.input_ids is None) or (
201
215
  self.text is not None and self.input_ids is not None
202
216
  ):
203
217
  raise ValueError("Either text or input_ids should be provided.")
204
218
 
219
+ # Derive the batch size
205
220
  if self.text is not None:
206
- self.is_single = isinstance(self.text, str)
221
+ if isinstance(self.text, str):
222
+ self.is_single = True
223
+ self.batch_size = 1
224
+ else:
225
+ self.is_single = False
226
+ self.batch_size = len(self.text)
207
227
  else:
208
- self.is_single = isinstance(self.input_ids[0], int)
228
+ if isinstance(self.input_ids[0], int):
229
+ self.is_single = True
230
+ self.batch_size = 1
231
+ else:
232
+ self.is_single = False
233
+ self.batch_size = len(self.input_ids)
209
234
 
235
+ # Fill in default arguments
210
236
  if self.is_single:
211
237
  if self.rid is None:
212
238
  self.rid = uuid.uuid4().hex
213
239
  if self.sampling_params is None:
214
240
  self.sampling_params = {}
215
- self.sampling_params["max_new_tokens"] = 1
241
+ self.sampling_params["max_new_tokens"] = 0
216
242
  else:
217
- # support select operation
218
- self.batch_size = (
219
- len(self.text) if self.text is not None else len(self.input_ids)
220
- )
221
243
  if self.rid is None:
222
244
  self.rid = [uuid.uuid4().hex for _ in range(self.batch_size)]
223
245
  else:
224
- if not isinstance(self.rid, list):
225
- raise ValueError("The rid should be a list.")
246
+ assert isinstance(self.rid, list), "The rid should be a list."
247
+
226
248
  if self.sampling_params is None:
227
249
  self.sampling_params = [{}] * self.batch_size
228
250
  for i in range(self.batch_size):
229
- self.sampling_params[i]["max_new_tokens"] = 1
251
+ self.sampling_params[i]["max_new_tokens"] = 0
230
252
 
253
+ def regenerate_rid(self):
254
+ self.rid = uuid.uuid4().hex
255
+ return self.rid
231
256
 
232
- @dataclass
233
- class TokenizedEmbeddingReqInput:
234
- # The request id
235
- rid: str
236
- # The input text
237
- input_text: str
238
- # The input token ids
239
- input_ids: List[int]
240
- # Dummy sampling params for compatibility
241
- sampling_params: SamplingParams
257
+ def __getitem__(self, i):
258
+ return EmbeddingReqInput(
259
+ text=self.text[i] if self.text is not None else None,
260
+ input_ids=self.input_ids[i] if self.input_ids is not None else None,
261
+ sampling_params=self.sampling_params[i],
262
+ rid=self.rid[i],
263
+ )
242
264
 
243
265
 
244
266
  @dataclass
245
- class RewardReqInput:
246
- # The input prompt in the chat format. It can be a single prompt or a batch of prompts.
247
- conv: Union[List[List[Dict]], List[Dict]]
248
- # The request id.
249
- rid: Optional[Union[List[str], str]] = None
250
- # Dummy sampling params for compatibility
251
- sampling_params: Union[List[Dict], Dict] = None
252
-
253
- def post_init(self):
254
- self.is_single = isinstance(self.conv[0], dict)
255
-
256
- if self.is_single:
257
- if self.rid is None:
258
- self.rid = uuid.uuid4().hex
259
- if self.sampling_params is None:
260
- self.sampling_params = {}
261
- self.sampling_params["max_new_tokens"] = 1
262
- else:
263
- # support select operation
264
- self.batch_size = len(self.conv)
265
- if self.rid is None:
266
- self.rid = [uuid.uuid4().hex for _ in range(self.batch_size)]
267
- else:
268
- if not isinstance(self.rid, list):
269
- raise ValueError("The rid should be a list.")
270
- if self.sampling_params is None:
271
- self.sampling_params = [{}] * self.batch_size
272
- for i in range(self.batch_size):
273
- self.sampling_params[i]["max_new_tokens"] = 1
274
-
275
-
276
- @dataclass
277
- class TokenizedRewardReqInput:
267
+ class TokenizedEmbeddingReqInput:
278
268
  # The request id
279
269
  rid: str
280
270
  # The input text
@@ -294,6 +284,8 @@ class BatchTokenIDOut:
294
284
  decoded_texts: List[str]
295
285
  decode_ids: List[int]
296
286
  read_offsets: List[int]
287
+ # Only used when `--skip-tokenizer-init`
288
+ output_ids: Optional[List[int]]
297
289
  skip_special_tokens: List[bool]
298
290
  spaces_between_special_tokens: List[bool]
299
291
  meta_info: List[Dict]
@@ -37,8 +37,7 @@ import torch
37
37
 
38
38
  from sglang.global_config import global_config
39
39
  from sglang.srt.configs.model_config import ModelConfig
40
- from sglang.srt.constrained import RegexGuide
41
- from sglang.srt.constrained.jump_forward import JumpForwardMap
40
+ from sglang.srt.constrained.grammar import Grammar
42
41
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
43
42
  from sglang.srt.mem_cache.chunk_cache import ChunkCache
44
43
  from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
@@ -212,9 +211,6 @@ class Req:
212
211
  # this does not include the jump forward tokens.
213
212
  self.completion_tokens_wo_jump_forward = 0
214
213
 
215
- # The number of cached tokens, that were already cached in the KV store
216
- self.cached_tokens = 0
217
-
218
214
  # For vision inputs
219
215
  self.image_inputs: Optional[ImageInputs] = None
220
216
 
@@ -222,7 +218,10 @@ class Req:
222
218
  self.prefix_indices = []
223
219
  self.extend_input_len = 0
224
220
  self.last_node = None
225
- self.is_inflight_req = 0
221
+ self.is_being_chunked = 0
222
+
223
+ # For retraction
224
+ self.is_retracted = False
226
225
 
227
226
  # Logprobs (arguments)
228
227
  self.return_logprob = False
@@ -243,13 +242,14 @@ class Req:
243
242
  # The relative logprob_start_len in an extend batch
244
243
  self.extend_logprob_start_len = 0
245
244
 
246
- # Embedding
245
+ # Embedding (return values)
247
246
  self.embedding = None
248
247
 
249
248
  # Constrained decoding
250
- self.regex_fsm: RegexGuide = None
251
- self.regex_fsm_state: int = 0
252
- self.jump_forward_map: JumpForwardMap = None
249
+ self.grammar: Optional[Grammar] = None
250
+
251
+ # The number of cached tokens, that were already cached in the KV cache
252
+ self.cached_tokens = 0
253
253
 
254
254
  # For Qwen2-VL
255
255
  self.mrope_position_delta = [] # use mutable object
@@ -359,6 +359,8 @@ class Req:
359
359
  return
360
360
 
361
361
  def jump_forward_and_retokenize(self, jump_forward_str, next_state):
362
+ assert self.grammar is not None and self.tokenizer is not None
363
+
362
364
  if self.origin_input_text is None:
363
365
  # Recovering text can only use unpadded ids
364
366
  self.origin_input_text = self.tokenizer.decode(
@@ -398,7 +400,8 @@ class Req:
398
400
  self.surr_offset = self.read_offset - i
399
401
  break
400
402
 
401
- self.regex_fsm_state = next_state
403
+ # update the inner state of the grammar
404
+ self.grammar.jump_and_retokenize(old_output_ids, self.output_ids, next_state)
402
405
 
403
406
  if self.return_logprob:
404
407
  # For fast-forward part's logprobs
@@ -468,8 +471,8 @@ class ScheduleBatch:
468
471
  # Stream
469
472
  has_stream: bool = False
470
473
 
471
- # Has regex
472
- has_regex: bool = False
474
+ # Has grammar
475
+ has_grammar: bool = False
473
476
 
474
477
  # device
475
478
  device: str = "cuda"
@@ -477,7 +480,7 @@ class ScheduleBatch:
477
480
  @classmethod
478
481
  def init_new(
479
482
  cls,
480
- reqs,
483
+ reqs: List[Req],
481
484
  req_to_token_pool,
482
485
  token_to_kv_pool,
483
486
  tree_cache,
@@ -491,7 +494,7 @@ class ScheduleBatch:
491
494
  model_config=model_config,
492
495
  return_logprob=any(req.return_logprob for req in reqs),
493
496
  has_stream=any(req.stream for req in reqs),
494
- has_regex=any(req.regex_fsm for req in reqs),
497
+ has_grammar=any(req.grammar for req in reqs),
495
498
  device=req_to_token_pool.device,
496
499
  )
497
500
 
@@ -561,7 +564,7 @@ class ScheduleBatch:
561
564
  seq_lens[i] -= encoder_len
562
565
 
563
566
  if len(req.prefix_indices) < encoder_len:
564
- # NOTE: the encoder part should considered as a whole
567
+ # NOTE: the encoder part should be considered as a whole
565
568
  assert len(req.prefix_indices) == 0
566
569
  input_ids[i] = input_ids[i][encoder_len:]
567
570
  encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len])
@@ -648,6 +651,7 @@ class ScheduleBatch:
648
651
 
649
652
  req.extend_logprob_start_len = extend_logprob_start_len
650
653
  pt += req.extend_input_len
654
+ req.is_retracted = False
651
655
 
652
656
  # Set fields
653
657
  self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
@@ -780,6 +784,7 @@ class ScheduleBatch:
780
784
  req.prefix_indices = []
781
785
  req.last_node = None
782
786
  req.extend_input_len = 0
787
+ req.is_retracted = True
783
788
 
784
789
  # For incremental logprobs
785
790
  req.last_update_decode_tokens = 0
@@ -803,26 +808,10 @@ class ScheduleBatch:
803
808
  keep_indices = set(i for i in range(len(self.reqs)))
804
809
 
805
810
  for i, req in enumerate(self.reqs):
806
- if req.jump_forward_map is not None:
807
- jump_forward_bytes = req.jump_forward_map.jump_forward_byte(
808
- req.regex_fsm_state
809
- )
810
- if jump_forward_bytes is not None and len(jump_forward_bytes) > 1:
811
- suffix_bytes = []
812
- continuation_range = range(0x80, 0xC0)
813
- cur_state = req.regex_fsm_state
814
- while (
815
- len(jump_forward_bytes)
816
- and jump_forward_bytes[0][0] in continuation_range
817
- ):
818
- # continuation bytes
819
- byte_edge = jump_forward_bytes.pop(0)
820
- suffix_bytes.append(byte_edge[0])
821
- cur_state = byte_edge[1]
822
-
823
- suffix_tokens = [f"<0x{hex(b)[2:].upper()}>" for b in suffix_bytes]
824
- suffix_ids = req.tokenizer.convert_tokens_to_ids(suffix_tokens)
825
-
811
+ if req.grammar is not None:
812
+ jump_helper = req.grammar.try_jump(req.tokenizer)
813
+ if jump_helper.can_jump():
814
+ suffix_ids = jump_helper.suffix_ids
826
815
  # Current ids, for cache and revert
827
816
  cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1]
828
817
  cur_output_ids = req.output_ids
@@ -836,10 +825,8 @@ class ScheduleBatch:
836
825
  (
837
826
  jump_forward_str,
838
827
  next_state,
839
- ) = req.jump_forward_map.jump_forward_symbol(cur_state)
828
+ ) = req.grammar.jump_forward_str_state(jump_helper)
840
829
 
841
- # Make the incrementally decoded text part of jump_forward_str
842
- # so that the UTF-8 will not corrupt
843
830
  jump_forward_str = new_text + jump_forward_str
844
831
  if not req.jump_forward_and_retokenize(
845
832
  jump_forward_str, next_state
@@ -906,7 +893,7 @@ class ScheduleBatch:
906
893
 
907
894
  def filter_batch(
908
895
  self,
909
- current_inflight_req: Optional[Req] = None,
896
+ being_chunked_req: Optional[Req] = None,
910
897
  keep_indices: Optional[List[int]] = None,
911
898
  ):
912
899
  if keep_indices is None:
@@ -914,7 +901,7 @@ class ScheduleBatch:
914
901
  i
915
902
  for i in range(len(self.reqs))
916
903
  if not self.reqs[i].finished()
917
- and self.reqs[i] is not current_inflight_req
904
+ and self.reqs[i] is not being_chunked_req
918
905
  ]
919
906
 
920
907
  if keep_indices is None or len(keep_indices) == 0:
@@ -946,7 +933,7 @@ class ScheduleBatch:
946
933
  self.top_logprobs_nums = None
947
934
 
948
935
  self.has_stream = any(req.stream for req in self.reqs)
949
- self.has_regex = any(req.regex_fsm for req in self.reqs)
936
+ self.has_grammar = any(req.grammar for req in self.reqs)
950
937
 
951
938
  self.sampling_info.filter_batch(keep_indices, new_indices)
952
939
 
@@ -979,7 +966,7 @@ class ScheduleBatch:
979
966
 
980
967
  self.return_logprob = self.return_logprob or other.return_logprob
981
968
  self.has_stream = self.has_stream or other.has_stream
982
- self.has_regex = self.has_regex or other.has_regex
969
+ self.has_grammar = self.has_grammar or other.has_grammar
983
970
 
984
971
  def get_model_worker_batch(self):
985
972
  if self.forward_mode.is_decode():
@@ -989,13 +976,10 @@ class ScheduleBatch:
989
976
  extend_prefix_lens = self.prefix_lens
990
977
  extend_logprob_start_lens = self.extend_logprob_start_lens
991
978
 
992
- if self.has_regex:
993
- self.sampling_info.regex_fsms = [req.regex_fsm for req in self.reqs]
994
- self.sampling_info.regex_fsm_states = [
995
- req.regex_fsm_state for req in self.reqs
996
- ]
979
+ if self.has_grammar:
980
+ self.sampling_info.grammars = [req.grammar for req in self.reqs]
997
981
  else:
998
- self.sampling_info.regex_fsms = None
982
+ self.sampling_info.grammars = None
999
983
 
1000
984
  global bid
1001
985
  bid += 1
@@ -30,7 +30,9 @@ from sglang.srt.mem_cache.radix_cache import TreeNode
30
30
  # This can prevent the server from being too conservative.
31
31
  # Note that this only clips the estimation in the scheduler but does not change the stop
32
32
  # condition. The request can still generate tokens until it hits the unclipped max_new_tokens.
33
- CLIP_MAX_NEW_TOKENS = int(os.environ.get("SGLANG_CLIP_MAX_NEW_TOKENS", "4096"))
33
+ CLIP_MAX_NEW_TOKENS_ESTIMATION = int(
34
+ os.environ.get("SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION", "4096")
35
+ )
34
36
 
35
37
 
36
38
  class SchedulePolicy:
@@ -43,9 +45,15 @@ class SchedulePolicy:
43
45
  self.tree_cache = tree_cache
44
46
 
45
47
  def calc_priority(self, waiting_queue: List[Req]):
48
+ if len(waiting_queue) > 128 and self.policy == "lpm":
49
+ # Turn off the expensive prefix matching and sorting when the #queue is large.
50
+ policy = "fcfs"
51
+ else:
52
+ policy = self.policy
53
+
46
54
  # Compute matched prefix length
47
55
  prefix_computed = False
48
- if self.policy == "lpm" or self.policy == "dfs-weight":
56
+ if policy == "lpm" or policy == "dfs-weight":
49
57
  for r in waiting_queue:
50
58
  # NOTE: the prefix_indices must always be aligned with last_node
51
59
  r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
@@ -54,18 +62,18 @@ class SchedulePolicy:
54
62
 
55
63
  prefix_computed = True
56
64
 
57
- if self.policy == "lpm":
65
+ if policy == "lpm":
58
66
  # Longest Prefix Match
59
67
  waiting_queue.sort(key=lambda x: -len(x.prefix_indices))
60
- elif self.policy == "fcfs":
68
+ elif policy == "fcfs":
61
69
  # first come first serve
62
70
  pass
63
- elif self.policy == "lof":
71
+ elif policy == "lof":
64
72
  # longest output first
65
73
  waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens)
66
- elif self.policy == "random":
74
+ elif policy == "random":
67
75
  random.shuffle(waiting_queue)
68
- elif self.policy == "dfs-weight":
76
+ elif policy == "dfs-weight":
69
77
  last_node_to_reqs = defaultdict(list)
70
78
  for req in waiting_queue:
71
79
  last_node_to_reqs[req.last_node].append(req)
@@ -83,7 +91,7 @@ class SchedulePolicy:
83
91
  waiting_queue,
84
92
  )
85
93
  else:
86
- raise ValueError(f"Unknown schedule_policy: {self.policy}")
94
+ raise ValueError(f"Unknown schedule_policy: {policy=}")
87
95
 
88
96
  return prefix_computed
89
97
 
@@ -146,7 +154,7 @@ class PrefillAdder:
146
154
  [
147
155
  min(
148
156
  (r.sampling_params.max_new_tokens - len(r.output_ids)),
149
- CLIP_MAX_NEW_TOKENS,
157
+ CLIP_MAX_NEW_TOKENS_ESTIMATION,
150
158
  )
151
159
  * self.new_token_ratio
152
160
  for r in running_batch.reqs
@@ -186,7 +194,7 @@ class PrefillAdder:
186
194
  len(req.prefix_indices),
187
195
  req.extend_input_len,
188
196
  (
189
- min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS)
197
+ min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION)
190
198
  if not truncated
191
199
  else 0
192
200
  ),
@@ -258,7 +266,7 @@ class PrefillAdder:
258
266
  self._prefill_one_req(
259
267
  0,
260
268
  req.extend_input_len,
261
- min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS),
269
+ min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION),
262
270
  )
263
271
  else:
264
272
  # Chunked prefill
@@ -276,7 +284,7 @@ class PrefillAdder:
276
284
  return self.add_one_req_ignore_eos(req)
277
285
 
278
286
  total_tokens = req.extend_input_len + min(
279
- req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS
287
+ req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION
280
288
  )
281
289
  input_tokens = req.extend_input_len
282
290
  prefix_len = len(req.prefix_indices)
@@ -302,7 +310,10 @@ class PrefillAdder:
302
310
  self._prefill_one_req(
303
311
  prefix_len,
304
312
  input_tokens,
305
- min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS),
313
+ min(
314
+ req.sampling_params.max_new_tokens,
315
+ CLIP_MAX_NEW_TOKENS_ESTIMATION,
316
+ ),
306
317
  )
307
318
  else:
308
319
  # Chunked prefill