sglang 0.3.5__py3-none-any.whl → 0.3.5.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. sglang/bench_serving.py +113 -3
  2. sglang/srt/configs/model_config.py +5 -2
  3. sglang/srt/constrained/__init__.py +2 -66
  4. sglang/srt/constrained/base_grammar_backend.py +72 -0
  5. sglang/srt/constrained/outlines_backend.py +165 -0
  6. sglang/srt/constrained/outlines_jump_forward.py +182 -0
  7. sglang/srt/constrained/xgrammar_backend.py +114 -0
  8. sglang/srt/layers/attention/triton_ops/decode_attention.py +7 -0
  9. sglang/srt/layers/attention/triton_ops/extend_attention.py +6 -0
  10. sglang/srt/layers/fused_moe/fused_moe.py +23 -7
  11. sglang/srt/layers/quantization/base_config.py +4 -6
  12. sglang/srt/layers/vocab_parallel_embedding.py +216 -150
  13. sglang/srt/managers/io_struct.py +5 -3
  14. sglang/srt/managers/schedule_batch.py +14 -20
  15. sglang/srt/managers/scheduler.py +153 -94
  16. sglang/srt/managers/tokenizer_manager.py +81 -17
  17. sglang/srt/metrics/collector.py +211 -0
  18. sglang/srt/metrics/func_timer.py +108 -0
  19. sglang/srt/mm_utils.py +1 -1
  20. sglang/srt/model_executor/cuda_graph_runner.py +2 -2
  21. sglang/srt/model_executor/forward_batch_info.py +7 -3
  22. sglang/srt/model_executor/model_runner.py +2 -1
  23. sglang/srt/models/gemma2_reward.py +69 -0
  24. sglang/srt/models/gpt2.py +31 -37
  25. sglang/srt/models/internlm2_reward.py +62 -0
  26. sglang/srt/models/llama.py +11 -6
  27. sglang/srt/models/llama_reward.py +5 -26
  28. sglang/srt/models/qwen2_vl.py +5 -7
  29. sglang/srt/openai_api/adapter.py +6 -2
  30. sglang/srt/sampling/sampling_batch_info.py +2 -3
  31. sglang/srt/sampling/sampling_params.py +0 -14
  32. sglang/srt/server.py +58 -16
  33. sglang/srt/server_args.py +42 -22
  34. sglang/srt/utils.py +87 -0
  35. sglang/test/simple_eval_common.py +1 -1
  36. sglang/test/simple_eval_humaneval.py +2 -2
  37. sglang/test/simple_eval_mgsm.py +2 -2
  38. sglang/test/test_utils.py +18 -4
  39. sglang/utils.py +1 -0
  40. sglang/version.py +1 -1
  41. {sglang-0.3.5.dist-info → sglang-0.3.5.post1.dist-info}/METADATA +11 -7
  42. {sglang-0.3.5.dist-info → sglang-0.3.5.post1.dist-info}/RECORD +45 -42
  43. {sglang-0.3.5.dist-info → sglang-0.3.5.post1.dist-info}/WHEEL +1 -1
  44. sglang/srt/constrained/base_tool_cache.py +0 -65
  45. sglang/srt/constrained/bnf_cache.py +0 -61
  46. sglang/srt/constrained/fsm_cache.py +0 -95
  47. sglang/srt/constrained/grammar.py +0 -190
  48. sglang/srt/constrained/jump_forward.py +0 -203
  49. {sglang-0.3.5.dist-info → sglang-0.3.5.post1.dist-info}/LICENSE +0 -0
  50. {sglang-0.3.5.dist-info → sglang-0.3.5.post1.dist-info}/top_level.txt +0 -0
@@ -37,7 +37,7 @@ import torch
37
37
 
38
38
  from sglang.global_config import global_config
39
39
  from sglang.srt.configs.model_config import ModelConfig
