sglang 0.4.9.post3__py3-none-any.whl → 0.4.9.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (128) hide show
  1. sglang/lang/chat_template.py +21 -0
  2. sglang/srt/_custom_ops.py +29 -1
  3. sglang/srt/configs/internvl.py +3 -0
  4. sglang/srt/configs/model_config.py +5 -1
  5. sglang/srt/constrained/base_grammar_backend.py +10 -2
  6. sglang/srt/constrained/xgrammar_backend.py +7 -5
  7. sglang/srt/conversation.py +17 -2
  8. sglang/srt/debug_utils/__init__.py +0 -0
  9. sglang/srt/debug_utils/dump_comparator.py +131 -0
  10. sglang/srt/debug_utils/dumper.py +108 -0
  11. sglang/srt/debug_utils/text_comparator.py +172 -0
  12. sglang/srt/disaggregation/common/conn.py +34 -6
  13. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +13 -1
  14. sglang/srt/disaggregation/mini_lb.py +3 -2
  15. sglang/srt/disaggregation/mooncake/conn.py +65 -20
  16. sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
  17. sglang/srt/disaggregation/nixl/conn.py +17 -13
  18. sglang/srt/disaggregation/prefill.py +13 -1
  19. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
  20. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
  21. sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
  22. sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
  23. sglang/srt/distributed/parallel_state.py +70 -15
  24. sglang/srt/entrypoints/engine.py +5 -9
  25. sglang/srt/entrypoints/http_server.py +20 -32
  26. sglang/srt/entrypoints/openai/protocol.py +3 -3
  27. sglang/srt/entrypoints/openai/serving_chat.py +148 -72
  28. sglang/srt/function_call/base_format_detector.py +74 -12
  29. sglang/srt/function_call/deepseekv3_detector.py +26 -11
  30. sglang/srt/function_call/ebnf_composer.py +105 -66
  31. sglang/srt/function_call/function_call_parser.py +6 -4
  32. sglang/srt/function_call/glm4_moe_detector.py +164 -0
  33. sglang/srt/function_call/kimik2_detector.py +41 -16
  34. sglang/srt/function_call/llama32_detector.py +6 -3
  35. sglang/srt/function_call/mistral_detector.py +11 -3
  36. sglang/srt/function_call/pythonic_detector.py +16 -14
  37. sglang/srt/function_call/qwen25_detector.py +12 -3
  38. sglang/srt/function_call/{qwen3_detector.py → qwen3_coder_detector.py} +11 -9
  39. sglang/srt/layers/activation.py +11 -3
  40. sglang/srt/layers/attention/base_attn_backend.py +3 -1
  41. sglang/srt/layers/attention/hybrid_attn_backend.py +100 -0
  42. sglang/srt/layers/attention/vision.py +56 -8
  43. sglang/srt/layers/communicator.py +12 -12
  44. sglang/srt/layers/dp_attention.py +72 -24
  45. sglang/srt/layers/layernorm.py +26 -1
  46. sglang/srt/layers/logits_processor.py +46 -25
  47. sglang/srt/layers/moe/ep_moe/layer.py +172 -206
  48. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=160,N=320,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  49. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
  50. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +25 -224
  51. sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -48
  52. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +11 -8
  53. sglang/srt/layers/moe/topk.py +88 -34
  54. sglang/srt/layers/multimodal.py +11 -8
  55. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -9
  56. sglang/srt/layers/quantization/fp8.py +25 -247
  57. sglang/srt/layers/quantization/fp8_kernel.py +78 -48
  58. sglang/srt/layers/quantization/modelopt_quant.py +33 -14
  59. sglang/srt/layers/quantization/unquant.py +24 -76
  60. sglang/srt/layers/quantization/utils.py +0 -9
  61. sglang/srt/layers/quantization/w4afp8.py +68 -17
  62. sglang/srt/layers/radix_attention.py +5 -3
  63. sglang/srt/lora/lora_manager.py +133 -169
  64. sglang/srt/lora/lora_registry.py +188 -0
  65. sglang/srt/lora/mem_pool.py +2 -2
  66. sglang/srt/managers/cache_controller.py +62 -13
  67. sglang/srt/managers/io_struct.py +19 -1
  68. sglang/srt/managers/mm_utils.py +154 -35
  69. sglang/srt/managers/multimodal_processor.py +3 -14
  70. sglang/srt/managers/schedule_batch.py +27 -11
  71. sglang/srt/managers/scheduler.py +48 -26
  72. sglang/srt/managers/tokenizer_manager.py +62 -28
  73. sglang/srt/managers/tp_worker.py +5 -4
  74. sglang/srt/mem_cache/allocator.py +67 -7
  75. sglang/srt/mem_cache/hicache_storage.py +17 -1
  76. sglang/srt/mem_cache/hiradix_cache.py +35 -18
  77. sglang/srt/mem_cache/memory_pool_host.py +3 -0
  78. sglang/srt/model_executor/cuda_graph_runner.py +61 -25
  79. sglang/srt/model_executor/forward_batch_info.py +201 -29
  80. sglang/srt/model_executor/model_runner.py +109 -37
  81. sglang/srt/models/deepseek_v2.py +63 -30
  82. sglang/srt/models/glm4_moe.py +1035 -0
  83. sglang/srt/models/glm4_moe_nextn.py +167 -0
  84. sglang/srt/models/interns1.py +328 -0
  85. sglang/srt/models/internvl.py +143 -47
  86. sglang/srt/models/llava.py +9 -5
  87. sglang/srt/models/minicpmo.py +4 -1
  88. sglang/srt/models/mllama4.py +10 -3
  89. sglang/srt/models/qwen2_moe.py +2 -6
  90. sglang/srt/models/qwen3_moe.py +6 -8
  91. sglang/srt/multimodal/processors/base_processor.py +20 -6
  92. sglang/srt/multimodal/processors/clip.py +2 -2
  93. sglang/srt/multimodal/processors/deepseek_vl_v2.py +2 -2
  94. sglang/srt/multimodal/processors/gemma3.py +2 -2
  95. sglang/srt/multimodal/processors/gemma3n.py +2 -2
  96. sglang/srt/multimodal/processors/internvl.py +21 -8
  97. sglang/srt/multimodal/processors/janus_pro.py +2 -2
  98. sglang/srt/multimodal/processors/kimi_vl.py +2 -2
  99. sglang/srt/multimodal/processors/llava.py +4 -4
  100. sglang/srt/multimodal/processors/minicpm.py +2 -3
  101. sglang/srt/multimodal/processors/mlama.py +2 -2
  102. sglang/srt/multimodal/processors/mllama4.py +18 -111
  103. sglang/srt/multimodal/processors/phi4mm.py +2 -2
  104. sglang/srt/multimodal/processors/pixtral.py +2 -2
  105. sglang/srt/multimodal/processors/qwen_audio.py +2 -2
  106. sglang/srt/multimodal/processors/qwen_vl.py +2 -2
  107. sglang/srt/multimodal/processors/vila.py +3 -1
  108. sglang/srt/reasoning_parser.py +48 -5
  109. sglang/srt/sampling/sampling_batch_info.py +6 -5
  110. sglang/srt/server_args.py +132 -60
  111. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
  112. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +37 -36
  113. sglang/srt/speculative/eagle_utils.py +51 -23
  114. sglang/srt/speculative/eagle_worker.py +59 -44
  115. sglang/srt/two_batch_overlap.py +9 -5
  116. sglang/srt/utils.py +113 -69
  117. sglang/srt/weight_sync/utils.py +119 -0
  118. sglang/test/runners.py +4 -0
  119. sglang/test/test_activation.py +50 -1
  120. sglang/test/test_utils.py +65 -5
  121. sglang/utils.py +19 -0
  122. sglang/version.py +1 -1
  123. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/METADATA +6 -6
  124. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/RECORD +127 -114
  125. sglang/srt/debug_utils.py +0 -74
  126. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/WHEEL +0 -0
  127. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/licenses/LICENSE +0 -0
  128. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/top_level.txt +0 -0
