sglang 0.4.5.post3__py3-none-any.whl → 0.4.6.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. sglang/bench_one_batch.py +19 -3
  2. sglang/bench_serving.py +8 -9
  3. sglang/compile_deep_gemm.py +45 -4
  4. sglang/srt/code_completion_parser.py +1 -1
  5. sglang/srt/configs/deepseekvl2.py +1 -1
  6. sglang/srt/configs/model_config.py +9 -3
  7. sglang/srt/constrained/llguidance_backend.py +78 -61
  8. sglang/srt/conversation.py +34 -1
  9. sglang/srt/disaggregation/decode.py +67 -13
  10. sglang/srt/disaggregation/fake/__init__.py +1 -0
  11. sglang/srt/disaggregation/fake/conn.py +88 -0
  12. sglang/srt/disaggregation/mini_lb.py +45 -8
  13. sglang/srt/disaggregation/mooncake/conn.py +198 -31
  14. sglang/srt/disaggregation/prefill.py +36 -12
  15. sglang/srt/disaggregation/utils.py +16 -2
  16. sglang/srt/entrypoints/engine.py +9 -0
  17. sglang/srt/entrypoints/http_server.py +35 -4
  18. sglang/srt/function_call_parser.py +77 -5
  19. sglang/srt/layers/attention/base_attn_backend.py +3 -0
  20. sglang/srt/layers/attention/cutlass_mla_backend.py +278 -0
  21. sglang/srt/layers/attention/flashattention_backend.py +28 -10
  22. sglang/srt/layers/attention/flashmla_backend.py +8 -11
  23. sglang/srt/layers/attention/utils.py +1 -1
  24. sglang/srt/layers/attention/vision.py +2 -0
  25. sglang/srt/layers/layernorm.py +38 -16
  26. sglang/srt/layers/logits_processor.py +2 -2
  27. sglang/srt/layers/moe/fused_moe_native.py +2 -4
  28. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H20.json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H200.json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H20.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H200.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  36. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  37. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  38. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H200.json +146 -0
  39. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=96,device_name=NVIDIA_H20.json +146 -0
  40. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +41 -41
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +20 -17
  43. sglang/srt/layers/moe/fused_moe_triton/layer.py +15 -17
  44. sglang/srt/layers/pooler.py +6 -0
  45. sglang/srt/layers/quantization/awq.py +5 -1
  46. sglang/srt/layers/quantization/deep_gemm.py +17 -10
  47. sglang/srt/layers/quantization/fp8.py +20 -22
  48. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  49. sglang/srt/layers/quantization/int8_kernel.py +32 -1
  50. sglang/srt/layers/radix_attention.py +13 -3
  51. sglang/srt/layers/rotary_embedding.py +170 -126
  52. sglang/srt/managers/data_parallel_controller.py +10 -3
  53. sglang/srt/managers/io_struct.py +7 -0
  54. sglang/srt/managers/mm_utils.py +85 -28
  55. sglang/srt/managers/multimodal_processors/base_processor.py +14 -1
  56. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +9 -2
  57. sglang/srt/managers/multimodal_processors/gemma3.py +2 -5
  58. sglang/srt/managers/multimodal_processors/janus_pro.py +2 -2
  59. sglang/srt/managers/multimodal_processors/minicpm.py +4 -3
  60. sglang/srt/managers/multimodal_processors/qwen_vl.py +38 -13
  61. sglang/srt/managers/schedule_batch.py +38 -12
  62. sglang/srt/managers/scheduler.py +41 -28
  63. sglang/srt/managers/scheduler_output_processor_mixin.py +25 -9
  64. sglang/srt/managers/tokenizer_manager.py +5 -1
  65. sglang/srt/managers/tp_worker.py +3 -3
  66. sglang/srt/managers/tp_worker_overlap_thread.py +9 -4
  67. sglang/srt/mem_cache/memory_pool.py +87 -0
  68. sglang/srt/model_executor/cuda_graph_runner.py +4 -3
  69. sglang/srt/model_executor/forward_batch_info.py +51 -95
  70. sglang/srt/model_executor/model_runner.py +19 -25
  71. sglang/srt/models/deepseek.py +12 -2
  72. sglang/srt/models/deepseek_nextn.py +101 -6
  73. sglang/srt/models/deepseek_v2.py +144 -70
  74. sglang/srt/models/deepseek_vl2.py +9 -4
  75. sglang/srt/models/gemma3_causal.py +1 -1
  76. sglang/srt/models/llama4.py +0 -1
  77. sglang/srt/models/minicpmo.py +5 -1
  78. sglang/srt/models/mllama4.py +2 -2
  79. sglang/srt/models/qwen2_5_vl.py +3 -6
  80. sglang/srt/models/qwen2_vl.py +3 -7
  81. sglang/srt/models/roberta.py +178 -0
  82. sglang/srt/openai_api/adapter.py +50 -11
  83. sglang/srt/openai_api/protocol.py +2 -0
  84. sglang/srt/reasoning_parser.py +25 -1
  85. sglang/srt/server_args.py +31 -24
  86. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  87. sglang/srt/torch_memory_saver_adapter.py +10 -1
  88. sglang/srt/utils.py +5 -1
  89. sglang/test/runners.py +6 -13
  90. sglang/test/send_one.py +84 -28
  91. sglang/test/test_utils.py +74 -18
  92. sglang/version.py +1 -1
  93. {sglang-0.4.5.post3.dist-info → sglang-0.4.6.post1.dist-info}/METADATA +5 -6
  94. {sglang-0.4.5.post3.dist-info → sglang-0.4.6.post1.dist-info}/RECORD +97 -80
  95. {sglang-0.4.5.post3.dist-info → sglang-0.4.6.post1.dist-info}/WHEEL +1 -1
  96. {sglang-0.4.5.post3.dist-info → sglang-0.4.6.post1.dist-info}/licenses/LICENSE +0 -0
  97. {sglang-0.4.5.post3.dist-info → sglang-0.4.6.post1.dist-info}/top_level.txt +0 -0
