sglang 0.3.4__py3-none-any.whl → 0.3.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +2 -1
- sglang/lang/chat_template.py +17 -0
- sglang/launch_server_llavavid.py +1 -1
- sglang/srt/configs/__init__.py +3 -0
- sglang/srt/configs/model_config.py +27 -2
- sglang/srt/configs/qwen2vl.py +133 -0
- sglang/srt/constrained/fsm_cache.py +10 -3
- sglang/srt/conversation.py +27 -0
- sglang/srt/hf_transformers_utils.py +16 -1
- sglang/srt/layers/attention/__init__.py +16 -5
- sglang/srt/layers/attention/double_sparsity_backend.py +22 -6
- sglang/srt/layers/attention/flashinfer_backend.py +174 -54
- sglang/srt/layers/attention/triton_backend.py +22 -6
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +26 -4
- sglang/srt/layers/linear.py +89 -63
- sglang/srt/layers/logits_processor.py +5 -5
- sglang/srt/layers/rotary_embedding.py +112 -0
- sglang/srt/layers/sampler.py +51 -39
- sglang/srt/lora/lora.py +3 -1
- sglang/srt/managers/data_parallel_controller.py +1 -1
- sglang/srt/managers/detokenizer_manager.py +4 -0
- sglang/srt/managers/image_processor.py +186 -13
- sglang/srt/managers/io_struct.py +10 -0
- sglang/srt/managers/schedule_batch.py +238 -68
- sglang/srt/managers/scheduler.py +69 -50
- sglang/srt/managers/tokenizer_manager.py +24 -4
- sglang/srt/managers/tp_worker.py +26 -111
- sglang/srt/managers/tp_worker_overlap_thread.py +209 -0
- sglang/srt/mem_cache/memory_pool.py +56 -10
- sglang/srt/mem_cache/radix_cache.py +4 -3
- sglang/srt/model_executor/cuda_graph_runner.py +87 -28
- sglang/srt/model_executor/forward_batch_info.py +83 -3
- sglang/srt/model_executor/model_runner.py +32 -11
- sglang/srt/models/chatglm.py +3 -3
- sglang/srt/models/deepseek_v2.py +2 -2
- sglang/srt/models/mllama.py +1004 -0
- sglang/srt/models/qwen2_vl.py +724 -0
- sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +6 -3
- sglang/srt/sampling/sampling_batch_info.py +13 -3
- sglang/srt/sampling/sampling_params.py +5 -7
- sglang/srt/server.py +12 -0
- sglang/srt/server_args.py +10 -0
- sglang/srt/utils.py +22 -0
- sglang/test/run_eval.py +2 -0
- sglang/test/runners.py +20 -1
- sglang/test/srt/sampling/penaltylib/utils.py +1 -0
- sglang/test/test_utils.py +100 -3
- sglang/version.py +1 -1
- {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/METADATA +17 -18
- {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/RECORD +53 -48
- {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/LICENSE +0 -0
- {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/WHEEL +0 -0
- {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/top_level.txt +0 -0
@@ -11,7 +11,6 @@ from enum import Enum, auto
|
|
11
11
|
from typing import TYPE_CHECKING
|
12
12
|
|
13
13
|
import torch
|
14
|
-
import torch.nn as nn
|
15
14
|
import triton
|
16
15
|
import triton.language as tl
|
17
16
|
|
@@ -21,6 +20,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|
21
20
|
from sglang.srt.utils import is_flashinfer_available
|
22
21
|
|
23
22
|
if TYPE_CHECKING:
|
23
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
24
24
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
25
25
|
|
26
26
|
if is_flashinfer_available():
|
@@ -56,13 +56,13 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
56
56
|
|
57
57
|
assert not (
|
58
58
|
model_runner.sliding_window_size is not None
|
59
|
-
and model_runner.
|
59
|
+
and model_runner.model_config.is_encoder_decoder
|
60
60
|
), "Sliding window and cross attention are not supported together"
|
61
61
|
|
62
62
|
if model_runner.sliding_window_size is not None:
|
63
63
|
self.num_wrappers = 2
|
64
64
|
self.dispatch_reason = WrapperDispatch.SLIDING_WINDOW
|
65
|
-
elif model_runner.
|
65
|
+
elif model_runner.model_config.is_encoder_decoder:
|
66
66
|
self.num_wrappers = 2
|
67
67
|
self.dispatch_reason = WrapperDispatch.CROSS_ATTENTION
|
68
68
|
else:
|
@@ -127,6 +127,9 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
127
127
|
self.indices_updater_decode.update(
|
128
128
|
forward_batch.req_pool_indices,
|
129
129
|
forward_batch.seq_lens,
|
130
|
+
forward_batch.seq_lens_sum,
|
131
|
+
decode_wrappers=None,
|
132
|
+
encoder_lens=forward_batch.encoder_lens,
|
130
133
|
)
|
131
134
|
self.forward_metadata = (self.decode_wrappers,)
|
132
135
|
else:
|
@@ -134,10 +137,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
134
137
|
|
135
138
|
# Some heuristics to check whether to use ragged forward
|
136
139
|
use_ragged = False
|
137
|
-
if
|
138
|
-
torch.sum(forward_batch.seq_lens).item() >= 4096
|
139
|
-
and self.num_wrappers == 1
|
140
|
-
):
|
140
|
+
if forward_batch.extend_num_tokens >= 4096 and self.num_wrappers == 1:
|
141
141
|
use_ragged = True
|
142
142
|
|
143
143
|
extend_no_prefix = not torch.any(forward_batch.extend_prefix_lens).item()
|
@@ -146,13 +146,11 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
146
146
|
forward_batch.req_pool_indices,
|
147
147
|
forward_batch.seq_lens,
|
148
148
|
prefix_lens,
|
149
|
-
use_ragged,
|
149
|
+
use_ragged=use_ragged,
|
150
|
+
encoder_lens=forward_batch.encoder_lens,
|
150
151
|
)
|
151
152
|
|
152
|
-
self.forward_metadata = (
|
153
|
-
use_ragged,
|
154
|
-
extend_no_prefix,
|
155
|
-
)
|
153
|
+
self.forward_metadata = (use_ragged, extend_no_prefix)
|
156
154
|
|
157
155
|
def init_cuda_graph_state(self, max_bs: int):
|
158
156
|
cuda_graph_kv_indices = torch.zeros(
|
@@ -165,7 +163,11 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
165
163
|
]
|
166
164
|
|
167
165
|
def init_forward_metadata_capture_cuda_graph(
|
168
|
-
self,
|
166
|
+
self,
|
167
|
+
bs: int,
|
168
|
+
req_pool_indices: torch.Tensor,
|
169
|
+
seq_lens: torch.Tensor,
|
170
|
+
encoder_lens: torch.Tensor = None,
|
169
171
|
):
|
170
172
|
decode_wrappers = []
|
171
173
|
for i in range(self.num_wrappers):
|
@@ -181,37 +183,59 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
181
183
|
)
|
182
184
|
)
|
183
185
|
|
184
|
-
|
186
|
+
seq_lens_sum = seq_lens.sum().item()
|
187
|
+
self.indices_updater_decode.update(
|
188
|
+
req_pool_indices,
|
189
|
+
seq_lens,
|
190
|
+
seq_lens_sum,
|
191
|
+
decode_wrappers=decode_wrappers,
|
192
|
+
encoder_lens=encoder_lens,
|
193
|
+
)
|
185
194
|
self.cuda_graph_metadata[bs] = decode_wrappers
|
186
195
|
self.forward_metadata = (decode_wrappers,)
|
187
196
|
|
188
197
|
def init_forward_metadata_replay_cuda_graph(
|
189
|
-
self,
|
198
|
+
self,
|
199
|
+
bs: int,
|
200
|
+
req_pool_indices: torch.Tensor,
|
201
|
+
seq_lens: torch.Tensor,
|
202
|
+
seq_lens_sum: int,
|
203
|
+
encoder_lens: torch.Tensor = None,
|
190
204
|
):
|
191
205
|
self.indices_updater_decode.update(
|
192
|
-
req_pool_indices[:bs],
|
206
|
+
req_pool_indices[:bs],
|
207
|
+
seq_lens[:bs],
|
208
|
+
seq_lens_sum,
|
209
|
+
decode_wrappers=self.cuda_graph_metadata[bs],
|
210
|
+
encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
|
193
211
|
)
|
194
212
|
|
195
213
|
def get_cuda_graph_seq_len_fill_value(self):
|
196
214
|
return 0
|
197
215
|
|
198
|
-
def forward_extend(
|
216
|
+
def forward_extend(
|
217
|
+
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
|
218
|
+
):
|
199
219
|
prefill_wrapper_paged = self.prefill_wrappers_paged[
|
200
220
|
self._get_wrapper_idx(layer)
|
201
221
|
]
|
202
222
|
|
203
223
|
use_ragged, extend_no_prefix = self.forward_metadata
|
224
|
+
cache_loc = (
|
225
|
+
forward_batch.out_cache_loc
|
226
|
+
if not layer.is_cross_attention
|
227
|
+
else forward_batch.encoder_out_cache_loc
|
228
|
+
)
|
204
229
|
|
205
230
|
if not use_ragged:
|
206
231
|
if k is not None:
|
207
232
|
assert v is not None
|
208
|
-
forward_batch.token_to_kv_pool.set_kv_buffer(
|
209
|
-
|
210
|
-
)
|
233
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
|
234
|
+
|
211
235
|
o = prefill_wrapper_paged.forward(
|
212
236
|
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
213
237
|
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
214
|
-
causal=
|
238
|
+
causal=not layer.is_cross_attention,
|
215
239
|
sm_scale=layer.scaling,
|
216
240
|
window_left=layer.sliding_window_size,
|
217
241
|
logits_soft_cap=layer.logit_cap,
|
@@ -239,20 +263,23 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
239
263
|
|
240
264
|
o, _ = merge_state(o1, s1, o2, s2)
|
241
265
|
|
242
|
-
forward_batch.token_to_kv_pool.set_kv_buffer(
|
243
|
-
layer.layer_id, forward_batch.out_cache_loc, k, v
|
244
|
-
)
|
266
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
|
245
267
|
|
246
268
|
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
247
269
|
|
248
|
-
def forward_decode(
|
270
|
+
def forward_decode(
|
271
|
+
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
|
272
|
+
):
|
249
273
|
decode_wrapper = self.forward_metadata[0][self._get_wrapper_idx(layer)]
|
274
|
+
cache_loc = (
|
275
|
+
forward_batch.out_cache_loc
|
276
|
+
if not layer.is_cross_attention
|
277
|
+
else forward_batch.encoder_out_cache_loc
|
278
|
+
)
|
250
279
|
|
251
280
|
if k is not None:
|
252
281
|
assert v is not None
|
253
|
-
forward_batch.token_to_kv_pool.set_kv_buffer(
|
254
|
-
layer.layer_id, forward_batch.out_cache_loc, k, v
|
255
|
-
)
|
282
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
|
256
283
|
|
257
284
|
o = decode_wrapper.forward(
|
258
285
|
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
@@ -263,7 +290,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
263
290
|
|
264
291
|
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
265
292
|
|
266
|
-
def _get_wrapper_idx(self, layer:
|
293
|
+
def _get_wrapper_idx(self, layer: RadixAttention):
|
267
294
|
if self.num_wrappers == 1:
|
268
295
|
return 0
|
269
296
|
|
@@ -290,6 +317,8 @@ class FlashInferIndicesUpdaterDecode:
|
|
290
317
|
self.max_context_len = model_runner.req_to_token_pool.req_to_token.size(1)
|
291
318
|
self.sliding_window_size = model_runner.sliding_window_size
|
292
319
|
|
320
|
+
self.attn_backend = attn_backend
|
321
|
+
|
293
322
|
# Buffers and wrappers
|
294
323
|
self.kv_indptr = attn_backend.kv_indptr
|
295
324
|
self.kv_last_page_len = attn_backend.kv_last_page_len
|
@@ -297,55 +326,117 @@ class FlashInferIndicesUpdaterDecode:
|
|
297
326
|
self.decode_wrappers = attn_backend.decode_wrappers
|
298
327
|
|
299
328
|
# Dispatch
|
300
|
-
if attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
|
329
|
+
if self.attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
|
301
330
|
self.update = self.update_sliding_window
|
302
|
-
elif attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
|
331
|
+
elif self.attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
|
303
332
|
self.update = self.update_cross_attention
|
304
333
|
else:
|
305
|
-
assert attn_backend.num_wrappers == 1
|
334
|
+
assert self.attn_backend.num_wrappers == 1
|
306
335
|
self.update = self.update_single_wrapper
|
307
336
|
|
308
|
-
def
|
337
|
+
def update(
|
338
|
+
self, req_pool_indices, seq_lens, seq_lens_sum, decode_wrappers, encoder_lens
|
339
|
+
):
|
340
|
+
# Keep the signature for type checking. It will be assigned during runtime.
|
341
|
+
raise NotImplementedError()
|
342
|
+
|
343
|
+
def update_single_wrapper(
|
344
|
+
self,
|
345
|
+
req_pool_indices: torch.Tensor,
|
346
|
+
seq_lens: torch.Tensor,
|
347
|
+
seq_lens_sum: int,
|
348
|
+
decode_wrappers=None,
|
349
|
+
encoder_lens=None,
|
350
|
+
):
|
309
351
|
decode_wrappers = decode_wrappers or self.decode_wrappers
|
310
352
|
self.call_begin_forward(
|
311
|
-
decode_wrappers[0],
|
353
|
+
decode_wrappers[0],
|
354
|
+
req_pool_indices,
|
355
|
+
seq_lens,
|
356
|
+
seq_lens_sum,
|
357
|
+
self.kv_indptr[0],
|
358
|
+
None,
|
312
359
|
)
|
313
360
|
|
314
|
-
def update_sliding_window(
|
361
|
+
def update_sliding_window(
|
362
|
+
self,
|
363
|
+
req_pool_indices: torch.Tensor,
|
364
|
+
seq_lens: torch.Tensor,
|
365
|
+
seq_lens_sum: int,
|
366
|
+
decode_wrappers=None,
|
367
|
+
encoder_lens=None,
|
368
|
+
):
|
315
369
|
decode_wrappers = decode_wrappers or self.decode_wrappers
|
316
370
|
|
317
371
|
for wrapper_id in range(2):
|
318
372
|
if wrapper_id == 0:
|
319
373
|
# Sliding window attention
|
320
|
-
|
374
|
+
paged_kernel_lens_tmp = torch.minimum( # TODO: replace this with clamp
|
321
375
|
seq_lens,
|
322
376
|
torch.tensor(self.sliding_window_size + 1),
|
323
377
|
)
|
378
|
+
paged_kernel_lens_sum_tmp = paged_kernel_lens_tmp.sum().item()
|
379
|
+
kv_start_idx_tmp = seq_lens - paged_kernel_lens_tmp
|
324
380
|
else:
|
325
381
|
# Full attention
|
326
|
-
|
382
|
+
paged_kernel_lens_tmp = seq_lens
|
383
|
+
paged_kernel_lens_sum_tmp = seq_lens_sum
|
384
|
+
kv_start_idx_tmp = None
|
327
385
|
|
328
|
-
|
386
|
+
self.call_begin_forward(
|
387
|
+
decode_wrappers[wrapper_id],
|
388
|
+
req_pool_indices,
|
389
|
+
paged_kernel_lens_tmp,
|
390
|
+
paged_kernel_lens_sum_tmp,
|
391
|
+
self.kv_indptr[wrapper_id],
|
392
|
+
kv_start_idx_tmp,
|
393
|
+
)
|
394
|
+
|
395
|
+
def update_cross_attention(
|
396
|
+
self,
|
397
|
+
req_pool_indices,
|
398
|
+
seq_lens,
|
399
|
+
seq_lens_sum,
|
400
|
+
decode_wrappers=None,
|
401
|
+
encoder_lens=None,
|
402
|
+
):
|
403
|
+
decode_wrappers = decode_wrappers or self.decode_wrappers
|
404
|
+
|
405
|
+
for wrapper_id in range(2):
|
406
|
+
if wrapper_id == 0:
|
407
|
+
# Normal attention
|
408
|
+
paged_kernel_lens = seq_lens
|
409
|
+
kv_start_idx = encoder_lens
|
410
|
+
else:
|
411
|
+
# Cross attention
|
412
|
+
paged_kernel_lens = encoder_lens
|
413
|
+
kv_start_idx = torch.zeros_like(encoder_lens)
|
414
|
+
seq_lens_sum = encoder_lens.sum().item()
|
329
415
|
|
330
416
|
self.call_begin_forward(
|
331
417
|
decode_wrappers[wrapper_id],
|
332
418
|
req_pool_indices,
|
333
419
|
paged_kernel_lens,
|
420
|
+
seq_lens_sum,
|
334
421
|
self.kv_indptr[wrapper_id],
|
335
422
|
kv_start_idx,
|
336
423
|
)
|
337
424
|
|
338
|
-
def update_cross_attention(self):
|
339
|
-
raise NotImplementedError()
|
340
|
-
|
341
425
|
def call_begin_forward(
|
342
|
-
self,
|
426
|
+
self,
|
427
|
+
wrapper,
|
428
|
+
req_pool_indices,
|
429
|
+
paged_kernel_lens,
|
430
|
+
paged_kernel_lens_sum,
|
431
|
+
kv_indptr,
|
432
|
+
kv_start_idx,
|
343
433
|
):
|
344
434
|
bs = len(req_pool_indices)
|
435
|
+
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
345
436
|
kv_indptr = kv_indptr[: bs + 1]
|
346
|
-
|
347
|
-
|
348
|
-
|
437
|
+
kv_indices = torch.empty(
|
438
|
+
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
|
439
|
+
)
|
349
440
|
|
350
441
|
create_flashinfer_kv_indices_triton[(bs,)](
|
351
442
|
self.req_to_token,
|
@@ -386,6 +477,8 @@ class FlashInferIndicesUpdaterPrefill:
|
|
386
477
|
self.max_context_len = model_runner.req_to_token_pool.req_to_token.size(1)
|
387
478
|
self.sliding_window_size = model_runner.sliding_window_size
|
388
479
|
|
480
|
+
self.attn_backend = attn_backend
|
481
|
+
|
389
482
|
# Buffers and wrappers
|
390
483
|
self.kv_indptr = attn_backend.kv_indptr
|
391
484
|
self.kv_last_page_len = attn_backend.kv_last_page_len
|
@@ -395,16 +488,20 @@ class FlashInferIndicesUpdaterPrefill:
|
|
395
488
|
self.wrappers_paged = attn_backend.prefill_wrappers_paged
|
396
489
|
|
397
490
|
# Dispatch
|
398
|
-
if attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
|
491
|
+
if self.attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
|
399
492
|
self.update = self.update_sliding_window
|
400
|
-
elif attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
|
493
|
+
elif self.attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
|
401
494
|
self.update = self.update_cross_attention
|
402
495
|
else:
|
403
|
-
assert attn_backend.num_wrappers == 1
|
496
|
+
assert self.attn_backend.num_wrappers == 1
|
404
497
|
self.update = self.update_single_wrapper
|
405
498
|
|
499
|
+
def update(self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens):
|
500
|
+
# Keep the signature for type checking. It will be assigned during runtime.
|
501
|
+
raise NotImplementedError()
|
502
|
+
|
406
503
|
def update_single_wrapper(
|
407
|
-
self, req_pool_indices, seq_lens, prefix_lens, use_ragged
|
504
|
+
self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens
|
408
505
|
):
|
409
506
|
if use_ragged:
|
410
507
|
paged_kernel_lens = prefix_lens
|
@@ -425,7 +522,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
425
522
|
)
|
426
523
|
|
427
524
|
def update_sliding_window(
|
428
|
-
self, req_pool_indices, seq_lens, prefix_lens, use_ragged
|
525
|
+
self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens
|
429
526
|
):
|
430
527
|
for wrapper_id in range(2):
|
431
528
|
if wrapper_id == 0:
|
@@ -452,8 +549,31 @@ class FlashInferIndicesUpdaterPrefill:
|
|
452
549
|
use_ragged,
|
453
550
|
)
|
454
551
|
|
455
|
-
def update_cross_attention(
|
456
|
-
|
552
|
+
def update_cross_attention(
|
553
|
+
self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens
|
554
|
+
):
|
555
|
+
for wrapper_id in range(2):
|
556
|
+
if wrapper_id == 0:
|
557
|
+
# normal attention
|
558
|
+
paged_kernel_lens = seq_lens
|
559
|
+
kv_start_idx = encoder_lens
|
560
|
+
else:
|
561
|
+
# cross attention
|
562
|
+
paged_kernel_lens = encoder_lens
|
563
|
+
kv_start_idx = torch.zeros_like(encoder_lens)
|
564
|
+
|
565
|
+
self.call_begin_forward(
|
566
|
+
self.wrapper_ragged,
|
567
|
+
self.wrappers_paged[wrapper_id],
|
568
|
+
req_pool_indices,
|
569
|
+
paged_kernel_lens,
|
570
|
+
seq_lens,
|
571
|
+
prefix_lens,
|
572
|
+
kv_start_idx,
|
573
|
+
self.kv_indptr[wrapper_id],
|
574
|
+
self.qo_indptr[wrapper_id],
|
575
|
+
use_ragged,
|
576
|
+
)
|
457
577
|
|
458
578
|
def call_begin_forward(
|
459
579
|
self,
|
@@ -469,8 +589,8 @@ class FlashInferIndicesUpdaterPrefill:
|
|
469
589
|
use_ragged,
|
470
590
|
):
|
471
591
|
bs = len(req_pool_indices)
|
592
|
+
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
472
593
|
kv_indptr = kv_indptr[: bs + 1]
|
473
|
-
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
474
594
|
kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda")
|
475
595
|
create_flashinfer_kv_indices_triton[(bs,)](
|
476
596
|
self.req_to_token,
|
@@ -482,8 +602,8 @@ class FlashInferIndicesUpdaterPrefill:
|
|
482
602
|
self.max_context_len,
|
483
603
|
)
|
484
604
|
|
605
|
+
qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
485
606
|
qo_indptr = qo_indptr[: bs + 1]
|
486
|
-
qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
487
607
|
|
488
608
|
# extend part
|
489
609
|
if use_ragged:
|
@@ -10,6 +10,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
|
|
10
10
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
11
11
|
|
12
12
|
if TYPE_CHECKING:
|
13
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
13
14
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
14
15
|
|
15
16
|
|
@@ -81,8 +82,13 @@ class TritonAttnBackend(AttentionBackend):
|
|
81
82
|
)
|
82
83
|
|
83
84
|
def init_forward_metadata_capture_cuda_graph(
|
84
|
-
self,
|
85
|
+
self,
|
86
|
+
bs: int,
|
87
|
+
req_pool_indices: torch.Tensor,
|
88
|
+
seq_lens: torch.Tensor,
|
89
|
+
encoder_lens=None,
|
85
90
|
):
|
91
|
+
# NOTE: encoder_lens expected to be zeros or None
|
86
92
|
self.forward_metadata = (
|
87
93
|
self.cuda_graph_start_loc,
|
88
94
|
self.cuda_graph_attn_logits,
|
@@ -91,15 +97,23 @@ class TritonAttnBackend(AttentionBackend):
|
|
91
97
|
)
|
92
98
|
|
93
99
|
def init_forward_metadata_replay_cuda_graph(
|
94
|
-
self,
|
100
|
+
self,
|
101
|
+
bs: int,
|
102
|
+
req_pool_indices: torch.Tensor,
|
103
|
+
seq_lens: torch.Tensor,
|
104
|
+
seq_lens_sum: int,
|
105
|
+
encoder_lens=None,
|
95
106
|
):
|
107
|
+
# NOTE: encoder_lens expected to be zeros or None
|
96
108
|
self.cuda_graph_start_loc.zero_()
|
97
109
|
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
|
98
110
|
|
99
111
|
def get_cuda_graph_seq_len_fill_value(self):
|
100
112
|
return 1
|
101
113
|
|
102
|
-
def forward_extend(
|
114
|
+
def forward_extend(
|
115
|
+
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
|
116
|
+
):
|
103
117
|
# TODO: reuse the buffer across layers
|
104
118
|
if layer.qk_head_dim != layer.v_head_dim:
|
105
119
|
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
@@ -107,7 +121,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
107
121
|
o = torch.empty_like(q)
|
108
122
|
|
109
123
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
110
|
-
layer
|
124
|
+
layer, forward_batch.out_cache_loc, k, v
|
111
125
|
)
|
112
126
|
|
113
127
|
start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
|
@@ -129,7 +143,9 @@ class TritonAttnBackend(AttentionBackend):
|
|
129
143
|
)
|
130
144
|
return o
|
131
145
|
|
132
|
-
def forward_decode(
|
146
|
+
def forward_decode(
|
147
|
+
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
|
148
|
+
):
|
133
149
|
# During torch.compile, there is a bug in rotary_emb that causes the
|
134
150
|
# output value to have a 3D tensor shape. This reshapes the output correctly.
|
135
151
|
q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
|
@@ -143,7 +159,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
143
159
|
start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
|
144
160
|
|
145
161
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
146
|
-
layer
|
162
|
+
layer, forward_batch.out_cache_loc, k, v
|
147
163
|
)
|
148
164
|
|
149
165
|
self.decode_attention_fwd(
|
@@ -50,6 +50,7 @@ def _fwd_kernel(
|
|
50
50
|
BLOCK_M: tl.constexpr,
|
51
51
|
BLOCK_DMODEL: tl.constexpr,
|
52
52
|
BLOCK_N: tl.constexpr,
|
53
|
+
IS_CAUSAL: tl.constexpr,
|
53
54
|
Lk: tl.constexpr,
|
54
55
|
):
|
55
56
|
cur_batch = tl.program_id(0)
|
@@ -78,7 +79,9 @@ def _fwd_kernel(
|
|
78
79
|
mask_d = offs_d < Lk
|
79
80
|
|
80
81
|
q = tl.load(
|
81
|
-
Q + off_q,
|
82
|
+
Q + off_q,
|
83
|
+
mask=(offs_m[:, None] < cur_batch_seq_len) & (mask_d[None, :]),
|
84
|
+
other=0.0,
|
82
85
|
)
|
83
86
|
|
84
87
|
k_ptrs = K + off_k
|
@@ -91,7 +94,12 @@ def _fwd_kernel(
|
|
91
94
|
|
92
95
|
block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)
|
93
96
|
|
94
|
-
|
97
|
+
end_n = (
|
98
|
+
cur_batch_seq_len
|
99
|
+
if not IS_CAUSAL
|
100
|
+
else tl.minimum((start_m + 1) * BLOCK_M, cur_batch_seq_len)
|
101
|
+
)
|
102
|
+
for start_n in range(0, block_mask * end_n, BLOCK_N):
|
95
103
|
start_n = tl.multiple_of(start_n, BLOCK_N)
|
96
104
|
# -- compute qk ----
|
97
105
|
k = tl.load(
|
@@ -104,7 +112,18 @@ def _fwd_kernel(
|
|
104
112
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
105
113
|
qk += tl.dot(q, k)
|
106
114
|
qk *= sm_scale
|
107
|
-
|
115
|
+
|
116
|
+
if IS_CAUSAL:
|
117
|
+
qk += tl.where(
|
118
|
+
(start_n + offs_n[None, :] < cur_batch_seq_len)
|
119
|
+
& (offs_m[:, None] >= (start_n + offs_n[None, :])),
|
120
|
+
0,
|
121
|
+
float("-inf"),
|
122
|
+
)
|
123
|
+
else:
|
124
|
+
qk += tl.where(
|
125
|
+
(start_n + offs_n[None, :]) < cur_batch_seq_len, 0, float("-inf")
|
126
|
+
)
|
108
127
|
|
109
128
|
# -- compute m_ij, p, l_ij
|
110
129
|
m_ij = tl.max(qk, 1)
|
@@ -146,7 +165,9 @@ def _fwd_kernel(
|
|
146
165
|
)
|
147
166
|
|
148
167
|
|
149
|
-
def context_attention_fwd(
|
168
|
+
def context_attention_fwd(
|
169
|
+
q, k, v, o, b_start_loc, b_seq_len, max_input_len, is_causal=True
|
170
|
+
):
|
150
171
|
if is_cuda_available and CUDA_CAPABILITY[0] >= 8:
|
151
172
|
BLOCK = 128
|
152
173
|
else:
|
@@ -181,6 +202,7 @@ def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
|
|
181
202
|
BLOCK_M=BLOCK,
|
182
203
|
BLOCK_DMODEL=triton.next_power_of_2(Lk),
|
183
204
|
BLOCK_N=BLOCK,
|
205
|
+
IS_CAUSAL=is_causal,
|
184
206
|
num_warps=num_warps,
|
185
207
|
num_stages=1,
|
186
208
|
Lk=Lk,
|