sglang 0.4.6.post2__py3-none-any.whl → 0.4.6.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sglang/bench_offline_throughput.py +4 -2
  2. sglang/bench_one_batch.py +3 -13
  3. sglang/bench_one_batch_server.py +143 -15
  4. sglang/bench_serving.py +158 -8
  5. sglang/compile_deep_gemm.py +1 -1
  6. sglang/eval/loogle_eval.py +157 -0
  7. sglang/lang/chat_template.py +119 -75
  8. sglang/lang/tracer.py +1 -1
  9. sglang/srt/code_completion_parser.py +1 -1
  10. sglang/srt/configs/deepseekvl2.py +5 -2
  11. sglang/srt/configs/device_config.py +1 -1
  12. sglang/srt/configs/internvl.py +696 -0
  13. sglang/srt/configs/janus_pro.py +3 -0
  14. sglang/srt/configs/model_config.py +18 -0
  15. sglang/srt/constrained/base_grammar_backend.py +55 -72
  16. sglang/srt/constrained/llguidance_backend.py +25 -21
  17. sglang/srt/constrained/outlines_backend.py +27 -26
  18. sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
  19. sglang/srt/constrained/xgrammar_backend.py +71 -53
  20. sglang/srt/conversation.py +78 -46
  21. sglang/srt/disaggregation/base/conn.py +1 -0
  22. sglang/srt/disaggregation/decode.py +11 -3
  23. sglang/srt/disaggregation/fake/conn.py +1 -1
  24. sglang/srt/disaggregation/mini_lb.py +74 -23
  25. sglang/srt/disaggregation/mooncake/conn.py +236 -138
  26. sglang/srt/disaggregation/nixl/conn.py +242 -71
  27. sglang/srt/disaggregation/prefill.py +7 -4
  28. sglang/srt/disaggregation/utils.py +51 -2
  29. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
  30. sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
  31. sglang/srt/distributed/device_communicators/pynccl.py +2 -1
  32. sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
  33. sglang/srt/distributed/parallel_state.py +22 -1
  34. sglang/srt/entrypoints/engine.py +31 -4
  35. sglang/srt/entrypoints/http_server.py +45 -3
  36. sglang/srt/entrypoints/verl_engine.py +3 -2
  37. sglang/srt/function_call_parser.py +2 -2
  38. sglang/srt/hf_transformers_utils.py +20 -1
  39. sglang/srt/layers/attention/flashattention_backend.py +147 -51
  40. sglang/srt/layers/attention/flashinfer_backend.py +23 -13
  41. sglang/srt/layers/attention/flashinfer_mla_backend.py +62 -15
  42. sglang/srt/layers/attention/merge_state.py +46 -0
  43. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
  44. sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
  45. sglang/srt/layers/attention/utils.py +4 -2
  46. sglang/srt/layers/attention/vision.py +290 -163
  47. sglang/srt/layers/dp_attention.py +71 -21
  48. sglang/srt/layers/layernorm.py +1 -1
  49. sglang/srt/layers/logits_processor.py +46 -11
  50. sglang/srt/layers/moe/ep_moe/kernels.py +343 -8
  51. sglang/srt/layers/moe/ep_moe/layer.py +121 -2
  52. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +97 -54
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  56. sglang/srt/layers/moe/topk.py +1 -1
  57. sglang/srt/layers/quantization/__init__.py +1 -1
  58. sglang/srt/layers/quantization/blockwise_int8.py +2 -2
  59. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
  60. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
  61. sglang/srt/layers/quantization/deep_gemm.py +77 -71
  62. sglang/srt/layers/quantization/fp8.py +110 -97
  63. sglang/srt/layers/quantization/fp8_kernel.py +81 -62
  64. sglang/srt/layers/quantization/fp8_utils.py +71 -23
  65. sglang/srt/layers/quantization/int8_kernel.py +2 -2
  66. sglang/srt/layers/quantization/kv_cache.py +3 -10
  67. sglang/srt/layers/quantization/utils.py +0 -5
  68. sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
  69. sglang/srt/layers/sampler.py +0 -4
  70. sglang/srt/layers/vocab_parallel_embedding.py +18 -7
  71. sglang/srt/lora/lora_manager.py +11 -14
  72. sglang/srt/lora/mem_pool.py +4 -4
  73. sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
  74. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  75. sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
  76. sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
  77. sglang/srt/lora/utils.py +1 -1
  78. sglang/srt/managers/cache_controller.py +115 -119
  79. sglang/srt/managers/data_parallel_controller.py +3 -3
  80. sglang/srt/managers/detokenizer_manager.py +21 -8
  81. sglang/srt/managers/io_struct.py +13 -1
  82. sglang/srt/managers/mm_utils.py +1 -1
  83. sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
  84. sglang/srt/managers/multimodal_processors/internvl.py +232 -0
  85. sglang/srt/managers/multimodal_processors/llava.py +46 -0
  86. sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
  87. sglang/srt/managers/schedule_batch.py +93 -23
  88. sglang/srt/managers/schedule_policy.py +11 -8
  89. sglang/srt/managers/scheduler.py +140 -100
  90. sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
  91. sglang/srt/managers/tokenizer_manager.py +157 -47
  92. sglang/srt/managers/tp_worker.py +21 -21
  93. sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
  94. sglang/srt/mem_cache/chunk_cache.py +2 -0
  95. sglang/srt/mem_cache/memory_pool.py +4 -2
  96. sglang/srt/metrics/collector.py +312 -37
  97. sglang/srt/model_executor/cuda_graph_runner.py +10 -11
  98. sglang/srt/model_executor/forward_batch_info.py +1 -1
  99. sglang/srt/model_executor/model_runner.py +57 -41
  100. sglang/srt/model_loader/loader.py +18 -11
  101. sglang/srt/models/clip.py +4 -4
  102. sglang/srt/models/deepseek_janus_pro.py +3 -3
  103. sglang/srt/models/deepseek_nextn.py +1 -20
  104. sglang/srt/models/deepseek_v2.py +77 -39
  105. sglang/srt/models/gemma3_mm.py +1 -1
  106. sglang/srt/models/internlm2.py +3 -0
  107. sglang/srt/models/internvl.py +670 -0
  108. sglang/srt/models/llama.py +3 -1
  109. sglang/srt/models/llama4.py +58 -13
  110. sglang/srt/models/llava.py +248 -5
  111. sglang/srt/models/minicpmv.py +1 -1
  112. sglang/srt/models/mixtral.py +98 -34
  113. sglang/srt/models/mllama.py +1 -1
  114. sglang/srt/models/phi3_small.py +16 -2
  115. sglang/srt/models/pixtral.py +467 -0
  116. sglang/srt/models/qwen2_5_vl.py +8 -4
  117. sglang/srt/models/qwen2_vl.py +4 -4
  118. sglang/srt/models/roberta.py +1 -1
  119. sglang/srt/models/torch_native_llama.py +1 -1
  120. sglang/srt/models/xiaomi_mimo.py +171 -0
  121. sglang/srt/openai_api/adapter.py +52 -42
  122. sglang/srt/openai_api/protocol.py +20 -16
  123. sglang/srt/reasoning_parser.py +1 -1
  124. sglang/srt/sampling/custom_logit_processor.py +18 -3
  125. sglang/srt/sampling/sampling_batch_info.py +2 -2
  126. sglang/srt/sampling/sampling_params.py +2 -0
  127. sglang/srt/server_args.py +64 -10
  128. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  129. sglang/srt/speculative/eagle_utils.py +7 -7
  130. sglang/srt/speculative/eagle_worker.py +22 -19
  131. sglang/srt/utils.py +41 -6
  132. sglang/test/few_shot_gsm8k.py +2 -2
  133. sglang/test/few_shot_gsm8k_engine.py +2 -2
  134. sglang/test/run_eval.py +2 -2
  135. sglang/test/runners.py +8 -1
  136. sglang/test/send_one.py +13 -3
  137. sglang/test/simple_eval_common.py +1 -1
  138. sglang/test/simple_eval_humaneval.py +1 -1
  139. sglang/test/test_block_fp8.py +2 -2
  140. sglang/test/test_deepep_utils.py +219 -0
  141. sglang/test/test_programs.py +5 -5
  142. sglang/test/test_utils.py +92 -15
  143. sglang/utils.py +1 -1
  144. sglang/version.py +1 -1
  145. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/METADATA +18 -9
  146. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/RECORD +150 -137
  147. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/WHEEL +1 -1
  148. /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
  149. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/licenses/LICENSE +0 -0
  150. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/top_level.txt +0 -0
