sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. sglang/bench_offline_throughput.py +4 -2
  2. sglang/bench_one_batch.py +2 -2
  3. sglang/bench_one_batch_server.py +143 -15
  4. sglang/bench_serving.py +9 -7
  5. sglang/compile_deep_gemm.py +1 -1
  6. sglang/eval/loogle_eval.py +157 -0
  7. sglang/lang/chat_template.py +78 -78
  8. sglang/lang/tracer.py +1 -1
  9. sglang/srt/code_completion_parser.py +1 -1
  10. sglang/srt/configs/deepseekvl2.py +2 -2
  11. sglang/srt/configs/model_config.py +1 -0
  12. sglang/srt/constrained/base_grammar_backend.py +55 -72
  13. sglang/srt/constrained/llguidance_backend.py +25 -21
  14. sglang/srt/constrained/outlines_backend.py +27 -26
  15. sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
  16. sglang/srt/constrained/xgrammar_backend.py +69 -43
  17. sglang/srt/conversation.py +48 -43
  18. sglang/srt/disaggregation/base/conn.py +1 -0
  19. sglang/srt/disaggregation/decode.py +7 -2
  20. sglang/srt/disaggregation/fake/conn.py +1 -1
  21. sglang/srt/disaggregation/mooncake/conn.py +227 -120
  22. sglang/srt/disaggregation/nixl/conn.py +1 -0
  23. sglang/srt/disaggregation/prefill.py +7 -4
  24. sglang/srt/disaggregation/utils.py +7 -1
  25. sglang/srt/entrypoints/engine.py +17 -2
  26. sglang/srt/entrypoints/http_server.py +17 -2
  27. sglang/srt/function_call_parser.py +2 -2
  28. sglang/srt/layers/attention/flashattention_backend.py +1 -1
  29. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
  30. sglang/srt/layers/attention/utils.py +4 -2
  31. sglang/srt/layers/dp_attention.py +71 -21
  32. sglang/srt/layers/layernorm.py +1 -1
  33. sglang/srt/layers/logits_processor.py +46 -11
  34. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  35. sglang/srt/layers/moe/ep_moe/layer.py +1 -1
  36. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
  37. sglang/srt/layers/moe/topk.py +1 -1
  38. sglang/srt/layers/quantization/__init__.py +1 -1
  39. sglang/srt/layers/quantization/blockwise_int8.py +2 -2
  40. sglang/srt/layers/quantization/deep_gemm.py +72 -71
  41. sglang/srt/layers/quantization/fp8.py +2 -2
  42. sglang/srt/layers/quantization/fp8_kernel.py +3 -3
  43. sglang/srt/layers/quantization/int8_kernel.py +2 -2
  44. sglang/srt/layers/sampler.py +0 -4
  45. sglang/srt/layers/vocab_parallel_embedding.py +18 -7
  46. sglang/srt/lora/lora_manager.py +1 -1
  47. sglang/srt/lora/mem_pool.py +4 -4
  48. sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
  49. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  50. sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
  51. sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
  52. sglang/srt/lora/utils.py +1 -1
  53. sglang/srt/managers/data_parallel_controller.py +3 -3
  54. sglang/srt/managers/detokenizer_manager.py +21 -8
  55. sglang/srt/managers/io_struct.py +3 -1
  56. sglang/srt/managers/mm_utils.py +1 -1
  57. sglang/srt/managers/multimodal_processors/llava.py +46 -0
  58. sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
  59. sglang/srt/managers/schedule_batch.py +76 -24
  60. sglang/srt/managers/schedule_policy.py +0 -3
  61. sglang/srt/managers/scheduler.py +113 -88
  62. sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
  63. sglang/srt/managers/tokenizer_manager.py +133 -34
  64. sglang/srt/managers/tp_worker.py +12 -9
  65. sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
  66. sglang/srt/mem_cache/memory_pool.py +2 -0
  67. sglang/srt/metrics/collector.py +312 -37
  68. sglang/srt/model_executor/cuda_graph_runner.py +10 -11
  69. sglang/srt/model_executor/forward_batch_info.py +1 -1
  70. sglang/srt/model_executor/model_runner.py +19 -14
  71. sglang/srt/models/deepseek_janus_pro.py +2 -2
  72. sglang/srt/models/deepseek_v2.py +23 -20
  73. sglang/srt/models/llama.py +2 -0
  74. sglang/srt/models/llama4.py +5 -6
  75. sglang/srt/models/llava.py +248 -5
  76. sglang/srt/models/mixtral.py +98 -34
  77. sglang/srt/models/pixtral.py +467 -0
  78. sglang/srt/models/roberta.py +1 -1
  79. sglang/srt/models/torch_native_llama.py +1 -1
  80. sglang/srt/openai_api/adapter.py +30 -4
  81. sglang/srt/openai_api/protocol.py +0 -8
  82. sglang/srt/reasoning_parser.py +3 -3
  83. sglang/srt/sampling/custom_logit_processor.py +18 -3
  84. sglang/srt/sampling/sampling_batch_info.py +4 -56
  85. sglang/srt/sampling/sampling_params.py +2 -2
  86. sglang/srt/server_args.py +34 -4
  87. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  88. sglang/srt/speculative/eagle_utils.py +7 -7
  89. sglang/srt/speculative/eagle_worker.py +22 -19
  90. sglang/srt/utils.py +6 -5
  91. sglang/test/few_shot_gsm8k.py +2 -2
  92. sglang/test/few_shot_gsm8k_engine.py +2 -2
  93. sglang/test/run_eval.py +2 -2
  94. sglang/test/runners.py +8 -1
  95. sglang/test/send_one.py +13 -3
  96. sglang/test/simple_eval_common.py +1 -1
  97. sglang/test/simple_eval_humaneval.py +1 -1
  98. sglang/test/test_programs.py +5 -5
  99. sglang/test/test_utils.py +89 -14
  100. sglang/utils.py +1 -1
  101. sglang/version.py +1 -1
  102. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/METADATA +6 -5
  103. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/RECORD +107 -104
  104. /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
  105. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/WHEEL +0 -0
  106. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/licenses/LICENSE +0 -0
  107. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/top_level.txt +0 -0
@@ -20,7 +20,6 @@ import signal
20
20
  import sys
21
21
  import threading
22
22
  import time
23
- import warnings
24
23
  from collections import defaultdict, deque
25
24
  from concurrent import futures
26
25
  from dataclasses import dataclass
@@ -121,11 +120,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
121
120
  from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
122
121
  from sglang.srt.mem_cache.radix_cache import RadixCache
123
122
  from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
124
- from sglang.srt.model_executor.forward_batch_info import (
125
- ForwardBatch,
126
- ForwardMode,
127
- PPProxyTensors,
128
- )
123
+ from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
129
124
  from sglang.srt.reasoning_parser import ReasoningParser
130
125
  from sglang.srt.server_args import PortArgs, ServerArgs
131
126
  from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
@@ -135,6 +130,7 @@ from sglang.srt.utils import (
135
130
  broadcast_pyobj,
136
131
  configure_logger,
137
132
  crash_on_warnings,
133
+ disable_request_logging,
138
134
  get_bool_env_var,
139
135
  get_zmq_socket,
140
136
  kill_itself_when_parent_died,
@@ -153,6 +149,7 @@ logger = logging.getLogger(__name__)
153
149
  # Test retract decode for debugging purposes
154
150
  TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
155
151
  RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
152
+ GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
156
153
 
157
154
 
158
155
  @dataclass
@@ -163,6 +160,7 @@ class GenerationBatchResult:
163
160
  extend_input_len_per_req: List[int]
164
161
  extend_logprob_start_len_per_req: List[int]
165
162
  bid: int
163
+ can_run_cuda_graph: bool
166
164
 
167
165
 
168
166
  @dataclass
@@ -209,7 +207,8 @@ class Scheduler(
209
207
  self.page_size = server_args.page_size
210
208
 
211
209
  # Distributed rank info
212
- self.attn_tp_rank, self.attn_tp_size, self.dp_rank = (
210
+ self.dp_size = server_args.dp_size
211
+ self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = (
213
212
  compute_dp_attention_world_info(
214
213
  server_args.enable_dp_attention,
215
214
  self.tp_rank,
@@ -326,13 +325,14 @@ class Scheduler(
326
325
  set_random_seed(self.random_seed)
327
326
 
328
327
  # Print debug info
329
- logger.info(
330
- f"max_total_num_tokens={self.max_total_num_tokens}, "
331
- f"chunked_prefill_size={server_args.chunked_prefill_size}, "
332
- f"max_prefill_tokens={self.max_prefill_tokens}, "
333
- f"max_running_requests={self.max_running_requests}, "
334
- f"context_len={self.model_config.context_len}"
335
- )
328
+ if tp_rank == 0:
329
+ logger.info(
330
+ f"max_total_num_tokens={self.max_total_num_tokens}, "
331
+ f"chunked_prefill_size={server_args.chunked_prefill_size}, "
332
+ f"max_prefill_tokens={self.max_prefill_tokens}, "
333
+ f"max_running_requests={self.max_running_requests}, "
334
+ f"context_len={self.model_config.context_len}"
335
+ )
336
336
 
337
337
  # Init memory pool and cache
338
338
  self.init_memory_pool_and_cache()
@@ -531,10 +531,6 @@ class Scheduler(
531
531
  )
532
532
 
533
533
  def init_metrics(self):
534
- # The largest prefill length of a single request
535
- self._largest_prefill_len: int = 0
536
- # The largest context length (prefill + generation) of a single request
537
- self._largest_prefill_decode_len: int = 0
538
534
  self.last_gen_throughput: float = 0.0
539
535
  self.last_input_throughput: float = 0.0
540
536
  self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
@@ -720,7 +716,7 @@ class Scheduler(
720
716
  server_is_idle = False
721
717
  result = self.run_batch(self.cur_batch)
722
718
 
723
- # send the outputs to the next step
719
+ # (last rank) send the outputs to the next step
724
720
  if self.pp_group.is_last_rank:
725
721
  if self.cur_batch:
726
722
  next_token_ids, bids[mb_id] = (
@@ -755,24 +751,25 @@ class Scheduler(
755
751
  extend_input_len_per_req=None,
756
752
  extend_logprob_start_len_per_req=None,
757
753
  bid=bids[next_mb_id],
754
+ can_run_cuda_graph=result.can_run_cuda_graph,
758
755
  )
759
756
  self.process_batch_result(mbs[next_mb_id], output_result)
760
757
  last_mbs[next_mb_id] = mbs[next_mb_id]
761
758
 
762
- # carry the outputs to the next stage
759
+ # (not last rank)
763
760
  if not self.pp_group.is_last_rank:
764
761
  if self.cur_batch:
765
762
  bids[mb_id] = result.bid
763
+ # carry the outputs to the next stage
764
+ # send the outputs from the last round to let the next stage worker run post processing
766
765
  if pp_outputs:
767
- # send the outputs from the last round to let the next stage worker run post processing
768
766
  self.pp_group.send_tensor_dict(
769
767
  pp_outputs.tensors,
770
768
  all_gather_group=self.attn_tp_group,
771
769
  )
772
770
 
773
- if not self.pp_group.is_last_rank:
774
771
  # send out reqs to the next stage
775
- dp_offset = self.dp_rank * self.attn_tp_size
772
+ dp_offset = self.attn_dp_rank * self.attn_tp_size
776
773
  if self.attn_tp_rank == 0:
777
774
  point_to_point_pyobj(
778
775
  recv_reqs,
@@ -819,7 +816,7 @@ class Scheduler(
819
816
  recv_reqs = None
820
817
  else:
821
818
  if self.attn_tp_rank == 0:
822
- dp_offset = self.dp_rank * self.attn_tp_size
819
+ dp_offset = self.attn_dp_rank * self.attn_tp_size
823
820
  recv_reqs = point_to_point_pyobj(
824
821
  [],
825
822
  self.pp_rank * self.tp_size + dp_offset,
@@ -907,19 +904,6 @@ class Scheduler(
907
904
  fake_input_ids = [1] * seq_length
908
905
  recv_req.input_ids = fake_input_ids
909
906
 
910
- # Handle custom logit processor passed to the request
911
- custom_logit_processor = recv_req.custom_logit_processor
912
- if (
913
- not self.server_args.enable_custom_logit_processor
914
- and custom_logit_processor is not None
915
- ):
916
- logger.warning(
917
- "The SGLang server is not configured to enable custom logit processor."
918
- "The custom logit processor passed in will be ignored."
919
- "Please set --enable-custom-logits-processor to enable this feature."
920
- )
921
- custom_logit_processor = None
922
-
923
907
  if recv_req.bootstrap_port is None:
924
908
  # Use default bootstrap port
925
909
  recv_req.bootstrap_port = self.server_args.disaggregation_bootstrap_port
@@ -935,7 +919,7 @@ class Scheduler(
935
919
  stream=recv_req.stream,
936
920
  lora_path=recv_req.lora_path,
937
921
  input_embeds=recv_req.input_embeds,
938
- custom_logit_processor=custom_logit_processor,
922
+ custom_logit_processor=recv_req.custom_logit_processor,
939
923
  return_hidden_states=recv_req.return_hidden_states,
940
924
  eos_token_ids=self.model_config.hf_eos_token_id,
941
925
  bootstrap_host=recv_req.bootstrap_host,
@@ -1041,9 +1025,11 @@ class Scheduler(
1041
1025
  elif req.sampling_params.structural_tag:
1042
1026
  key = ("structural_tag", req.sampling_params.structural_tag)
1043
1027
 
1044
- req.grammar = self.grammar_backend.get_cached_value(key)
1045
- if not req.grammar:
1046
- req.grammar = self.grammar_backend.get_future_value(key)
1028
+ value, cache_hit = self.grammar_backend.get_cached_or_future_value(key)
1029
+ req.grammar = value
1030
+
1031
+ if not cache_hit:
1032
+ req.grammar_key = key
1047
1033
  add_to_grammar_queue = True
1048
1034
 
1049
1035
  if add_to_grammar_queue:
@@ -1133,9 +1119,6 @@ class Scheduler(
1133
1119
  self.token_to_kv_pool_allocator.available_size()
1134
1120
  + self.tree_cache.evictable_size()
1135
1121
  )
1136
- self._largest_prefill_len = max(
1137
- self._largest_prefill_len, adder.log_input_tokens
1138
- )
1139
1122
 
1140
1123
  num_new_seq = len(can_run_list)
1141
1124
  f = (
@@ -1173,7 +1156,9 @@ class Scheduler(
1173
1156
 
1174
1157
  self.metrics_collector.log_stats(self.stats)
1175
1158
 
1176
- def log_decode_stats(self, running_batch=None):
1159
+ def log_decode_stats(
1160
+ self, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None
1161
+ ):
1177
1162
  batch = running_batch or self.running_batch
1178
1163
 
1179
1164
  gap_latency = time.time() - self.last_decode_stats_tic
@@ -1213,6 +1198,7 @@ class Scheduler(
1213
1198
  msg += f"pre-allocated usage: {self.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "
1214
1199
 
1215
1200
  msg += (
1201
+ f"cuda graph: {can_run_cuda_graph}, "
1216
1202
  f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
1217
1203
  f"#queue-req: {len(self.waiting_queue)}"
1218
1204
  )
@@ -1225,6 +1211,7 @@ class Scheduler(
1225
1211
  self.stats.cache_hit_rate = 0.0
1226
1212
  self.stats.gen_throughput = self.last_gen_throughput
1227
1213
  self.stats.num_queue_reqs = len(self.waiting_queue)
1214
+ self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
1228
1215
  self.stats.spec_accept_length = spec_accept_length
1229
1216
  self.metrics_collector.log_stats(self.stats)
1230
1217
 
@@ -1246,9 +1233,7 @@ class Scheduler(
1246
1233
  f"{self.token_to_kv_pool_allocator.available_size()=}\n"
1247
1234
  f"{self.tree_cache.evictable_size()=}\n"
1248
1235
  )
1249
- warnings.warn(msg)
1250
- if crash_on_warnings():
1251
- raise ValueError(msg)
1236
+ raise ValueError(msg)
1252
1237
 
1253
1238
  if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
1254
1239
  msg = (
@@ -1256,9 +1241,7 @@ class Scheduler(
1256
1241
  f"available_size={len(self.req_to_token_pool.free_slots)}, "
1257
1242
  f"total_size={self.req_to_token_pool.size}\n"
1258
1243
  )
1259
- warnings.warn(msg)
1260
- if crash_on_warnings():
1261
- raise ValueError(msg)
1244
+ raise ValueError(msg)
1262
1245
 
1263
1246
  if (
1264
1247
  self.enable_metrics
@@ -1276,6 +1259,7 @@ class Scheduler(
1276
1259
  self.stats.token_usage = num_used / self.max_total_num_tokens
1277
1260
  self.stats.gen_throughput = 0
1278
1261
  self.stats.num_queue_reqs = len(self.waiting_queue)
1262
+ self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
1279
1263
  self.metrics_collector.log_stats(self.stats)
1280
1264
 
1281
1265
  def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
@@ -1346,7 +1330,7 @@ class Scheduler(
1346
1330
  return None
1347
1331
 
1348
1332
  running_bs = len(self.running_batch.reqs)
1349
- # Igore the check if self.chunked_req is not None.
1333
+ # Ignore the check if self.chunked_req is not None.
1350
1334
  # In the non-PP case, when self.chunked_req is not None, num_allocatable_reqs should always be greater than 0,
1351
1335
  # as the space for the chunked request has just been released.
1352
1336
  # In PP case, a chunked req can start in one microbatch and end in another microbatch, so the max_running_requests per microbatch should not be strict.
@@ -1540,11 +1524,11 @@ class Scheduler(
1540
1524
  if self.spec_algorithm.is_none():
1541
1525
  model_worker_batch = batch.get_model_worker_batch()
1542
1526
  if self.pp_group.is_last_rank:
1543
- logits_output, next_token_ids = (
1527
+ logits_output, next_token_ids, can_run_cuda_graph = (
1544
1528
  self.tp_worker.forward_batch_generation(model_worker_batch)
1545
1529
  )
1546
1530
  else:
1547
- pp_hidden_states_proxy_tensors, _ = (
1531
+ pp_hidden_states_proxy_tensors, _, can_run_cuda_graph = (
1548
1532
  self.tp_worker.forward_batch_generation(model_worker_batch)
1549
1533
  )
1550
1534
  bid = model_worker_batch.bid
@@ -1554,6 +1538,7 @@ class Scheduler(
1554
1538
  next_token_ids,
1555
1539
  bid,
1556
1540
  num_accepted_tokens,
1541
+ can_run_cuda_graph,
1557
1542
  ) = self.draft_worker.forward_batch_speculative_generation(batch)
1558
1543
  self.spec_num_total_accepted_tokens += (
1559
1544
  num_accepted_tokens + batch.batch_size()
@@ -1587,6 +1572,7 @@ class Scheduler(
1587
1572
  extend_input_len_per_req=extend_input_len_per_req,
1588
1573
  extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
1589
1574
  bid=bid,
1575
+ can_run_cuda_graph=can_run_cuda_graph,
1590
1576
  )
1591
1577
  else: # embedding or reward model
1592
1578
  model_worker_batch = batch.get_model_worker_batch()
@@ -1609,14 +1595,9 @@ class Scheduler(
1609
1595
  elif batch.forward_mode.is_idle():
1610
1596
  if self.enable_overlap:
1611
1597
  self.tp_worker.resolve_last_batch_result(launch_done)
1612
- if batch.next_batch_sampling_info:
1613
- batch.next_batch_sampling_info.update_regex_vocab_mask()
1614
- self.current_stream.synchronize()
1615
- batch.next_batch_sampling_info.sampling_info_done.set()
1598
+ self.set_next_batch_sampling_info_done(batch)
1616
1599
  elif batch.forward_mode.is_dummy_first():
1617
- batch.next_batch_sampling_info.update_regex_vocab_mask()
1618
- self.current_stream.synchronize()
1619
- batch.next_batch_sampling_info.sampling_info_done.set()
1600
+ self.set_next_batch_sampling_info_done(batch)
1620
1601
 
1621
1602
  if self.return_health_check_ct:
1622
1603
  # Return some signal for the health check.
@@ -1630,6 +1611,7 @@ class Scheduler(
1630
1611
  local_batch,
1631
1612
  dp_size=self.server_args.dp_size,
1632
1613
  attn_tp_size=self.attn_tp_size,
1614
+ moe_dense_tp_size=self.server_args.moe_dense_tp_size,
1633
1615
  tp_cpu_group=self.tp_cpu_group,
1634
1616
  get_idle_batch=self.get_idle_batch,
1635
1617
  disable_cuda_graph=self.server_args.disable_cuda_graph,
@@ -1642,6 +1624,7 @@ class Scheduler(
1642
1624
  local_batch: ScheduleBatch,
1643
1625
  dp_size,
1644
1626
  attn_tp_size: int,
1627
+ moe_dense_tp_size: Optional[int],
1645
1628
  tp_cpu_group,
1646
1629
  get_idle_batch,
1647
1630
  disable_cuda_graph: bool,
@@ -1651,15 +1634,15 @@ class Scheduler(
1651
1634
  # Check if other DP workers have running batches
1652
1635
  if local_batch is None:
1653
1636
  num_tokens = 0
1654
- global_num_tokens_for_logprob = 0
1637
+ num_tokens_for_logprob = 0
1655
1638
  elif local_batch.forward_mode.is_decode():
1656
1639
  num_tokens = local_batch.batch_size()
1657
1640
  if not spec_algorithm.is_none() and spec_algorithm.is_eagle():
1658
1641
  num_tokens = num_tokens * speculative_num_draft_tokens
1659
- global_num_tokens_for_logprob = num_tokens
1642
+ num_tokens_for_logprob = num_tokens
1660
1643
  else:
1661
1644
  num_tokens = local_batch.extend_num_tokens
1662
- global_num_tokens_for_logprob = sum(
1645
+ num_tokens_for_logprob = sum(
1663
1646
  [
1664
1647
  # We should have at least 1 token for sample in every case.
1665
1648
  max(extend_len - logprob_start_len, 1)
@@ -1686,7 +1669,7 @@ class Scheduler(
1686
1669
  [
1687
1670
  num_tokens,
1688
1671
  can_cuda_graph,
1689
- global_num_tokens_for_logprob,
1672
+ num_tokens_for_logprob,
1690
1673
  is_extend_in_batch,
1691
1674
  ],
1692
1675
  dtype=torch.int64,
@@ -1709,8 +1692,15 @@ class Scheduler(
1709
1692
  local_batch = get_idle_batch()
1710
1693
 
1711
1694
  if local_batch is not None:
1712
- local_batch.global_num_tokens = global_num_tokens
1713
- local_batch.global_num_tokens_for_logprob = global_num_tokens_for_logprob
1695
+ # TODO: handle the case when moe_dense_tp_size != 1
1696
+ if moe_dense_tp_size == 1 and global_server_args_dict["enable_dp_lm_head"]:
1697
+ local_batch.global_num_tokens = [num_tokens]
1698
+ local_batch.global_num_tokens_for_logprob = [num_tokens_for_logprob]
1699
+ else:
1700
+ local_batch.global_num_tokens = global_num_tokens
1701
+ local_batch.global_num_tokens_for_logprob = (
1702
+ global_num_tokens_for_logprob
1703
+ )
1714
1704
 
1715
1705
  # Check forward mode for cuda graph
1716
1706
  if not disable_cuda_graph:
@@ -1736,11 +1726,17 @@ class Scheduler(
1736
1726
  """Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
1737
1727
 
1738
1728
  num_ready_reqs = 0
1729
+ num_abort_reqs = 0
1739
1730
  for req in self.grammar_queue:
1740
1731
  try:
1741
- req.grammar = req.grammar.result(timeout=0.05)
1732
+ req.grammar = req.grammar.result(timeout=0.03)
1733
+ if req.grammar:
1734
+ self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
1742
1735
  num_ready_reqs += 1
1743
1736
  except futures._base.TimeoutError:
1737
+ req.grammar_wait_ct += 1
1738
+ if req.grammar_wait_ct > GRAMMAR_TIMEOUT / 0.03:
1739
+ num_abort_reqs = 1
1744
1740
  break
1745
1741
 
1746
1742
  if self.server_args.enable_dp_attention:
@@ -1752,18 +1748,39 @@ class Scheduler(
1752
1748
 
1753
1749
  if tp_size > 1:
1754
1750
  # Sync across TP ranks to make sure they have the same number of ready requests
1755
- tensor = torch.tensor(num_ready_reqs, dtype=torch.int32)
1751
+ tensor = torch.tensor([num_ready_reqs, num_abort_reqs], dtype=torch.int32)
1756
1752
  torch.distributed.all_reduce(
1757
1753
  tensor, op=torch.distributed.ReduceOp.MAX, group=tp_group
1758
1754
  )
1759
- num_ready_reqs_max = tensor.item()
1755
+ num_ready_reqs_max, num_abort_reqs_max = tensor.tolist()
1756
+
1760
1757
  for i in range(num_ready_reqs, num_ready_reqs_max):
1761
- self.grammar_queue[i].grammar = self.grammar_queue[i].grammar.result()
1762
- num_ready_reqs = num_ready_reqs_max
1758
+ req = self.grammar_queue[i]
1759
+ req.grammar = req.grammar.result()
1760
+ if req.grammar:
1761
+ self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
1762
+
1763
+ for i in range(num_ready_reqs, num_ready_reqs + num_abort_reqs_max):
1764
+ req = self.grammar_queue[i]
1765
+ req.grammar.cancel()
1766
+ req.grammar = None
1767
+ error_msg = f"Grammar preprocessing timed out for {req.grammar_key=}"
1768
+ logger.error(error_msg)
1769
+ req.finished_reason = FINISH_ABORT(
1770
+ error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
1771
+ )
1772
+ num_ready_reqs = num_ready_reqs_max + num_abort_reqs_max
1763
1773
 
1764
1774
  self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
1765
1775
  self.grammar_queue = self.grammar_queue[num_ready_reqs:]
1766
1776
 
1777
+ def set_next_batch_sampling_info_done(self, batch: ScheduleBatch):
1778
+ if batch.next_batch_sampling_info:
1779
+ if batch.next_batch_sampling_info.grammars is not None:
1780
+ batch.next_batch_sampling_info.update_regex_vocab_mask()
1781
+ self.current_stream.synchronize()
1782
+ batch.next_batch_sampling_info.sampling_info_done.set()
1783
+
1767
1784
  def watchdog_thread(self):
1768
1785
  """A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
1769
1786
  self.watchdog_last_forward_ct = 0
@@ -1774,24 +1791,27 @@ class Scheduler(
1774
1791
  if self.cur_batch is not None:
1775
1792
  if self.watchdog_last_forward_ct == self.forward_ct:
1776
1793
  if current > self.watchdog_last_time + self.watchdog_timeout:
1777
- logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
1778
1794
  break
1779
1795
  else:
1780
1796
  self.watchdog_last_forward_ct = self.forward_ct
1781
1797
  self.watchdog_last_time = current
1782
1798
  time.sleep(self.watchdog_timeout // 2)
1783
1799
 
1784
- # Print batch size and memory pool info to check whether there are de-sync issues.
1785
- logger.error(
1786
- f"{self.cur_batch.batch_size()=}, "
1787
- f"{self.cur_batch.reqs=}, "
1788
- f"{self.token_to_kv_pool_allocator.available_size()=}, "
1789
- f"{self.tree_cache.evictable_size()=}, "
1790
- )
1791
- # Wait for some time so that the parent process can print the error.
1800
+ if not disable_request_logging():
1801
+ # Print batch size and memory pool info to check whether there are de-sync issues.
1802
+ logger.error(
1803
+ f"{self.cur_batch.batch_size()=}, "
1804
+ f"{self.cur_batch.reqs=}, "
1805
+ f"{self.token_to_kv_pool_allocator.available_size()=}, "
1806
+ f"{self.tree_cache.evictable_size()=}, "
1807
+ )
1808
+
1792
1809
  pyspy_dump_schedulers()
1810
+ logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
1793
1811
  print(file=sys.stderr, flush=True)
1794
1812
  print(file=sys.stdout, flush=True)
1813
+
1814
+ # Wait for some time so that the parent process can print the error.
1795
1815
  time.sleep(5)
1796
1816
  self.parent_process.send_signal(signal.SIGQUIT)
1797
1817
 
@@ -1923,25 +1943,30 @@ class Scheduler(
1923
1943
  )
1924
1944
 
1925
1945
  def abort_request(self, recv_req: AbortReq):
1946
+ # TODO(lmzheng): abort the requests in the grammar queue.
1947
+
1926
1948
  # Delete requests in the waiting queue
1927
1949
  to_del = []
1928
1950
  for i, req in enumerate(self.waiting_queue):
1929
1951
  if req.rid.startswith(recv_req.rid):
1930
1952
  to_del.append(i)
1931
- break
1932
1953
 
1933
1954
  # Sort in reverse order to avoid index issues when deleting
1934
- for i in sorted(to_del, reverse=True):
1955
+ for i in reversed(to_del):
1935
1956
  req = self.waiting_queue.pop(i)
1957
+ self.send_to_tokenizer.send_pyobj(AbortReq(req.rid))
1936
1958
  logger.debug(f"Abort queued request. {req.rid=}")
1937
- return
1938
1959
 
1939
1960
  # Delete requests in the running batch
1940
- for req in self.running_batch.reqs:
1961
+ if self.cur_batch is self.running_batch or self.cur_batch is None:
1962
+ reqs = self.running_batch.reqs
1963
+ else:
1964
+ reqs = self.running_batch.reqs + self.cur_batch.reqs
1965
+
1966
+ for req in reqs:
1941
1967
  if req.rid.startswith(recv_req.rid) and not req.finished():
1942
1968
  logger.debug(f"Abort running request. {req.rid=}")
1943
1969
  req.to_abort = True
1944
- return
1945
1970
 
1946
1971
  def _pause_engine(self) -> Tuple[List[Req], int]:
1947
1972
  raise NotImplementedError()
@@ -2162,8 +2187,8 @@ class Scheduler(
2162
2187
 
2163
2188
  def get_print_prefix(self):
2164
2189
  prefix = ""
2165
- if self.dp_rank is not None:
2166
- prefix += f" DP{self.dp_rank}"
2190
+ if self.attn_dp_rank is not None:
2191
+ prefix += f" DP{self.attn_dp_rank}"
2167
2192
  if self.server_args.tp_size > 1:
2168
2193
  prefix += f" TP{self.tp_rank}"
2169
2194
  if self.pp_size > 1: