sglang 0.4.3.post4__py3-none-any.whl → 0.4.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. sglang/bench_serving.py +1 -1
  2. sglang/lang/chat_template.py +29 -0
  3. sglang/srt/_custom_ops.py +19 -17
  4. sglang/srt/configs/__init__.py +2 -0
  5. sglang/srt/configs/janus_pro.py +629 -0
  6. sglang/srt/configs/model_config.py +24 -14
  7. sglang/srt/conversation.py +80 -2
  8. sglang/srt/custom_op.py +64 -3
  9. sglang/srt/distributed/device_communicators/custom_all_reduce.py +18 -17
  10. sglang/srt/distributed/parallel_state.py +10 -1
  11. sglang/srt/entrypoints/engine.py +5 -3
  12. sglang/srt/entrypoints/http_server.py +1 -1
  13. sglang/srt/function_call_parser.py +33 -2
  14. sglang/srt/hf_transformers_utils.py +16 -1
  15. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  16. sglang/srt/layers/attention/flashinfer_mla_backend.py +317 -57
  17. sglang/srt/layers/attention/triton_backend.py +1 -3
  18. sglang/srt/layers/attention/triton_ops/decode_attention.py +6 -6
  19. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +3 -3
  20. sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
  21. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +3 -3
  22. sglang/srt/layers/attention/vision.py +43 -62
  23. sglang/srt/layers/dp_attention.py +30 -2
  24. sglang/srt/layers/elementwise.py +411 -0
  25. sglang/srt/layers/linear.py +1 -1
  26. sglang/srt/layers/logits_processor.py +1 -0
  27. sglang/srt/layers/moe/ep_moe/kernels.py +2 -1
  28. sglang/srt/layers/moe/ep_moe/layer.py +25 -9
  29. sglang/srt/layers/moe/fused_moe_triton/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +63 -23
  36. sglang/srt/layers/moe/fused_moe_triton/layer.py +16 -4
  37. sglang/srt/layers/moe/router.py +342 -0
  38. sglang/srt/layers/parameter.py +10 -0
  39. sglang/srt/layers/quantization/__init__.py +90 -68
  40. sglang/srt/layers/quantization/blockwise_int8.py +1 -2
  41. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  44. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  45. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  46. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  47. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  48. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  49. sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  50. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  51. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  52. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  63. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  64. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  65. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/quantization/fp8.py +174 -106
  68. sglang/srt/layers/quantization/fp8_kernel.py +210 -38
  69. sglang/srt/layers/quantization/fp8_utils.py +156 -15
  70. sglang/srt/layers/quantization/modelopt_quant.py +5 -1
  71. sglang/srt/layers/quantization/w8a8_fp8.py +128 -0
  72. sglang/srt/layers/quantization/w8a8_int8.py +152 -3
  73. sglang/srt/layers/rotary_embedding.py +5 -3
  74. sglang/srt/layers/sampler.py +29 -35
  75. sglang/srt/layers/vocab_parallel_embedding.py +0 -1
  76. sglang/srt/lora/backend/__init__.py +9 -12
  77. sglang/srt/managers/cache_controller.py +74 -8
  78. sglang/srt/managers/data_parallel_controller.py +1 -1
  79. sglang/srt/managers/image_processor.py +37 -631
  80. sglang/srt/managers/image_processors/base_image_processor.py +219 -0
  81. sglang/srt/managers/image_processors/janus_pro.py +79 -0
  82. sglang/srt/managers/image_processors/llava.py +152 -0
  83. sglang/srt/managers/image_processors/minicpmv.py +86 -0
  84. sglang/srt/managers/image_processors/mlama.py +60 -0
  85. sglang/srt/managers/image_processors/qwen_vl.py +161 -0
  86. sglang/srt/managers/io_struct.py +32 -15
  87. sglang/srt/managers/multi_modality_padding.py +134 -0
  88. sglang/srt/managers/schedule_batch.py +213 -118
  89. sglang/srt/managers/schedule_policy.py +40 -8
  90. sglang/srt/managers/scheduler.py +176 -683
  91. sglang/srt/managers/scheduler_output_processor_mixin.py +614 -0
  92. sglang/srt/managers/tokenizer_manager.py +6 -6
  93. sglang/srt/managers/tp_worker_overlap_thread.py +4 -1
  94. sglang/srt/mem_cache/base_prefix_cache.py +6 -8
  95. sglang/srt/mem_cache/chunk_cache.py +12 -44
  96. sglang/srt/mem_cache/hiradix_cache.py +71 -34
  97. sglang/srt/mem_cache/memory_pool.py +81 -17
  98. sglang/srt/mem_cache/paged_allocator.py +283 -0
  99. sglang/srt/mem_cache/radix_cache.py +117 -36
  100. sglang/srt/model_executor/cuda_graph_runner.py +68 -20
  101. sglang/srt/model_executor/forward_batch_info.py +23 -10
  102. sglang/srt/model_executor/model_runner.py +63 -63
  103. sglang/srt/model_loader/loader.py +2 -1
  104. sglang/srt/model_loader/weight_utils.py +1 -1
  105. sglang/srt/models/deepseek_janus_pro.py +2127 -0
  106. sglang/srt/models/deepseek_nextn.py +23 -3
  107. sglang/srt/models/deepseek_v2.py +200 -191
  108. sglang/srt/models/grok.py +374 -119
  109. sglang/srt/models/minicpmv.py +28 -89
  110. sglang/srt/models/mllama.py +1 -1
  111. sglang/srt/models/qwen2.py +0 -1
  112. sglang/srt/models/qwen2_5_vl.py +25 -50
  113. sglang/srt/models/qwen2_vl.py +33 -49
  114. sglang/srt/openai_api/adapter.py +59 -35
  115. sglang/srt/openai_api/protocol.py +8 -1
  116. sglang/srt/sampling/penaltylib/frequency_penalty.py +0 -1
  117. sglang/srt/sampling/penaltylib/presence_penalty.py +0 -1
  118. sglang/srt/server_args.py +24 -16
  119. sglang/srt/speculative/eagle_worker.py +75 -39
  120. sglang/srt/utils.py +104 -9
  121. sglang/test/runners.py +104 -10
  122. sglang/test/test_block_fp8.py +106 -16
  123. sglang/test/test_custom_ops.py +88 -0
  124. sglang/test/test_utils.py +20 -4
  125. sglang/utils.py +0 -4
  126. sglang/version.py +1 -1
  127. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/METADATA +9 -10
  128. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/RECORD +131 -84
  129. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/WHEEL +1 -1
  130. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/LICENSE +0 -0
  131. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/top_level.txt +0 -0
