sglang 0.4.6.post1__py3-none-any.whl → 0.4.6.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. sglang/bench_one_batch.py +2 -0
  2. sglang/check_env.py +3 -3
  3. sglang/srt/configs/__init__.py +4 -0
  4. sglang/srt/configs/kimi_vl.py +38 -0
  5. sglang/srt/configs/kimi_vl_moonvit.py +32 -0
  6. sglang/srt/configs/model_config.py +15 -0
  7. sglang/srt/conversation.py +122 -1
  8. sglang/srt/entrypoints/engine.py +44 -22
  9. sglang/srt/function_call_parser.py +97 -0
  10. sglang/srt/hf_transformers_utils.py +2 -0
  11. sglang/srt/layers/attention/cutlass_mla_backend.py +1 -1
  12. sglang/srt/layers/attention/flashinfer_backend.py +107 -82
  13. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -16
  14. sglang/srt/layers/attention/flashmla_backend.py +3 -0
  15. sglang/srt/layers/dp_attention.py +5 -2
  16. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +1 -3
  17. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  18. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  19. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  20. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  21. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +8 -6
  22. sglang/srt/layers/quantization/__init__.py +2 -2
  23. sglang/srt/layers/quantization/deep_gemm.py +1 -1
  24. sglang/srt/layers/utils.py +35 -0
  25. sglang/srt/lora/layers.py +35 -9
  26. sglang/srt/lora/lora_manager.py +84 -35
  27. sglang/srt/managers/data_parallel_controller.py +52 -34
  28. sglang/srt/managers/multimodal_processors/kimi_vl.py +73 -0
  29. sglang/srt/managers/schedule_batch.py +25 -15
  30. sglang/srt/managers/scheduler.py +263 -59
  31. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -1
  32. sglang/srt/managers/tp_worker.py +51 -16
  33. sglang/srt/managers/tp_worker_overlap_thread.py +9 -3
  34. sglang/srt/mem_cache/memory_pool.py +70 -36
  35. sglang/srt/model_executor/cuda_graph_runner.py +82 -19
  36. sglang/srt/model_executor/forward_batch_info.py +31 -1
  37. sglang/srt/model_executor/model_runner.py +115 -57
  38. sglang/srt/models/deepseek_nextn.py +1 -257
  39. sglang/srt/models/deepseek_v2.py +78 -18
  40. sglang/srt/models/kimi_vl.py +308 -0
  41. sglang/srt/models/kimi_vl_moonvit.py +639 -0
  42. sglang/srt/models/llama.py +92 -30
  43. sglang/srt/models/llama4.py +2 -1
  44. sglang/srt/models/llama_eagle.py +4 -1
  45. sglang/srt/models/llama_eagle3.py +4 -1
  46. sglang/srt/models/qwen2_moe.py +8 -3
  47. sglang/srt/models/qwen2_vl.py +0 -12
  48. sglang/srt/models/qwen3_moe.py +8 -3
  49. sglang/srt/openai_api/adapter.py +34 -22
  50. sglang/srt/openai_api/protocol.py +11 -1
  51. sglang/srt/server_args.py +67 -22
  52. sglang/srt/speculative/eagle_worker.py +3 -2
  53. sglang/srt/utils.py +88 -9
  54. sglang/test/runners.py +4 -0
  55. sglang/test/test_utils.py +29 -0
  56. sglang/version.py +1 -1
  57. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post2.dist-info}/METADATA +5 -4
  58. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post2.dist-info}/RECORD +61 -51
  59. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post2.dist-info}/WHEEL +1 -1
  60. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post2.dist-info}/licenses/LICENSE +0 -0
  61. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post2.dist-info}/top_level.txt +0 -0
@@ -15,11 +15,12 @@
15
15
 
16
16
  import logging
17
17
  import threading
18
- from typing import Optional, Tuple
18
+ from typing import Optional, Tuple, Union
19
19
 
20
20
  import torch
21
21
 
22
22
  from sglang.srt.configs.model_config import ModelConfig
23
+ from sglang.srt.distributed import get_pp_group, get_tp_group, get_world_group
23
24
  from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
