sglang 0.3.5__py3-none-any.whl → 0.3.5.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. sglang/bench_offline_throughput.py +309 -0
  2. sglang/bench_serving.py +148 -24
  3. sglang/srt/configs/model_config.py +5 -2
  4. sglang/srt/constrained/__init__.py +2 -66
  5. sglang/srt/constrained/base_grammar_backend.py +73 -0
  6. sglang/srt/constrained/outlines_backend.py +165 -0
  7. sglang/srt/constrained/outlines_jump_forward.py +182 -0
  8. sglang/srt/constrained/xgrammar_backend.py +150 -0
  9. sglang/srt/layers/attention/triton_ops/decode_attention.py +7 -0
  10. sglang/srt/layers/attention/triton_ops/extend_attention.py +6 -0
  11. sglang/srt/layers/fused_moe/fused_moe.py +23 -7
  12. sglang/srt/layers/fused_moe/patch.py +4 -2
  13. sglang/srt/layers/quantization/base_config.py +4 -6
  14. sglang/srt/layers/vocab_parallel_embedding.py +216 -150
  15. sglang/srt/managers/detokenizer_manager.py +0 -14
  16. sglang/srt/managers/io_struct.py +5 -3
  17. sglang/srt/managers/schedule_batch.py +14 -20
  18. sglang/srt/managers/scheduler.py +159 -96
  19. sglang/srt/managers/tokenizer_manager.py +81 -17
  20. sglang/srt/metrics/collector.py +211 -0
  21. sglang/srt/metrics/func_timer.py +108 -0
  22. sglang/srt/mm_utils.py +1 -1
  23. sglang/srt/model_executor/cuda_graph_runner.py +2 -2
  24. sglang/srt/model_executor/forward_batch_info.py +7 -3
  25. sglang/srt/model_executor/model_runner.py +6 -2
  26. sglang/srt/models/gemma2_reward.py +69 -0
  27. sglang/srt/models/gpt2.py +31 -37
  28. sglang/srt/models/internlm2_reward.py +62 -0
  29. sglang/srt/models/llama.py +11 -6
  30. sglang/srt/models/llama_reward.py +5 -26
  31. sglang/srt/models/qwen2_vl.py +5 -7
  32. sglang/srt/openai_api/adapter.py +11 -4
  33. sglang/srt/openai_api/protocol.py +29 -26
  34. sglang/srt/sampling/sampling_batch_info.py +2 -3
  35. sglang/srt/sampling/sampling_params.py +2 -16
  36. sglang/srt/server.py +60 -17
  37. sglang/srt/server_args.py +66 -25
  38. sglang/srt/utils.py +120 -0
  39. sglang/test/simple_eval_common.py +1 -1
  40. sglang/test/simple_eval_humaneval.py +2 -2
  41. sglang/test/simple_eval_mgsm.py +2 -2
  42. sglang/test/test_utils.py +21 -7
  43. sglang/utils.py +1 -0
  44. sglang/version.py +1 -1
  45. {sglang-0.3.5.dist-info → sglang-0.3.5.post2.dist-info}/METADATA +12 -8
  46. {sglang-0.3.5.dist-info → sglang-0.3.5.post2.dist-info}/RECORD +49 -45
  47. {sglang-0.3.5.dist-info → sglang-0.3.5.post2.dist-info}/WHEEL +1 -1
  48. sglang/srt/constrained/base_tool_cache.py +0 -65
  49. sglang/srt/constrained/bnf_cache.py +0 -61
  50. sglang/srt/constrained/fsm_cache.py +0 -95
  51. sglang/srt/constrained/grammar.py +0 -190
  52. sglang/srt/constrained/jump_forward.py +0 -203
  53. {sglang-0.3.5.dist-info → sglang-0.3.5.post2.dist-info}/LICENSE +0 -0
  54. {sglang-0.3.5.dist-info → sglang-0.3.5.post2.dist-info}/top_level.txt +0 -0
@@ -37,7 +37,7 @@ import torch
37
37
 
38
38
  from sglang.global_config import global_config
39
39
  from sglang.srt.configs.model_config import ModelConfig
