sglang 0.4.3.post4__py3-none-any.whl → 0.4.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. sglang/bench_serving.py +1 -1
  2. sglang/lang/chat_template.py +29 -0
  3. sglang/srt/_custom_ops.py +19 -17
  4. sglang/srt/configs/__init__.py +2 -0
  5. sglang/srt/configs/janus_pro.py +629 -0
  6. sglang/srt/configs/model_config.py +24 -14
  7. sglang/srt/conversation.py +80 -2
  8. sglang/srt/custom_op.py +64 -3
  9. sglang/srt/distributed/device_communicators/custom_all_reduce.py +18 -17
  10. sglang/srt/distributed/parallel_state.py +10 -1
  11. sglang/srt/entrypoints/engine.py +5 -3
  12. sglang/srt/entrypoints/http_server.py +1 -1
  13. sglang/srt/hf_transformers_utils.py +16 -1
  14. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  15. sglang/srt/layers/attention/flashinfer_mla_backend.py +317 -57
  16. sglang/srt/layers/attention/triton_backend.py +1 -3
  17. sglang/srt/layers/attention/triton_ops/decode_attention.py +6 -6
  18. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +3 -3
  19. sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
  20. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +3 -3
  21. sglang/srt/layers/attention/vision.py +43 -62
  22. sglang/srt/layers/linear.py +1 -1
  23. sglang/srt/layers/moe/ep_moe/kernels.py +2 -1
  24. sglang/srt/layers/moe/ep_moe/layer.py +25 -9
  25. sglang/srt/layers/moe/fused_moe_triton/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  27. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +63 -23
  32. sglang/srt/layers/moe/fused_moe_triton/layer.py +16 -4
  33. sglang/srt/layers/parameter.py +10 -0
  34. sglang/srt/layers/quantization/__init__.py +90 -68
  35. sglang/srt/layers/quantization/blockwise_int8.py +1 -2
  36. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  37. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  38. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  39. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  40. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  41. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  44. sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  45. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  46. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  47. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  48. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  49. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  50. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  51. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  52. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/quantization/fp8.py +174 -106
  63. sglang/srt/layers/quantization/fp8_kernel.py +210 -38
  64. sglang/srt/layers/quantization/fp8_utils.py +156 -15
  65. sglang/srt/layers/quantization/modelopt_quant.py +5 -1
  66. sglang/srt/layers/quantization/w8a8_fp8.py +128 -0
  67. sglang/srt/layers/quantization/w8a8_int8.py +152 -3
  68. sglang/srt/layers/rotary_embedding.py +5 -3
  69. sglang/srt/layers/sampler.py +29 -35
  70. sglang/srt/layers/vocab_parallel_embedding.py +0 -1
  71. sglang/srt/lora/backend/__init__.py +9 -12
  72. sglang/srt/managers/cache_controller.py +72 -8
  73. sglang/srt/managers/image_processor.py +37 -631
  74. sglang/srt/managers/image_processors/base_image_processor.py +219 -0
  75. sglang/srt/managers/image_processors/janus_pro.py +79 -0
  76. sglang/srt/managers/image_processors/llava.py +152 -0
  77. sglang/srt/managers/image_processors/minicpmv.py +86 -0
  78. sglang/srt/managers/image_processors/mlama.py +60 -0
  79. sglang/srt/managers/image_processors/qwen_vl.py +161 -0
  80. sglang/srt/managers/io_struct.py +32 -15
  81. sglang/srt/managers/multi_modality_padding.py +134 -0
  82. sglang/srt/managers/schedule_batch.py +212 -117
  83. sglang/srt/managers/schedule_policy.py +40 -8
  84. sglang/srt/managers/scheduler.py +124 -665
  85. sglang/srt/managers/scheduler_output_processor_mixin.py +611 -0
  86. sglang/srt/managers/tokenizer_manager.py +6 -6
  87. sglang/srt/managers/tp_worker_overlap_thread.py +4 -1
  88. sglang/srt/mem_cache/base_prefix_cache.py +6 -8
  89. sglang/srt/mem_cache/chunk_cache.py +12 -44
  90. sglang/srt/mem_cache/hiradix_cache.py +63 -34
  91. sglang/srt/mem_cache/memory_pool.py +78 -17
  92. sglang/srt/mem_cache/paged_allocator.py +283 -0
  93. sglang/srt/mem_cache/radix_cache.py +117 -36
  94. sglang/srt/model_executor/cuda_graph_runner.py +9 -4
  95. sglang/srt/model_executor/forward_batch_info.py +12 -8
  96. sglang/srt/model_executor/model_runner.py +63 -63
  97. sglang/srt/model_loader/loader.py +2 -1
  98. sglang/srt/model_loader/weight_utils.py +1 -1
  99. sglang/srt/models/deepseek_janus_pro.py +2127 -0
  100. sglang/srt/models/deepseek_nextn.py +23 -3
  101. sglang/srt/models/deepseek_v2.py +25 -19
  102. sglang/srt/models/minicpmv.py +28 -89
  103. sglang/srt/models/mllama.py +1 -1
  104. sglang/srt/models/qwen2.py +0 -1
  105. sglang/srt/models/qwen2_5_vl.py +25 -50
  106. sglang/srt/models/qwen2_vl.py +33 -49
  107. sglang/srt/openai_api/adapter.py +37 -15
  108. sglang/srt/openai_api/protocol.py +8 -1
  109. sglang/srt/sampling/penaltylib/frequency_penalty.py +0 -1
  110. sglang/srt/sampling/penaltylib/presence_penalty.py +0 -1
  111. sglang/srt/server_args.py +19 -11
  112. sglang/srt/speculative/eagle_worker.py +75 -39
  113. sglang/srt/utils.py +104 -9
  114. sglang/test/runners.py +104 -10
  115. sglang/test/test_block_fp8.py +106 -16
  116. sglang/test/test_custom_ops.py +88 -0
  117. sglang/test/test_utils.py +20 -4
  118. sglang/utils.py +0 -4
  119. sglang/version.py +1 -1
  120. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.dist-info}/METADATA +9 -10
  121. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.dist-info}/RECORD +124 -79
  122. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.dist-info}/WHEEL +1 -1
  123. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.dist-info}/LICENSE +0 -0
  124. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.dist-info}/top_level.txt +0 -0