@@ -41,8 +41,6 @@ from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
41
41
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
42
42
  from sglang.srt.managers.io_struct import (
43
43
  AbortReq,
44
- BatchEmbeddingOut,
45
- BatchTokenIDOut,
46
44
  CloseSessionReqInput,
47
45
  FlushCacheReq,
48
46
  GetInternalStateReq,
@@ -74,7 +72,6 @@ from sglang.srt.managers.io_struct import (
74
72
  )
75
73
  from sglang.srt.managers.schedule_batch import (
76
74
  FINISH_ABORT,
77
- BaseFinishReason,
78
75
  ImageInputs,
79
76
  Req,
80
77
  ScheduleBatch,
@@ -85,6 +82,9 @@ from sglang.srt.managers.schedule_policy import (
85
82
  PrefillAdder,
86
83
  SchedulePolicy,
87
84
  )
85
+ from sglang.srt.managers.scheduler_output_processor_mixin import (
86
+ SchedulerOutputProcessorMixin,
87
+ )
88
88
  from sglang.srt.managers.session_controller import Session
89
89
  from sglang.srt.managers.tp_worker import TpModelWorker
90
90
  from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
@@ -93,7 +93,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
93
93
  from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
94
94
  from sglang.srt.mem_cache.radix_cache import RadixCache
95
95
  from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
96
- from sglang.srt.model_executor.forward_batch_info import ForwardMode
96
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
97
97
  from sglang.srt.server_args import PortArgs, ServerArgs
98
98
  from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
99
99
  from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
@@ -103,6 +103,7 @@ from sglang.srt.utils import (
103
103
  crash_on_warnings,
104
104
  get_bool_env_var,
105
105
  get_zmq_socket,
106
+ kill_itself_when_parent_died,
106
107
  pyspy_dump_schedulers,
107
108
  set_gpu_proc_affinity,
108
109
  set_random_seed,
@@ -132,7 +133,7 @@ class EmbeddingBatchResult:
132
133
  bid: int
133
134
 
134
135
 
135
- class Scheduler:
136
+ class Scheduler(SchedulerOutputProcessorMixin):
136
137
  """A scheduler that manages a tensor parallel GPU worker."""
137
138
 
138
139
  def __init__(
@@ -159,6 +160,7 @@ class Scheduler:
159
160
  )
160
161
  self.gpu_id = gpu_id
161
162
  self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
163
+ self.page_size = server_args.page_size
162
164
 
163
165
  # Distributed rank info
164
166
  self.dp_size = server_args.dp_size
@@ -270,17 +272,18 @@ class Scheduler:
270
272
 
271
273
  # Init running status
272
274
  self.waiting_queue: List[Req] = []
273
- self.staging_reqs = {}
274
275
  # The running decoding batch for continuous batching
275
- self.running_batch: Optional[ScheduleBatch] = None
276
+ self.running_batch: ScheduleBatch = ScheduleBatch(reqs=[], batch_is_full=False)
276
277
  # The current forward batch
277
278
  self.cur_batch: Optional[ScheduleBatch] = None
278
- # The current forward batch
279
+ # The last forward batch
279
280
  self.last_batch: Optional[ScheduleBatch] = None
280
281
  self.forward_ct = 0
281
282
  self.forward_ct_decode = 0
282
283
  self.num_generated_tokens = 0
284
+ self.num_prefill_tokens = 0
283
285
  self.last_decode_stats_tic = time.time()
286
+ self.last_prefill_stats_tic = time.time()
284
287
  self.return_health_check_ct = 0
285
288
  self.current_stream = torch.get_device_module(self.device).current_stream()
286
289
  if self.device == "cpu":
@@ -308,7 +311,11 @@ class Scheduler:
308
311
  self.grammar_backend = None
309
312
 
310
313
  # Init schedule policy and new token estimation
311
- self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache)
314
+ self.policy = SchedulePolicy(
315
+ self.schedule_policy,
316
+ self.tree_cache,
317
+ self.enable_hierarchical_cache,
318
+ )
312
319
  assert (
313
320
  server_args.schedule_conservativeness >= 0
314
321
  ), "Invalid schedule_conservativeness"
@@ -327,11 +334,6 @@ class Scheduler:
327
334
  ) / global_config.default_new_token_ratio_decay_steps
328
335
  self.new_token_ratio = self.init_new_token_ratio
329
336
 
330
- # Tell whether the current running batch is full so that we can skip
331
- # the check of whether to prefill new requests.
332
- # This is an optimization to reduce the overhead of the prefill check.
333
- self.batch_is_full = False
334
-
335
337
  # Init watchdog thread
336
338
  self.watchdog_timeout = server_args.watchdog_timeout
337
339
  t = threading.Thread(target=self.watchdog_thread, daemon=True)
@@ -431,11 +433,14 @@ class Scheduler:
431
433
  self.tree_cache = HiRadixCache(
432
434
  req_to_token_pool=self.req_to_token_pool,
433
435
  token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
436
+ tp_cache_group=self.tp_worker.get_tp_cpu_group(),
437
+ page_size=self.page_size,
434
438
  )
435
439
  else:
436
440
  self.tree_cache = RadixCache(
437
441
  req_to_token_pool=self.req_to_token_pool,
438
442
  token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
443
+ page_size=self.page_size,
439
444
  disable=server_args.disable_radix_cache,
440
445
  )
441
446
 
@@ -457,6 +462,7 @@ class Scheduler:
457
462
  # The largest context length (prefill + generation) of a single request
458
463
  self._largest_prefill_decode_len: int = 0
459
464
  self.last_gen_throughput: float = 0.0
465
+ self.last_input_throughput: float = 0.0
460
466
  self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
461
467
  self.spec_num_total_accepted_tokens = 0
462
468
  self.spec_num_total_forward_ct = 0
@@ -486,7 +492,7 @@ class Scheduler:
486
492
  result = self.run_batch(batch)
487
493
  self.process_batch_result(batch, result)
488
494
  else:
489
- # When the server is idle, so self-check and re-init some states
495
+ # When the server is idle, do self-check and re-init some states
490
496
  self.check_memory()
491
497
  self.new_token_ratio = self.init_new_token_ratio
492
498
 
@@ -526,7 +532,7 @@ class Scheduler:
526
532
  )
527
533
  self.process_batch_result(tmp_batch, tmp_result)
528
534
  elif batch is None:
529
- # When the server is idle, so self-check and re-init some states
535
+ # When the server is idle, do self-check and re-init some states
530
536
  self.check_memory()
531
537
  self.new_token_ratio = self.init_new_token_ratio
532
538
 
@@ -587,7 +593,7 @@ class Scheduler:
587
593
  for recv_req in recv_reqs:
588
594
  # If it is a health check generation request and there are running requests, ignore it.
589
595
  if is_health_check_generate_req(recv_req) and (
590
- self.chunked_req is not None or self.running_batch is not None
596
+ self.chunked_req is not None or not self.running_batch.is_empty()
591
597
  ):
592
598
  self.return_health_check_ct += 1
593
599
  continue
@@ -767,6 +773,30 @@ class Scheduler:
767
773
  )
768
774
  req.tokenizer = self.tokenizer
769
775
 
776
+ # Handle multimodal inputs
777
+ if recv_req.image_inputs is not None:
778
+ image_inputs = ImageInputs.from_dict(recv_req.image_inputs)
779
+ # Expand a single image token into multiple dummy tokens for receiving image embeddings
780
+ req.origin_input_ids = self.pad_input_ids_func(
781
+ req.origin_input_ids, image_inputs
782
+ )
783
+ req.extend_image_inputs(image_inputs)
784
+
785
+ if len(req.origin_input_ids) >= self.max_req_input_len:
786
+ error_msg = (
787
+ "Multimodal prompt is too long after expanding multimodal tokens. "
788
+ f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
789
+ )
790
+ logger.error(error_msg)
791
+ req.origin_input_ids = [0]
792
+ req.image_inputs = None
793
+ req.sampling_params.max_new_tokens = 0
794
+ req.finished_reason = FINISH_ABORT(
795
+ error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
796
+ )
797
+ self.waiting_queue.append(req)
798
+ return
799
+
770
800
  # Validate prompts length