40
- from sglang.srt.constrained.grammar import Grammar
40
+ from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
41
41
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
42
42
  from sglang.srt.mem_cache.chunk_cache import ChunkCache
43
43
  from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
@@ -107,12 +107,14 @@ class FINISH_LENGTH(BaseFinishReason):
107
107
 
108
108
 
109
109
  class FINISH_ABORT(BaseFinishReason):
110
- def __init__(self):
110
+ def __init__(self, message="Unknown error"):
111
111
  super().__init__(is_error=True)
112
+ self.message = message
112
113
 
113
114
  def to_json(self):
114
115
  return {
115
116
  "type": "abort",
117
+ "message": self.message,
116
118
  }
117
119
 
118
120
 
@@ -133,6 +135,7 @@ class ImageInputs:
133
135
  aspect_ratio_mask: Optional[List[torch.Tensor]] = None
134
136
  # QWen2-VL related
135
137
  image_grid_thws: List[Tuple[int, int, int]] = None
138
+ mrope_position_delta: Optional[torch.Tensor] = None
136
139
 
137
140
  @staticmethod
138
141
  def from_dict(obj, vocab_size):
@@ -211,7 +214,7 @@ class Req:
211
214
  # this does not include the jump forward tokens.
212
215
  self.completion_tokens_wo_jump_forward = 0
213
216
 
214
- # For vision inputs
217
+ # For multimodal inputs
215
218
  self.image_inputs: Optional[ImageInputs] = None
216
219
 
217
220
  # Prefix info
@@ -246,14 +249,11 @@ class Req:
246
249
  self.embedding = None
247
250
 
248
251
  # Constrained decoding
249
- self.grammar: Optional[Grammar] = None
252
+ self.grammar: Optional[BaseGrammarObject] = None
250
253
 
251
254
  # The number of cached tokens, that were already cached in the KV cache
252
255
  self.cached_tokens = 0
253
256
 
254
- # For Qwen2-VL
255
- self.mrope_position_delta = [] # use mutable object
256
-
257
257
  # whether request reached finished condition
258
258
  def finished(self) -> bool:
259
259
  return self.finished_reason is not None
@@ -359,8 +359,6 @@ class Req:
359
359
  return
360
360
 
361
361
  def jump_forward_and_retokenize(self, jump_forward_str, next_state):
362
- assert self.grammar is not None and self.tokenizer is not None
363
-
364
362
  if self.origin_input_text is None:
365
363
  # Recovering text can only use unpadded ids
366
364
  self.origin_input_text = self.tokenizer.decode(
@@ -809,9 +807,10 @@ class ScheduleBatch:
809
807
 
810
808
  for i, req in enumerate(self.reqs):
811
809
  if req.grammar is not None:
812
- jump_helper = req.grammar.try_jump(req.tokenizer)
813
- if jump_helper.can_jump():
814
- suffix_ids = jump_helper.suffix_ids
810
+ jump_helper = req.grammar.try_jump_forward(req.tokenizer)
811
+ if jump_helper:
812
+ suffix_ids, _ = jump_helper
813
+
815
814
  # Current ids, for cache and revert
816
815
  cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1]
817
816
  cur_output_ids = req.output_ids
@@ -827,6 +826,8 @@ class ScheduleBatch:
827
826
  next_state,
828
827
  ) = req.grammar.jump_forward_str_state(jump_helper)
829
828
 
829
+ # Make the incrementally decoded text part of jump_forward_str
830
+ # so that the UTF-8 will not corrupt
830
831
  jump_forward_str = new_text + jump_forward_str
