sglang 0.3.4.post1__py3-none-any.whl → 0.3.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (91) hide show
  1. sglang/api.py +1 -1
  2. sglang/bench_latency.py +3 -3
  3. sglang/bench_server_latency.py +2 -3
  4. sglang/bench_serving.py +92 -0
  5. sglang/global_config.py +9 -3
  6. sglang/lang/chat_template.py +50 -25
  7. sglang/lang/interpreter.py +9 -1
  8. sglang/lang/ir.py +11 -2
  9. sglang/launch_server.py +1 -1
  10. sglang/srt/configs/model_config.py +76 -15
  11. sglang/srt/constrained/__init__.py +18 -0
  12. sglang/srt/constrained/bnf_cache.py +61 -0
  13. sglang/srt/constrained/fsm_cache.py +10 -3
  14. sglang/srt/constrained/grammar.py +190 -0
  15. sglang/srt/hf_transformers_utils.py +20 -5
  16. sglang/srt/layers/attention/flashinfer_backend.py +5 -5
  17. sglang/srt/layers/attention/triton_ops/decode_attention.py +110 -30
  18. sglang/srt/layers/attention/triton_ops/prefill_attention.py +1 -1
  19. sglang/srt/layers/fused_moe/fused_moe.py +4 -3
  20. sglang/srt/layers/fused_moe/layer.py +28 -0
  21. sglang/srt/layers/logits_processor.py +5 -5
  22. sglang/srt/layers/quantization/base_config.py +16 -1
  23. sglang/srt/layers/rotary_embedding.py +15 -48
  24. sglang/srt/layers/sampler.py +51 -39
  25. sglang/srt/layers/vocab_parallel_embedding.py +486 -0
  26. sglang/srt/managers/data_parallel_controller.py +8 -7
  27. sglang/srt/managers/detokenizer_manager.py +11 -9
  28. sglang/srt/managers/image_processor.py +4 -3
  29. sglang/srt/managers/io_struct.py +80 -78
  30. sglang/srt/managers/schedule_batch.py +46 -52
  31. sglang/srt/managers/schedule_policy.py +24 -13
  32. sglang/srt/managers/scheduler.py +145 -82
  33. sglang/srt/managers/tokenizer_manager.py +236 -334
  34. sglang/srt/managers/tp_worker.py +5 -5
  35. sglang/srt/managers/tp_worker_overlap_thread.py +58 -21
  36. sglang/srt/mem_cache/flush_cache.py +1 -1
  37. sglang/srt/mem_cache/memory_pool.py +10 -3
  38. sglang/srt/model_executor/cuda_graph_runner.py +34 -23
  39. sglang/srt/model_executor/forward_batch_info.py +6 -9
  40. sglang/srt/model_executor/model_runner.py +10 -19
  41. sglang/srt/models/baichuan.py +4 -4
  42. sglang/srt/models/chatglm.py +4 -4
  43. sglang/srt/models/commandr.py +1 -1
  44. sglang/srt/models/dbrx.py +5 -5
  45. sglang/srt/models/deepseek.py +4 -4
  46. sglang/srt/models/deepseek_v2.py +4 -4
  47. sglang/srt/models/exaone.py +4 -4
  48. sglang/srt/models/gemma.py +1 -1
  49. sglang/srt/models/gemma2.py +1 -1
  50. sglang/srt/models/gpt2.py +287 -0
  51. sglang/srt/models/gpt_bigcode.py +1 -1
  52. sglang/srt/models/grok.py +4 -4
  53. sglang/srt/models/internlm2.py +4 -4
  54. sglang/srt/models/llama.py +15 -7
  55. sglang/srt/models/llama_embedding.py +2 -10
  56. sglang/srt/models/llama_reward.py +5 -0
  57. sglang/srt/models/minicpm.py +4 -4
  58. sglang/srt/models/minicpm3.py +4 -4
  59. sglang/srt/models/mixtral.py +7 -5
  60. sglang/srt/models/mixtral_quant.py +4 -4
  61. sglang/srt/models/mllama.py +5 -5
  62. sglang/srt/models/olmo.py +4 -4
  63. sglang/srt/models/olmoe.py +4 -4
  64. sglang/srt/models/qwen.py +4 -4
  65. sglang/srt/models/qwen2.py +4 -4
  66. sglang/srt/models/qwen2_moe.py +4 -4
  67. sglang/srt/models/qwen2_vl.py +4 -8
  68. sglang/srt/models/stablelm.py +4 -4
  69. sglang/srt/models/torch_native_llama.py +4 -4
  70. sglang/srt/models/xverse.py +4 -4
  71. sglang/srt/models/xverse_moe.py +4 -4
  72. sglang/srt/openai_api/adapter.py +52 -66
  73. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +6 -3
  74. sglang/srt/sampling/sampling_batch_info.py +7 -13
  75. sglang/srt/sampling/sampling_params.py +5 -7
  76. sglang/srt/server.py +41 -33
  77. sglang/srt/server_args.py +34 -5
  78. sglang/srt/utils.py +40 -56
  79. sglang/test/run_eval.py +2 -0
  80. sglang/test/runners.py +2 -1
  81. sglang/test/srt/sampling/penaltylib/utils.py +1 -0
  82. sglang/test/test_utils.py +151 -6
  83. sglang/utils.py +62 -1
  84. sglang/version.py +1 -1
  85. sglang-0.3.5.dist-info/METADATA +344 -0
  86. sglang-0.3.5.dist-info/RECORD +152 -0
  87. {sglang-0.3.4.post1.dist-info → sglang-0.3.5.dist-info}/WHEEL +1 -1
  88. sglang-0.3.4.post1.dist-info/METADATA +0 -900
  89. sglang-0.3.4.post1.dist-info/RECORD +0 -148
  90. {sglang-0.3.4.post1.dist-info → sglang-0.3.5.dist-info}/LICENSE +0 -0
  91. {sglang-0.3.4.post1.dist-info → sglang-0.3.5.dist-info}/top_level.txt +0 -0
@@ -15,22 +15,21 @@ limitations under the License.
15
15
 
16
16
  """A scheduler that manages a tensor parallel GPU worker."""
17
17
 
18
- import json
19
18
  import logging
20
19
  import os
20
+ import threading
21
21
  import time
22
22
  import warnings
23
23
  from collections import deque
24
24
  from types import SimpleNamespace
25
- from typing import List, Optional, Union
25
+ from typing import List, Optional
26
26
 
27
27
  import torch
28
28
  import zmq
29
29
 
30
30
  from sglang.global_config import global_config
31
31
  from sglang.srt.configs.model_config import ModelConfig
32
- from sglang.srt.constrained.fsm_cache import FSMCache
33
- from sglang.srt.constrained.jump_forward import JumpForwardCache
32
+ from sglang.srt.constrained.grammar import GrammarCache
34
33
  from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
35
34
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
36
35
  from sglang.srt.managers.io_struct import (
@@ -38,10 +37,11 @@ from sglang.srt.managers.io_struct import (
38
37
  BatchEmbeddingOut,
39
38
  BatchTokenIDOut,
40
39
  FlushCacheReq,
40
+ GetMemPoolSizeReq,
41
+ GetMemPoolSizeReqOutput,
41
42
  ProfileReq,
42
43
  TokenizedEmbeddingReqInput,
43
44
  TokenizedGenerateReqInput,
44
- TokenizedRewardReqInput,
45
45
  UpdateWeightReqInput,
46
46
  UpdateWeightReqOutput,
47
47
  )
@@ -66,10 +66,8 @@ from sglang.srt.server_args import PortArgs, ServerArgs
66
66
  from sglang.srt.utils import (
67
67
  broadcast_pyobj,
68
68
  configure_logger,
69
- is_generation_model,
70
- is_multimodal_model,
69
+ get_zmq_socket,
71
70
  kill_parent_process,
72
- pytorch_profile,
73
71
  set_random_seed,
74
72
  suppress_other_loggers,
75
73
  )
@@ -77,6 +75,7 @@ from sglang.utils import get_exception_traceback
77
75
 
78
76
  logger = logging.getLogger(__name__)
79
77
 
78
+
80
79
  # Crash on warning if we are running CI tests
81
80
  crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"
82
81
 
@@ -104,16 +103,26 @@ class Scheduler:
104
103
  self.lora_paths = server_args.lora_paths
105
104
  self.max_loras_per_batch = server_args.max_loras_per_batch
106
105
  self.enable_overlap = server_args.enable_overlap_schedule
106
+ self.skip_tokenizer_init = server_args.skip_tokenizer_init
107
107
 
108
108
  # Init inter-process communication
109
109
  context = zmq.Context(2)
110
110
 
111
111
  if self.tp_rank == 0:
112
- self.recv_from_tokenizer = context.socket(zmq.PULL)
113
- self.recv_from_tokenizer.bind(f"ipc://{port_args.scheduler_input_ipc_name}")
112
+ self.recv_from_tokenizer = get_zmq_socket(
113
+ context, zmq.PULL, port_args.scheduler_input_ipc_name
114
+ )
114
115
 
115
- self.send_to_detokenizer = context.socket(zmq.PUSH)
116
- self.send_to_detokenizer.connect(f"ipc://{port_args.detokenizer_ipc_name}")
116
+ if server_args.skip_tokenizer_init:
117
+ # Directly send to the tokenizer/api
118
+ self.send_to_detokenizer = get_zmq_socket(
119
+ context, zmq.PUSH, port_args.tokenizer_ipc_name
120
+ )
121
+ else:
122
+ # Send to the detokenizer
123
+ self.send_to_detokenizer = get_zmq_socket(
124
+ context, zmq.PUSH, port_args.detokenizer_ipc_name
125
+ )
117
126
  else:
118
127
  self.recv_from_tokenizer = None
119
128
  self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
@@ -121,15 +130,17 @@ class Scheduler:
121
130
  # Init tokenizer
122
131
  self.model_config = ModelConfig(
123
132
  server_args.model_path,
124
- server_args.trust_remote_code,
133
+ trust_remote_code=server_args.trust_remote_code,
125
134
  context_length=server_args.context_length,
126
- model_override_args=json.loads(server_args.json_model_override_args),
135
+ model_override_args=server_args.json_model_override_args,
136
+ is_embedding=server_args.is_embedding,
127
137
  )
138
+ self.is_generation = self.model_config.is_generation
128
139
 
129
140
  if server_args.skip_tokenizer_init:
130
141
  self.tokenizer = self.processor = None
131
142
  else:
132
- if is_multimodal_model(self.model_config.hf_config.architectures):
143
+ if self.model_config.is_multimodal:
133
144
  self.processor = get_processor(
134
145
  server_args.tokenizer_path,
135
146
  tokenizer_mode=server_args.tokenizer_mode,
@@ -142,9 +153,6 @@ class Scheduler:
142
153
  tokenizer_mode=server_args.tokenizer_mode,
143
154
  trust_remote_code=server_args.trust_remote_code,
144
155
  )
145
- self.is_generation = is_generation_model(
146
- self.model_config.hf_config.architectures, self.server_args.is_embedding
147
- )
148
156
 
149
157
  # Launch a tensor parallel worker
150
158
  if self.enable_overlap:
@@ -211,44 +219,62 @@ class Scheduler:
211
219
  self.waiting_queue: List[Req] = []
212
220
  self.running_batch: Optional[ScheduleBatch] = None
213
221
  self.cur_batch: Optional[ScheduleBatch] = None
214
- self.decode_forward_ct = 0
215
- self.stream_interval = server_args.stream_interval
222
+ self.forward_ct = 0
223
+ self.forward_ct_decode = 0
216
224
  self.num_generated_tokens = 0
217
225
  self.last_stats_tic = time.time()
226
+ self.stream_interval = server_args.stream_interval
218
227
 
219
228
  # Init chunked prefill
220
229
  self.chunked_prefill_size = server_args.chunked_prefill_size
221
- self.current_inflight_req = None
230
+ self.being_chunked_req = None
222
231
  self.is_mixed_chunk = (
223
232
  self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
224
233
  )
225
234
 
226
235
  # Init the FSM cache for constrained generation
236
+ self.grammar_cache = None
237
+
227
238
  if not server_args.skip_tokenizer_init:
228
- self.regex_fsm_cache = FSMCache(
239
+ self.grammar_cache = GrammarCache(
229
240
  server_args.tokenizer_path,
230
241
  {
231
242
  "tokenizer_mode": server_args.tokenizer_mode,
232
243
  "trust_remote_code": server_args.trust_remote_code,
233
244
  },
234
245
  skip_tokenizer_init=server_args.skip_tokenizer_init,
235
- constrained_json_whitespace_pattern=server_args.constrained_json_whitespace_pattern,
246
+ whitespace_patterns=server_args.constrained_json_whitespace_pattern,
247
+ backend=server_args.grammar_backend,
248
+ allow_jump=not server_args.disable_regex_jump_forward,
236
249
  )
237
- self.jump_forward_cache = JumpForwardCache()
238
250
 
239
251
  # Init new token estimation
240
252
  assert (
241
253
  server_args.schedule_conservativeness >= 0
242
254
  ), "Invalid schedule_conservativeness"
243
- self.min_new_token_ratio = min(
244
- global_config.base_min_new_token_ratio
255
+
256
+ self.init_new_token_ratio = min(
257
+ global_config.default_init_new_token_ratio
245
258
  * server_args.schedule_conservativeness,
246
259
  1.0,
247
260
  )
248
- self.new_token_ratio = self.min_new_token_ratio
249
- self.new_token_ratio_decay = global_config.new_token_ratio_decay
261
+ self.min_new_token_ratio = min(
262
+ self.init_new_token_ratio
263
+ * global_config.default_min_new_token_ratio_factor,
264
+ 1.0,
265
+ )
266
+ self.new_token_ratio_decay = (
267
+ self.init_new_token_ratio - self.min_new_token_ratio
268
+ ) / global_config.default_new_token_ratio_decay_steps
269
+ self.new_token_ratio = self.init_new_token_ratio
270
+
250
271
  self.batch_is_full = False
251
272
 
273
+ # Init watchdog thread
274
+ self.watchdog_timeout = server_args.watchdog_timeout
275
+ t = threading.Thread(target=self.watchdog_thread, daemon=True)
276
+ t.start()
277
+
252
278
  # Init profiler
253
279
  if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "":
254
280
  self.profiler = None
@@ -266,6 +292,23 @@ class Scheduler:
266
292
  with_stack=True,
267
293
  )
268
294
 
295
+ def watchdog_thread(self):
296
+ self.watchdog_last_forward_ct = 0
297
+ self.watchdog_last_time = time.time()
298
+
299
+ while True:
300
+ if self.cur_batch is not None:
301
+ if self.watchdog_last_forward_ct == self.forward_ct:
302
+ if time.time() > self.watchdog_last_time + self.watchdog_timeout:
303
+ logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
304
+ break
305
+ else:
306
+ self.watchdog_last_forward_ct = self.forward_ct
307
+ self.watchdog_last_time = time.time()
308
+ time.sleep(self.watchdog_timeout / 2)
309
+
310
+ kill_parent_process()
311
+
269
312
  @torch.inference_mode()
270
313
  def event_loop_normal(self):
271
314
  """A normal blocking scheduler loop."""
@@ -276,6 +319,7 @@ class Scheduler:
276
319
  self.process_input_requests(recv_reqs)
277
320
 
278
321
  batch = self.get_next_batch_to_run()
322
+ self.cur_batch = batch
279
323
 
280
324
  if batch:
281
325
  result = self.run_batch(batch)
@@ -293,7 +337,7 @@ class Scheduler:
293
337
  self.process_batch_result(batch, result)
294
338
  else:
295
339
  self.check_memory()
296
- self.new_token_ratio = global_config.init_new_token_ratio
340
+ self.new_token_ratio = self.init_new_token_ratio
297
341
 
298
342
  self.last_batch = batch
299
343
 
@@ -320,7 +364,7 @@ class Scheduler:
320
364
  self.process_batch_result(tmp_batch, tmp_result)
321
365
  elif batch is None:
322
366
  self.check_memory()
323
- self.new_token_ratio = global_config.init_new_token_ratio
367
+ self.new_token_ratio = self.init_new_token_ratio
324
368
 
325
369
  self.last_batch = batch
326
370
 
@@ -345,9 +389,7 @@ class Scheduler:
345
389
  for recv_req in recv_reqs:
346
390
  if isinstance(recv_req, TokenizedGenerateReqInput):
347
391
  self.handle_generate_request(recv_req)
348
- elif isinstance(
349
- recv_req, (TokenizedEmbeddingReqInput, TokenizedRewardReqInput)
350
- ):
392
+ elif isinstance(recv_req, TokenizedEmbeddingReqInput):
351
393
  self.handle_embedding_request(recv_req)
352
394
  elif isinstance(recv_req, FlushCacheReq):
353
395
  self.flush_cache()
@@ -363,6 +405,10 @@ class Scheduler:
363
405
  self.start_profile()
364
406
  else:
365
407
  self.stop_profile()
408
+ elif isinstance(recv_req, GetMemPoolSizeReq):
409
+ self.send_to_detokenizer.send_pyobj(
410
+ GetMemPoolSizeReqOutput(self.max_total_num_tokens)
411
+ )
366
412
  else:
367
413
  raise ValueError(f"Invalid request: {recv_req}")
368
414
 
@@ -397,26 +443,24 @@ class Scheduler:
397
443
  # By default, only return the logprobs for output tokens
398
444
  req.logprob_start_len = len(recv_req.input_ids) - 1
399
445
 
400
- # Init regex FSM
446
+ # Init regex FSM or BNF
401
447
  if (
402
448
  req.sampling_params.json_schema is not None
403
449
  or req.sampling_params.regex is not None
404
450
  ):
451
+ assert self.grammar_cache is not None
405
452
  if req.sampling_params.json_schema is not None:
406
- req.regex_fsm, computed_regex_string = self.regex_fsm_cache.query(
407
- ("json", req.sampling_params.json_schema)
453
+ req.grammar = self.grammar_cache.query(
454
+ ("json", req.sampling_params.json_schema),
455
+ self.model_config.vocab_size,
408
456
  )
409
457
  elif req.sampling_params.regex is not None:
410
- req.regex_fsm, computed_regex_string = self.regex_fsm_cache.query(
411
- ("regex", req.sampling_params.regex)
412
- )
413
- if not self.disable_regex_jump_forward:
414
- req.jump_forward_map = self.jump_forward_cache.query(
415
- computed_regex_string
458
+ req.grammar = self.grammar_cache.query(
459
+ ("regex", req.sampling_params.regex), self.model_config.vocab_size
416
460
  )
417
461
 
418
462
  # Truncate prompts that are too long
419
- if len(req.origin_input_ids) >= self.max_req_input_len:
463
+ if len(req.origin_input_ids) > self.max_req_input_len:
420
464
  logger.warning(
421
465
  "Request length is longer than the KV cache pool size or "
422
466
  "the max context length. Truncated!!!"
@@ -436,7 +480,7 @@ class Scheduler:
436
480
 
437
481
  def handle_embedding_request(
438
482
  self,
439
- recv_req: Union[TokenizedEmbeddingReqInput, TokenizedRewardReqInput],
483
+ recv_req: TokenizedEmbeddingReqInput,
440
484
  ):
441
485
  req = Req(
442
486
  recv_req.rid,
@@ -501,13 +545,13 @@ class Scheduler:
501
545
  and not self.last_batch.forward_mode.is_decode()
502
546
  and not self.last_batch.is_empty()
503
547
  ):
504
- if self.current_inflight_req:
548
+ if self.being_chunked_req:
505
549
  self.last_batch.filter_batch(
506
- current_inflight_req=self.current_inflight_req
550
+ being_chunked_req=self.being_chunked_req
507
551
  )
508
- self.tree_cache.cache_unfinished_req(self.current_inflight_req)
552
+ self.tree_cache.cache_unfinished_req(self.being_chunked_req)
509
553
  # Inflight request keeps its rid but will get a new req_pool_idx.
510
- self.req_to_token_pool.free(self.current_inflight_req.req_pool_idx)
554
+ self.req_to_token_pool.free(self.being_chunked_req.req_pool_idx)
511
555
  self.batch_is_full = False
512
556
  if not self.last_batch.is_empty():
513
557
  if self.running_batch is None:
@@ -538,7 +582,7 @@ class Scheduler:
538
582
  # Handle the cases where prefill is not allowed
539
583
  if (
540
584
  self.batch_is_full or len(self.waiting_queue) == 0
541
- ) and self.current_inflight_req is None:
585
+ ) and self.being_chunked_req is None:
542
586
  return None
543
587
 
544
588
  running_bs = len(self.running_batch.reqs) if self.running_batch else 0
@@ -561,13 +605,11 @@ class Scheduler:
561
605
  num_mixed_running,
562
606
  )
563
607
 
564
- has_inflight = self.current_inflight_req is not None
608
+ has_inflight = self.being_chunked_req is not None
565
609
  if has_inflight:
566
- self.current_inflight_req.init_next_round_input(
567
- None if prefix_computed else self.tree_cache
568
- )
569
- self.current_inflight_req = adder.add_inflight_req(
570
- self.current_inflight_req
610
+ self.being_chunked_req.init_next_round_input()
611
+ self.being_chunked_req = adder.add_inflight_req(
612
+ self.being_chunked_req
571
613
  )
572
614
 
573
615
  if self.lora_paths:
@@ -611,11 +653,11 @@ class Scheduler:
611
653
  ]
612
654
 
613
655
  if adder.new_inflight_req is not None:
614
- assert self.current_inflight_req is None
615
- self.current_inflight_req = adder.new_inflight_req
656
+ assert self.being_chunked_req is None
657
+ self.being_chunked_req = adder.new_inflight_req
616
658
 
617
- if self.current_inflight_req:
618
- self.current_inflight_req.is_inflight_req += 1
659
+ if self.being_chunked_req:
660
+ self.being_chunked_req.is_being_chunked += 1
619
661
 
620
662
  # Print stats
621
663
  if self.tp_rank == 0:
@@ -670,9 +712,11 @@ class Scheduler:
670
712
 
671
713
  # Mixed-style chunked prefill
672
714
  if self.is_mixed_chunk and self.running_batch is not None:
673
- self.running_batch.prepare_for_decode(self.enable_overlap)
674
- new_batch.mix_with_running(self.running_batch)
675
- new_batch.decoding_reqs = self.running_batch.reqs
715
+ self.running_batch.filter_batch()
716
+ if not self.running_batch.is_empty():
717
+ self.running_batch.prepare_for_decode(self.enable_overlap)
718
+ new_batch.mix_with_running(self.running_batch)
719
+ new_batch.decoding_reqs = self.running_batch.reqs
676
720
  self.running_batch = None
677
721
  else:
678
722
  new_batch.decoding_reqs = None
@@ -721,6 +765,8 @@ class Scheduler:
721
765
 
722
766
  def run_batch(self, batch: ScheduleBatch):
723
767
  """Run a batch."""
768
+ self.forward_ct += 1
769
+
724
770
  if self.is_generation:
725
771
  if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
726
772
  model_worker_batch = batch.get_model_worker_batch()
@@ -729,7 +775,7 @@ class Scheduler:
729
775
  )
730
776
  else:
731
777
  logits_output = None
732
- if self.tokenizer is not None:
778
+ if self.skip_tokenizer_init:
733
779
  next_token_ids = torch.full(
734
780
  (batch.batch_size(),), self.tokenizer.eos_token_id
735
781
  )
@@ -753,6 +799,7 @@ class Scheduler:
753
799
  self.process_batch_result_prefill(batch, result)
754
800
 
755
801
  def process_batch_result_prefill(self, batch: ScheduleBatch, result):
802
+
756
803
  if self.is_generation:
757
804
  logits_output, next_token_ids, bid = result
758
805
 
@@ -778,9 +825,10 @@ class Scheduler:
778
825
  # Check finish conditions
779
826
  logprob_pt = 0
780
827
  for i, req in enumerate(batch.reqs):
781
- if req.is_inflight_req > 0:
782
- req.is_inflight_req -= 1
783
- else:
828
+ if req.is_retracted:
829
+ continue
830
+
831
+ if req.is_being_chunked <= 0:
784
832
  # Inflight reqs' prefill is not finished
785
833
  req.completion_tokens_wo_jump_forward += 1
786
834
  req.output_ids.append(next_token_ids[i])
@@ -791,24 +839,28 @@ class Scheduler:
791
839
  elif not batch.decoding_reqs or req not in batch.decoding_reqs:
792
840
  self.tree_cache.cache_unfinished_req(req)
793
841
 
794
- if req.regex_fsm is not None:
795
- req.regex_fsm_state = req.regex_fsm.get_next_state(
796
- req.regex_fsm_state, next_token_ids[i]
797
- )
842
+ if req.grammar is not None:
843
+ req.grammar.accept_token(next_token_ids[i])
798
844
 
799
845
  if req.return_logprob:
800
846
  logprob_pt += self.add_logprob_return_values(
801
847
  i, req, logprob_pt, next_token_ids, logits_output
802
848
  )
849
+ else:
850
+ req.is_being_chunked -= 1
851
+
803
852
  else: # embedding or reward model
804
853
  embeddings, bid = result
805
854
  embeddings = embeddings.tolist()
806
855
 
807
856
  # Check finish conditions
808
857
  for i, req in enumerate(batch.reqs):
858
+ if req.is_retracted:
859
+ continue
860
+
809
861
  req.embedding = embeddings[i]
810
- if req.is_inflight_req > 0:
811
- req.is_inflight_req -= 1
862
+ if req.is_being_chunked > 0:
863
+ req.is_being_chunked -= 1
812
864
  else:
813
865
  # Inflight reqs' prefill is not finished
814
866
  # dummy output token for embedding models
@@ -828,6 +880,7 @@ class Scheduler:
828
880
 
829
881
  if self.enable_overlap:
830
882
  logits_output, next_token_ids = self.tp_worker.resulve_batch_result(bid)
883
+ next_token_logprobs = logits_output.next_token_logprobs
831
884
  else:
832
885
  # Move next_token_ids and logprobs to cpu
833
886
  if batch.return_logprob:
@@ -841,7 +894,12 @@ class Scheduler:
841
894
 
842
895
  # Check finish condition
843
896
  for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
844
- if self.server_args.enable_overlap_schedule and req.finished():
897
+ if req.is_retracted:
898
+ continue
899
+
900
+ if self.server_args.enable_overlap_schedule and (
901
+ req.finished()
902
+ ):
845
903
  self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
846
904
  continue
847
905
 
@@ -849,10 +907,8 @@ class Scheduler:
849
907
  req.output_ids.append(next_token_id)
850
908
  req.check_finished()
851
909
 
852
- if req.regex_fsm is not None:
853
- req.regex_fsm_state = req.regex_fsm.get_next_state(
854
- req.regex_fsm_state, next_token_id
855
- )
910
+ if req.grammar is not None:
911
+ req.grammar.accept_token(next_token_id)
856
912
 
857
913
  if req.finished():
858
914
  self.tree_cache.cache_finished_req(req)
@@ -868,8 +924,8 @@ class Scheduler:
868
924
 
869
925
  self.token_to_kv_pool.free_group_end()
870
926
 
871
- self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
872
- if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
927
+ self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
928
+ if self.tp_rank == 0 and self.forward_ct_decode % self.server_args.decode_log_interval == 0:
873
929
  self.print_decode_stats()
874
930
 
875
931
  def add_logprob_return_values(
@@ -948,22 +1004,24 @@ class Scheduler:
948
1004
  def stream_output(self, reqs: List[Req]):
949
1005
  """Stream the output to detokenizer."""
950
1006
  output_rids = []
951
- output_meta_info = []
1007
+ output_meta_info: List[dict] = []
952
1008
  output_finished_reason: List[BaseFinishReason] = []
953
1009
  if self.is_generation:
954
1010
  output_vids = []
955
1011
  decoded_texts = []
956
1012
  output_read_ids = []
957
1013
  output_read_offsets = []
1014
+ output_ids = []
958
1015
  output_skip_special_tokens = []
959
1016
  output_spaces_between_special_tokens = []
960
1017
  output_no_stop_trim = []
961
1018
  else: # embedding or reward model
962
1019
  output_embeddings = []
963
1020
 
964
- is_stream_iter = self.decode_forward_ct % self.stream_interval == 0
1021
+ is_stream_iter = self.forward_ct_decode % self.stream_interval == 0
965
1022
 
966
1023
  for req in reqs:
1024
+ # TODO(lianmin): revisit this for overlap + retract + stream
967
1025
  if req.finished() or (
968
1026
  req.stream and (is_stream_iter or len(req.output_ids) == 1)
969
1027
  ):
@@ -975,6 +1033,8 @@ class Scheduler:
975
1033
  read_ids, read_offset = req.init_incremental_detokenize()
976
1034
  output_read_ids.append(read_ids)
977
1035
  output_read_offsets.append(read_offset)
1036
+ if self.skip_tokenizer_init:
1037
+ output_ids.append(req.output_ids)
978
1038
  output_skip_special_tokens.append(
979
1039
  req.sampling_params.skip_special_tokens
980
1040
  )
@@ -1026,6 +1086,7 @@ class Scheduler:
1026
1086
  decoded_texts,
1027
1087
  output_read_ids,
1028
1088
  output_read_offsets,
1089
+ output_ids,
1029
1090
  output_skip_special_tokens,
1030
1091
  output_spaces_between_special_tokens,
1031
1092
  output_meta_info,
@@ -1050,7 +1111,9 @@ class Scheduler:
1050
1111
  ):
1051
1112
  self.tree_cache.reset()
1052
1113
  self.tree_cache_metrics = {"total": 0, "hit": 0}
1053
- self.regex_fsm_cache.reset()
1114
+ if self.grammar_cache is not None:
1115
+ self.grammar_cache.reset()
1116
+ # TODO(dark): reset the bnf cache
1054
1117
  self.req_to_token_pool.clear()
1055
1118
  self.token_to_kv_pool.clear()
1056
1119
  torch.cuda.empty_cache()