40
- from sglang.srt.constrained.grammar import Grammar
40
+ from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
41
41
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
42
42
  from sglang.srt.mem_cache.chunk_cache import ChunkCache
43
43
  from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
@@ -107,12 +107,14 @@ class FINISH_LENGTH(BaseFinishReason):
107
107
 
108
108
 
109
109
  class FINISH_ABORT(BaseFinishReason):
110
- def __init__(self):
110
+ def __init__(self, message="Unknown error"):
111
111
  super().__init__(is_error=True)
112
+ self.message = message
112
113
 
113
114
  def to_json(self):
114
115
  return {
115
116
  "type": "abort",
117
+ "message": self.message,
116
118
  }
117
119
 
118
120
 
@@ -133,6 +135,7 @@ class ImageInputs:
133
135
  aspect_ratio_mask: Optional[List[torch.Tensor]] = None
134
136
  # QWen2-VL related
135
137
  image_grid_thws: List[Tuple[int, int, int]] = None
138
+ mrope_position_delta: Optional[torch.Tensor] = None
136
139
 
137
140
  @staticmethod
138
141
  def from_dict(obj, vocab_size):
@@ -211,7 +214,7 @@ class Req:
211
214
  # this does not include the jump forward tokens.
212
215
  self.completion_tokens_wo_jump_forward = 0
213
216
 
214
- # For vision inputs
217
+ # For multimodal inputs
215
218
  self.image_inputs: Optional[ImageInputs] = None
216
219
 
217
220
  # Prefix info
@@ -246,14 +249,11 @@ class Req:
246
249
  self.embedding = None
247
250
 
248
251
  # Constrained decoding
249
- self.grammar: Optional[Grammar] = None
252
+ self.grammar: Optional[BaseGrammarObject] = None
250
253
 
251
254
  # The number of cached tokens, that were already cached in the KV cache
252
255
  self.cached_tokens = 0
253
256
 
254
- # For Qwen2-VL
255
- self.mrope_position_delta = [] # use mutable object
256
-
257
257
  # whether request reached finished condition
258
258
  def finished(self) -> bool:
259
259
  return self.finished_reason is not None
@@ -359,8 +359,6 @@ class Req:
359
359
  return
360
360
 
361
361
  def jump_forward_and_retokenize(self, jump_forward_str, next_state):
362
- assert self.grammar is not None and self.tokenizer is not None
363
-
364
362
  if self.origin_input_text is None:
365
363
  # Recovering text can only use unpadded ids
366
364
  self.origin_input_text = self.tokenizer.decode(
@@ -809,9 +807,10 @@ class ScheduleBatch:
809
807
 
810
808
  for i, req in enumerate(self.reqs):
811
809
  if req.grammar is not None:
812
- jump_helper = req.grammar.try_jump(req.tokenizer)
813
- if jump_helper.can_jump():
814
- suffix_ids = jump_helper.suffix_ids
810
+ jump_helper = req.grammar.try_jump_forward(req.tokenizer)
811
+ if jump_helper:
812
+ suffix_ids, _ = jump_helper
813
+
815
814
  # Current ids, for cache and revert
816
815
  cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1]
817
816
  cur_output_ids = req.output_ids
@@ -827,6 +826,8 @@ class ScheduleBatch:
827
826
  next_state,
828
827
  ) = req.grammar.jump_forward_str_state(jump_helper)
829
828
 
829
+ # Make the incrementally decoded text part of jump_forward_str
830
+ # so that the UTF-8 will not corrupt
830
831
  jump_forward_str = new_text + jump_forward_str