831
832
  if not req.jump_forward_and_retokenize(
832
833
  jump_forward_str, next_state
@@ -900,8 +901,7 @@ class ScheduleBatch:
900
901
  keep_indices = [
901
902
  i
902
903
  for i in range(len(self.reqs))
903
- if not self.reqs[i].finished()
904
- and self.reqs[i] is not being_chunked_req
904
+ if not self.reqs[i].finished() and self.reqs[i] is not being_chunked_req
905
905
  ]
906
906
 
907
907
  if keep_indices is None or len(keep_indices) == 0:
@@ -984,8 +984,6 @@ class ScheduleBatch:
984
984
  global bid
985
985
  bid += 1
986
986
 
987
- mrope_positions_delta = [req.mrope_position_delta for req in self.reqs]
988
-
989
987
  return ModelWorkerBatch(
990
988
  bid=bid,
991
989
  forward_mode=self.forward_mode,
@@ -1008,7 +1006,6 @@ class ScheduleBatch:
1008
1006
  encoder_out_cache_loc=self.encoder_out_cache_loc,
1009
1007
  lora_paths=[req.lora_path for req in self.reqs],
1010
1008
  sampling_info=self.sampling_info,
1011
- mrope_positions_delta=mrope_positions_delta,
1012
1009
  )
1013
1010
 
1014
1011
  def copy(self):
@@ -1075,9 +1072,6 @@ class ModelWorkerBatch:
1075
1072
  # Sampling info
1076
1073
  sampling_info: SamplingBatchInfo
1077
1074
 
1078
- # For Qwen2-VL
1079
- mrope_positions_delta: List[List[int]]
1080
-
1081
1075
  def copy(self):
1082
1076
  return dataclasses.replace(self, sampling_info=self.sampling_info.copy())
1083
1077
 
@@ -21,6 +21,7 @@ import threading
21
21
  import time
22
22
  import warnings
23
23
  from collections import deque
24
+ from concurrent import futures
24
25
  from types import SimpleNamespace
25
26
  from typing import List, Optional
26
27
 
@@ -29,7 +30,6 @@ import zmq
29
30
 
30
31
  from sglang.global_config import global_config
31
32
  from sglang.srt.configs.model_config import ModelConfig
32
- from sglang.srt.constrained.grammar import GrammarCache
33
33
  from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
34
34
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
35
35
  from sglang.srt.managers.io_struct import (
@@ -62,6 +62,7 @@ from sglang.srt.managers.tp_worker import TpModelWorker
62
62
  from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
63
63
  from sglang.srt.mem_cache.chunk_cache import ChunkCache
64
64
  from sglang.srt.mem_cache.radix_cache import RadixCache
65
+ from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
65
66
  from sglang.srt.server_args import PortArgs, ServerArgs
66
67
  from sglang.srt.utils import (
67
68
  broadcast_pyobj,
@@ -99,11 +100,12 @@ class Scheduler:
99
100
  self.tp_rank = tp_rank
100
101
  self.tp_size = server_args.tp_size
101
102
  self.schedule_policy = server_args.schedule_policy
102
- self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
103
+ self.disable_jump_forward = server_args.disable_jump_forward
103
104
  self.lora_paths = server_args.lora_paths
104
105
  self.max_loras_per_batch = server_args.max_loras_per_batch
105
106
  self.enable_overlap = server_args.enable_overlap_schedule
106
107
  self.skip_tokenizer_init = server_args.skip_tokenizer_init
108
+ self.enable_metrics = server_args.enable_metrics
107
109
 
108
110
  # Init inter-process communication
109
111
  context = zmq.Context(2)
@@ -112,6 +114,9 @@ class Scheduler:
112
114
  self.recv_from_tokenizer = get_zmq_socket(
113
115
  context, zmq.PULL, port_args.scheduler_input_ipc_name
114
116
  )
117
+ self.send_to_tokenizer = get_zmq_socket(
118
+ context, zmq.PUSH, port_args.tokenizer_ipc_name
119
+ )
115
120
 
116
121
  if server_args.skip_tokenizer_init:
117
122
  # Directly send to the tokenizer/api
@@ -125,6 +130,7 @@ class Scheduler:
125
130
  )
126
131
  else:
127
132
  self.recv_from_tokenizer = None
133
+ self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
128
134
  self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
129
135
 
130
136
  # Init tokenizer
@@ -222,7 +228,7 @@ class Scheduler:
222
228
  self.forward_ct = 0
223
229
  self.forward_ct_decode = 0
224
230
  self.num_generated_tokens = 0
225
- self.last_stats_tic = time.time()
231
+ self.last_decode_stats_tic = time.time()
226
232
  self.stream_interval = server_args.stream_interval
227
233
 
228
234
  # Init chunked prefill
@@ -232,21 +238,33 @@ class Scheduler:
232
238
  self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
233
239
  )
234
240
 
235
- # Init the FSM cache for constrained generation
236
- self.grammar_cache = None
237
-
241
+ # Init the grammar backend for constrained generation
242
+ self.grammar_queue: List[Req] = []
238
243
  if not server_args.skip_tokenizer_init:
239
- self.grammar_cache = GrammarCache(
240
- server_args.tokenizer_path,
241
- {
242
- "tokenizer_mode": server_args.tokenizer_mode,
243
- "trust_remote_code": server_args.trust_remote_code,
244
- },
245
- skip_tokenizer_init=server_args.skip_tokenizer_init,
246
- whitespace_patterns=server_args.constrained_json_whitespace_pattern,
247
- backend=server_args.grammar_backend,
248
- allow_jump=not server_args.disable_regex_jump_forward,
249
- )
244
+ if server_args.grammar_backend == "outlines":
245
+ from sglang.srt.constrained.outlines_backend import (
246
+ OutlinesGrammarBackend,
247
+ )
248
+
249
+ self.grammar_backend = OutlinesGrammarBackend(
250
+ self.tokenizer,
251
+ whitespace_pattern=server_args.constrained_json_whitespace_pattern,
252
+ allow_jump_forward=not server_args.disable_jump_forward,
253
+ )
254
+ elif server_args.grammar_backend == "xgrammar":
255
+ from sglang.srt.constrained.xgrammar_backend import (
256
+ XGrammarGrammarBackend,
257
+ )
258
+
259
+ self.grammar_backend = XGrammarGrammarBackend(
260
+ self.tokenizer, vocab_size=self.model_config.vocab_size
261
+ )
262
+ else:
263
+ raise ValueError(
264
+ f"Invalid grammar backend: {server_args.grammar_backend}"
265
+ )
266
+ else:
267
+ self.grammar_backend = None
250
268
 
251
269
  # Init new token estimation
252
270
  assert (
@@ -292,6 +310,16 @@ class Scheduler:
292
310
  with_stack=True,
293
311
  )
294
312
 
313
+ # Init metrics stats
314
+ self.stats = SchedulerStats()
315
+ if self.enable_metrics:
316
+ self.metrics_collector = SchedulerMetricsCollector(
317
+ labels={
318
+ "model_name": self.server_args.served_model_name,
319
+ # TODO: Add lora name/path in the future,
320
+ },
321
+ )
322
+
295
323
  def watchdog_thread(self):
296
324
  self.watchdog_last_forward_ct = 0
297
325
  self.watchdog_last_time = time.time()
@@ -397,7 +425,7 @@ class Scheduler:
397
425
  self.abort_request(recv_req)
398
426
  elif isinstance(recv_req, UpdateWeightReqInput):
399
427
  success, message = self.update_weights(recv_req)
400
- self.send_to_detokenizer.send_pyobj(
428
+ self.send_to_tokenizer.send_pyobj(
401
429
  UpdateWeightReqOutput(success, message)
402
430
  )
403
431
  elif isinstance(recv_req, ProfileReq):
@@ -406,7 +434,7 @@ class Scheduler:
406
434
  else:
407
435
  self.stop_profile()
408
436
  elif isinstance(recv_req, GetMemPoolSizeReq):
409
- self.send_to_detokenizer.send_pyobj(
437
+ self.send_to_tokenizer.send_pyobj(
410
438
  GetMemPoolSizeReqOutput(self.max_total_num_tokens)
411
439
  )
412
440
  else:
@@ -443,22 +471,6 @@ class Scheduler:
443
471
  # By default, only return the logprobs for output tokens
444
472
  req.logprob_start_len = len(recv_req.input_ids) - 1
445
473
 
446
- # Init regex FSM or BNF
447
- if (
448
- req.sampling_params.json_schema is not None
449
- or req.sampling_params.regex is not None
450
- ):
451
- assert self.grammar_cache is not None
452
- if req.sampling_params.json_schema is not None:
453
- req.grammar = self.grammar_cache.query(
454
- ("json", req.sampling_params.json_schema),
455
- self.model_config.vocab_size,
456
- )
457
- elif req.sampling_params.regex is not None:
458
- req.grammar = self.grammar_cache.query(
459
- ("regex", req.sampling_params.regex), self.model_config.vocab_size
460
- )
461
-
462
474
  # Truncate prompts that are too long
463
475
  if len(req.origin_input_ids) > self.max_req_input_len:
464
476
  logger.warning(
@@ -476,7 +488,27 @@ class Scheduler:
476
488
  self.max_req_len - len(req.origin_input_ids) - 1,
477
489
  )
478
490
 
479
- self.waiting_queue.append(req)
491
+ # Init grammar cache for this request
492
+ add_to_grammar_queue = False
493
+ if (
494
+ req.sampling_params.json_schema is not None
495
+ or req.sampling_params.regex is not None
496
+ ):
497
+ assert self.grammar_backend is not None
498
+ if req.sampling_params.json_schema is not None:
499
+ key = ("json", req.sampling_params.json_schema)
500
+ elif req.sampling_params.regex is not None:
501
+ key = ("regex", req.sampling_params.regex)
502
+
503
+ req.grammar = self.grammar_backend.get_cached_value(key)
504
+ if not req.grammar:
505
+ req.grammar = self.grammar_backend.get_future_value(key)
506
+ add_to_grammar_queue = True
507
+
508
+ if add_to_grammar_queue:
509
+ self.grammar_queue.append(req)
510
+ else:
511
+ self.waiting_queue.append(req)
480
512
 
481
513
  def handle_embedding_request(
482
514
  self,
@@ -500,23 +532,68 @@ class Scheduler:
500
532
 
501
533
  self.waiting_queue.append(req)
502
534
 
503
- def print_decode_stats(self):
535
+ def log_prefill_stats(self, adder, can_run_list, running_bs, has_inflight):
536
+ if isinstance(self.tree_cache, RadixCache):
537
+ self.tree_cache_metrics["total"] += (
538
+ adder.log_input_tokens + adder.log_hit_tokens
539
+ ) / 10**9
540
+ self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9
541
+ tree_cache_hit_rate = (
542
+ self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
543
+ )
544
+ else:
545
+ tree_cache_hit_rate = 0.0
546
+
547
+ num_used = self.max_total_num_tokens - (
548
+ self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
549
+ )
550
+
551
+ logger.info(
552
+ f"Prefill batch. "
553
+ f"#new-seq: {len(can_run_list)}, "
554
+ f"#new-token: {adder.log_input_tokens}, "
555
+ f"#cached-token: {adder.log_hit_tokens}, "
556
+ f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
557
+ f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
558
+ f"#running-req: {running_bs}, "
559
+ f"#queue-req: {len(self.waiting_queue) + has_inflight}"
560
+ )
561
+
562
+ if self.enable_metrics:
563
+ self.stats.num_running_reqs = running_bs
564
+ self.stats.num_used_tokens = num_used
565
+ self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
566
+ self.stats.num_queue_reqs = len(self.waiting_queue) + has_inflight
567
+ self.stats.cache_hit_rate = tree_cache_hit_rate
568
+ self.metrics_collector.log_stats(self.stats)
569
+
570
+ def log_decode_stats(self):
504
571
  num_used = self.max_total_num_tokens - (
505
572
  self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
506
573
  )
507
- throughput = self.num_generated_tokens / (time.time() - self.last_stats_tic)
574
+ gen_throughput = self.num_generated_tokens / (
575
+ time.time() - self.last_decode_stats_tic
576
+ )
508
577
  self.num_generated_tokens = 0
509
- self.last_stats_tic = time.time()
578
+ self.last_decode_stats_tic = time.time()
510
579
  num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
511
580
  logger.info(
512
581
  f"Decode batch. "
513
582
  f"#running-req: {num_running_reqs}, "
514
583
  f"#token: {num_used}, "
515
584
  f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
516
- f"gen throughput (token/s): {throughput:.2f}, "
585
+ f"gen throughput (token/s): {gen_throughput:.2f}, "
517
586
  f"#queue-req: {len(self.waiting_queue)}"
518
587
  )
519
588
 
589
+ if self.enable_metrics:
590
+ self.stats.num_running_reqs = num_running_reqs
591
+ self.stats.num_used_tokens = num_used
592
+ self.stats.token_usage = num_used / self.max_total_num_tokens
593
+ self.stats.gen_throughput = gen_throughput
594
+ self.stats.num_queue_reqs = len(self.waiting_queue)
595
+ self.metrics_collector.log_stats(self.stats)
596
+
520
597
  def check_memory(self):
521
598
  available_size = (
522
599
  self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
@@ -546,9 +623,7 @@ class Scheduler:
546
623
  and not self.last_batch.is_empty()
547
624
  ):
548
625
  if self.being_chunked_req:
549
- self.last_batch.filter_batch(
550
- being_chunked_req=self.being_chunked_req
551
- )
626
+ self.last_batch.filter_batch(being_chunked_req=self.being_chunked_req)
552
627
  self.tree_cache.cache_unfinished_req(self.being_chunked_req)
553
628
  # Inflight request keeps its rid but will get a new req_pool_idx.
554
629
  self.req_to_token_pool.free(self.being_chunked_req.req_pool_idx)
@@ -579,6 +654,10 @@ class Scheduler:
579
654
  return self.running_batch
580
655
 
581
656
  def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
657
+ # Check if the grammar is ready in the grammar queue
658
+ if self.grammar_queue:
659
+ self.move_ready_grammar_requests()
660
+
582
661
  # Handle the cases where prefill is not allowed
583
662
  if (
584
663
  self.batch_is_full or len(self.waiting_queue) == 0
@@ -594,7 +673,6 @@ class Scheduler:
594
673
  prefix_computed = self.policy.calc_priority(self.waiting_queue)
595
674
 
596
675
  # Prefill policy
597
- num_mixed_running = running_bs if self.is_mixed_chunk else 0
598
676
  adder = PrefillAdder(
599
677
  self.tree_cache,
600
678
  self.running_batch,
@@ -602,15 +680,13 @@ class Scheduler:
602
680
  self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(),
603
681
  self.max_prefill_tokens,
604
682
  self.chunked_prefill_size,
605
- num_mixed_running,
683
+ running_bs if self.is_mixed_chunk else 0,
606
684
  )
607
685
 
608
686
  has_inflight = self.being_chunked_req is not None
609
687
  if has_inflight:
610
688
  self.being_chunked_req.init_next_round_input()
611
- self.being_chunked_req = adder.add_inflight_req(
612
- self.being_chunked_req
613
- )
689
+ self.being_chunked_req = adder.add_inflight_req(self.being_chunked_req)
614
690
 
615
691
  if self.lora_paths:
616
692
  lora_set = (
@@ -661,44 +737,7 @@ class Scheduler:
661
737
 
662
738
  # Print stats
663
739
  if self.tp_rank == 0:
664
- if isinstance(self.tree_cache, RadixCache):
665
- self.tree_cache_metrics["total"] += (
666
- adder.log_input_tokens + adder.log_hit_tokens
667
- ) / 10**9
668
- self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9
669
- tree_cache_hit_rate = (
670
- self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
671
- )
672
- else:
673
- tree_cache_hit_rate = 0.0
674
-
675
- num_used = self.max_total_num_tokens - (
676
- self.token_to_kv_pool.available_size()
677
- + self.tree_cache.evictable_size()
678
- )
679
-
680
- if num_mixed_running > 0:
681
- logger.info(
682
- f"Prefill batch"
683
- f"(mixed #running-req: {num_mixed_running}). "
684
- f"#new-seq: {len(can_run_list)}, "
685
- f"#new-token: {adder.log_input_tokens}, "
686
- f"#cached-token: {adder.log_hit_tokens}, "
687
- f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
688
- f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
689
- f"#queue-req: {len(self.waiting_queue) + has_inflight}"
690
- )
691
- else:
692
- logger.info(
693
- f"Prefill batch. "
694
- f"#new-seq: {len(can_run_list)}, "
695
- f"#new-token: {adder.log_input_tokens}, "
696
- f"#cached-token: {adder.log_hit_tokens}, "
697
- f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
698
- f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
699
- f"#running-req: {running_bs}, "
700
- f"#queue-req: {len(self.waiting_queue) + has_inflight}"
701
- )
740
+ self.log_prefill_stats(adder, can_run_list, running_bs, has_inflight)
702
741
 
703
742
  # Create a new batch
704
743
  new_batch = ScheduleBatch.init_new(
@@ -753,7 +792,7 @@ class Scheduler:
753
792
  )
754
793
 
755
794
  # Check for jump-forward
756
- if not self.disable_regex_jump_forward:
795
+ if not self.disable_jump_forward:
757
796
  jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
758
797
  self.waiting_queue.extend(jump_forward_reqs)
759
798
  if batch.is_empty():
@@ -768,8 +807,8 @@ class Scheduler:
768
807
  self.forward_ct += 1
769
808
 
770
809
  if self.is_generation:
810
+ model_worker_batch = batch.get_model_worker_batch()
771
811
  if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
772
- model_worker_batch = batch.get_model_worker_batch()
773
812
  logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
774
813
  model_worker_batch
775
814
  )
@@ -897,9 +936,7 @@ class Scheduler:
897
936
  if req.is_retracted:
898
937
  continue
899
938
 
900
- if self.server_args.enable_overlap_schedule and (
901
- req.finished()
902
- ):
939
+ if self.server_args.enable_overlap_schedule and (req.finished()):
903
940
  self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
904
941
  continue
905
942
 
@@ -925,8 +962,11 @@ class Scheduler:
925
962
  self.token_to_kv_pool.free_group_end()
926
963
 
927
964
  self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
928
- if self.tp_rank == 0 and self.forward_ct_decode % self.server_args.decode_log_interval == 0:
929
- self.print_decode_stats()
965
+ if (
966
+ self.tp_rank == 0
967
+ and self.forward_ct_decode % self.server_args.decode_log_interval == 0
968
+ ):
969
+ self.log_decode_stats()
930
970
 
931
971
  def add_logprob_return_values(
932
972
  self,
@@ -1104,6 +1144,30 @@ class Scheduler:
1104
1144
  )
1105
1145
  )
1106
1146
 
1147
+ def move_ready_grammar_requests(self):
1148
+ """Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
1149
+ num_ready_reqs = 0
1150
+ for req in self.grammar_queue:
1151
+ try:
1152
+ req.grammar = req.grammar.result(timeout=0.05)
1153
+ num_ready_reqs += 1
1154
+ except futures._base.TimeoutError:
1155
+ break
1156
+
1157
+ if self.tp_size > 1:
1158
+ # Sync across TP ranks to make sure they have the same number of ready requests
1159
+ tensor = torch.tensor(num_ready_reqs, dtype=torch.int32)
1160
+ torch.distributed.all_reduce(
1161
+ tensor, op=torch.distributed.ReduceOp.MAX, group=self.tp_cpu_group
1162
+ )
1163
+ num_ready_reqs_max = tensor.item()
1164
+ for i in range(num_ready_reqs, num_ready_reqs_max):
1165
+ self.grammar_queue[i].grammar = self.grammar_queue[i].grammar.result()
1166
+ num_ready_reqs = num_ready_reqs_max
1167
+
1168
+ self.waiting_queue.extend(self.grammar_queue[:num_ready_reqs])
1169
+ self.grammar_queue = self.grammar_queue[num_ready_reqs:]
1170
+
1107
1171
  def flush_cache(self):
1108
1172
  """Flush the memory pool and cache."""
1109
1173
  if len(self.waiting_queue) == 0 and (
@@ -1111,9 +1175,8 @@ class Scheduler:
1111
1175
  ):
1112
1176
  self.tree_cache.reset()
1113
1177
  self.tree_cache_metrics = {"total": 0, "hit": 0}
1114
- if self.grammar_cache is not None:
1115
- self.grammar_cache.reset()
1116
- # TODO(dark): reset the bnf cache
1178
+ if self.grammar_backend:
1179
+ self.grammar_backend.reset()
1117
1180
  self.req_to_token_pool.clear()
1118
1181
  self.token_to_kv_pool.clear()
1119
1182
  torch.cuda.empty_cache()