@@ -41,8 +41,6 @@ from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
41
41
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
42
42
  from sglang.srt.managers.io_struct import (
43
43
  AbortReq,
44
- BatchEmbeddingOut,
45
- BatchTokenIDOut,
46
44
  CloseSessionReqInput,
47
45
  FlushCacheReq,
48
46
  GetInternalStateReq,
@@ -74,7 +72,6 @@ from sglang.srt.managers.io_struct import (
74
72
  )
75
73
  from sglang.srt.managers.schedule_batch import (
76
74
  FINISH_ABORT,
77
- BaseFinishReason,
78
75
  ImageInputs,
79
76
  Req,
80
77
  ScheduleBatch,
@@ -85,6 +82,9 @@ from sglang.srt.managers.schedule_policy import (
85
82
  PrefillAdder,
86
83
  SchedulePolicy,
87
84
  )
85
+ from sglang.srt.managers.scheduler_output_processor_mixin import (
86
+ SchedulerOutputProcessorMixin,
87
+ )
88
88
  from sglang.srt.managers.session_controller import Session
89
89
  from sglang.srt.managers.tp_worker import TpModelWorker
90
90
  from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
@@ -93,7 +93,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
93
93
  from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
94
94
  from sglang.srt.mem_cache.radix_cache import RadixCache
95
95
  from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
96
- from sglang.srt.model_executor.forward_batch_info import ForwardMode
96
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
97
97
  from sglang.srt.server_args import PortArgs, ServerArgs
98
98
  from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
99
99
  from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
@@ -103,6 +103,7 @@ from sglang.srt.utils import (
103
103
  crash_on_warnings,
104
104
  get_bool_env_var,
105
105
  get_zmq_socket,
106
+ kill_itself_when_parent_died,
106
107
  pyspy_dump_schedulers,
107
108
  set_gpu_proc_affinity,
108
109
  set_random_seed,
@@ -132,7 +133,7 @@ class EmbeddingBatchResult:
132
133
  bid: int
133
134
 
134
135
 
135
- class Scheduler:
136
+ class Scheduler(SchedulerOutputProcessorMixin):
136
137
  """A scheduler that manages a tensor parallel GPU worker."""
137
138
 
138
139
  def __init__(
@@ -159,6 +160,7 @@ class Scheduler:
159
160
  )
160
161
  self.gpu_id = gpu_id
161
162
  self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
163
+ self.page_size = server_args.page_size
162
164
 
163
165
  # Distributed rank info
164
166
  self.dp_size = server_args.dp_size
@@ -270,17 +272,18 @@ class Scheduler:
270
272
 
271
273
  # Init running status
272
274
  self.waiting_queue: List[Req] = []
273
- self.staging_reqs = {}
274
275
  # The running decoding batch for continuous batching
275
- self.running_batch: Optional[ScheduleBatch] = None
276
+ self.running_batch: ScheduleBatch = ScheduleBatch(reqs=[], batch_is_full=False)
276
277
  # The current forward batch
277
278
  self.cur_batch: Optional[ScheduleBatch] = None
278
- # The current forward batch
279
+ # The last forward batch
279
280
  self.last_batch: Optional[ScheduleBatch] = None
280
281
  self.forward_ct = 0
281
282
  self.forward_ct_decode = 0
282
283
  self.num_generated_tokens = 0
284
+ self.num_prefill_tokens = 0
283
285
  self.last_decode_stats_tic = time.time()
286
+ self.last_prefill_stats_tic = time.time()
284
287
  self.return_health_check_ct = 0
285
288
  self.current_stream = torch.get_device_module(self.device).current_stream()
286
289
  if self.device == "cpu":
@@ -308,7 +311,11 @@ class Scheduler:
308
311
  self.grammar_backend = None
309
312
 
310
313
  # Init schedule policy and new token estimation
311
- self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache)
314
+ self.policy = SchedulePolicy(
315
+ self.schedule_policy,
316
+ self.tree_cache,
317
+ self.enable_hierarchical_cache,
318
+ )
312
319
  assert (
313
320
  server_args.schedule_conservativeness >= 0
314
321
  ), "Invalid schedule_conservativeness"
@@ -327,11 +334,6 @@ class Scheduler:
327
334
  ) / global_config.default_new_token_ratio_decay_steps
328
335
  self.new_token_ratio = self.init_new_token_ratio
329
336
 
330
- # Tell whether the current running batch is full so that we can skip
331
- # the check of whether to prefill new requests.
332
- # This is an optimization to reduce the overhead of the prefill check.
333
- self.batch_is_full = False
334
-
335
337
  # Init watchdog thread
336
338
  self.watchdog_timeout = server_args.watchdog_timeout
337
339
  t = threading.Thread(target=self.watchdog_thread, daemon=True)
@@ -431,11 +433,13 @@ class Scheduler:
431
433
  self.tree_cache = HiRadixCache(
432
434
  req_to_token_pool=self.req_to_token_pool,
433
435
  token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
436
+ tp_cache_group=self.tp_worker.get_tp_cpu_group(),
434
437
  )
435
438
  else:
436
439
  self.tree_cache = RadixCache(
437
440
  req_to_token_pool=self.req_to_token_pool,
438
441
  token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
442
+ page_size=self.page_size,
439
443
  disable=server_args.disable_radix_cache,
440
444
  )
441
445
 
@@ -457,6 +461,7 @@ class Scheduler:
457
461
  # The largest context length (prefill + generation) of a single request
458
462
  self._largest_prefill_decode_len: int = 0
459
463
  self.last_gen_throughput: float = 0.0
464
+ self.last_input_throughput: float = 0.0
460
465
  self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
461
466
  self.spec_num_total_accepted_tokens = 0
462
467
  self.spec_num_total_forward_ct = 0
@@ -486,7 +491,7 @@ class Scheduler:
486
491
  result = self.run_batch(batch)
487
492
  self.process_batch_result(batch, result)
488
493
  else:
489
- # When the server is idle, so self-check and re-init some states
494
+ # When the server is idle, do self-check and re-init some states
490
495
  self.check_memory()
491
496
  self.new_token_ratio = self.init_new_token_ratio
492
497
 
@@ -526,7 +531,7 @@ class Scheduler:
526
531
  )
527
532
  self.process_batch_result(tmp_batch, tmp_result)
528
533
  elif batch is None:
529
- # When the server is idle, so self-check and re-init some states
534
+ # When the server is idle, do self-check and re-init some states
530
535
  self.check_memory()
531
536
  self.new_token_ratio = self.init_new_token_ratio
532
537
 
@@ -587,7 +592,7 @@ class Scheduler:
587
592
  for recv_req in recv_reqs:
588
593
  # If it is a health check generation request and there are running requests, ignore it.
589
594
  if is_health_check_generate_req(recv_req) and (
590
- self.chunked_req is not None or self.running_batch is not None
595
+ self.chunked_req is not None or not self.running_batch.is_empty()
591
596
  ):
592
597
  self.return_health_check_ct += 1
593
598
  continue
@@ -767,6 +772,30 @@ class Scheduler:
767
772
  )
768
773
  req.tokenizer = self.tokenizer
769
774
 
