sglang 0.4.3.post3__py3-none-any.whl → 0.4.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (128) hide show
  1. sglang/bench_serving.py +2 -2
  2. sglang/lang/chat_template.py +29 -0
  3. sglang/srt/_custom_ops.py +19 -17
  4. sglang/srt/configs/__init__.py +2 -0
  5. sglang/srt/configs/janus_pro.py +629 -0
  6. sglang/srt/configs/model_config.py +24 -14
  7. sglang/srt/conversation.py +80 -2
  8. sglang/srt/custom_op.py +64 -3
  9. sglang/srt/distributed/device_communicators/custom_all_reduce.py +18 -17
  10. sglang/srt/distributed/parallel_state.py +10 -1
  11. sglang/srt/entrypoints/engine.py +5 -3
  12. sglang/srt/entrypoints/http_server.py +1 -1
  13. sglang/srt/hf_transformers_utils.py +16 -1
  14. sglang/srt/layers/attention/flashinfer_backend.py +95 -49
  15. sglang/srt/layers/attention/flashinfer_mla_backend.py +317 -57
  16. sglang/srt/layers/attention/triton_backend.py +5 -5
  17. sglang/srt/layers/attention/triton_ops/decode_attention.py +6 -6
  18. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +3 -3
  19. sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
  20. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +3 -3
  21. sglang/srt/layers/attention/vision.py +43 -62
  22. sglang/srt/layers/linear.py +1 -1
  23. sglang/srt/layers/moe/ep_moe/kernels.py +2 -1
  24. sglang/srt/layers/moe/ep_moe/layer.py +25 -9
  25. sglang/srt/layers/moe/fused_moe_triton/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  27. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +63 -23
  32. sglang/srt/layers/moe/fused_moe_triton/layer.py +16 -4
  33. sglang/srt/layers/parameter.py +10 -0
  34. sglang/srt/layers/quantization/__init__.py +90 -68
  35. sglang/srt/layers/quantization/blockwise_int8.py +1 -2
  36. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  37. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  38. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  39. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  40. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  41. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  44. sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  45. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  46. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  47. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  48. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  49. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  50. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  51. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  52. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/quantization/fp8.py +174 -106
  63. sglang/srt/layers/quantization/fp8_kernel.py +210 -38
  64. sglang/srt/layers/quantization/fp8_utils.py +156 -15
  65. sglang/srt/layers/quantization/modelopt_quant.py +5 -1
  66. sglang/srt/layers/quantization/w8a8_fp8.py +128 -0
  67. sglang/srt/layers/quantization/w8a8_int8.py +152 -3
  68. sglang/srt/layers/rotary_embedding.py +5 -3
  69. sglang/srt/layers/sampler.py +29 -35
  70. sglang/srt/layers/vocab_parallel_embedding.py +0 -1
  71. sglang/srt/lora/backend/__init__.py +9 -12
  72. sglang/srt/managers/cache_controller.py +72 -8
  73. sglang/srt/managers/image_processor.py +37 -631
  74. sglang/srt/managers/image_processors/base_image_processor.py +219 -0
  75. sglang/srt/managers/image_processors/janus_pro.py +79 -0
  76. sglang/srt/managers/image_processors/llava.py +152 -0
  77. sglang/srt/managers/image_processors/minicpmv.py +86 -0
  78. sglang/srt/managers/image_processors/mlama.py +60 -0
  79. sglang/srt/managers/image_processors/qwen_vl.py +161 -0
  80. sglang/srt/managers/io_struct.py +33 -15
  81. sglang/srt/managers/multi_modality_padding.py +134 -0
  82. sglang/srt/managers/schedule_batch.py +212 -117
  83. sglang/srt/managers/schedule_policy.py +40 -8
  84. sglang/srt/managers/scheduler.py +258 -782
  85. sglang/srt/managers/scheduler_output_processor_mixin.py +611 -0
  86. sglang/srt/managers/tokenizer_manager.py +7 -6
  87. sglang/srt/managers/tp_worker_overlap_thread.py +4 -1
  88. sglang/srt/mem_cache/base_prefix_cache.py +6 -8
  89. sglang/srt/mem_cache/chunk_cache.py +12 -44
  90. sglang/srt/mem_cache/hiradix_cache.py +63 -34
  91. sglang/srt/mem_cache/memory_pool.py +112 -46
  92. sglang/srt/mem_cache/paged_allocator.py +283 -0
  93. sglang/srt/mem_cache/radix_cache.py +117 -36
  94. sglang/srt/metrics/collector.py +8 -0
  95. sglang/srt/model_executor/cuda_graph_runner.py +10 -11
  96. sglang/srt/model_executor/forward_batch_info.py +12 -8
  97. sglang/srt/model_executor/model_runner.py +153 -134
  98. sglang/srt/model_loader/loader.py +2 -1
  99. sglang/srt/model_loader/weight_utils.py +1 -1
  100. sglang/srt/models/deepseek_janus_pro.py +2127 -0
  101. sglang/srt/models/deepseek_nextn.py +23 -3
  102. sglang/srt/models/deepseek_v2.py +25 -19
  103. sglang/srt/models/minicpmv.py +28 -89
  104. sglang/srt/models/mllama.py +1 -1
  105. sglang/srt/models/qwen2.py +0 -1
  106. sglang/srt/models/qwen2_5_vl.py +25 -50
  107. sglang/srt/models/qwen2_vl.py +33 -49
  108. sglang/srt/openai_api/adapter.py +37 -15
  109. sglang/srt/openai_api/protocol.py +8 -1
  110. sglang/srt/sampling/penaltylib/frequency_penalty.py +0 -1
  111. sglang/srt/sampling/penaltylib/presence_penalty.py +0 -1
  112. sglang/srt/server_args.py +19 -20
  113. sglang/srt/speculative/build_eagle_tree.py +6 -1
  114. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -11
  115. sglang/srt/speculative/eagle_utils.py +2 -1
  116. sglang/srt/speculative/eagle_worker.py +109 -38
  117. sglang/srt/utils.py +104 -9
  118. sglang/test/runners.py +104 -10
  119. sglang/test/test_block_fp8.py +106 -16
  120. sglang/test/test_custom_ops.py +88 -0
  121. sglang/test/test_utils.py +20 -4
  122. sglang/utils.py +0 -4
  123. sglang/version.py +1 -1
  124. {sglang-0.4.3.post3.dist-info → sglang-0.4.4.dist-info}/METADATA +9 -9
  125. {sglang-0.4.3.post3.dist-info → sglang-0.4.4.dist-info}/RECORD +128 -83
  126. {sglang-0.4.3.post3.dist-info → sglang-0.4.4.dist-info}/WHEEL +1 -1
  127. {sglang-0.4.3.post3.dist-info → sglang-0.4.4.dist-info}/LICENSE +0 -0
  128. {sglang-0.4.3.post3.dist-info → sglang-0.4.4.dist-info}/top_level.txt +0 -0