771
801
  error_msg = validate_input_length(
772
802
  req,
@@ -787,6 +817,11 @@ class Scheduler:
787
817
  can_run_list: List[Req],
788
818
  running_bs: int,
789
819
  ):
820
+ gap_latency = time.time() - self.last_prefill_stats_tic
821
+ self.last_prefill_stats_tic = time.time()
822
+ self.last_input_throughput = self.num_prefill_tokens / gap_latency
823
+ self.num_prefill_tokens = 0
824
+
790
825
  num_used = self.max_total_num_tokens - (
791
826
  self.token_to_kv_pool_allocator.available_size()
792
827
  + self.tree_cache.evictable_size()
@@ -822,7 +857,7 @@ class Scheduler:
822
857
  self.last_decode_stats_tic = time.time()
823
858
  self.last_gen_throughput = self.num_generated_tokens / gap_latency
824
859
  self.num_generated_tokens = 0
825
- num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
860
+ num_running_reqs = len(self.running_batch.reqs)
826
861
  num_used = self.max_total_num_tokens - (
827
862
  self.token_to_kv_pool_allocator.available_size()
828
863
  + self.tree_cache.evictable_size()
@@ -886,8 +921,10 @@ class Scheduler:
886
921
  )
887
922
  if memory_leak:
888
923
  msg = (
889
- "KV cache pool leak detected!"
924
+ "KV cache pool leak detected! "
890
925
  f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
926
+ f"{self.token_to_kv_pool_allocator.available_size()=}\n"
927
+ f"{self.tree_cache.evictable_size()=}\n"
891
928
  )
892
929
  warnings.warn(msg)
893
930
  if crash_on_warnings():
@@ -913,7 +950,7 @@ class Scheduler:
913
950
  self.token_to_kv_pool_allocator.available_size()
914
951
  + self.tree_cache.evictable_size()
915
952
  )
916
- num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
953
+ num_running_reqs = len(self.running_batch.reqs)
917
954
  self.stats.num_running_reqs = num_running_reqs
918
955
  self.stats.num_used_tokens = num_used
919
956
  self.stats.token_usage = num_used / self.max_total_num_tokens
@@ -931,14 +968,20 @@ class Scheduler:
931
968
  self.tree_cache.cache_unfinished_req(self.chunked_req)
932
969
  # chunked request keeps its rid but will get a new req_pool_idx
933
970
  self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
934
- self.batch_is_full = False
971
+ self.running_batch.batch_is_full = False
935
972
 
973
+ # Filter batch
974
+ last_bs = self.last_batch.batch_size()
936
975
  self.last_batch.filter_batch()
976
+ if self.last_batch.batch_size() < last_bs:
977
+ self.running_batch.batch_is_full = False
978
+
979
+ # Merge the new batch into the running batch
937
980
  if not self.last_batch.is_empty():
938
- if self.running_batch is None:
981
+ if self.running_batch.is_empty():
939
982
  self.running_batch = self.last_batch
940
983
  else:
941
- # merge running_batch with prefill batch
984
+ # Merge running_batch with prefill batch
942
985
  self.running_batch.merge_batch(self.last_batch)
943
986
 
944
987
  new_batch = self.get_new_batch_prefill()
@@ -947,15 +990,15 @@ class Scheduler:
947
990
  ret = new_batch
948
991
  else:
949
992
  # Run decode
950
- if self.running_batch is None:
951
- ret = None
952
- else:
993
+ if not self.running_batch.is_empty():
953
994
  self.running_batch = self.update_running_batch(self.running_batch)
954
- ret = self.running_batch
995
+ ret = self.running_batch if not self.running_batch.is_empty() else None
996
+ else:
997
+ ret = None
955
998
 
956
999
  # Handle DP attention
957
1000
  if self.server_args.enable_dp_attention:
958
- ret = self.prepare_dp_attn_batch(ret)
1001
+ ret, _ = self.prepare_dp_attn_batch(ret)
959
1002
 
960
1003
  return ret
961
1004
 
@@ -966,15 +1009,20 @@ class Scheduler:
966
1009
 
967
1010
  # Handle the cases where prefill is not allowed
968
1011
  if (
969
- self.batch_is_full or len(self.waiting_queue) == 0
1012
+ self.running_batch.batch_is_full or len(self.waiting_queue) == 0
970
1013
  ) and self.chunked_req is None:
971
1014
  return None
972
1015
 
973
- running_bs = len(self.running_batch.reqs) if self.running_batch else 0
1016
+ running_bs = len(self.running_batch.reqs)
974
1017
  if running_bs >= self.max_running_requests:
975
- self.batch_is_full = True
1018
+ self.running_batch.batch_is_full = True
976
1019
  return None
977
1020
 
1021
+ if self.enable_hierarchical_cache:
1022
+ # check for completion of hierarchical cache activities to release memory
1023
+ self.tree_cache.writing_check()
1024
+ self.tree_cache.loading_check()
1025
+
978
1026
  # Get priority queue
979
1027
  prefix_computed = self.policy.calc_priority(self.waiting_queue)
980
1028
 
@@ -989,17 +1037,13 @@ class Scheduler:
989
1037
  running_bs if self.is_mixed_chunk else 0,
990
1038
  )
991
1039
 
992
- is_chunked = self.chunked_req is not None
993
- if is_chunked:
1040
+ if self.chunked_req is not None:
994
1041
  self.chunked_req.init_next_round_input()
995
1042
  self.chunked_req = adder.add_chunked_req(self.chunked_req)
996
1043
 
997
1044
  if self.lora_paths:
998
- lora_set = (
999
- set([req.lora_path for req in self.running_batch.reqs])
1000
- if self.running_batch is not None
1001
- else set([])
1002
- )
1045
+ lora_set = set([req.lora_path for req in self.running_batch.reqs])
1046
+
1003
1047
  # Get requests from the waiting queue to a new prefill batch
1004
1048
  for req in self.waiting_queue:
1005
1049
  if (
@@ -1011,49 +1055,33 @@ class Scheduler:
1011
1055
  )
1012
1056
  > self.max_loras_per_batch
1013
1057
  ):
1014
- self.batch_is_full = True
1058
+ self.running_batch.batch_is_full = True
1015
1059
  break
1016
1060
 
1017
1061
  if running_bs + len(adder.can_run_list) >= self.max_running_requests:
1018
- self.batch_is_full = True
1062
+ self.running_batch.batch_is_full = True
1019
1063
  break
1020
1064
 
1021
- req.init_next_round_input(None if prefix_computed else self.tree_cache)
1065
+ req.init_next_round_input(
1066
+ None if prefix_computed else self.tree_cache,
1067
+ self.enable_hierarchical_cache,
1068
+ )
1022
1069
 