775
+ # Handle multimodal inputs
776
+ if recv_req.image_inputs is not None:
777
+ image_inputs = ImageInputs.from_dict(recv_req.image_inputs)
778
+ # Expand a single image token into multiple dummy tokens for receiving image embeddings
779
+ req.origin_input_ids = self.pad_input_ids_func(
780
+ req.origin_input_ids, image_inputs
781
+ )
782
+ req.extend_image_inputs(image_inputs)
783
+
784
+ if len(req.origin_input_ids) >= self.max_req_input_len:
785
+ error_msg = (
786
+ "Multimodal prompt is too long after expanding multimodal tokens. "
787
+ f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
788
+ )
789
+ logger.error(error_msg)
790
+ req.origin_input_ids = [0]
791
+ req.image_inputs = None
792
+ req.sampling_params.max_new_tokens = 0
793
+ req.finished_reason = FINISH_ABORT(
794
+ error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
795
+ )
796
+ self.waiting_queue.append(req)
797
+ return
798
+
770
799
  # Validate prompts length
771
800
  error_msg = validate_input_length(
772
801
  req,
@@ -787,6 +816,11 @@ class Scheduler:
787
816
  can_run_list: List[Req],
788
817
  running_bs: int,
789
818
  ):
819
+ gap_latency = time.time() - self.last_prefill_stats_tic
820
+ self.last_prefill_stats_tic = time.time()
821
+ self.last_input_throughput = self.num_prefill_tokens / gap_latency
822
+ self.num_prefill_tokens = 0
823
+
790
824
  num_used = self.max_total_num_tokens - (
791
825
  self.token_to_kv_pool_allocator.available_size()
792
826
  + self.tree_cache.evictable_size()
@@ -822,7 +856,7 @@ class Scheduler:
822
856
  self.last_decode_stats_tic = time.time()
823
857
  self.last_gen_throughput = self.num_generated_tokens / gap_latency
824
858
  self.num_generated_tokens = 0
825
- num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
859
+ num_running_reqs = len(self.running_batch.reqs)
826
860
  num_used = self.max_total_num_tokens - (
827
861
  self.token_to_kv_pool_allocator.available_size()
828
862
  + self.tree_cache.evictable_size()
@@ -886,8 +920,10 @@ class Scheduler:
886
920
  )
887
921
  if memory_leak:
888
922
  msg = (
889
- "KV cache pool leak detected!"
923
+ "KV cache pool leak detected! "
890
924
  f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
925
+ f"{self.token_to_kv_pool_allocator.available_size()=}\n"
926
+ f"{self.tree_cache.evictable_size()=}\n"
891
927
  )
892
928
  warnings.warn(msg)
893
929
  if crash_on_warnings():
@@ -913,7 +949,7 @@ class Scheduler:
913
949
  self.token_to_kv_pool_allocator.available_size()
914
950
  + self.tree_cache.evictable_size()
915
951
  )
916
- num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
952
+ num_running_reqs = len(self.running_batch.reqs)
917
953
  self.stats.num_running_reqs = num_running_reqs
918
954
  self.stats.num_used_tokens = num_used
919
955
  self.stats.token_usage = num_used / self.max_total_num_tokens
@@ -931,14 +967,20 @@ class Scheduler:
931
967
  self.tree_cache.cache_unfinished_req(self.chunked_req)
932
968
  # chunked request keeps its rid but will get a new req_pool_idx
933
969
  self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
934
- self.batch_is_full = False
970
+ self.running_batch.batch_is_full = False
935
971
 
972
+ # Filter batch
973
+ last_bs = self.last_batch.batch_size()
936
974
  self.last_batch.filter_batch()
975
+ if self.last_batch.batch_size() < last_bs:
976
+ self.running_batch.batch_is_full = False
977
+
978
+ # Merge the new batch into the running batch
937
979
  if not self.last_batch.is_empty():
938
- if self.running_batch is None:
980
+ if self.running_batch.is_empty():
939
981
  self.running_batch = self.last_batch
940
982
  else:
941
- # merge running_batch with prefill batch
983
+ # Merge running_batch with prefill batch
942
984
  self.running_batch.merge_batch(self.last_batch)
943
985
 
944
986
  new_batch = self.get_new_batch_prefill()
@@ -947,11 +989,11 @@ class Scheduler:
947
989
  ret = new_batch
948
990
  else:
949
991
  # Run decode
950
- if self.running_batch is None:
951
- ret = None
952
- else:
992
+ if not self.running_batch.is_empty():
953
993
  self.running_batch = self.update_running_batch(self.running_batch)
954
- ret = self.running_batch
994
+ ret = self.running_batch if not self.running_batch.is_empty() else None
995
+ else:
996
+ ret = None
955
997
 
956
998
  # Handle DP attention
957
999
  if self.server_args.enable_dp_attention:
@@ -966,15 +1008,20 @@ class Scheduler:
966
1008
 
967
1009
  # Handle the cases where prefill is not allowed
968
1010
  if (
969
- self.batch_is_full or len(self.waiting_queue) == 0
1011
+ self.running_batch.batch_is_full or len(self.waiting_queue) == 0
970
1012
  ) and self.chunked_req is None:
971
1013
  return None
972
1014
 
973
- running_bs = len(self.running_batch.reqs) if self.running_batch else 0
1015
+ running_bs = len(self.running_batch.reqs)
974
1016
  if running_bs >= self.max_running_requests:
975
- self.batch_is_full = True
1017
+ self.running_batch.batch_is_full = True
976
1018
  return None
977
1019
 
1020
+ if self.enable_hierarchical_cache:
1021
+ # check for completion of hierarchical cache activities to release memory
1022
+ self.tree_cache.writing_check()
1023
+ self.tree_cache.loading_check()
1024
+
978
1025
  # Get priority queue
979
1026
  prefix_computed = self.policy.calc_priority(self.waiting_queue)
980
1027
 
@@ -989,17 +1036,13 @@ class Scheduler:
989
1036
  running_bs if self.is_mixed_chunk else 0,
990
1037
  )
991
1038
 
992
- is_chunked = self.chunked_req is not None
993
- if is_chunked:
1039
+ if self.chunked_req is not None:
994
1040
  self.chunked_req.init_next_round_input()
995
1041
  self.chunked_req = adder.add_chunked_req(self.chunked_req)
996
1042
 
997
1043
  if self.lora_paths:
998
- lora_set = (
999
- set([req.lora_path for req in self.running_batch.reqs])
1000
- if self.running_batch is not None
1001
- else set([])
1002
- )
1044
+ lora_set = set([req.lora_path for req in self.running_batch.reqs])
1045
+
1003
1046
  # Get requests from the waiting queue to a new prefill batch
1004
1047
  for req in self.waiting_queue:
1005
1048
  if (
@@ -1011,49 +1054,33 @@ class Scheduler:
1011
1054
  )
1012
1055
  > self.max_loras_per_batch
1013
1056
  ):
1014
- self.batch_is_full = True
1057
+ self.running_batch.batch_is_full = True
1015
1058
  break
1016
1059
 
1017
1060
  if running_bs + len(adder.can_run_list) >= self.max_running_requests:
1018
- self.batch_is_full = True
1061
+ self.running_batch.batch_is_full = True
1019
1062
  break
1020
1063
 
1021
- req.init_next_round_input(None if prefix_computed else self.tree_cache)
1064
+ req.init_next_round_input(
1065
+ None if prefix_computed else self.tree_cache,
1066
+ self.enable_hierarchical_cache,
1067
+ )
1022
1068
 
1023
- if self.enable_hierarchical_cache and req.last_node is not None:
1024
- if req.last_node.evicted:
1025
- # loading KV cache for the request
1026
- req.last_node, req.prefix_indices = self.tree_cache.init_load_back(
1027
- req.last_node,
1028
- req.prefix_indices,
1029
- adder.rem_total_tokens,
1030
- )
1031
- if req.last_node.loading:
1032
- # to prevent frequent cache invalidation
1033
- if req.rid in self.staging_reqs:
1034
- self.tree_cache.dec_lock_ref(self.staging_reqs[req.rid])
1035
- self.tree_cache.inc_lock_ref(req.last_node)
1036
- self.staging_reqs[req.rid] = req.last_node
1037
- continue
1038
- elif req.last_node.loading:
1039
- if not self.tree_cache.loading_complete(req.last_node):
1040
- continue
1041
-
1042
- if req.rid in self.staging_reqs:
1043
- self.tree_cache.dec_lock_ref(self.staging_reqs[req.rid])
1044
- del self.staging_reqs[req.rid]
1045
-
1046
- res = adder.add_one_req(req, self.chunked_req)
1069
+ res = adder.add_one_req(
1070
+ req, self.chunked_req, self.enable_hierarchical_cache
1071
+ )
1047
1072
  if res != AddReqResult.CONTINUE:
1048
1073
  if res == AddReqResult.NO_TOKEN:
1049
1074
  if self.enable_hierarchical_cache:
1050
1075
  # Set batch_is_full after making sure there are requests that can be served
1051
- self.batch_is_full = len(adder.can_run_list) > 0 or (
1076
+ self.running_batch.batch_is_full = len(
1077
+ adder.can_run_list
1078
+ ) > 0 or (
1052
1079
  self.running_batch is not None
1053
1080
  and not self.running_batch.is_empty()
1054
1081
  )
1055
1082
  else:
1056
- self.batch_is_full = True
1083
+ self.running_batch.batch_is_full = True
1057
1084
  break
1058
1085
 
1059
1086
  # Update waiting queue
@@ -1064,6 +1091,9 @@ class Scheduler:
1064
1091
  x for x in self.waiting_queue if x not in set(can_run_list)
1065
1092
  ]
1066
1093
 
1094
+ if self.enable_hierarchical_cache:
1095
+ self.tree_cache.read_to_load_cache()
1096
+
1067
1097
  if adder.new_chunked_req is not None:
1068
1098
  assert self.chunked_req is None
1069
1099
  self.chunked_req = adder.new_chunked_req
@@ -1091,7 +1121,7 @@ class Scheduler:
1091
1121
  # Mixed-style chunked prefill
1092
1122
  if (
1093
1123
  self.is_mixed_chunk
1094
- and self.running_batch is not None
1124
+ and not self.running_batch.is_empty()
1095
1125
  and not (new_batch.return_logprob or self.running_batch.return_logprob)
1096
1126
  ):
1097
1127
  # TODO (lianmin): support return_logprob + mixed chunked prefill
@@ -1100,7 +1130,9 @@ class Scheduler:
1100
1130
  self.running_batch.prepare_for_decode()
1101
1131
  new_batch.mix_with_running(self.running_batch)
1102
1132
  new_batch.decoding_reqs = self.running_batch.reqs
1103
- self.running_batch = None
1133
+ self.running_batch = ScheduleBatch(
1134
+ reqs=[], batch_is_full=self.running_batch.batch_is_full
1135
+ )
1104
1136
  else:
1105
1137
  new_batch.decoding_reqs = None
1106
1138
 
@@ -1112,8 +1144,8 @@ class Scheduler:
1112
1144
 
1113
1145
  batch.filter_batch()
1114
1146
  if batch.is_empty():
1115
- self.batch_is_full = False
1116
- return None
1147
+ batch.batch_is_full = False
1148
+ return batch
1117
1149
 
1118
1150
  # Check if decode out of memory
1119
1151
  if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or (
@@ -1137,7 +1169,7 @@ class Scheduler:
1137
1169
  )
1138
1170
 
1139
1171
  if batch.batch_size() < initial_bs:
1140
- self.batch_is_full = False
1172
+ batch.batch_is_full = False
1141
1173
 
1142
1174
  # Update batch tensors
1143
1175
  batch.prepare_for_decode()
@@ -1212,8 +1244,6 @@ class Scheduler:
1212
1244
  ):
1213
1245
  if batch.forward_mode.is_decode():
1214
1246
  self.process_batch_result_decode(batch, result)
1215
- if batch.is_empty():
1216
- self.running_batch = None
1217
1247
  elif batch.forward_mode.is_extend():
1218
1248
  self.process_batch_result_prefill(batch, result)
1219
1249
  elif batch.forward_mode.is_idle():
@@ -1235,578 +1265,6 @@ class Scheduler:
1235
1265
  self.return_health_check_ct -= 1
1236
1266
  self.send_to_tokenizer.send_pyobj(HealthCheckOutput())
1237
1267
 
1238
- def process_batch_result_prefill(
1239
- self,
1240
- batch: ScheduleBatch,
1241
- result: Union[GenerationBatchResult, EmbeddingBatchResult],
1242
- ):
1243
- skip_stream_req = None
1244
-
1245
- if self.is_generation:
1246
- (
1247
- logits_output,
1248
- next_token_ids,
1249
- extend_input_len_per_req,
1250
- extend_logprob_start_len_per_req,
1251
- bid,
1252
- ) = (
1253
- result.logits_output,
1254
- result.next_token_ids,
1255
- result.extend_input_len_per_req,
1256
- result.extend_logprob_start_len_per_req,
1257
- result.bid,
1258
- )
1259
-
1260
- if self.enable_overlap:
1261
- logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
1262
- else:
1263
- # Move next_token_ids and logprobs to cpu
1264
- next_token_ids = next_token_ids.tolist()
1265
- if batch.return_logprob:
1266
- if logits_output.next_token_logprobs is not None:
1267
- logits_output.next_token_logprobs = (
1268
- logits_output.next_token_logprobs.tolist()
1269
- )
1270
- if logits_output.input_token_logprobs is not None:
1271
- logits_output.input_token_logprobs = tuple(
1272
- logits_output.input_token_logprobs.tolist()
1273
- )
1274
-
1275
- hidden_state_offset = 0
1276
-
1277
- # Check finish conditions
1278
- logprob_pt = 0
1279
- for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
1280
- if req.is_retracted:
1281
- continue
1282
-
1283
- if self.is_mixed_chunk and self.enable_overlap and req.finished():
1284
- # Free the one delayed token for the mixed decode batch
1285
- j = len(batch.out_cache_loc) - len(batch.reqs) + i
1286
- self.token_to_kv_pool_allocator.free(batch.out_cache_loc[j : j + 1])
1287
- continue
1288
-
1289
- if req.is_chunked <= 0:
1290
- # req output_ids are set here
1291
- req.output_ids.append(next_token_id)
1292
- req.check_finished()
1293
-
1294
- if req.finished():
1295
- self.tree_cache.cache_finished_req(req)
1296
- elif not batch.decoding_reqs or req not in batch.decoding_reqs:
1297
- # This updates radix so others can match
1298
- self.tree_cache.cache_unfinished_req(req)
1299
-
1300
- if req.return_logprob:
1301
- assert extend_logprob_start_len_per_req is not None
1302
- assert extend_input_len_per_req is not None
1303
- extend_logprob_start_len = extend_logprob_start_len_per_req[i]
1304
- extend_input_len = extend_input_len_per_req[i]
1305
- num_input_logprobs = extend_input_len - extend_logprob_start_len
1306
- self.add_logprob_return_values(
1307
- i,
1308
- req,
1309
- logprob_pt,
1310
- next_token_ids,
1311
- num_input_logprobs,
1312
- logits_output,
1313
- )
1314
- logprob_pt += num_input_logprobs
1315
-
1316
- if (
1317
- req.return_hidden_states
1318
- and logits_output.hidden_states is not None
1319
- ):
1320
- req.hidden_states.append(
1321
- logits_output.hidden_states[
1322
- hidden_state_offset : (
1323
- hidden_state_offset := hidden_state_offset
1324
- + len(req.origin_input_ids)
1325
- )
1326
- ]
1327
- .cpu()
1328
- .clone()
1329
- )
1330
-
1331
- if req.grammar is not None:
1332
- req.grammar.accept_token(next_token_id)
1333
- req.grammar.finished = req.finished()
1334
- else:
1335
- # being chunked reqs' prefill is not finished
1336
- req.is_chunked -= 1
1337
- # There is only at most one request being currently chunked.
1338
- # Because this request does not finish prefill,
1339
- # we don't want to stream the request currently being chunked.
1340
- skip_stream_req = req
1341
-
1342
- # Incrementally update input logprobs.
1343
- if req.return_logprob:
1344
- extend_logprob_start_len = extend_logprob_start_len_per_req[i]
1345
- extend_input_len = extend_input_len_per_req[i]
1346
- if extend_logprob_start_len < extend_input_len:
1347
- # Update input logprobs.
1348
- num_input_logprobs = (
1349
- extend_input_len - extend_logprob_start_len
1350
- )
1351
- self.add_input_logprob_return_values(
1352
- i,
1353
- req,
1354
- logits_output,
1355
- logprob_pt,
1356
- num_input_logprobs,
1357
- last_prefill_chunk=False,
1358
- )
1359
- logprob_pt += num_input_logprobs
1360
-
1361
- if batch.next_batch_sampling_info:
1362
- batch.next_batch_sampling_info.update_regex_vocab_mask()
1363
- self.current_stream.synchronize()
1364
- batch.next_batch_sampling_info.sampling_info_done.set()
1365
-
1366
- else: # embedding or reward model
1367
- embeddings, bid = result.embeddings, result.bid
1368
- embeddings = embeddings.tolist()
1369
-
1370
- # Check finish conditions
1371
- for i, req in enumerate(batch.reqs):
1372
- if req.is_retracted:
1373
- continue
1374
-
1375
- req.embedding = embeddings[i]
1376
- if req.is_chunked <= 0:
1377
- # Dummy output token for embedding models
1378
- req.output_ids.append(0)
1379
- req.check_finished()
1380
-
1381
- if req.finished():
1382
- self.tree_cache.cache_finished_req(req)
1383
- else:
1384
- self.tree_cache.cache_unfinished_req(req)
1385
- else:
1386
- # being chunked reqs' prefill is not finished
1387
- req.is_chunked -= 1
1388
-
1389
- self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
1390
-
1391
- def process_batch_result_decode(
1392
- self,
1393
- batch: ScheduleBatch,
1394
- result: GenerationBatchResult,
1395
- ):
1396
- logits_output, next_token_ids, bid = (
1397
- result.logits_output,
1398
- result.next_token_ids,
1399
- result.bid,
1400
- )
1401
- self.num_generated_tokens += len(batch.reqs)
1402
-
1403
- if self.enable_overlap:
1404
- assert batch.spec_algorithm.is_none()
1405
- logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
1406
- next_token_logprobs = logits_output.next_token_logprobs
1407
- elif batch.spec_algorithm.is_none():
1408
- # spec decoding handles output logprobs inside verify process.
1409
- next_token_ids = next_token_ids.tolist()
1410
- if batch.return_logprob:
1411
- next_token_logprobs = logits_output.next_token_logprobs.tolist()
1412
-
1413
- self.token_to_kv_pool_allocator.free_group_begin()
1414
-
1415
- # Check finish condition
1416
- # NOTE: the length of reqs and next_token_ids don't match if it is spec decoding.
1417
- # We should ignore using next_token_ids for spec decoding cases.
1418
- for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
1419
- if req.is_retracted:
1420
- continue
1421
-
1422
- if self.enable_overlap and req.finished():
1423
- # Free the one delayed token
1424
- self.token_to_kv_pool_allocator.free(batch.out_cache_loc[i : i + 1])
1425
- continue
1426
-
1427
- if batch.spec_algorithm.is_none():
1428
- # speculative worker will solve the output_ids in speculative decoding
1429
- req.output_ids.append(next_token_id)
1430
-
1431
- req.check_finished()
1432
- if req.finished():
1433
- self.tree_cache.cache_finished_req(req)
1434
-
1435
- if req.return_logprob and batch.spec_algorithm.is_none():
1436
- # speculative worker handles logprob in speculative decoding
1437
- req.output_token_logprobs_val.append(next_token_logprobs[i])
1438
- req.output_token_logprobs_idx.append(next_token_id)
1439
- if req.top_logprobs_num > 0:
1440
- req.output_top_logprobs_val.append(
1441
- logits_output.next_token_top_logprobs_val[i]
1442
- )
1443
- req.output_top_logprobs_idx.append(
1444
- logits_output.next_token_top_logprobs_idx[i]
1445
- )
1446
- if req.token_ids_logprob is not None:
1447
- req.output_token_ids_logprobs_val.append(
1448
- logits_output.next_token_token_ids_logprobs_val[i]
1449
- )
1450
- req.output_token_ids_logprobs_idx.append(
1451
- logits_output.next_token_token_ids_logprobs_idx[i]
1452
- )
1453
-
1454
- if req.return_hidden_states and logits_output.hidden_states is not None:
1455
- req.hidden_states.append(logits_output.hidden_states[i].cpu().clone())
1456
-
1457
- if req.grammar is not None and batch.spec_algorithm.is_none():
1458
- req.grammar.accept_token(next_token_id)
1459
- req.grammar.finished = req.finished()
1460
-
1461
- if batch.next_batch_sampling_info:
1462
- batch.next_batch_sampling_info.update_regex_vocab_mask()
1463
- self.current_stream.synchronize()
1464
- batch.next_batch_sampling_info.sampling_info_done.set()
1465
-
1466
- self.stream_output(batch.reqs, batch.return_logprob)
1467
-
1468
- self.token_to_kv_pool_allocator.free_group_end()
1469
-
1470
- self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
1471
- if (
1472
- self.attn_tp_rank == 0
1473
- and self.forward_ct_decode % self.server_args.decode_log_interval == 0
1474
- ):
1475
- self.log_decode_stats()
1476
-
1477
- def add_input_logprob_return_values(
1478
- self,
1479
- i: int,
1480
- req: Req,
1481
- output: LogitsProcessorOutput,
1482
- logprob_pt: int,
1483
- num_input_logprobs: int,
1484
- last_prefill_chunk: bool, # If True, it means prefill is finished.
1485
- ):
1486
- """Incrementally add input logprobs to `req`.
1487
-
1488
- Args:
1489
- i: The request index in a batch.
1490
- req: The request. Input logprobs inside req are modified as a
1491
- consequence of the API
1492
- fill_ids: The prefill ids processed.
1493
- output: Logit processor output that's used to compute input logprobs
1494
- last_prefill_chunk: True if it is the last prefill (when chunked).
1495
- Some of input logprob operation should only happen at the last
1496
- prefill (e.g., computing input token logprobs).
1497
- """
1498
- assert output.input_token_logprobs is not None
1499
- if req.input_token_logprobs is None:
1500
- req.input_token_logprobs = []
1501
- if req.temp_input_top_logprobs_val is None:
1502
- req.temp_input_top_logprobs_val = []
1503
- if req.temp_input_top_logprobs_idx is None:
1504
- req.temp_input_top_logprobs_idx = []
1505
- if req.temp_input_token_ids_logprobs_val is None:
1506
- req.temp_input_token_ids_logprobs_val = []
1507
- if req.temp_input_token_ids_logprobs_idx is None:
1508
- req.temp_input_token_ids_logprobs_idx = []
1509
-
1510
- if req.input_token_logprobs_val is not None:
1511
- # The input logprob has been already computed. It only happens
1512
- # upon retract.
1513
- if req.top_logprobs_num > 0:
1514
- assert req.input_token_logprobs_val is not None
1515
- return
1516
-
1517
- # Important for the performance.
1518
- assert isinstance(output.input_token_logprobs, tuple)
1519
- input_token_logprobs: Tuple[int] = output.input_token_logprobs
1520
- input_token_logprobs = input_token_logprobs[
1521
- logprob_pt : logprob_pt + num_input_logprobs
1522
- ]
1523
- req.input_token_logprobs.extend(input_token_logprobs)
1524
-
1525
- if req.top_logprobs_num > 0:
1526
- req.temp_input_top_logprobs_val.append(output.input_top_logprobs_val[i])
1527
- req.temp_input_top_logprobs_idx.append(output.input_top_logprobs_idx[i])
1528
-
1529
- if req.token_ids_logprob is not None:
1530
- req.temp_input_token_ids_logprobs_val.append(
1531
- output.input_token_ids_logprobs_val[i]
1532
- )
1533
- req.temp_input_token_ids_logprobs_idx.append(
1534
- output.input_token_ids_logprobs_idx[i]
1535
- )
1536
-
1537
- if last_prefill_chunk:
1538
- input_token_logprobs = req.input_token_logprobs
1539
- req.input_token_logprobs = None
1540
- assert req.input_token_logprobs_val is None
1541
- assert req.input_token_logprobs_idx is None
1542
- assert req.input_top_logprobs_val is None
1543
- assert req.input_top_logprobs_idx is None
1544
-
1545
- # Compute input_token_logprobs_val
1546
- # Always pad the first one with None.
1547
- req.input_token_logprobs_val = [None]
1548
- req.input_token_logprobs_val.extend(input_token_logprobs)
1549
- # The last input logprob is for sampling, so just pop it out.
1550
- req.input_token_logprobs_val.pop()
1551
-
1552
- # Compute input_token_logprobs_idx
1553
- input_token_logprobs_idx = req.origin_input_ids[req.logprob_start_len :]
1554
- # Clip the padded hash values from image tokens.
1555
- # Otherwise, it will lead to detokenization errors.
1556
- input_token_logprobs_idx = [
1557
- x if x < self.model_config.vocab_size - 1 else 0
1558
- for x in input_token_logprobs_idx
1559
- ]
1560
- req.input_token_logprobs_idx = input_token_logprobs_idx
1561
-
1562
- if req.top_logprobs_num > 0:
1563
- req.input_top_logprobs_val = [None]
1564
- req.input_top_logprobs_idx = [None]
1565
- assert len(req.temp_input_token_ids_logprobs_val) == len(
1566
- req.temp_input_token_ids_logprobs_idx
1567
- )
1568
- for val, idx in zip(
1569
- req.temp_input_top_logprobs_val,
1570
- req.temp_input_top_logprobs_idx,
1571
- strict=True,
1572
- ):
1573
- req.input_top_logprobs_val.extend(val)
1574
- req.input_top_logprobs_idx.extend(idx)
1575
-
1576
- # Last token is a sample token.
1577
- req.input_top_logprobs_val.pop()
1578
- req.input_top_logprobs_idx.pop()
1579
- req.temp_input_top_logprobs_idx = None
1580
- req.temp_input_top_logprobs_val = None
1581
-
1582
- if req.token_ids_logprob is not None:
1583
- req.input_token_ids_logprobs_val = [None]
1584
- req.input_token_ids_logprobs_idx = [None]
1585
-
1586
- for val, idx in zip(
1587
- req.temp_input_token_ids_logprobs_val,
1588
- req.temp_input_token_ids_logprobs_idx,
1589
- strict=True,
1590
- ):
1591
- req.input_token_ids_logprobs_val.extend(val)
1592
- req.input_token_ids_logprobs_idx.extend(idx)
1593
-
1594
- # Last token is a sample token.
1595
- req.input_token_ids_logprobs_val.pop()
1596
- req.input_token_ids_logprobs_idx.pop()
1597
- req.temp_input_token_ids_logprobs_idx = None
1598
- req.temp_input_token_ids_logprobs_val = None
1599
-
1600
- if req.return_logprob:
1601
- relevant_tokens_len = len(req.origin_input_ids) - req.logprob_start_len
1602
- assert len(req.input_token_logprobs_val) == relevant_tokens_len
1603
- assert len(req.input_token_logprobs_idx) == relevant_tokens_len
1604
- if req.top_logprobs_num > 0:
1605
- assert len(req.input_top_logprobs_val) == relevant_tokens_len
1606
- assert len(req.input_top_logprobs_idx) == relevant_tokens_len
1607
- if req.token_ids_logprob is not None:
1608
- assert len(req.input_token_ids_logprobs_val) == relevant_tokens_len
1609
- assert len(req.input_token_ids_logprobs_idx) == relevant_tokens_len
1610
-
1611
- def add_logprob_return_values(
1612
- self,
1613
- i: int,
1614
- req: Req,
1615
- pt: int,
1616
- next_token_ids: List[int],
1617
- num_input_logprobs: int,
1618
- output: LogitsProcessorOutput,
1619
- ):
1620
- """Attach logprobs to the return values."""
1621
- req.output_token_logprobs_val.append(output.next_token_logprobs[i])
1622
- req.output_token_logprobs_idx.append(next_token_ids[i])
1623
-
1624
- self.add_input_logprob_return_values(
1625
- i, req, output, pt, num_input_logprobs, last_prefill_chunk=True
1626
- )
1627
-
1628
- if req.top_logprobs_num > 0:
1629
- req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i])
1630
- req.output_top_logprobs_idx.append(output.next_token_top_logprobs_idx[i])
1631
-
1632
- if req.token_ids_logprob is not None:
1633
- req.output_token_ids_logprobs_val.append(
1634
- output.next_token_token_ids_logprobs_val[i]
1635
- )
1636
- req.output_token_ids_logprobs_idx.append(
1637
- output.next_token_token_ids_logprobs_idx[i]
1638
- )
1639
-
1640
- return num_input_logprobs
1641
-
1642
- def stream_output(
1643
- self, reqs: List[Req], return_logprob: bool, skip_req: Optional[Req] = None
1644
- ):
1645
- """Stream the output to detokenizer."""
1646
- rids = []
1647
- finished_reasons: List[BaseFinishReason] = []
1648
-
1649
- if self.is_generation:
1650
- decoded_texts = []
1651
- decode_ids_list = []
1652
- read_offsets = []
1653
- output_ids = []
1654
-
1655
- skip_special_tokens = []
1656
- spaces_between_special_tokens = []
1657
- no_stop_trim = []
1658
- prompt_tokens = []
1659
- completion_tokens = []
1660
- cached_tokens = []
1661
- spec_verify_ct = []
1662
- output_hidden_states = None
1663
-
1664
- if return_logprob:
1665
- input_token_logprobs_val = []
1666
- input_token_logprobs_idx = []
1667
- output_token_logprobs_val = []
1668
- output_token_logprobs_idx = []
1669
- input_top_logprobs_val = []
1670
- input_top_logprobs_idx = []
1671
- output_top_logprobs_val = []
1672
- output_top_logprobs_idx = []
1673
- input_token_ids_logprobs_val = []
1674
- input_token_ids_logprobs_idx = []
1675
- output_token_ids_logprobs_val = []
1676
- output_token_ids_logprobs_idx = []
1677
- else:
1678
- input_token_logprobs_val = input_token_logprobs_idx = (
1679
- output_token_logprobs_val
1680
- ) = output_token_logprobs_idx = input_top_logprobs_val = (
1681
- input_top_logprobs_idx
1682
- ) = output_top_logprobs_val = output_top_logprobs_idx = (
1683
- input_token_ids_logprobs_val
1684
- ) = input_token_ids_logprobs_idx = output_token_ids_logprobs_val = (
1685
- output_token_ids_logprobs_idx
1686
- ) = None
1687
-
1688
- for req in reqs:
1689
- if req is skip_req:
1690
- continue
1691
-
1692
- # Multimodal partial stream chunks break the detokenizer, so drop aborted requests here.
1693
- if self.model_config.is_multimodal_gen and req.to_abort:
1694
- continue
1695
-
1696
- if (
1697
- req.finished()
1698
- # If stream, follow the given stream_interval
1699
- or (req.stream and len(req.output_ids) % self.stream_interval == 0)
1700
- # If not stream, we still want to output some tokens to get the benefit of incremental decoding.
1701
- # TODO(lianmin): this is wrong for speculative decoding because len(req.output_ids) does not
1702
- # always increase one-by-one.
1703
- or (
1704
- not req.stream
1705
- and len(req.output_ids) % 50 == 0
1706
- and not self.model_config.is_multimodal_gen
1707
- )
1708
- ):
1709
- rids.append(req.rid)
1710
- finished_reasons.append(
1711
- req.finished_reason.to_json() if req.finished_reason else None
1712
- )
1713
- decoded_texts.append(req.decoded_text)
1714
- decode_ids, read_offset = req.init_incremental_detokenize()
1715
- decode_ids_list.append(decode_ids)
1716
- read_offsets.append(read_offset)
1717
- if self.skip_tokenizer_init:
1718
- output_ids.append(req.output_ids)
1719
- skip_special_tokens.append(req.sampling_params.skip_special_tokens)
1720
- spaces_between_special_tokens.append(
1721
- req.sampling_params.spaces_between_special_tokens
1722
- )
1723
- no_stop_trim.append(req.sampling_params.no_stop_trim)
1724
-
1725
- prompt_tokens.append(len(req.origin_input_ids))
1726
- completion_tokens.append(len(req.output_ids))
1727
- cached_tokens.append(req.cached_tokens)
1728
-
1729
- if not self.spec_algorithm.is_none():
1730
- spec_verify_ct.append(req.spec_verify_ct)
1731
-
1732
- if return_logprob:
1733
- input_token_logprobs_val.append(req.input_token_logprobs_val)
1734
- input_token_logprobs_idx.append(req.input_token_logprobs_idx)
1735
- output_token_logprobs_val.append(req.output_token_logprobs_val)
1736
- output_token_logprobs_idx.append(req.output_token_logprobs_idx)
1737
- input_top_logprobs_val.append(req.input_top_logprobs_val)
1738
- input_top_logprobs_idx.append(req.input_top_logprobs_idx)
1739
- output_top_logprobs_val.append(req.output_top_logprobs_val)
1740
- output_top_logprobs_idx.append(req.output_top_logprobs_idx)
1741
- input_token_ids_logprobs_val.append(
1742
- req.input_token_ids_logprobs_val
1743
- )
1744
- input_token_ids_logprobs_idx.append(
1745
- req.input_token_ids_logprobs_idx
1746
- )
1747
- output_token_ids_logprobs_val.append(
1748
- req.output_token_ids_logprobs_val
1749
- )
1750
- output_token_ids_logprobs_idx.append(
1751
- req.output_token_ids_logprobs_idx
1752
- )
1753
-
1754
- if req.return_hidden_states:
1755
- if output_hidden_states is None:
1756
- output_hidden_states = []
1757
- output_hidden_states.append(req.hidden_states)
1758
-
1759
- # Send to detokenizer
1760
- if rids:
1761
- if self.model_config.is_multimodal_gen:
1762
- raise NotImplementedError()
1763
- self.send_to_detokenizer.send_pyobj(
1764
- BatchTokenIDOut(
1765
- rids,
1766
- finished_reasons,
1767
- decoded_texts,
1768
- decode_ids_list,
1769
- read_offsets,
1770
- output_ids,
1771
- skip_special_tokens,
1772
- spaces_between_special_tokens,
1773
- no_stop_trim,
1774
- prompt_tokens,
1775
- completion_tokens,
1776
- cached_tokens,
1777
- spec_verify_ct,
1778
- input_token_logprobs_val,
1779
- input_token_logprobs_idx,
1780
- output_token_logprobs_val,
1781
- output_token_logprobs_idx,
1782
- input_top_logprobs_val,
1783
- input_top_logprobs_idx,
1784
- output_top_logprobs_val,
1785
- output_top_logprobs_idx,
1786
- input_token_ids_logprobs_val,
1787
- input_token_ids_logprobs_idx,
1788
- output_token_ids_logprobs_val,
1789
- output_token_ids_logprobs_idx,
1790
- output_hidden_states,
1791
- )
1792
- )
1793
- else: # embedding or reward model
1794
- embeddings = []
1795
- prompt_tokens = []
1796
- cached_tokens = []
1797
- for req in reqs:
1798
- if req.finished():
1799
- rids.append(req.rid)
1800
- finished_reasons.append(req.finished_reason.to_json())
1801
- embeddings.append(req.embedding)
1802
- prompt_tokens.append(len(req.origin_input_ids))
1803
- cached_tokens.append(req.cached_tokens)
1804
- self.send_to_detokenizer.send_pyobj(
1805
- BatchEmbeddingOut(
1806
- rids, finished_reasons, embeddings, prompt_tokens, cached_tokens
1807
- )
1808
- )
1809
-
1810
1268
  def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
1811
1269
  # Check if other DP workers have running batches
1812
1270
  if local_batch is None:
@@ -1926,9 +1384,7 @@ class Scheduler:
1926
1384
 
1927
1385
  def flush_cache(self):
1928
1386
  """Flush the memory pool and cache."""
1929
- if len(self.waiting_queue) == 0 and (
1930
- self.running_batch is None or len(self.running_batch.reqs) == 0
1931
- ):
1387
+ if len(self.waiting_queue) == 0 and self.running_batch.is_empty():
1932
1388
  self.cur_batch = None
1933
1389
  self.last_batch = None
1934
1390
  self.tree_cache.reset()
@@ -1954,7 +1410,7 @@ class Scheduler:
1954
1410
  logging.warning(
1955
1411
  f"Cache not flushed because there are pending requests. "
1956
1412
  f"#queue-req: {len(self.waiting_queue)}, "
1957
- f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
1413
+ f"#running-req: {len(self.running_batch.reqs)}"
1958
1414
  )
1959
1415
  if_success = False
1960
1416
  return if_success
@@ -2004,24 +1460,24 @@ class Scheduler:
2004
1460
 
2005
1461
  def abort_request(self, recv_req: AbortReq):
2006
1462
  # Delete requests in the waiting queue
2007
- to_del = None
1463
+ to_del = []
2008
1464
  for i, req in enumerate(self.waiting_queue):
2009
- if req.rid == recv_req.rid:
2010
- to_del = i
1465
+ if req.rid.startswith(recv_req.rid):
1466
+ to_del.append(i)
2011
1467
  break
2012
1468
 
2013
- if to_del is not None:
2014
- del self.waiting_queue[to_del]
1469
+ # Sort in reverse order to avoid index issues when deleting
1470
+ for i in sorted(to_del, reverse=True):
1471
+ req = self.waiting_queue.pop(i)
2015
1472
  logger.debug(f"Abort queued request. {req.rid=}")
2016
1473
  return
2017
1474
 
2018
1475
  # Delete requests in the running batch
2019
- if self.running_batch:
2020
- for req in self.running_batch.reqs:
2021
- if req.rid == recv_req.rid and not req.finished():
2022
- logger.debug(f"Abort running request. {req.rid=}")
2023
- req.to_abort = True
2024
- break
1476
+ for req in self.running_batch.reqs:
1477
+ if req.rid.startswith(recv_req.rid) and not req.finished():
1478
+ logger.debug(f"Abort running request. {req.rid=}")
1479
+ req.to_abort = True
1480
+ return
2025
1481
 
2026
1482
  def _pause_engine(self) -> Tuple[List[Req], int]:
2027
1483
  raise NotImplementedError()