@@ -20,7 +20,6 @@ import signal
20
20
  import sys
21
21
  import threading
22
22
  import time
23
- import warnings
24
23
  from collections import defaultdict, deque
25
24
  from concurrent import futures
26
25
  from dataclasses import dataclass
@@ -52,7 +51,11 @@ from sglang.srt.disaggregation.utils import (
52
51
  TransferBackend,
53
52
  )
54
53
  from sglang.srt.distributed import get_pp_group, get_world_group
55
- from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
54
+ from sglang.srt.hf_transformers_utils import (
55
+ get_processor,
56
+ get_tokenizer,
57
+ get_tokenizer_from_processor,
58
+ )
56
59
  from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
57
60
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
58
61
  from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
@@ -83,6 +86,8 @@ from sglang.srt.managers.io_struct import (
83
86
  RpcReqOutput,
84
87
  SetInternalStateReq,
85
88
  SetInternalStateReqOutput,
89
+ SlowDownReqInput,
90
+ SlowDownReqOutput,
86
91
  TokenizedEmbeddingReqInput,
87
92
  TokenizedGenerateReqInput,
88
93
  UpdateWeightFromDiskReqInput,
@@ -115,11 +120,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
115
120
  from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
116
121
  from sglang.srt.mem_cache.radix_cache import RadixCache
117
122
  from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
118
- from sglang.srt.model_executor.forward_batch_info import (
119
- ForwardBatch,
120
- ForwardMode,
121
- PPProxyTensors,
122
- )
123
+ from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
123
124
  from sglang.srt.reasoning_parser import ReasoningParser
124
125
  from sglang.srt.server_args import PortArgs, ServerArgs
125
126
  from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
@@ -129,6 +130,7 @@ from sglang.srt.utils import (
129
130
  broadcast_pyobj,
130
131
  configure_logger,
131
132
  crash_on_warnings,
133
+ disable_request_logging,
132
134
  get_bool_env_var,
133
135
  get_zmq_socket,
134
136
  kill_itself_when_parent_died,
@@ -147,6 +149,7 @@ logger = logging.getLogger(__name__)
147
149
  # Test retract decode for debugging purposes
148
150
  TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
149
151
  RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
152
+ GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
150
153
 
151
154
 
152
155
  @dataclass
@@ -157,6 +160,7 @@ class GenerationBatchResult:
157
160
  extend_input_len_per_req: List[int]
158
161
  extend_logprob_start_len_per_req: List[int]
159
162
  bid: int
163
+ can_run_cuda_graph: bool
160
164
 
161
165
 
162
166
  @dataclass
@@ -203,7 +207,8 @@ class Scheduler(
203
207
  self.page_size = server_args.page_size
204
208
 
205
209
  # Distributed rank info
206
- self.attn_tp_rank, self.attn_tp_size, self.dp_rank = (
210
+ self.dp_size = server_args.dp_size
211
+ self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = (
207
212
  compute_dp_attention_world_info(
208
213
  server_args.enable_dp_attention,
209
214
  self.tp_rank,
@@ -320,13 +325,14 @@ class Scheduler(
320
325
  set_random_seed(self.random_seed)
321
326
 
322
327
  # Print debug info
323
- logger.info(
324
- f"max_total_num_tokens={self.max_total_num_tokens}, "
325
- f"chunked_prefill_size={server_args.chunked_prefill_size}, "
326
- f"max_prefill_tokens={self.max_prefill_tokens}, "
327
- f"max_running_requests={self.max_running_requests}, "
328
- f"context_len={self.model_config.context_len}"
329
- )
328
+ if tp_rank == 0:
329
+ logger.info(
330
+ f"max_total_num_tokens={self.max_total_num_tokens}, "
331
+ f"chunked_prefill_size={server_args.chunked_prefill_size}, "
332
+ f"max_prefill_tokens={self.max_prefill_tokens}, "
333
+ f"max_running_requests={self.max_running_requests}, "
334
+ f"context_len={self.model_config.context_len}"
335
+ )
330
336
 
331
337
  # Init memory pool and cache
332
338
  self.init_memory_pool_and_cache()
@@ -413,6 +419,8 @@ class Scheduler(
413
419
  self.profiler_id: Optional[str] = None
414
420
  self.profiler_target_forward_ct: Optional[int] = None
415
421
 
422
+ self.forward_sleep_time = None
423
+
416
424
  # Init metrics stats
417
425
  self.init_metrics()
418
426
 
@@ -435,6 +443,7 @@ class Scheduler(
435
443
  (GetWeightsByNameReqInput, self.get_weights_by_name),
436
444
  (ReleaseMemoryOccupationReqInput, self.release_memory_occupation),
437
445
  (ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
446
+ (SlowDownReqInput, self.slow_down),
438
447
  (ProfileReq, self.profile),
439
448
  (GetInternalStateReq, self.get_internal_state),
440
449
  (SetInternalStateReq, self.set_internal_state),
@@ -451,17 +460,7 @@ class Scheduler(
451
460
  def init_tokenizer(self):
452
461
  server_args = self.server_args
453
462
 
454
- self.model_config = ModelConfig(
455
- server_args.model_path,
456
- trust_remote_code=server_args.trust_remote_code,
457
- revision=server_args.revision,
458
- context_length=server_args.context_length,
459
- model_override_args=server_args.json_model_override_args,
460
- is_embedding=server_args.is_embedding,
461
- enable_multimodal=server_args.enable_multimodal,
462
- dtype=server_args.dtype,
463
- quantization=server_args.quantization,
464
- )
463
+ self.model_config = ModelConfig.from_server_args(server_args)
465
464
  self.is_generation = self.model_config.is_generation
466
465
 
467
466
  if server_args.skip_tokenizer_init:
@@ -475,7 +474,7 @@ class Scheduler(
475
474
  revision=server_args.revision,
476
475
  use_fast=not server_args.disable_fast_image_processor,
477
476
  )
478
- self.tokenizer = self.processor.tokenizer
477
+ self.tokenizer = get_tokenizer_from_processor(self.processor)
479
478
  else:
480
479
  self.tokenizer = get_tokenizer(
481
480
  server_args.tokenizer_path,
@@ -498,6 +497,7 @@ class Scheduler(
498
497
  self.tree_cache = ChunkCache(
499
498
  req_to_token_pool=self.req_to_token_pool,
500
499
  token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
500
+ page_size=self.page_size,
501
501
  )
502
502
  else:
503
503
  if self.enable_hierarchical_cache:
@@ -531,10 +531,6 @@ class Scheduler(
531
531
  )
532
532
 
533
533
  def init_metrics(self):
534
- # The largest prefill length of a single request
535
- self._largest_prefill_len: int = 0
536
- # The largest context length (prefill + generation) of a single request
537
- self._largest_prefill_decode_len: int = 0
538
534
  self.last_gen_throughput: float = 0.0
539
535
  self.last_input_throughput: float = 0.0
540
536
  self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
@@ -720,7 +716,7 @@ class Scheduler(
720
716
  server_is_idle = False
721
717
  result = self.run_batch(self.cur_batch)
722
718
 
723
- # send the outputs to the next step
719
+ # (last rank) send the outputs to the next step
724
720
  if self.pp_group.is_last_rank:
725
721
  if self.cur_batch:
726
722
  next_token_ids, bids[mb_id] = (
@@ -755,24 +751,25 @@ class Scheduler(
755
751
  extend_input_len_per_req=None,
756
752
  extend_logprob_start_len_per_req=None,
757
753
  bid=bids[next_mb_id],
754
+ can_run_cuda_graph=result.can_run_cuda_graph,
758
755
  )
759
756
  self.process_batch_result(mbs[next_mb_id], output_result)
760
757
  last_mbs[next_mb_id] = mbs[next_mb_id]
761
758
 
762
- # carry the outputs to the next stage
759
+ # (not last rank)
763
760
  if not self.pp_group.is_last_rank:
764
761
  if self.cur_batch:
765
762
  bids[mb_id] = result.bid
763
+ # carry the outputs to the next stage
764
+ # send the outputs from the last round to let the next stage worker run post processing
766
765
  if pp_outputs:
767
- # send the outputs from the last round to let the next stage worker run post processing
768
766
  self.pp_group.send_tensor_dict(
769
767
  pp_outputs.tensors,
770
768
  all_gather_group=self.attn_tp_group,
771
769
  )
772
770
 
773
- if not self.pp_group.is_last_rank:
774
771
  # send out reqs to the next stage
775
- dp_offset = self.dp_rank * self.attn_tp_size
772
+ dp_offset = self.attn_dp_rank * self.attn_tp_size
776
773
  if self.attn_tp_rank == 0:
777
774
  point_to_point_pyobj(
778
775
  recv_reqs,
@@ -819,7 +816,7 @@ class Scheduler(
819
816
  recv_reqs = None
820
817
  else:
821
818
  if self.attn_tp_rank == 0:
822
- dp_offset = self.dp_rank * self.attn_tp_size
819
+ dp_offset = self.attn_dp_rank * self.attn_tp_size
823
820
  recv_reqs = point_to_point_pyobj(
824
821
  [],
825
822
  self.pp_rank * self.tp_size + dp_offset,
@@ -907,18 +904,9 @@ class Scheduler(
907
904
  fake_input_ids = [1] * seq_length
908
905
  recv_req.input_ids = fake_input_ids
909
906
 
910
- # Handle custom logit processor passed to the request
911
- custom_logit_processor = recv_req.custom_logit_processor
912
- if (
913
- not self.server_args.enable_custom_logit_processor
914
- and custom_logit_processor is not None
915
- ):
916
- logger.warning(
917
- "The SGLang server is not configured to enable custom logit processor."
918
- "The custom logit processor passed in will be ignored."
919
- "Please set --enable-custom-logits-processor to enable this feature."
920
- )
921
- custom_logit_processor = None
907
+ if recv_req.bootstrap_port is None:
908
+ # Use default bootstrap port
909
+ recv_req.bootstrap_port = self.server_args.disaggregation_bootstrap_port
922
910
 
923
911
  req = Req(
924
912
  recv_req.rid,
@@ -931,7 +919,7 @@ class Scheduler(
931
919
  stream=recv_req.stream,
932
920
  lora_path=recv_req.lora_path,
933
921
  input_embeds=recv_req.input_embeds,
934
- custom_logit_processor=custom_logit_processor,
922
+ custom_logit_processor=recv_req.custom_logit_processor,
935
923
  return_hidden_states=recv_req.return_hidden_states,
936
924
  eos_token_ids=self.model_config.hf_eos_token_id,
937
925
  bootstrap_host=recv_req.bootstrap_host,
@@ -1037,9 +1025,11 @@ class Scheduler(
1037
1025
  elif req.sampling_params.structural_tag:
1038
1026
  key = ("structural_tag", req.sampling_params.structural_tag)
1039
1027
 
1040
- req.grammar = self.grammar_backend.get_cached_value(key)
1041
- if not req.grammar:
1042
- req.grammar = self.grammar_backend.get_future_value(key)
1028
+ value, cache_hit = self.grammar_backend.get_cached_or_future_value(key)
1029
+ req.grammar = value
1030
+
1031
+ if not cache_hit:
1032
+ req.grammar_key = key
1043
1033
  add_to_grammar_queue = True
1044
1034
 
1045
1035
  if add_to_grammar_queue:
@@ -1129,9 +1119,6 @@ class Scheduler(
1129
1119
  self.token_to_kv_pool_allocator.available_size()
1130
1120
  + self.tree_cache.evictable_size()
1131
1121
  )
1132
- self._largest_prefill_len = max(
1133
- self._largest_prefill_len, adder.log_input_tokens
1134
- )
1135
1122
 
1136
1123
  num_new_seq = len(can_run_list)
1137
1124
  f = (
@@ -1169,7 +1156,9 @@ class Scheduler(
1169
1156
 
1170
1157
  self.metrics_collector.log_stats(self.stats)
1171
1158
 
1172
- def log_decode_stats(self, running_batch=None):
1159
+ def log_decode_stats(
1160
+ self, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None
1161
+ ):
1173
1162
  batch = running_batch or self.running_batch
1174
1163
 
1175
1164
  gap_latency = time.time() - self.last_decode_stats_tic
@@ -1209,6 +1198,7 @@ class Scheduler(
1209
1198
  msg += f"pre-allocated usage: {self.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "
1210
1199
 
1211
1200
  msg += (
1201
+ f"cuda graph: {can_run_cuda_graph}, "
1212
1202
  f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
1213
1203
  f"#queue-req: {len(self.waiting_queue)}"
1214
1204
  )
@@ -1221,6 +1211,7 @@ class Scheduler(
1221
1211
  self.stats.cache_hit_rate = 0.0
1222
1212
  self.stats.gen_throughput = self.last_gen_throughput
1223
1213
  self.stats.num_queue_reqs = len(self.waiting_queue)
1214
+ self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
1224
1215
  self.stats.spec_accept_length = spec_accept_length
1225
1216
  self.metrics_collector.log_stats(self.stats)
1226
1217
 
@@ -1242,9 +1233,7 @@ class Scheduler(
1242
1233
  f"{self.token_to_kv_pool_allocator.available_size()=}\n"
1243
1234
  f"{self.tree_cache.evictable_size()=}\n"
1244
1235
  )
1245
- warnings.warn(msg)
1246
- if crash_on_warnings():
1247
- raise ValueError(msg)
1236
+ raise ValueError(msg)
1248
1237
 
1249
1238
  if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
1250
1239
  msg = (
@@ -1252,9 +1241,7 @@ class Scheduler(
1252
1241
  f"available_size={len(self.req_to_token_pool.free_slots)}, "
1253
1242
  f"total_size={self.req_to_token_pool.size}\n"
1254
1243
  )
1255
- warnings.warn(msg)
1256
- if crash_on_warnings():
1257
- raise ValueError(msg)
1244
+ raise ValueError(msg)
1258
1245
 
1259
1246
  if (
1260
1247
  self.enable_metrics
@@ -1272,6 +1259,7 @@ class Scheduler(
1272
1259
  self.stats.token_usage = num_used / self.max_total_num_tokens
1273
1260
  self.stats.gen_throughput = 0
1274
1261
  self.stats.num_queue_reqs = len(self.waiting_queue)
1262
+ self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
1275
1263
  self.metrics_collector.log_stats(self.stats)
1276
1264
 
1277
1265
  def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
@@ -1342,7 +1330,7 @@ class Scheduler(
1342
1330
  return None
1343
1331
 
1344
1332
  running_bs = len(self.running_batch.reqs)
1345
- # Igore the check if self.chunked_req is not None.
1333
+ # Ignore the check if self.chunked_req is not None.
1346
1334
  # In the non-PP case, when self.chunked_req is not None, num_allocatable_reqs should always be greater than 0,
1347
1335
  # as the space for the chunked request has just been released.
1348
1336
  # In PP case, a chunked req can start in one microbatch and end in another microbatch, so the max_running_requests per microbatch should not be strict.
@@ -1527,16 +1515,20 @@ class Scheduler(
1527
1515
  ):
1528
1516
  self.stop_profile()
1529
1517
 
1518
+ if self.forward_sleep_time is not None:
1519
+ logger.info(f"Scheduler.run_batch sleep {self.forward_sleep_time}s")
1520
+ time.sleep(self.forward_sleep_time)
1521
+
1530
1522
  # Run forward
1531
1523
  if self.is_generation:
1532
1524
  if self.spec_algorithm.is_none():
1533
1525
  model_worker_batch = batch.get_model_worker_batch()
1534
1526
  if self.pp_group.is_last_rank:
1535
- logits_output, next_token_ids = (
1527
+ logits_output, next_token_ids, can_run_cuda_graph = (
1536
1528
  self.tp_worker.forward_batch_generation(model_worker_batch)
1537
1529
  )
1538
1530
  else:
1539
- pp_hidden_states_proxy_tensors, _ = (
1531
+ pp_hidden_states_proxy_tensors, _, can_run_cuda_graph = (
1540
1532
  self.tp_worker.forward_batch_generation(model_worker_batch)
1541
1533
  )
1542
1534
  bid = model_worker_batch.bid
@@ -1546,6 +1538,7 @@ class Scheduler(
1546
1538
  next_token_ids,
1547
1539
  bid,
1548
1540
  num_accepted_tokens,
1541
+ can_run_cuda_graph,
1549
1542
  ) = self.draft_worker.forward_batch_speculative_generation(batch)
1550
1543
  self.spec_num_total_accepted_tokens += (
1551
1544
  num_accepted_tokens + batch.batch_size()
@@ -1579,6 +1572,7 @@ class Scheduler(
1579
1572
  extend_input_len_per_req=extend_input_len_per_req,
1580
1573
  extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
1581
1574
  bid=bid,
1575
+ can_run_cuda_graph=can_run_cuda_graph,
1582
1576
  )
1583
1577
  else: # embedding or reward model
1584
1578
  model_worker_batch = batch.get_model_worker_batch()
@@ -1601,14 +1595,9 @@ class Scheduler(
1601
1595
  elif batch.forward_mode.is_idle():
1602
1596
  if self.enable_overlap:
1603
1597
  self.tp_worker.resolve_last_batch_result(launch_done)
1604
- if batch.next_batch_sampling_info:
1605
- batch.next_batch_sampling_info.update_regex_vocab_mask()
1606
- self.current_stream.synchronize()
1607
- batch.next_batch_sampling_info.sampling_info_done.set()
1598
+ self.set_next_batch_sampling_info_done(batch)
1608
1599
  elif batch.forward_mode.is_dummy_first():
1609
- batch.next_batch_sampling_info.update_regex_vocab_mask()
1610
- self.current_stream.synchronize()
1611
- batch.next_batch_sampling_info.sampling_info_done.set()
1600
+ self.set_next_batch_sampling_info_done(batch)
1612
1601
 
1613
1602
  if self.return_health_check_ct:
1614
1603
  # Return some signal for the health check.
@@ -1622,6 +1611,7 @@ class Scheduler(
1622
1611
  local_batch,
1623
1612
  dp_size=self.server_args.dp_size,
1624
1613
  attn_tp_size=self.attn_tp_size,
1614
+ moe_dense_tp_size=self.server_args.moe_dense_tp_size,
1625
1615
  tp_cpu_group=self.tp_cpu_group,
1626
1616
  get_idle_batch=self.get_idle_batch,
1627
1617
  disable_cuda_graph=self.server_args.disable_cuda_graph,
@@ -1634,6 +1624,7 @@ class Scheduler(
1634
1624
  local_batch: ScheduleBatch,
1635
1625
  dp_size,
1636
1626
  attn_tp_size: int,
1627
+ moe_dense_tp_size: Optional[int],
1637
1628
  tp_cpu_group,
1638
1629
  get_idle_batch,
1639
1630
  disable_cuda_graph: bool,
@@ -1643,15 +1634,15 @@ class Scheduler(
1643
1634
  # Check if other DP workers have running batches
1644
1635
  if local_batch is None:
1645
1636
  num_tokens = 0
1646
- global_num_tokens_for_logprob = 0
1637
+ num_tokens_for_logprob = 0
1647
1638
  elif local_batch.forward_mode.is_decode():
1648
1639
  num_tokens = local_batch.batch_size()
1649
1640
  if not spec_algorithm.is_none() and spec_algorithm.is_eagle():
1650
1641
  num_tokens = num_tokens * speculative_num_draft_tokens
1651
- global_num_tokens_for_logprob = num_tokens
1642
+ num_tokens_for_logprob = num_tokens
1652
1643
  else:
1653
1644
  num_tokens = local_batch.extend_num_tokens
1654
- global_num_tokens_for_logprob = sum(
1645
+ num_tokens_for_logprob = sum(
1655
1646
  [
1656
1647
  # We should have at least 1 token for sample in every case.
1657
1648
  max(extend_len - logprob_start_len, 1)
@@ -1678,7 +1669,7 @@ class Scheduler(
1678
1669
  [
1679
1670
  num_tokens,
1680
1671
  can_cuda_graph,
1681
- global_num_tokens_for_logprob,
1672
+ num_tokens_for_logprob,
1682
1673
  is_extend_in_batch,
1683
1674
  ],
1684
1675
  dtype=torch.int64,
@@ -1701,8 +1692,15 @@ class Scheduler(
1701
1692
  local_batch = get_idle_batch()
1702
1693
 
1703
1694
  if local_batch is not None:
1704
- local_batch.global_num_tokens = global_num_tokens
1705
- local_batch.global_num_tokens_for_logprob = global_num_tokens_for_logprob
1695
+ # TODO: handle the case when moe_dense_tp_size != 1
1696
+ if moe_dense_tp_size == 1 and global_server_args_dict["enable_dp_lm_head"]:
1697
+ local_batch.global_num_tokens = [num_tokens]
1698
+ local_batch.global_num_tokens_for_logprob = [num_tokens_for_logprob]
1699
+ else:
1700
+ local_batch.global_num_tokens = global_num_tokens
1701
+ local_batch.global_num_tokens_for_logprob = (
1702
+ global_num_tokens_for_logprob
1703
+ )
1706
1704
 
1707
1705
  # Check forward mode for cuda graph
1708
1706
  if not disable_cuda_graph:
@@ -1728,11 +1726,17 @@ class Scheduler(
1728
1726
  """Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
1729
1727
 
1730
1728
  num_ready_reqs = 0
1729
+ num_abort_reqs = 0
1731
1730
  for req in self.grammar_queue:
1732
1731
  try:
1733
- req.grammar = req.grammar.result(timeout=0.05)
1732
+ req.grammar = req.grammar.result(timeout=0.03)
1733
+ if req.grammar:
1734
+ self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
1734
1735
  num_ready_reqs += 1
1735
1736
  except futures._base.TimeoutError:
1737
+ req.grammar_wait_ct += 1
1738
+ if req.grammar_wait_ct > GRAMMAR_TIMEOUT / 0.03:
1739
+ num_abort_reqs = 1
1736
1740
  break
1737
1741
 
1738
1742
  if self.server_args.enable_dp_attention:
@@ -1744,18 +1748,39 @@ class Scheduler(
1744
1748
 
1745
1749
  if tp_size > 1:
1746
1750
  # Sync across TP ranks to make sure they have the same number of ready requests
1747
- tensor = torch.tensor(num_ready_reqs, dtype=torch.int32)
1751
+ tensor = torch.tensor([num_ready_reqs, num_abort_reqs], dtype=torch.int32)
1748
1752
  torch.distributed.all_reduce(
1749
1753
  tensor, op=torch.distributed.ReduceOp.MAX, group=tp_group
1750
1754
  )
1751
- num_ready_reqs_max = tensor.item()
1755
+ num_ready_reqs_max, num_abort_reqs_max = tensor.tolist()
1756
+
1752
1757
  for i in range(num_ready_reqs, num_ready_reqs_max):
1753
- self.grammar_queue[i].grammar = self.grammar_queue[i].grammar.result()
1754
- num_ready_reqs = num_ready_reqs_max
1758
+ req = self.grammar_queue[i]
1759
+ req.grammar = req.grammar.result()
1760
+ if req.grammar:
1761
+ self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
1762
+
1763
+ for i in range(num_ready_reqs, num_ready_reqs + num_abort_reqs_max):
1764
+ req = self.grammar_queue[i]
1765
+ req.grammar.cancel()
1766
+ req.grammar = None
1767
+ error_msg = f"Grammar preprocessing timed out for {req.grammar_key=}"
1768
+ logger.error(error_msg)
1769
+ req.finished_reason = FINISH_ABORT(
1770
+ error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
1771
+ )
1772
+ num_ready_reqs = num_ready_reqs_max + num_abort_reqs_max
1755
1773
 
1756
1774
  self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
1757
1775
  self.grammar_queue = self.grammar_queue[num_ready_reqs:]
1758
1776
 
1777
+ def set_next_batch_sampling_info_done(self, batch: ScheduleBatch):
1778
+ if batch.next_batch_sampling_info:
1779
+ if batch.next_batch_sampling_info.grammars is not None:
1780
+ batch.next_batch_sampling_info.update_regex_vocab_mask()
1781
+ self.current_stream.synchronize()
1782
+ batch.next_batch_sampling_info.sampling_info_done.set()
1783
+
1759
1784
  def watchdog_thread(self):
1760
1785
  """A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
1761
1786
  self.watchdog_last_forward_ct = 0
@@ -1766,24 +1791,27 @@ class Scheduler(
1766
1791
  if self.cur_batch is not None:
1767
1792
  if self.watchdog_last_forward_ct == self.forward_ct:
1768
1793
  if current > self.watchdog_last_time + self.watchdog_timeout:
1769
- logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
1770
1794
  break
1771
1795
  else:
1772
1796
  self.watchdog_last_forward_ct = self.forward_ct
1773
1797
  self.watchdog_last_time = current
1774
1798
  time.sleep(self.watchdog_timeout // 2)
1775
1799
 
1776
- # Print batch size and memory pool info to check whether there are de-sync issues.
1777
- logger.error(
1778
- f"{self.cur_batch.batch_size()=}, "
1779
- f"{self.cur_batch.reqs=}, "
1780
- f"{self.token_to_kv_pool_allocator.available_size()=}, "
1781
- f"{self.tree_cache.evictable_size()=}, "
1782
- )
1783
- # Wait for some time so that the parent process can print the error.
1800
+ if not disable_request_logging():
1801
+ # Print batch size and memory pool info to check whether there are de-sync issues.
1802
+ logger.error(
1803
+ f"{self.cur_batch.batch_size()=}, "
1804
+ f"{self.cur_batch.reqs=}, "
1805
+ f"{self.token_to_kv_pool_allocator.available_size()=}, "
1806
+ f"{self.tree_cache.evictable_size()=}, "
1807
+ )
1808
+
1784
1809
  pyspy_dump_schedulers()
1810
+ logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
1785
1811
  print(file=sys.stderr, flush=True)
1786
1812
  print(file=sys.stdout, flush=True)
1813
+
1814
+ # Wait for some time so that the parent process can print the error.
1787
1815
  time.sleep(5)
1788
1816
  self.parent_process.send_signal(signal.SIGQUIT)
1789
1817
 
@@ -1915,25 +1943,30 @@ class Scheduler(
1915
1943
  )
1916
1944
 
1917
1945
  def abort_request(self, recv_req: AbortReq):
1946
+ # TODO(lmzheng): abort the requests in the grammar queue.
1947
+
1918
1948
  # Delete requests in the waiting queue
1919
1949
  to_del = []
1920
1950
  for i, req in enumerate(self.waiting_queue):
1921
1951
  if req.rid.startswith(recv_req.rid):
1922
1952
  to_del.append(i)
1923
- break
1924
1953
 
1925
1954
  # Sort in reverse order to avoid index issues when deleting
1926
- for i in sorted(to_del, reverse=True):
1955
+ for i in reversed(to_del):
1927
1956
  req = self.waiting_queue.pop(i)
1957
+ self.send_to_tokenizer.send_pyobj(AbortReq(req.rid))
1928
1958
  logger.debug(f"Abort queued request. {req.rid=}")
1929
- return
1930
1959
 
1931
1960
  # Delete requests in the running batch
1932
- for req in self.running_batch.reqs:
1961
+ if self.cur_batch is self.running_batch or self.cur_batch is None:
1962
+ reqs = self.running_batch.reqs
1963
+ else:
1964
+ reqs = self.running_batch.reqs + self.cur_batch.reqs
1965
+
1966
+ for req in reqs:
1933
1967
  if req.rid.startswith(recv_req.rid) and not req.finished():
1934
1968
  logger.debug(f"Abort running request. {req.rid=}")
1935
1969
  req.to_abort = True
1936
- return
1937
1970
 
1938
1971
  def _pause_engine(self) -> Tuple[List[Req], int]:
1939
1972
  raise NotImplementedError()
@@ -2002,6 +2035,13 @@ class Scheduler(
2002
2035
  del self.stashed_model_static_state
2003
2036
  return ResumeMemoryOccupationReqOutput()
2004
2037
 
2038
+ def slow_down(self, recv_req: SlowDownReqInput):
2039
+ t = recv_req.forward_sleep_time
2040
+ if t is not None and t <= 0:
2041
+ t = None
2042
+ self.forward_sleep_time = t
2043
+ return SlowDownReqOutput()
2044
+
2005
2045
  def profile(self, recv_req: ProfileReq):
2006
2046
  if recv_req.type == ProfileReqType.START_PROFILE:
2007
2047
  return self.start_profile(
@@ -2147,8 +2187,8 @@ class Scheduler(
2147
2187
 
2148
2188
  def get_print_prefix(self):
2149
2189
  prefix = ""
2150
- if self.dp_rank is not None:
2151
- prefix += f" DP{self.dp_rank}"
2190
+ if self.attn_dp_rank is not None:
2191
+ prefix += f" DP{self.attn_dp_rank}"
2152
2192
  if self.server_args.tp_size > 1:
2153
2193
  prefix += f" TP{self.tp_rank}"
2154
2194
  if self.pp_size > 1: