sglang 0.3.4.post2__py3-none-any.whl → 0.3.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. sglang/api.py +1 -1
  2. sglang/bench_latency.py +3 -3
  3. sglang/bench_server_latency.py +2 -3
  4. sglang/bench_serving.py +92 -0
  5. sglang/global_config.py +9 -3
  6. sglang/lang/chat_template.py +50 -25
  7. sglang/lang/interpreter.py +9 -1
  8. sglang/lang/ir.py +11 -2
  9. sglang/launch_server.py +1 -1
  10. sglang/srt/configs/model_config.py +51 -13
  11. sglang/srt/constrained/__init__.py +18 -0
  12. sglang/srt/constrained/bnf_cache.py +61 -0
  13. sglang/srt/constrained/grammar.py +190 -0
  14. sglang/srt/hf_transformers_utils.py +6 -5
  15. sglang/srt/layers/attention/triton_ops/decode_attention.py +110 -30
  16. sglang/srt/layers/attention/triton_ops/prefill_attention.py +1 -1
  17. sglang/srt/layers/fused_moe/fused_moe.py +4 -3
  18. sglang/srt/layers/fused_moe/layer.py +28 -0
  19. sglang/srt/layers/quantization/base_config.py +16 -1
  20. sglang/srt/layers/vocab_parallel_embedding.py +486 -0
  21. sglang/srt/managers/data_parallel_controller.py +7 -6
  22. sglang/srt/managers/detokenizer_manager.py +9 -11
  23. sglang/srt/managers/image_processor.py +4 -3
  24. sglang/srt/managers/io_struct.py +70 -78
  25. sglang/srt/managers/schedule_batch.py +33 -49
  26. sglang/srt/managers/schedule_policy.py +24 -13
  27. sglang/srt/managers/scheduler.py +137 -80
  28. sglang/srt/managers/tokenizer_manager.py +224 -336
  29. sglang/srt/managers/tp_worker.py +5 -5
  30. sglang/srt/mem_cache/flush_cache.py +1 -1
  31. sglang/srt/model_executor/cuda_graph_runner.py +7 -4
  32. sglang/srt/model_executor/model_runner.py +8 -17
  33. sglang/srt/models/baichuan.py +4 -4
  34. sglang/srt/models/chatglm.py +4 -4
  35. sglang/srt/models/commandr.py +1 -1
  36. sglang/srt/models/dbrx.py +5 -5
  37. sglang/srt/models/deepseek.py +4 -4
  38. sglang/srt/models/deepseek_v2.py +4 -4
  39. sglang/srt/models/exaone.py +4 -4
  40. sglang/srt/models/gemma.py +1 -1
  41. sglang/srt/models/gemma2.py +1 -1
  42. sglang/srt/models/gpt2.py +287 -0
  43. sglang/srt/models/gpt_bigcode.py +1 -1
  44. sglang/srt/models/grok.py +4 -4
  45. sglang/srt/models/internlm2.py +4 -4
  46. sglang/srt/models/llama.py +15 -7
  47. sglang/srt/models/llama_embedding.py +2 -10
  48. sglang/srt/models/llama_reward.py +5 -0
  49. sglang/srt/models/minicpm.py +4 -4
  50. sglang/srt/models/minicpm3.py +4 -4
  51. sglang/srt/models/mixtral.py +7 -5
  52. sglang/srt/models/mixtral_quant.py +4 -4
  53. sglang/srt/models/mllama.py +5 -5
  54. sglang/srt/models/olmo.py +4 -4
  55. sglang/srt/models/olmoe.py +4 -4
  56. sglang/srt/models/qwen.py +4 -4
  57. sglang/srt/models/qwen2.py +4 -4
  58. sglang/srt/models/qwen2_moe.py +4 -4
  59. sglang/srt/models/qwen2_vl.py +4 -8
  60. sglang/srt/models/stablelm.py +4 -4
  61. sglang/srt/models/torch_native_llama.py +4 -4
  62. sglang/srt/models/xverse.py +4 -4
  63. sglang/srt/models/xverse_moe.py +4 -4
  64. sglang/srt/openai_api/adapter.py +52 -66
  65. sglang/srt/sampling/sampling_batch_info.py +7 -13
  66. sglang/srt/server.py +31 -35
  67. sglang/srt/server_args.py +34 -5
  68. sglang/srt/utils.py +40 -56
  69. sglang/test/runners.py +2 -1
  70. sglang/test/test_utils.py +73 -25
  71. sglang/utils.py +62 -1
  72. sglang/version.py +1 -1
  73. sglang-0.3.5.dist-info/METADATA +344 -0
  74. {sglang-0.3.4.post2.dist-info → sglang-0.3.5.dist-info}/RECORD +77 -73
  75. {sglang-0.3.4.post2.dist-info → sglang-0.3.5.dist-info}/WHEEL +1 -1
  76. sglang-0.3.4.post2.dist-info/METADATA +0 -899
  77. {sglang-0.3.4.post2.dist-info → sglang-0.3.5.dist-info}/LICENSE +0 -0
  78. {sglang-0.3.4.post2.dist-info → sglang-0.3.5.dist-info}/top_level.txt +0 -0
@@ -15,22 +15,21 @@ limitations under the License.
15
15
 
16
16
  """A scheduler that manages a tensor parallel GPU worker."""
17
17
 
18
- import json
19
18
  import logging
20
19
  import os
20
+ import threading
21
21
  import time
22
22
  import warnings
23
23
  from collections import deque
24
24
  from types import SimpleNamespace
25
- from typing import List, Optional, Union
25
+ from typing import List, Optional
26
26
 
27
27
  import torch
28
28
  import zmq
29
29
 
30
30
  from sglang.global_config import global_config
31
31
  from sglang.srt.configs.model_config import ModelConfig
32
- from sglang.srt.constrained.fsm_cache import FSMCache
33
- from sglang.srt.constrained.jump_forward import JumpForwardCache
32
+ from sglang.srt.constrained.grammar import GrammarCache
34
33
  from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
35
34
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
36
35
  from sglang.srt.managers.io_struct import (
@@ -43,7 +42,6 @@ from sglang.srt.managers.io_struct import (
43
42
  ProfileReq,
44
43
  TokenizedEmbeddingReqInput,
45
44
  TokenizedGenerateReqInput,
46
- TokenizedRewardReqInput,
47
45
  UpdateWeightReqInput,
48
46
  UpdateWeightReqOutput,
49
47
  )
@@ -68,8 +66,7 @@ from sglang.srt.server_args import PortArgs, ServerArgs
68
66
  from sglang.srt.utils import (
69
67
  broadcast_pyobj,
70
68
  configure_logger,
71
- is_generation_model,
72
- is_multimodal_model,
69
+ get_zmq_socket,
73
70
  kill_parent_process,
74
71
  set_random_seed,
75
72
  suppress_other_loggers,
@@ -78,6 +75,7 @@ from sglang.utils import get_exception_traceback
78
75
 
79
76
  logger = logging.getLogger(__name__)
80
77
 
78
+
81
79
  # Crash on warning if we are running CI tests
82
80
  crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"
83
81
 
@@ -105,16 +103,26 @@ class Scheduler:
105
103
  self.lora_paths = server_args.lora_paths
106
104
  self.max_loras_per_batch = server_args.max_loras_per_batch
107
105
  self.enable_overlap = server_args.enable_overlap_schedule
106
+ self.skip_tokenizer_init = server_args.skip_tokenizer_init
108
107
 
109
108
  # Init inter-process communication
110
109
  context = zmq.Context(2)
111
110
 
112
111
  if self.tp_rank == 0:
113
- self.recv_from_tokenizer = context.socket(zmq.PULL)
114
- self.recv_from_tokenizer.bind(f"ipc://{port_args.scheduler_input_ipc_name}")
112
+ self.recv_from_tokenizer = get_zmq_socket(
113
+ context, zmq.PULL, port_args.scheduler_input_ipc_name
114
+ )
115
115
 
116
- self.send_to_detokenizer = context.socket(zmq.PUSH)
117
- self.send_to_detokenizer.connect(f"ipc://{port_args.detokenizer_ipc_name}")
116
+ if server_args.skip_tokenizer_init:
117
+ # Directly send to the tokenizer/api
118
+ self.send_to_detokenizer = get_zmq_socket(
119
+ context, zmq.PUSH, port_args.tokenizer_ipc_name
120
+ )
121
+ else:
122
+ # Send to the detokenizer
123
+ self.send_to_detokenizer = get_zmq_socket(
124
+ context, zmq.PUSH, port_args.detokenizer_ipc_name
125
+ )
118
126
  else:
119
127
  self.recv_from_tokenizer = None
120
128
  self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
@@ -122,15 +130,17 @@ class Scheduler:
122
130
  # Init tokenizer
123
131
  self.model_config = ModelConfig(
124
132
  server_args.model_path,
125
- server_args.trust_remote_code,
133
+ trust_remote_code=server_args.trust_remote_code,
126
134
  context_length=server_args.context_length,
127
- model_override_args=json.loads(server_args.json_model_override_args),
135
+ model_override_args=server_args.json_model_override_args,
136
+ is_embedding=server_args.is_embedding,
128
137
  )
138
+ self.is_generation = self.model_config.is_generation
129
139
 
130
140
  if server_args.skip_tokenizer_init:
131
141
  self.tokenizer = self.processor = None
132
142
  else:
133
- if is_multimodal_model(self.model_config.hf_config.architectures):
143
+ if self.model_config.is_multimodal:
134
144
  self.processor = get_processor(
135
145
  server_args.tokenizer_path,
136
146
  tokenizer_mode=server_args.tokenizer_mode,
@@ -143,9 +153,6 @@ class Scheduler:
143
153
  tokenizer_mode=server_args.tokenizer_mode,
144
154
  trust_remote_code=server_args.trust_remote_code,
145
155
  )
146
- self.is_generation = is_generation_model(
147
- self.model_config.hf_config.architectures, self.server_args.is_embedding
148
- )
149
156
 
150
157
  # Launch a tensor parallel worker
151
158
  if self.enable_overlap:
@@ -212,44 +219,62 @@ class Scheduler:
212
219
  self.waiting_queue: List[Req] = []
213
220
  self.running_batch: Optional[ScheduleBatch] = None
214
221
  self.cur_batch: Optional[ScheduleBatch] = None
215
- self.decode_forward_ct = 0
216
- self.stream_interval = server_args.stream_interval
222
+ self.forward_ct = 0
223
+ self.forward_ct_decode = 0
217
224
  self.num_generated_tokens = 0
218
225
  self.last_stats_tic = time.time()
226
+ self.stream_interval = server_args.stream_interval
219
227
 
220
228
  # Init chunked prefill
221
229
  self.chunked_prefill_size = server_args.chunked_prefill_size
222
- self.current_inflight_req = None
230
+ self.being_chunked_req = None
223
231
  self.is_mixed_chunk = (
224
232
  self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
225
233
  )
226
234
 
227
235
  # Init the FSM cache for constrained generation
236
+ self.grammar_cache = None
237
+
228
238
  if not server_args.skip_tokenizer_init:
229
- self.regex_fsm_cache = FSMCache(
239
+ self.grammar_cache = GrammarCache(
230
240
  server_args.tokenizer_path,
231
241
  {
232
242
  "tokenizer_mode": server_args.tokenizer_mode,
233
243
  "trust_remote_code": server_args.trust_remote_code,
234
244
  },
235
245
  skip_tokenizer_init=server_args.skip_tokenizer_init,
236
- constrained_json_whitespace_pattern=server_args.constrained_json_whitespace_pattern,
246
+ whitespace_patterns=server_args.constrained_json_whitespace_pattern,
247
+ backend=server_args.grammar_backend,
248
+ allow_jump=not server_args.disable_regex_jump_forward,
237
249
  )
238
- self.jump_forward_cache = JumpForwardCache()
239
250
 
240
251
  # Init new token estimation
241
252
  assert (
242
253
  server_args.schedule_conservativeness >= 0
243
254
  ), "Invalid schedule_conservativeness"
244
- self.min_new_token_ratio = min(
245
- global_config.base_min_new_token_ratio
255
+
256
+ self.init_new_token_ratio = min(
257
+ global_config.default_init_new_token_ratio
246
258
  * server_args.schedule_conservativeness,
247
259
  1.0,
248
260
  )
249
- self.new_token_ratio = self.min_new_token_ratio
250
- self.new_token_ratio_decay = global_config.new_token_ratio_decay
261
+ self.min_new_token_ratio = min(
262
+ self.init_new_token_ratio
263
+ * global_config.default_min_new_token_ratio_factor,
264
+ 1.0,
265
+ )
266
+ self.new_token_ratio_decay = (
267
+ self.init_new_token_ratio - self.min_new_token_ratio
268
+ ) / global_config.default_new_token_ratio_decay_steps
269
+ self.new_token_ratio = self.init_new_token_ratio
270
+
251
271
  self.batch_is_full = False
252
272
 
273
+ # Init watchdog thread
274
+ self.watchdog_timeout = server_args.watchdog_timeout
275
+ t = threading.Thread(target=self.watchdog_thread, daemon=True)
276
+ t.start()
277
+
253
278
  # Init profiler
254
279
  if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "":
255
280
  self.profiler = None
@@ -267,6 +292,23 @@ class Scheduler:
267
292
  with_stack=True,
268
293
  )
269
294
 
295
+ def watchdog_thread(self):
296
+ self.watchdog_last_forward_ct = 0
297
+ self.watchdog_last_time = time.time()
298
+
299
+ while True:
300
+ if self.cur_batch is not None:
301
+ if self.watchdog_last_forward_ct == self.forward_ct:
302
+ if time.time() > self.watchdog_last_time + self.watchdog_timeout:
303
+ logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
304
+ break
305
+ else:
306
+ self.watchdog_last_forward_ct = self.forward_ct
307
+ self.watchdog_last_time = time.time()
308
+ time.sleep(self.watchdog_timeout / 2)
309
+
310
+ kill_parent_process()
311
+
270
312
  @torch.inference_mode()
271
313
  def event_loop_normal(self):
272
314
  """A normal blocking scheduler loop."""
@@ -277,6 +319,7 @@ class Scheduler:
277
319
  self.process_input_requests(recv_reqs)
278
320
 
279
321
  batch = self.get_next_batch_to_run()
322
+ self.cur_batch = batch
280
323
 
281
324
  if batch:
282
325
  result = self.run_batch(batch)
@@ -294,7 +337,7 @@ class Scheduler:
294
337
  self.process_batch_result(batch, result)
295
338
  else:
296
339
  self.check_memory()
297
- self.new_token_ratio = global_config.init_new_token_ratio
340
+ self.new_token_ratio = self.init_new_token_ratio
298
341
 
299
342
  self.last_batch = batch
300
343
 
@@ -321,7 +364,7 @@ class Scheduler:
321
364
  self.process_batch_result(tmp_batch, tmp_result)
322
365
  elif batch is None:
323
366
  self.check_memory()
324
- self.new_token_ratio = global_config.init_new_token_ratio
367
+ self.new_token_ratio = self.init_new_token_ratio
325
368
 
326
369
  self.last_batch = batch
327
370
 
@@ -346,9 +389,7 @@ class Scheduler:
346
389
  for recv_req in recv_reqs:
347
390
  if isinstance(recv_req, TokenizedGenerateReqInput):
348
391
  self.handle_generate_request(recv_req)
349
- elif isinstance(
350
- recv_req, (TokenizedEmbeddingReqInput, TokenizedRewardReqInput)
351
- ):
392
+ elif isinstance(recv_req, TokenizedEmbeddingReqInput):
352
393
  self.handle_embedding_request(recv_req)
353
394
  elif isinstance(recv_req, FlushCacheReq):
354
395
  self.flush_cache()
@@ -402,22 +443,20 @@ class Scheduler:
402
443
  # By default, only return the logprobs for output tokens
403
444
  req.logprob_start_len = len(recv_req.input_ids) - 1
404
445
 
405
- # Init regex FSM
446
+ # Init regex FSM or BNF
406
447
  if (
407
448
  req.sampling_params.json_schema is not None
408
449
  or req.sampling_params.regex is not None
409
450
  ):
451
+ assert self.grammar_cache is not None
410
452
  if req.sampling_params.json_schema is not None:
411
- req.regex_fsm, computed_regex_string = self.regex_fsm_cache.query(
412
- ("json", req.sampling_params.json_schema)
453
+ req.grammar = self.grammar_cache.query(
454
+ ("json", req.sampling_params.json_schema),
455
+ self.model_config.vocab_size,
413
456
  )
414
457
  elif req.sampling_params.regex is not None:
415
- req.regex_fsm, computed_regex_string = self.regex_fsm_cache.query(
416
- ("regex", req.sampling_params.regex)
417
- )
418
- if not self.disable_regex_jump_forward:
419
- req.jump_forward_map = self.jump_forward_cache.query(
420
- computed_regex_string
458
+ req.grammar = self.grammar_cache.query(
459
+ ("regex", req.sampling_params.regex), self.model_config.vocab_size
421
460
  )
422
461
 
423
462
  # Truncate prompts that are too long
@@ -441,7 +480,7 @@ class Scheduler:
441
480
 
442
481
  def handle_embedding_request(
443
482
  self,
444
- recv_req: Union[TokenizedEmbeddingReqInput, TokenizedRewardReqInput],
483
+ recv_req: TokenizedEmbeddingReqInput,
445
484
  ):
446
485
  req = Req(
447
486
  recv_req.rid,
@@ -506,13 +545,13 @@ class Scheduler:
506
545
  and not self.last_batch.forward_mode.is_decode()
507
546
  and not self.last_batch.is_empty()
508
547
  ):
509
- if self.current_inflight_req:
548
+ if self.being_chunked_req:
510
549
  self.last_batch.filter_batch(
511
- current_inflight_req=self.current_inflight_req
550
+ being_chunked_req=self.being_chunked_req
512
551
  )
513
- self.tree_cache.cache_unfinished_req(self.current_inflight_req)
552
+ self.tree_cache.cache_unfinished_req(self.being_chunked_req)
514
553
  # Inflight request keeps its rid but will get a new req_pool_idx.
515
- self.req_to_token_pool.free(self.current_inflight_req.req_pool_idx)
554
+ self.req_to_token_pool.free(self.being_chunked_req.req_pool_idx)
516
555
  self.batch_is_full = False
517
556
  if not self.last_batch.is_empty():
518
557
  if self.running_batch is None:
@@ -543,7 +582,7 @@ class Scheduler:
543
582
  # Handle the cases where prefill is not allowed
544
583
  if (
545
584
  self.batch_is_full or len(self.waiting_queue) == 0
546
- ) and self.current_inflight_req is None:
585
+ ) and self.being_chunked_req is None:
547
586
  return None
548
587
 
549
588
  running_bs = len(self.running_batch.reqs) if self.running_batch else 0
@@ -566,13 +605,11 @@ class Scheduler:
566
605
  num_mixed_running,
567
606
  )
568
607
 
569
- has_inflight = self.current_inflight_req is not None
608
+ has_inflight = self.being_chunked_req is not None
570
609
  if has_inflight:
571
- self.current_inflight_req.init_next_round_input(
572
- None if prefix_computed else self.tree_cache
573
- )
574
- self.current_inflight_req = adder.add_inflight_req(
575
- self.current_inflight_req
610
+ self.being_chunked_req.init_next_round_input()
611
+ self.being_chunked_req = adder.add_inflight_req(
612
+ self.being_chunked_req
576
613
  )
577
614
 
578
615
  if self.lora_paths:
@@ -616,11 +653,11 @@ class Scheduler:
616
653
  ]
617
654
 
618
655
  if adder.new_inflight_req is not None:
619
- assert self.current_inflight_req is None
620
- self.current_inflight_req = adder.new_inflight_req
656
+ assert self.being_chunked_req is None
657
+ self.being_chunked_req = adder.new_inflight_req
621
658
 
622
- if self.current_inflight_req:
623
- self.current_inflight_req.is_inflight_req += 1
659
+ if self.being_chunked_req:
660
+ self.being_chunked_req.is_being_chunked += 1
624
661
 
625
662
  # Print stats
626
663
  if self.tp_rank == 0:
@@ -675,9 +712,11 @@ class Scheduler:
675
712
 
676
713
  # Mixed-style chunked prefill
677
714
  if self.is_mixed_chunk and self.running_batch is not None:
678
- self.running_batch.prepare_for_decode(self.enable_overlap)
679
- new_batch.mix_with_running(self.running_batch)
680
- new_batch.decoding_reqs = self.running_batch.reqs
715
+ self.running_batch.filter_batch()
716
+ if not self.running_batch.is_empty():
717
+ self.running_batch.prepare_for_decode(self.enable_overlap)
718
+ new_batch.mix_with_running(self.running_batch)
719
+ new_batch.decoding_reqs = self.running_batch.reqs
681
720
  self.running_batch = None
682
721
  else:
683
722
  new_batch.decoding_reqs = None
@@ -726,6 +765,8 @@ class Scheduler:
726
765
 
727
766
  def run_batch(self, batch: ScheduleBatch):
728
767
  """Run a batch."""
768
+ self.forward_ct += 1
769
+
729
770
  if self.is_generation:
730
771
  if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
731
772
  model_worker_batch = batch.get_model_worker_batch()
@@ -734,7 +775,7 @@ class Scheduler:
734
775
  )
735
776
  else:
736
777
  logits_output = None
737
- if self.tokenizer is not None:
778
+ if self.skip_tokenizer_init:
738
779
  next_token_ids = torch.full(
739
780
  (batch.batch_size(),), self.tokenizer.eos_token_id
740
781
  )
@@ -758,6 +799,7 @@ class Scheduler:
758
799
  self.process_batch_result_prefill(batch, result)
759
800
 
760
801
  def process_batch_result_prefill(self, batch: ScheduleBatch, result):
802
+
761
803
  if self.is_generation:
762
804
  logits_output, next_token_ids, bid = result
763
805
 
@@ -783,9 +825,10 @@ class Scheduler:
783
825
  # Check finish conditions
784
826
  logprob_pt = 0
785
827
  for i, req in enumerate(batch.reqs):
786
- if req.is_inflight_req > 0:
787
- req.is_inflight_req -= 1
788
- else:
828
+ if req.is_retracted:
829
+ continue
830
+
831
+ if req.is_being_chunked <= 0:
789
832
  # Inflight reqs' prefill is not finished
790
833
  req.completion_tokens_wo_jump_forward += 1
791
834
  req.output_ids.append(next_token_ids[i])
@@ -796,24 +839,28 @@ class Scheduler:
796
839
  elif not batch.decoding_reqs or req not in batch.decoding_reqs:
797
840
  self.tree_cache.cache_unfinished_req(req)
798
841
 
799
- if req.regex_fsm is not None:
800
- req.regex_fsm_state = req.regex_fsm.get_next_state(
801
- req.regex_fsm_state, next_token_ids[i]
802
- )
842
+ if req.grammar is not None:
843
+ req.grammar.accept_token(next_token_ids[i])
803
844
 
804
845
  if req.return_logprob:
805
846
  logprob_pt += self.add_logprob_return_values(
806
847
  i, req, logprob_pt, next_token_ids, logits_output
807
848
  )
849
+ else:
850
+ req.is_being_chunked -= 1
851
+
808
852
  else: # embedding or reward model
809
853
  embeddings, bid = result
810
854
  embeddings = embeddings.tolist()
811
855
 
812
856
  # Check finish conditions
813
857
  for i, req in enumerate(batch.reqs):
858
+ if req.is_retracted:
859
+ continue
860
+
814
861
  req.embedding = embeddings[i]
815
- if req.is_inflight_req > 0:
816
- req.is_inflight_req -= 1
862
+ if req.is_being_chunked > 0:
863
+ req.is_being_chunked -= 1
817
864
  else:
818
865
  # Inflight reqs' prefill is not finished
819
866
  # dummy output token for embedding models
@@ -847,7 +894,12 @@ class Scheduler:
847
894
 
848
895
  # Check finish condition
849
896
  for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
850
- if self.server_args.enable_overlap_schedule and req.finished():
897
+ if req.is_retracted:
898
+ continue
899
+
900
+ if self.server_args.enable_overlap_schedule and (
901
+ req.finished()
902
+ ):
851
903
  self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
852
904
  continue
853
905
 
@@ -855,10 +907,8 @@ class Scheduler:
855
907
  req.output_ids.append(next_token_id)
856
908
  req.check_finished()
857
909
 
858
- if req.regex_fsm is not None:
859
- req.regex_fsm_state = req.regex_fsm.get_next_state(
860
- req.regex_fsm_state, next_token_id
861
- )
910
+ if req.grammar is not None:
911
+ req.grammar.accept_token(next_token_id)
862
912
 
863
913
  if req.finished():
864
914
  self.tree_cache.cache_finished_req(req)
@@ -874,8 +924,8 @@ class Scheduler:
874
924
 
875
925
  self.token_to_kv_pool.free_group_end()
876
926
 
877
- self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
878
- if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
927
+ self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
928
+ if self.tp_rank == 0 and self.forward_ct_decode % self.server_args.decode_log_interval == 0:
879
929
  self.print_decode_stats()
880
930
 
881
931
  def add_logprob_return_values(
@@ -954,22 +1004,24 @@ class Scheduler:
954
1004
  def stream_output(self, reqs: List[Req]):
955
1005
  """Stream the output to detokenizer."""
956
1006
  output_rids = []
957
- output_meta_info = []
1007
+ output_meta_info: List[dict] = []
958
1008
  output_finished_reason: List[BaseFinishReason] = []
959
1009
  if self.is_generation:
960
1010
  output_vids = []
961
1011
  decoded_texts = []
962
1012
  output_read_ids = []
963
1013
  output_read_offsets = []
1014
+ output_ids = []
964
1015
  output_skip_special_tokens = []
965
1016
  output_spaces_between_special_tokens = []
966
1017
  output_no_stop_trim = []
967
1018
  else: # embedding or reward model
968
1019
  output_embeddings = []
969
1020
 
970
- is_stream_iter = self.decode_forward_ct % self.stream_interval == 0
1021
+ is_stream_iter = self.forward_ct_decode % self.stream_interval == 0
971
1022
 
972
1023
  for req in reqs:
1024
+ # TODO(lianmin): revisit this for overlap + retract + stream
973
1025
  if req.finished() or (
974
1026
  req.stream and (is_stream_iter or len(req.output_ids) == 1)
975
1027
  ):
@@ -981,6 +1033,8 @@ class Scheduler:
981
1033
  read_ids, read_offset = req.init_incremental_detokenize()
982
1034
  output_read_ids.append(read_ids)
983
1035
  output_read_offsets.append(read_offset)
1036
+ if self.skip_tokenizer_init:
1037
+ output_ids.append(req.output_ids)
984
1038
  output_skip_special_tokens.append(
985
1039
  req.sampling_params.skip_special_tokens
986
1040
  )
@@ -1032,6 +1086,7 @@ class Scheduler:
1032
1086
  decoded_texts,
1033
1087
  output_read_ids,
1034
1088
  output_read_offsets,
1089
+ output_ids,
1035
1090
  output_skip_special_tokens,
1036
1091
  output_spaces_between_special_tokens,
1037
1092
  output_meta_info,
@@ -1056,7 +1111,9 @@ class Scheduler:
1056
1111
  ):
1057
1112
  self.tree_cache.reset()
1058
1113
  self.tree_cache_metrics = {"total": 0, "hit": 0}
1059
- self.regex_fsm_cache.reset()
1114
+ if self.grammar_cache is not None:
1115
+ self.grammar_cache.reset()
1116
+ # TODO(dark): reset the bnf cache
1060
1117
  self.req_to_token_pool.clear()
1061
1118
  self.token_to_kv_pool.clear()
1062
1119
  torch.cuda.empty_cache()