24
25
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
25
26
  from sglang.srt.managers.io_struct import (
@@ -31,7 +32,7 @@ from sglang.srt.managers.io_struct import (
31
32
  )
32
33
  from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
33
34
  from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
34
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch
35
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
35
36
  from sglang.srt.model_executor.model_runner import ModelRunner
36
37
  from sglang.srt.server_args import ServerArgs
37
38
  from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj, set_random_seed
@@ -47,6 +48,7 @@ class TpModelWorker:
47
48
  server_args: ServerArgs,
48
49
  gpu_id: int,
49
50
  tp_rank: int,
51
+ pp_rank: int,
50
52
  dp_rank: Optional[int],
51
53
  nccl_port: int,
52
54
  is_draft_worker: bool = False,
@@ -54,7 +56,9 @@ class TpModelWorker:
54
56
  token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None,
55
57
  ):
56
58
  # Parse args
59
+ self.tp_size = server_args.tp_size
57
60
  self.tp_rank = tp_rank
61
+ self.pp_rank = pp_rank
58
62
 
59
63
  # Init model and tokenizer
60
64
  self.model_config = ModelConfig(
@@ -71,13 +75,17 @@ class TpModelWorker:
71
75
  enable_multimodal=server_args.enable_multimodal,
72
76
  dtype=server_args.dtype,
73
77
  quantization=server_args.quantization,
78
+ is_draft_model=is_draft_worker,
74
79
  )
80
+
75
81
  self.model_runner = ModelRunner(
76
82
  model_config=self.model_config,
77
83
  mem_fraction_static=server_args.mem_fraction_static,
78
84
  gpu_id=gpu_id,
79
85
  tp_rank=tp_rank,
80
86
  tp_size=server_args.tp_size,
87
+ pp_rank=pp_rank,
88
+ pp_size=server_args.pp_size,
81
89
  nccl_port=nccl_port,
82
90
  server_args=server_args,
83
91
  is_draft_worker=is_draft_worker,
@@ -104,6 +112,10 @@ class TpModelWorker:
104
112
  )
105
113
  self.device = self.model_runner.device
106
114
 
115
+ # Init nccl groups
116
+ self.pp_group = get_pp_group()
117
+ self.world_group = get_world_group()
118
+
107
119
  # Profile number of tokens
108
120
  self.max_total_num_tokens = self.model_runner.max_total_num_tokens
109
121
  self.max_prefill_tokens = server_args.max_prefill_tokens
@@ -129,8 +141,9 @@ class TpModelWorker:
129
141
  # Sync random seed across TP workers
130
142
  self.random_seed = broadcast_pyobj(
131
143
  [server_args.random_seed],
132
- self.tp_rank,
133
- self.model_runner.tp_group.cpu_group,
144
+ self.tp_size * self.pp_rank + tp_rank,
145
+ self.world_group.cpu_group,
146
+ src=self.world_group.ranks[0],
134
147
  )[0]
135
148
  set_random_seed(self.random_seed)
136
149
 
@@ -155,11 +168,14 @@ class TpModelWorker:
155
168
  def get_pad_input_ids_func(self):
156
169
  return getattr(self.model_runner.model, "pad_input_ids", None)
157
170
 
158
- def get_tp_cpu_group(self):
159
- return self.model_runner.tp_group.cpu_group
171
+ def get_tp_group(self):
172
+ return self.model_runner.tp_group
173
+
174
+ def get_attention_tp_group(self):
175
+ return self.model_runner.attention_tp_group
160
176
 
161
177
  def get_attention_tp_cpu_group(self):
162
- return self.model_runner.attention_tp_group.cpu_group
178
+ return getattr(self.model_runner.attention_tp_group, "cpu_group", None)
163
179
 
164
180
  def get_memory_pool(self):
165
181
  return (
@@ -171,19 +187,38 @@ class TpModelWorker:
171
187
  self,
172
188
  model_worker_batch: ModelWorkerBatch,
173
189
  skip_sample: bool = False,
174
- ) -> Tuple[LogitsProcessorOutput, Optional[torch.Tensor]]:
190
+ ) -> Tuple[Union[LogitsProcessorOutput, torch.Tensor], Optional[torch.Tensor]]:
175
191
  forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
176
- logits_output = self.model_runner.forward(forward_batch)
177
192
 
178
- if model_worker_batch.launch_done is not None:
179
- model_worker_batch.launch_done.set()
193
+ pp_proxy_tensors = None
194
+ if not self.pp_group.is_first_rank:
195
+ pp_proxy_tensors = PPProxyTensors(
196
+ self.pp_group.recv_tensor_dict(
197
+ all_gather_group=self.get_attention_tp_group()
198
+ )
199
+ )
200
+
201
+ if self.pp_group.is_last_rank:
202
+ logits_output = self.model_runner.forward(
203
+ forward_batch, pp_proxy_tensors=pp_proxy_tensors
204
+ )
205
+ if model_worker_batch.launch_done is not None:
206
+ model_worker_batch.launch_done.set()
180
207
 
