sglang 0.4.2.post1__py3-none-any.whl → 0.4.2.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/srt/constrained/outlines_backend.py +9 -1
- sglang/srt/custom_op.py +40 -0
- sglang/srt/entrypoints/engine.py +2 -2
- sglang/srt/layers/activation.py +10 -5
- sglang/srt/layers/attention/flashinfer_backend.py +284 -39
- sglang/srt/layers/attention/triton_backend.py +71 -7
- sglang/srt/layers/attention/triton_ops/decode_attention.py +53 -59
- sglang/srt/layers/layernorm.py +1 -5
- sglang/srt/layers/moe/ep_moe/layer.py +1 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +178 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +175 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +3 -11
- sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -3
- sglang/srt/layers/moe/topk.py +4 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8_kernel.py +140 -2
- sglang/srt/layers/rotary_embedding.py +1 -3
- sglang/srt/layers/sampler.py +4 -4
- sglang/srt/lora/backend/__init__.py +8 -0
- sglang/srt/lora/backend/base_backend.py +95 -0
- sglang/srt/lora/backend/flashinfer_backend.py +91 -0
- sglang/srt/lora/backend/triton_backend.py +61 -0
- sglang/srt/lora/lora.py +127 -112
- sglang/srt/lora/lora_manager.py +50 -18
- sglang/srt/lora/triton_ops/__init__.py +5 -0
- sglang/srt/lora/triton_ops/qkv_lora_b.py +182 -0
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +143 -0
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +159 -0
- sglang/srt/model_executor/cuda_graph_runner.py +77 -80
- sglang/srt/model_executor/forward_batch_info.py +58 -59
- sglang/srt/model_executor/model_runner.py +2 -2
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/server_args.py +13 -2
- sglang/srt/speculative/build_eagle_tree.py +4 -2
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +213 -0
- sglang/srt/speculative/eagle_utils.py +361 -372
- sglang/srt/speculative/eagle_worker.py +177 -45
- sglang/srt/utils.py +7 -0
- sglang/test/runners.py +2 -0
- sglang/version.py +1 -1
- {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post2.dist-info}/METADATA +15 -6
- {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post2.dist-info}/RECORD +72 -33
- sglang/srt/layers/custom_op_util.py +0 -25
- {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post2.dist-info}/LICENSE +0 -0
- {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post2.dist-info}/WHEEL +0 -0
- {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post2.dist-info}/top_level.txt +0 -0
@@ -21,8 +21,8 @@ from typing import TYPE_CHECKING, Callable
|
|
21
21
|
|
22
22
|
import torch
|
23
23
|
import tqdm
|
24
|
-
from vllm.model_executor.custom_op import CustomOp
|
25
24
|
|
25
|
+
from sglang.srt.custom_op import CustomOp
|
26
26
|
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
27
27
|
from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
|
28
28
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
@@ -103,69 +103,75 @@ def set_torch_compile_config():
|
|
103
103
|
torch._dynamo.config.cache_size_limit = 1024
|
104
104
|
|
105
105
|
|
106
|
+
def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
107
|
+
server_args = model_runner.server_args
|
108
|
+
capture_bs = server_args.cuda_graph_bs
|
109
|
+
if capture_bs is None:
|
110
|
+
if server_args.disable_cuda_graph_padding:
|
111
|
+
capture_bs = list(range(1, 33)) + [64, 128]
|
112
|
+
else:
|
113
|
+
capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
|
114
|
+
if max(capture_bs) > model_runner.req_to_token_pool.size:
|
115
|
+
# In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
|
116
|
+
# is very samll. We add more values here to make sure we capture the maximum bs.
|
117
|
+
capture_bs = list(
|
118
|
+
sorted(
|
119
|
+
set(
|
120
|
+
capture_bs
|
121
|
+
+ [model_runner.req_to_token_pool.size - 1]
|
122
|
+
+ [model_runner.req_to_token_pool.size]
|
123
|
+
)
|
124
|
+
)
|
125
|
+
)
|
126
|
+
capture_bs = [
|
127
|
+
bs
|
128
|
+
for bs in capture_bs
|
129
|
+
if bs <= model_runner.req_to_token_pool.size
|
130
|
+
and bs <= server_args.cuda_graph_max_bs
|
131
|
+
]
|
132
|
+
compile_bs = (
|
133
|
+
[bs for bs in capture_bs if bs <= server_args.torch_compile_max_bs]
|
134
|
+
if server_args.enable_torch_compile
|
135
|
+
else []
|
136
|
+
)
|
137
|
+
return capture_bs, compile_bs
|
138
|
+
|
139
|
+
|
140
|
+
# Reuse this memory pool across all cuda graph runners.
|
141
|
+
global_graph_memory_pool = None
|
142
|
+
|
143
|
+
|
144
|
+
def get_global_graph_memory_pool():
|
145
|
+
return global_graph_memory_pool
|
146
|
+
|
147
|
+
|
148
|
+
def set_global_graph_memory_pool(val):
|
149
|
+
global global_graph_memory_pool
|
150
|
+
global_graph_memory_pool = val
|
151
|
+
|
152
|
+
|
106
153
|
class CudaGraphRunner:
|
107
154
|
"""A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile."""
|
108
155
|
|
109
|
-
def __init__(self, model_runner:
|
156
|
+
def __init__(self, model_runner: ModelRunner):
|
110
157
|
# Parse args
|
111
158
|
self.model_runner = model_runner
|
112
159
|
self.graphs = {}
|
113
|
-
self.input_buffers = {}
|
114
160
|
self.output_buffers = {}
|
115
|
-
self.
|
116
|
-
self.graph_memory_pool = None
|
117
|
-
self.use_torch_compile = model_runner.server_args.enable_torch_compile
|
161
|
+
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
|
118
162
|
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
119
|
-
self.is_encoder_decoder =
|
120
|
-
self.enable_dp_attention =
|
121
|
-
self.tp_size =
|
122
|
-
self.dp_size =
|
163
|
+
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
|
164
|
+
self.enable_dp_attention = model_runner.server_args.enable_dp_attention
|
165
|
+
self.tp_size = model_runner.server_args.tp_size
|
166
|
+
self.dp_size = model_runner.server_args.dp_size
|
123
167
|
|
124
168
|
# Batch sizes to capture
|
125
|
-
self.capture_bs =
|
126
|
-
if self.capture_bs is None:
|
127
|
-
if model_runner.server_args.disable_cuda_graph_padding:
|
128
|
-
self.capture_bs = list(range(1, 33)) + [64, 128]
|
129
|
-
else:
|
130
|
-
self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
|
131
|
-
|
132
|
-
if max(self.capture_bs) > model_runner.req_to_token_pool.size:
|
133
|
-
# In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
|
134
|
-
# is very samll. We add more values here to make sure we capture the maximum bs.
|
135
|
-
self.capture_bs = list(
|
136
|
-
sorted(
|
137
|
-
set(
|
138
|
-
self.capture_bs
|
139
|
-
+ [model_runner.req_to_token_pool.size - 1]
|
140
|
-
+ [model_runner.req_to_token_pool.size]
|
141
|
-
)
|
142
|
-
)
|
143
|
-
)
|
144
|
-
|
145
|
-
self.capture_bs = [
|
146
|
-
bs
|
147
|
-
for bs in self.capture_bs
|
148
|
-
if bs <= model_runner.req_to_token_pool.size
|
149
|
-
and bs <= model_runner.server_args.cuda_graph_max_bs
|
150
|
-
]
|
151
|
-
|
152
|
-
self.compile_bs = (
|
153
|
-
[
|
154
|
-
bs
|
155
|
-
for bs in self.capture_bs
|
156
|
-
if bs <= self.model_runner.server_args.torch_compile_max_bs
|
157
|
-
]
|
158
|
-
if self.use_torch_compile
|
159
|
-
else []
|
160
|
-
)
|
161
|
-
|
169
|
+
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
|
162
170
|
self.capture_forward_mode = ForwardMode.DECODE
|
163
171
|
self.num_tokens_per_bs = 1
|
164
172
|
if model_runner.spec_algorithm.is_eagle():
|
165
173
|
if self.model_runner.is_draft_worker:
|
166
|
-
|
167
|
-
self.model_runner.server_args.speculative_eagle_topk
|
168
|
-
)
|
174
|
+
raise RuntimeError("This should not happen")
|
169
175
|
else:
|
170
176
|
self.capture_forward_mode = ForwardMode.TARGET_VERIFY
|
171
177
|
self.num_tokens_per_bs = (
|
@@ -182,10 +188,10 @@ class CudaGraphRunner:
|
|
182
188
|
# FIXME(lsyin): leave it here for now, I don't know whether it is necessary
|
183
189
|
self.encoder_len_fill_value = 0
|
184
190
|
|
185
|
-
if self.
|
191
|
+
if self.enable_torch_compile:
|
186
192
|
set_torch_compile_config()
|
187
193
|
|
188
|
-
#
|
194
|
+
# Graph inputs
|
189
195
|
with torch.device("cuda"):
|
190
196
|
self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
191
197
|
self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
|
@@ -301,7 +307,7 @@ class CudaGraphRunner:
|
|
301
307
|
stream = self.stream
|
302
308
|
num_tokens = bs * self.num_tokens_per_bs
|
303
309
|
|
304
|
-
#
|
310
|
+
# Graph inputs
|
305
311
|
input_ids = self.input_ids[:num_tokens]
|
306
312
|
req_pool_indices = self.req_pool_indices[:bs]
|
307
313
|
seq_lens = self.seq_lens[:bs]
|
@@ -320,7 +326,7 @@ class CudaGraphRunner:
|
|
320
326
|
global_num_tokens = None
|
321
327
|
gathered_buffer = None
|
322
328
|
|
323
|
-
spec_info = self.get_spec_info(num_tokens
|
329
|
+
spec_info = self.get_spec_info(num_tokens)
|
324
330
|
|
325
331
|
forward_batch = ForwardBatch(
|
326
332
|
forward_mode=self.capture_forward_mode,
|
@@ -335,7 +341,6 @@ class CudaGraphRunner:
|
|
335
341
|
seq_lens_sum=seq_lens.sum(),
|
336
342
|
encoder_lens=encoder_lens,
|
337
343
|
return_logprob=False,
|
338
|
-
top_logprobs_nums=[0] * bs,
|
339
344
|
positions=positions,
|
340
345
|
global_num_tokens=global_num_tokens,
|
341
346
|
gathered_buffer=gathered_buffer,
|
@@ -375,13 +380,14 @@ class CudaGraphRunner:
|
|
375
380
|
torch.cuda.synchronize()
|
376
381
|
self.model_runner.tp_group.barrier()
|
377
382
|
|
378
|
-
|
383
|
+
global global_graph_memory_pool
|
384
|
+
with torch.cuda.graph(graph, pool=global_graph_memory_pool, stream=stream):
|
379
385
|
out = run_once()
|
380
386
|
|
381
387
|
torch.cuda.synchronize()
|
382
388
|
self.model_runner.tp_group.barrier()
|
383
389
|
|
384
|
-
|
390
|
+
global_graph_memory_pool = graph.pool()
|
385
391
|
return graph, out
|
386
392
|
|
387
393
|
def replay(self, forward_batch: ForwardBatch):
|
@@ -439,35 +445,26 @@ class CudaGraphRunner:
|
|
439
445
|
)
|
440
446
|
return logits_output
|
441
447
|
|
442
|
-
def get_spec_info(self, num_tokens: int
|
448
|
+
def get_spec_info(self, num_tokens: int):
|
443
449
|
spec_info = None
|
444
450
|
if self.model_runner.spec_algorithm.is_eagle():
|
445
|
-
from sglang.srt.speculative.eagle_utils import
|
446
|
-
EAGLEDraftInput,
|
447
|
-
EagleVerifyInput,
|
448
|
-
)
|
451
|
+
from sglang.srt.speculative.eagle_utils import EagleVerifyInput
|
449
452
|
|
450
453
|
if self.model_runner.is_draft_worker:
|
451
|
-
|
452
|
-
spec_info.load_server_args(self.model_runner.server_args)
|
453
|
-
spec_info.hidden_states = self.hidden_states[:num_tokens]
|
454
|
-
spec_info.positions = positions
|
455
|
-
spec_info.capture_hidden_mode = CaptureHiddenMode.FULL
|
454
|
+
raise RuntimeError("This should not happen.")
|
456
455
|
else:
|
457
456
|
spec_info = EagleVerifyInput(
|
458
|
-
None,
|
459
|
-
|
460
|
-
|
461
|
-
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
|
469
|
-
device="cuda",
|
457
|
+
draft_token=None,
|
458
|
+
custom_mask=torch.zeros(
|
459
|
+
(num_tokens * self.model_runner.model_config.context_len),
|
460
|
+
dtype=torch.bool,
|
461
|
+
device="cuda",
|
462
|
+
),
|
463
|
+
positions=None,
|
464
|
+
retrive_index=None,
|
465
|
+
retrive_cum_len=None,
|
466
|
+
draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens,
|
467
|
+
capture_hidden_mode=CaptureHiddenMode.FULL,
|
470
468
|
)
|
471
|
-
spec_info.capture_hidden_mode = CaptureHiddenMode.FULL
|
472
469
|
|
473
470
|
return spec_info
|
@@ -197,64 +197,6 @@ class ForwardBatch:
|
|
197
197
|
# For Qwen2-VL
|
198
198
|
mrope_positions: torch.Tensor = None
|
199
199
|
|
200
|
-
def compute_mrope_positions(
|
201
|
-
self, model_runner: ModelRunner, batch: ModelWorkerBatch
|
202
|
-
):
|
203
|
-
device = model_runner.device
|
204
|
-
hf_config = model_runner.model_config.hf_config
|
205
|
-
mrope_positions_list = [None] * self.seq_lens.shape[0]
|
206
|
-
if self.forward_mode.is_decode():
|
207
|
-
for i, _ in enumerate(mrope_positions_list):
|
208
|
-
mrope_position_delta = (
|
209
|
-
0
|
210
|
-
if batch.image_inputs[i] is None
|
211
|
-
else batch.image_inputs[i].mrope_position_delta
|
212
|
-
)
|
213
|
-
mrope_positions_list[i] = MRotaryEmbedding.get_next_input_positions(
|
214
|
-
mrope_position_delta,
|
215
|
-
int(self.seq_lens[i]) - 1,
|
216
|
-
int(self.seq_lens[i]),
|
217
|
-
)
|
218
|
-
elif self.forward_mode.is_extend():
|
219
|
-
extend_start_loc_cpu = self.extend_start_loc.cpu().numpy()
|
220
|
-
for i, image_inputs in enumerate(batch.image_inputs):
|
221
|
-
extend_start_loc, extend_seq_len, extend_prefix_len = (
|
222
|
-
extend_start_loc_cpu[i],
|
223
|
-
batch.extend_seq_lens[i],
|
224
|
-
batch.extend_prefix_lens[i],
|
225
|
-
)
|
226
|
-
if image_inputs is None:
|
227
|
-
# text only
|
228
|
-
mrope_positions = [
|
229
|
-
[
|
230
|
-
pos
|
231
|
-
for pos in range(
|
232
|
-
extend_prefix_len, extend_prefix_len + extend_seq_len
|
233
|
-
)
|
234
|
-
]
|
235
|
-
] * 3
|
236
|
-
else:
|
237
|
-
# TODO: current qwen2-vl do not support radix cache since mrope position calculation
|
238
|
-
mrope_positions, mrope_position_delta = (
|
239
|
-
MRotaryEmbedding.get_input_positions(
|
240
|
-
input_tokens=self.input_ids[
|
241
|
-
extend_start_loc : extend_start_loc + extend_seq_len
|
242
|
-
],
|
243
|
-
image_grid_thw=image_inputs.image_grid_thws,
|
244
|
-
vision_start_token_id=hf_config.vision_start_token_id,
|
245
|
-
spatial_merge_size=hf_config.vision_config.spatial_merge_size,
|
246
|
-
context_len=0,
|
247
|
-
)
|
248
|
-
)
|
249
|
-
batch.image_inputs[i].mrope_position_delta = mrope_position_delta
|
250
|
-
mrope_positions_list[i] = mrope_positions
|
251
|
-
|
252
|
-
self.mrope_positions = torch.concat(
|
253
|
-
[torch.tensor(pos, device=device) for pos in mrope_positions_list],
|
254
|
-
axis=1,
|
255
|
-
)
|
256
|
-
self.mrope_positions = self.mrope_positions.to(torch.int64)
|
257
|
-
|
258
200
|
@classmethod
|
259
201
|
def init_new(
|
260
202
|
cls,
|
@@ -337,7 +279,7 @@ class ForwardBatch:
|
|
337
279
|
ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens
|
338
280
|
|
339
281
|
if model_runner.model_is_mrope:
|
340
|
-
ret.
|
282
|
+
ret._compute_mrope_positions(model_runner, batch)
|
341
283
|
|
342
284
|
# Init lora information
|
343
285
|
if model_runner.server_args.lora_paths is not None:
|
@@ -345,6 +287,63 @@ class ForwardBatch:
|
|
345
287
|
|
346
288
|
return ret
|
347
289
|
|
290
|
+
def _compute_mrope_positions(
|
291
|
+
self, model_runner: ModelRunner, batch: ModelWorkerBatch
|
292
|
+
):
|
293
|
+
device = model_runner.device
|
294
|
+
hf_config = model_runner.model_config.hf_config
|
295
|
+
mrope_positions_list = [None] * self.seq_lens.shape[0]
|
296
|
+
if self.forward_mode.is_decode():
|
297
|
+
for i, _ in enumerate(mrope_positions_list):
|
298
|
+
mrope_position_delta = (
|
299
|
+
0
|
300
|
+
if batch.image_inputs[i] is None
|
301
|
+
else batch.image_inputs[i].mrope_position_delta
|
302
|
+
)
|
303
|
+
mrope_positions_list[i] = MRotaryEmbedding.get_next_input_positions(
|
304
|
+
mrope_position_delta,
|
305
|
+
int(self.seq_lens[i]) - 1,
|
306
|
+
int(self.seq_lens[i]),
|
307
|
+
)
|
308
|
+
elif self.forward_mode.is_extend():
|
309
|
+
extend_start_loc_cpu = self.extend_start_loc.cpu().numpy()
|
310
|
+
for i, image_inputs in enumerate(batch.image_inputs):
|
311
|
+
extend_start_loc, extend_seq_len, extend_prefix_len = (
|
312
|
+
extend_start_loc_cpu[i],
|
313
|
+
batch.extend_seq_lens[i],
|
314
|
+
batch.extend_prefix_lens[i],
|
315
|
+
)
|
316
|
+
if image_inputs is None:
|
317
|
+
# text only
|
318
|
+
mrope_positions = [
|
319
|
+
[
|
320
|
+
pos
|
321
|
+
for pos in range(
|
322
|
+
extend_prefix_len, extend_prefix_len + extend_seq_len
|
323
|
+
)
|
324
|
+
]
|
325
|
+
] * 3
|
326
|
+
else:
|
327
|
+
# TODO: current qwen2-vl do not support radix cache since mrope position calculation
|
328
|
+
mrope_positions, mrope_position_delta = (
|
329
|
+
MRotaryEmbedding.get_input_positions(
|
330
|
+
input_tokens=self.input_ids[
|
331
|
+
extend_start_loc : extend_start_loc + extend_seq_len
|
332
|
+
],
|
333
|
+
image_grid_thw=image_inputs.image_grid_thws,
|
334
|
+
vision_start_token_id=hf_config.vision_start_token_id,
|
335
|
+
spatial_merge_size=hf_config.vision_config.spatial_merge_size,
|
336
|
+
context_len=0,
|
337
|
+
)
|
338
|
+
)
|
339
|
+
batch.image_inputs[i].mrope_position_delta = mrope_position_delta
|
340
|
+
mrope_positions_list[i] = mrope_positions
|
341
|
+
self.mrope_positions = torch.concat(
|
342
|
+
[torch.tensor(pos, device=device) for pos in mrope_positions_list],
|
343
|
+
axis=1,
|
344
|
+
)
|
345
|
+
self.mrope_positions = self.mrope_positions.to(torch.int64)
|
346
|
+
|
348
347
|
|
349
348
|
def compute_position_triton(
|
350
349
|
extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor, extend_seq_lens_sum
|
@@ -52,6 +52,7 @@ from sglang.srt.mem_cache.memory_pool import (
|
|
52
52
|
MLATokenToKVPool,
|
53
53
|
ReqToTokenPool,
|
54
54
|
)
|
55
|
+
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
55
56
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
56
57
|
from sglang.srt.model_loader import get_model
|
57
58
|
from sglang.srt.server_args import ServerArgs
|
@@ -529,6 +530,7 @@ class ModelRunner:
|
|
529
530
|
max_loras_per_batch=self.server_args.max_loras_per_batch,
|
530
531
|
load_config=self.load_config,
|
531
532
|
dtype=self.dtype,
|
533
|
+
lora_backend=self.server_args.lora_backend,
|
532
534
|
)
|
533
535
|
logger.info("LoRA manager ready.")
|
534
536
|
|
@@ -714,8 +716,6 @@ class ModelRunner:
|
|
714
716
|
|
715
717
|
def init_cuda_graphs(self):
|
716
718
|
"""Capture cuda graphs."""
|
717
|
-
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
718
|
-
|
719
719
|
self.cuda_graph_runner = None
|
720
720
|
|
721
721
|
if not self.is_generation:
|
sglang/srt/models/qwen2_vl.py
CHANGED
@@ -31,10 +31,10 @@ import torch
|
|
31
31
|
import torch.nn as nn
|
32
32
|
import torch.nn.functional as F
|
33
33
|
from einops import rearrange
|
34
|
-
from vllm.model_executor.layers.activation import QuickGELU
|
35
34
|
|
36
35
|
from sglang.srt.configs import Qwen2VLConfig, Qwen2VLVisionConfig
|
37
36
|
from sglang.srt.hf_transformers_utils import get_processor
|
37
|
+
from sglang.srt.layers.activation import QuickGELU
|
38
38
|
from sglang.srt.layers.attention.vision import VisionAttention
|
39
39
|
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
|
40
40
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
sglang/srt/server_args.py
CHANGED
@@ -113,6 +113,7 @@ class ServerArgs:
|
|
113
113
|
# LoRA
|
114
114
|
lora_paths: Optional[List[str]] = None
|
115
115
|
max_loras_per_batch: int = 8
|
116
|
+
lora_backend: str = "triton"
|
116
117
|
|
117
118
|
# Kernel backend
|
118
119
|
attention_backend: Optional[str] = None
|
@@ -273,6 +274,10 @@ class ServerArgs:
|
|
273
274
|
) and check_gguf_file(self.model_path):
|
274
275
|
self.quantization = self.load_format = "gguf"
|
275
276
|
|
277
|
+
# AMD-specific Triton attention KV splits default number
|
278
|
+
if is_hip():
|
279
|
+
self.triton_attention_num_kv_splits = 16
|
280
|
+
|
276
281
|
@staticmethod
|
277
282
|
def add_cli_args(parser: argparse.ArgumentParser):
|
278
283
|
# Model and port args
|
@@ -649,13 +654,19 @@ class ServerArgs:
|
|
649
654
|
nargs="*",
|
650
655
|
default=None,
|
651
656
|
action=LoRAPathAction,
|
652
|
-
help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}",
|
657
|
+
help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}.",
|
653
658
|
)
|
654
659
|
parser.add_argument(
|
655
660
|
"--max-loras-per-batch",
|
656
661
|
type=int,
|
657
662
|
default=8,
|
658
|
-
help="Maximum number of adapters for a running batch, include base-only request",
|
663
|
+
help="Maximum number of adapters for a running batch, include base-only request.",
|
664
|
+
)
|
665
|
+
parser.add_argument(
|
666
|
+
"--lora-backend",
|
667
|
+
type=str,
|
668
|
+
default="triton",
|
669
|
+
help="Choose the kernel backend for multi-LoRA serving.",
|
659
670
|
)
|
660
671
|
|
661
672
|
# Kernel backend
|
@@ -79,11 +79,13 @@ __global__ void build_tree(Tensor<long, 2> parent_list, Tensor<long, 2> selected
|
|
79
79
|
)
|
80
80
|
|
81
81
|
|
82
|
-
def build_tree_kernel(
|
82
|
+
def build_tree_kernel(
|
83
|
+
parent_list, top_score_index, seq_lens, seq_lens_sum, topk, depth, draft_token
|
84
|
+
):
|
83
85
|
bs = seq_lens.numel()
|
84
86
|
device = parent_list.device
|
85
87
|
tree_mask = torch.full(
|
86
|
-
(
|
88
|
+
(seq_lens_sum * draft_token + draft_token * draft_token * bs,),
|
87
89
|
True,
|
88
90
|
device=device,
|
89
91
|
)
|
@@ -0,0 +1,213 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import bisect
|
4
|
+
import time
|
5
|
+
from typing import TYPE_CHECKING, Callable
|
6
|
+
|
7
|
+
import torch
|
8
|
+
|
9
|
+
from sglang.srt.model_executor.cuda_graph_runner import (
|
10
|
+
CudaGraphRunner,
|
11
|
+
get_batch_sizes_to_capture,
|
12
|
+
get_global_graph_memory_pool,
|
13
|
+
set_global_graph_memory_pool,
|
14
|
+
set_torch_compile_config,
|
15
|
+
)
|
16
|
+
from sglang.srt.model_executor.forward_batch_info import (
|
17
|
+
CaptureHiddenMode,
|
18
|
+
ForwardBatch,
|
19
|
+
ForwardMode,
|
20
|
+
)
|
21
|
+
from sglang.srt.speculative.eagle_utils import EagleDraftInput
|
22
|
+
|
23
|
+
if TYPE_CHECKING:
|
24
|
+
from sglang.srt.model_executor.model_runner import ModelRunner
|
25
|
+
from sglang.srt.speculative.eagle_worker import EAGLEWorker
|
26
|
+
|
27
|
+
|
28
|
+
class EAGLEDraftCudaGraphRunner:
|
29
|
+
def __init__(self, eagle_worker: EAGLEWorker):
|
30
|
+
# Parse args
|
31
|
+
self.eagle_worker = eagle_worker
|
32
|
+
self.model_runner = model_runner = eagle_worker.model_runner
|
33
|
+
self.graphs = {}
|
34
|
+
self.output_buffers = {}
|
35
|
+
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
|
36
|
+
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
37
|
+
self.tp_size = self.model_runner.tp_size
|
38
|
+
self.dp_size = model_runner.server_args.dp_size
|
39
|
+
self.topk = model_runner.server_args.speculative_eagle_topk
|
40
|
+
self.speculative_num_steps = model_runner.server_args.speculative_num_steps
|
41
|
+
server_args = model_runner.server_args
|
42
|
+
|
43
|
+
assert self.disable_padding
|
44
|
+
|
45
|
+
# Batch sizes to capture
|
46
|
+
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
|
47
|
+
self.num_tokens_per_bs = server_args.speculative_eagle_topk
|
48
|
+
|
49
|
+
# Attention backend
|
50
|
+
self.max_bs = max(self.capture_bs)
|
51
|
+
self.max_num_token = self.max_bs * self.num_tokens_per_bs
|
52
|
+
self.model_runner.draft_attn_backend.init_cuda_graph_state(self.max_num_token)
|
53
|
+
self.seq_len_fill_value = self.model_runner.draft_attn_backend.attn_backends[
|
54
|
+
0
|
55
|
+
].get_cuda_graph_seq_len_fill_value()
|
56
|
+
|
57
|
+
if self.enable_torch_compile:
|
58
|
+
set_torch_compile_config()
|
59
|
+
|
60
|
+
# Graph inputs
|
61
|
+
with torch.device("cuda"):
|
62
|
+
self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
63
|
+
self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
|
64
|
+
self.seq_lens = torch.full(
|
65
|
+
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
|
66
|
+
)
|
67
|
+
self.out_cache_loc = torch.zeros(
|
68
|
+
(self.max_num_token * self.speculative_num_steps,), dtype=torch.int64
|
69
|
+
)
|
70
|
+
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
71
|
+
self.topk_p = torch.zeros((self.max_bs, self.topk), dtype=torch.float32)
|
72
|
+
self.topk_index = torch.zeros((self.max_bs, self.topk), dtype=torch.int64)
|
73
|
+
self.hidden_states = torch.zeros(
|
74
|
+
(self.max_bs, self.model_runner.model_config.hidden_size),
|
75
|
+
dtype=self.model_runner.dtype,
|
76
|
+
)
|
77
|
+
|
78
|
+
# Capture
|
79
|
+
try:
|
80
|
+
self.capture()
|
81
|
+
except RuntimeError as e:
|
82
|
+
raise Exception(
|
83
|
+
f"Capture cuda graph failed: {e}\n"
|
84
|
+
"Possible solutions:\n"
|
85
|
+
"1. disable cuda graph by --disable-cuda-graph\n"
|
86
|
+
"2. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n"
|
87
|
+
"3. disable torch compile by not using --enable-torch-compile\n"
|
88
|
+
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
|
89
|
+
)
|
90
|
+
|
91
|
+
def can_run(self, forward_batch: ForwardBatch):
|
92
|
+
is_bs_supported = (
|
93
|
+
forward_batch.batch_size in self.graphs
|
94
|
+
if self.disable_padding
|
95
|
+
else forward_batch.batch_size <= self.max_bs
|
96
|
+
)
|
97
|
+
return is_bs_supported
|
98
|
+
|
99
|
+
def capture(self):
|
100
|
+
CudaGraphRunner.capture(self)
|
101
|
+
|
102
|
+
def capture_one_batch_size(self, num_seqs: int, forward: Callable):
|
103
|
+
graph = torch.cuda.CUDAGraph()
|
104
|
+
stream = self.stream
|
105
|
+
num_tokens = num_seqs * self.num_tokens_per_bs
|
106
|
+
|
107
|
+
# Graph inputs
|
108
|
+
req_pool_indices = self.req_pool_indices[:num_seqs]
|
109
|
+
seq_lens = self.seq_lens[:num_seqs]
|
110
|
+
out_cache_loc = self.out_cache_loc[: num_tokens * self.speculative_num_steps]
|
111
|
+
positions = self.positions[:num_tokens]
|
112
|
+
topk_p = self.topk_p[:num_seqs]
|
113
|
+
topk_index = self.topk_index[:num_seqs]
|
114
|
+
hidden_states = self.hidden_states[:num_seqs]
|
115
|
+
|
116
|
+
spec_info = EagleDraftInput(
|
117
|
+
topk_p=topk_p,
|
118
|
+
topk_index=topk_index,
|
119
|
+
hidden_states=hidden_states,
|
120
|
+
)
|
121
|
+
|
122
|
+
# Forward batch
|
123
|
+
forward_batch = ForwardBatch(
|
124
|
+
forward_mode=ForwardMode.DECODE,
|
125
|
+
batch_size=num_seqs,
|
126
|
+
input_ids=None,
|
127
|
+
req_pool_indices=req_pool_indices,
|
128
|
+
seq_lens=seq_lens,
|
129
|
+
req_to_token_pool=self.model_runner.req_to_token_pool,
|
130
|
+
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
131
|
+
out_cache_loc=out_cache_loc,
|
132
|
+
seq_lens_sum=seq_lens.sum(),
|
133
|
+
return_logprob=False,
|
134
|
+
positions=positions,
|
135
|
+
spec_algorithm=self.model_runner.spec_algorithm,
|
136
|
+
spec_info=spec_info,
|
137
|
+
capture_hidden_mode=(
|
138
|
+
spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
|
139
|
+
),
|
140
|
+
)
|
141
|
+
|
142
|
+
# Attention backend
|
143
|
+
self.model_runner.draft_attn_backend.init_forward_metadata_capture_cuda_graph(
|
144
|
+
forward_batch
|
145
|
+
)
|
146
|
+
|
147
|
+
# Run and capture
|
148
|
+
def run_once():
|
149
|
+
# Backup two fileds, which will be modified in-place in `draft_forward`.
|
150
|
+
output_cache_loc_backup = forward_batch.out_cache_loc
|
151
|
+
hidden_states_backup = forward_batch.spec_info.hidden_states
|
152
|
+
|
153
|
+
ret = self.eagle_worker.draft_forward(forward_batch)
|
154
|
+
|
155
|
+
forward_batch.out_cache_loc = output_cache_loc_backup
|
156
|
+
forward_batch.spec_info.hidden_states = hidden_states_backup
|
157
|
+
return ret
|
158
|
+
|
159
|
+
for _ in range(2):
|
160
|
+
torch.cuda.synchronize()
|
161
|
+
self.model_runner.tp_group.barrier()
|
162
|
+
|
163
|
+
run_once()
|
164
|
+
|
165
|
+
torch.cuda.synchronize()
|
166
|
+
self.model_runner.tp_group.barrier()
|
167
|
+
|
168
|
+
torch.cuda.synchronize()
|
169
|
+
self.model_runner.tp_group.barrier()
|
170
|
+
|
171
|
+
with torch.cuda.graph(
|
172
|
+
graph, pool=get_global_graph_memory_pool(), stream=stream
|
173
|
+
):
|
174
|
+
out = run_once()
|
175
|
+
|
176
|
+
torch.cuda.synchronize()
|
177
|
+
self.model_runner.tp_group.barrier()
|
178
|
+
|
179
|
+
set_global_graph_memory_pool(graph.pool())
|
180
|
+
return graph, out
|
181
|
+
|
182
|
+
def replay(self, forward_batch: ForwardBatch):
|
183
|
+
assert forward_batch.out_cache_loc is not None
|
184
|
+
raw_bs = forward_batch.batch_size
|
185
|
+
raw_num_token = raw_bs * self.num_tokens_per_bs
|
186
|
+
|
187
|
+
# Pad
|
188
|
+
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
189
|
+
bs = self.capture_bs[index]
|
190
|
+
if bs != raw_bs:
|
191
|
+
self.seq_lens.fill_(1)
|
192
|
+
self.out_cache_loc.zero_()
|
193
|
+
|
194
|
+
# Common inputs
|
195
|
+
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
|
196
|
+
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
|
197
|
+
self.out_cache_loc[: raw_num_token * self.speculative_num_steps].copy_(
|
198
|
+
forward_batch.out_cache_loc
|
199
|
+
)
|
200
|
+
self.positions[:raw_num_token].copy_(forward_batch.positions)
|
201
|
+
self.topk_p[:raw_bs].copy_(forward_batch.spec_info.topk_p)
|
202
|
+
self.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index)
|
203
|
+
self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states)
|
204
|
+
|
205
|
+
# Attention backend
|
206
|
+
self.model_runner.draft_attn_backend.init_forward_metadata_replay_cuda_graph(
|
207
|
+
forward_batch
|
208
|
+
)
|
209
|
+
|
210
|
+
# Replay
|
211
|
+
self.graphs[bs].replay()
|
212
|
+
|
213
|
+
return self.output_buffers[bs]
|