1023
- if self.enable_hierarchical_cache and req.last_node is not None:
1024
- if req.last_node.evicted:
1025
- # loading KV cache for the request
1026
- req.last_node, req.prefix_indices = self.tree_cache.init_load_back(
1027
- req.last_node,
1028
- req.prefix_indices,
1029
- adder.rem_total_tokens,
1030
- )
1031
- if req.last_node.loading:
1032
- # to prevent frequent cache invalidation
1033
- if req.rid in self.staging_reqs:
1034
- self.tree_cache.dec_lock_ref(self.staging_reqs[req.rid])
1035
- self.tree_cache.inc_lock_ref(req.last_node)
1036
- self.staging_reqs[req.rid] = req.last_node
1037
- continue
1038
- elif req.last_node.loading:
1039
- if not self.tree_cache.loading_complete(req.last_node):
1040
- continue
1041
-
1042
- if req.rid in self.staging_reqs:
1043
- self.tree_cache.dec_lock_ref(self.staging_reqs[req.rid])
1044
- del self.staging_reqs[req.rid]
1045
-
1046
- res = adder.add_one_req(req, self.chunked_req)
1070
+ res = adder.add_one_req(
1071
+ req, self.chunked_req, self.enable_hierarchical_cache
1072
+ )
1047
1073
  if res != AddReqResult.CONTINUE:
1048
1074
  if res == AddReqResult.NO_TOKEN:
1049
1075
  if self.enable_hierarchical_cache:
1050
1076
  # Set batch_is_full after making sure there are requests that can be served
1051
- self.batch_is_full = len(adder.can_run_list) > 0 or (
1077
+ self.running_batch.batch_is_full = len(
1078
+ adder.can_run_list
1079
+ ) > 0 or (
1052
1080
  self.running_batch is not None
1053
1081
  and not self.running_batch.is_empty()
1054
1082
  )
1055
1083
  else:
1056
- self.batch_is_full = True
1084
+ self.running_batch.batch_is_full = True
1057
1085
  break
1058
1086
 
1059
1087
  # Update waiting queue
@@ -1064,6 +1092,9 @@ class Scheduler:
1064
1092
  x for x in self.waiting_queue if x not in set(can_run_list)
1065
1093
  ]
1066
1094
 
1095
+ if self.enable_hierarchical_cache:
1096
+ self.tree_cache.read_to_load_cache()
1097
+
1067
1098
  if adder.new_chunked_req is not None:
1068
1099
  assert self.chunked_req is None
1069
1100
  self.chunked_req = adder.new_chunked_req
@@ -1091,7 +1122,7 @@ class Scheduler:
1091
1122
  # Mixed-style chunked prefill
1092
1123
  if (
1093
1124
  self.is_mixed_chunk
1094
- and self.running_batch is not None
1125
+ and not self.running_batch.is_empty()
1095
1126
  and not (new_batch.return_logprob or self.running_batch.return_logprob)
1096
1127
  ):
1097
1128
  # TODO (lianmin): support return_logprob + mixed chunked prefill
@@ -1100,7 +1131,9 @@ class Scheduler:
1100
1131
  self.running_batch.prepare_for_decode()
1101
1132
  new_batch.mix_with_running(self.running_batch)
1102
1133
  new_batch.decoding_reqs = self.running_batch.reqs
1103
- self.running_batch = None
1134
+ self.running_batch = ScheduleBatch(
1135
+ reqs=[], batch_is_full=self.running_batch.batch_is_full
1136
+ )
1104
1137
  else:
1105
1138
  new_batch.decoding_reqs = None
1106
1139
 
@@ -1112,8 +1145,8 @@ class Scheduler:
1112
1145
 
1113
1146
  batch.filter_batch()
1114
1147
  if batch.is_empty():
1115
- self.batch_is_full = False
1116
- return None
1148
+ batch.batch_is_full = False
1149
+ return batch
1117
1150
 
1118
1151
  # Check if decode out of memory
1119
1152
  if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or (
@@ -1137,7 +1170,7 @@ class Scheduler:
1137
1170
  )
1138
1171
 
1139
1172
  if batch.batch_size() < initial_bs:
1140
- self.batch_is_full = False
1173
+ batch.batch_is_full = False
1141
1174
 
1142
1175
  # Update batch tensors
1143
1176
  batch.prepare_for_decode()
@@ -1212,8 +1245,6 @@ class Scheduler:
1212
1245
  ):
1213
1246
  if batch.forward_mode.is_decode():
1214
1247
  self.process_batch_result_decode(batch, result)
1215
- if batch.is_empty():
1216
- self.running_batch = None
1217
1248
  elif batch.forward_mode.is_extend():
1218
1249
  self.process_batch_result_prefill(batch, result)
1219
1250
  elif batch.forward_mode.is_idle():
@@ -1235,615 +1266,76 @@ class Scheduler:
1235
1266
  self.return_health_check_ct -= 1
1236
1267
  self.send_to_tokenizer.send_pyobj(HealthCheckOutput())
1237
1268
 
1238
- def process_batch_result_prefill(
1239
- self,
1240
- batch: ScheduleBatch,
1241
- result: Union[GenerationBatchResult, EmbeddingBatchResult],
1242
- ):
1243
- skip_stream_req = None
1244
-
1245
- if self.is_generation:
1246
- (
1247
- logits_output,
1248
- next_token_ids,
1249
- extend_input_len_per_req,
1250
- extend_logprob_start_len_per_req,
1251
- bid,
1252
- ) = (
1253
- result.logits_output,
1254
- result.next_token_ids,
1255
- result.extend_input_len_per_req,
1256
- result.extend_logprob_start_len_per_req,
1257
- result.bid,
1258
- )
1259
-
1260
- if self.enable_overlap:
1261
- logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
1262
- else:
1263
- # Move next_token_ids and logprobs to cpu
1264
- next_token_ids = next_token_ids.tolist()
1265
- if batch.return_logprob:
1266
- if logits_output.next_token_logprobs is not None:
1267
- logits_output.next_token_logprobs = (
1268
- logits_output.next_token_logprobs.tolist()
1269
- )
1270
- if logits_output.input_token_logprobs is not None:
1271
- logits_output.input_token_logprobs = tuple(
1272
- logits_output.input_token_logprobs.tolist()
1273
- )
1274
-
1275
- hidden_state_offset = 0
1276
-
1277
- # Check finish conditions
1278
- logprob_pt = 0
1279
- for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
1280
- if req.is_retracted:
1281
- continue
1282
-
1283
- if self.is_mixed_chunk and self.enable_overlap and req.finished():
1284
- # Free the one delayed token for the mixed decode batch
1285
- j = len(batch.out_cache_loc) - len(batch.reqs) + i
1286
- self.token_to_kv_pool_allocator.free(batch.out_cache_loc[j : j + 1])
1287
- continue
1288
-
1289
- if req.is_chunked <= 0:
1290
- # req output_ids are set here
1291
- req.output_ids.append(next_token_id)
1292
- req.check_finished()
1293
-
1294
- if req.finished():
1295
- self.tree_cache.cache_finished_req(req)
1296
- elif not batch.decoding_reqs or req not in batch.decoding_reqs:
1297
- # This updates radix so others can match
1298
- self.tree_cache.cache_unfinished_req(req)
1299
-
1300
- if req.return_logprob:
1301
- assert extend_logprob_start_len_per_req is not None
1302
- assert extend_input_len_per_req is not None
1303
- extend_logprob_start_len = extend_logprob_start_len_per_req[i]
1304
- extend_input_len = extend_input_len_per_req[i]
1305
- num_input_logprobs = extend_input_len - extend_logprob_start_len
1306
- self.add_logprob_return_values(
1307
- i,
1308
- req,
1309
- logprob_pt,
1310
- next_token_ids,
1311
- num_input_logprobs,
1312
- logits_output,
1313
- )
1314
- logprob_pt += num_input_logprobs
1315
-
1316
- if (
1317
- req.return_hidden_states
1318
- and logits_output.hidden_states is not None
1319
- ):
1320
- req.hidden_states.append(
1321
- logits_output.hidden_states[
1322
- hidden_state_offset : (
1323
- hidden_state_offset := hidden_state_offset
1324
- + len(req.origin_input_ids)
1325
- )
1326
- ]
1327
- .cpu()
1328
- .clone()
1329
- )
1330
-
1331
- if req.grammar is not None:
1332
- req.grammar.accept_token(next_token_id)
1333
- req.grammar.finished = req.finished()
1334
- else:
1335
- # being chunked reqs' prefill is not finished
1336
- req.is_chunked -= 1
1337
- # There is only at most one request being currently chunked.
1338
- # Because this request does not finish prefill,
1339
- # we don't want to stream the request currently being chunked.
1340
- skip_stream_req = req
1341
-
1342
- # Incrementally update input logprobs.
1343
- if req.return_logprob:
1344
- extend_logprob_start_len = extend_logprob_start_len_per_req[i]
1345
- extend_input_len = extend_input_len_per_req[i]
1346
- if extend_logprob_start_len < extend_input_len:
1347
- # Update input logprobs.
1348
- num_input_logprobs = (
1349
- extend_input_len - extend_logprob_start_len
1350
- )
1351
- self.add_input_logprob_return_values(
1352
- i,
1353
- req,
1354
- logits_output,
1355
- logprob_pt,
1356
- num_input_logprobs,
1357
- last_prefill_chunk=False,
1358
- )
1359
- logprob_pt += num_input_logprobs
1360
-
1361
- if batch.next_batch_sampling_info:
1362
- batch.next_batch_sampling_info.update_regex_vocab_mask()
1363
- self.current_stream.synchronize()
1364
- batch.next_batch_sampling_info.sampling_info_done.set()
1365
-
1366
- else: # embedding or reward model
1367
- embeddings, bid = result.embeddings, result.bid
1368
- embeddings = embeddings.tolist()
1369
-
1370
- # Check finish conditions
1371
- for i, req in enumerate(batch.reqs):
1372
- if req.is_retracted:
1373
- continue
1374
-
1375
- req.embedding = embeddings[i]
1376
- if req.is_chunked <= 0:
1377
- # Dummy output token for embedding models
1378
- req.output_ids.append(0)
1379
- req.check_finished()
1380
-
1381
- if req.finished():
1382
- self.tree_cache.cache_finished_req(req)
1383
- else:
1384
- self.tree_cache.cache_unfinished_req(req)
1385
- else:
1386
- # being chunked reqs' prefill is not finished
1387
- req.is_chunked -= 1
1388
-
1389
- self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
1390
-
1391
- def process_batch_result_decode(
1392
- self,
1393
- batch: ScheduleBatch,
1394
- result: GenerationBatchResult,
1395
- ):
1396
- logits_output, next_token_ids, bid = (
1397
- result.logits_output,
1398
- result.next_token_ids,
1399
- result.bid,
1400
- )
1401
- self.num_generated_tokens += len(batch.reqs)
1402
-
1403
- if self.enable_overlap:
1404
- assert batch.spec_algorithm.is_none()
1405
- logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
1406
- next_token_logprobs = logits_output.next_token_logprobs
1407
- elif batch.spec_algorithm.is_none():
1408
- # spec decoding handles output logprobs inside verify process.
1409
- next_token_ids = next_token_ids.tolist()
1410
- if batch.return_logprob:
1411
- next_token_logprobs = logits_output.next_token_logprobs.tolist()
1412
-
1413
- self.token_to_kv_pool_allocator.free_group_begin()
1414
-
1415
- # Check finish condition
1416
- # NOTE: the length of reqs and next_token_ids don't match if it is spec decoding.
1417
- # We should ignore using next_token_ids for spec decoding cases.
1418
- for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
1419
- if req.is_retracted:
1420
- continue
1421
-
1422
- if self.enable_overlap and req.finished():
1423
- # Free the one delayed token
1424
- self.token_to_kv_pool_allocator.free(batch.out_cache_loc[i : i + 1])
1425
- continue
1426
-
1427
- if batch.spec_algorithm.is_none():
1428
- # speculative worker will solve the output_ids in speculative decoding
1429
- req.output_ids.append(next_token_id)
1430
-
1431
- req.check_finished()
1432
- if req.finished():
1433
- self.tree_cache.cache_finished_req(req)
1434
-
1435
- if req.return_logprob and batch.spec_algorithm.is_none():
1436
- # speculative worker handles logprob in speculative decoding
1437
- req.output_token_logprobs_val.append(next_token_logprobs[i])
1438
- req.output_token_logprobs_idx.append(next_token_id)
1439
- if req.top_logprobs_num > 0:
1440
- req.output_top_logprobs_val.append(
1441
- logits_output.next_token_top_logprobs_val[i]
1442
- )
1443
- req.output_top_logprobs_idx.append(
1444
- logits_output.next_token_top_logprobs_idx[i]
1445
- )
1446
- if req.token_ids_logprob is not None:
1447
- req.output_token_ids_logprobs_val.append(
1448
- logits_output.next_token_token_ids_logprobs_val[i]
1449
- )
1450
- req.output_token_ids_logprobs_idx.append(
1451
- logits_output.next_token_token_ids_logprobs_idx[i]
1452
- )
1453
-
1454
- if req.return_hidden_states and logits_output.hidden_states is not None:
1455
- req.hidden_states.append(logits_output.hidden_states[i].cpu().clone())
1456
-
1457
- if req.grammar is not None and batch.spec_algorithm.is_none():
1458
- req.grammar.accept_token(next_token_id)
1459
- req.grammar.finished = req.finished()
1460
-
1461
- if batch.next_batch_sampling_info:
1462
- batch.next_batch_sampling_info.update_regex_vocab_mask()
1463
- self.current_stream.synchronize()
1464
- batch.next_batch_sampling_info.sampling_info_done.set()
1465
-
1466
- self.stream_output(batch.reqs, batch.return_logprob)
1467
-
1468
- self.token_to_kv_pool_allocator.free_group_end()
1469
-
1470
- self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
1471
- if (
1472
- self.attn_tp_rank == 0
1473
- and self.forward_ct_decode % self.server_args.decode_log_interval == 0
1474
- ):
1475
- self.log_decode_stats()
1476
-
1477
- def add_input_logprob_return_values(
1478
- self,
1479
- i: int,
1480
- req: Req,
1481
- output: LogitsProcessorOutput,
1482
- logprob_pt: int,
1483
- num_input_logprobs: int,
1484
- last_prefill_chunk: bool, # If True, it means prefill is finished.
1485
- ):
1486
- """Incrementally add input logprobs to `req`.
1487
-
1488
- Args:
1489
- i: The request index in a batch.
1490
- req: The request. Input logprobs inside req are modified as a
1491
- consequence of the API
1492
- fill_ids: The prefill ids processed.
1493
- output: Logit processor output that's used to compute input logprobs
1494
- last_prefill_chunk: True if it is the last prefill (when chunked).
1495
- Some of input logprob operation should only happen at the last
1496
- prefill (e.g., computing input token logprobs).
1497
- """
1498
- assert output.input_token_logprobs is not None
1499
- if req.input_token_logprobs is None:
1500
- req.input_token_logprobs = []
1501
- if req.temp_input_top_logprobs_val is None:
1502
- req.temp_input_top_logprobs_val = []
1503
- if req.temp_input_top_logprobs_idx is None:
1504
- req.temp_input_top_logprobs_idx = []
1505
- if req.temp_input_token_ids_logprobs_val is None:
1506
- req.temp_input_token_ids_logprobs_val = []
1507
- if req.temp_input_token_ids_logprobs_idx is None:
1508
- req.temp_input_token_ids_logprobs_idx = []
1509
-
1510
- if req.input_token_logprobs_val is not None:
1511
- # The input logprob has been already computed. It only happens
1512
- # upon retract.
1513
- if req.top_logprobs_num > 0:
1514
- assert req.input_token_logprobs_val is not None
1515
- return
1516
-
1517
- # Important for the performance.
1518
- assert isinstance(output.input_token_logprobs, tuple)
1519
- input_token_logprobs: Tuple[int] = output.input_token_logprobs
1520
- input_token_logprobs = input_token_logprobs[
1521
- logprob_pt : logprob_pt + num_input_logprobs
1522
- ]
1523
- req.input_token_logprobs.extend(input_token_logprobs)
1524
-
1525
- if req.top_logprobs_num > 0:
1526
- req.temp_input_top_logprobs_val.append(output.input_top_logprobs_val[i])
1527
- req.temp_input_top_logprobs_idx.append(output.input_top_logprobs_idx[i])
1528
-
1529
- if req.token_ids_logprob is not None:
1530
- req.temp_input_token_ids_logprobs_val.append(
1531
- output.input_token_ids_logprobs_val[i]
1532
- )
1533
- req.temp_input_token_ids_logprobs_idx.append(
1534
- output.input_token_ids_logprobs_idx[i]
1535
- )
1536
-
1537
- if last_prefill_chunk:
1538
- input_token_logprobs = req.input_token_logprobs
1539
- req.input_token_logprobs = None
1540
- assert req.input_token_logprobs_val is None
1541
- assert req.input_token_logprobs_idx is None
1542
- assert req.input_top_logprobs_val is None
1543
- assert req.input_top_logprobs_idx is None
1544
-
1545
- # Compute input_token_logprobs_val
1546
- # Always pad the first one with None.
1547
- req.input_token_logprobs_val = [None]
1548
- req.input_token_logprobs_val.extend(input_token_logprobs)
1549
- # The last input logprob is for sampling, so just pop it out.
1550
- req.input_token_logprobs_val.pop()
1551
-
1552
- # Compute input_token_logprobs_idx
1553
- input_token_logprobs_idx = req.origin_input_ids[req.logprob_start_len :]
1554
- # Clip the padded hash values from image tokens.
1555
- # Otherwise, it will lead to detokenization errors.
1556
- input_token_logprobs_idx = [
1557
- x if x < self.model_config.vocab_size - 1 else 0
1558
- for x in input_token_logprobs_idx
1559
- ]
1560
- req.input_token_logprobs_idx = input_token_logprobs_idx
1561
-
1562
- if req.top_logprobs_num > 0:
1563
- req.input_top_logprobs_val = [None]
1564
- req.input_top_logprobs_idx = [None]
1565
- assert len(req.temp_input_token_ids_logprobs_val) == len(
1566
- req.temp_input_token_ids_logprobs_idx
1567
- )
1568
- for val, idx in zip(
1569
- req.temp_input_top_logprobs_val,
1570
- req.temp_input_top_logprobs_idx,
1571
- strict=True,
1572
- ):
1573
- req.input_top_logprobs_val.extend(val)
1574
- req.input_top_logprobs_idx.extend(idx)
1575
-
1576
- # Last token is a sample token.
1577
- req.input_top_logprobs_val.pop()
1578
- req.input_top_logprobs_idx.pop()
1579
- req.temp_input_top_logprobs_idx = None
1580
- req.temp_input_top_logprobs_val = None
1581
-
1582
- if req.token_ids_logprob is not None:
1583
- req.input_token_ids_logprobs_val = [None]
1584
- req.input_token_ids_logprobs_idx = [None]
1585
-
1586
- for val, idx in zip(
1587
- req.temp_input_token_ids_logprobs_val,
1588
- req.temp_input_token_ids_logprobs_idx,
1589
- strict=True,
1590
- ):
1591
- req.input_token_ids_logprobs_val.extend(val)
1592
- req.input_token_ids_logprobs_idx.extend(idx)
1593
-
1594
- # Last token is a sample token.
1595
- req.input_token_ids_logprobs_val.pop()
1596
- req.input_token_ids_logprobs_idx.pop()
1597
- req.temp_input_token_ids_logprobs_idx = None
1598
- req.temp_input_token_ids_logprobs_val = None
1599
-
1600
- if req.return_logprob:
1601
- relevant_tokens_len = len(req.origin_input_ids) - req.logprob_start_len
1602
- assert len(req.input_token_logprobs_val) == relevant_tokens_len
1603
- assert len(req.input_token_logprobs_idx) == relevant_tokens_len
1604
- if req.top_logprobs_num > 0:
1605
- assert len(req.input_top_logprobs_val) == relevant_tokens_len
1606
- assert len(req.input_top_logprobs_idx) == relevant_tokens_len
1607
- if req.token_ids_logprob is not None:
1608
- assert len(req.input_token_ids_logprobs_val) == relevant_tokens_len
1609
- assert len(req.input_token_ids_logprobs_idx) == relevant_tokens_len
1610
-
1611
- def add_logprob_return_values(
1612
- self,
1613
- i: int,
1614
- req: Req,
1615
- pt: int,
1616
- next_token_ids: List[int],
1617
- num_input_logprobs: int,
1618
- output: LogitsProcessorOutput,
1619
- ):
1620
- """Attach logprobs to the return values."""
1621
- req.output_token_logprobs_val.append(output.next_token_logprobs[i])
1622
- req.output_token_logprobs_idx.append(next_token_ids[i])
1623
-
1624
- self.add_input_logprob_return_values(
1625
- i, req, output, pt, num_input_logprobs, last_prefill_chunk=True
1626
- )
1627
-
1628
- if req.top_logprobs_num > 0:
1629
- req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i])
1630
- req.output_top_logprobs_idx.append(output.next_token_top_logprobs_idx[i])
1631
-
1632
- if req.token_ids_logprob is not None:
1633
- req.output_token_ids_logprobs_val.append(
1634
- output.next_token_token_ids_logprobs_val[i]
1635
- )
1636
- req.output_token_ids_logprobs_idx.append(
1637
- output.next_token_token_ids_logprobs_idx[i]
1638
- )
1639
-
1640
- return num_input_logprobs
1641
-
1642
- def stream_output(
1643
- self, reqs: List[Req], return_logprob: bool, skip_req: Optional[Req] = None
1644
- ):
1645
- """Stream the output to detokenizer."""
1646
- rids = []
1647
- finished_reasons: List[BaseFinishReason] = []
1648
-
1649
- if self.is_generation:
1650
- decoded_texts = []
1651
- decode_ids_list = []
1652
- read_offsets = []
1653
- output_ids = []
1654
-
1655
- skip_special_tokens = []
1656
- spaces_between_special_tokens = []
1657
- no_stop_trim = []
1658
- prompt_tokens = []
1659
- completion_tokens = []
1660
- cached_tokens = []
1661
- spec_verify_ct = []
1662
- output_hidden_states = None
1663
-
1664
- if return_logprob:
1665
- input_token_logprobs_val = []
1666
- input_token_logprobs_idx = []
1667
- output_token_logprobs_val = []
1668
- output_token_logprobs_idx = []
1669
- input_top_logprobs_val = []
1670
- input_top_logprobs_idx = []
1671
- output_top_logprobs_val = []
1672
- output_top_logprobs_idx = []
1673
- input_token_ids_logprobs_val = []
1674
- input_token_ids_logprobs_idx = []
1675
- output_token_ids_logprobs_val = []
1676
- output_token_ids_logprobs_idx = []
1677
- else:
1678
- input_token_logprobs_val = input_token_logprobs_idx = (
1679
- output_token_logprobs_val
1680
- ) = output_token_logprobs_idx = input_top_logprobs_val = (
1681
- input_top_logprobs_idx
1682
- ) = output_top_logprobs_val = output_top_logprobs_idx = (
1683
- input_token_ids_logprobs_val
1684
- ) = input_token_ids_logprobs_idx = output_token_ids_logprobs_val = (
1685
- output_token_ids_logprobs_idx
1686
- ) = None
1687
-
1688
- for req in reqs:
1689
- if req is skip_req:
1690
- continue
1691
-
1692
- # Multimodal partial stream chunks break the detokenizer, so drop aborted requests here.
1693
- if self.model_config.is_multimodal_gen and req.to_abort:
1694
- continue
1695
-
1696
- if (
1697
- req.finished()
1698
- # If stream, follow the given stream_interval
1699
- or (req.stream and len(req.output_ids) % self.stream_interval == 0)
1700
- # If not stream, we still want to output some tokens to get the benefit of incremental decoding.
1701
- # TODO(lianmin): this is wrong for speculative decoding because len(req.output_ids) does not
1702
- # always increase one-by-one.
1703
- or (
1704
- not req.stream
1705
- and len(req.output_ids) % 50 == 0
1706
- and not self.model_config.is_multimodal_gen
1707
- )
1708
- ):
1709
- rids.append(req.rid)
1710
- finished_reasons.append(
1711
- req.finished_reason.to_json() if req.finished_reason else None
1712
- )
1713
- decoded_texts.append(req.decoded_text)
1714
- decode_ids, read_offset = req.init_incremental_detokenize()
1715
- decode_ids_list.append(decode_ids)
1716
- read_offsets.append(read_offset)
1717
- if self.skip_tokenizer_init:
1718
- output_ids.append(req.output_ids)
1719
- skip_special_tokens.append(req.sampling_params.skip_special_tokens)
1720
- spaces_between_special_tokens.append(
1721
- req.sampling_params.spaces_between_special_tokens
1722
- )
1723
- no_stop_trim.append(req.sampling_params.no_stop_trim)
1724
-
1725
- prompt_tokens.append(len(req.origin_input_ids))
1726
- completion_tokens.append(len(req.output_ids))
1727
- cached_tokens.append(req.cached_tokens)
1728
-
1729
- if not self.spec_algorithm.is_none():
1730
- spec_verify_ct.append(req.spec_verify_ct)
1731
-
1732
- if return_logprob:
1733
- input_token_logprobs_val.append(req.input_token_logprobs_val)
1734
- input_token_logprobs_idx.append(req.input_token_logprobs_idx)
1735
- output_token_logprobs_val.append(req.output_token_logprobs_val)
1736
- output_token_logprobs_idx.append(req.output_token_logprobs_idx)
1737
- input_top_logprobs_val.append(req.input_top_logprobs_val)
1738
- input_top_logprobs_idx.append(req.input_top_logprobs_idx)
1739
- output_top_logprobs_val.append(req.output_top_logprobs_val)
1740
- output_top_logprobs_idx.append(req.output_top_logprobs_idx)
1741
- input_token_ids_logprobs_val.append(
1742
- req.input_token_ids_logprobs_val
1743
- )
1744
- input_token_ids_logprobs_idx.append(
1745
- req.input_token_ids_logprobs_idx
1746
- )
1747
- output_token_ids_logprobs_val.append(
1748
- req.output_token_ids_logprobs_val
1749
- )
1750
- output_token_ids_logprobs_idx.append(
1751
- req.output_token_ids_logprobs_idx
1752
- )
1753
-
1754
- if req.return_hidden_states:
1755
- if output_hidden_states is None:
1756
- output_hidden_states = []
1757
- output_hidden_states.append(req.hidden_states)
1758
-
1759
- # Send to detokenizer
1760
- if rids:
1761
- if self.model_config.is_multimodal_gen:
1762
- raise NotImplementedError()
1763
- self.send_to_detokenizer.send_pyobj(
1764
- BatchTokenIDOut(
1765
- rids,
1766
- finished_reasons,
1767
- decoded_texts,
1768
- decode_ids_list,
1769
- read_offsets,
1770
- output_ids,
1771
- skip_special_tokens,
1772
- spaces_between_special_tokens,
1773
- no_stop_trim,
1774
- prompt_tokens,
1775
- completion_tokens,
1776
- cached_tokens,
1777
- spec_verify_ct,
1778
- input_token_logprobs_val,
1779
- input_token_logprobs_idx,
1780
- output_token_logprobs_val,
1781
- output_token_logprobs_idx,
1782
- input_top_logprobs_val,
1783
- input_top_logprobs_idx,
1784
- output_top_logprobs_val,
1785
- output_top_logprobs_idx,
1786
- input_token_ids_logprobs_val,
1787
- input_token_ids_logprobs_idx,
1788
- output_token_ids_logprobs_val,
1789
- output_token_ids_logprobs_idx,
1790
- output_hidden_states,
1791
- )
1792
- )
1793
- else: # embedding or reward model
1794
- embeddings = []
1795
- prompt_tokens = []
1796
- cached_tokens = []
1797
- for req in reqs:
1798
- if req.finished():
1799
- rids.append(req.rid)
1800
- finished_reasons.append(req.finished_reason.to_json())
1801
- embeddings.append(req.embedding)
1802
- prompt_tokens.append(len(req.origin_input_ids))
1803
- cached_tokens.append(req.cached_tokens)
1804
- self.send_to_detokenizer.send_pyobj(
1805
- BatchEmbeddingOut(
1806
- rids, finished_reasons, embeddings, prompt_tokens, cached_tokens
1807
- )
1808
- )
1809
-
1810
1269
  def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
1811
1270
  # Check if other DP workers have running batches
1812
1271
  if local_batch is None:
1813
1272
  num_tokens = 0
1273
+ global_num_tokens_for_logprob = 0
1814
1274
  elif local_batch.forward_mode.is_decode():
1815
1275
  num_tokens = local_batch.batch_size()
1276
+ if not self.spec_algorithm.is_none() and self.spec_algorithm.is_eagle():
1277
+ num_tokens = num_tokens * self.server_args.speculative_num_draft_tokens
1278
+ global_num_tokens_for_logprob = num_tokens
1816
1279
  else:
1817
1280
  num_tokens = local_batch.extend_num_tokens
1281
+ global_num_tokens_for_logprob = sum(
1282
+ [
1283
+ # We should have at least 1 token for sample in every case.
1284
+ max(extend_len - logprob_start_len, 1)
1285
+ for logprob_start_len, extend_len in zip(
1286
+ local_batch.extend_logprob_start_lens, local_batch.extend_lens
1287
+ )
1288
+ ]
1289
+ )
1290
+
1291
+ if local_batch is None or local_batch.forward_mode.is_decode_or_idle():
1292
+ can_cuda_graph = 1
1293
+ else:
1294
+ can_cuda_graph = 0
1818
1295
 
1819
- local_num_tokens = torch.tensor([num_tokens], dtype=torch.int64)
1820
- global_num_tokens = torch.empty(self.tp_size, dtype=torch.int64)
1296
+ if not self.spec_algorithm.is_none():
1297
+ # TODO(sang): Support cuda graph when idle batch is there.
1298
+ if local_batch is None or local_batch.forward_mode.is_idle():
1299
+ can_cuda_graph = 0
1300
+
1301
+ is_extend_in_batch = (
1302
+ local_batch.forward_mode.is_extend() if local_batch else False
1303
+ )
1304
+ local_info = torch.tensor(
1305
+ [
1306
+ num_tokens,
1307
+ can_cuda_graph,
1308
+ global_num_tokens_for_logprob,
1309
+ is_extend_in_batch,
1310
+ ],
1311
+ dtype=torch.int64,
1312
+ )
1313
+ global_info = torch.empty(
1314
+ (self.server_args.dp_size, self.attn_tp_size, 4),
1315
+ dtype=torch.int64,
1316
+ )
1821
1317
  torch.distributed.all_gather_into_tensor(
1822
- global_num_tokens,
1823
- local_num_tokens,
1318
+ global_info.flatten(),
1319
+ local_info,
1824
1320
  group=self.tp_cpu_group,
1825
1321
  )
1322
+ global_num_tokens = global_info[:, 0, 0].tolist()
1323
+ can_cuda_graph = min(global_info[:, 0, 1].tolist())
1324
+ global_num_tokens_for_logprob = global_info[:, 0, 2].tolist()
1325
+ is_extend_in_batch = global_info[:, 0, 3].tolist()
1826
1326
 
1827
- if local_batch is None and global_num_tokens.max().item() > 0:
1327
+ if local_batch is None and max(global_num_tokens) > 0:
1828
1328
  local_batch = self.get_idle_batch()
1829
1329
 
1830
1330
  if local_batch is not None:
1831
- local_batch.global_num_tokens = global_num_tokens.tolist()
1331
+ local_batch.global_num_tokens = global_num_tokens
1332
+ local_batch.global_num_tokens_for_logprob = global_num_tokens_for_logprob
1832
1333
 
1833
1334
  # Check forward mode for cuda graph
1834
1335
  if not self.server_args.disable_cuda_graph:
1835
- forward_mode_state = torch.tensor(
1836
- (1 if local_batch.forward_mode.is_decode_or_idle() else 0),
1837
- dtype=torch.int32,
1838
- )
1839
- torch.distributed.all_reduce(
1840
- forward_mode_state,
1841
- op=torch.distributed.ReduceOp.MIN,
1842
- group=self.tp_cpu_group,
1843
- )
1844
- local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1
1336
+ local_batch.can_run_dp_cuda_graph = can_cuda_graph
1845
1337
 
1846
- return local_batch
1338
+ return local_batch, any(is_extend_in_batch)
1847
1339
 
1848
1340
  def get_idle_batch(self):
1849
1341
  idle_batch = ScheduleBatch.init_new(
@@ -1926,9 +1418,7 @@ class Scheduler:
1926
1418
 
1927
1419
  def flush_cache(self):
1928
1420
  """Flush the memory pool and cache."""
1929
- if len(self.waiting_queue) == 0 and (
1930
- self.running_batch is None or len(self.running_batch.reqs) == 0
1931
- ):
1421
+ if len(self.waiting_queue) == 0 and self.running_batch.is_empty():
1932
1422
  self.cur_batch = None
1933
1423
  self.last_batch = None
1934
1424
  self.tree_cache.reset()
@@ -1954,7 +1444,7 @@ class Scheduler:
1954
1444
  logging.warning(
1955
1445
  f"Cache not flushed because there are pending requests. "
1956
1446
  f"#queue-req: {len(self.waiting_queue)}, "
1957
- f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
1447
+ f"#running-req: {len(self.running_batch.reqs)}"
1958
1448
  )
1959
1449
  if_success = False
1960
1450
  return if_success
@@ -2004,24 +1494,24 @@ class Scheduler:
2004
1494
 
2005
1495
  def abort_request(self, recv_req: AbortReq):
2006
1496
  # Delete requests in the waiting queue
2007
- to_del = None
1497
+ to_del = []
2008
1498
  for i, req in enumerate(self.waiting_queue):
2009
- if req.rid == recv_req.rid:
2010
- to_del = i
1499
+ if req.rid.startswith(recv_req.rid):
1500
+ to_del.append(i)
2011
1501
  break
2012
1502
 
2013
- if to_del is not None:
2014
- del self.waiting_queue[to_del]
1503
+ # Sort in reverse order to avoid index issues when deleting
1504
+ for i in sorted(to_del, reverse=True):
1505
+ req = self.waiting_queue.pop(i)
2015
1506
  logger.debug(f"Abort queued request. {req.rid=}")
2016
1507
  return
2017
1508
 
2018
1509
  # Delete requests in the running batch
2019
- if self.running_batch:
2020
- for req in self.running_batch.reqs:
2021
- if req.rid == recv_req.rid and not req.finished():
2022
- logger.debug(f"Abort running request. {req.rid=}")
2023
- req.to_abort = True
2024
- break
1510
+ for req in self.running_batch.reqs:
1511
+ if req.rid.startswith(recv_req.rid) and not req.finished():
1512
+ logger.debug(f"Abort running request. {req.rid=}")
1513
+ req.to_abort = True
1514
+ return
2025
1515
 
2026
1516
  def _pause_engine(self) -> Tuple[List[Req], int]:
2027
1517
  raise NotImplementedError()
@@ -2228,9 +1718,16 @@ def run_scheduler_process(
2228
1718
  dp_rank: Optional[int],
2229
1719
  pipe_writer,
2230
1720
  ):
1721
+
1722
+ # Generate the prefix
1723
+ if dp_rank is None:
1724
+ prefix = f" TP{tp_rank}"
1725
+ else:
1726
+ prefix = f" DP{dp_rank} TP{tp_rank}"
1727
+
2231
1728
  # Config the process
2232
1729
  # kill_itself_when_parent_died() # This is disabled because it does not work for `--dp 2`
2233
- setproctitle.setproctitle(f"sglang::scheduler_{dp_rank}")
1730
+ setproctitle.setproctitle(f"sglang::scheduler{prefix.replace(' ', '_')}")
2234
1731
  faulthandler.enable()
2235
1732
  parent_process = psutil.Process().parent()
2236
1733
 
@@ -2239,10 +1736,6 @@ def run_scheduler_process(
2239
1736
  dp_rank = int(os.environ["SGLANG_DP_RANK"])
2240
1737
 
2241
1738
  # Configure the logger
2242
- if dp_rank is None:
2243
- prefix = f" TP{tp_rank}"
2244
- else:
2245
- prefix = f" DP{dp_rank} TP{tp_rank}"
2246
1739
  configure_logger(server_args, prefix=prefix)
2247
1740
  suppress_other_loggers()
2248
1741