@@ -41,8 +41,6 @@ from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
41
41
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
42
42
  from sglang.srt.managers.io_struct import (
43
43
  AbortReq,
44
- BatchEmbeddingOut,
45
- BatchTokenIDOut,
46
44
  CloseSessionReqInput,
47
45
  FlushCacheReq,
48
46
  GetInternalStateReq,
@@ -74,7 +72,6 @@ from sglang.srt.managers.io_struct import (
74
72
  )
75
73
  from sglang.srt.managers.schedule_batch import (
76
74
  FINISH_ABORT,
77
- BaseFinishReason,
78
75
  ImageInputs,
79
76
  Req,
80
77
  ScheduleBatch,
@@ -85,6 +82,9 @@ from sglang.srt.managers.schedule_policy import (
85
82
  PrefillAdder,
86
83
  SchedulePolicy,
87
84
  )
85
+ from sglang.srt.managers.scheduler_output_processor_mixin import (
86
+ SchedulerOutputProcessorMixin,
87
+ )
88
88
  from sglang.srt.managers.session_controller import Session
89
89
  from sglang.srt.managers.tp_worker import TpModelWorker
90
90
  from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
@@ -93,7 +93,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
93
93
  from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
94
94
  from sglang.srt.mem_cache.radix_cache import RadixCache
95
95
  from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
96
- from sglang.srt.model_executor.forward_batch_info import ForwardMode
96
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
97
97
  from sglang.srt.server_args import PortArgs, ServerArgs
98
98
  from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
99
99
  from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
@@ -103,6 +103,7 @@ from sglang.srt.utils import (
103
103
  crash_on_warnings,
104
104
  get_bool_env_var,
105
105
  get_zmq_socket,
106
+ kill_itself_when_parent_died,
106
107
  pyspy_dump_schedulers,
107
108
  set_gpu_proc_affinity,
108
109
  set_random_seed,
@@ -132,7 +133,7 @@ class EmbeddingBatchResult:
132
133
  bid: int
133
134
 
134
135
 
135
- class Scheduler:
136
+ class Scheduler(SchedulerOutputProcessorMixin):
136
137
  """A scheduler that manages a tensor parallel GPU worker."""
137
138
 
138
139
  def __init__(
@@ -159,17 +160,7 @@ class Scheduler:
159
160
  )
160
161
  self.gpu_id = gpu_id
161
162
  self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
162
- self.decode_mem_cache_buf_multiplier = (
163
- (
164
- self.server_args.speculative_num_draft_tokens
165
- + (
166
- self.server_args.speculative_eagle_topk
167
- * self.server_args.speculative_num_draft_tokens
168
- )
169
- )
170
- if not self.spec_algorithm.is_none()
171
- else 1
172
- )
163
+ self.page_size = server_args.page_size
173
164
 
174
165
  # Distributed rank info
175
166
  self.dp_size = server_args.dp_size
@@ -208,42 +199,12 @@ class Scheduler:
208
199
  self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
209
200
 
210
201
  # Init tokenizer
211
- self.model_config = ModelConfig(
212
- server_args.model_path,
213
- trust_remote_code=server_args.trust_remote_code,
214
- revision=server_args.revision,
215
- context_length=server_args.context_length,
216
- model_override_args=server_args.json_model_override_args,
217
- is_embedding=server_args.is_embedding,
218
- dtype=server_args.dtype,
219
- quantization=server_args.quantization,
220
- )
221
- self.is_generation = self.model_config.is_generation
222
-
223
- if server_args.skip_tokenizer_init:
224
- self.tokenizer = self.processor = None
225
- else:
226
- if self.model_config.is_multimodal:
227
- self.processor = get_processor(
228
- server_args.tokenizer_path,
229
- tokenizer_mode=server_args.tokenizer_mode,
230
- trust_remote_code=server_args.trust_remote_code,
231
- revision=server_args.revision,
232
- )
233
- self.tokenizer = self.processor.tokenizer
234
- else:
235
- self.tokenizer = get_tokenizer(
236
- server_args.tokenizer_path,
237
- tokenizer_mode=server_args.tokenizer_mode,
238
- trust_remote_code=server_args.trust_remote_code,
239
- revision=server_args.revision,
240
- )
202
+ self.init_tokenizer()
241
203
 
242
204
  # Check whether overlap can be enabled
243
205
  if not self.is_generation:
244
206
  self.enable_overlap = False
245
207
  logger.info("Overlap scheduler is disabled for embedding models.")
246
-
247
208
  if self.model_config.is_multimodal:
248
209
  self.enable_overlap = False
249
210
  logger.info("Overlap scheduler is disabled for multimodal models.")
@@ -274,10 +235,8 @@ class Scheduler:
274
235
  target_worker=self.tp_worker,
275
236
  dp_rank=dp_rank,
276
237
  )
277
- self.prefill_only_one_req = True
278
238
  else:
279
239
  self.draft_worker = None
280
- self.prefill_only_one_req = False
281
240
 
282
241
  # Get token and memory info from the model worker
283
242
  (
@@ -309,64 +268,28 @@ class Scheduler:
309
268
  )
310
269
 
311
270
  # Init memory pool and cache
312
- self.req_to_token_pool, self.token_to_kv_pool_allocator = (
313
- self.tp_worker.get_memory_pool()
314
- )
315
-
316
- if (
317
- server_args.chunked_prefill_size is not None
318
- and server_args.disable_radix_cache
319
- ):
320
- self.tree_cache = ChunkCache(
321
- req_to_token_pool=self.req_to_token_pool,
322
- token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
323
- )
324
- else:
325
- if self.enable_hierarchical_cache:
326
- self.tree_cache = HiRadixCache(
327
- req_to_token_pool=self.req_to_token_pool,
328
- token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
329
- )
330
- else:
331
- self.tree_cache = RadixCache(
332
- req_to_token_pool=self.req_to_token_pool,
333
- token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
334
- disable=server_args.disable_radix_cache,
335
- )
336
-
337
- self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache)
271
+ self.init_memory_pool_and_cache()
338
272
 
339
273
  # Init running status
340
274
  self.waiting_queue: List[Req] = []
341
- self.staging_reqs = {}
342
275
  # The running decoding batch for continuous batching
343
- self.running_batch: Optional[ScheduleBatch] = None
276
+ self.running_batch: ScheduleBatch = ScheduleBatch(reqs=[], batch_is_full=False)
344
277
  # The current forward batch
345
278
  self.cur_batch: Optional[ScheduleBatch] = None
346
- # The current forward batch
279
+ # The last forward batch
347
280
  self.last_batch: Optional[ScheduleBatch] = None
348
281
  self.forward_ct = 0
349
282
  self.forward_ct_decode = 0
350
283
  self.num_generated_tokens = 0
351
- self.spec_num_total_accepted_tokens = 0
352
- self.spec_num_total_forward_ct = 0
353
- self.cum_spec_accept_length = 0
354
- self.cum_spec_accept_count = 0
284
+ self.num_prefill_tokens = 0
355
285
  self.last_decode_stats_tic = time.time()
286
+ self.last_prefill_stats_tic = time.time()
356
287
  self.return_health_check_ct = 0
357
288
  self.current_stream = torch.get_device_module(self.device).current_stream()
358
289
  if self.device == "cpu":
359
290
  self.current_stream.synchronize = lambda: None # No-op for CPU
360
291
 
361
- # For metrics only.
362
- # The largest prefill length of a single request
363
- self._largest_prefill_len: int = 0
364
- # The largest context length (prefill + generation) of a single request
365
- self._largest_prefill_decode_len: int = 0
366
- self.last_gen_throughput: float = 0.0
367
- self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
368
-
369
- # Session info
292
+ # Init session info
370
293
  self.sessions: Dict[str, Session] = {}
371
294
 
372
295
  # Init chunked prefill
@@ -387,11 +310,15 @@ class Scheduler:
387
310
  else:
388
311
  self.grammar_backend = None
389
312
 
390
- # Init new token estimation
313
+ # Init schedule policy and new token estimation
314
+ self.policy = SchedulePolicy(
315
+ self.schedule_policy,
316
+ self.tree_cache,
317
+ self.enable_hierarchical_cache,
318
+ )
391
319
  assert (
392
320
  server_args.schedule_conservativeness >= 0
393
321
  ), "Invalid schedule_conservativeness"
394
-
395
322
  self.init_new_token_ratio = min(
396
323
  global_config.default_init_new_token_ratio
397
324
  * server_args.schedule_conservativeness,
@@ -407,11 +334,6 @@ class Scheduler:
407
334
  ) / global_config.default_new_token_ratio_decay_steps
408
335
  self.new_token_ratio = self.init_new_token_ratio
409
336
 
410
- # Tell whether the current running batch is full so that we can skip
411
- # the check of whether to prefill new requests.
412
- # This is an optimization to reduce the overhead of the prefill check.
413
- self.batch_is_full = False
414
-
415
337
  # Init watchdog thread
416
338
  self.watchdog_timeout = server_args.watchdog_timeout
417
339
  t = threading.Thread(target=self.watchdog_thread, daemon=True)
@@ -430,14 +352,7 @@ class Scheduler:
430
352
  self.profiler_target_forward_ct: Optional[int] = None
431
353
 
432
354
  # Init metrics stats
433
- self.stats = SchedulerStats()
434
- if self.enable_metrics:
435
- self.metrics_collector = SchedulerMetricsCollector(
436
- labels={
437
- "model_name": self.server_args.served_model_name,
438
- # TODO: Add lora name/path in the future,
439
- },
440
- )
355
+ self.init_metrics()
441
356
 
442
357
  # Init request dispatcher
443
358
  self._request_dispatcher = TypeBasedDispatcher(
@@ -460,39 +375,107 @@ class Scheduler:
460
375
  (ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
461
376
  (ProfileReq, self.profile),
462
377
  (GetInternalStateReq, self.get_internal_state),
378
+ (SetInternalStateReq, self.set_internal_state),
463
379
  ]
464
380
  )
465
381
 
466
- def watchdog_thread(self):
467
- """A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
468
- self.watchdog_last_forward_ct = 0
469
- self.watchdog_last_time = time.time()
382
+ def init_tokenizer(self):
383
+ server_args = self.server_args
470
384
 
471
- while True:
472
- current = time.time()
473
- if self.cur_batch is not None:
474
- if self.watchdog_last_forward_ct == self.forward_ct:
475
- if current > self.watchdog_last_time + self.watchdog_timeout:
476
- logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
477
- break
478
- else:
479
- self.watchdog_last_forward_ct = self.forward_ct
480
- self.watchdog_last_time = current
481
- time.sleep(self.watchdog_timeout // 2)
385
+ self.model_config = ModelConfig(
386
+ server_args.model_path,
387
+ trust_remote_code=server_args.trust_remote_code,
388
+ revision=server_args.revision,
389
+ context_length=server_args.context_length,
390
+ model_override_args=server_args.json_model_override_args,
391
+ is_embedding=server_args.is_embedding,
392
+ dtype=server_args.dtype,
393
+ quantization=server_args.quantization,
394
+ )
395
+ self.is_generation = self.model_config.is_generation
482
396
 
483
- # Print batch size and memory pool info to check whether there are de-sync issues.
484
- logger.error(
485
- f"{self.cur_batch.batch_size()=}, "
486
- f"{self.cur_batch.reqs=}, "
487
- f"{self.token_to_kv_pool.available_size()=}, "
488
- f"{self.tree_cache.evictable_size()=}, "
397
+ if server_args.skip_tokenizer_init:
398
+ self.tokenizer = self.processor = None
399
+ else:
400
+ if self.model_config.is_multimodal:
401
+ self.processor = get_processor(
402
+ server_args.tokenizer_path,
403
+ tokenizer_mode=server_args.tokenizer_mode,
404
+ trust_remote_code=server_args.trust_remote_code,
405
+ revision=server_args.revision,
406
+ )
407
+ self.tokenizer = self.processor.tokenizer
408
+ else:
409
+ self.tokenizer = get_tokenizer(
410
+ server_args.tokenizer_path,
411
+ tokenizer_mode=server_args.tokenizer_mode,
412
+ trust_remote_code=server_args.trust_remote_code,
413
+ revision=server_args.revision,
414
+ )
415
+
416
+ def init_memory_pool_and_cache(self):
417
+ server_args = self.server_args
418
+
419
+ self.req_to_token_pool, self.token_to_kv_pool_allocator = (
420
+ self.tp_worker.get_memory_pool()
421
+ )
422
+
423
+ if (
424
+ server_args.chunked_prefill_size is not None
425
+ and server_args.disable_radix_cache
426
+ ):
427
+ self.tree_cache = ChunkCache(
428
+ req_to_token_pool=self.req_to_token_pool,
429
+ token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
430
+ )
431
+ else:
432
+ if self.enable_hierarchical_cache:
433
+ self.tree_cache = HiRadixCache(
434
+ req_to_token_pool=self.req_to_token_pool,
435
+ token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
436
+ tp_cache_group=self.tp_worker.get_tp_cpu_group(),
437
+ )
438
+ else:
439
+ self.tree_cache = RadixCache(
440
+ req_to_token_pool=self.req_to_token_pool,
441
+ token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
442
+ page_size=self.page_size,
443
+ disable=server_args.disable_radix_cache,
444
+ )
445
+
446
+ self.decode_mem_cache_buf_multiplier = (
447
+ 1
448
+ if self.spec_algorithm.is_none()
449
+ else (
450
+ server_args.speculative_num_draft_tokens
451
+ + (
452
+ server_args.speculative_eagle_topk
453
+ * server_args.speculative_num_steps
454
+ )
455
+ )
489
456
  )
490
- # Wait for some time so that the parent process can print the error.
491
- pyspy_dump_schedulers()
492
- print(file=sys.stderr, flush=True)
493
- print(file=sys.stdout, flush=True)
494
- time.sleep(5)
495
- self.parent_process.send_signal(signal.SIGQUIT)
457
+
458
+ def init_metrics(self):
459
+ # The largest prefill length of a single request
460
+ self._largest_prefill_len: int = 0
461
+ # The largest context length (prefill + generation) of a single request
462
+ self._largest_prefill_decode_len: int = 0
463
+ self.last_gen_throughput: float = 0.0
464
+ self.last_input_throughput: float = 0.0
465
+ self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
466
+ self.spec_num_total_accepted_tokens = 0
467
+ self.spec_num_total_forward_ct = 0
468
+ self.cum_spec_accept_length = 0
469
+ self.cum_spec_accept_count = 0
470
+ self.stats = SchedulerStats()
471
+ if self.enable_metrics:
472
+ engine_type = "unified"
473
+ self.metrics_collector = SchedulerMetricsCollector(
474
+ labels={
475
+ "model_name": self.server_args.served_model_name,
476
+ "engine_type": engine_type,
477
+ },
478
+ )
496
479
 
497
480
  @torch.no_grad()
498
481
  def event_loop_normal(self):
@@ -508,7 +491,7 @@ class Scheduler:
508
491
  result = self.run_batch(batch)
509
492
  self.process_batch_result(batch, result)
510
493
  else:
511
- # When the server is idle, so self-check and re-init some states
494
+ # When the server is idle, do self-check and re-init some states
512
495
  self.check_memory()
513
496
  self.new_token_ratio = self.init_new_token_ratio
514
497
 
@@ -548,7 +531,7 @@ class Scheduler:
548
531
  )
549
532
  self.process_batch_result(tmp_batch, tmp_result)
550
533
  elif batch is None:
551
- # When the server is idle, so self-check and re-init some states
534
+ # When the server is idle, do self-check and re-init some states
552
535
  self.check_memory()
553
536
  self.new_token_ratio = self.init_new_token_ratio
554
537
 
@@ -609,7 +592,7 @@ class Scheduler:
609
592
  for recv_req in recv_reqs:
610
593
  # If it is a health check generation request and there are running requests, ignore it.
611
594
  if is_health_check_generate_req(recv_req) and (
612
- self.chunked_req is not None or self.running_batch is not None
595
+ self.chunked_req is not None or not self.running_batch.is_empty()
613
596
  ):
614
597
  self.return_health_check_ct += 1
615
598
  continue
@@ -789,6 +772,30 @@ class Scheduler:
789
772
  )
790
773
  req.tokenizer = self.tokenizer
791
774
 
775
+ # Handle multimodal inputs
776
+ if recv_req.image_inputs is not None:
777
+ image_inputs = ImageInputs.from_dict(recv_req.image_inputs)
778
+ # Expand a single image token into multiple dummy tokens for receiving image embeddings
779
+ req.origin_input_ids = self.pad_input_ids_func(
780
+ req.origin_input_ids, image_inputs
781
+ )
782
+ req.extend_image_inputs(image_inputs)
783
+
784
+ if len(req.origin_input_ids) >= self.max_req_input_len:
785
+ error_msg = (
786
+ "Multimodal prompt is too long after expanding multimodal tokens. "
787
+ f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
788
+ )
789
+ logger.error(error_msg)
790
+ req.origin_input_ids = [0]
791
+ req.image_inputs = None
792
+ req.sampling_params.max_new_tokens = 0
793
+ req.finished_reason = FINISH_ABORT(
794
+ error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
795
+ )
796
+ self.waiting_queue.append(req)
797
+ return
798
+
792
799
  # Validate prompts length
793
800
  error_msg = validate_input_length(
794
801
  req,
@@ -809,6 +816,11 @@ class Scheduler:
809
816
  can_run_list: List[Req],
810
817
  running_bs: int,
811
818
  ):
819
+ gap_latency = time.time() - self.last_prefill_stats_tic
820
+ self.last_prefill_stats_tic = time.time()
821
+ self.last_input_throughput = self.num_prefill_tokens / gap_latency
822
+ self.num_prefill_tokens = 0
823
+
812
824
  num_used = self.max_total_num_tokens - (
813
825
  self.token_to_kv_pool_allocator.available_size()
814
826
  + self.tree_cache.evictable_size()
@@ -844,7 +856,7 @@ class Scheduler:
844
856
  self.last_decode_stats_tic = time.time()
845
857
  self.last_gen_throughput = self.num_generated_tokens / gap_latency
846
858
  self.num_generated_tokens = 0
847
- num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
859
+ num_running_reqs = len(self.running_batch.reqs)
848
860
  num_used = self.max_total_num_tokens - (
849
861
  self.token_to_kv_pool_allocator.available_size()
850
862
  + self.tree_cache.evictable_size()
@@ -908,8 +920,10 @@ class Scheduler:
908
920
  )
909
921
  if memory_leak:
910
922
  msg = (
911
- "KV cache pool leak detected!"
923
+ "KV cache pool leak detected! "
912
924
  f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
925
+ f"{self.token_to_kv_pool_allocator.available_size()=}\n"
926
+ f"{self.tree_cache.evictable_size()=}\n"
913
927
  )
914
928
  warnings.warn(msg)
915
929
  if crash_on_warnings():
@@ -932,10 +946,10 @@ class Scheduler:
932
946
  ):
933
947
  # During idle time, also collect metrics every 30 seconds.
934
948
  num_used = self.max_total_num_tokens - (
935
- self.token_to_kv_pool.available_size()
949
+ self.token_to_kv_pool_allocator.available_size()
936
950
  + self.tree_cache.evictable_size()
937
951
  )
938
- num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
952
+ num_running_reqs = len(self.running_batch.reqs)
939
953
  self.stats.num_running_reqs = num_running_reqs
940
954
  self.stats.num_used_tokens = num_used
941
955
  self.stats.token_usage = num_used / self.max_total_num_tokens
@@ -953,14 +967,20 @@ class Scheduler:
953
967
  self.tree_cache.cache_unfinished_req(self.chunked_req)
954
968
  # chunked request keeps its rid but will get a new req_pool_idx
955
969
  self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
956
- self.batch_is_full = False
970
+ self.running_batch.batch_is_full = False
957
971
 
972
+ # Filter batch
973
+ last_bs = self.last_batch.batch_size()
958
974
  self.last_batch.filter_batch()
975
+ if self.last_batch.batch_size() < last_bs:
976
+ self.running_batch.batch_is_full = False
977
+
978
+ # Merge the new batch into the running batch
959
979
  if not self.last_batch.is_empty():
960
- if self.running_batch is None:
980
+ if self.running_batch.is_empty():
961
981
  self.running_batch = self.last_batch
962
982
  else:
963
- # merge running_batch with prefill batch
983
+ # Merge running_batch with prefill batch
964
984
  self.running_batch.merge_batch(self.last_batch)
965
985
 
966
986
  new_batch = self.get_new_batch_prefill()
@@ -969,11 +989,11 @@ class Scheduler:
969
989
  ret = new_batch
970
990
  else:
971
991
  # Run decode
972
- if self.running_batch is None:
973
- ret = None
974
- else:
992
+ if not self.running_batch.is_empty():
975
993
  self.running_batch = self.update_running_batch(self.running_batch)
976
- ret = self.running_batch
994
+ ret = self.running_batch if not self.running_batch.is_empty() else None
995
+ else:
996
+ ret = None
977
997
 
978
998
  # Handle DP attention
979
999
  if self.server_args.enable_dp_attention:
@@ -988,15 +1008,20 @@ class Scheduler:
988
1008
 
989
1009
  # Handle the cases where prefill is not allowed
990
1010
  if (
991
- self.batch_is_full or len(self.waiting_queue) == 0
1011
+ self.running_batch.batch_is_full or len(self.waiting_queue) == 0
992
1012
  ) and self.chunked_req is None:
993
1013
  return None
994
1014
 
995
- running_bs = len(self.running_batch.reqs) if self.running_batch else 0
1015
+ running_bs = len(self.running_batch.reqs)
996
1016
  if running_bs >= self.max_running_requests:
997
- self.batch_is_full = True
1017
+ self.running_batch.batch_is_full = True
998
1018
  return None
999
1019
 
1020
+ if self.enable_hierarchical_cache:
1021
+ # check for completion of hierarchical cache activities to release memory
1022
+ self.tree_cache.writing_check()
1023
+ self.tree_cache.loading_check()
1024
+
1000
1025
  # Get priority queue
1001
1026
  prefix_computed = self.policy.calc_priority(self.waiting_queue)
1002
1027
 
@@ -1011,17 +1036,13 @@ class Scheduler:
1011
1036
  running_bs if self.is_mixed_chunk else 0,
1012
1037
  )
1013
1038
 
1014
- is_chunked = self.chunked_req is not None
1015
- if is_chunked:
1039
+ if self.chunked_req is not None:
1016
1040
  self.chunked_req.init_next_round_input()
1017
1041
  self.chunked_req = adder.add_chunked_req(self.chunked_req)
1018
1042
 
1019
1043
  if self.lora_paths:
1020
- lora_set = (
1021
- set([req.lora_path for req in self.running_batch.reqs])
1022
- if self.running_batch is not None
1023
- else set([])
1024
- )
1044
+ lora_set = set([req.lora_path for req in self.running_batch.reqs])
1045
+
1025
1046
  # Get requests from the waiting queue to a new prefill batch
1026
1047
  for req in self.waiting_queue:
1027
1048
  if (
@@ -1033,51 +1054,33 @@ class Scheduler:
1033
1054
  )
1034
1055
  > self.max_loras_per_batch
1035
1056
  ):
1036
- self.batch_is_full = True
1057
+ self.running_batch.batch_is_full = True
1037
1058
  break
1038
1059
 
1039
1060
  if running_bs + len(adder.can_run_list) >= self.max_running_requests:
1040
- self.batch_is_full = True
1061
+ self.running_batch.batch_is_full = True
1041
1062
  break
1042
1063
 
1043
- req.init_next_round_input(None if prefix_computed else self.tree_cache)
1064
+ req.init_next_round_input(
1065
+ None if prefix_computed else self.tree_cache,
1066
+ self.enable_hierarchical_cache,
1067
+ )
1044
1068
 
1045
- if self.enable_hierarchical_cache and req.last_node is not None:
1046
- if req.last_node.evicted:
1047
- # loading KV cache for the request
1048
- req.last_node, req.prefix_indices = self.tree_cache.init_load_back(
1049
- req.last_node,
1050
- req.prefix_indices,
1051
- adder.rem_total_tokens,
1052
- )
1053
- if req.last_node.loading:
1054
- # to prevent frequent cache invalidation
1055
- if req.rid in self.staging_reqs:
1056
- self.tree_cache.dec_lock_ref(self.staging_reqs[req.rid])
1057
- self.tree_cache.inc_lock_ref(req.last_node)
1058
- self.staging_reqs[req.rid] = req.last_node
1059
- continue
1060
- elif req.last_node.loading:
1061
- if not self.tree_cache.loading_complete(req.last_node):
1062
- continue
1063
-
1064
- if req.rid in self.staging_reqs:
1065
- self.tree_cache.dec_lock_ref(self.staging_reqs[req.rid])
1066
- del self.staging_reqs[req.rid]
1067
-
1068
- res = adder.add_one_req(req, self.chunked_req)
1069
+ res = adder.add_one_req(
1070
+ req, self.chunked_req, self.enable_hierarchical_cache
1071
+ )
1069
1072
  if res != AddReqResult.CONTINUE:
1070
1073
  if res == AddReqResult.NO_TOKEN:
1071
1074
  if self.enable_hierarchical_cache:
1072
1075
  # Set batch_is_full after making sure there are requests that can be served
1073
- self.batch_is_full = len(adder.can_run_list) > 0 or (
1076
+ self.running_batch.batch_is_full = len(
1077
+ adder.can_run_list
1078
+ ) > 0 or (
1074
1079
  self.running_batch is not None
1075
1080
  and not self.running_batch.is_empty()
1076
1081
  )
1077
1082
  else:
1078
- self.batch_is_full = True
1079
- break
1080
- if self.prefill_only_one_req:
1083
+ self.running_batch.batch_is_full = True
1081
1084
  break
1082
1085
 
1083
1086
  # Update waiting queue
@@ -1088,6 +1091,9 @@ class Scheduler:
1088
1091
  x for x in self.waiting_queue if x not in set(can_run_list)
1089
1092
  ]
1090
1093
 
1094
+ if self.enable_hierarchical_cache:
1095
+ self.tree_cache.read_to_load_cache()
1096
+
1091
1097
  if adder.new_chunked_req is not None:
1092
1098
  assert self.chunked_req is None
1093
1099
  self.chunked_req = adder.new_chunked_req
@@ -1115,7 +1121,7 @@ class Scheduler:
1115
1121
  # Mixed-style chunked prefill
1116
1122
  if (
1117
1123
  self.is_mixed_chunk
1118
- and self.running_batch is not None
1124
+ and not self.running_batch.is_empty()
1119
1125
  and not (new_batch.return_logprob or self.running_batch.return_logprob)
1120
1126
  ):
1121
1127
  # TODO (lianmin): support return_logprob + mixed chunked prefill
@@ -1124,7 +1130,9 @@ class Scheduler:
1124
1130
  self.running_batch.prepare_for_decode()
1125
1131
  new_batch.mix_with_running(self.running_batch)
1126
1132
  new_batch.decoding_reqs = self.running_batch.reqs
1127
- self.running_batch = None
1133
+ self.running_batch = ScheduleBatch(
1134
+ reqs=[], batch_is_full=self.running_batch.batch_is_full
1135
+ )
1128
1136
  else:
1129
1137
  new_batch.decoding_reqs = None
1130
1138
 
@@ -1136,8 +1144,8 @@ class Scheduler:
1136
1144
 
1137
1145
  batch.filter_batch()
1138
1146
  if batch.is_empty():
1139
- self.batch_is_full = False
1140
- return None
1147
+ batch.batch_is_full = False
1148
+ return batch
1141
1149
 
1142
1150
  # Check if decode out of memory
1143
1151
  if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or (
@@ -1161,7 +1169,7 @@ class Scheduler:
1161
1169
  )
1162
1170
 
1163
1171
  if batch.batch_size() < initial_bs:
1164
- self.batch_is_full = False
1172
+ batch.batch_is_full = False
1165
1173
 
1166
1174
  # Update batch tensors
1167
1175
  batch.prepare_for_decode()
@@ -1180,6 +1188,7 @@ class Scheduler:
1180
1188
  ):
1181
1189
  self.stop_profile()
1182
1190
 
1191
+ # Run forward
1183
1192
  if self.is_generation:
1184
1193
  if self.spec_algorithm.is_none():
1185
1194
  model_worker_batch = batch.get_model_worker_batch()
@@ -1200,6 +1209,7 @@ class Scheduler:
1200
1209
  self.spec_num_total_forward_ct += batch.batch_size()
1201
1210
  self.num_generated_tokens += num_accepted_tokens
1202
1211
  batch.output_ids = next_token_ids
1212
+
1203
1213
  # These 2 values are needed for processing the output, but the values can be
1204
1214
  # modified by overlap schedule. So we have to copy them here so that
1205
1215
  # we can use the correct values in output processing.
@@ -1233,10 +1243,7 @@ class Scheduler:
1233
1243
  result: Union[GenerationBatchResult, EmbeddingBatchResult],
1234
1244
  ):
1235
1245
  if batch.forward_mode.is_decode():
1236
- assert isinstance(result, GenerationBatchResult)
1237
1246
  self.process_batch_result_decode(batch, result)
1238
- if batch.is_empty():
1239
- self.running_batch = None
1240
1247
  elif batch.forward_mode.is_extend():
1241
1248
  self.process_batch_result_prefill(batch, result)
1242
1249
  elif batch.forward_mode.is_idle():
@@ -1258,571 +1265,6 @@ class Scheduler:
1258
1265
  self.return_health_check_ct -= 1
1259
1266
  self.send_to_tokenizer.send_pyobj(HealthCheckOutput())
1260
1267
 
1261
- def process_batch_result_prefill(
1262
- self,
1263
- batch: ScheduleBatch,
1264
- result: Union[GenerationBatchResult, EmbeddingBatchResult],
1265
- ):
1266
- skip_stream_req = None
1267
-
1268
- if self.is_generation:
1269
- (
1270
- logits_output,
1271
- next_token_ids,
1272
- extend_input_len_per_req,
1273
- extend_logprob_start_len_per_req,
1274
- bid,
1275
- ) = (
1276
- result.logits_output,
1277
- result.next_token_ids,
1278
- result.extend_input_len_per_req,
1279
- result.extend_logprob_start_len_per_req,
1280
- result.bid,
1281
- )
1282
-
1283
- if self.enable_overlap:
1284
- logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
1285
- else:
1286
- # Move next_token_ids and logprobs to cpu
1287
- next_token_ids = next_token_ids.tolist()
1288
- if batch.return_logprob:
1289
- if logits_output.next_token_logprobs is not None:
1290
- logits_output.next_token_logprobs = (
1291
- logits_output.next_token_logprobs.tolist()
1292
- )
1293
- if logits_output.input_token_logprobs is not None:
1294
- logits_output.input_token_logprobs = tuple(
1295
- logits_output.input_token_logprobs.tolist()
1296
- )
1297
-
1298
- hidden_state_offset = 0
1299
-
1300
- # Check finish conditions
1301
- logprob_pt = 0
1302
- for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
1303
- if req.is_retracted:
1304
- continue
1305
-
1306
- if self.is_mixed_chunk and self.enable_overlap and req.finished():
1307
- # Free the one delayed token for the mixed decode batch
1308
- j = len(batch.out_cache_loc) - len(batch.reqs) + i
1309
- self.token_to_kv_pool_allocator.free(batch.out_cache_loc[j : j + 1])
1310
- continue
1311
-
1312
- if req.is_chunked <= 0:
1313
- # req output_ids are set here
1314
- req.output_ids.append(next_token_id)
1315
- req.check_finished()
1316
-
1317
- if req.finished():
1318
- self.tree_cache.cache_finished_req(req)
1319
- elif not batch.decoding_reqs or req not in batch.decoding_reqs:
1320
- # This updates radix so others can match
1321
- self.tree_cache.cache_unfinished_req(req)
1322
-
1323
- if req.return_logprob:
1324
- assert extend_logprob_start_len_per_req is not None
1325
- assert extend_input_len_per_req is not None
1326
- extend_logprob_start_len = extend_logprob_start_len_per_req[i]
1327
- extend_input_len = extend_input_len_per_req[i]
1328
- num_input_logprobs = extend_input_len - extend_logprob_start_len
1329
- self.add_logprob_return_values(
1330
- i,
1331
- req,
1332
- logprob_pt,
1333
- next_token_ids,
1334
- num_input_logprobs,
1335
- logits_output,
1336
- )
1337
- logprob_pt += num_input_logprobs
1338
-
1339
- if (
1340
- req.return_hidden_states
1341
- and logits_output.hidden_states is not None
1342
- ):
1343
- req.hidden_states.append(
1344
- logits_output.hidden_states[
1345
- hidden_state_offset : (
1346
- hidden_state_offset := hidden_state_offset
1347
- + len(req.origin_input_ids)
1348
- )
1349
- ]
1350
- .cpu()
1351
- .clone()
1352
- )
1353
-
1354
- if req.grammar is not None:
1355
- req.grammar.accept_token(next_token_id)
1356
- req.grammar.finished = req.finished()
1357
- else:
1358
- # being chunked reqs' prefill is not finished
1359
- req.is_chunked -= 1
1360
- # There is only at most one request being currently chunked.
1361
- # Because this request does not finish prefill,
1362
- # we don't want to stream the request currently being chunked.
1363
- skip_stream_req = req
1364
-
1365
- # Incrementally update input logprobs.
1366
- if req.return_logprob:
1367
- extend_logprob_start_len = extend_logprob_start_len_per_req[i]
1368
- extend_input_len = extend_input_len_per_req[i]
1369
- if extend_logprob_start_len < extend_input_len:
1370
- # Update input logprobs.
1371
- num_input_logprobs = (
1372
- extend_input_len - extend_logprob_start_len
1373
- )
1374
- self.add_input_logprob_return_values(
1375
- i,
1376
- req,
1377
- logits_output,
1378
- logprob_pt,
1379
- num_input_logprobs,
1380
- last_prefill_chunk=False,
1381
- )
1382
- logprob_pt += num_input_logprobs
1383
-
1384
- if batch.next_batch_sampling_info:
1385
- batch.next_batch_sampling_info.update_regex_vocab_mask()
1386
- self.current_stream.synchronize()
1387
- batch.next_batch_sampling_info.sampling_info_done.set()
1388
-
1389
- else: # embedding or reward model
1390
- embeddings, bid = result.embeddings, result.bid
1391
- embeddings = embeddings.tolist()
1392
-
1393
- # Check finish conditions
1394
- for i, req in enumerate(batch.reqs):
1395
- if req.is_retracted:
1396
- continue
1397
-
1398
- req.embedding = embeddings[i]
1399
- if req.is_chunked <= 0:
1400
- # Dummy output token for embedding models
1401
- req.output_ids.append(0)
1402
- req.check_finished()
1403
-
1404
- if req.finished():
1405
- self.tree_cache.cache_finished_req(req)
1406
- else:
1407
- self.tree_cache.cache_unfinished_req(req)
1408
- else:
1409
- # being chunked reqs' prefill is not finished
1410
- req.is_chunked -= 1
1411
-
1412
- self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
1413
-
1414
- def process_batch_result_decode(
1415
- self,
1416
- batch: ScheduleBatch,
1417
- result: GenerationBatchResult,
1418
- ):
1419
- logits_output, next_token_ids, bid = (
1420
- result.logits_output,
1421
- result.next_token_ids,
1422
- result.bid,
1423
- )
1424
- self.num_generated_tokens += len(batch.reqs)
1425
-
1426
- if self.enable_overlap:
1427
- assert batch.spec_algorithm.is_none()
1428
- logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
1429
- next_token_logprobs = logits_output.next_token_logprobs
1430
- elif batch.spec_algorithm.is_none():
1431
- # spec decoding handles output logprobs inside verify process.
1432
- next_token_ids = next_token_ids.tolist()
1433
- if batch.return_logprob:
1434
- next_token_logprobs = logits_output.next_token_logprobs.tolist()
1435
-
1436
- self.token_to_kv_pool_allocator.free_group_begin()
1437
-
1438
- # Check finish condition
1439
- # NOTE: the length of reqs and next_token_ids don't match if it is spec decoding.
1440
- # We should ignore using next_token_ids for spec decoding cases.
1441
- for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
1442
- if req.is_retracted:
1443
- continue
1444
-
1445
- if self.enable_overlap and req.finished():
1446
- # Free the one delayed token
1447
- self.token_to_kv_pool_allocator.free(batch.out_cache_loc[i : i + 1])
1448
- continue
1449
-
1450
- if batch.spec_algorithm.is_none():
1451
- # speculative worker will solve the output_ids in speculative decoding
1452
- req.output_ids.append(next_token_id)
1453
-
1454
- req.check_finished()
1455
- if req.finished():
1456
- self.tree_cache.cache_finished_req(req)
1457
-
1458
- if req.return_logprob and batch.spec_algorithm.is_none():
1459
- # speculative worker handles logprob in speculative decoding
1460
- req.output_token_logprobs_val.append(next_token_logprobs[i])
1461
- req.output_token_logprobs_idx.append(next_token_id)
1462
- if req.top_logprobs_num > 0:
1463
- req.output_top_logprobs_val.append(
1464
- logits_output.next_token_top_logprobs_val[i]
1465
- )
1466
- req.output_top_logprobs_idx.append(
1467
- logits_output.next_token_top_logprobs_idx[i]
1468
- )
1469
- if req.token_ids_logprob is not None:
1470
- req.output_token_ids_logprobs_val.append(
1471
- logits_output.next_token_token_ids_logprobs_val[i]
1472
- )
1473
- req.output_token_ids_logprobs_idx.append(
1474
- logits_output.next_token_token_ids_logprobs_idx[i]
1475
- )
1476
-
1477
- if req.return_hidden_states and logits_output.hidden_states is not None:
1478
- req.hidden_states.append(logits_output.hidden_states[i].cpu().clone())
1479
-
1480
- if req.grammar is not None and batch.spec_algorithm.is_none():
1481
- req.grammar.accept_token(next_token_id)
1482
- req.grammar.finished = req.finished()
1483
-
1484
- if batch.next_batch_sampling_info:
1485
- batch.next_batch_sampling_info.update_regex_vocab_mask()
1486
- self.current_stream.synchronize()
1487
- batch.next_batch_sampling_info.sampling_info_done.set()
1488
- self.stream_output(batch.reqs, batch.return_logprob)
1489
-
1490
- self.token_to_kv_pool_allocator.free_group_end()
1491
-
1492
- self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
1493
- if (
1494
- self.attn_tp_rank == 0
1495
- and self.forward_ct_decode % self.server_args.decode_log_interval == 0
1496
- ):
1497
- self.log_decode_stats()
1498
-
1499
- def add_input_logprob_return_values(
1500
- self,
1501
- i: int,
1502
- req: Req,
1503
- output: LogitsProcessorOutput,
1504
- logprob_pt: int,
1505
- num_input_logprobs: int,
1506
- last_prefill_chunk: bool, # If True, it means prefill is finished.
1507
- ):
1508
- """Incrementally add input logprobs to `req`.
1509
-
1510
- Args:
1511
- i: The request index in a batch.
1512
- req: The request. Input logprobs inside req are modified as a
1513
- consequence of the API
1514
- fill_ids: The prefill ids processed.
1515
- output: Logit processor output that's used to compute input logprobs
1516
- last_prefill_chunk: True if it is the last prefill (when chunked).
1517
- Some of input logprob operation should only happen at the last
1518
- prefill (e.g., computing input token logprobs).
1519
- """
1520
- assert output.input_token_logprobs is not None
1521
- if req.input_token_logprobs is None:
1522
- req.input_token_logprobs = []
1523
- if req.temp_input_top_logprobs_val is None:
1524
- req.temp_input_top_logprobs_val = []
1525
- if req.temp_input_top_logprobs_idx is None:
1526
- req.temp_input_top_logprobs_idx = []
1527
- if req.temp_input_token_ids_logprobs_val is None:
1528
- req.temp_input_token_ids_logprobs_val = []
1529
- if req.temp_input_token_ids_logprobs_idx is None:
1530
- req.temp_input_token_ids_logprobs_idx = []
1531
-
1532
- if req.input_token_logprobs_val is not None:
1533
- # The input logprob has been already computed. It only happens
1534
- # upon retract.
1535
- if req.top_logprobs_num > 0:
1536
- assert req.input_token_logprobs_val is not None
1537
- return
1538
-
1539
- # Important for the performance.
1540
- assert isinstance(output.input_token_logprobs, tuple)
1541
- input_token_logprobs: Tuple[int] = output.input_token_logprobs
1542
- input_token_logprobs = input_token_logprobs[
1543
- logprob_pt : logprob_pt + num_input_logprobs
1544
- ]
1545
- req.input_token_logprobs.extend(input_token_logprobs)
1546
-
1547
- if req.top_logprobs_num > 0:
1548
- req.temp_input_top_logprobs_val.append(output.input_top_logprobs_val[i])
1549
- req.temp_input_top_logprobs_idx.append(output.input_top_logprobs_idx[i])
1550
-
1551
- if req.token_ids_logprob is not None:
1552
- req.temp_input_token_ids_logprobs_val.append(
1553
- output.input_token_ids_logprobs_val[i]
1554
- )
1555
- req.temp_input_token_ids_logprobs_idx.append(
1556
- output.input_token_ids_logprobs_idx[i]
1557
- )
1558
-
1559
- if last_prefill_chunk:
1560
- input_token_logprobs = req.input_token_logprobs
1561
- req.input_token_logprobs = None
1562
- assert req.input_token_logprobs_val is None
1563
- assert req.input_token_logprobs_idx is None
1564
- assert req.input_top_logprobs_val is None
1565
- assert req.input_top_logprobs_idx is None
1566
-
1567
- # Compute input_token_logprobs_val
1568
- # Always pad the first one with None.
1569
- req.input_token_logprobs_val = [None]
1570
- req.input_token_logprobs_val.extend(input_token_logprobs)
1571
- # The last input logprob is for sampling, so just pop it out.
1572
- req.input_token_logprobs_val.pop()
1573
-
1574
- # Compute input_token_logprobs_idx
1575
- input_token_logprobs_idx = req.origin_input_ids[req.logprob_start_len :]
1576
- # Clip the padded hash values from image tokens.
1577
- # Otherwise, it will lead to detokenization errors.
1578
- input_token_logprobs_idx = [
1579
- x if x < self.model_config.vocab_size - 1 else 0
1580
- for x in input_token_logprobs_idx
1581
- ]
1582
- req.input_token_logprobs_idx = input_token_logprobs_idx
1583
-
1584
- if req.top_logprobs_num > 0:
1585
- req.input_top_logprobs_val = [None]
1586
- req.input_top_logprobs_idx = [None]
1587
- assert len(req.temp_input_token_ids_logprobs_val) == len(
1588
- req.temp_input_token_ids_logprobs_idx
1589
- )
1590
- for val, idx in zip(
1591
- req.temp_input_top_logprobs_val, req.temp_input_top_logprobs_idx
1592
- ):
1593
- req.input_top_logprobs_val.extend(val)
1594
- req.input_top_logprobs_idx.extend(idx)
1595
-
1596
- # Last token is a sample token.
1597
- req.input_top_logprobs_val.pop()
1598
- req.input_top_logprobs_idx.pop()
1599
- req.temp_input_top_logprobs_idx = None
1600
- req.temp_input_top_logprobs_val = None
1601
-
1602
- if req.token_ids_logprob is not None:
1603
- req.input_token_ids_logprobs_val = [None]
1604
- req.input_token_ids_logprobs_idx = [None]
1605
-
1606
- for val, idx in zip(
1607
- req.temp_input_token_ids_logprobs_val,
1608
- req.temp_input_token_ids_logprobs_idx,
1609
- strict=True,
1610
- ):
1611
- req.input_token_ids_logprobs_val.extend(val)
1612
- req.input_token_ids_logprobs_idx.extend(idx)
1613
-
1614
- # Last token is a sample token.
1615
- req.input_token_ids_logprobs_val.pop()
1616
- req.input_token_ids_logprobs_idx.pop()
1617
- req.temp_input_token_ids_logprobs_idx = None
1618
- req.temp_input_token_ids_logprobs_val = None
1619
-
1620
- if req.return_logprob:
1621
- relevant_tokens_len = len(req.origin_input_ids) - req.logprob_start_len
1622
- assert len(req.input_token_logprobs_val) == relevant_tokens_len
1623
- assert len(req.input_token_logprobs_idx) == relevant_tokens_len
1624
- if req.top_logprobs_num > 0:
1625
- assert len(req.input_top_logprobs_val) == relevant_tokens_len
1626
- assert len(req.input_top_logprobs_idx) == relevant_tokens_len
1627
- if req.token_ids_logprob is not None:
1628
- assert len(req.input_token_ids_logprobs_val) == relevant_tokens_len
1629
- assert len(req.input_token_ids_logprobs_idx) == relevant_tokens_len
1630
-
1631
- def add_logprob_return_values(
1632
- self,
1633
- i: int,
1634
- req: Req,
1635
- pt: int,
1636
- next_token_ids: List[int],
1637
- num_input_logprobs: int,
1638
- output: LogitsProcessorOutput,
1639
- ):
1640
- """Attach logprobs to the return values."""
1641
- req.output_token_logprobs_val.append(output.next_token_logprobs[i])
1642
- req.output_token_logprobs_idx.append(next_token_ids[i])
1643
-
1644
- self.add_input_logprob_return_values(
1645
- i, req, output, pt, num_input_logprobs, last_prefill_chunk=True
1646
- )
1647
-
1648
- if req.top_logprobs_num > 0:
1649
- req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i])
1650
- req.output_top_logprobs_idx.append(output.next_token_top_logprobs_idx[i])
1651
-
1652
- if req.token_ids_logprob is not None:
1653
- req.output_token_ids_logprobs_val.append(
1654
- output.next_token_token_ids_logprobs_val[i]
1655
- )
1656
- req.output_token_ids_logprobs_idx.append(
1657
- output.next_token_token_ids_logprobs_idx[i]
1658
- )
1659
-
1660
- return num_input_logprobs
1661
-
1662
- def stream_output(
1663
- self, reqs: List[Req], return_logprob: bool, skip_req: Optional[Req] = None
1664
- ):
1665
- """Stream the output to detokenizer."""
1666
- rids = []
1667
- finished_reasons: List[BaseFinishReason] = []
1668
-
1669
- if self.is_generation:
1670
- decoded_texts = []
1671
- decode_ids_list = []
1672
- read_offsets = []
1673
- output_ids = []
1674
-
1675
- skip_special_tokens = []
1676
- spaces_between_special_tokens = []
1677
- no_stop_trim = []
1678
- prompt_tokens = []
1679
- completion_tokens = []
1680
- cached_tokens = []
1681
- spec_verify_ct = []
1682
- output_hidden_states = None
1683
-
1684
- if return_logprob:
1685
- input_token_logprobs_val = []
1686
- input_token_logprobs_idx = []
1687
- output_token_logprobs_val = []
1688
- output_token_logprobs_idx = []
1689
- input_top_logprobs_val = []
1690
- input_top_logprobs_idx = []
1691
- output_top_logprobs_val = []
1692
- output_top_logprobs_idx = []
1693
- input_token_ids_logprobs_val = []
1694
- input_token_ids_logprobs_idx = []
1695
- output_token_ids_logprobs_val = []
1696
- output_token_ids_logprobs_idx = []
1697
- else:
1698
- input_token_logprobs_val = input_token_logprobs_idx = (
1699
- output_token_logprobs_val
1700
- ) = output_token_logprobs_idx = input_top_logprobs_val = (
1701
- input_top_logprobs_idx
1702
- ) = output_top_logprobs_val = output_top_logprobs_idx = (
1703
- input_token_ids_logprobs_val
1704
- ) = input_token_ids_logprobs_idx = output_token_ids_logprobs_val = (
1705
- output_token_ids_logprobs_idx
1706
- ) = None
1707
-
1708
- for req in reqs:
1709
- if req is skip_req:
1710
- continue
1711
-
1712
- # Multimodal partial stream chunks break the detokenizer, so drop aborted requests here.
1713
- if self.model_config.is_multimodal_gen and req.to_abort:
1714
- continue
1715
-
1716
- if (
1717
- req.finished()
1718
- # If stream, follow the given stream_interval
1719
- or (req.stream and len(req.output_ids) % self.stream_interval == 0)
1720
- # If not stream, we still want to output some tokens to get the benefit of incremental decoding.
1721
- # TODO(lianmin): this is wrong for speculative decoding because len(req.output_ids) does not
1722
- # always increase one-by-one.
1723
- or (
1724
- not req.stream
1725
- and len(req.output_ids) % 50 == 0
1726
- and not self.model_config.is_multimodal_gen
1727
- )
1728
- ):
1729
- rids.append(req.rid)
1730
- finished_reasons.append(
1731
- req.finished_reason.to_json() if req.finished_reason else None
1732
- )
1733
- decoded_texts.append(req.decoded_text)
1734
- decode_ids, read_offset = req.init_incremental_detokenize()
1735
- decode_ids_list.append(decode_ids)
1736
- read_offsets.append(read_offset)
1737
- if self.skip_tokenizer_init:
1738
- output_ids.append(req.output_ids)
1739
- skip_special_tokens.append(req.sampling_params.skip_special_tokens)
1740
- spaces_between_special_tokens.append(
1741
- req.sampling_params.spaces_between_special_tokens
1742
- )
1743
- no_stop_trim.append(req.sampling_params.no_stop_trim)
1744
-
1745
- prompt_tokens.append(len(req.origin_input_ids))
1746
- completion_tokens.append(len(req.output_ids))
1747
- cached_tokens.append(req.cached_tokens)
1748
-
1749
- if not self.spec_algorithm.is_none():
1750
- spec_verify_ct.append(req.spec_verify_ct)
1751
-
1752
- if return_logprob:
1753
- input_token_logprobs_val.append(req.input_token_logprobs_val)
1754
- input_token_logprobs_idx.append(req.input_token_logprobs_idx)
1755
- output_token_logprobs_val.append(req.output_token_logprobs_val)
1756
- output_token_logprobs_idx.append(req.output_token_logprobs_idx)
1757
- input_top_logprobs_val.append(req.input_top_logprobs_val)
1758
- input_top_logprobs_idx.append(req.input_top_logprobs_idx)
1759
- output_top_logprobs_val.append(req.output_top_logprobs_val)
1760
- output_top_logprobs_idx.append(req.output_top_logprobs_idx)
1761
- input_token_ids_logprobs_val.append(
1762
- req.input_token_ids_logprobs_val
1763
- )
1764
- input_token_ids_logprobs_idx.append(
1765
- req.input_token_ids_logprobs_idx
1766
- )
1767
- output_token_ids_logprobs_val.append(
1768
- req.output_token_ids_logprobs_val
1769
- )
1770
- output_token_ids_logprobs_idx.append(
1771
- req.output_token_ids_logprobs_idx
1772
- )
1773
-
1774
- if req.return_hidden_states:
1775
- if output_hidden_states is None:
1776
- output_hidden_states = []
1777
- output_hidden_states.append(req.hidden_states)
1778
-
1779
- # Send to detokenizer
1780
- if rids:
1781
- if self.model_config.is_multimodal_gen:
1782
- raise NotImplementedError()
1783
- self.send_to_detokenizer.send_pyobj(
1784
- BatchTokenIDOut(
1785
- rids,
1786
- finished_reasons,
1787
- decoded_texts,
1788
- decode_ids_list,
1789
- read_offsets,
1790
- output_ids,
1791
- skip_special_tokens,
1792
- spaces_between_special_tokens,
1793
- no_stop_trim,
1794
- prompt_tokens,
1795
- completion_tokens,
1796
- cached_tokens,
1797
- spec_verify_ct,
1798
- input_token_logprobs_val,
1799
- input_token_logprobs_idx,
1800
- output_token_logprobs_val,
1801
- output_token_logprobs_idx,
1802
- input_top_logprobs_val,
1803
- input_top_logprobs_idx,
1804
- output_top_logprobs_val,
1805
- output_top_logprobs_idx,
1806
- input_token_ids_logprobs_val,
1807
- input_token_ids_logprobs_idx,
1808
- output_token_ids_logprobs_val,
1809
- output_token_ids_logprobs_idx,
1810
- output_hidden_states,
1811
- )
1812
- )
1813
- else: # embedding or reward model
1814
- embeddings = []
1815
- prompt_tokens = []
1816
- for req in reqs:
1817
- if req.finished():
1818
- rids.append(req.rid)
1819
- finished_reasons.append(req.finished_reason.to_json())
1820
- embeddings.append(req.embedding)
1821
- prompt_tokens.append(len(req.origin_input_ids))
1822
- self.send_to_detokenizer.send_pyobj(
1823
- BatchEmbeddingOut(rids, finished_reasons, embeddings, prompt_tokens)
1824
- )
1825
-
1826
1268
  def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
1827
1269
  # Check if other DP workers have running batches
1828
1270
  if local_batch is None:
@@ -1906,18 +1348,46 @@ class Scheduler:
1906
1348
  self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
1907
1349
  self.grammar_queue = self.grammar_queue[num_ready_reqs:]
1908
1350
 
1351
+ def watchdog_thread(self):
1352
+ """A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
1353
+ self.watchdog_last_forward_ct = 0
1354
+ self.watchdog_last_time = time.time()
1355
+
1356
+ while True:
1357
+ current = time.time()
1358
+ if self.cur_batch is not None:
1359
+ if self.watchdog_last_forward_ct == self.forward_ct:
1360
+ if current > self.watchdog_last_time + self.watchdog_timeout:
1361
+ logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
1362
+ break
1363
+ else:
1364
+ self.watchdog_last_forward_ct = self.forward_ct
1365
+ self.watchdog_last_time = current
1366
+ time.sleep(self.watchdog_timeout // 2)
1367
+
1368
+ # Print batch size and memory pool info to check whether there are de-sync issues.
1369
+ logger.error(
1370
+ f"{self.cur_batch.batch_size()=}, "
1371
+ f"{self.cur_batch.reqs=}, "
1372
+ f"{self.token_to_kv_pool_allocator.available_size()=}, "
1373
+ f"{self.tree_cache.evictable_size()=}, "
1374
+ )
1375
+ # Wait for some time so that the parent process can print the error.
1376
+ pyspy_dump_schedulers()
1377
+ print(file=sys.stderr, flush=True)
1378
+ print(file=sys.stdout, flush=True)
1379
+ time.sleep(5)
1380
+ self.parent_process.send_signal(signal.SIGQUIT)
1381
+
1909
1382
  def flush_cache_wrapped(self, recv_req: FlushCacheReq):
1910
1383
  self.flush_cache()
1911
1384
 
1912
1385
  def flush_cache(self):
1913
1386
  """Flush the memory pool and cache."""
1914
- if len(self.waiting_queue) == 0 and (
1915
- self.running_batch is None or len(self.running_batch.reqs) == 0
1916
- ):
1387
+ if len(self.waiting_queue) == 0 and self.running_batch.is_empty():
1917
1388
  self.cur_batch = None
1918
1389
  self.last_batch = None
1919
1390
  self.tree_cache.reset()
1920
- self.tree_cache_metrics = {"total": 0, "hit": 0}
1921
1391
  if self.grammar_backend:
1922
1392
  self.grammar_backend.reset()
1923
1393
  self.req_to_token_pool.clear()
@@ -1940,7 +1410,7 @@ class Scheduler:
1940
1410
  logging.warning(
1941
1411
  f"Cache not flushed because there are pending requests. "
1942
1412
  f"#queue-req: {len(self.waiting_queue)}, "
1943
- f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
1413
+ f"#running-req: {len(self.running_batch.reqs)}"
1944
1414
  )
1945
1415
  if_success = False
1946
1416
  return if_success
@@ -1990,24 +1460,27 @@ class Scheduler:
1990
1460
 
1991
1461
  def abort_request(self, recv_req: AbortReq):
1992
1462
  # Delete requests in the waiting queue
1993
- to_del = None
1463
+ to_del = []
1994
1464
  for i, req in enumerate(self.waiting_queue):
1995
- if req.rid == recv_req.rid:
1996
- to_del = i
1465
+ if req.rid.startswith(recv_req.rid):
1466
+ to_del.append(i)
1997
1467
  break
1998
1468
 
1999
- if to_del is not None:
2000
- del self.waiting_queue[to_del]
1469
+ # Sort in reverse order to avoid index issues when deleting
1470
+ for i in sorted(to_del, reverse=True):
1471
+ req = self.waiting_queue.pop(i)
2001
1472
  logger.debug(f"Abort queued request. {req.rid=}")
2002
1473
  return
2003
1474
 
2004
1475
  # Delete requests in the running batch
2005
- if self.running_batch:
2006
- for req in self.running_batch.reqs:
2007
- if req.rid == recv_req.rid and not req.finished():
2008
- logger.debug(f"Abort running request. {req.rid=}")
2009
- req.to_abort = True
2010
- break
1476
+ for req in self.running_batch.reqs:
1477
+ if req.rid.startswith(recv_req.rid) and not req.finished():
1478
+ logger.debug(f"Abort running request. {req.rid=}")
1479
+ req.to_abort = True
1480
+ return
1481
+
1482
+ def _pause_engine(self) -> Tuple[List[Req], int]:
1483
+ raise NotImplementedError()
2011
1484
 
2012
1485
  def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
2013
1486
  """In-place update of the weights from disk."""
@@ -2211,9 +1684,16 @@ def run_scheduler_process(
2211
1684
  dp_rank: Optional[int],
2212
1685
  pipe_writer,
2213
1686
  ):
1687
+
1688
+ # Generate the prefix
1689
+ if dp_rank is None:
1690
+ prefix = f" TP{tp_rank}"
1691
+ else:
1692
+ prefix = f" DP{dp_rank} TP{tp_rank}"
1693
+
2214
1694
  # Config the process
2215
1695
  # kill_itself_when_parent_died() # This is disabled because it does not work for `--dp 2`
2216
- setproctitle.setproctitle(f"sglang::scheduler_{dp_rank}")
1696
+ setproctitle.setproctitle(f"sglang::scheduler{prefix.replace(' ', '_')}")
2217
1697
  faulthandler.enable()
2218
1698
  parent_process = psutil.Process().parent()
2219
1699
 
@@ -2222,10 +1702,6 @@ def run_scheduler_process(
2222
1702
  dp_rank = int(os.environ["SGLANG_DP_RANK"])
2223
1703
 
2224
1704
  # Configure the logger
2225
- if dp_rank is None:
2226
- prefix = f" TP{tp_rank}"
2227
- else:
2228
- prefix = f" DP{dp_rank} TP{tp_rank}"
2229
1705
  configure_logger(server_args, prefix=prefix)
2230
1706
  suppress_other_loggers()
2231
1707