sglang 0.4.2__py3-none-any.whl → 0.4.2.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/srt/constrained/outlines_backend.py +9 -1
- sglang/srt/custom_op.py +40 -0
- sglang/srt/entrypoints/engine.py +2 -2
- sglang/srt/layers/activation.py +10 -5
- sglang/srt/layers/attention/flashinfer_backend.py +284 -39
- sglang/srt/layers/attention/triton_backend.py +71 -7
- sglang/srt/layers/attention/triton_ops/decode_attention.py +53 -59
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +6 -0
- sglang/srt/layers/attention/vision.py +243 -40
- sglang/srt/layers/layernorm.py +1 -5
- sglang/srt/layers/moe/ep_moe/layer.py +1 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +178 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +175 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +3 -11
- sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -3
- sglang/srt/layers/moe/topk.py +4 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8.py +7 -0
- sglang/srt/layers/quantization/fp8_kernel.py +140 -2
- sglang/srt/layers/rotary_embedding.py +29 -15
- sglang/srt/layers/sampler.py +9 -6
- sglang/srt/lora/backend/__init__.py +8 -0
- sglang/srt/lora/backend/base_backend.py +95 -0
- sglang/srt/lora/backend/flashinfer_backend.py +91 -0
- sglang/srt/lora/backend/triton_backend.py +61 -0
- sglang/srt/lora/lora.py +127 -112
- sglang/srt/lora/lora_manager.py +50 -18
- sglang/srt/lora/triton_ops/__init__.py +5 -0
- sglang/srt/lora/triton_ops/qkv_lora_b.py +182 -0
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +143 -0
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +159 -0
- sglang/srt/managers/image_processor.py +77 -38
- sglang/srt/managers/scheduler.py +17 -3
- sglang/srt/mem_cache/base_prefix_cache.py +4 -0
- sglang/srt/mem_cache/chunk_cache.py +3 -0
- sglang/srt/mem_cache/radix_cache.py +30 -1
- sglang/srt/model_executor/cuda_graph_runner.py +77 -80
- sglang/srt/model_executor/forward_batch_info.py +58 -59
- sglang/srt/model_executor/model_runner.py +2 -2
- sglang/srt/models/minicpmv.py +129 -76
- sglang/srt/models/mllama.py +16 -56
- sglang/srt/models/qwen2.py +4 -1
- sglang/srt/models/qwen2_vl.py +19 -9
- sglang/srt/server_args.py +19 -2
- sglang/srt/speculative/build_eagle_tree.py +4 -2
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +213 -0
- sglang/srt/speculative/eagle_utils.py +361 -372
- sglang/srt/speculative/eagle_worker.py +177 -45
- sglang/srt/utils.py +7 -2
- sglang/test/runners.py +2 -0
- sglang/utils.py +42 -0
- sglang/version.py +1 -1
- {sglang-0.4.2.dist-info → sglang-0.4.2.post2.dist-info}/METADATA +16 -7
- {sglang-0.4.2.dist-info → sglang-0.4.2.post2.dist-info}/RECORD +84 -45
- sglang/srt/layers/custom_op_util.py +0 -25
- {sglang-0.4.2.dist-info → sglang-0.4.2.post2.dist-info}/LICENSE +0 -0
- {sglang-0.4.2.dist-info → sglang-0.4.2.post2.dist-info}/WHEEL +0 -0
- {sglang-0.4.2.dist-info → sglang-0.4.2.post2.dist-info}/top_level.txt +0 -0
@@ -34,7 +34,10 @@ if TYPE_CHECKING:
|
|
34
34
|
|
35
35
|
|
36
36
|
class TreeNode:
|
37
|
-
|
37
|
+
|
38
|
+
counter = 0
|
39
|
+
|
40
|
+
def __init__(self, id: Optional[int] = None):
|
38
41
|
self.children = defaultdict(TreeNode)
|
39
42
|
self.parent = None
|
40
43
|
self.key = None
|
@@ -42,6 +45,23 @@ class TreeNode:
|
|
42
45
|
self.lock_ref = 0
|
43
46
|
self.last_access_time = time.time()
|
44
47
|
|
48
|
+
self.hit_count = 0
|
49
|
+
# indicating the node is loading KV cache from host
|
50
|
+
self.loading = False
|
51
|
+
# store the host indices of KV cache
|
52
|
+
self.host_value = None
|
53
|
+
|
54
|
+
self.id = TreeNode.counter if id is None else id
|
55
|
+
TreeNode.counter += 1
|
56
|
+
|
57
|
+
@property
|
58
|
+
def evicted(self):
|
59
|
+
return self.value is None
|
60
|
+
|
61
|
+
@property
|
62
|
+
def backuped(self):
|
63
|
+
return self.host_value is not None
|
64
|
+
|
45
65
|
def __lt__(self, other: "TreeNode"):
|
46
66
|
return self.last_access_time < other.last_access_time
|
47
67
|
|
@@ -75,6 +95,7 @@ class RadixCache(BasePrefixCache):
|
|
75
95
|
self.root_node.value = []
|
76
96
|
self.root_node.lock_ref = 1
|
77
97
|
self.evictable_size_ = 0
|
98
|
+
self.protected_size_ = 0
|
78
99
|
|
79
100
|
def match_prefix(self, key: List[int], **kwargs) -> Tuple[torch.Tensor, int]:
|
80
101
|
"""Find the matching prefix from the radix tree.
|
@@ -203,6 +224,7 @@ class RadixCache(BasePrefixCache):
|
|
203
224
|
while node != self.root_node:
|
204
225
|
if node.lock_ref == 0:
|
205
226
|
self.evictable_size_ -= len(node.value)
|
227
|
+
self.protected_size_ += len(node.value)
|
206
228
|
delta -= len(node.value)
|
207
229
|
node.lock_ref += 1
|
208
230
|
node = node.parent
|
@@ -216,6 +238,7 @@ class RadixCache(BasePrefixCache):
|
|
216
238
|
while node != self.root_node:
|
217
239
|
if node.lock_ref == 1:
|
218
240
|
self.evictable_size_ += len(node.value)
|
241
|
+
self.protected_size_ -= len(node.value)
|
219
242
|
delta += len(node.value)
|
220
243
|
node.lock_ref -= 1
|
221
244
|
node = node.parent
|
@@ -224,6 +247,10 @@ class RadixCache(BasePrefixCache):
|
|
224
247
|
def evictable_size(self):
|
225
248
|
return self.evictable_size_
|
226
249
|
|
250
|
+
def protected_size(self):
|
251
|
+
# protected size refers to the size of the cache that is locked
|
252
|
+
return self.protected_size_
|
253
|
+
|
227
254
|
##### Internal Helper Functions #####
|
228
255
|
|
229
256
|
def _match_prefix_helper(
|
@@ -303,6 +330,8 @@ class RadixCache(BasePrefixCache):
|
|
303
330
|
self.evictable_size_ -= len(node.key)
|
304
331
|
|
305
332
|
def _total_size_helper(self, node: TreeNode):
|
333
|
+
if node.evicted:
|
334
|
+
return 0
|
306
335
|
x = len(node.value)
|
307
336
|
for child in node.children.values():
|
308
337
|
x += self._total_size_helper(child)
|
@@ -21,8 +21,8 @@ from typing import TYPE_CHECKING, Callable
|
|
21
21
|
|
22
22
|
import torch
|
23
23
|
import tqdm
|
24
|
-
from vllm.model_executor.custom_op import CustomOp
|
25
24
|
|
25
|
+
from sglang.srt.custom_op import CustomOp
|
26
26
|
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
27
27
|
from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
|
28
28
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
@@ -103,69 +103,75 @@ def set_torch_compile_config():
|
|
103
103
|
torch._dynamo.config.cache_size_limit = 1024
|
104
104
|
|
105
105
|
|
106
|
+
def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
107
|
+
server_args = model_runner.server_args
|
108
|
+
capture_bs = server_args.cuda_graph_bs
|
109
|
+
if capture_bs is None:
|
110
|
+
if server_args.disable_cuda_graph_padding:
|
111
|
+
capture_bs = list(range(1, 33)) + [64, 128]
|
112
|
+
else:
|
113
|
+
capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
|
114
|
+
if max(capture_bs) > model_runner.req_to_token_pool.size:
|
115
|
+
# In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
|
116
|
+
# is very samll. We add more values here to make sure we capture the maximum bs.
|
117
|
+
capture_bs = list(
|
118
|
+
sorted(
|
119
|
+
set(
|
120
|
+
capture_bs
|
121
|
+
+ [model_runner.req_to_token_pool.size - 1]
|
122
|
+
+ [model_runner.req_to_token_pool.size]
|
123
|
+
)
|
124
|
+
)
|
125
|
+
)
|
126
|
+
capture_bs = [
|
127
|
+
bs
|
128
|
+
for bs in capture_bs
|
129
|
+
if bs <= model_runner.req_to_token_pool.size
|
130
|
+
and bs <= server_args.cuda_graph_max_bs
|
131
|
+
]
|
132
|
+
compile_bs = (
|
133
|
+
[bs for bs in capture_bs if bs <= server_args.torch_compile_max_bs]
|
134
|
+
if server_args.enable_torch_compile
|
135
|
+
else []
|
136
|
+
)
|
137
|
+
return capture_bs, compile_bs
|
138
|
+
|
139
|
+
|
140
|
+
# Reuse this memory pool across all cuda graph runners.
|
141
|
+
global_graph_memory_pool = None
|
142
|
+
|
143
|
+
|
144
|
+
def get_global_graph_memory_pool():
|
145
|
+
return global_graph_memory_pool
|
146
|
+
|
147
|
+
|
148
|
+
def set_global_graph_memory_pool(val):
|
149
|
+
global global_graph_memory_pool
|
150
|
+
global_graph_memory_pool = val
|
151
|
+
|
152
|
+
|
106
153
|
class CudaGraphRunner:
|
107
154
|
"""A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile."""
|
108
155
|
|
109
|
-
def __init__(self, model_runner:
|
156
|
+
def __init__(self, model_runner: ModelRunner):
|
110
157
|
# Parse args
|
111
158
|
self.model_runner = model_runner
|
112
159
|
self.graphs = {}
|
113
|
-
self.input_buffers = {}
|
114
160
|
self.output_buffers = {}
|
115
|
-
self.
|
116
|
-
self.graph_memory_pool = None
|
117
|
-
self.use_torch_compile = model_runner.server_args.enable_torch_compile
|
161
|
+
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
|
118
162
|
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
119
|
-
self.is_encoder_decoder =
|
120
|
-
self.enable_dp_attention =
|
121
|
-
self.tp_size =
|
122
|
-
self.dp_size =
|
163
|
+
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
|
164
|
+
self.enable_dp_attention = model_runner.server_args.enable_dp_attention
|
165
|
+
self.tp_size = model_runner.server_args.tp_size
|
166
|
+
self.dp_size = model_runner.server_args.dp_size
|
123
167
|
|
124
168
|
# Batch sizes to capture
|
125
|
-
self.capture_bs =
|
126
|
-
if self.capture_bs is None:
|
127
|
-
if model_runner.server_args.disable_cuda_graph_padding:
|
128
|
-
self.capture_bs = list(range(1, 33)) + [64, 128]
|
129
|
-
else:
|
130
|
-
self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
|
131
|
-
|
132
|
-
if max(self.capture_bs) > model_runner.req_to_token_pool.size:
|
133
|
-
# In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
|
134
|
-
# is very samll. We add more values here to make sure we capture the maximum bs.
|
135
|
-
self.capture_bs = list(
|
136
|
-
sorted(
|
137
|
-
set(
|
138
|
-
self.capture_bs
|
139
|
-
+ [model_runner.req_to_token_pool.size - 1]
|
140
|
-
+ [model_runner.req_to_token_pool.size]
|
141
|
-
)
|
142
|
-
)
|
143
|
-
)
|
144
|
-
|
145
|
-
self.capture_bs = [
|
146
|
-
bs
|
147
|
-
for bs in self.capture_bs
|
148
|
-
if bs <= model_runner.req_to_token_pool.size
|
149
|
-
and bs <= model_runner.server_args.cuda_graph_max_bs
|
150
|
-
]
|
151
|
-
|
152
|
-
self.compile_bs = (
|
153
|
-
[
|
154
|
-
bs
|
155
|
-
for bs in self.capture_bs
|
156
|
-
if bs <= self.model_runner.server_args.torch_compile_max_bs
|
157
|
-
]
|
158
|
-
if self.use_torch_compile
|
159
|
-
else []
|
160
|
-
)
|
161
|
-
|
169
|
+
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
|
162
170
|
self.capture_forward_mode = ForwardMode.DECODE
|
163
171
|
self.num_tokens_per_bs = 1
|
164
172
|
if model_runner.spec_algorithm.is_eagle():
|
165
173
|
if self.model_runner.is_draft_worker:
|
166
|
-
|
167
|
-
self.model_runner.server_args.speculative_eagle_topk
|
168
|
-
)
|
174
|
+
raise RuntimeError("This should not happen")
|
169
175
|
else:
|
170
176
|
self.capture_forward_mode = ForwardMode.TARGET_VERIFY
|
171
177
|
self.num_tokens_per_bs = (
|
@@ -182,10 +188,10 @@ class CudaGraphRunner:
|
|
182
188
|
# FIXME(lsyin): leave it here for now, I don't know whether it is necessary
|
183
189
|
self.encoder_len_fill_value = 0
|
184
190
|
|
185
|
-
if self.
|
191
|
+
if self.enable_torch_compile:
|
186
192
|
set_torch_compile_config()
|
187
193
|
|
188
|
-
#
|
194
|
+
# Graph inputs
|
189
195
|
with torch.device("cuda"):
|
190
196
|
self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
191
197
|
self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
|
@@ -301,7 +307,7 @@ class CudaGraphRunner:
|
|
301
307
|
stream = self.stream
|
302
308
|
num_tokens = bs * self.num_tokens_per_bs
|
303
309
|
|
304
|
-
#
|
310
|
+
# Graph inputs
|
305
311
|
input_ids = self.input_ids[:num_tokens]
|
306
312
|
req_pool_indices = self.req_pool_indices[:bs]
|
307
313
|
seq_lens = self.seq_lens[:bs]
|
@@ -320,7 +326,7 @@ class CudaGraphRunner:
|
|
320
326
|
global_num_tokens = None
|
321
327
|
gathered_buffer = None
|
322
328
|
|
323
|
-
spec_info = self.get_spec_info(num_tokens
|
329
|
+
spec_info = self.get_spec_info(num_tokens)
|
324
330
|
|
325
331
|
forward_batch = ForwardBatch(
|
326
332
|
forward_mode=self.capture_forward_mode,
|
@@ -335,7 +341,6 @@ class CudaGraphRunner:
|
|
335
341
|
seq_lens_sum=seq_lens.sum(),
|
336
342
|
encoder_lens=encoder_lens,
|
337
343
|
return_logprob=False,
|
338
|
-
top_logprobs_nums=[0] * bs,
|
339
344
|
positions=positions,
|
340
345
|
global_num_tokens=global_num_tokens,
|
341
346
|
gathered_buffer=gathered_buffer,
|
@@ -375,13 +380,14 @@ class CudaGraphRunner:
|
|
375
380
|
torch.cuda.synchronize()
|
376
381
|
self.model_runner.tp_group.barrier()
|
377
382
|
|
378
|
-
|
383
|
+
global global_graph_memory_pool
|
384
|
+
with torch.cuda.graph(graph, pool=global_graph_memory_pool, stream=stream):
|
379
385
|
out = run_once()
|
380
386
|
|
381
387
|
torch.cuda.synchronize()
|
382
388
|
self.model_runner.tp_group.barrier()
|
383
389
|
|
384
|
-
|
390
|
+
global_graph_memory_pool = graph.pool()
|
385
391
|
return graph, out
|
386
392
|
|
387
393
|
def replay(self, forward_batch: ForwardBatch):
|
@@ -439,35 +445,26 @@ class CudaGraphRunner:
|
|
439
445
|
)
|
440
446
|
return logits_output
|
441
447
|
|
442
|
-
def get_spec_info(self, num_tokens: int
|
448
|
+
def get_spec_info(self, num_tokens: int):
|
443
449
|
spec_info = None
|
444
450
|
if self.model_runner.spec_algorithm.is_eagle():
|
445
|
-
from sglang.srt.speculative.eagle_utils import
|
446
|
-
EAGLEDraftInput,
|
447
|
-
EagleVerifyInput,
|
448
|
-
)
|
451
|
+
from sglang.srt.speculative.eagle_utils import EagleVerifyInput
|
449
452
|
|
450
453
|
if self.model_runner.is_draft_worker:
|
451
|
-
|
452
|
-
spec_info.load_server_args(self.model_runner.server_args)
|
453
|
-
spec_info.hidden_states = self.hidden_states[:num_tokens]
|
454
|
-
spec_info.positions = positions
|
455
|
-
spec_info.capture_hidden_mode = CaptureHiddenMode.FULL
|
454
|
+
raise RuntimeError("This should not happen.")
|
456
455
|
else:
|
457
456
|
spec_info = EagleVerifyInput(
|
458
|
-
None,
|
459
|
-
|
460
|
-
|
461
|
-
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
|
469
|
-
device="cuda",
|
457
|
+
draft_token=None,
|
458
|
+
custom_mask=torch.zeros(
|
459
|
+
(num_tokens * self.model_runner.model_config.context_len),
|
460
|
+
dtype=torch.bool,
|
461
|
+
device="cuda",
|
462
|
+
),
|
463
|
+
positions=None,
|
464
|
+
retrive_index=None,
|
465
|
+
retrive_cum_len=None,
|
466
|
+
draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens,
|
467
|
+
capture_hidden_mode=CaptureHiddenMode.FULL,
|
470
468
|
)
|
471
|
-
spec_info.capture_hidden_mode = CaptureHiddenMode.FULL
|
472
469
|
|
473
470
|
return spec_info
|
@@ -197,64 +197,6 @@ class ForwardBatch:
|
|
197
197
|
# For Qwen2-VL
|
198
198
|
mrope_positions: torch.Tensor = None
|
199
199
|
|
200
|
-
def compute_mrope_positions(
|
201
|
-
self, model_runner: ModelRunner, batch: ModelWorkerBatch
|
202
|
-
):
|
203
|
-
device = model_runner.device
|
204
|
-
hf_config = model_runner.model_config.hf_config
|
205
|
-
mrope_positions_list = [None] * self.seq_lens.shape[0]
|
206
|
-
if self.forward_mode.is_decode():
|
207
|
-
for i, _ in enumerate(mrope_positions_list):
|
208
|
-
mrope_position_delta = (
|
209
|
-
0
|
210
|
-
if batch.image_inputs[i] is None
|
211
|
-
else batch.image_inputs[i].mrope_position_delta
|
212
|
-
)
|
213
|
-
mrope_positions_list[i] = MRotaryEmbedding.get_next_input_positions(
|
214
|
-
mrope_position_delta,
|
215
|
-
int(self.seq_lens[i]) - 1,
|
216
|
-
int(self.seq_lens[i]),
|
217
|
-
)
|
218
|
-
elif self.forward_mode.is_extend():
|
219
|
-
extend_start_loc_cpu = self.extend_start_loc.cpu().numpy()
|
220
|
-
for i, image_inputs in enumerate(batch.image_inputs):
|
221
|
-
extend_start_loc, extend_seq_len, extend_prefix_len = (
|
222
|
-
extend_start_loc_cpu[i],
|
223
|
-
batch.extend_seq_lens[i],
|
224
|
-
batch.extend_prefix_lens[i],
|
225
|
-
)
|
226
|
-
if image_inputs is None:
|
227
|
-
# text only
|
228
|
-
mrope_positions = [
|
229
|
-
[
|
230
|
-
pos
|
231
|
-
for pos in range(
|
232
|
-
extend_prefix_len, extend_prefix_len + extend_seq_len
|
233
|
-
)
|
234
|
-
]
|
235
|
-
] * 3
|
236
|
-
else:
|
237
|
-
# TODO: current qwen2-vl do not support radix cache since mrope position calculation
|
238
|
-
mrope_positions, mrope_position_delta = (
|
239
|
-
MRotaryEmbedding.get_input_positions(
|
240
|
-
input_tokens=self.input_ids[
|
241
|
-
extend_start_loc : extend_start_loc + extend_seq_len
|
242
|
-
],
|
243
|
-
image_grid_thw=image_inputs.image_grid_thws,
|
244
|
-
vision_start_token_id=hf_config.vision_start_token_id,
|
245
|
-
spatial_merge_size=hf_config.vision_config.spatial_merge_size,
|
246
|
-
context_len=0,
|
247
|
-
)
|
248
|
-
)
|
249
|
-
batch.image_inputs[i].mrope_position_delta = mrope_position_delta
|
250
|
-
mrope_positions_list[i] = mrope_positions
|
251
|
-
|
252
|
-
self.mrope_positions = torch.concat(
|
253
|
-
[torch.tensor(pos, device=device) for pos in mrope_positions_list],
|
254
|
-
axis=1,
|
255
|
-
)
|
256
|
-
self.mrope_positions = self.mrope_positions.to(torch.int64)
|
257
|
-
|
258
200
|
@classmethod
|
259
201
|
def init_new(
|
260
202
|
cls,
|
@@ -337,7 +279,7 @@ class ForwardBatch:
|
|
337
279
|
ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens
|
338
280
|
|
339
281
|
if model_runner.model_is_mrope:
|
340
|
-
ret.
|
282
|
+
ret._compute_mrope_positions(model_runner, batch)
|
341
283
|
|
342
284
|
# Init lora information
|
343
285
|
if model_runner.server_args.lora_paths is not None:
|
@@ -345,6 +287,63 @@ class ForwardBatch:
|
|
345
287
|
|
346
288
|
return ret
|
347
289
|
|
290
|
+
def _compute_mrope_positions(
|
291
|
+
self, model_runner: ModelRunner, batch: ModelWorkerBatch
|
292
|
+
):
|
293
|
+
device = model_runner.device
|
294
|
+
hf_config = model_runner.model_config.hf_config
|
295
|
+
mrope_positions_list = [None] * self.seq_lens.shape[0]
|
296
|
+
if self.forward_mode.is_decode():
|
297
|
+
for i, _ in enumerate(mrope_positions_list):
|
298
|
+
mrope_position_delta = (
|
299
|
+
0
|
300
|
+
if batch.image_inputs[i] is None
|
301
|
+
else batch.image_inputs[i].mrope_position_delta
|
302
|
+
)
|
303
|
+
mrope_positions_list[i] = MRotaryEmbedding.get_next_input_positions(
|
304
|
+
mrope_position_delta,
|
305
|
+
int(self.seq_lens[i]) - 1,
|
306
|
+
int(self.seq_lens[i]),
|
307
|
+
)
|
308
|
+
elif self.forward_mode.is_extend():
|
309
|
+
extend_start_loc_cpu = self.extend_start_loc.cpu().numpy()
|
310
|
+
for i, image_inputs in enumerate(batch.image_inputs):
|
311
|
+
extend_start_loc, extend_seq_len, extend_prefix_len = (
|
312
|
+
extend_start_loc_cpu[i],
|
313
|
+
batch.extend_seq_lens[i],
|
314
|
+
batch.extend_prefix_lens[i],
|
315
|
+
)
|
316
|
+
if image_inputs is None:
|
317
|
+
# text only
|
318
|
+
mrope_positions = [
|
319
|
+
[
|
320
|
+
pos
|
321
|
+
for pos in range(
|
322
|
+
extend_prefix_len, extend_prefix_len + extend_seq_len
|
323
|
+
)
|
324
|
+
]
|
325
|
+
] * 3
|
326
|
+
else:
|
327
|
+
# TODO: current qwen2-vl do not support radix cache since mrope position calculation
|
328
|
+
mrope_positions, mrope_position_delta = (
|
329
|
+
MRotaryEmbedding.get_input_positions(
|
330
|
+
input_tokens=self.input_ids[
|
331
|
+
extend_start_loc : extend_start_loc + extend_seq_len
|
332
|
+
],
|
333
|
+
image_grid_thw=image_inputs.image_grid_thws,
|
334
|
+
vision_start_token_id=hf_config.vision_start_token_id,
|
335
|
+
spatial_merge_size=hf_config.vision_config.spatial_merge_size,
|
336
|
+
context_len=0,
|
337
|
+
)
|
338
|
+
)
|
339
|
+
batch.image_inputs[i].mrope_position_delta = mrope_position_delta
|
340
|
+
mrope_positions_list[i] = mrope_positions
|
341
|
+
self.mrope_positions = torch.concat(
|
342
|
+
[torch.tensor(pos, device=device) for pos in mrope_positions_list],
|
343
|
+
axis=1,
|
344
|
+
)
|
345
|
+
self.mrope_positions = self.mrope_positions.to(torch.int64)
|
346
|
+
|
348
347
|
|
349
348
|
def compute_position_triton(
|
350
349
|
extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor, extend_seq_lens_sum
|
@@ -52,6 +52,7 @@ from sglang.srt.mem_cache.memory_pool import (
|
|
52
52
|
MLATokenToKVPool,
|
53
53
|
ReqToTokenPool,
|
54
54
|
)
|
55
|
+
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
55
56
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
56
57
|
from sglang.srt.model_loader import get_model
|
57
58
|
from sglang.srt.server_args import ServerArgs
|
@@ -529,6 +530,7 @@ class ModelRunner:
|
|
529
530
|
max_loras_per_batch=self.server_args.max_loras_per_batch,
|
530
531
|
load_config=self.load_config,
|
531
532
|
dtype=self.dtype,
|
533
|
+
lora_backend=self.server_args.lora_backend,
|
532
534
|
)
|
533
535
|
logger.info("LoRA manager ready.")
|
534
536
|
|
@@ -714,8 +716,6 @@ class ModelRunner:
|
|
714
716
|
|
715
717
|
def init_cuda_graphs(self):
|
716
718
|
"""Capture cuda graphs."""
|
717
|
-
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
718
|
-
|
719
719
|
self.cuda_graph_runner = None
|
720
720
|
|
721
721
|
if not self.is_generation:
|