181
- if skip_sample:
182
- next_token_ids = None
183
- else:
184
- next_token_ids = self.model_runner.sample(logits_output, model_worker_batch)
208
+ if skip_sample:
209
+ next_token_ids = None
210
+ else:
211
+ next_token_ids = self.model_runner.sample(
212
+ logits_output, model_worker_batch
213
+ )
185
214
 
186
- return logits_output, next_token_ids
215
+ return logits_output, next_token_ids
216
+ else:
217
+ pp_proxy_tensors = self.model_runner.forward(
218
+ forward_batch,
219
+ pp_proxy_tensors=pp_proxy_tensors,
220
+ )
221
+ return pp_proxy_tensors.tensors, None
187
222
 
188
223
  def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
189
224
  forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
@@ -56,11 +56,14 @@ class TpModelWorkerClient:
56
56
  server_args: ServerArgs,
57
57
  gpu_id: int,
58
58
  tp_rank: int,
59
+ pp_rank: int,
59
60
  dp_rank: Optional[int],
60
61
  nccl_port: int,
61
62
  ):
62
63
  # Load the model
63
- self.worker = TpModelWorker(server_args, gpu_id, tp_rank, dp_rank, nccl_port)
64
+ self.worker = TpModelWorker(
65
+ server_args, gpu_id, tp_rank, pp_rank, dp_rank, nccl_port
66
+ )
64
67
  self.max_running_requests = self.worker.max_running_requests
65
68
  self.device = self.worker.device
66
69
  self.gpu_id = gpu_id
@@ -91,8 +94,11 @@ class TpModelWorkerClient:
91
94
  def get_pad_input_ids_func(self):
92
95
  return self.worker.get_pad_input_ids_func()
93
96
 
94
- def get_tp_cpu_group(self):
95
- return self.worker.get_tp_cpu_group()
97
+ def get_tp_group(self):
98
+ return self.worker.get_tp_group()
99
+
100
+ def get_attention_tp_group(self):
101
+ return self.worker.get_attention_tp_group()
96
102
 
97
103
  def get_attention_tp_cpu_group(self):
98
104
  return self.worker.get_attention_tp_cpu_group()
@@ -214,6 +214,8 @@ class MHATokenToKVPool(KVCache):
214
214
  layer_num: int,
215
215
  device: str,
216
216
  enable_memory_saver: bool,
217
+ start_layer: Optional[int] = None,
218
+ end_layer: Optional[int] = None,
217
219
  ):
218
220
  self.size = size
219
221
  self.page_size = page_size
@@ -232,6 +234,8 @@ class MHATokenToKVPool(KVCache):
232
234
  self.head_dim = head_dim
233
235
  self.layer_num = layer_num
234
236
  self._create_buffers()
237
+ self.start_layer = start_layer or 0
238
+ self.end_layer = end_layer or layer_num - 1
235
239
 
236
240
  self.layer_transfer_counter = None
237
241
  self.capture_mode = False
@@ -281,6 +285,8 @@ class MHATokenToKVPool(KVCache):
281
285
 
282
286
  # for disagg
283
287
  def get_contiguous_buf_infos(self):
288
+ # layer_num x [seq_len, head_num, head_dim]
289
+ # layer_num x [page_num, page_size, head_num, head_dim]
284
290
  kv_data_ptrs = [
285
291
  self.get_key_buffer(i).data_ptr() for i in range(self.layer_num)
286
292
  ] + [self.get_value_buffer(i).data_ptr() for i in range(self.layer_num)]
@@ -320,24 +326,24 @@ class MHATokenToKVPool(KVCache):
320
326
  # transfer prepared data from host to device
321
327
  flat_data = flat_data.to(device=self.device, non_blocking=False)
322
328
  k_data, v_data = flat_data[0], flat_data[1]
323
- self.k_buffer[layer_id][indices] = k_data
324
- self.v_buffer[layer_id][indices] = v_data
329
+ self.k_buffer[layer_id - self.start_layer][indices] = k_data
330
+ self.v_buffer[layer_id - self.start_layer][indices] = v_data
325
331
 
326
332
  def get_key_buffer(self, layer_id: int):
327
333
  if self.layer_transfer_counter is not None:
328
- self.layer_transfer_counter.wait_until(layer_id)
334
+ self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
329
335
 
330
336
  if self.store_dtype != self.dtype:
331
- return self.k_buffer[layer_id].view(self.dtype)
332
- return self.k_buffer[layer_id]
337
+ return self.k_buffer[layer_id - self.start_layer].view(self.dtype)
338
+ return self.k_buffer[layer_id - self.start_layer]
333
339
 
334
340
  def get_value_buffer(self, layer_id: int):
335
341
  if self.layer_transfer_counter is not None:
336
- self.layer_transfer_counter.wait_until(layer_id)
342
+ self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
337
343
 
338
344
  if self.store_dtype != self.dtype:
339
- return self.v_buffer[layer_id].view(self.dtype)
340
- return self.v_buffer[layer_id]
345
+ return self.v_buffer[layer_id - self.start_layer].view(self.dtype)
346
+ return self.v_buffer[layer_id - self.start_layer]
341
347
 
342
348
  def get_kv_buffer(self, layer_id: int):
343
349
  return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
@@ -369,12 +375,12 @@ class MHATokenToKVPool(KVCache):
369
375
  current_stream = self.device_module.current_stream()
370
376
  self.alt_stream.wait_stream(current_stream)
371
377
  with self.device_module.stream(self.alt_stream):
372
- self.k_buffer[layer_id][loc] = cache_k
373
- self.v_buffer[layer_id][loc] = cache_v
378
+ self.k_buffer[layer_id - self.start_layer][loc] = cache_k
379
+ self.v_buffer[layer_id - self.start_layer][loc] = cache_v
374
380
  current_stream.wait_stream(self.alt_stream)
375
381
  else:
376
- self.k_buffer[layer_id][loc] = cache_k
377
- self.v_buffer[layer_id][loc] = cache_v
382
+ self.k_buffer[layer_id - self.start_layer][loc] = cache_k
383
+ self.v_buffer[layer_id - self.start_layer][loc] = cache_v
378
384
 
379
385
 
380
386
  @torch.compile
@@ -484,6 +490,8 @@ class MLATokenToKVPool(KVCache):
484
490
  layer_num: int,
485
491
  device: str,
486
492
  enable_memory_saver: bool,
493
+ start_layer: Optional[int] = None,
494
+ end_layer: Optional[int] = None,
487
495
  ):
488
496
  self.size = size
489
497
  self.page_size = page_size
@@ -497,6 +505,8 @@ class MLATokenToKVPool(KVCache):
497
505
  self.kv_lora_rank = kv_lora_rank
498
506
  self.qk_rope_head_dim = qk_rope_head_dim
499
507
  self.layer_num = layer_num
508
+ self.start_layer = start_layer or 0
509
+ self.end_layer = end_layer or layer_num - 1
500
510
 
501
511
  memory_saver_adapter = TorchMemorySaverAdapter.create(
502
512
  enable=enable_memory_saver
@@ -540,19 +550,21 @@ class MLATokenToKVPool(KVCache):
540
550
 
541
551
  def get_key_buffer(self, layer_id: int):
542
552
  if self.layer_transfer_counter is not None:
543
- self.layer_transfer_counter.wait_until(layer_id)
553
+ self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
544
554
 
545
555
  if self.store_dtype != self.dtype:
546
- return self.kv_buffer[layer_id].view(self.dtype)
547
- return self.kv_buffer[layer_id]
556
+ return self.kv_buffer[layer_id - self.start_layer].view(self.dtype)
557
+ return self.kv_buffer[layer_id - self.start_layer]
548
558
 
549
559
  def get_value_buffer(self, layer_id: int):
550
560
  if self.layer_transfer_counter is not None:
551
- self.layer_transfer_counter.wait_until(layer_id)
561
+ self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
552
562
 
553
563
  if self.store_dtype != self.dtype:
554
- return self.kv_buffer[layer_id][..., : self.kv_lora_rank].view(self.dtype)
555
- return self.kv_buffer[layer_id][..., : self.kv_lora_rank]
564
+ return self.kv_buffer[layer_id - self.start_layer][
565
+ ..., : self.kv_lora_rank
566
+ ].view(self.dtype)
567
+ return self.kv_buffer[layer_id - self.start_layer][..., : self.kv_lora_rank]
556
568
 
557
569
  def get_kv_buffer(self, layer_id: int):
558
570
  return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
@@ -568,9 +580,11 @@ class MLATokenToKVPool(KVCache):
568
580
  if cache_k.dtype != self.dtype:
569
581
  cache_k = cache_k.to(self.dtype)
570
582
  if self.store_dtype != self.dtype:
571
- self.kv_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
583
+ self.kv_buffer[layer_id - self.start_layer][loc] = cache_k.view(
584
+ self.store_dtype
585
+ )
572
586
  else:
573
- self.kv_buffer[layer_id][loc] = cache_k
587
+ self.kv_buffer[layer_id - self.start_layer][loc] = cache_k
574
588
 
575
589
  def set_mla_kv_buffer(
576
590
  self,
@@ -605,7 +619,7 @@ class MLATokenToKVPool(KVCache):
605
619
  def transfer_per_layer(self, indices, flat_data, layer_id):
606
620
  # transfer prepared data from host to device
607
621
  flat_data = flat_data.to(device=self.device, non_blocking=False)
608
- self.kv_buffer[layer_id][indices] = flat_data
622
+ self.kv_buffer[layer_id - self.start_layer][indices] = flat_data
609
623
 
610
624
 
611
625
  class DoubleSparseTokenToKVPool(KVCache):
@@ -620,6 +634,8 @@ class DoubleSparseTokenToKVPool(KVCache):
620
634
  device: str,
621
635
  heavy_channel_num: int,
622
636
  enable_memory_saver: bool,
637
+ start_layer: Optional[int] = None,
638
+ end_layer: Optional[int] = None,
623
639
  ):
624
640
  self.size = size
625
641
  self.page_size = page_size
@@ -657,17 +673,23 @@ class DoubleSparseTokenToKVPool(KVCache):
657
673
  for _ in range(layer_num)
658
674
  ]