831
832
  if not req.jump_forward_and_retokenize(
832
833
  jump_forward_str, next_state
@@ -900,8 +901,7 @@ class ScheduleBatch:
900
901
  keep_indices = [
901
902
  i
902
903
  for i in range(len(self.reqs))
903
- if not self.reqs[i].finished()
904
- and self.reqs[i] is not being_chunked_req
904
+ if not self.reqs[i].finished() and self.reqs[i] is not being_chunked_req
905
905
  ]
906
906
 
907
907
  if keep_indices is None or len(keep_indices) == 0:
@@ -984,8 +984,6 @@ class ScheduleBatch:
984
984
  global bid
985
985
  bid += 1
986
986
 
987
- mrope_positions_delta = [req.mrope_position_delta for req in self.reqs]
988
-
989
987
  return ModelWorkerBatch(
990
988
  bid=bid,
991
989
  forward_mode=self.forward_mode,
@@ -1008,7 +1006,6 @@ class ScheduleBatch:
1008
1006
  encoder_out_cache_loc=self.encoder_out_cache_loc,
1009
1007
  lora_paths=[req.lora_path for req in self.reqs],
1010
1008
  sampling_info=self.sampling_info,
1011
- mrope_positions_delta=mrope_positions_delta,
1012
1009
  )
1013
1010
 
1014
1011
  def copy(self):
@@ -1075,9 +1072,6 @@ class ModelWorkerBatch:
1075
1072
  # Sampling info
1076
1073
  sampling_info: SamplingBatchInfo
1077
1074
 
1078
- # For Qwen2-VL
1079
- mrope_positions_delta: List[List[int]]
1080
-
1081
1075
  def copy(self):
1082
1076
  return dataclasses.replace(self, sampling_info=self.sampling_info.copy())
1083
1077
 
@@ -21,6 +21,7 @@ import threading
21
21
  import time
22
22
  import warnings
23
23
  from collections import deque
24
+ from concurrent import futures
24
25
  from types import SimpleNamespace
25
26
  from typing import List, Optional
26
27
 
@@ -29,7 +30,6 @@ import zmq
29
30
 
30
31
  from sglang.global_config import global_config
31
32
  from sglang.srt.configs.model_config import ModelConfig
32
- from sglang.srt.constrained.grammar import GrammarCache
33
33
  from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
34
34
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
35
35
  from sglang.srt.managers.io_struct import (
@@ -62,6 +62,7 @@ from sglang.srt.managers.tp_worker import TpModelWorker
62
62
  from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
63
63
  from sglang.srt.mem_cache.chunk_cache import ChunkCache
64
64
  from sglang.srt.mem_cache.radix_cache import RadixCache
65
+ from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
65
66
  from sglang.srt.server_args import PortArgs, ServerArgs
66
67
  from sglang.srt.utils import (
67
68
  broadcast_pyobj,
@@ -99,11 +100,12 @@ class Scheduler:
99
100
  self.tp_rank = tp_rank
100
101
  self.tp_size = server_args.tp_size
101
102
  self.schedule_policy = server_args.schedule_policy
102
- self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
103
+ self.disable_jump_forward = server_args.disable_jump_forward
103
104
  self.lora_paths = server_args.lora_paths
104
105
  self.max_loras_per_batch = server_args.max_loras_per_batch
105
106
  self.enable_overlap = server_args.enable_overlap_schedule
106
107
  self.skip_tokenizer_init = server_args.skip_tokenizer_init
108
+ self.enable_metrics = server_args.enable_metrics
107
109
 
108
110
  # Init inter-process communication
109
111
  context = zmq.Context(2)
@@ -222,7 +224,7 @@ class Scheduler:
222
224
  self.forward_ct = 0
223
225
  self.forward_ct_decode = 0
224
226
  self.num_generated_tokens = 0
225
- self.last_stats_tic = time.time()
227
+ self.last_decode_stats_tic = time.time()
226
228
  self.stream_interval = server_args.stream_interval
227
229
 
228
230
  # Init chunked prefill
@@ -232,21 +234,33 @@ class Scheduler:
232
234
  self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
233
235
  )
234
236
 
235
- # Init the FSM cache for constrained generation
236
- self.grammar_cache = None
237
-
237
+ # Init the grammar backend for constrained generation
238
+ self.grammar_queue: List[Req] = []
238
239
  if not server_args.skip_tokenizer_init:
239
- self.grammar_cache = GrammarCache(
240
- server_args.tokenizer_path,
241
- {
242
- "tokenizer_mode": server_args.tokenizer_mode,
243
- "trust_remote_code": server_args.trust_remote_code,
244
- },
245
- skip_tokenizer_init=server_args.skip_tokenizer_init,
246
- whitespace_patterns=server_args.constrained_json_whitespace_pattern,
247
- backend=server_args.grammar_backend,
248
- allow_jump=not server_args.disable_regex_jump_forward,
249
- )
240
+ if server_args.grammar_backend == "outlines":
241
+ from sglang.srt.constrained.outlines_backend import (
242
+ OutlinesGrammarBackend,
243
+ )
244
+
245
+ self.grammar_backend = OutlinesGrammarBackend(
246
+ self.tokenizer,
247
+ whitespace_pattern=server_args.constrained_json_whitespace_pattern,
248
+ allow_jump_forward=not server_args.disable_jump_forward,
249
+ )
250
+ elif server_args.grammar_backend == "xgrammar":
251
+ from sglang.srt.constrained.xgrammar_backend import (
252
+ XGrammarGrammarBackend,
253
+ )
254
+
255
+ self.grammar_backend = XGrammarGrammarBackend(
256
+ self.tokenizer, vocab_size=self.model_config.vocab_size
257
+ )
258
+ else:
259
+ raise ValueError(
260
+ f"Invalid grammar backend: {server_args.grammar_backend}"
261
+ )
262
+ else:
263
+ self.grammar_backend = None
250
264
 
251
265
  # Init new token estimation
252
266
  assert (
@@ -292,6 +306,16 @@ class Scheduler:
292
306
  with_stack=True,
293
307
  )
294
308
 
309
+ # Init metrics stats
310
+ self.stats = SchedulerStats()
311
+ if self.enable_metrics:
312
+ self.metrics_collector = SchedulerMetricsCollector(
313
+ labels={
314
+ "model_name": self.server_args.served_model_name,
315
+ # TODO: Add lora name/path in the future,
316
+ },
317
+ )
318
+
295
319
  def watchdog_thread(self):
296
320
  self.watchdog_last_forward_ct = 0
297
321
  self.watchdog_last_time = time.time()
@@ -443,22 +467,6 @@ class Scheduler:
443
467
  # By default, only return the logprobs for output tokens
444
468
  req.logprob_start_len = len(recv_req.input_ids) - 1
445
469
 
446
- # Init regex FSM or BNF
447
- if (
448
- req.sampling_params.json_schema is not None
449
- or req.sampling_params.regex is not None
450
- ):
451
- assert self.grammar_cache is not None
452
- if req.sampling_params.json_schema is not None:
453
- req.grammar = self.grammar_cache.query(
454
- ("json", req.sampling_params.json_schema),
455
- self.model_config.vocab_size,
456
- )
457
- elif req.sampling_params.regex is not None:
458
- req.grammar = self.grammar_cache.query(
459
- ("regex", req.sampling_params.regex), self.model_config.vocab_size
460
- )
461
-
462
470
  # Truncate prompts that are too long
463
471
  if len(req.origin_input_ids) > self.max_req_input_len:
464
472
  logger.warning(
@@ -476,7 +484,27 @@ class Scheduler:
476
484
  self.max_req_len - len(req.origin_input_ids) - 1,
477
485
  )
478
486
 
479
- self.waiting_queue.append(req)
487
+ # Init grammar cache for this request
488
+ add_to_grammar_queue = False
489
+ if (
490
+ req.sampling_params.json_schema is not None
491
+ or req.sampling_params.regex is not None
492
+ ):
493
+ assert self.grammar_backend is not None
494
+ if req.sampling_params.json_schema is not None:
495
+ key = ("json", req.sampling_params.json_schema)
496
+ elif req.sampling_params.regex is not None:
497
+ key = ("regex", req.sampling_params.regex)
498
+
499
+ req.grammar = self.grammar_backend.get_cached_value(key)
500
+ if not req.grammar:
501
+ req.grammar = self.grammar_backend.get_future_value(key)
502
+ add_to_grammar_queue = True
503
+
504
+ if add_to_grammar_queue:
505
+ self.grammar_queue.append(req)
506
+ else:
507
+ self.waiting_queue.append(req)
480
508
 
481
509
  def handle_embedding_request(
482
510
  self,
@@ -500,23 +528,68 @@ class Scheduler:
500
528
 
501
529
  self.waiting_queue.append(req)
502
530
 
503
- def print_decode_stats(self):
531
+ def log_prefill_stats(self, adder, can_run_list, running_bs, has_inflight):
532
+ if isinstance(self.tree_cache, RadixCache):
533
+ self.tree_cache_metrics["total"] += (
534
+ adder.log_input_tokens + adder.log_hit_tokens
535
+ ) / 10**9
536
+ self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9
537
+ tree_cache_hit_rate = (
538
+ self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
539
+ )
540
+ else:
541
+ tree_cache_hit_rate = 0.0
542
+
504
543
  num_used = self.max_total_num_tokens - (
505
544
  self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
506
545
  )
507
- throughput = self.num_generated_tokens / (time.time() - self.last_stats_tic)
546
+
547
+ logger.info(
548
+ f"Prefill batch. "
549
+ f"#new-seq: {len(can_run_list)}, "
550
+ f"#new-token: {adder.log_input_tokens}, "
551
+ f"#cached-token: {adder.log_hit_tokens}, "
552
+ f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
553
+ f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
554
+ f"#running-req: {running_bs}, "
555
+ f"#queue-req: {len(self.waiting_queue) + has_inflight}"
556
+ )
557
+
558
+ if self.enable_metrics:
559
+ self.stats.num_running_reqs = running_bs
560
+ self.stats.num_used_tokens = num_used
561
+ self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
562
+ self.stats.num_queue_reqs = len(self.waiting_queue) + has_inflight
563
+ self.stats.cache_hit_rate = tree_cache_hit_rate
564
+ self.metrics_collector.log_stats(self.stats)
565
+
566
+ def log_decode_stats(self):
567
+ num_used = self.max_total_num_tokens - (
568
+ self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
569
+ )
570
+ gen_throughput = self.num_generated_tokens / (
571
+ time.time() - self.last_decode_stats_tic
572
+ )
508
573
  self.num_generated_tokens = 0
509
- self.last_stats_tic = time.time()
574
+ self.last_decode_stats_tic = time.time()
510
575
  num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
511
576
  logger.info(
512
577
  f"Decode batch. "
513
578
  f"#running-req: {num_running_reqs}, "
514
579
  f"#token: {num_used}, "
515
580
  f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
516
- f"gen throughput (token/s): {throughput:.2f}, "
581
+ f"gen throughput (token/s): {gen_throughput:.2f}, "
517
582
  f"#queue-req: {len(self.waiting_queue)}"
518
583
  )
519
584
 
585
+ if self.enable_metrics:
586
+ self.stats.num_running_reqs = num_running_reqs
587
+ self.stats.num_used_tokens = num_used
588
+ self.stats.token_usage = num_used / self.max_total_num_tokens
589
+ self.stats.gen_throughput = gen_throughput
590
+ self.stats.num_queue_reqs = len(self.waiting_queue)
591
+ self.metrics_collector.log_stats(self.stats)
592
+
520
593
  def check_memory(self):
521
594
  available_size = (
522
595
  self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
@@ -546,9 +619,7 @@ class Scheduler:
546
619
  and not self.last_batch.is_empty()
547
620
  ):
548
621
  if self.being_chunked_req:
549
- self.last_batch.filter_batch(
550
- being_chunked_req=self.being_chunked_req
551
- )
622
+ self.last_batch.filter_batch(being_chunked_req=self.being_chunked_req)
552
623
  self.tree_cache.cache_unfinished_req(self.being_chunked_req)
553
624
  # Inflight request keeps its rid but will get a new req_pool_idx.
554
625
  self.req_to_token_pool.free(self.being_chunked_req.req_pool_idx)
@@ -579,6 +650,10 @@ class Scheduler:
579
650
  return self.running_batch
580
651
 
581
652
  def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
653
+ # Check if the grammar is ready in the grammar queue
654
+ if self.grammar_queue:
655
+ self.move_ready_grammar_requests()
656
+
582
657
  # Handle the cases where prefill is not allowed
583
658
  if (
584
659
  self.batch_is_full or len(self.waiting_queue) == 0
@@ -594,7 +669,6 @@ class Scheduler:
594
669
  prefix_computed = self.policy.calc_priority(self.waiting_queue)
595
670
 
596
671
  # Prefill policy
597
- num_mixed_running = running_bs if self.is_mixed_chunk else 0
598
672
  adder = PrefillAdder(
599
673
  self.tree_cache,
600
674
  self.running_batch,
@@ -602,15 +676,13 @@ class Scheduler:
602
676
  self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(),
603
677
  self.max_prefill_tokens,
604
678
  self.chunked_prefill_size,
605
- num_mixed_running,
679
+ running_bs if self.is_mixed_chunk else 0,
606
680
  )
607
681
 
608
682
  has_inflight = self.being_chunked_req is not None
609
683
  if has_inflight:
610
684
  self.being_chunked_req.init_next_round_input()
611
- self.being_chunked_req = adder.add_inflight_req(
612
- self.being_chunked_req
613
- )
685
+ self.being_chunked_req = adder.add_inflight_req(self.being_chunked_req)
614
686
 
615
687
  if self.lora_paths:
616
688
  lora_set = (
@@ -661,44 +733,7 @@ class Scheduler:
661
733
 
662
734
  # Print stats
663
735
  if self.tp_rank == 0:
664
- if isinstance(self.tree_cache, RadixCache):
665
- self.tree_cache_metrics["total"] += (
666
- adder.log_input_tokens + adder.log_hit_tokens
667
- ) / 10**9
668
- self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9
669
- tree_cache_hit_rate = (
670
- self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
671
- )
672
- else:
673
- tree_cache_hit_rate = 0.0
674
-
675
- num_used = self.max_total_num_tokens - (
676
- self.token_to_kv_pool.available_size()
677
- + self.tree_cache.evictable_size()
678
- )
679
-
680
- if num_mixed_running > 0:
681
- logger.info(
682
- f"Prefill batch"
683
- f"(mixed #running-req: {num_mixed_running}). "
684
- f"#new-seq: {len(can_run_list)}, "
685
- f"#new-token: {adder.log_input_tokens}, "
686
- f"#cached-token: {adder.log_hit_tokens}, "
687
- f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
688
- f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
689
- f"#queue-req: {len(self.waiting_queue) + has_inflight}"
690
- )
691
- else:
692
- logger.info(
693
- f"Prefill batch. "
694
- f"#new-seq: {len(can_run_list)}, "
695
- f"#new-token: {adder.log_input_tokens}, "
696
- f"#cached-token: {adder.log_hit_tokens}, "
697
- f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
698
- f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
699
- f"#running-req: {running_bs}, "
700
- f"#queue-req: {len(self.waiting_queue) + has_inflight}"
701
- )
736
+ self.log_prefill_stats(adder, can_run_list, running_bs, has_inflight)
702
737
 
703
738
  # Create a new batch
704
739
  new_batch = ScheduleBatch.init_new(
@@ -753,7 +788,7 @@ class Scheduler:
753
788
  )
754
789
 
755
790
  # Check for jump-forward
756
- if not self.disable_regex_jump_forward:
791
+ if not self.disable_jump_forward:
757
792
  jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
758
793
  self.waiting_queue.extend(jump_forward_reqs)
759
794
  if batch.is_empty():
@@ -768,8 +803,8 @@ class Scheduler:
768
803
  self.forward_ct += 1
769
804
 
770
805
  if self.is_generation:
806
+ model_worker_batch = batch.get_model_worker_batch()
771
807
  if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
772
- model_worker_batch = batch.get_model_worker_batch()
773
808
  logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
774
809
  model_worker_batch
775
810
  )
@@ -897,9 +932,7 @@ class Scheduler:
897
932
  if req.is_retracted:
898
933
  continue
899
934
 
900
- if self.server_args.enable_overlap_schedule and (
901
- req.finished()
902
- ):
935
+ if self.server_args.enable_overlap_schedule and (req.finished()):
903
936
  self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
904
937
  continue
905
938
 
@@ -925,8 +958,11 @@ class Scheduler:
925
958
  self.token_to_kv_pool.free_group_end()
926
959
 
927
960
  self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
928
- if self.tp_rank == 0 and self.forward_ct_decode % self.server_args.decode_log_interval == 0:
929
- self.print_decode_stats()
961
+ if (
962
+ self.tp_rank == 0
963
+ and self.forward_ct_decode % self.server_args.decode_log_interval == 0
964
+ ):
965
+ self.log_decode_stats()
930
966
 
931
967
  def add_logprob_return_values(
932
968
  self,
@@ -1104,6 +1140,30 @@ class Scheduler:
1104
1140
  )
1105
1141
  )
1106
1142
 
1143
+ def move_ready_grammar_requests(self):
1144
+ """Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
1145
+ num_ready_reqs = 0
1146
+ for req in self.grammar_queue:
1147
+ try:
1148
+ req.grammar = req.grammar.result(timeout=0.05)
1149
+ num_ready_reqs += 1
1150
+ except futures._base.TimeoutError:
1151
+ break
1152
+
1153
+ if self.tp_size > 1:
1154
+ # Sync across TP ranks to make sure they have the same number of ready requests
1155
+ tensor = torch.tensor(num_ready_reqs, dtype=torch.int32)
1156
+ torch.distributed.all_reduce(
1157
+ tensor, op=torch.distributed.ReduceOp.MAX, group=self.tp_cpu_group
1158
+ )
1159
+ num_ready_reqs_max = tensor.item()
1160
+ for i in range(num_ready_reqs, num_ready_reqs_max):
1161
+ self.grammar_queue[i].grammar = self.grammar_queue[i].grammar.result()
1162
+ num_ready_reqs = num_ready_reqs_max
1163
+
1164
+ self.waiting_queue.extend(self.grammar_queue[:num_ready_reqs])
1165
+ self.grammar_queue = self.grammar_queue[num_ready_reqs:]
1166
+
1107
1167
  def flush_cache(self):
1108
1168
  """Flush the memory pool and cache."""
1109
1169
  if len(self.waiting_queue) == 0 and (
@@ -1111,9 +1171,8 @@ class Scheduler:
1111
1171
  ):
1112
1172
  self.tree_cache.reset()
1113
1173
  self.tree_cache_metrics = {"total": 0, "hit": 0}
1114
- if self.grammar_cache is not None:
1115
- self.grammar_cache.reset()
1116
- # TODO(dark): reset the bnf cache
1174
+ if self.grammar_backend:
1175
+ self.grammar_backend.reset()
1117
1176
  self.req_to_token_pool.clear()
1118
1177
  self.token_to_kv_pool.clear()
1119
1178
  torch.cuda.empty_cache()