@@ -50,6 +50,7 @@ class HiRadixCache(RadixCache):
50
50
  raise ValueError(f"HiRadixCache only supports MHA and MLA yet")
51
51
 
52
52
  self.tp_group = tp_cache_group
53
+ self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group)
53
54
  self.enable_storage = hicache_storage_backend is not None
54
55
  # todo: customizable storage prefetch threshold
55
56
  self.prefetch_threshold = 256
@@ -59,6 +60,7 @@ class HiRadixCache(RadixCache):
59
60
  token_to_kv_pool_allocator,
60
61
  self.token_to_kv_pool_host,
61
62
  page_size,
63
+ self.tp_group,
62
64
  load_cache_event=self.load_cache_event,
63
65
  write_policy=hicache_write_policy,
64
66
  io_backend=hicache_io_backend,
@@ -153,7 +155,7 @@ class HiRadixCache(RadixCache):
153
155
  queue_size = torch.tensor(
154
156
  self.cache_controller.ack_write_queue.qsize(), dtype=torch.int
155
157
  )
156
- if torch.distributed.get_world_size(group=self.tp_group) > 1:
158
+ if self.tp_world_size > 1:
157
159
  # synchrnoize TP workers to make the same update to radix cache
158
160
  torch.distributed.all_reduce(
159
161
  queue_size,
@@ -353,7 +355,7 @@ class HiRadixCache(RadixCache):
353
355
  queue_size = torch.tensor(
354
356
  self.cache_controller.prefetch_revoke_queue.qsize(), dtype=torch.int
355
357
  )
356
- if torch.distributed.get_world_size(group=self.tp_group) > 1:
358
+ if self.tp_world_size > 1:
357
359
  # synchrnoize TP workers to make the same update to hiradix cache
358
360
  torch.distributed.all_reduce(
359
361
  queue_size,
@@ -363,16 +365,18 @@ class HiRadixCache(RadixCache):
363
365
  for _ in range(queue_size.item()):
364
366
  req_id = self.cache_controller.prefetch_revoke_queue.get()
365
367
  if req_id in self.ongoing_prefetch:
366
- last_host_node, _, host_indices, _ = self.ongoing_prefetch[req_id]
368
+ last_host_node, _, _, _ = self.ongoing_prefetch[req_id]
367
369
  last_host_node.release_host()
368
- self.cache_controller.mem_pool_host.free(host_indices)
369
370
  del self.ongoing_prefetch[req_id]
371
+ else:
372
+ # the revoked operation already got terminated
373
+ pass
370
374
 
371
375
  def check_backup_progress(self):
372
376
  queue_size = torch.tensor(
373
377
  self.cache_controller.ack_backup_queue.qsize(), dtype=torch.int
374
378
  )
375
- if torch.distributed.get_world_size(group=self.tp_group) > 1:
379
+ if self.tp_world_size > 1:
376
380
  # synchrnoize TP workers to make the same update to hiradix cache
377
381
  torch.distributed.all_reduce(
378
382
  queue_size,
@@ -380,9 +384,15 @@ class HiRadixCache(RadixCache):
380
384
  group=self.tp_group,
381
385
  )
382
386
  for _ in range(queue_size.item()):
383
- ack_id, hash_value = self.cache_controller.ack_backup_queue.get()
384
- self.ongoing_backup[ack_id].hash_value = hash_value
385
- self.ongoing_backup[ack_id].release_host()
387
+ ack_id, hash_value, completed_tokens = (
388
+ self.cache_controller.ack_backup_queue.get()
389
+ )
390
+ host_node = self.ongoing_backup[ack_id]
391
+ if completed_tokens < len(host_node.key):
392
+ # backup is only partially successful, split the node
393
+ new_node = self._split_node(host_node.key, host_node, completed_tokens)
394
+ new_node.hash_value = hash_value
395
+ host_node.release_host()
386
396
  del self.ongoing_backup[ack_id]
387
397
 
388
398
  def check_prefetch_progress(self, req_id: str):
@@ -395,20 +405,24 @@ class HiRadixCache(RadixCache):
395
405
  last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch[
396
406
  req_id
397
407
  ]
408
+
398
409
  completed_tokens, hash_value = self.cache_controller.terminate_prefetch(
399
410
  operation
400
411
  )
401
412
  logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens")
402
413
 
403
- min_completed_tokens = torch.tensor(completed_tokens, dtype=torch.int)
404
- if torch.distributed.get_world_size(group=self.tp_group) > 1:
414
+ min_completed_tokens = completed_tokens
415
+ if self.tp_world_size > 1:
405
416
  # synchrnoize TP workers to make the same update to hiradix cache
417
+ completed_tokens_tensor = torch.tensor(
418
+ min_completed_tokens, dtype=torch.int
419
+ )
406
420
  torch.distributed.all_reduce(
407
- min_completed_tokens,
421
+ completed_tokens_tensor,
408
422
  op=torch.distributed.ReduceOp.MIN,
409
423
  group=self.tp_group,
410
424
  )
411
- min_completed_tokens = min_completed_tokens.item()
425
+ min_completed_tokens = completed_tokens_tensor.item()
412
426
  fetched_token_ids = token_ids[:min_completed_tokens]
413
427
  written_indices = host_indices[:min_completed_tokens]
414
428
  matched_length = self._insert_helper_host(
@@ -465,16 +479,19 @@ class HiRadixCache(RadixCache):
465
479
  new_input_tokens: List[int],
466
480
  last_hash: Optional[str] = None,
467
481
  ):
468
- if not self.enable_storage or len(new_input_tokens) < self.prefetch_threshold:
482
+ # align the number of fetching tokens to the page size
483
+ prefetch_length = len(new_input_tokens) - (
484
+ len(new_input_tokens) % self.page_size
485
+ )
486
+ new_input_tokens = new_input_tokens[:prefetch_length]
487
+ if not self.enable_storage or prefetch_length < self.prefetch_threshold:
469
488
  return
470
489
 
471
490
  last_host_node.protect_host()
472
- host_indices = self.cache_controller.mem_pool_host.alloc(len(new_input_tokens))
491
+ host_indices = self.cache_controller.mem_pool_host.alloc(prefetch_length)
473
492
  if host_indices is None:
474
- self.evict_host(len(new_input_tokens))
475
- host_indices = self.cache_controller.mem_pool_host.alloc(
476
- len(new_input_tokens)
477
- )
493
+ self.evict_host(prefetch_length)
494
+ host_indices = self.cache_controller.mem_pool_host.alloc(prefetch_length)
478
495
  if host_indices is None:
479
496
  last_host_node.release_host()
480
497
  # no sufficient host memory to prefetch
@@ -126,6 +126,9 @@ class HostKVCache(abc.ABC):
126
126
 
127
127
  @synchronized()
128
128
  def alloc(self, need_size: int) -> torch.Tensor:
129
+ assert (
130
+ need_size % self.page_size == 0
131
+ ), "The requested size should be a multiple of the page size."
129
132
  if need_size > self.available_size():
130
133
  return None
131
134
 
@@ -29,9 +29,9 @@ from torch.profiler import ProfilerActivity, profile
29
29
  from sglang.srt.custom_op import CustomOp
30
30
  from sglang.srt.distributed import get_tensor_model_parallel_rank
31
31
  from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
32
+ from sglang.srt.layers.dp_attention import DPPaddingMode, get_attention_tp_size
32
33
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
33
34
  from sglang.srt.layers.torchao_utils import save_gemlite_cache
34
- from sglang.srt.managers.schedule_batch import global_server_args_dict
35
35
  from sglang.srt.model_executor.forward_batch_info import (
36
36
  CaptureHiddenMode,
37
37
  ForwardBatch,
@@ -167,8 +167,15 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
167
167
  # is very small. We add more values here to make sure we capture the maximum bs.
168
168
  capture_bs += [model_runner.req_to_token_pool.size]
169
169
 
170
+ mul_base = 1
171
+
170
172
  if server_args.enable_two_batch_overlap:
171
- capture_bs = [bs for bs in capture_bs if bs % 2 == 0]
173
+ mul_base *= 2
174
+
175
+ if require_gathered_buffer(server_args):
176
+ mul_base *= get_attention_tp_size()
177
+
178
+ capture_bs = [bs for bs in capture_bs if bs % mul_base == 0]
172
179
 
173
180
  if server_args.cuda_graph_max_bs:
174
181
  capture_bs = [bs for bs in capture_bs if bs <= server_args.cuda_graph_max_bs]
@@ -306,20 +313,37 @@ class CudaGraphRunner:
306
313
  self.encoder_lens = None
307
314
 
308
315
  if self.require_gathered_buffer:
309
- self.gathered_buffer = torch.zeros(
310
- (
311
- self.max_num_token,
312
- self.model_runner.model_config.hidden_size,
313
- ),
314
- dtype=self.model_runner.dtype,
315
- )
316
316
  if self.require_mlp_tp_gather:
317
317
  self.global_num_tokens_gpu = torch.zeros(
318
318
  (self.dp_size,), dtype=torch.int32
319
319
  )
320
+ self.global_num_tokens_for_logprob_gpu = torch.zeros(
321
+ (self.dp_size,), dtype=torch.int32
322
+ )
323
+ self.gathered_buffer = torch.zeros(
324
+ (
325
+ self.max_num_token * self.dp_size,
326
+ self.model_runner.model_config.hidden_size,
327
+ ),
328
+ dtype=self.model_runner.dtype,
329
+ )
320
330
  else:
321
331
  assert self.require_attn_tp_gather
322
332
  self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
333
+ self.global_num_tokens_for_logprob_gpu = torch.zeros(
334
+ (1,), dtype=torch.int32
335
+ )
336
+ self.gathered_buffer = torch.zeros(
337
+ (
338
+ self.max_num_token,
339
+ self.model_runner.model_config.hidden_size,
340
+ ),
341
+ dtype=self.model_runner.dtype,
342
+ )
343
+ else:
344
+ self.global_num_tokens_gpu = None
345
+ self.global_num_tokens_for_logprob_gpu = None
346
+ self.gathered_buffer = None
323
347
 
324
348
  self.custom_mask = torch.ones(
325
349
  (
@@ -342,9 +366,9 @@ class CudaGraphRunner:
342
366
  def can_run(self, forward_batch: ForwardBatch):
343
367
  if self.require_mlp_tp_gather:
344
368
  cuda_graph_bs = (
345
- sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
369
+ max(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
346
370
  if self.model_runner.spec_algorithm.is_eagle()
347
- else sum(forward_batch.global_num_tokens_cpu)
371
+ else max(forward_batch.global_num_tokens_cpu)
348
372
  )
349
373
  else:
350
374
  cuda_graph_bs = forward_batch.batch_size
@@ -480,16 +504,19 @@ class CudaGraphRunner:
480
504
  if self.require_mlp_tp_gather:
481
505
  self.global_num_tokens_gpu.copy_(
482
506
  torch.tensor(
483
- [
484
- num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
485
- for i in range(self.dp_size)
486
- ],
507
+ [num_tokens] * self.dp_size,
487
508
  dtype=torch.int32,
488
509
  device=input_ids.device,
489
510
  )
490
511
  )
491
- global_num_tokens = self.global_num_tokens_gpu
492
- gathered_buffer = self.gathered_buffer[:num_tokens]
512
+ self.global_num_tokens_for_logprob_gpu.copy_(
513
+ torch.tensor(
514
+ [num_tokens] * self.dp_size,
515
+ dtype=torch.int32,
516
+ device=input_ids.device,
517
+ )
518
+ )
519
+ gathered_buffer = self.gathered_buffer[: num_tokens * self.dp_size]
493
520
  elif self.require_attn_tp_gather:
494
521
  self.global_num_tokens_gpu.copy_(
495
522
  torch.tensor(
@@ -498,10 +525,15 @@ class CudaGraphRunner:
498
525
  device=input_ids.device,
499
526
  )
500
527
  )
501
- global_num_tokens = self.global_num_tokens_gpu
528
+ self.global_num_tokens_for_logprob_gpu.copy_(
529
+ torch.tensor(
530
+ [num_tokens],
531
+ dtype=torch.int32,
532
+ device=input_ids.device,
533
+ )
534
+ )
502
535
  gathered_buffer = self.gathered_buffer[:num_tokens]
503
536
  else:
504
- global_num_tokens = None
505
537
  gathered_buffer = None
506
538
 
507
539
  spec_info = self.get_spec_info(num_tokens)
@@ -531,7 +563,9 @@ class CudaGraphRunner:
531
563
  encoder_lens=encoder_lens,
532
564
  return_logprob=False,
533
565
  positions=positions,
534
- global_num_tokens_gpu=global_num_tokens,
566
+ global_num_tokens_gpu=self.global_num_tokens_gpu,
567
+ global_num_tokens_for_logprob_gpu=self.global_num_tokens_for_logprob_gpu,
568
+ dp_padding_mode=DPPaddingMode.get_default_mode_in_cuda_graph(),
535
569
  gathered_buffer=gathered_buffer,
536
570
  mrope_positions=mrope_positions,
537
571
  spec_algorithm=self.model_runner.spec_algorithm,
@@ -635,12 +669,13 @@ class CudaGraphRunner:
635
669
 
636
670
  # Pad
637
671
  if self.require_mlp_tp_gather:
638
- total_batch_size = (
639
- sum(forward_batch.global_num_tokens_cpu) / self.num_tokens_per_bs
672
+ max_num_tokens = max(forward_batch.global_num_tokens_cpu)
673
+ max_batch_size = (
674
+ max_num_tokens / self.num_tokens_per_bs
640
675
  if self.model_runner.spec_algorithm.is_eagle()
641
- else sum(forward_batch.global_num_tokens_cpu)
676
+ else max_num_tokens
642
677
  )
643
- index = bisect.bisect_left(self.capture_bs, total_batch_size)
678
+ index = bisect.bisect_left(self.capture_bs, max_batch_size)
644
679
  else:
645
680
  index = bisect.bisect_left(self.capture_bs, raw_bs)
646
681
  bs = self.capture_bs[index]
@@ -670,7 +705,8 @@ class CudaGraphRunner:
670
705
  if forward_batch.mrope_positions is not None:
671
706
  self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
672
707
  if self.require_gathered_buffer:
673
- self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
708
+ self.global_num_tokens_gpu.fill_(bs * self.num_tokens_per_bs)
709
+ self.global_num_tokens_for_logprob_gpu.fill_(bs * self.num_tokens_per_bs)
674
710
  if enable_num_token_non_padded(self.model_runner.server_args):
675
711
  self.num_token_non_padded.copy_(forward_batch.num_token_non_padded)
676
712
  if self.enable_two_batch_overlap:
@@ -38,6 +38,11 @@ import torch
38
38
  import triton
39
39
  import triton.language as tl
40
40
 
41
+ from sglang.srt.layers.dp_attention import (
42
+ DPPaddingMode,
43
+ get_attention_dp_rank,
44
+ get_attention_tp_size,
45
+ )
41
46
  from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
42
47
  from sglang.srt.utils import (
43
48
  flatten_nested_list,
@@ -48,6 +53,7 @@ from sglang.srt.utils import (
48
53
 
49
54
  if TYPE_CHECKING:
50
55
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
56
+ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
51
57
  from sglang.srt.managers.schedule_batch import ModelWorkerBatch, MultimodalInputs
52
58
  from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
53
59
  from sglang.srt.model_executor.model_runner import ModelRunner
@@ -68,8 +74,6 @@ class ForwardMode(IntEnum):
68
74
  MIXED = auto()
69
75
  # No sequence to forward. For data parallel attention, some workers will be IDLE if no sequence are allocated.
70
76
  IDLE = auto()
71
- # Split Prefill for PD multiplexing
72
- SPLIT_PREFILL = auto()
73
77
 
74
78
  # Used in speculative decoding: verify a batch in the target model.
75
79
  TARGET_VERIFY = auto()
@@ -80,6 +84,9 @@ class ForwardMode(IntEnum):
80
84
  # It is now used for triggering the sampling_info_done event for the first prefill batch.
81
85
  DUMMY_FIRST = auto()
82
86
 
87
+ # Split Prefill for PD multiplexing
88
+ SPLIT_PREFILL = auto()
89
+
83
90
  def is_prefill(self):
84
91
  return self.is_extend()
85
92
 
@@ -97,12 +104,12 @@ class ForwardMode(IntEnum):
97
104
  def is_mixed(self):
98
105
  return self == ForwardMode.MIXED
99
106
 
100
- def is_split_prefill(self):
101
- return self == ForwardMode.SPLIT_PREFILL
102
-
103
107
  def is_idle(self):
104
108
  return self == ForwardMode.IDLE
105
109
 
110
+ def is_decode_or_idle(self):
111
+ return self == ForwardMode.DECODE or self == ForwardMode.IDLE
112
+
106
113
  def is_target_verify(self):
107
114
  return self == ForwardMode.TARGET_VERIFY
108
115
 
@@ -126,8 +133,8 @@ class ForwardMode(IntEnum):
126
133
  def is_dummy_first(self):
127
134
  return self == ForwardMode.DUMMY_FIRST
128
135
 
129
- def is_decode_or_idle(self):
130
- return self == ForwardMode.DECODE or self == ForwardMode.IDLE
136
+ def is_split_prefill(self):
137
+ return self == ForwardMode.SPLIT_PREFILL
131
138
 
132
139
 
133
140
  @total_ordering
@@ -242,7 +249,7 @@ class ForwardBatch:
242
249
  lora_paths: Optional[List[str]] = None
243
250
 
244
251
  # For input embeddings
245
- input_embeds: Optional[torch.tensor] = None
252
+ input_embeds: Optional[torch.Tensor] = None
246
253
 
247
254
  # For cross-encoder model
248
255
  token_type_ids: Optional[torch.Tensor] = None
@@ -261,6 +268,8 @@ class ForwardBatch:
261
268
  # Has to be None when cuda graph is captured.
262
269
  global_num_tokens_for_logprob_cpu: Optional[List[int]] = None
263
270
  global_num_tokens_for_logprob_gpu: Optional[torch.Tensor] = None
271
+ # The padding mode for DP attention
272
+ dp_padding_mode: Optional[DPPaddingMode] = None
264
273
  # for extend, local start pos and num tokens is different in logits processor
265
274
  # this will be computed in get_dp_local_info
266
275
  # this will be recomputed in LogitsMetadata.from_forward_batch
@@ -286,7 +295,7 @@ class ForwardBatch:
286
295
  # For two-batch overlap
287
296
  tbo_split_seq_index: Optional[int] = None
288
297
  tbo_parent_token_range: Optional[Tuple[int, int]] = None
289
- tbo_children: Optional[List["ForwardBatch"]] = None
298
+ tbo_children: Optional[List[ForwardBatch]] = None
290
299
 
291
300
  @classmethod
292
301
  def init_new(
@@ -340,20 +349,38 @@ class ForwardBatch:
340
349
  len(batch.input_ids), dtype=torch.int32
341
350
  ).to(device, non_blocking=True)
342
351
 
343
- # For DP attention
352
+ # For MLP sync
344
353
  if batch.global_num_tokens is not None:
345
-
346
- spec_num_draft_tokens = (
347
- batch.spec_num_draft_tokens
348
- if batch.spec_num_draft_tokens is not None
349
- else 1
354
+ from sglang.srt.speculative.eagle_utils import (
355
+ EagleDraftInput,
356
+ EagleVerifyInput,
350
357
  )
351
- global_num_tokens = [
352
- x * spec_num_draft_tokens for x in batch.global_num_tokens
353
- ]
354
- global_num_tokens_for_logprob = [
355
- x * spec_num_draft_tokens for x in batch.global_num_tokens_for_logprob
356
- ]
358
+
359
+ assert batch.global_num_tokens_for_logprob is not None
360
+ # process global_num_tokens and global_num_tokens_for_logprob
361
+ if batch.spec_info is not None:
362
+ if isinstance(batch.spec_info, EagleDraftInput):
363
+ global_num_tokens = [
364
+ x * batch.spec_info.num_tokens_per_batch
365
+ for x in batch.global_num_tokens
366
+ ]
367
+ global_num_tokens_for_logprob = [
368
+ x * batch.spec_info.num_tokens_for_logprob_per_batch
369
+ for x in batch.global_num_tokens_for_logprob
370
+ ]
371
+ else:
372
+ assert isinstance(batch.spec_info, EagleVerifyInput)
373
+ global_num_tokens = [
374
+ x * batch.spec_info.draft_token_num
375
+ for x in batch.global_num_tokens
376
+ ]
377
+ global_num_tokens_for_logprob = [
378
+ x * batch.spec_info.draft_token_num
379
+ for x in batch.global_num_tokens_for_logprob
380
+ ]
381
+ else:
382
+ global_num_tokens = batch.global_num_tokens
383
+ global_num_tokens_for_logprob = batch.global_num_tokens_for_logprob
357
384
 
358
385
  ret.global_num_tokens_cpu = global_num_tokens
359
386
  ret.global_num_tokens_gpu = torch.tensor(
@@ -365,15 +392,8 @@ class ForwardBatch:
365
392
  global_num_tokens_for_logprob, dtype=torch.int64
366
393
  ).to(device, non_blocking=True)
367
394
 
368
- sum_len = sum(global_num_tokens)
369
- ret.gathered_buffer = torch.zeros(
370
- (sum_len, model_runner.model_config.hidden_size),
371
- dtype=model_runner.dtype,
372
- device=device,
373
- )
374
-
375
395
  if ret.forward_mode.is_idle():
376
- ret.positions = torch.empty((0,), device=device)
396
+ ret.positions = torch.empty((0,), dtype=torch.int64, device=device)
377
397
  TboForwardBatchPreparer.prepare(
378
398
  ret, is_draft_worker=model_runner.is_draft_worker
379
399
  )
@@ -573,6 +593,158 @@ class ForwardBatch:
573
593
  )
574
594
  self.prefix_chunk_kv_indices.append(chunk_kv_indices)
575
595
 
596
+ def _pad_tensor_to_size(self, tensor: torch.Tensor, size: int, *, value: int = 0):
597
+ if value == 0:
598
+ return torch.cat(
599
+ [tensor, tensor.new_zeros(size - tensor.shape[0], *tensor.shape[1:])],
600
+ dim=0,
601
+ )
602
+ else:
603
+ return torch.cat(
604
+ [
605
+ tensor,
606
+ tensor.new_full((size - tensor.shape[0], *tensor.shape[1:]), value),
607
+ ],
608
+ dim=0,
609
+ )
610
+
611
+ def prepare_mlp_sync_batch(self, model_runner: ModelRunner):
612
+
613
+ from sglang.srt.speculative.eagle_utils import EagleDraftInput
614
+
615
+ assert self.global_num_tokens_cpu is not None
616
+ assert self.global_num_tokens_for_logprob_cpu is not None
617
+
618
+ global_num_tokens = self.global_num_tokens_cpu
619
+ sync_group_size = len(global_num_tokens)
620
+ attn_tp_size = get_attention_tp_size()
621
+
622
+ for i in range(sync_group_size):
623
+ # make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
624
+ # there is no reduce-scatter in LM logprob, so we do not need to adjust the padded length for logprob
625
+ global_num_tokens[i] = (
626
+ (global_num_tokens[i] - 1) // attn_tp_size + 1
627
+ ) * attn_tp_size
628
+
629
+ dp_padding_mode = DPPaddingMode.get_dp_padding_mode(global_num_tokens)
630
+ self.dp_padding_mode = dp_padding_mode
631
+
632
+ if dp_padding_mode.is_max_len():
633
+ # when DP gather mode is all gather, we will use all_gather_into_tensor to gather hidden states,
634
+ # where transferred tokens should be padded to the same length.
635
+ max_num_tokens = max(global_num_tokens)
636
+ global_num_tokens = [max_num_tokens] * sync_group_size
637
+ buffer_len = max_num_tokens * sync_group_size
638
+ else:
639
+ buffer_len = sum(global_num_tokens)
640
+
641
+ self.gathered_buffer = torch.zeros(
642
+ (buffer_len, model_runner.model_config.hidden_size),
643
+ dtype=model_runner.dtype,
644
+ device=model_runner.device,
645
+ )
646
+
647
+ bs = self.batch_size
648
+ if len(global_num_tokens) > 1:
649
+ num_tokens = global_num_tokens[get_attention_dp_rank()]
650
+ else:
651
+ num_tokens = global_num_tokens[0]
652
+
653
+ # padding
654
+ self.input_ids = self._pad_tensor_to_size(self.input_ids, num_tokens)
655
+ self.req_pool_indices = self._pad_tensor_to_size(self.req_pool_indices, bs)
656
+
657
+ seq_len_fill_value = (
658
+ model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
659
+ )
660
+ self.seq_lens = self._pad_tensor_to_size(
661
+ self.seq_lens, bs, value=seq_len_fill_value
662
+ )
663
+ if self.seq_lens_cpu is not None:
664
+ self.seq_lens_cpu = self._pad_tensor_to_size(
665
+ self.seq_lens_cpu, bs, value=seq_len_fill_value
666
+ )
667
+
668
+ self.out_cache_loc = self._pad_tensor_to_size(self.out_cache_loc, num_tokens)
669
+ if self.encoder_lens is not None:
670
+ self.encoder_lens = self._pad_tensor_to_size(self.encoder_lens, bs)
671
+ self.positions = self._pad_tensor_to_size(self.positions, num_tokens)
672
+ self.global_num_tokens_cpu = global_num_tokens
673
+ self.global_num_tokens_gpu = self.global_num_tokens_gpu.new_tensor(
674
+ global_num_tokens
675
+ )
676
+
677
+ if self.mrope_positions is not None:
678
+ self.mrope_positions = self._pad_tensor_to_size(self.mrope_positions, bs)
679
+
680
+ if self.extend_seq_lens is not None:
681
+ self.extend_seq_lens = self._pad_tensor_to_size(self.extend_seq_lens, bs)
682
+
683
+ if self.spec_info is not None and isinstance(self.spec_info, EagleDraftInput):
684
+ spec_info = self.spec_info
685
+ self.output_cache_loc_backup = self.out_cache_loc
686
+ self.hidden_states_backup = spec_info.hidden_states
687
+ if spec_info.topk_p is not None:
688
+ spec_info.topk_p = self._pad_tensor_to_size(spec_info.topk_p, bs)
689
+ if spec_info.topk_index is not None:
690
+ spec_info.topk_index = self._pad_tensor_to_size(
691
+ spec_info.topk_index, bs
692
+ )
693
+ if spec_info.accept_length is not None:
694
+ spec_info.accept_length = self._pad_tensor_to_size(
695
+ spec_info.accept_length, bs
696
+ )
697
+ spec_info.hidden_states = self._pad_tensor_to_size(
698
+ spec_info.hidden_states, num_tokens
699
+ )
700
+
701
+ def post_forward_mlp_sync_batch(self, logits_output: LogitsProcessorOutput):
702
+
703
+ bs = self.batch_size
704
+
705
+ if self.spec_info is not None:
706
+ if self.forward_mode.is_decode(): # draft
707
+ num_tokens = self.hidden_states_backup.shape[0]
708
+ self.positions = self.positions[:num_tokens]
709
+ self.seq_lens = self.seq_lens[:bs]
710
+ self.req_pool_indices = self.req_pool_indices[:bs]
711
+ if self.seq_lens_cpu is not None:
712
+ self.seq_lens_cpu = self.seq_lens_cpu[:bs]
713
+ logits_output.next_token_logits = logits_output.next_token_logits[
714
+ :num_tokens
715
+ ]
716
+ logits_output.hidden_states = logits_output.hidden_states[:num_tokens]
717
+ elif self.forward_mode.is_target_verify(): # verify
718
+ num_tokens = bs * self.spec_info.draft_token_num
719
+ logits_output.next_token_logits = logits_output.next_token_logits[
720
+ :num_tokens
721
+ ]
722
+ logits_output.hidden_states = logits_output.hidden_states[:num_tokens]
723
+ elif self.forward_mode.is_draft_extend(): # draft extend
724
+ self.spec_info.accept_length = self.spec_info.accept_length[:bs]
725
+ logits_output.next_token_logits = logits_output.next_token_logits[:bs]
726
+ logits_output.hidden_states = logits_output.hidden_states[:bs]
727
+ elif self.forward_mode.is_extend() or self.forward_mode.is_idle():
728
+ logits_output.next_token_logits = logits_output.next_token_logits[:bs]
729
+ logits_output.hidden_states = logits_output.hidden_states[:bs]
730
+
731
+ if hasattr(self, "hidden_states_backup"):
732
+ self.spec_info.hidden_states = self.hidden_states_backup
733
+ if hasattr(self, "output_cache_loc_backup"):
734
+ self.out_cache_loc = self.output_cache_loc_backup
735
+
736
+ elif self.forward_mode.is_decode() or self.forward_mode.is_idle():
737
+ logits_output.next_token_logits = logits_output.next_token_logits[:bs]
738
+ if logits_output.hidden_states is not None:
739
+ logits_output.hidden_states = logits_output.hidden_states[:bs]
740
+ elif self.forward_mode.is_extend():
741
+ num_tokens = self.seq_lens_sum
742
+ logits_output.next_token_logits = logits_output.next_token_logits[
743
+ :num_tokens
744
+ ]
745
+ if logits_output.hidden_states is not None:
746
+ logits_output.hidden_states = logits_output.hidden_states[:num_tokens]
747
+
576
748
  # Here we suppose the length of each chunk is equal
577
749
  # For example, if we have 4 sequences with prefix length [256, 512, 768, 1024], prefix_chunk_len = 256
578
750
  # num_prefix_chunks = cdiv(1024, 256) = 4