@@ -35,6 +35,7 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
35
35
  import copy
36
36
  import dataclasses
37
37
  import logging
38
+ import threading
38
39
  from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
39
40
 
40
41
  import numpy as np
@@ -285,6 +286,7 @@ class MultimodalInputs:
285
286
  num_image_tokens: Optional[int] = None
286
287
 
287
288
  # QWen2-VL related
289
+ mrope_positions: Optional[torch.Tensor] = None
288
290
  mrope_position_delta: Optional[torch.Tensor] = None
289
291
 
290
292
  # image
@@ -310,16 +312,12 @@ class MultimodalInputs:
310
312
  assert isinstance(ret.mm_items, list)
311
313
  ret.mm_items = [item for item in ret.mm_items if item.is_valid()]
312
314
 
313
- assert len(ret.mm_items) != 0
314
-
315
- # Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
316
- # Please note that if the `input_ids` is later used in the model forward,
317
- # you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
318
- # errors in cuda kernels. See also llava.py for example.
319
315
  for item in ret.mm_items:
320
316
  item.set_pad_value()
321
317
 
322
318
  optional_args = [
319
+ "mrope_positions",
320
+ "mrope_position_delta",
323
321
  "im_token_id",
324
322
  "im_start_id",
325
323
  "im_end_id",
@@ -350,11 +348,6 @@ class MultimodalInputs:
350
348
  merge image inputs when requests are being merged
351
349
  """
352
350
 
353
- # Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
354
- # Please note that if the `input_ids` is later used in the model forward,
355
- # you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
356
- # errors in cuda kernels. See also llava.py for example.
357
-
358
351
  # args needed to be merged
359
352
  optional_args = [
360
353
  "mm_items",
@@ -364,6 +357,30 @@ class MultimodalInputs:
364
357
  self_arg = getattr(self, arg, None)
365
358
  if self_arg is not None:
366
359
  setattr(self, arg, self_arg + getattr(other, arg))
360
+
361
+ mrope_positions = self.mrope_positions
362
+ if mrope_positions is not None:
363
+ if other.mrope_positions is None:
364
+ self.mrope_positions = mrope_positions
365
+ else:
366
+ self.mrope_positions = torch.cat(
367
+ [self.mrope_positions, other.mrope_positions], dim=1
368
+ )
369
+
370
+ mrope_position_delta = self.mrope_position_delta
371
+ if mrope_position_delta is not None:
372
+ if other.mrope_position_delta is None:
373
+ self.mrope_position_delta = mrope_position_delta
374
+ else:
375
+ self.mrope_position_delta = torch.cat(
376
+ [self.mrope_position_delta, other.mrope_position_delta], dim=0
377
+ )
378
+
379
+ for key, val in other.__dict__.items():
380
+ if "_id" in key:
381
+ # set token_ids
382
+ if getattr(self, key, None) is None:
383
+ setattr(self, key, getattr(other, key, None))
367
384
  # other args would be kept intact
368
385
 
369
386
 
@@ -388,6 +405,7 @@ class Req:
388
405
  return_hidden_states: bool = False,
389
406
  eos_token_ids: Optional[Set[int]] = None,
390
407
  bootstrap_host: Optional[str] = None,
408
+ bootstrap_port: Optional[int] = None,
391
409
  bootstrap_room: Optional[int] = None,
392
410
  ):
393
411
  # Input and output info
@@ -523,6 +541,7 @@ class Req:
523
541
 
524
542
  # For disaggregation
525
543
  self.bootstrap_host: str = bootstrap_host
544
+ self.bootstrap_port: Optional[int] = bootstrap_port
526
545
  self.bootstrap_room: Optional[int] = bootstrap_room
527
546
  self.disagg_kv_sender: Optional[BaseKVSender] = None
528
547
 
@@ -706,6 +725,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
706
725
  # This is an optimization to reduce the overhead of the prefill check.
707
726
  batch_is_full: bool = False
708
727
 
728
+ # Events
729
+ launch_done: Optional[threading.Event] = None
730
+
709
731
  # Sampling info
710
732
  sampling_info: SamplingBatchInfo = None
711
733
  next_batch_sampling_info: SamplingBatchInfo = None
@@ -1450,7 +1472,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1450
1472
  if self.model_config.is_encoder_decoder:
1451
1473
  self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens])
1452
1474
  self.encoder_lens_cpu.extend(other.encoder_lens_cpu)
1453
-
1454
1475
  self.req_pool_indices = torch.cat(
1455
1476
  [self.req_pool_indices, other.req_pool_indices]
1456
1477
  )
@@ -1494,6 +1515,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1494
1515
  )
1495
1516
  or global_server_args_dict["attention_backend"] == "flashmla"
1496
1517
  or global_server_args_dict["attention_backend"] == "fa3"
1518
+ or global_server_args_dict["attention_backend"] == "cutlass_mla"
1497
1519
  ):
1498
1520
  seq_lens_cpu = self.seq_lens.cpu()
1499
1521
  else:
@@ -1548,6 +1570,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1548
1570
  )
1549
1571
  ),
1550
1572
  extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
1573
+ launch_done=self.launch_done,
1551
1574
  )
1552
1575
 
1553
1576
  def copy(self):
@@ -1630,6 +1653,9 @@ class ModelWorkerBatch:
1630
1653
  # If set, the output of the batch contains the hidden states of the run.
1631
1654
  capture_hidden_mode: CaptureHiddenMode = None
1632
1655
 
1656
+ # Overlap event
1657
+ launch_done: Optional[threading.Event] = None
1658
+
1633
1659
 
1634
1660
  @triton.jit
1635
1661
  def write_req_to_token_pool_triton(
@@ -248,9 +248,6 @@ class Scheduler(
248
248
  if not self.is_generation:
249
249
  self.enable_overlap = False
250
250
  logger.info("Overlap scheduler is disabled for embedding models.")
251
- if self.model_config.is_multimodal:
252
- self.enable_overlap = False
253
- logger.info("Overlap scheduler is disabled for multimodal models.")
254
251
 
255
252
  # Launch a tensor parallel worker
256
253
  if self.enable_overlap:
@@ -578,6 +575,10 @@ class Scheduler(
578
575
  bootstrap_port=self.server_args.disaggregation_bootstrap_port,
579
576
  transfer_backend=self.transfer_backend,
580
577
  )
578
+
579
+ # Metric for pre-allocation
580
+ self.num_tokens_pre_allocated = 0
581
+
581
582
  elif self.disaggregation_mode == DisaggregationMode.PREFILL:
582
583
  # *2 for the headroom.
583
584
  buffer_size = self.max_running_requests * 2
@@ -593,7 +594,7 @@ class Scheduler(
593
594
  )
594
595
  metadata_buffers = [output_id_buffer]
595
596
 
596
- self.disagg_prefill_pending_queue = PrefillBootstrapQueue(
597
+ self.disagg_prefill_bootstrap_queue = PrefillBootstrapQueue(
597
598
  token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
598
599
  req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
599
600
  metadata_buffers=metadata_buffers,
@@ -641,6 +642,7 @@ class Scheduler(
641
642
  self.cur_batch = batch
642
643
 
643
644
  if batch:
645
+ batch.launch_done = threading.Event()
644
646
  result = self.run_batch(batch)
645
647
  self.result_queue.append((batch.copy(), result))
646
648
 
@@ -652,7 +654,7 @@ class Scheduler(
652
654
  forward_mode=ForwardMode.DUMMY_FIRST,
653
655
  next_batch_sampling_info=self.tp_worker.cur_sampling_info,
654
656
  )
655
- self.process_batch_result(tmp_batch, None)
657
+ self.process_batch_result(tmp_batch, None, batch.launch_done)
656
658
 
657
659
  if self.last_batch:
658
660
  # Process the results of the last batch
@@ -660,7 +662,10 @@ class Scheduler(
660
662
  tmp_batch.next_batch_sampling_info = (
661
663
  self.tp_worker.cur_sampling_info if batch else None
662
664
  )
663
- self.process_batch_result(tmp_batch, tmp_result)
665
+ # NOTE: we should use current launched batch's launch_done event Instead of the last batch's
666
+ self.process_batch_result(
667
+ tmp_batch, tmp_result, batch.launch_done if batch else None
668
+ )
664
669
  elif batch is None:
665
670
  # When the server is idle, do self-check and re-init some states
666
671
  self.check_memory()
@@ -787,6 +792,7 @@ class Scheduler(
787
792
  return_hidden_states=recv_req.return_hidden_states,
788
793
  eos_token_ids=self.model_config.hf_eos_token_id,
789
794
  bootstrap_host=recv_req.bootstrap_host,
795
+ bootstrap_port=recv_req.bootstrap_port,
790
796
  bootstrap_room=recv_req.bootstrap_room,
791
797
  )
792
798
  req.tokenizer = self.tokenizer
@@ -901,7 +907,7 @@ class Scheduler(
901
907
  def _add_request_to_queue(self, req: Req):
902
908
  req.queue_time_start = time.time()
903
909
  if self.disaggregation_mode == DisaggregationMode.PREFILL:
904
- self.disagg_prefill_pending_queue.add(req)
910
+ self.disagg_prefill_bootstrap_queue.add(req)
905
911
  elif self.disaggregation_mode == DisaggregationMode.DECODE:
906
912
  self.disagg_decode_prealloc_queue.add(req)
907
913
  else:
@@ -991,8 +997,15 @@ class Scheduler(
991
997
  f"#cached-token: {adder.log_hit_tokens}, "
992
998
  f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
993
999
  f"#running-req: {running_bs}, "
994
- f"#queue-req: {len(self.waiting_queue)}, "
995
1000
  )
1001
+
1002
+ if self.disaggregation_mode == DisaggregationMode.PREFILL:
1003
+ f += f"#unbootstrapped-req: {len(self.disagg_prefill_bootstrap_queue.queue)}, "
1004
+ f += f"#queue-req: {len(self.waiting_queue)}, "
1005
+ f += f"#transferring-req: {len(self.disagg_prefill_inflight_queue)} "
1006
+ else:
1007
+ f += f"#queue-req: {len(self.waiting_queue)}"
1008
+
996
1009
  logger.info(f)
997
1010
 
998
1011
  if self.enable_metrics:
@@ -1028,15 +1041,14 @@ class Scheduler(
1028
1041
  gap_latency / self.server_args.decode_log_interval
1029
1042
  )
1030
1043
 
1044
+ msg = (
1045
+ f"Decode batch. "
1046
+ f"#running-req: {num_running_reqs}, "
1047
+ f"#token: {num_used}, "
1048
+ f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
1049
+ )
1050
+
1031
1051
  if self.spec_algorithm.is_none():
1032
- msg = (
1033
- f"Decode batch. "
1034
- f"#running-req: {num_running_reqs}, "
1035
- f"#token: {num_used}, "
1036
- f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
1037
- f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
1038
- f"#queue-req: {len(self.waiting_queue)}, "
1039
- )
1040
1052
  spec_accept_length = 0
1041
1053
  else:
1042
1054
  spec_accept_length = (
@@ -1045,15 +1057,15 @@ class Scheduler(
1045
1057
  self.cum_spec_accept_length += self.spec_num_total_accepted_tokens
1046
1058
  self.cum_spec_accept_count += self.spec_num_total_forward_ct
1047
1059
  self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
1048
- msg = (
1049
- f"Decode batch. "
1050
- f"#running-req: {num_running_reqs}, "
1051
- f"#token: {num_used}, "
1052
- f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
1053
- f"accept len: {spec_accept_length:.2f}, "
1054
- f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
1055
- f"#queue-req: {len(self.waiting_queue)}, "
1056
- )
1060
+ msg += f"accept len: {spec_accept_length:.2f}, "
1061
+
1062
+ if self.disaggregation_mode == DisaggregationMode.DECODE:
1063
+ msg += f"pre-allocated usage: {self.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "
1064
+
1065
+ msg += (
1066
+ f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
1067
+ f"#queue-req: {len(self.waiting_queue)}"
1068
+ )
1057
1069
 
1058
1070
  logger.info(msg)
1059
1071
  if self.enable_metrics:
@@ -1406,14 +1418,15 @@ class Scheduler(
1406
1418
  self,
1407
1419
  batch: ScheduleBatch,
1408
1420
  result: Union[GenerationBatchResult, EmbeddingBatchResult],
1421
+ launch_done: Optional[threading.Event] = None,
1409
1422
  ):
1410
1423
  if batch.forward_mode.is_decode():
1411
- self.process_batch_result_decode(batch, result)
1424
+ self.process_batch_result_decode(batch, result, launch_done)
1412
1425
  elif batch.forward_mode.is_extend():
1413
- self.process_batch_result_prefill(batch, result)
1426
+ self.process_batch_result_prefill(batch, result, launch_done)
1414
1427
  elif batch.forward_mode.is_idle():
1415
1428
  if self.enable_overlap:
1416
- self.tp_worker.resolve_batch_result(result.bid)
1429
+ self.tp_worker.resolve_last_batch_result(launch_done)
1417
1430
  if batch.next_batch_sampling_info:
1418
1431
  batch.next_batch_sampling_info.update_regex_vocab_mask()
1419
1432
  self.current_stream.synchronize()
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import threading
3
4
  from typing import TYPE_CHECKING, List, Optional, Tuple, Union
4
5
 
5
6
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
@@ -11,6 +12,7 @@ if TYPE_CHECKING:
11
12
  EmbeddingBatchResult,
12
13
  GenerationBatchResult,
13
14
  ScheduleBatch,
15
+ Scheduler,
14
16
  )
15
17
 
16
18
 
@@ -21,9 +23,10 @@ class SchedulerOutputProcessorMixin:
21
23
  """
22
24
 
23
25
  def process_batch_result_prefill(
24
- self,
26
+ self: Scheduler,
25
27
  batch: ScheduleBatch,
26
28
  result: Union[GenerationBatchResult, EmbeddingBatchResult],
29
+ launch_done: Optional[threading.Event] = None,
27
30
  ):
28
31
  skip_stream_req = None
29
32
 
@@ -43,7 +46,11 @@ class SchedulerOutputProcessorMixin:
43
46
  )