@@ -2228,9 +1684,16 @@ def run_scheduler_process(
2228
1684
  dp_rank: Optional[int],
2229
1685
  pipe_writer,
2230
1686
  ):
1687
+
1688
+ # Generate the prefix
1689
+ if dp_rank is None:
1690
+ prefix = f" TP{tp_rank}"
1691
+ else:
1692
+ prefix = f" DP{dp_rank} TP{tp_rank}"
1693
+
2231
1694
  # Config the process
2232
1695
  # kill_itself_when_parent_died() # This is disabled because it does not work for `--dp 2`
2233
- setproctitle.setproctitle(f"sglang::scheduler_{dp_rank}")
1696
+ setproctitle.setproctitle(f"sglang::scheduler{prefix.replace(' ', '_')}")
2234
1697
  faulthandler.enable()
2235
1698
  parent_process = psutil.Process().parent()
2236
1699
 
@@ -2239,10 +1702,6 @@ def run_scheduler_process(
2239
1702
  dp_rank = int(os.environ["SGLANG_DP_RANK"])
2240
1703
 
2241
1704
  # Configure the logger
2242
- if dp_rank is None:
2243
- prefix = f" TP{tp_rank}"
2244
- else:
2245
- prefix = f" DP{dp_rank} TP{tp_rank}"
2246
1705
  configure_logger(server_args, prefix=prefix)
2247
1706
  suppress_other_loggers()
2248
1707