sglang 0.2.11__py3-none-any.whl → 0.2.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +7 -1
- sglang/bench_latency.py +9 -6
- sglang/bench_serving.py +46 -22
- sglang/global_config.py +1 -1
- sglang/lang/backend/runtime_endpoint.py +60 -49
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +4 -2
- sglang/lang/ir.py +16 -7
- sglang/srt/constrained/base_tool_cache.py +1 -1
- sglang/srt/constrained/fsm_cache.py +12 -2
- sglang/srt/constrained/jump_forward.py +13 -2
- sglang/srt/layers/activation.py +32 -0
- sglang/srt/layers/{token_attention.py → decode_attention.py} +9 -5
- sglang/srt/layers/extend_attention.py +9 -2
- sglang/srt/layers/fused_moe/__init__.py +1 -0
- sglang/srt/layers/{fused_moe.py → fused_moe/fused_moe.py} +165 -108
- sglang/srt/layers/fused_moe/layer.py +587 -0
- sglang/srt/layers/layernorm.py +65 -0
- sglang/srt/layers/logits_processor.py +7 -2
- sglang/srt/layers/pooler.py +50 -0
- sglang/srt/layers/{context_flashattention_nopad.py → prefill_attention.py} +5 -0
- sglang/srt/layers/radix_attention.py +40 -16
- sglang/srt/managers/detokenizer_manager.py +31 -9
- sglang/srt/managers/io_struct.py +63 -0
- sglang/srt/managers/policy_scheduler.py +173 -25
- sglang/srt/managers/schedule_batch.py +115 -97
- sglang/srt/managers/tokenizer_manager.py +194 -112
- sglang/srt/managers/tp_worker.py +290 -359
- sglang/srt/mem_cache/{base_cache.py → base_prefix_cache.py} +9 -4
- sglang/srt/mem_cache/chunk_cache.py +43 -20
- sglang/srt/mem_cache/memory_pool.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +74 -40
- sglang/srt/model_executor/cuda_graph_runner.py +71 -25
- sglang/srt/model_executor/forward_batch_info.py +293 -156
- sglang/srt/model_executor/model_runner.py +77 -57
- sglang/srt/models/chatglm.py +2 -2
- sglang/srt/models/commandr.py +1 -1
- sglang/srt/models/deepseek.py +2 -2
- sglang/srt/models/deepseek_v2.py +7 -6
- sglang/srt/models/gemma.py +1 -1
- sglang/srt/models/gemma2.py +11 -6
- sglang/srt/models/grok.py +50 -396
- sglang/srt/models/internlm2.py +2 -7
- sglang/srt/models/llama2.py +4 -4
- sglang/srt/models/llama_embedding.py +88 -0
- sglang/srt/models/minicpm.py +2 -2
- sglang/srt/models/mixtral.py +56 -254
- sglang/srt/models/mixtral_quant.py +1 -4
- sglang/srt/models/qwen.py +2 -2
- sglang/srt/models/qwen2.py +2 -2
- sglang/srt/models/qwen2_moe.py +2 -13
- sglang/srt/models/stablelm.py +1 -1
- sglang/srt/openai_api/adapter.py +187 -48
- sglang/srt/openai_api/protocol.py +37 -1
- sglang/srt/sampling/penaltylib/__init__.py +13 -0
- sglang/srt/sampling/penaltylib/orchestrator.py +357 -0
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +80 -0
- sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +105 -0
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +79 -0
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +83 -0
- sglang/srt/sampling_params.py +31 -8
- sglang/srt/server.py +91 -29
- sglang/srt/server_args.py +32 -19
- sglang/srt/utils.py +32 -15
- sglang/test/run_eval.py +10 -1
- sglang/test/runners.py +81 -73
- sglang/test/simple_eval_humaneval.py +2 -8
- sglang/test/simple_eval_mgsm.py +203 -0
- sglang/test/srt/sampling/penaltylib/utils.py +337 -0
- sglang/test/test_layernorm.py +60 -0
- sglang/test/test_programs.py +36 -7
- sglang/test/test_utils.py +24 -2
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/METADATA +33 -16
- sglang-0.2.13.dist-info/RECORD +112 -0
- {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/WHEEL +1 -1
- sglang/srt/layers/linear.py +0 -884
- sglang/srt/layers/quantization/__init__.py +0 -64
- sglang/srt/layers/quantization/fp8.py +0 -677
- sglang/srt/model_loader/model_loader.py +0 -292
- sglang/srt/model_loader/utils.py +0 -275
- sglang-0.2.11.dist-info/RECORD +0 -102
- {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/LICENSE +0 -0
- {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/top_level.txt +0 -0
@@ -16,13 +16,17 @@ limitations under the License.
|
|
16
16
|
"""ModelRunner runs the forward passes of the models."""
|
17
17
|
from dataclasses import dataclass
|
18
18
|
from enum import IntEnum, auto
|
19
|
-
from typing import List
|
19
|
+
from typing import TYPE_CHECKING, List, Optional
|
20
20
|
|
21
21
|
import numpy as np
|
22
22
|
import torch
|
23
23
|
|
24
|
+
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
24
25
|
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
25
26
|
|
27
|
+
if TYPE_CHECKING:
|
28
|
+
from sglang.srt.model_executor.model_runner import ModelRunner
|
29
|
+
|
26
30
|
|
27
31
|
class ForwardMode(IntEnum):
|
28
32
|
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
|
@@ -39,25 +43,33 @@ class InputMetadata:
|
|
39
43
|
|
40
44
|
forward_mode: ForwardMode
|
41
45
|
batch_size: int
|
42
|
-
total_num_tokens: int
|
43
46
|
req_pool_indices: torch.Tensor
|
44
47
|
seq_lens: torch.Tensor
|
45
|
-
positions: torch.Tensor
|
46
48
|
req_to_token_pool: ReqToTokenPool
|
47
49
|
token_to_kv_pool: BaseTokenToKVPool
|
48
50
|
|
49
|
-
# For extend
|
50
|
-
extend_seq_lens: torch.Tensor
|
51
|
-
extend_start_loc: torch.Tensor
|
52
|
-
extend_no_prefix: bool
|
53
|
-
|
54
51
|
# Output location of the KV cache
|
55
|
-
out_cache_loc: torch.Tensor
|
52
|
+
out_cache_loc: torch.Tensor
|
53
|
+
|
54
|
+
total_num_tokens: int = None
|
55
|
+
|
56
|
+
# Position information
|
57
|
+
positions: torch.Tensor = None
|
58
|
+
|
59
|
+
# For extend
|
60
|
+
extend_seq_lens: torch.Tensor = None
|
61
|
+
extend_start_loc: torch.Tensor = None
|
62
|
+
extend_no_prefix: bool = None
|
56
63
|
|
57
64
|
# Output options
|
58
65
|
return_logprob: bool = False
|
59
66
|
top_logprobs_nums: List[int] = None
|
60
67
|
|
68
|
+
# For multimodal
|
69
|
+
pixel_values: List[torch.Tensor] = None
|
70
|
+
image_sizes: List[List[int]] = None
|
71
|
+
image_offsets: List[int] = None
|
72
|
+
|
61
73
|
# Trition attention backend
|
62
74
|
triton_max_seq_len: int = 0
|
63
75
|
triton_max_extend_len: int = 0
|
@@ -70,107 +82,175 @@ class InputMetadata:
|
|
70
82
|
flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
|
71
83
|
flashinfer_use_ragged: bool = False
|
72
84
|
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
out_cache_loc,
|
83
|
-
top_logprobs_nums=None,
|
84
|
-
return_logprob=False,
|
85
|
-
skip_flashinfer_init=False,
|
86
|
-
):
|
87
|
-
flashinfer_use_ragged = False
|
88
|
-
if not skip_flashinfer_init and not model_runner.server_args.disable_flashinfer:
|
89
|
-
if forward_mode != ForwardMode.DECODE and int(torch.sum(seq_lens)) > 4096:
|
90
|
-
flashinfer_use_ragged = True
|
91
|
-
init_flashinfer_args(
|
92
|
-
forward_mode,
|
93
|
-
model_runner,
|
94
|
-
req_pool_indices,
|
95
|
-
seq_lens,
|
96
|
-
prefix_lens,
|
97
|
-
model_runner.flashinfer_decode_wrapper,
|
98
|
-
flashinfer_use_ragged,
|
85
|
+
def init_multimuldal_info(self, batch: ScheduleBatch):
|
86
|
+
reqs = batch.reqs
|
87
|
+
self.pixel_values = [r.pixel_values for r in reqs]
|
88
|
+
self.image_sizes = [r.image_size for r in reqs]
|
89
|
+
self.image_offsets = [
|
90
|
+
(
|
91
|
+
(r.image_offset - len(r.prefix_indices))
|
92
|
+
if r.image_offset is not None
|
93
|
+
else 0
|
99
94
|
)
|
95
|
+
for r in reqs
|
96
|
+
]
|
100
97
|
|
101
|
-
|
98
|
+
def compute_positions(self, batch: ScheduleBatch):
|
99
|
+
position_ids_offsets = batch.position_ids_offsets
|
102
100
|
|
103
|
-
if forward_mode == ForwardMode.DECODE:
|
104
|
-
|
105
|
-
|
106
|
-
if not model_runner.server_args.disable_flashinfer:
|
107
|
-
# This variable is not needed in this case,
|
108
|
-
# we do not compute it to make it compatbile with cuda graph.
|
109
|
-
total_num_tokens = None
|
101
|
+
if self.forward_mode == ForwardMode.DECODE:
|
102
|
+
if True:
|
103
|
+
self.positions = self.seq_lens - 1
|
110
104
|
else:
|
111
|
-
|
105
|
+
# Deprecated
|
106
|
+
self.positions = (self.seq_lens - 1) + position_ids_offsets
|
112
107
|
else:
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
)
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
108
|
+
if True:
|
109
|
+
self.positions = torch.tensor(
|
110
|
+
np.concatenate(
|
111
|
+
[
|
112
|
+
np.arange(len(req.prefix_indices), len(req.fill_ids))
|
113
|
+
for req in batch.reqs
|
114
|
+
],
|
115
|
+
axis=0,
|
116
|
+
),
|
117
|
+
device="cuda",
|
118
|
+
)
|
119
|
+
else:
|
120
|
+
# Deprecated
|
121
|
+
position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
|
122
|
+
self.positions = torch.tensor(
|
123
|
+
np.concatenate(
|
124
|
+
[
|
125
|
+
np.arange(
|
126
|
+
len(req.prefix_indices) + position_ids_offsets_cpu[i],
|
127
|
+
len(req.fill_ids) + position_ids_offsets_cpu[i],
|
128
|
+
)
|
129
|
+
for i, req in enumerate(batch.reqs)
|
130
|
+
],
|
131
|
+
axis=0,
|
132
|
+
),
|
133
|
+
device="cuda",
|
134
|
+
)
|
135
|
+
|
136
|
+
# Positions should be in long type
|
137
|
+
self.positions = self.positions.to(torch.int64)
|
138
|
+
|
139
|
+
def compute_extend_infos(self, batch: ScheduleBatch):
|
140
|
+
if self.forward_mode == ForwardMode.DECODE:
|
141
|
+
self.extend_seq_lens = self.extend_start_loc = self.extend_no_prefix = None
|
142
|
+
else:
|
143
|
+
extend_lens_cpu = [
|
144
|
+
len(r.fill_ids) - len(r.prefix_indices) for r in batch.reqs
|
145
|
+
]
|
146
|
+
self.extend_seq_lens = torch.tensor(extend_lens_cpu, device="cuda")
|
147
|
+
self.extend_start_loc = torch.zeros_like(self.seq_lens)
|
148
|
+
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
|
149
|
+
self.extend_no_prefix = all(len(r.prefix_indices) == 0 for r in batch.reqs)
|
134
150
|
|
151
|
+
@classmethod
|
152
|
+
def from_schedule_batch(
|
153
|
+
cls,
|
154
|
+
model_runner: "ModelRunner",
|
155
|
+
batch: ScheduleBatch,
|
156
|
+
forward_mode: ForwardMode,
|
157
|
+
):
|
135
158
|
ret = cls(
|
136
159
|
forward_mode=forward_mode,
|
137
|
-
batch_size=batch_size,
|
138
|
-
|
139
|
-
|
140
|
-
seq_lens=seq_lens,
|
141
|
-
positions=positions,
|
160
|
+
batch_size=batch.batch_size(),
|
161
|
+
req_pool_indices=batch.req_pool_indices,
|
162
|
+
seq_lens=batch.seq_lens,
|
142
163
|
req_to_token_pool=model_runner.req_to_token_pool,
|
143
164
|
token_to_kv_pool=model_runner.token_to_kv_pool,
|
144
|
-
out_cache_loc=out_cache_loc,
|
145
|
-
|
146
|
-
|
147
|
-
extend_no_prefix=extend_no_prefix,
|
148
|
-
return_logprob=return_logprob,
|
149
|
-
top_logprobs_nums=top_logprobs_nums,
|
150
|
-
flashinfer_prefill_wrapper_ragged=model_runner.flashinfer_prefill_wrapper_ragged,
|
151
|
-
flashinfer_prefill_wrapper_paged=model_runner.flashinfer_prefill_wrapper_paged,
|
152
|
-
flashinfer_decode_wrapper=model_runner.flashinfer_decode_wrapper,
|
153
|
-
flashinfer_use_ragged=flashinfer_use_ragged,
|
165
|
+
out_cache_loc=batch.out_cache_loc,
|
166
|
+
return_logprob=batch.return_logprob,
|
167
|
+
top_logprobs_nums=batch.top_logprobs_nums,
|
154
168
|
)
|
155
169
|
|
170
|
+
ret.compute_positions(batch)
|
171
|
+
|
172
|
+
ret.compute_extend_infos(batch)
|
173
|
+
|
174
|
+
if (
|
175
|
+
forward_mode != ForwardMode.DECODE
|
176
|
+
or model_runner.server_args.disable_flashinfer
|
177
|
+
):
|
178
|
+
ret.total_num_tokens = int(torch.sum(ret.seq_lens))
|
179
|
+
|
180
|
+
if forward_mode != ForwardMode.DECODE:
|
181
|
+
ret.init_multimuldal_info(batch)
|
182
|
+
|
183
|
+
prefix_lens = None
|
184
|
+
if forward_mode != ForwardMode.DECODE:
|
185
|
+
prefix_lens = torch.tensor(
|
186
|
+
[len(r.prefix_indices) for r in batch.reqs], device="cuda"
|
187
|
+
)
|
188
|
+
|
156
189
|
if model_runner.server_args.disable_flashinfer:
|
157
|
-
(
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
190
|
+
ret.init_triton_args(batch, prefix_lens)
|
191
|
+
|
192
|
+
flashinfer_use_ragged = False
|
193
|
+
if not model_runner.server_args.disable_flashinfer:
|
194
|
+
if (
|
195
|
+
forward_mode != ForwardMode.DECODE
|
196
|
+
and int(torch.sum(ret.seq_lens)) > 4096
|
197
|
+
and model_runner.sliding_window_size is None
|
198
|
+
):
|
199
|
+
flashinfer_use_ragged = True
|
200
|
+
ret.init_flashinfer_handlers(
|
201
|
+
model_runner, prefix_lens, flashinfer_use_ragged
|
202
|
+
)
|
163
203
|
|
164
204
|
return ret
|
165
205
|
|
206
|
+
def init_triton_args(self, batch: ScheduleBatch, prefix_lens):
|
207
|
+
"""Init auxiliary variables for triton attention backend."""
|
208
|
+
self.triton_max_seq_len = int(torch.max(self.seq_lens))
|
209
|
+
self.triton_prefix_lens = prefix_lens
|
210
|
+
self.triton_start_loc = torch.zeros_like(self.seq_lens, dtype=torch.int32)
|
211
|
+
self.triton_start_loc[1:] = torch.cumsum(self.seq_lens[:-1], dim=0)
|
212
|
+
|
213
|
+
if self.forward_mode == ForwardMode.DECODE:
|
214
|
+
self.triton_max_extend_len = None
|
215
|
+
else:
|
216
|
+
extend_seq_lens = self.seq_lens - prefix_lens
|
217
|
+
self.triton_max_extend_len = int(torch.max(extend_seq_lens))
|
218
|
+
|
219
|
+
def init_flashinfer_handlers(
|
220
|
+
self,
|
221
|
+
model_runner,
|
222
|
+
prefix_lens,
|
223
|
+
flashinfer_use_ragged,
|
224
|
+
):
|
225
|
+
update_flashinfer_indices(
|
226
|
+
self.forward_mode,
|
227
|
+
model_runner,
|
228
|
+
self.req_pool_indices,
|
229
|
+
self.seq_lens,
|
230
|
+
prefix_lens,
|
231
|
+
flashinfer_use_ragged=flashinfer_use_ragged,
|
232
|
+
)
|
166
233
|
|
167
|
-
|
234
|
+
(
|
235
|
+
self.flashinfer_prefill_wrapper_ragged,
|
236
|
+
self.flashinfer_prefill_wrapper_paged,
|
237
|
+
self.flashinfer_decode_wrapper,
|
238
|
+
self.flashinfer_use_ragged,
|
239
|
+
) = (
|
240
|
+
model_runner.flashinfer_prefill_wrapper_ragged,
|
241
|
+
model_runner.flashinfer_prefill_wrapper_paged,
|
242
|
+
model_runner.flashinfer_decode_wrapper,
|
243
|
+
flashinfer_use_ragged,
|
244
|
+
)
|
245
|
+
|
246
|
+
|
247
|
+
def update_flashinfer_indices(
|
168
248
|
forward_mode,
|
169
249
|
model_runner,
|
170
250
|
req_pool_indices,
|
171
251
|
seq_lens,
|
172
252
|
prefix_lens,
|
173
|
-
flashinfer_decode_wrapper,
|
253
|
+
flashinfer_decode_wrapper=None,
|
174
254
|
flashinfer_use_ragged=False,
|
175
255
|
):
|
176
256
|
"""Init auxiliary variables for FlashInfer attention backend."""
|
@@ -178,79 +258,136 @@ def init_flashinfer_args(
|
|
178
258
|
num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size)
|
179
259
|
head_dim = model_runner.model_config.head_dim
|
180
260
|
batch_size = len(req_pool_indices)
|
181
|
-
total_num_tokens = int(torch.sum(seq_lens))
|
182
|
-
|
183
|
-
if flashinfer_use_ragged:
|
184
|
-
paged_kernel_lens = prefix_lens
|
185
|
-
else:
|
186
|
-
paged_kernel_lens = seq_lens
|
187
|
-
|
188
|
-
kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
189
|
-
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
190
|
-
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
|
191
|
-
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
|
192
|
-
kv_indices = torch.cat(
|
193
|
-
[
|
194
|
-
model_runner.req_to_token_pool.req_to_token[
|
195
|
-
req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
|
196
|
-
]
|
197
|
-
for i in range(batch_size)
|
198
|
-
],
|
199
|
-
dim=0,
|
200
|
-
).contiguous()
|
201
|
-
kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
|
202
|
-
|
203
|
-
if forward_mode == ForwardMode.DECODE:
|
204
|
-
flashinfer_decode_wrapper.end_forward()
|
205
|
-
flashinfer_decode_wrapper.begin_forward(
|
206
|
-
kv_indptr,
|
207
|
-
kv_indices,
|
208
|
-
kv_last_page_len,
|
209
|
-
num_qo_heads,
|
210
|
-
num_kv_heads,
|
211
|
-
head_dim,
|
212
|
-
1,
|
213
|
-
)
|
214
|
-
else:
|
215
|
-
# extend part
|
216
|
-
qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
217
|
-
qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
218
261
|
|
262
|
+
if model_runner.sliding_window_size is None:
|
219
263
|
if flashinfer_use_ragged:
|
220
|
-
|
221
|
-
|
222
|
-
|
264
|
+
paged_kernel_lens = prefix_lens
|
265
|
+
else:
|
266
|
+
paged_kernel_lens = seq_lens
|
267
|
+
|
268
|
+
kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
269
|
+
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
270
|
+
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
|
271
|
+
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
|
272
|
+
kv_indices = torch.cat(
|
273
|
+
[
|
274
|
+
model_runner.req_to_token_pool.req_to_token[
|
275
|
+
req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
|
276
|
+
]
|
277
|
+
for i in range(batch_size)
|
278
|
+
],
|
279
|
+
dim=0,
|
280
|
+
).contiguous()
|
281
|
+
kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
|
282
|
+
|
283
|
+
if forward_mode == ForwardMode.DECODE:
|
284
|
+
# CUDA graph uses different flashinfer_decode_wrapper
|
285
|
+
if flashinfer_decode_wrapper is None:
|
286
|
+
flashinfer_decode_wrapper = model_runner.flashinfer_decode_wrapper
|
287
|
+
|
288
|
+
flashinfer_decode_wrapper.end_forward()
|
289
|
+
flashinfer_decode_wrapper.begin_forward(
|
290
|
+
kv_indptr,
|
291
|
+
kv_indices,
|
292
|
+
kv_last_page_len,
|
293
|
+
num_qo_heads,
|
294
|
+
num_kv_heads,
|
295
|
+
head_dim,
|
296
|
+
1,
|
297
|
+
)
|
298
|
+
else:
|
299
|
+
# extend part
|
300
|
+
qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
301
|
+
qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
302
|
+
|
303
|
+
if flashinfer_use_ragged:
|
304
|
+
model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
|
305
|
+
model_runner.flashinfer_prefill_wrapper_ragged.begin_forward(
|
306
|
+
qo_indptr,
|
307
|
+
qo_indptr,
|
308
|
+
num_qo_heads,
|
309
|
+
num_kv_heads,
|
310
|
+
head_dim,
|
311
|
+
)
|
312
|
+
|
313
|
+
# cached part
|
314
|
+
model_runner.flashinfer_prefill_wrapper_paged.end_forward()
|
315
|
+
model_runner.flashinfer_prefill_wrapper_paged.begin_forward(
|
223
316
|
qo_indptr,
|
317
|
+
kv_indptr,
|
318
|
+
kv_indices,
|
319
|
+
kv_last_page_len,
|
224
320
|
num_qo_heads,
|
225
321
|
num_kv_heads,
|
226
322
|
head_dim,
|
323
|
+
1,
|
227
324
|
)
|
228
|
-
|
229
|
-
# cached part
|
230
|
-
model_runner.flashinfer_prefill_wrapper_paged.end_forward()
|
231
|
-
model_runner.flashinfer_prefill_wrapper_paged.begin_forward(
|
232
|
-
qo_indptr,
|
233
|
-
kv_indptr,
|
234
|
-
kv_indices,
|
235
|
-
kv_last_page_len,
|
236
|
-
num_qo_heads,
|
237
|
-
num_kv_heads,
|
238
|
-
head_dim,
|
239
|
-
1,
|
240
|
-
)
|
241
|
-
|
242
|
-
|
243
|
-
def init_triton_args(forward_mode, seq_lens, prefix_lens):
|
244
|
-
"""Init auxiliary variables for triton attention backend."""
|
245
|
-
batch_size = len(seq_lens)
|
246
|
-
max_seq_len = int(torch.max(seq_lens))
|
247
|
-
start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
|
248
|
-
start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
|
249
|
-
|
250
|
-
if forward_mode == ForwardMode.DECODE:
|
251
|
-
max_extend_len = None
|
252
325
|
else:
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
326
|
+
# window attention use paged only
|
327
|
+
kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
|
328
|
+
for wrapper_id in range(2):
|
329
|
+
if wrapper_id == 0:
|
330
|
+
if forward_mode == ForwardMode.DECODE:
|
331
|
+
paged_kernel_lens = torch.minimum(
|
332
|
+
seq_lens, torch.tensor(model_runner.sliding_window_size + 1)
|
333
|
+
)
|
334
|
+
else:
|
335
|
+
paged_kernel_lens = torch.minimum(
|
336
|
+
seq_lens,
|
337
|
+
torch.tensor(model_runner.sliding_window_size)
|
338
|
+
+ seq_lens
|
339
|
+
- prefix_lens,
|
340
|
+
)
|
341
|
+
else:
|
342
|
+
paged_kernel_lens = seq_lens
|
343
|
+
|
344
|
+
kv_start_idx = seq_lens - paged_kernel_lens
|
345
|
+
|
346
|
+
kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
347
|
+
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
348
|
+
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
|
349
|
+
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
|
350
|
+
kv_indices = torch.cat(
|
351
|
+
[
|
352
|
+
model_runner.req_to_token_pool.req_to_token[
|
353
|
+
req_pool_indices_cpu[i],
|
354
|
+
kv_start_idx[i] : kv_start_idx[i] + paged_kernel_lens_cpu[i],
|
355
|
+
]
|
356
|
+
for i in range(batch_size)
|
357
|
+
],
|
358
|
+
dim=0,
|
359
|
+
).contiguous()
|
360
|
+
|
361
|
+
if forward_mode == ForwardMode.DECODE:
|
362
|
+
# CUDA graph uses different flashinfer_decode_wrapper
|
363
|
+
if flashinfer_decode_wrapper is None:
|
364
|
+
flashinfer_decode_wrapper = model_runner.flashinfer_decode_wrapper
|
365
|
+
|
366
|
+
flashinfer_decode_wrapper[wrapper_id].end_forward()
|
367
|
+
flashinfer_decode_wrapper[wrapper_id].begin_forward(
|
368
|
+
kv_indptr,
|
369
|
+
kv_indices,
|
370
|
+
kv_last_page_len,
|
371
|
+
num_qo_heads,
|
372
|
+
num_kv_heads,
|
373
|
+
head_dim,
|
374
|
+
1,
|
375
|
+
)
|
376
|
+
else:
|
377
|
+
# extend part
|
378
|
+
qo_indptr = torch.zeros(
|
379
|
+
(batch_size + 1,), dtype=torch.int32, device="cuda"
|
380
|
+
)
|
381
|
+
qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
382
|
+
|
383
|
+
model_runner.flashinfer_prefill_wrapper_paged[wrapper_id].end_forward()
|
384
|
+
model_runner.flashinfer_prefill_wrapper_paged[wrapper_id].begin_forward(
|
385
|
+
qo_indptr,
|
386
|
+
kv_indptr,
|
387
|
+
kv_indices,
|
388
|
+
kv_last_page_len,
|
389
|
+
num_qo_heads,
|
390
|
+
num_kv_heads,
|
391
|
+
head_dim,
|
392
|
+
1,
|
393
|
+
)
|