44
47
 
45
48
  if self.enable_overlap:
46
- logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
49
+ logits_output, next_token_ids = (
50
+ self.tp_worker.resolve_last_batch_result(
51
+ launch_done,
52
+ )
53
+ )
47
54
  else:
48
55
  # Move next_token_ids and logprobs to cpu
49
56
  next_token_ids = next_token_ids.tolist()
@@ -175,9 +182,10 @@ class SchedulerOutputProcessorMixin:
175
182
  self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
176
183
 
177
184
  def process_batch_result_decode(
178
- self,
185
+ self: Scheduler,
179
186
  batch: ScheduleBatch,
180
187
  result: GenerationBatchResult,
188
+ launch_done: Optional[threading.Event] = None,
181
189
  ):
182
190
  logits_output, next_token_ids, bid = (
183
191
  result.logits_output,
@@ -187,7 +195,9 @@ class SchedulerOutputProcessorMixin:
187
195
  self.num_generated_tokens += len(batch.reqs)
188
196
 
189
197
  if self.enable_overlap:
190
- logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
198
+ logits_output, next_token_ids = self.tp_worker.resolve_last_batch_result(
199
+ launch_done
200
+ )
191
201
  next_token_logprobs = logits_output.next_token_logprobs
192
202
  elif batch.spec_algorithm.is_none():
193
203
  # spec decoding handles output logprobs inside verify process.
@@ -271,7 +281,7 @@ class SchedulerOutputProcessorMixin:
271
281
  self.log_decode_stats()
272
282
 
273
283
  def add_input_logprob_return_values(
274
- self,
284
+ self: Scheduler,
275
285
  i: int,
276
286
  req: Req,
277
287
  output: LogitsProcessorOutput,
@@ -405,7 +415,7 @@ class SchedulerOutputProcessorMixin:
405
415
  assert len(req.input_token_ids_logprobs_idx) == relevant_tokens_len
406
416
 
407
417
  def add_logprob_return_values(
408
- self,
418
+ self: Scheduler,
409
419
  i: int,
410
420
  req: Req,
411
421
  pt: int,
@@ -436,7 +446,10 @@ class SchedulerOutputProcessorMixin:
436
446
  return num_input_logprobs
437
447
 
438
448
  def stream_output(
439
- self, reqs: List[Req], return_logprob: bool, skip_req: Optional[Req] = None
449
+ self: Scheduler,
450
+ reqs: List[Req],
451
+ return_logprob: bool,
452
+ skip_req: Optional[Req] = None,
440
453
  ):
441
454
  """Stream the output to detokenizer."""
442
455
  if self.is_generation:
@@ -445,7 +458,10 @@ class SchedulerOutputProcessorMixin:
445
458
  self.stream_output_embedding(reqs)
446
459
 
447
460
  def stream_output_generation(
448
- self, reqs: List[Req], return_logprob: bool, skip_req: Optional[Req] = None
461
+ self: Scheduler,
462
+ reqs: List[Req],
463
+ return_logprob: bool,
464
+ skip_req: Optional[Req] = None,
449
465
  ):
450
466
  rids = []
451
467
  finished_reasons: List[BaseFinishReason] = []
@@ -593,7 +609,7 @@ class SchedulerOutputProcessorMixin:
593
609
  )
594
610
  )
595
611
 
596
- def stream_output_embedding(self, reqs: List[Req]):
612
+ def stream_output_embedding(self: Scheduler, reqs: List[Req]):
597
613
  rids = []
598
614
  finished_reasons: List[BaseFinishReason] = []
599
615
 
@@ -419,7 +419,10 @@ class TokenizerManager:
419
419
  input_ids = self.tokenizer.encode(input_text)
420
420
 
421
421
  image_inputs: Dict = await self.mm_processor.process_mm_data_async(
422
- obj.image_data, input_text or input_ids, obj, self.max_req_input_len
422
+ image_data=obj.image_data,
423
+ input_text=input_text or input_ids,
424
+ request_obj=obj,
425
+ max_req_input_len=self.max_req_input_len,
423
426
  )
424
427
  if image_inputs and "input_ids" in image_inputs:
425
428
  input_ids = image_inputs["input_ids"]
@@ -495,6 +498,7 @@ class TokenizerManager:
495
498
  token_ids_logprob,
496
499
  obj.stream,
497
500
  bootstrap_host=obj.bootstrap_host,
501
+ bootstrap_port=obj.bootstrap_port,
498
502
  bootstrap_room=obj.bootstrap_room,
499
503
  lora_path=obj.lora_path,
500
504
  input_embeds=input_embeds,
@@ -170,13 +170,13 @@ class TpModelWorker:
170
170
  def forward_batch_generation(
171
171
  self,
172
172
  model_worker_batch: ModelWorkerBatch,
173
- launch_done: Optional[threading.Event] = None,
174
173
  skip_sample: bool = False,
175
174
  ) -> Tuple[LogitsProcessorOutput, Optional[torch.Tensor]]:
176
175
  forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
177
176
  logits_output = self.model_runner.forward(forward_batch)
178
- if launch_done:
179
- launch_done.set()
177
+
178
+ if model_worker_batch.launch_done is not None:
179
+ model_worker_batch.launch_done.set()
180
180
 
181
181
  if skip_sample:
182
182
  next_token_ids = None
@@ -132,7 +132,6 @@ class TpModelWorkerClient:
132
132
  batch_pt += 1
133
133
 
134
134
  # Create event
135
- self.launch_done = threading.Event()
136
135
  copy_done = torch.get_device_module(self.device).Event()
137
136
 
138
137
  # Resolve future tokens in the input
@@ -141,7 +140,7 @@ class TpModelWorkerClient:
141
140
 
142
141
  # Run forward
143
142
  logits_output, next_token_ids = self.worker.forward_batch_generation(
144
- model_worker_batch, self.launch_done
143
+ model_worker_batch
145
144
  )
146
145
 
147
146
  # Update the future token ids map
@@ -168,10 +167,16 @@ class TpModelWorkerClient:
168
167
 
169
168
  self.output_queue.put((copy_done, logits_output, next_token_ids))
170
169
 
171
- def resolve_batch_result(self, bid: int):
170
+ def resolve_last_batch_result(self, launch_done: Optional[threading.Event] = None):
171
+ """
172
+ This function is called to resolve the last batch result and
173
+ wait for the current batch to be launched. Used in overlap mode.
174
+ """
172
175
  copy_done, logits_output, next_token_ids = self.output_queue.get()
176
+
177
+ if launch_done is not None:
178
+ launch_done.wait()
173
179
  copy_done.synchronize()
174
- self.launch_done.wait()
175
180
 
176
181
  if logits_output.next_token_logprobs is not None:
177
182
  logits_output.next_token_logprobs = (
@@ -34,6 +34,8 @@ from typing import List, Optional, Tuple, Union
34
34
  import numpy as np
35
35
  import psutil
36
36
  import torch
37
+ import triton
38
+ import triton.language as tl
37
39
 
38
40
  from sglang.srt.layers.radix_attention import RadixAttention
39
41
  from sglang.srt.utils import debug_timing, get_compiler_backend
@@ -405,6 +407,72 @@ def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype):
405
407
  dst_2[loc] = src_2.to(dtype).view(store_dtype)
406
408
 
407
409
 
410
+ @triton.jit
411
+ def set_mla_kv_buffer_kernel(
412
+ kv_buffer_ptr,
413
+ cache_k_nope_ptr,
414
+ cache_k_rope_ptr,
415
+ loc_ptr,
416
+ buffer_stride: tl.constexpr,
417
+ nope_stride: tl.constexpr,
418
+ rope_stride: tl.constexpr,
419
+ nope_dim: tl.constexpr,
420
+ rope_dim: tl.constexpr,
421
+ BLOCK: tl.constexpr,
422
+ ):
423
+ pid_loc = tl.program_id(0)
424
+ pid_blk = tl.program_id(1)
425
+
426
+ base = pid_blk * BLOCK
427
+ offs = base + tl.arange(0, BLOCK)
428
+ total_dim = nope_dim + rope_dim
429
+ mask = offs < total_dim
430
+
431
+ loc = tl.load(loc_ptr + pid_loc)
432
+ dst_ptr = kv_buffer_ptr + loc * buffer_stride + offs
433
+
434
+ if base + BLOCK <= nope_dim:
435
+ src = tl.load(
436
+ cache_k_nope_ptr + pid_loc * nope_stride + offs,
437
+ mask=mask,
438
+ )
439
+ else:
440
+ offs_rope = offs - nope_dim
441
+ src = tl.load(
442
+ cache_k_rope_ptr + pid_loc * rope_stride + offs_rope,
443
+ mask=mask,
444
+ )
445
+
446
+ tl.store(dst_ptr, src, mask=mask)
447
+
448
+
449
+ def set_mla_kv_buffer_triton(
450
+ kv_buffer: torch.Tensor,
451
+ loc: torch.Tensor,
452
+ cache_k_nope: torch.Tensor,
453
+ cache_k_rope: torch.Tensor,
454
+ ):
455
+ nope_dim = cache_k_nope.shape[-1]
456
+ rope_dim = cache_k_rope.shape[-1]
457
+ total_dim = nope_dim + rope_dim
458
+ BLOCK = 128
459
+ n_loc = loc.numel()
460
+ grid = (n_loc, triton.cdiv(total_dim, BLOCK))
461
+
462
+ set_mla_kv_buffer_kernel[grid](
463
+ kv_buffer,
464
+ cache_k_nope,
465
+ cache_k_rope,
466
+ loc,
467
+ kv_buffer.stride(0),
468
+ cache_k_nope.stride(0),
469
+ cache_k_rope.stride(0),
470
+ nope_dim,
471
+ rope_dim,
472
+ BLOCK=BLOCK,
473
+ )
474
+
475
+
408
476
  class MLATokenToKVPool(KVCache):
409
477
  def __init__(
410
478
  self,
@@ -504,6 +572,25 @@ class MLATokenToKVPool(KVCache):
504
572
  else:
505
573
  self.kv_buffer[layer_id][loc] = cache_k
506
574
 
575
+ def set_mla_kv_buffer(
576
+ self,
577
+ layer: RadixAttention,
578
+ loc: torch.Tensor,
579
+ cache_k_nope: torch.Tensor,
580
+ cache_k_rope: torch.Tensor,
581
+ ):
582
+ layer_id = layer.layer_id
583
+ if cache_k_nope.dtype != self.dtype:
584
+ cache_k_nope = cache_k_nope.to(self.dtype)
585
+ cache_k_rope = cache_k_rope.to(self.dtype)
586
+ if self.store_dtype != self.dtype:
587
+ cache_k_nope = cache_k_nope.view(self.store_dtype)
588
+ cache_k_rope = cache_k_rope.view(self.store_dtype)
589
+
590
+ set_mla_kv_buffer_triton(
591
+ self.kv_buffer[layer_id], loc, cache_k_nope, cache_k_rope
592
+ )
593
+
507
594
  def get_flat_data(self, indices):
508
595
  # prepare a large chunk of contiguous data for efficient transfer
509
596
  return torch.stack([self.kv_buffer[i][indices] for i in range(self.layer_num)])
@@ -134,7 +134,8 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
134
134
  )
135
135
 
136
136
  gpu_mem = get_device_memory_capacity()
137
- if gpu_mem is not None and gpu_mem > 81920:
137
+ # Batch size of each rank will not become so large when DP is on
138
+ if gpu_mem is not None and gpu_mem > 81920 and server_args.dp_size == 1:
138
139
  capture_bs += list(range(160, 257, 8))
139
140
 
140
141
  if max(capture_bs) > model_runner.req_to_token_pool.size:
@@ -278,9 +279,9 @@ class CudaGraphRunner:
278
279
  f"Capture cuda graph failed: {e}\n"
279
280
  "Possible solutions:\n"
280
281
  "1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n"
281
- "2. set --cuda-graph-max-bs to a smaller value (e.g., 32)\n"
282
+ "2. set --cuda-graph-max-bs to a smaller value (e.g., 16)\n"
282
283
  "3. disable torch compile by not using --enable-torch-compile\n"
283
- "4. disable cuda graph by --disable-cuda-graph\n"
284
+ "4. disable cuda graph by --disable-cuda-graph. (Not recommonded. Huge perf loss)\n"
284
285
  "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
285
286
  )
286
287