sglang 0.4.6.post1__py3-none-any.whl → 0.4.6.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -0
- sglang/check_env.py +3 -3
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/kimi_vl.py +38 -0
- sglang/srt/configs/kimi_vl_moonvit.py +32 -0
- sglang/srt/configs/model_config.py +15 -0
- sglang/srt/conversation.py +122 -1
- sglang/srt/entrypoints/engine.py +44 -22
- sglang/srt/function_call_parser.py +97 -0
- sglang/srt/hf_transformers_utils.py +2 -0
- sglang/srt/layers/attention/cutlass_mla_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_backend.py +107 -82
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -16
- sglang/srt/layers/attention/flashmla_backend.py +3 -0
- sglang/srt/layers/dp_attention.py +5 -2
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +1 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +8 -6
- sglang/srt/layers/quantization/__init__.py +2 -2
- sglang/srt/layers/quantization/deep_gemm.py +1 -1
- sglang/srt/layers/utils.py +35 -0
- sglang/srt/lora/layers.py +35 -9
- sglang/srt/lora/lora_manager.py +84 -35
- sglang/srt/managers/data_parallel_controller.py +52 -34
- sglang/srt/managers/multimodal_processors/kimi_vl.py +73 -0
- sglang/srt/managers/schedule_batch.py +25 -15
- sglang/srt/managers/scheduler.py +263 -59
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -1
- sglang/srt/managers/tp_worker.py +51 -16
- sglang/srt/managers/tp_worker_overlap_thread.py +9 -3
- sglang/srt/mem_cache/memory_pool.py +70 -36
- sglang/srt/model_executor/cuda_graph_runner.py +82 -19
- sglang/srt/model_executor/forward_batch_info.py +31 -1
- sglang/srt/model_executor/model_runner.py +115 -57
- sglang/srt/models/deepseek_nextn.py +1 -257
- sglang/srt/models/deepseek_v2.py +78 -18
- sglang/srt/models/kimi_vl.py +308 -0
- sglang/srt/models/kimi_vl_moonvit.py +639 -0
- sglang/srt/models/llama.py +92 -30
- sglang/srt/models/llama4.py +2 -1
- sglang/srt/models/llama_eagle.py +4 -1
- sglang/srt/models/llama_eagle3.py +4 -1
- sglang/srt/models/qwen2_moe.py +8 -3
- sglang/srt/models/qwen2_vl.py +0 -12
- sglang/srt/models/qwen3_moe.py +8 -3
- sglang/srt/openai_api/adapter.py +34 -22
- sglang/srt/openai_api/protocol.py +11 -1
- sglang/srt/server_args.py +67 -22
- sglang/srt/speculative/eagle_worker.py +3 -2
- sglang/srt/utils.py +88 -9
- sglang/test/runners.py +4 -0
- sglang/test/test_utils.py +29 -0
- sglang/version.py +1 -1
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post2.dist-info}/METADATA +5 -4
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post2.dist-info}/RECORD +61 -51
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post2.dist-info}/WHEEL +1 -1
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post2.dist-info}/top_level.txt +0 -0
sglang/srt/managers/tp_worker.py
CHANGED
@@ -15,11 +15,12 @@
|
|
15
15
|
|
16
16
|
import logging
|
17
17
|
import threading
|
18
|
-
from typing import Optional, Tuple
|
18
|
+
from typing import Optional, Tuple, Union
|
19
19
|
|
20
20
|
import torch
|
21
21
|
|
22
22
|
from sglang.srt.configs.model_config import ModelConfig
|
23
|
+
from sglang.srt.distributed import get_pp_group, get_tp_group, get_world_group
|
23
24
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
24
25
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
25
26
|
from sglang.srt.managers.io_struct import (
|
@@ -31,7 +32,7 @@ from sglang.srt.managers.io_struct import (
|
|
31
32
|
)
|
32
33
|
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
|
33
34
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
34
|
-
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
35
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
35
36
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
36
37
|
from sglang.srt.server_args import ServerArgs
|
37
38
|
from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj, set_random_seed
|
@@ -47,6 +48,7 @@ class TpModelWorker:
|
|
47
48
|
server_args: ServerArgs,
|
48
49
|
gpu_id: int,
|
49
50
|
tp_rank: int,
|
51
|
+
pp_rank: int,
|
50
52
|
dp_rank: Optional[int],
|
51
53
|
nccl_port: int,
|
52
54
|
is_draft_worker: bool = False,
|
@@ -54,7 +56,9 @@ class TpModelWorker:
|
|
54
56
|
token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None,
|
55
57
|
):
|
56
58
|
# Parse args
|
59
|
+
self.tp_size = server_args.tp_size
|
57
60
|
self.tp_rank = tp_rank
|
61
|
+
self.pp_rank = pp_rank
|
58
62
|
|
59
63
|
# Init model and tokenizer
|
60
64
|
self.model_config = ModelConfig(
|
@@ -71,13 +75,17 @@ class TpModelWorker:
|
|
71
75
|
enable_multimodal=server_args.enable_multimodal,
|
72
76
|
dtype=server_args.dtype,
|
73
77
|
quantization=server_args.quantization,
|
78
|
+
is_draft_model=is_draft_worker,
|
74
79
|
)
|
80
|
+
|
75
81
|
self.model_runner = ModelRunner(
|
76
82
|
model_config=self.model_config,
|
77
83
|
mem_fraction_static=server_args.mem_fraction_static,
|
78
84
|
gpu_id=gpu_id,
|
79
85
|
tp_rank=tp_rank,
|
80
86
|
tp_size=server_args.tp_size,
|
87
|
+
pp_rank=pp_rank,
|
88
|
+
pp_size=server_args.pp_size,
|
81
89
|
nccl_port=nccl_port,
|
82
90
|
server_args=server_args,
|
83
91
|
is_draft_worker=is_draft_worker,
|
@@ -104,6 +112,10 @@ class TpModelWorker:
|
|
104
112
|
)
|
105
113
|
self.device = self.model_runner.device
|
106
114
|
|
115
|
+
# Init nccl groups
|
116
|
+
self.pp_group = get_pp_group()
|
117
|
+
self.world_group = get_world_group()
|
118
|
+
|
107
119
|
# Profile number of tokens
|
108
120
|
self.max_total_num_tokens = self.model_runner.max_total_num_tokens
|
109
121
|
self.max_prefill_tokens = server_args.max_prefill_tokens
|
@@ -129,8 +141,9 @@ class TpModelWorker:
|
|
129
141
|
# Sync random seed across TP workers
|
130
142
|
self.random_seed = broadcast_pyobj(
|
131
143
|
[server_args.random_seed],
|
132
|
-
self.tp_rank,
|
133
|
-
self.
|
144
|
+
self.tp_size * self.pp_rank + tp_rank,
|
145
|
+
self.world_group.cpu_group,
|
146
|
+
src=self.world_group.ranks[0],
|
134
147
|
)[0]
|
135
148
|
set_random_seed(self.random_seed)
|
136
149
|
|
@@ -155,11 +168,14 @@ class TpModelWorker:
|
|
155
168
|
def get_pad_input_ids_func(self):
|
156
169
|
return getattr(self.model_runner.model, "pad_input_ids", None)
|
157
170
|
|
158
|
-
def
|
159
|
-
return self.model_runner.tp_group
|
171
|
+
def get_tp_group(self):
|
172
|
+
return self.model_runner.tp_group
|
173
|
+
|
174
|
+
def get_attention_tp_group(self):
|
175
|
+
return self.model_runner.attention_tp_group
|
160
176
|
|
161
177
|
def get_attention_tp_cpu_group(self):
|
162
|
-
return self.model_runner.attention_tp_group
|
178
|
+
return getattr(self.model_runner.attention_tp_group, "cpu_group", None)
|
163
179
|
|
164
180
|
def get_memory_pool(self):
|
165
181
|
return (
|
@@ -171,19 +187,38 @@ class TpModelWorker:
|
|
171
187
|
self,
|
172
188
|
model_worker_batch: ModelWorkerBatch,
|
173
189
|
skip_sample: bool = False,
|
174
|
-
) -> Tuple[LogitsProcessorOutput, Optional[torch.Tensor]]:
|
190
|
+
) -> Tuple[Union[LogitsProcessorOutput, torch.Tensor], Optional[torch.Tensor]]:
|
175
191
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
176
|
-
logits_output = self.model_runner.forward(forward_batch)
|
177
192
|
|
178
|
-
|
179
|
-
|
193
|
+
pp_proxy_tensors = None
|
194
|
+
if not self.pp_group.is_first_rank:
|
195
|
+
pp_proxy_tensors = PPProxyTensors(
|
196
|
+
self.pp_group.recv_tensor_dict(
|
197
|
+
all_gather_group=self.get_attention_tp_group()
|
198
|
+
)
|
199
|
+
)
|
200
|
+
|
201
|
+
if self.pp_group.is_last_rank:
|
202
|
+
logits_output = self.model_runner.forward(
|
203
|
+
forward_batch, pp_proxy_tensors=pp_proxy_tensors
|
204
|
+
)
|
205
|
+
if model_worker_batch.launch_done is not None:
|
206
|
+
model_worker_batch.launch_done.set()
|
180
207
|
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
208
|
+
if skip_sample:
|
209
|
+
next_token_ids = None
|
210
|
+
else:
|
211
|
+
next_token_ids = self.model_runner.sample(
|
212
|
+
logits_output, model_worker_batch
|
213
|
+
)
|
185
214
|
|
186
|
-
|
215
|
+
return logits_output, next_token_ids
|
216
|
+
else:
|
217
|
+
pp_proxy_tensors = self.model_runner.forward(
|
218
|
+
forward_batch,
|
219
|
+
pp_proxy_tensors=pp_proxy_tensors,
|
220
|
+
)
|
221
|
+
return pp_proxy_tensors.tensors, None
|
187
222
|
|
188
223
|
def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
|
189
224
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
@@ -56,11 +56,14 @@ class TpModelWorkerClient:
|
|
56
56
|
server_args: ServerArgs,
|
57
57
|
gpu_id: int,
|
58
58
|
tp_rank: int,
|
59
|
+
pp_rank: int,
|
59
60
|
dp_rank: Optional[int],
|
60
61
|
nccl_port: int,
|
61
62
|
):
|
62
63
|
# Load the model
|
63
|
-
self.worker = TpModelWorker(
|
64
|
+
self.worker = TpModelWorker(
|
65
|
+
server_args, gpu_id, tp_rank, pp_rank, dp_rank, nccl_port
|
66
|
+
)
|
64
67
|
self.max_running_requests = self.worker.max_running_requests
|
65
68
|
self.device = self.worker.device
|
66
69
|
self.gpu_id = gpu_id
|
@@ -91,8 +94,11 @@ class TpModelWorkerClient:
|
|
91
94
|
def get_pad_input_ids_func(self):
|
92
95
|
return self.worker.get_pad_input_ids_func()
|
93
96
|
|
94
|
-
def
|
95
|
-
return self.worker.
|
97
|
+
def get_tp_group(self):
|
98
|
+
return self.worker.get_tp_group()
|
99
|
+
|
100
|
+
def get_attention_tp_group(self):
|
101
|
+
return self.worker.get_attention_tp_group()
|
96
102
|
|
97
103
|
def get_attention_tp_cpu_group(self):
|
98
104
|
return self.worker.get_attention_tp_cpu_group()
|
@@ -214,6 +214,8 @@ class MHATokenToKVPool(KVCache):
|
|
214
214
|
layer_num: int,
|
215
215
|
device: str,
|
216
216
|
enable_memory_saver: bool,
|
217
|
+
start_layer: Optional[int] = None,
|
218
|
+
end_layer: Optional[int] = None,
|
217
219
|
):
|
218
220
|
self.size = size
|
219
221
|
self.page_size = page_size
|
@@ -232,6 +234,8 @@ class MHATokenToKVPool(KVCache):
|
|
232
234
|
self.head_dim = head_dim
|
233
235
|
self.layer_num = layer_num
|
234
236
|
self._create_buffers()
|
237
|
+
self.start_layer = start_layer or 0
|
238
|
+
self.end_layer = end_layer or layer_num - 1
|
235
239
|
|
236
240
|
self.layer_transfer_counter = None
|
237
241
|
self.capture_mode = False
|
@@ -281,6 +285,8 @@ class MHATokenToKVPool(KVCache):
|
|
281
285
|
|
282
286
|
# for disagg
|
283
287
|
def get_contiguous_buf_infos(self):
|
288
|
+
# layer_num x [seq_len, head_num, head_dim]
|
289
|
+
# layer_num x [page_num, page_size, head_num, head_dim]
|
284
290
|
kv_data_ptrs = [
|
285
291
|
self.get_key_buffer(i).data_ptr() for i in range(self.layer_num)
|
286
292
|
] + [self.get_value_buffer(i).data_ptr() for i in range(self.layer_num)]
|
@@ -320,24 +326,24 @@ class MHATokenToKVPool(KVCache):
|
|
320
326
|
# transfer prepared data from host to device
|
321
327
|
flat_data = flat_data.to(device=self.device, non_blocking=False)
|
322
328
|
k_data, v_data = flat_data[0], flat_data[1]
|
323
|
-
self.k_buffer[layer_id][indices] = k_data
|
324
|
-
self.v_buffer[layer_id][indices] = v_data
|
329
|
+
self.k_buffer[layer_id - self.start_layer][indices] = k_data
|
330
|
+
self.v_buffer[layer_id - self.start_layer][indices] = v_data
|
325
331
|
|
326
332
|
def get_key_buffer(self, layer_id: int):
|
327
333
|
if self.layer_transfer_counter is not None:
|
328
|
-
self.layer_transfer_counter.wait_until(layer_id)
|
334
|
+
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
329
335
|
|
330
336
|
if self.store_dtype != self.dtype:
|
331
|
-
return self.k_buffer[layer_id].view(self.dtype)
|
332
|
-
return self.k_buffer[layer_id]
|
337
|
+
return self.k_buffer[layer_id - self.start_layer].view(self.dtype)
|
338
|
+
return self.k_buffer[layer_id - self.start_layer]
|
333
339
|
|
334
340
|
def get_value_buffer(self, layer_id: int):
|
335
341
|
if self.layer_transfer_counter is not None:
|
336
|
-
self.layer_transfer_counter.wait_until(layer_id)
|
342
|
+
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
337
343
|
|
338
344
|
if self.store_dtype != self.dtype:
|
339
|
-
return self.v_buffer[layer_id].view(self.dtype)
|
340
|
-
return self.v_buffer[layer_id]
|
345
|
+
return self.v_buffer[layer_id - self.start_layer].view(self.dtype)
|
346
|
+
return self.v_buffer[layer_id - self.start_layer]
|
341
347
|
|
342
348
|
def get_kv_buffer(self, layer_id: int):
|
343
349
|
return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
|
@@ -369,12 +375,12 @@ class MHATokenToKVPool(KVCache):
|
|
369
375
|
current_stream = self.device_module.current_stream()
|
370
376
|
self.alt_stream.wait_stream(current_stream)
|
371
377
|
with self.device_module.stream(self.alt_stream):
|
372
|
-
self.k_buffer[layer_id][loc] = cache_k
|
373
|
-
self.v_buffer[layer_id][loc] = cache_v
|
378
|
+
self.k_buffer[layer_id - self.start_layer][loc] = cache_k
|
379
|
+
self.v_buffer[layer_id - self.start_layer][loc] = cache_v
|
374
380
|
current_stream.wait_stream(self.alt_stream)
|
375
381
|
else:
|
376
|
-
self.k_buffer[layer_id][loc] = cache_k
|
377
|
-
self.v_buffer[layer_id][loc] = cache_v
|
382
|
+
self.k_buffer[layer_id - self.start_layer][loc] = cache_k
|
383
|
+
self.v_buffer[layer_id - self.start_layer][loc] = cache_v
|
378
384
|
|
379
385
|
|
380
386
|
@torch.compile
|
@@ -484,6 +490,8 @@ class MLATokenToKVPool(KVCache):
|
|
484
490
|
layer_num: int,
|
485
491
|
device: str,
|
486
492
|
enable_memory_saver: bool,
|
493
|
+
start_layer: Optional[int] = None,
|
494
|
+
end_layer: Optional[int] = None,
|
487
495
|
):
|
488
496
|
self.size = size
|
489
497
|
self.page_size = page_size
|
@@ -497,6 +505,8 @@ class MLATokenToKVPool(KVCache):
|
|
497
505
|
self.kv_lora_rank = kv_lora_rank
|
498
506
|
self.qk_rope_head_dim = qk_rope_head_dim
|
499
507
|
self.layer_num = layer_num
|
508
|
+
self.start_layer = start_layer or 0
|
509
|
+
self.end_layer = end_layer or layer_num - 1
|
500
510
|
|
501
511
|
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
502
512
|
enable=enable_memory_saver
|
@@ -540,19 +550,21 @@ class MLATokenToKVPool(KVCache):
|
|
540
550
|
|
541
551
|
def get_key_buffer(self, layer_id: int):
|
542
552
|
if self.layer_transfer_counter is not None:
|
543
|
-
self.layer_transfer_counter.wait_until(layer_id)
|
553
|
+
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
544
554
|
|
545
555
|
if self.store_dtype != self.dtype:
|
546
|
-
return self.kv_buffer[layer_id].view(self.dtype)
|
547
|
-
return self.kv_buffer[layer_id]
|
556
|
+
return self.kv_buffer[layer_id - self.start_layer].view(self.dtype)
|
557
|
+
return self.kv_buffer[layer_id - self.start_layer]
|
548
558
|
|
549
559
|
def get_value_buffer(self, layer_id: int):
|
550
560
|
if self.layer_transfer_counter is not None:
|
551
|
-
self.layer_transfer_counter.wait_until(layer_id)
|
561
|
+
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
552
562
|
|
553
563
|
if self.store_dtype != self.dtype:
|
554
|
-
return self.kv_buffer[layer_id
|
555
|
-
|
564
|
+
return self.kv_buffer[layer_id - self.start_layer][
|
565
|
+
..., : self.kv_lora_rank
|
566
|
+
].view(self.dtype)
|
567
|
+
return self.kv_buffer[layer_id - self.start_layer][..., : self.kv_lora_rank]
|
556
568
|
|
557
569
|
def get_kv_buffer(self, layer_id: int):
|
558
570
|
return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
|
@@ -568,9 +580,11 @@ class MLATokenToKVPool(KVCache):
|
|
568
580
|
if cache_k.dtype != self.dtype:
|
569
581
|
cache_k = cache_k.to(self.dtype)
|
570
582
|
if self.store_dtype != self.dtype:
|
571
|
-
self.kv_buffer[layer_id][loc] = cache_k.view(
|
583
|
+
self.kv_buffer[layer_id - self.start_layer][loc] = cache_k.view(
|
584
|
+
self.store_dtype
|
585
|
+
)
|
572
586
|
else:
|
573
|
-
self.kv_buffer[layer_id][loc] = cache_k
|
587
|
+
self.kv_buffer[layer_id - self.start_layer][loc] = cache_k
|
574
588
|
|
575
589
|
def set_mla_kv_buffer(
|
576
590
|
self,
|
@@ -605,7 +619,7 @@ class MLATokenToKVPool(KVCache):
|
|
605
619
|
def transfer_per_layer(self, indices, flat_data, layer_id):
|
606
620
|
# transfer prepared data from host to device
|
607
621
|
flat_data = flat_data.to(device=self.device, non_blocking=False)
|
608
|
-
self.kv_buffer[layer_id][indices] = flat_data
|
622
|
+
self.kv_buffer[layer_id - self.start_layer][indices] = flat_data
|
609
623
|
|
610
624
|
|
611
625
|
class DoubleSparseTokenToKVPool(KVCache):
|
@@ -620,6 +634,8 @@ class DoubleSparseTokenToKVPool(KVCache):
|
|
620
634
|
device: str,
|
621
635
|
heavy_channel_num: int,
|
622
636
|
enable_memory_saver: bool,
|
637
|
+
start_layer: Optional[int] = None,
|
638
|
+
end_layer: Optional[int] = None,
|
623
639
|
):
|
624
640
|
self.size = size
|
625
641
|
self.page_size = page_size
|
@@ -657,17 +673,23 @@ class DoubleSparseTokenToKVPool(KVCache):
|
|
657
673
|
for _ in range(layer_num)
|
658
674
|
]
|
659
675
|
|
676
|
+
self.start_layer = start_layer or 0
|
677
|
+
self.end_layer = end_layer or layer_num - 1
|
678
|
+
|
660
679
|
def get_key_buffer(self, layer_id: int):
|
661
|
-
return self.k_buffer[layer_id]
|
680
|
+
return self.k_buffer[layer_id - self.start_layer]
|
662
681
|
|
663
682
|
def get_value_buffer(self, layer_id: int):
|
664
|
-
return self.v_buffer[layer_id]
|
683
|
+
return self.v_buffer[layer_id - self.start_layer]
|
665
684
|
|
666
685
|
def get_label_buffer(self, layer_id: int):
|
667
|
-
return self.label_buffer[layer_id]
|
686
|
+
return self.label_buffer[layer_id - self.start_layer]
|
668
687
|
|
669
688
|
def get_kv_buffer(self, layer_id: int):
|
670
|
-
return
|
689
|
+
return (
|
690
|
+
self.k_buffer[layer_id - self.start_layer],
|
691
|
+
self.v_buffer[layer_id - self.start_layer],
|
692
|
+
)
|
671
693
|
|
672
694
|
def set_kv_buffer(
|
673
695
|
self,
|
@@ -679,9 +701,9 @@ class DoubleSparseTokenToKVPool(KVCache):
|
|
679
701
|
):
|
680
702
|
# NOTE(Andy): ignore the dtype check
|
681
703
|
layer_id = layer.layer_id
|
682
|
-
self.k_buffer[layer_id][loc] = cache_k
|
683
|
-
self.v_buffer[layer_id][loc] = cache_v
|
684
|
-
self.label_buffer[layer_id][loc] = cache_label
|
704
|
+
self.k_buffer[layer_id - self.start_layer][loc] = cache_k
|
705
|
+
self.v_buffer[layer_id - self.start_layer][loc] = cache_v
|
706
|
+
self.label_buffer[layer_id - self.start_layer][loc] = cache_label
|
685
707
|
|
686
708
|
def get_flat_data(self, indices):
|
687
709
|
pass
|
@@ -930,7 +952,7 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|
930
952
|
return self.kv_buffer[:, :, indices]
|
931
953
|
|
932
954
|
def get_flat_data_by_layer(self, indices, layer_id):
|
933
|
-
return self.kv_buffer[:, layer_id, indices]
|
955
|
+
return self.kv_buffer[:, layer_id - self.start_layer, indices]
|
934
956
|
|
935
957
|
def assign_flat_data(self, indices, flat_data):
|
936
958
|
self.kv_buffer[:, :, indices] = flat_data
|
@@ -955,12 +977,20 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|
955
977
|
for i in range(len(device_indices_cpu)):
|
956
978
|
h_index = host_indices[i * self.page_size]
|
957
979
|
d_index = device_indices_cpu[i]
|
958
|
-
device_pool.k_buffer[layer_id
|
959
|
-
|
980
|
+
device_pool.k_buffer[layer_id - self.start_layer][
|
981
|
+
d_index : d_index + self.page_size
|
982
|
+
].copy_(
|
983
|
+
self.kv_buffer[
|
984
|
+
0, layer_id - self.start_layer, h_index : h_index + self.page_size
|
985
|
+
],
|
960
986
|
non_blocking=True,
|
961
987
|
)
|
962
|
-
device_pool.v_buffer[layer_id
|
963
|
-
|
988
|
+
device_pool.v_buffer[layer_id - self.start_layer][
|
989
|
+
d_index : d_index + self.page_size
|
990
|
+
].copy_(
|
991
|
+
self.kv_buffer[
|
992
|
+
1, layer_id - self.start_layer, h_index : h_index + self.page_size
|
993
|
+
],
|
964
994
|
non_blocking=True,
|
965
995
|
)
|
966
996
|
|
@@ -1015,7 +1045,7 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|
1015
1045
|
return self.kv_buffer[:, indices]
|
1016
1046
|
|
1017
1047
|
def get_flat_data_by_layer(self, indices, layer_id):
|
1018
|
-
return self.kv_buffer[layer_id, indices]
|
1048
|
+
return self.kv_buffer[layer_id - self.start_layer, indices]
|
1019
1049
|
|
1020
1050
|
def assign_flat_data(self, indices, flat_data):
|
1021
1051
|
self.kv_buffer[:, indices] = flat_data
|
@@ -1036,7 +1066,11 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|
1036
1066
|
for i in range(len(device_indices_cpu)):
|
1037
1067
|
h_index = host_indices[i * self.page_size]
|
1038
1068
|
d_index = device_indices_cpu[i]
|
1039
|
-
device_pool.kv_buffer[layer_id
|
1040
|
-
|
1069
|
+
device_pool.kv_buffer[layer_id - self.start_layer][
|
1070
|
+
d_index : d_index + self.page_size
|
1071
|
+
].copy_(
|
1072
|
+
self.kv_buffer[
|
1073
|
+
layer_id - self.start_layer, h_index : h_index + self.page_size
|
1074
|
+
],
|
1041
1075
|
non_blocking=True,
|
1042
1076
|
)
|
@@ -16,6 +16,7 @@
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
18
|
import bisect
|
19
|
+
import inspect
|
19
20
|
import os
|
20
21
|
from contextlib import contextmanager
|
21
22
|
from typing import TYPE_CHECKING, Callable
|
@@ -33,12 +34,14 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|
33
34
|
CaptureHiddenMode,
|
34
35
|
ForwardBatch,
|
35
36
|
ForwardMode,
|
37
|
+
PPProxyTensors,
|
36
38
|
)
|
37
39
|
from sglang.srt.patch_torch import monkey_patch_torch_compile
|
38
40
|
from sglang.srt.utils import (
|
39
41
|
get_available_gpu_memory,
|
40
42
|
get_device_memory_capacity,
|
41
43
|
is_hip,
|
44
|
+
rank0_log,
|
42
45
|
)
|
43
46
|
|
44
47
|
if TYPE_CHECKING:
|
@@ -135,7 +138,7 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
|
135
138
|
|
136
139
|
gpu_mem = get_device_memory_capacity()
|
137
140
|
# Batch size of each rank will not become so large when DP is on
|
138
|
-
if gpu_mem is not None and gpu_mem >
|
141
|
+
if gpu_mem is not None and gpu_mem > 96 * 1024:
|
139
142
|
capture_bs += list(range(160, 257, 8))
|
140
143
|
|
141
144
|
if max(capture_bs) > model_runner.req_to_token_pool.size:
|
@@ -188,10 +191,11 @@ class CudaGraphRunner:
|
|
188
191
|
self.speculative_algorithm = model_runner.server_args.speculative_algorithm
|
189
192
|
self.tp_size = model_runner.server_args.tp_size
|
190
193
|
self.dp_size = model_runner.server_args.dp_size
|
194
|
+
self.pp_size = model_runner.server_args.pp_size
|
191
195
|
|
192
196
|
# Batch sizes to capture
|
193
197
|
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
|
194
|
-
|
198
|
+
rank0_log(f"Capture cuda graph bs {self.capture_bs}")
|
195
199
|
self.capture_forward_mode = ForwardMode.DECODE
|
196
200
|
self.capture_hidden_mode = CaptureHiddenMode.NULL
|
197
201
|
self.num_tokens_per_bs = 1
|
@@ -220,6 +224,9 @@ class CudaGraphRunner:
|
|
220
224
|
if self.enable_torch_compile:
|
221
225
|
set_torch_compile_config()
|
222
226
|
|
227
|
+
if self.model_runner.server_args.lora_paths is not None:
|
228
|
+
self.model_runner.lora_manager.init_cuda_graph_batch_info(self.max_bs)
|
229
|
+
|
223
230
|
# Graph inputs
|
224
231
|
with torch.device("cuda"):
|
225
232
|
self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
@@ -231,6 +238,19 @@ class CudaGraphRunner:
|
|
231
238
|
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
232
239
|
self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64)
|
233
240
|
|
241
|
+
# pipeline parallelism
|
242
|
+
if self.pp_size > 1:
|
243
|
+
self.pp_proxy_tensors = {
|
244
|
+
"hidden_states": torch.zeros(
|
245
|
+
(self.max_bs, self.model_runner.model_config.hidden_size),
|
246
|
+
dtype=torch.bfloat16,
|
247
|
+
),
|
248
|
+
"residual": torch.zeros(
|
249
|
+
(self.max_bs, self.model_runner.model_config.hidden_size),
|
250
|
+
dtype=torch.bfloat16,
|
251
|
+
),
|
252
|
+
}
|
253
|
+
|
234
254
|
# Speculative_inference
|
235
255
|
if (
|
236
256
|
model_runner.spec_algorithm.is_eagle3()
|
@@ -381,6 +401,12 @@ class CudaGraphRunner:
|
|
381
401
|
encoder_lens = None
|
382
402
|
mrope_positions = self.mrope_positions[:, :bs]
|
383
403
|
|
404
|
+
# pipeline parallelism
|
405
|
+
if self.pp_size > 1:
|
406
|
+
pp_proxy_tensors = PPProxyTensors(
|
407
|
+
{k: v[:num_tokens] for k, v in self.pp_proxy_tensors.items()}
|
408
|
+
)
|
409
|
+
|
384
410
|
if self.enable_dp_attention or self.enable_sp_layernorm:
|
385
411
|
self.global_num_tokens_gpu.copy_(
|
386
412
|
torch.tensor(
|
@@ -403,6 +429,13 @@ class CudaGraphRunner:
|
|
403
429
|
self.capture_hidden_mode = (
|
404
430
|
spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
|
405
431
|
)
|
432
|
+
if self.model_runner.server_args.lora_paths is not None:
|
433
|
+
# Currently, if the lora_path in `lora_paths` is None, the lora backend will use a
|
434
|
+
# different logic to handle lora, so we need to set `lora_paths` to a list of non-None
|
435
|
+
# values if lora is enabled.
|
436
|
+
lora_paths = [next(iter(self.model_runner.server_args.lora_paths))] * bs
|
437
|
+
else:
|
438
|
+
lora_paths = None
|
406
439
|
|
407
440
|
forward_batch = ForwardBatch(
|
408
441
|
forward_mode=self.capture_forward_mode,
|
@@ -424,8 +457,12 @@ class CudaGraphRunner:
|
|
424
457
|
spec_algorithm=self.model_runner.spec_algorithm,
|
425
458
|
spec_info=spec_info,
|
426
459
|
capture_hidden_mode=self.capture_hidden_mode,
|
460
|
+
lora_paths=lora_paths,
|
427
461
|
)
|
428
462
|
|
463
|
+
if lora_paths is not None:
|
464
|
+
self.model_runner.lora_manager.prepare_lora_batch(forward_batch)
|
465
|
+
|
429
466
|
# Attention backend
|
430
467
|
self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
|
431
468
|
bs,
|
@@ -442,8 +479,20 @@ class CudaGraphRunner:
|
|
442
479
|
# Clean intermediate result cache for DP attention
|
443
480
|
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
|
444
481
|
|
445
|
-
|
446
|
-
|
482
|
+
kwargs = {}
|
483
|
+
if (
|
484
|
+
self.pp_size > 1
|
485
|
+
and "pp_proxy_tensors" in inspect.signature(forward).parameters
|
486
|
+
):
|
487
|
+
kwargs["pp_proxy_tensors"] = pp_proxy_tensors
|
488
|
+
|
489
|
+
logits_output_or_pp_proxy_tensors = forward(
|
490
|
+
input_ids,
|
491
|
+
forward_batch.positions,
|
492
|
+
forward_batch,
|
493
|
+
**kwargs,
|
494
|
+
)
|
495
|
+
return logits_output_or_pp_proxy_tensors
|
447
496
|
|
448
497
|
for _ in range(2):
|
449
498
|
torch.cuda.synchronize()
|
@@ -476,7 +525,11 @@ class CudaGraphRunner:
|
|
476
525
|
self.capture_hidden_mode = hidden_mode_from_spec_info
|
477
526
|
self.capture()
|
478
527
|
|
479
|
-
def replay_prepare(
|
528
|
+
def replay_prepare(
|
529
|
+
self,
|
530
|
+
forward_batch: ForwardBatch,
|
531
|
+
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
532
|
+
):
|
480
533
|
self.recapture_if_needed(forward_batch)
|
481
534
|
|
482
535
|
raw_bs = forward_batch.batch_size
|
@@ -505,6 +558,11 @@ class CudaGraphRunner:
|
|
505
558
|
self.seq_lens_cpu.fill_(1)
|
506
559
|
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
|
507
560
|
|
561
|
+
if pp_proxy_tensors:
|
562
|
+
for key in self.pp_proxy_tensors.keys():
|
563
|
+
dim = pp_proxy_tensors[key].shape[0]
|
564
|
+
self.pp_proxy_tensors[key][:dim].copy_(pp_proxy_tensors[key])
|
565
|
+
|
508
566
|
if self.is_encoder_decoder:
|
509
567
|
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
|
510
568
|
if forward_batch.mrope_positions is not None:
|
@@ -533,10 +591,13 @@ class CudaGraphRunner:
|
|
533
591
|
self.bs = bs
|
534
592
|
|
535
593
|
def replay(
|
536
|
-
self,
|
537
|
-
|
594
|
+
self,
|
595
|
+
forward_batch: ForwardBatch,
|
596
|
+
skip_attn_backend_init: bool = False,
|
597
|
+
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
598
|
+
) -> Union[LogitsProcessorOutput, PPProxyTensors]:
|
538
599
|
if not skip_attn_backend_init:
|
539
|
-
self.replay_prepare(forward_batch)
|
600
|
+
self.replay_prepare(forward_batch, pp_proxy_tensors)
|
540
601
|
else:
|
541
602
|
# In speculative decoding, these two fields are still needed.
|
542
603
|
self.input_ids[: self.raw_num_token].copy_(forward_batch.input_ids)
|
@@ -544,17 +605,19 @@ class CudaGraphRunner:
|
|
544
605
|
|
545
606
|
# Replay
|
546
607
|
self.graphs[self.bs].replay()
|
547
|
-
|
548
|
-
|
549
|
-
|
550
|
-
|
551
|
-
|
552
|
-
|
553
|
-
|
554
|
-
|
555
|
-
|
556
|
-
|
557
|
-
|
608
|
+
output = self.output_buffers[self.bs]
|
609
|
+
if isinstance(output, LogitsProcessorOutput):
|
610
|
+
return LogitsProcessorOutput(
|
611
|
+
next_token_logits=output.next_token_logits[: self.raw_num_token],
|
612
|
+
hidden_states=(
|
613
|
+
output.hidden_states[: self.raw_num_token]
|
614
|
+
if output.hidden_states is not None
|
615
|
+
else None
|
616
|
+
),
|
617
|
+
)
|
618
|
+
else:
|
619
|
+
assert isinstance(output, PPProxyTensors)
|
620
|
+
return PPProxyTensors({k: v[: self.bs] for k, v in output.tensors.items()})
|
558
621
|
|
559
622
|
def get_spec_info(self, num_tokens: int):
|
560
623
|
spec_info = None
|
@@ -31,7 +31,7 @@ from __future__ import annotations
|
|
31
31
|
|
32
32
|
from dataclasses import dataclass
|
33
33
|
from enum import IntEnum, auto
|
34
|
-
from typing import TYPE_CHECKING, List, Optional, Union
|
34
|
+
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
35
35
|
|
36
36
|
import torch
|
37
37
|
import triton
|
@@ -585,6 +585,36 @@ class ForwardBatch:
|
|
585
585
|
self.prepare_chunked_kv_indices(device)
|
586
586
|
|
587
587
|
|
588
|
+
class PPProxyTensors:
|
589
|
+
# adapted from https://github.com/vllm-project/vllm/blob/d14e98d924724b284dc5eaf8070d935e214e50c0/vllm/sequence.py#L1103
|
590
|
+
tensors: Dict[str, torch.Tensor]
|
591
|
+
|
592
|
+
def __init__(self, tensors):
|
593
|
+
# manually define this function, so that
|
594
|
+
# Dynamo knows `IntermediateTensors()` comes from this file.
|
595
|
+
# Otherwise, dataclass will generate this function by evaluating
|
596
|
+
# a string, and we will lose the information about the source file.
|
597
|
+
self.tensors = tensors
|
598
|
+
|
599
|
+
def __getitem__(self, key: Union[str, slice]):
|
600
|
+
if isinstance(key, str):
|
601
|
+
return self.tensors[key]
|
602
|
+
elif isinstance(key, slice):
|
603
|
+
return self.__class__({k: v[key] for k, v in self.tensors.items()})
|
604
|
+
|
605
|
+
def __setitem__(self, key: str, value: torch.Tensor):
|
606
|
+
self.tensors[key] = value
|
607
|
+
|
608
|
+
def __len__(self):
|
609
|
+
return len(self.tensors)
|
610
|
+
|
611
|
+
def __eq__(self, other: object):
|
612
|
+
return isinstance(other, self.__class__) and self
|
613
|
+
|
614
|
+
def __repr__(self) -> str:
|
615
|
+
return f"PPProxyTensors(tensors={self.tensors})"
|
616
|
+
|
617
|
+
|
588
618
|
def compute_position_triton(
|
589
619
|
extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor, extend_seq_lens_sum
|
590
620
|
):
|