659
675
 
676
+ self.start_layer = start_layer or 0
677
+ self.end_layer = end_layer or layer_num - 1
678
+
660
679
  def get_key_buffer(self, layer_id: int):
661
- return self.k_buffer[layer_id]
680
+ return self.k_buffer[layer_id - self.start_layer]
662
681
 
663
682
  def get_value_buffer(self, layer_id: int):
664
- return self.v_buffer[layer_id]
683
+ return self.v_buffer[layer_id - self.start_layer]
665
684
 
666
685
  def get_label_buffer(self, layer_id: int):
667
- return self.label_buffer[layer_id]
686
+ return self.label_buffer[layer_id - self.start_layer]
668
687
 
669
688
  def get_kv_buffer(self, layer_id: int):
670
- return self.k_buffer[layer_id], self.v_buffer[layer_id]
689
+ return (
690
+ self.k_buffer[layer_id - self.start_layer],
691
+ self.v_buffer[layer_id - self.start_layer],
692
+ )
671
693
 
672
694
  def set_kv_buffer(
673
695
  self,
@@ -679,9 +701,9 @@ class DoubleSparseTokenToKVPool(KVCache):
679
701
  ):
680
702
  # NOTE(Andy): ignore the dtype check
681
703
  layer_id = layer.layer_id
682
- self.k_buffer[layer_id][loc] = cache_k
683
- self.v_buffer[layer_id][loc] = cache_v
684
- self.label_buffer[layer_id][loc] = cache_label
704
+ self.k_buffer[layer_id - self.start_layer][loc] = cache_k
705
+ self.v_buffer[layer_id - self.start_layer][loc] = cache_v
706
+ self.label_buffer[layer_id - self.start_layer][loc] = cache_label
685
707
 
686
708
  def get_flat_data(self, indices):
687
709
  pass
@@ -930,7 +952,7 @@ class MHATokenToKVPoolHost(HostKVCache):
930
952
  return self.kv_buffer[:, :, indices]
931
953
 
932
954
  def get_flat_data_by_layer(self, indices, layer_id):
933
- return self.kv_buffer[:, layer_id, indices]
955
+ return self.kv_buffer[:, layer_id - self.start_layer, indices]
934
956
 
935
957
  def assign_flat_data(self, indices, flat_data):
936
958
  self.kv_buffer[:, :, indices] = flat_data
@@ -955,12 +977,20 @@ class MHATokenToKVPoolHost(HostKVCache):
955
977
  for i in range(len(device_indices_cpu)):
956
978
  h_index = host_indices[i * self.page_size]
957
979
  d_index = device_indices_cpu[i]
