sglang 0.2.12__py3-none-any.whl → 0.2.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +7 -1
- sglang/bench_latency.py +3 -2
- sglang/global_config.py +1 -1
- sglang/lang/backend/runtime_endpoint.py +60 -49
- sglang/lang/interpreter.py +4 -2
- sglang/lang/ir.py +13 -4
- sglang/srt/constrained/jump_forward.py +13 -2
- sglang/srt/layers/activation.py +0 -1
- sglang/srt/layers/extend_attention.py +3 -1
- sglang/srt/layers/fused_moe/__init__.py +1 -0
- sglang/srt/layers/{fused_moe.py → fused_moe/fused_moe.py} +165 -108
- sglang/srt/layers/fused_moe/layer.py +587 -0
- sglang/srt/layers/logits_processor.py +4 -4
- sglang/srt/layers/radix_attention.py +38 -14
- sglang/srt/managers/schedule_batch.py +9 -14
- sglang/srt/managers/tokenizer_manager.py +1 -1
- sglang/srt/managers/tp_worker.py +1 -7
- sglang/srt/model_executor/cuda_graph_runner.py +48 -17
- sglang/srt/model_executor/forward_batch_info.py +132 -58
- sglang/srt/model_executor/model_runner.py +61 -28
- sglang/srt/models/chatglm.py +2 -2
- sglang/srt/models/commandr.py +1 -1
- sglang/srt/models/deepseek.py +2 -2
- sglang/srt/models/deepseek_v2.py +7 -6
- sglang/srt/models/gemma.py +1 -1
- sglang/srt/models/gemma2.py +11 -5
- sglang/srt/models/grok.py +50 -396
- sglang/srt/models/minicpm.py +2 -2
- sglang/srt/models/mixtral.py +56 -254
- sglang/srt/models/mixtral_quant.py +1 -4
- sglang/srt/models/qwen.py +2 -2
- sglang/srt/models/qwen2.py +2 -2
- sglang/srt/models/qwen2_moe.py +2 -2
- sglang/srt/models/stablelm.py +1 -1
- sglang/srt/openai_api/adapter.py +32 -21
- sglang/srt/sampling_params.py +0 -4
- sglang/srt/server.py +23 -15
- sglang/srt/server_args.py +7 -1
- sglang/srt/utils.py +1 -2
- sglang/test/runners.py +18 -10
- sglang/test/test_programs.py +32 -5
- sglang/test/test_utils.py +5 -1
- sglang/version.py +1 -1
- {sglang-0.2.12.dist-info → sglang-0.2.13.dist-info}/METADATA +12 -4
- {sglang-0.2.12.dist-info → sglang-0.2.13.dist-info}/RECORD +48 -48
- {sglang-0.2.12.dist-info → sglang-0.2.13.dist-info}/WHEEL +1 -1
- sglang/srt/model_loader/model_loader.py +0 -292
- sglang/srt/model_loader/utils.py +0 -275
- {sglang-0.2.12.dist-info → sglang-0.2.13.dist-info}/LICENSE +0 -0
- {sglang-0.2.12.dist-info → sglang-0.2.13.dist-info}/top_level.txt +0 -0
@@ -15,6 +15,8 @@ limitations under the License.
|
|
15
15
|
|
16
16
|
"""Radix attention."""
|
17
17
|
|
18
|
+
from typing import Optional
|
19
|
+
|
18
20
|
import torch
|
19
21
|
from flashinfer.cascade import merge_state
|
20
22
|
from torch import nn
|
@@ -34,6 +36,7 @@ class RadixAttention(nn.Module):
|
|
34
36
|
scaling: float,
|
35
37
|
num_kv_heads: int,
|
36
38
|
layer_id: int,
|
39
|
+
sliding_window_size: Optional[int] = None,
|
37
40
|
logit_cap: int = -1,
|
38
41
|
v_head_dim: int = -1,
|
39
42
|
):
|
@@ -46,6 +49,7 @@ class RadixAttention(nn.Module):
|
|
46
49
|
self.v_head_dim = v_head_dim if v_head_dim != -1 else head_dim
|
47
50
|
self.scaling = scaling
|
48
51
|
self.layer_id = layer_id
|
52
|
+
self.sliding_window_size = sliding_window_size if sliding_window_size else -1
|
49
53
|
|
50
54
|
if (
|
51
55
|
not global_server_args_dict.get("disable_flashinfer", False)
|
@@ -113,14 +117,25 @@ class RadixAttention(nn.Module):
|
|
113
117
|
return o
|
114
118
|
|
115
119
|
def extend_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
|
120
|
+
# using two wrappers is unnecessary in the current PR, but are prepared for future PRs
|
121
|
+
prefill_wrapper_paged = input_metadata.flashinfer_prefill_wrapper_paged
|
122
|
+
if self.sliding_window_size != -1:
|
123
|
+
prefill_wrapper_paged = prefill_wrapper_paged[0]
|
124
|
+
else:
|
125
|
+
if isinstance(prefill_wrapper_paged, list):
|
126
|
+
prefill_wrapper_paged = prefill_wrapper_paged[1]
|
127
|
+
|
116
128
|
if not input_metadata.flashinfer_use_ragged:
|
117
|
-
|
129
|
+
if k is not None:
|
130
|
+
assert v is not None
|
131
|
+
self.store_kv_cache(k, v, input_metadata)
|
118
132
|
|
119
|
-
o =
|
133
|
+
o = prefill_wrapper_paged.forward(
|
120
134
|
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
121
135
|
input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
|
122
136
|
causal=True,
|
123
137
|
sm_scale=self.scaling,
|
138
|
+
window_left=self.sliding_window_size,
|
124
139
|
logits_soft_cap=self.logit_cap,
|
125
140
|
)
|
126
141
|
else:
|
@@ -138,14 +153,12 @@ class RadixAttention(nn.Module):
|
|
138
153
|
if input_metadata.extend_no_prefix:
|
139
154
|
o = o1
|
140
155
|
else:
|
141
|
-
o2, s2 = (
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
logits_soft_cap=self.logit_cap,
|
148
|
-
)
|
156
|
+
o2, s2 = prefill_wrapper_paged.forward_return_lse(
|
157
|
+
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
158
|
+
input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
|
159
|
+
causal=False,
|
160
|
+
sm_scale=self.scaling,
|
161
|
+
logits_soft_cap=self.logit_cap,
|
149
162
|
)
|
150
163
|
|
151
164
|
o, _ = merge_state(o1, s1, o2, s2)
|
@@ -158,9 +171,18 @@ class RadixAttention(nn.Module):
|
|
158
171
|
return o.view(-1, self.tp_q_head_num * self.head_dim)
|
159
172
|
|
160
173
|
def decode_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
|
161
|
-
|
174
|
+
decode_wrapper = input_metadata.flashinfer_decode_wrapper
|
175
|
+
if self.sliding_window_size != -1:
|
176
|
+
decode_wrapper = decode_wrapper[0]
|
177
|
+
else:
|
178
|
+
if isinstance(decode_wrapper, list):
|
179
|
+
decode_wrapper = decode_wrapper[1]
|
180
|
+
|
181
|
+
if k is not None:
|
182
|
+
assert v is not None
|
183
|
+
self.store_kv_cache(k, v, input_metadata)
|
162
184
|
|
163
|
-
o =
|
185
|
+
o = decode_wrapper.forward(
|
164
186
|
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
165
187
|
input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
|
166
188
|
sm_scale=self.scaling,
|
@@ -170,8 +192,10 @@ class RadixAttention(nn.Module):
|
|
170
192
|
return o.view(-1, self.tp_q_head_num * self.head_dim)
|
171
193
|
|
172
194
|
def forward(self, q, k, v, input_metadata: InputMetadata):
|
173
|
-
k
|
174
|
-
|
195
|
+
if k is not None:
|
196
|
+
assert v is not None
|
197
|
+
k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
|
198
|
+
v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
|
175
199
|
|
176
200
|
if input_metadata.forward_mode == ForwardMode.EXTEND:
|
177
201
|
return self.extend_forward(q, k, v, input_metadata)
|
@@ -235,10 +235,12 @@ class Req:
|
|
235
235
|
return
|
236
236
|
|
237
237
|
last_token_id = self.output_ids[-1]
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
238
|
+
|
239
|
+
matched_eos = last_token_id in self.sampling_params.stop_token_ids
|
240
|
+
|
241
|
+
if self.tokenizer is not None:
|
242
|
+
matched_eos |= last_token_id == self.tokenizer.eos_token_id
|
243
|
+
|
242
244
|
if matched_eos and not self.sampling_params.ignore_eos:
|
243
245
|
self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
|
244
246
|
return
|
@@ -383,7 +385,7 @@ class ScheduleBatch:
|
|
383
385
|
|
384
386
|
return out_cache_loc
|
385
387
|
|
386
|
-
def batch_sampling_params(self, vocab_size
|
388
|
+
def batch_sampling_params(self, vocab_size):
|
387
389
|
device = "cuda"
|
388
390
|
bs, reqs = self.batch_size(), self.reqs
|
389
391
|
self.temperatures = torch.tensor(
|
@@ -419,15 +421,8 @@ class ScheduleBatch:
|
|
419
421
|
|
420
422
|
# Handle logit bias but only allocate when needed
|
421
423
|
self.logit_bias = None
|
422
|
-
for i in range(bs):
|
423
|
-
if reqs[i].sampling_params.dtype == "int":
|
424
|
-
if self.logit_bias is None:
|
425
|
-
self.logit_bias = torch.zeros(
|
426
|
-
(bs, vocab_size), dtype=torch.float32, device=device
|
427
|
-
)
|
428
|
-
self.logit_bias[i][: len(int_token_logit_bias)] = int_token_logit_bias
|
429
424
|
|
430
|
-
def prepare_for_extend(self, vocab_size: int
|
425
|
+
def prepare_for_extend(self, vocab_size: int):
|
431
426
|
bs = self.batch_size()
|
432
427
|
reqs = self.reqs
|
433
428
|
input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
|
@@ -466,7 +461,7 @@ class ScheduleBatch:
|
|
466
461
|
self.out_cache_loc = out_cache_loc
|
467
462
|
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
|
468
463
|
|
469
|
-
self.batch_sampling_params(vocab_size
|
464
|
+
self.batch_sampling_params(vocab_size)
|
470
465
|
|
471
466
|
def check_decode_mem(self):
|
472
467
|
bs = self.batch_size()
|
sglang/srt/managers/tp_worker.py
CHANGED
@@ -54,7 +54,6 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
|
54
54
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
55
55
|
from sglang.srt.server_args import ServerArgs
|
56
56
|
from sglang.srt.utils import (
|
57
|
-
get_int_token_logit_bias,
|
58
57
|
is_multimodal_model,
|
59
58
|
set_random_seed,
|
60
59
|
suppress_other_loggers,
|
@@ -132,9 +131,6 @@ class ModelTpServer:
|
|
132
131
|
),
|
133
132
|
self.model_runner.req_to_token_pool.size - 1,
|
134
133
|
)
|
135
|
-
self.int_token_logit_bias = torch.tensor(
|
136
|
-
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
|
137
|
-
)
|
138
134
|
self.max_req_input_len = min(
|
139
135
|
self.model_config.context_len - 1,
|
140
136
|
self.max_total_num_tokens - 1,
|
@@ -442,9 +438,7 @@ class ModelTpServer:
|
|
442
438
|
|
443
439
|
def forward_prefill_batch(self, batch: ScheduleBatch):
|
444
440
|
# Build batch tensors
|
445
|
-
batch.prepare_for_extend(
|
446
|
-
self.model_config.vocab_size, self.int_token_logit_bias
|
447
|
-
)
|
441
|
+
batch.prepare_for_extend(self.model_config.vocab_size)
|
448
442
|
|
449
443
|
if self.model_runner.is_generation:
|
450
444
|
# Forward and sample the next tokens
|
@@ -98,8 +98,8 @@ class CudaGraphRunner:
|
|
98
98
|
self.req_pool_indices = torch.zeros(
|
99
99
|
(self.max_bs,), dtype=torch.int32, device="cuda"
|
100
100
|
)
|
101
|
-
self.seq_lens = torch.
|
102
|
-
self.position_ids_offsets = torch.
|
101
|
+
self.seq_lens = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
|
102
|
+
self.position_ids_offsets = torch.ones(
|
103
103
|
(self.max_bs,), dtype=torch.int32, device="cuda"
|
104
104
|
)
|
105
105
|
self.out_cache_loc = torch.zeros(
|
@@ -107,9 +107,6 @@ class CudaGraphRunner:
|
|
107
107
|
)
|
108
108
|
|
109
109
|
# FlashInfer inputs
|
110
|
-
self.flashinfer_workspace_buffer = (
|
111
|
-
self.model_runner.flashinfer_workspace_buffers[0]
|
112
|
-
)
|
113
110
|
self.flashinfer_kv_indptr = torch.zeros(
|
114
111
|
(self.max_bs + 1,), dtype=torch.int32, device="cuda"
|
115
112
|
)
|
@@ -121,6 +118,23 @@ class CudaGraphRunner:
|
|
121
118
|
self.flashinfer_kv_last_page_len = torch.ones(
|
122
119
|
(self.max_bs,), dtype=torch.int32, device="cuda"
|
123
120
|
)
|
121
|
+
if model_runner.sliding_window_size is None:
|
122
|
+
self.flashinfer_workspace_buffer = (
|
123
|
+
self.model_runner.flashinfer_workspace_buffer
|
124
|
+
)
|
125
|
+
else:
|
126
|
+
self.flashinfer_workspace_buffer = (
|
127
|
+
self.model_runner.flashinfer_workspace_buffer
|
128
|
+
)
|
129
|
+
|
130
|
+
self.flashinfer_kv_indptr = [
|
131
|
+
self.flashinfer_kv_indptr,
|
132
|
+
self.flashinfer_kv_indptr.clone(),
|
133
|
+
]
|
134
|
+
self.flashinfer_kv_indices = [
|
135
|
+
self.flashinfer_kv_indices,
|
136
|
+
self.flashinfer_kv_indices.clone(),
|
137
|
+
]
|
124
138
|
|
125
139
|
self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if use_torch_compile else []
|
126
140
|
|
@@ -171,15 +185,32 @@ class CudaGraphRunner:
|
|
171
185
|
use_tensor_cores = True
|
172
186
|
else:
|
173
187
|
use_tensor_cores = False
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
188
|
+
if self.model_runner.sliding_window_size is None:
|
189
|
+
flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
190
|
+
self.flashinfer_workspace_buffer,
|
191
|
+
"NHD",
|
192
|
+
use_cuda_graph=True,
|
193
|
+
use_tensor_cores=use_tensor_cores,
|
194
|
+
paged_kv_indptr_buffer=self.flashinfer_kv_indptr[: bs + 1],
|
195
|
+
paged_kv_indices_buffer=self.flashinfer_kv_indices,
|
196
|
+
paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[:bs],
|
197
|
+
)
|
198
|
+
else:
|
199
|
+
flashinfer_decode_wrapper = []
|
200
|
+
for i in range(2):
|
201
|
+
flashinfer_decode_wrapper.append(
|
202
|
+
BatchDecodeWithPagedKVCacheWrapper(
|
203
|
+
self.flashinfer_workspace_buffer,
|
204
|
+
"NHD",
|
205
|
+
use_cuda_graph=True,
|
206
|
+
use_tensor_cores=use_tensor_cores,
|
207
|
+
paged_kv_indptr_buffer=self.flashinfer_kv_indptr[i][: bs + 1],
|
208
|
+
paged_kv_indices_buffer=self.flashinfer_kv_indices[i],
|
209
|
+
paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[
|
210
|
+
:bs
|
211
|
+
],
|
212
|
+
)
|
213
|
+
)
|
183
214
|
update_flashinfer_indices(
|
184
215
|
ForwardMode.DECODE,
|
185
216
|
self.model_runner,
|
@@ -201,7 +232,7 @@ class CudaGraphRunner:
|
|
201
232
|
out_cache_loc=out_cache_loc,
|
202
233
|
return_logprob=False,
|
203
234
|
top_logprobs_nums=0,
|
204
|
-
positions=(seq_lens - 1).to(torch.int64),
|
235
|
+
positions=(seq_lens - 1 + position_ids_offsets).to(torch.int64),
|
205
236
|
flashinfer_decode_wrapper=flashinfer_decode_wrapper,
|
206
237
|
)
|
207
238
|
|
@@ -225,8 +256,8 @@ class CudaGraphRunner:
|
|
225
256
|
index = bisect.bisect_left(self.batch_size_list, raw_bs)
|
226
257
|
bs = self.batch_size_list[index]
|
227
258
|
if bs != raw_bs:
|
228
|
-
self.seq_lens.
|
229
|
-
self.position_ids_offsets.
|
259
|
+
self.seq_lens.zero_()
|
260
|
+
self.position_ids_offsets.fill_(1)
|
230
261
|
self.out_cache_loc.zero_()
|
231
262
|
|
232
263
|
# Common inputs
|
@@ -16,7 +16,7 @@ limitations under the License.
|
|
16
16
|
"""ModelRunner runs the forward passes of the models."""
|
17
17
|
from dataclasses import dataclass
|
18
18
|
from enum import IntEnum, auto
|
19
|
-
from typing import TYPE_CHECKING, List
|
19
|
+
from typing import TYPE_CHECKING, List, Optional
|
20
20
|
|
21
21
|
import numpy as np
|
22
22
|
import torch
|
@@ -194,6 +194,7 @@ class InputMetadata:
|
|
194
194
|
if (
|
195
195
|
forward_mode != ForwardMode.DECODE
|
196
196
|
and int(torch.sum(ret.seq_lens)) > 4096
|
197
|
+
and model_runner.sliding_window_size is None
|
197
198
|
):
|
198
199
|
flashinfer_use_ragged = True
|
199
200
|
ret.init_flashinfer_handlers(
|
@@ -216,7 +217,10 @@ class InputMetadata:
|
|
216
217
|
self.triton_max_extend_len = int(torch.max(extend_seq_lens))
|
217
218
|
|
218
219
|
def init_flashinfer_handlers(
|
219
|
-
self,
|
220
|
+
self,
|
221
|
+
model_runner,
|
222
|
+
prefix_lens,
|
223
|
+
flashinfer_use_ragged,
|
220
224
|
):
|
221
225
|
update_flashinfer_indices(
|
222
226
|
self.forward_mode,
|
@@ -255,65 +259,135 @@ def update_flashinfer_indices(
|
|
255
259
|
head_dim = model_runner.model_config.head_dim
|
256
260
|
batch_size = len(req_pool_indices)
|
257
261
|
|
258
|
-
if
|
259
|
-
paged_kernel_lens = prefix_lens
|
260
|
-
else:
|
261
|
-
paged_kernel_lens = seq_lens
|
262
|
-
|
263
|
-
kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
264
|
-
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
265
|
-
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
|
266
|
-
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
|
267
|
-
kv_indices = torch.cat(
|
268
|
-
[
|
269
|
-
model_runner.req_to_token_pool.req_to_token[
|
270
|
-
req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
|
271
|
-
]
|
272
|
-
for i in range(batch_size)
|
273
|
-
],
|
274
|
-
dim=0,
|
275
|
-
).contiguous()
|
276
|
-
kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
|
277
|
-
|
278
|
-
if forward_mode == ForwardMode.DECODE:
|
279
|
-
# CUDA graph uses different flashinfer_decode_wrapper
|
280
|
-
if flashinfer_decode_wrapper is None:
|
281
|
-
flashinfer_decode_wrapper = model_runner.flashinfer_decode_wrapper
|
282
|
-
|
283
|
-
flashinfer_decode_wrapper.end_forward()
|
284
|
-
flashinfer_decode_wrapper.begin_forward(
|
285
|
-
kv_indptr,
|
286
|
-
kv_indices,
|
287
|
-
kv_last_page_len,
|
288
|
-
num_qo_heads,
|
289
|
-
num_kv_heads,
|
290
|
-
head_dim,
|
291
|
-
1,
|
292
|
-
)
|
293
|
-
else:
|
294
|
-
# extend part
|
295
|
-
qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
296
|
-
qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
297
|
-
|
262
|
+
if model_runner.sliding_window_size is None:
|
298
263
|
if flashinfer_use_ragged:
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
264
|
+
paged_kernel_lens = prefix_lens
|
265
|
+
else:
|
266
|
+
paged_kernel_lens = seq_lens
|
267
|
+
|
268
|
+
kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
269
|
+
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
270
|
+
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
|
271
|
+
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
|
272
|
+
kv_indices = torch.cat(
|
273
|
+
[
|
274
|
+
model_runner.req_to_token_pool.req_to_token[
|
275
|
+
req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
|
276
|
+
]
|
277
|
+
for i in range(batch_size)
|
278
|
+
],
|
279
|
+
dim=0,
|
280
|
+
).contiguous()
|
281
|
+
kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
|
282
|
+
|
283
|
+
if forward_mode == ForwardMode.DECODE:
|
284
|
+
# CUDA graph uses different flashinfer_decode_wrapper
|
285
|
+
if flashinfer_decode_wrapper is None:
|
286
|
+
flashinfer_decode_wrapper = model_runner.flashinfer_decode_wrapper
|
287
|
+
|
288
|
+
flashinfer_decode_wrapper.end_forward()
|
289
|
+
flashinfer_decode_wrapper.begin_forward(
|
290
|
+
kv_indptr,
|
291
|
+
kv_indices,
|
292
|
+
kv_last_page_len,
|
303
293
|
num_qo_heads,
|
304
294
|
num_kv_heads,
|
305
295
|
head_dim,
|
296
|
+
1,
|
306
297
|
)
|
298
|
+
else:
|
299
|
+
# extend part
|
300
|
+
qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
301
|
+
qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
302
|
+
|
303
|
+
if flashinfer_use_ragged:
|
304
|
+
model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
|
305
|
+
model_runner.flashinfer_prefill_wrapper_ragged.begin_forward(
|
306
|
+
qo_indptr,
|
307
|
+
qo_indptr,
|
308
|
+
num_qo_heads,
|
309
|
+
num_kv_heads,
|
310
|
+
head_dim,
|
311
|
+
)
|
307
312
|
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
313
|
+
# cached part
|
314
|
+
model_runner.flashinfer_prefill_wrapper_paged.end_forward()
|
315
|
+
model_runner.flashinfer_prefill_wrapper_paged.begin_forward(
|
316
|
+
qo_indptr,
|
317
|
+
kv_indptr,
|
318
|
+
kv_indices,
|
319
|
+
kv_last_page_len,
|
320
|
+
num_qo_heads,
|
321
|
+
num_kv_heads,
|
322
|
+
head_dim,
|
323
|
+
1,
|
324
|
+
)
|
325
|
+
else:
|
326
|
+
# window attention use paged only
|
327
|
+
kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
|
328
|
+
for wrapper_id in range(2):
|
329
|
+
if wrapper_id == 0:
|
330
|
+
if forward_mode == ForwardMode.DECODE:
|
331
|
+
paged_kernel_lens = torch.minimum(
|
332
|
+
seq_lens, torch.tensor(model_runner.sliding_window_size + 1)
|
333
|
+
)
|
334
|
+
else:
|
335
|
+
paged_kernel_lens = torch.minimum(
|
336
|
+
seq_lens,
|
337
|
+
torch.tensor(model_runner.sliding_window_size)
|
338
|
+
+ seq_lens
|
339
|
+
- prefix_lens,
|
340
|
+
)
|
341
|
+
else:
|
342
|
+
paged_kernel_lens = seq_lens
|
343
|
+
|
344
|
+
kv_start_idx = seq_lens - paged_kernel_lens
|
345
|
+
|
346
|
+
kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
347
|
+
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
348
|
+
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
|
349
|
+
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
|
350
|
+
kv_indices = torch.cat(
|
351
|
+
[
|
352
|
+
model_runner.req_to_token_pool.req_to_token[
|
353
|
+
req_pool_indices_cpu[i],
|
354
|
+
kv_start_idx[i] : kv_start_idx[i] + paged_kernel_lens_cpu[i],
|
355
|
+
]
|
356
|
+
for i in range(batch_size)
|
357
|
+
],
|
358
|
+
dim=0,
|
359
|
+
).contiguous()
|
360
|
+
|
361
|
+
if forward_mode == ForwardMode.DECODE:
|
362
|
+
# CUDA graph uses different flashinfer_decode_wrapper
|
363
|
+
if flashinfer_decode_wrapper is None:
|
364
|
+
flashinfer_decode_wrapper = model_runner.flashinfer_decode_wrapper
|
365
|
+
|
366
|
+
flashinfer_decode_wrapper[wrapper_id].end_forward()
|
367
|
+
flashinfer_decode_wrapper[wrapper_id].begin_forward(
|
368
|
+
kv_indptr,
|
369
|
+
kv_indices,
|
370
|
+
kv_last_page_len,
|
371
|
+
num_qo_heads,
|
372
|
+
num_kv_heads,
|
373
|
+
head_dim,
|
374
|
+
1,
|
375
|
+
)
|
376
|
+
else:
|
377
|
+
# extend part
|
378
|
+
qo_indptr = torch.zeros(
|
379
|
+
(batch_size + 1,), dtype=torch.int32, device="cuda"
|
380
|
+
)
|
381
|
+
qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
382
|
+
|
383
|
+
model_runner.flashinfer_prefill_wrapper_paged[wrapper_id].end_forward()
|
384
|
+
model_runner.flashinfer_prefill_wrapper_paged[wrapper_id].begin_forward(
|
385
|
+
qo_indptr,
|
386
|
+
kv_indptr,
|
387
|
+
kv_indices,
|
388
|
+
kv_last_page_len,
|
389
|
+
num_qo_heads,
|
390
|
+
num_kv_heads,
|
391
|
+
head_dim,
|
392
|
+
1,
|
393
|
+
)
|
@@ -38,6 +38,7 @@ from vllm.distributed import (
|
|
38
38
|
init_distributed_environment,
|
39
39
|
initialize_model_parallel,
|
40
40
|
)
|
41
|
+
from vllm.model_executor.model_loader import get_model
|
41
42
|
from vllm.model_executor.models import ModelRegistry
|
42
43
|
|
43
44
|
from sglang.global_config import global_config
|
@@ -53,7 +54,7 @@ from sglang.srt.server_args import ServerArgs
|
|
53
54
|
from sglang.srt.utils import (
|
54
55
|
get_available_gpu_memory,
|
55
56
|
is_generation_model,
|
56
|
-
|
57
|
+
is_llama3_405b_fp8_head_16,
|
57
58
|
is_multimodal_model,
|
58
59
|
monkey_patch_vllm_dummy_weight_loader,
|
59
60
|
monkey_patch_vllm_p2p_access_check,
|
@@ -158,7 +159,7 @@ class ModelRunner:
|
|
158
159
|
skip_tokenizer_init=True,
|
159
160
|
)
|
160
161
|
|
161
|
-
if
|
162
|
+
if is_llama3_405b_fp8_head_16(self.model_config) and self.tp_size <= 8:
|
162
163
|
# A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
|
163
164
|
self.model_config.hf_config.num_key_value_heads = 8
|
164
165
|
vllm_model_config.hf_config.num_key_value_heads = 8
|
@@ -168,15 +169,6 @@ class ModelRunner:
|
|
168
169
|
if self.model_config.model_overide_args is not None:
|
169
170
|
vllm_model_config.hf_config.update(self.model_config.model_overide_args)
|
170
171
|
|
171
|
-
if (
|
172
|
-
self.server_args.efficient_weight_load
|
173
|
-
and "llama" in self.server_args.model_path.lower()
|
174
|
-
and self.server_args.quantization == "fp8"
|
175
|
-
):
|
176
|
-
from sglang.srt.model_loader.model_loader import get_model
|
177
|
-
else:
|
178
|
-
from vllm.model_executor.model_loader import get_model
|
179
|
-
|
180
172
|
self.model = get_model(
|
181
173
|
model_config=vllm_model_config,
|
182
174
|
device_config=device_config,
|
@@ -187,6 +179,11 @@ class ModelRunner:
|
|
187
179
|
scheduler_config=None,
|
188
180
|
cache_config=None,
|
189
181
|
)
|
182
|
+
self.sliding_window_size = (
|
183
|
+
self.model.get_window_size()
|
184
|
+
if hasattr(self.model, "get_window_size")
|
185
|
+
else None
|
186
|
+
)
|
190
187
|
self.is_generation = is_generation_model(
|
191
188
|
self.model_config.hf_config.architectures
|
192
189
|
)
|
@@ -296,6 +293,9 @@ class ModelRunner:
|
|
296
293
|
|
297
294
|
def init_flashinfer(self):
|
298
295
|
if self.server_args.disable_flashinfer:
|
296
|
+
assert (
|
297
|
+
self.sliding_window_size is None
|
298
|
+
), "turn on flashinfer to support window attention"
|
299
299
|
self.flashinfer_prefill_wrapper_ragged = None
|
300
300
|
self.flashinfer_prefill_wrapper_paged = None
|
301
301
|
self.flashinfer_decode_wrapper = None
|
@@ -309,20 +309,47 @@ class ModelRunner:
|
|
309
309
|
else:
|
310
310
|
use_tensor_cores = False
|
311
311
|
|
312
|
-
self.
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
312
|
+
if self.sliding_window_size is None:
|
313
|
+
self.flashinfer_workspace_buffer = torch.empty(
|
314
|
+
global_config.flashinfer_workspace_size,
|
315
|
+
dtype=torch.uint8,
|
316
|
+
device="cuda",
|
317
|
+
)
|
318
|
+
self.flashinfer_prefill_wrapper_ragged = (
|
319
|
+
BatchPrefillWithRaggedKVCacheWrapper(
|
320
|
+
self.flashinfer_workspace_buffer, "NHD"
|
321
|
+
)
|
322
|
+
)
|
323
|
+
self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
|
324
|
+
self.flashinfer_workspace_buffer, "NHD"
|
325
|
+
)
|
326
|
+
self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
327
|
+
self.flashinfer_workspace_buffer,
|
328
|
+
"NHD",
|
329
|
+
use_tensor_cores=use_tensor_cores,
|
330
|
+
)
|
331
|
+
else:
|
332
|
+
self.flashinfer_workspace_buffer = torch.empty(
|
333
|
+
global_config.flashinfer_workspace_size,
|
334
|
+
dtype=torch.uint8,
|
335
|
+
device="cuda",
|
336
|
+
)
|
337
|
+
self.flashinfer_prefill_wrapper_ragged = None
|
338
|
+
self.flashinfer_prefill_wrapper_paged = []
|
339
|
+
self.flashinfer_decode_wrapper = []
|
340
|
+
for i in range(2):
|
341
|
+
self.flashinfer_prefill_wrapper_paged.append(
|
342
|
+
BatchPrefillWithPagedKVCacheWrapper(
|
343
|
+
self.flashinfer_workspace_buffer, "NHD"
|
344
|
+
)
|
345
|
+
)
|
346
|
+
self.flashinfer_decode_wrapper.append(
|
347
|
+
BatchDecodeWithPagedKVCacheWrapper(
|
348
|
+
self.flashinfer_workspace_buffer,
|
349
|
+
"NHD",
|
350
|
+
use_tensor_cores=use_tensor_cores,
|
351
|
+
)
|
352
|
+
)
|
326
353
|
|
327
354
|
def init_cuda_graphs(self):
|
328
355
|
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
@@ -358,7 +385,9 @@ class ModelRunner:
|
|
358
385
|
return self.cuda_graph_runner.replay(batch)
|
359
386
|
|
360
387
|
input_metadata = InputMetadata.from_schedule_batch(
|
361
|
-
self,
|
388
|
+
self,
|
389
|
+
batch,
|
390
|
+
ForwardMode.DECODE,
|
362
391
|
)
|
363
392
|
|
364
393
|
return self.model.forward(
|
@@ -368,7 +397,9 @@ class ModelRunner:
|
|
368
397
|
@torch.inference_mode()
|
369
398
|
def forward_extend(self, batch: ScheduleBatch):
|
370
399
|
input_metadata = InputMetadata.from_schedule_batch(
|
371
|
-
self,
|
400
|
+
self,
|
401
|
+
batch,
|
402
|
+
forward_mode=ForwardMode.EXTEND,
|
372
403
|
)
|
373
404
|
return self.model.forward(
|
374
405
|
batch.input_ids, input_metadata.positions, input_metadata
|
@@ -377,7 +408,9 @@ class ModelRunner:
|
|
377
408
|
@torch.inference_mode()
|
378
409
|
def forward_extend_multi_modal(self, batch: ScheduleBatch):
|
379
410
|
input_metadata = InputMetadata.from_schedule_batch(
|
380
|
-
self,
|
411
|
+
self,
|
412
|
+
batch,
|
413
|
+
forward_mode=ForwardMode.EXTEND,
|
381
414
|
)
|
382
415
|
return self.model.forward(
|
383
416
|
batch.input_ids,
|