958
- device_pool.k_buffer[layer_id][d_index : d_index + self.page_size].copy_(
959
- self.kv_buffer[0, layer_id, h_index : h_index + self.page_size],
980
+ device_pool.k_buffer[layer_id - self.start_layer][
981
+ d_index : d_index + self.page_size
982
+ ].copy_(
983
+ self.kv_buffer[
984
+ 0, layer_id - self.start_layer, h_index : h_index + self.page_size
985
+ ],
960
986
  non_blocking=True,
961
987
  )
962
- device_pool.v_buffer[layer_id][d_index : d_index + self.page_size].copy_(
963
- self.kv_buffer[1, layer_id, h_index : h_index + self.page_size],
988
+ device_pool.v_buffer[layer_id - self.start_layer][
989
+ d_index : d_index + self.page_size
990
+ ].copy_(
991
+ self.kv_buffer[
992
+ 1, layer_id - self.start_layer, h_index : h_index + self.page_size
993
+ ],
964
994
  non_blocking=True,
965
995
  )
966
996
 
@@ -1015,7 +1045,7 @@ class MLATokenToKVPoolHost(HostKVCache):
1015
1045
  return self.kv_buffer[:, indices]
1016
1046
 
1017
1047
  def get_flat_data_by_layer(self, indices, layer_id):
1018
- return self.kv_buffer[layer_id, indices]
1048
+ return self.kv_buffer[layer_id - self.start_layer, indices]
1019
1049
 
1020
1050
  def assign_flat_data(self, indices, flat_data):
1021
1051
  self.kv_buffer[:, indices] = flat_data
@@ -1036,7 +1066,11 @@ class MLATokenToKVPoolHost(HostKVCache):
1036
1066
  for i in range(len(device_indices_cpu)):
1037
1067
  h_index = host_indices[i * self.page_size]
1038
1068
  d_index = device_indices_cpu[i]
1039
- device_pool.kv_buffer[layer_id][d_index : d_index + self.page_size].copy_(
1040
- self.kv_buffer[layer_id, h_index : h_index + self.page_size],
1069
+ device_pool.kv_buffer[layer_id - self.start_layer][
1070
+ d_index : d_index + self.page_size
1071
+ ].copy_(
1072
+ self.kv_buffer[
1073
+ layer_id - self.start_layer, h_index : h_index + self.page_size
1074
+ ],
1041
1075
  non_blocking=True,
1042
1076
  )
@@ -16,6 +16,7 @@
16
16
  from __future__ import annotations
17
17
 
18
18
  import bisect
19
+ import inspect
19
20
  import os
20
21
  from contextlib import contextmanager
21
22
  from typing import TYPE_CHECKING, Callable
@@ -33,12 +34,14 @@ from sglang.srt.model_executor.forward_batch_info import (
33
34
  CaptureHiddenMode,
34
35
  ForwardBatch,
35
36
  ForwardMode,
37
+ PPProxyTensors,
36
38
  )
37
39
  from sglang.srt.patch_torch import monkey_patch_torch_compile
38
40
  from sglang.srt.utils import (
39
41
  get_available_gpu_memory,
40
42
  get_device_memory_capacity,
41
43
  is_hip,
44
+ rank0_log,
42
45
  )
43
46
 
44
47
  if TYPE_CHECKING:
@@ -135,7 +138,7 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
135
138
 
136
139
  gpu_mem = get_device_memory_capacity()
137
140
  # Batch size of each rank will not become so large when DP is on
138
- if gpu_mem is not None and gpu_mem > 81920 and server_args.dp_size == 1:
141
+ if gpu_mem is not None and gpu_mem > 96 * 1024:
139
142
  capture_bs += list(range(160, 257, 8))
140
143
 
141
144
  if max(capture_bs) > model_runner.req_to_token_pool.size:
@@ -188,10 +191,11 @@ class CudaGraphRunner:
188
191
  self.speculative_algorithm = model_runner.server_args.speculative_algorithm
189
192
  self.tp_size = model_runner.server_args.tp_size
190
193
  self.dp_size = model_runner.server_args.dp_size
194
+ self.pp_size = model_runner.server_args.pp_size
191
195
 
192
196
  # Batch sizes to capture
193
197
  self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
194
-
198
+ rank0_log(f"Capture cuda graph bs {self.capture_bs}")
195
199
  self.capture_forward_mode = ForwardMode.DECODE
196
200
  self.capture_hidden_mode = CaptureHiddenMode.NULL
197
201
  self.num_tokens_per_bs = 1
@@ -220,6 +224,9 @@ class CudaGraphRunner:
220
224
  if self.enable_torch_compile:
221
225
  set_torch_compile_config()
222
226
 
227
+ if self.model_runner.server_args.lora_paths is not None:
228
+ self.model_runner.lora_manager.init_cuda_graph_batch_info(self.max_bs)
229
+
223
230
  # Graph inputs
224
231
  with torch.device("cuda"):
225
232
  self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
@@ -231,6 +238,19 @@ class CudaGraphRunner:
231
238
  self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
232
239
  self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64)
233
240
 
241
+ # pipeline parallelism
242
+ if self.pp_size > 1:
243
+ self.pp_proxy_tensors = {
244
+ "hidden_states": torch.zeros(
245
+ (self.max_bs, self.model_runner.model_config.hidden_size),
246
+ dtype=torch.bfloat16,
247
+ ),
248
+ "residual": torch.zeros(
249
+ (self.max_bs, self.model_runner.model_config.hidden_size),
250
+ dtype=torch.bfloat16,
251
+ ),
252
+ }
253
+
234
254
  # Speculative_inference
235
255
  if (
236
256
  model_runner.spec_algorithm.is_eagle3()
@@ -381,6 +401,12 @@ class CudaGraphRunner:
381
401
  encoder_lens = None
382
402
  mrope_positions = self.mrope_positions[:, :bs]
383
403
 
404
+ # pipeline parallelism
405
+ if self.pp_size > 1:
406
+ pp_proxy_tensors = PPProxyTensors(
407
+ {k: v[:num_tokens] for k, v in self.pp_proxy_tensors.items()}
408
+ )
409
+
384
410
  if self.enable_dp_attention or self.enable_sp_layernorm:
385
411
  self.global_num_tokens_gpu.copy_(
386
412
  torch.tensor(
@@ -403,6 +429,13 @@ class CudaGraphRunner:
403
429
  self.capture_hidden_mode = (
404
430
  spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
405
431
  )
432
+ if self.model_runner.server_args.lora_paths is not None:
433
+ # Currently, if the lora_path in `lora_paths` is None, the lora backend will use a
434
+ # different logic to handle lora, so we need to set `lora_paths` to a list of non-None
435
+ # values if lora is enabled.
436
+ lora_paths = [next(iter(self.model_runner.server_args.lora_paths))] * bs
437
+ else:
438
+ lora_paths = None
406
439
 
407
440
  forward_batch = ForwardBatch(
408
441
  forward_mode=self.capture_forward_mode,
@@ -424,8 +457,12 @@ class CudaGraphRunner:
424
457
  spec_algorithm=self.model_runner.spec_algorithm,
425
458
  spec_info=spec_info,
426
459
  capture_hidden_mode=self.capture_hidden_mode,
460
+ lora_paths=lora_paths,
427
461
  )
428
462
 
463
+ if lora_paths is not None:
464
+ self.model_runner.lora_manager.prepare_lora_batch(forward_batch)
465
+
429
466
  # Attention backend
430
467
  self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
431
468
  bs,
@@ -442,8 +479,20 @@ class CudaGraphRunner:
442
479
  # Clean intermediate result cache for DP attention
443
480
  forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
444
481
 
445
- logits_output = forward(input_ids, forward_batch.positions, forward_batch)
446
- return logits_output.next_token_logits, logits_output.hidden_states
482
+ kwargs = {}
483
+ if (
484
+ self.pp_size > 1
485
+ and "pp_proxy_tensors" in inspect.signature(forward).parameters
486
+ ):
487
+ kwargs["pp_proxy_tensors"] = pp_proxy_tensors
488
+
489
+ logits_output_or_pp_proxy_tensors = forward(
490
+ input_ids,
491
+ forward_batch.positions,
492
+ forward_batch,
493
+ **kwargs,
494
+ )
495
+ return logits_output_or_pp_proxy_tensors
447
496
 
448
497
  for _ in range(2):
449
498
  torch.cuda.synchronize()
@@ -476,7 +525,11 @@ class CudaGraphRunner:
476
525
  self.capture_hidden_mode = hidden_mode_from_spec_info
477
526
  self.capture()
478
527
 
479
- def replay_prepare(self, forward_batch: ForwardBatch):
528
+ def replay_prepare(
529
+ self,
530
+ forward_batch: ForwardBatch,
531
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
532
+ ):
480
533
  self.recapture_if_needed(forward_batch)
481
534
 
482
535
  raw_bs = forward_batch.batch_size
@@ -505,6 +558,11 @@ class CudaGraphRunner:
505
558
  self.seq_lens_cpu.fill_(1)
506
559
  self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
507
560
 
561
+ if pp_proxy_tensors:
562
+ for key in self.pp_proxy_tensors.keys():
563
+ dim = pp_proxy_tensors[key].shape[0]
564
+ self.pp_proxy_tensors[key][:dim].copy_(pp_proxy_tensors[key])
565
+
508
566
  if self.is_encoder_decoder:
509
567
  self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
510
568
  if forward_batch.mrope_positions is not None:
@@ -533,10 +591,13 @@ class CudaGraphRunner:
533
591
  self.bs = bs
534
592
 
535
593
  def replay(
536
- self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False
537
- ) -> LogitsProcessorOutput:
594
+ self,
595
+ forward_batch: ForwardBatch,
596
+ skip_attn_backend_init: bool = False,
597
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
598
+ ) -> Union[LogitsProcessorOutput, PPProxyTensors]:
538
599
  if not skip_attn_backend_init:
539
- self.replay_prepare(forward_batch)
600
+ self.replay_prepare(forward_batch, pp_proxy_tensors)
540
601
  else:
541
602
  # In speculative decoding, these two fields are still needed.
542
603
  self.input_ids[: self.raw_num_token].copy_(forward_batch.input_ids)
@@ -544,17 +605,19 @@ class CudaGraphRunner:
544
605
 
545
606
  # Replay
546
607
  self.graphs[self.bs].replay()
547
- next_token_logits, hidden_states = self.output_buffers[self.bs]
548
-
549
- logits_output = LogitsProcessorOutput(
550
- next_token_logits=next_token_logits[: self.raw_num_token],
551
- hidden_states=(
552
- hidden_states[: self.raw_num_token]
553
- if hidden_states is not None
554
- else None
555
- ),
556
- )
557
- return logits_output
608
+ output = self.output_buffers[self.bs]
609
+ if isinstance(output, LogitsProcessorOutput):
610
+ return LogitsProcessorOutput(
611
+ next_token_logits=output.next_token_logits[: self.raw_num_token],
612
+ hidden_states=(
613
+ output.hidden_states[: self.raw_num_token]
614
+ if output.hidden_states is not None
615
+ else None
616
+ ),
617
+ )
618
+ else:
619
+ assert isinstance(output, PPProxyTensors)
620
+ return PPProxyTensors({k: v[: self.bs] for k, v in output.tensors.items()})
558
621
 
559
622
  def get_spec_info(self, num_tokens: int):
560
623
  spec_info = None
@@ -31,7 +31,7 @@ from __future__ import annotations
31
31
 
32
32
  from dataclasses import dataclass
33
33
  from enum import IntEnum, auto
34
- from typing import TYPE_CHECKING, List, Optional, Union
34
+ from typing import TYPE_CHECKING, Dict, List, Optional, Union
35
35
 
36
36
  import torch
37
37
  import triton
@@ -585,6 +585,36 @@ class ForwardBatch:
585
585
  self.prepare_chunked_kv_indices(device)
586
586
 
587
587
 
588
+ class PPProxyTensors:
589
+ # adapted from https://github.com/vllm-project/vllm/blob/d14e98d924724b284dc5eaf8070d935e214e50c0/vllm/sequence.py#L1103
590
+ tensors: Dict[str, torch.Tensor]
591
+
592
+ def __init__(self, tensors):
593
+ # manually define this function, so that
594
+ # Dynamo knows `IntermediateTensors()` comes from this file.
595
+ # Otherwise, dataclass will generate this function by evaluating
596
+ # a string, and we will lose the information about the source file.
597
+ self.tensors = tensors
598
+
599
+ def __getitem__(self, key: Union[str, slice]):
600
+ if isinstance(key, str):
601
+ return self.tensors[key]
602
+ elif isinstance(key, slice):
603
+ return self.__class__({k: v[key] for k, v in self.tensors.items()})
604
+
605
+ def __setitem__(self, key: str, value: torch.Tensor):
606
+ self.tensors[key] = value
607
+
608
+ def __len__(self):
609
+ return len(self.tensors)
610
+
611
+ def __eq__(self, other: object):
612
+ return isinstance(other, self.__class__) and self
613
+
614
+ def __repr__(self) -> str:
615
+ return f"PPProxyTensors(tensors={self.tensors})"
616
+
617
+
588
618
  def compute_position_triton(
589
619
  extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor, extend_seq_lens_sum
590
620
  ):