sglang 0.4.5.post2__py3-none-any.whl → 0.4.5.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +3 -2
- sglang/compile_deep_gemm.py +136 -0
- sglang/lang/backend/openai.py +5 -1
- sglang/lang/backend/runtime_endpoint.py +5 -1
- sglang/srt/configs/model_config.py +4 -1
- sglang/srt/constrained/xgrammar_backend.py +1 -0
- sglang/srt/disaggregation/decode.py +43 -0
- sglang/srt/disaggregation/mini_lb.py +69 -8
- sglang/srt/disaggregation/mooncake/conn.py +1 -1
- sglang/srt/disaggregation/nixl/__init__.py +1 -0
- sglang/srt/disaggregation/nixl/conn.py +622 -0
- sglang/srt/disaggregation/prefill.py +100 -16
- sglang/srt/disaggregation/utils.py +17 -0
- sglang/srt/entrypoints/engine.py +4 -0
- sglang/srt/entrypoints/http_server.py +3 -7
- sglang/srt/function_call_parser.py +60 -0
- sglang/srt/layers/activation.py +2 -2
- sglang/srt/layers/attention/flashattention_backend.py +781 -150
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +5 -5
- sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -5
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +7 -3
- sglang/srt/layers/dp_attention.py +1 -1
- sglang/srt/layers/layernorm.py +19 -4
- sglang/srt/layers/moe/ep_moe/layer.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +153 -0
- sglang/srt/layers/quantization/deep_gemm.py +378 -0
- sglang/srt/layers/quantization/fp8_kernel.py +7 -38
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gptq.py +13 -7
- sglang/srt/layers/quantization/modelopt_quant.py +2 -2
- sglang/srt/layers/quantization/w8a8_int8.py +3 -3
- sglang/srt/layers/rotary_embedding.py +6 -6
- sglang/srt/layers/sampler.py +2 -2
- sglang/srt/managers/data_parallel_controller.py +7 -1
- sglang/srt/managers/io_struct.py +14 -3
- sglang/srt/managers/schedule_batch.py +13 -0
- sglang/srt/managers/scheduler.py +16 -6
- sglang/srt/managers/tokenizer_manager.py +115 -29
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +40 -32
- sglang/srt/mem_cache/memory_pool.py +31 -13
- sglang/srt/model_executor/cuda_graph_runner.py +13 -8
- sglang/srt/model_executor/model_runner.py +19 -4
- sglang/srt/models/deepseek_v2.py +9 -6
- sglang/srt/models/minicpm3.py +2 -2
- sglang/srt/models/minicpmo.py +17 -6
- sglang/srt/openai_api/adapter.py +71 -4
- sglang/srt/openai_api/protocol.py +6 -1
- sglang/srt/server_args.py +52 -40
- sglang/srt/speculative/build_eagle_tree.py +2 -2
- sglang/srt/speculative/eagle_utils.py +2 -2
- sglang/srt/speculative/eagle_worker.py +2 -7
- sglang/srt/utils.py +46 -5
- sglang/test/test_utils.py +3 -1
- sglang/version.py +1 -1
- {sglang-0.4.5.post2.dist-info → sglang-0.4.5.post3.dist-info}/METADATA +3 -3
- {sglang-0.4.5.post2.dist-info → sglang-0.4.5.post3.dist-info}/RECORD +62 -57
- {sglang-0.4.5.post2.dist-info → sglang-0.4.5.post3.dist-info}/WHEEL +0 -0
- {sglang-0.4.5.post2.dist-info → sglang-0.4.5.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.post2.dist-info → sglang-0.4.5.post3.dist-info}/top_level.txt +0 -0
@@ -16,6 +16,7 @@ if TYPE_CHECKING:
|
|
16
16
|
from sglang.srt.layers.radix_attention import RadixAttention
|
17
17
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
18
18
|
|
19
|
+
from sgl_kernel import merge_state_v2
|
19
20
|
from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
|
20
21
|
|
21
22
|
|
@@ -30,7 +31,7 @@ class FlashAttentionMetadata:
|
|
30
31
|
# Sequence lengths for the forward batch
|
31
32
|
cache_seqlens_int32: torch.Tensor = None
|
32
33
|
# Maximum sequence length for query
|
33
|
-
max_seq_len_q: int =
|
34
|
+
max_seq_len_q: int = 1
|
34
35
|
# Maximum sequence length for key
|
35
36
|
max_seq_len_k: int = 0
|
36
37
|
# Cumulative sequence lengths for query
|
@@ -267,6 +268,12 @@ def cdiv(a: int, b: int) -> int:
|
|
267
268
|
return -(a // -b)
|
268
269
|
|
269
270
|
|
271
|
+
# TODO(hebiao064): remove this once we have a better way to handle the merge_state_v2 torch.compile issue
|
272
|
+
@torch._dynamo.disable()
|
273
|
+
def merge_state_v2_wrapper(o, s_a, o_exp, s_b):
|
274
|
+
return merge_state_v2(o, s_a, o_exp, s_b)
|
275
|
+
|
276
|
+
|
270
277
|
class FlashAttentionBackend(AttentionBackend):
|
271
278
|
"""FlashAttention backend implementation.
|
272
279
|
|
@@ -301,6 +308,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|
301
308
|
), "Sliding window and cross attention are not supported together"
|
302
309
|
|
303
310
|
self.forward_metadata: FlashAttentionMetadata = None
|
311
|
+
# extra metdata for handling speculative decoding topk > 1, extended draft decode and verify
|
312
|
+
self.forward_metadata_spec_decode_expand: FlashAttentionMetadata = None
|
304
313
|
self.max_context_len = model_runner.model_config.context_len
|
305
314
|
self.device = model_runner.device
|
306
315
|
self.decode_cuda_graph_metadata = {}
|
@@ -311,8 +320,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
311
320
|
self.page_size = model_runner.page_size
|
312
321
|
self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA
|
313
322
|
self.skip_prefill = skip_prefill
|
314
|
-
|
315
|
-
self.topk = topk
|
323
|
+
self.topk = model_runner.server_args.speculative_eagle_topk or 0
|
316
324
|
self.speculative_num_steps = speculative_num_steps
|
317
325
|
self.speculative_num_draft_tokens = (
|
318
326
|
model_runner.server_args.speculative_num_draft_tokens
|
@@ -336,14 +344,107 @@ class FlashAttentionBackend(AttentionBackend):
|
|
336
344
|
if forward_batch.forward_mode.is_decode_or_idle():
|
337
345
|
# Draft Decode
|
338
346
|
if forward_batch.spec_info is not None:
|
347
|
+
if self.topk <= 1:
|
348
|
+
metadata.cache_seqlens_int32 = (
|
349
|
+
seqlens_in_batch + (self.speculative_step_id + 1)
|
350
|
+
).to(torch.int32)
|
351
|
+
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + (
|
352
|
+
self.speculative_step_id + 1
|
353
|
+
)
|
354
|
+
metadata.cu_seqlens_q = torch.arange(
|
355
|
+
0, batch_size + 1, dtype=torch.int32, device=device
|
356
|
+
)
|
357
|
+
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
358
|
+
torch.cumsum(
|
359
|
+
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
360
|
+
),
|
361
|
+
(1, 0),
|
362
|
+
)
|
363
|
+
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
364
|
+
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
365
|
+
]
|
366
|
+
else:
|
367
|
+
metadata.cache_seqlens_int32 = (seqlens_in_batch).to(torch.int32)
|
368
|
+
metadata.max_seq_len_q = self.topk
|
369
|
+
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
370
|
+
metadata.cu_seqlens_q = torch.arange(
|
371
|
+
0,
|
372
|
+
batch_size * self.topk + 1,
|
373
|
+
step=self.topk,
|
374
|
+
dtype=torch.int32,
|
375
|
+
device=device,
|
376
|
+
)
|
377
|
+
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
378
|
+
torch.cumsum(
|
379
|
+
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
380
|
+
),
|
381
|
+
(1, 0),
|
382
|
+
)
|
383
|
+
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
384
|
+
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
385
|
+
]
|
386
|
+
|
387
|
+
metadata_expand = FlashAttentionMetadata()
|
388
|
+
decode_length = self.speculative_step_id + 1
|
389
|
+
metadata_expand.cache_seqlens_int32 = torch.full(
|
390
|
+
(seqlens_in_batch.numel() * self.topk,),
|
391
|
+
decode_length,
|
392
|
+
device=device,
|
393
|
+
dtype=torch.int32,
|
394
|
+
)
|
395
|
+
metadata_expand.max_seq_len_q = 1
|
396
|
+
metadata_expand.max_seq_len_k = self.speculative_step_id + 1
|
397
|
+
metadata_expand.cu_seqlens_q = torch.arange(
|
398
|
+
0,
|
399
|
+
metadata_expand.cache_seqlens_int32.numel() + 1,
|
400
|
+
dtype=torch.int32,
|
401
|
+
device=device,
|
402
|
+
)
|
403
|
+
metadata_expand.cu_seqlens_k = torch.arange(
|
404
|
+
0,
|
405
|
+
metadata_expand.cache_seqlens_int32.numel() * decode_length + 1,
|
406
|
+
step=decode_length,
|
407
|
+
dtype=torch.int32,
|
408
|
+
device=device,
|
409
|
+
)
|
410
|
+
cache_loc = forward_batch.out_cache_loc.view(
|
411
|
+
self.speculative_num_steps, -1
|
412
|
+
).T.contiguous()
|
413
|
+
metadata_expand.page_table = (
|
414
|
+
cache_loc[:, :decode_length].contiguous().to(torch.int32)
|
415
|
+
)
|
416
|
+
self.forward_metadata_spec_decode_expand = metadata_expand
|
417
|
+
else:
|
418
|
+
# Normal Decode
|
419
|
+
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
420
|
+
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
421
|
+
metadata.cu_seqlens_q = torch.arange(
|
422
|
+
0, batch_size + 1, dtype=torch.int32, device=device
|
423
|
+
)
|
424
|
+
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
425
|
+
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
|
426
|
+
)
|
427
|
+
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
428
|
+
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
429
|
+
]
|
430
|
+
# TODO: we need to test this part for llama 4 eagle case
|
431
|
+
self._init_local_attn_metadata(metadata, device)
|
432
|
+
elif forward_batch.forward_mode.is_target_verify():
|
433
|
+
if self.topk <= 1:
|
339
434
|
metadata.cache_seqlens_int32 = (
|
340
|
-
|
435
|
+
forward_batch.seq_lens + self.speculative_num_draft_tokens
|
341
436
|
).to(torch.int32)
|
342
|
-
metadata.
|
343
|
-
|
437
|
+
metadata.max_seq_len_q = self.speculative_num_draft_tokens
|
438
|
+
metadata.max_seq_len_k = (
|
439
|
+
forward_batch.seq_lens_cpu.max().item()
|
440
|
+
+ self.speculative_num_draft_tokens
|
344
441
|
)
|
345
442
|
metadata.cu_seqlens_q = torch.arange(
|
346
|
-
0,
|
443
|
+
0,
|
444
|
+
batch_size * self.speculative_num_draft_tokens + 1,
|
445
|
+
self.speculative_num_draft_tokens,
|
446
|
+
dtype=torch.int32,
|
447
|
+
device=device,
|
347
448
|
)
|
348
449
|
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
349
450
|
torch.cumsum(
|
@@ -357,44 +458,101 @@ class FlashAttentionBackend(AttentionBackend):
|
|
357
458
|
|
358
459
|
self._init_local_attn_metadata(metadata, device)
|
359
460
|
else:
|
360
|
-
|
361
|
-
metadata.
|
461
|
+
metadata.cache_seqlens_int32 = forward_batch.seq_lens.to(torch.int32)
|
462
|
+
metadata.max_seq_len_q = self.speculative_num_draft_tokens
|
362
463
|
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
363
464
|
metadata.cu_seqlens_q = torch.arange(
|
364
|
-
0,
|
465
|
+
0,
|
466
|
+
batch_size * self.speculative_num_draft_tokens + 1,
|
467
|
+
step=self.speculative_num_draft_tokens,
|
468
|
+
dtype=torch.int32,
|
469
|
+
device=device,
|
365
470
|
)
|
366
471
|
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
367
|
-
torch.cumsum(
|
472
|
+
torch.cumsum(
|
473
|
+
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
474
|
+
),
|
475
|
+
(1, 0),
|
368
476
|
)
|
369
477
|
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
370
478
|
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
371
479
|
]
|
372
480
|
|
373
|
-
|
374
|
-
elif forward_batch.forward_mode.is_target_verify():
|
375
|
-
metadata.cache_seqlens_int32 = (
|
376
|
-
forward_batch.seq_lens + self.speculative_num_draft_tokens
|
377
|
-
).to(torch.int32)
|
378
|
-
metadata.max_seq_len_q = self.speculative_num_draft_tokens
|
379
|
-
metadata.max_seq_len_k = (
|
380
|
-
forward_batch.seq_lens_cpu.max().item()
|
381
|
-
+ self.speculative_num_draft_tokens
|
382
|
-
)
|
383
|
-
metadata.cu_seqlens_q = torch.arange(
|
384
|
-
0,
|
385
|
-
batch_size * self.speculative_num_draft_tokens + 1,
|
386
|
-
self.speculative_num_draft_tokens,
|
387
|
-
dtype=torch.int32,
|
388
|
-
device=device,
|
389
|
-
)
|
390
|
-
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
391
|
-
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32),
|
392
|
-
(1, 0),
|
393
|
-
)
|
394
|
-
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
395
|
-
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
396
|
-
]
|
481
|
+
metadata_expand = FlashAttentionMetadata()
|
397
482
|
|
483
|
+
metadata_expand.max_seq_len_q = 1
|
484
|
+
metadata_expand.cu_seqlens_q = torch.arange(
|
485
|
+
0,
|
486
|
+
forward_batch.seq_lens.numel() * self.speculative_num_draft_tokens
|
487
|
+
+ 1,
|
488
|
+
dtype=torch.int32,
|
489
|
+
device=device,
|
490
|
+
)
|
491
|
+
|
492
|
+
# create expand page table
|
493
|
+
offsets = torch.arange(
|
494
|
+
self.speculative_num_draft_tokens, device=device
|
495
|
+
).unsqueeze(
|
496
|
+
0
|
497
|
+
) # shape: (1, self.speculative_num_draft_tokens)
|
498
|
+
cols = offsets.expand(
|
499
|
+
forward_batch.seq_lens.numel(), -1
|
500
|
+
) + forward_batch.seq_lens.unsqueeze(1)
|
501
|
+
cum_len = torch.nn.functional.pad(
|
502
|
+
torch.cumsum(
|
503
|
+
(
|
504
|
+
forward_batch.seq_lens + self.speculative_num_draft_tokens
|
505
|
+
).repeat_interleave(self.speculative_num_draft_tokens),
|
506
|
+
dim=0,
|
507
|
+
),
|
508
|
+
(1, 0),
|
509
|
+
)[:-1]
|
510
|
+
mask_extraction_indices = (
|
511
|
+
cols.repeat_interleave(self.speculative_num_draft_tokens, dim=0)
|
512
|
+
+ cum_len[:, None]
|
513
|
+
).view(1, -1)
|
514
|
+
mask = forward_batch.spec_info.custom_mask[
|
515
|
+
mask_extraction_indices
|
516
|
+
].view(
|
517
|
+
-1, self.speculative_num_draft_tokens
|
518
|
+
) # (bsz * draft_num, draft_num)
|
519
|
+
|
520
|
+
# shift table indices to avoid padding
|
521
|
+
# non_masked_page_table [[8, 9, 10], mask (display with int format) [[1, 0, 0],
|
522
|
+
# [8, 9, 10], [1, 1, 0],
|
523
|
+
# [8, 9, 10]] [1, 0, 1]]
|
524
|
+
# if masked with padding [[8, 0, 0], our mask without padding [[8, 9, 10],
|
525
|
+
# [8, 9, 0], [8, 9, 10],
|
526
|
+
# [8, 0, 10]] [8, 10, 9]]
|
527
|
+
# note here cache_seqlens_int32 is [1, 2, 2] so extra page indices will be ignored in each row
|
528
|
+
col_indices = offsets.expand(
|
529
|
+
mask.shape[0], self.speculative_num_draft_tokens
|
530
|
+
)
|
531
|
+
# Build keys: if an entry is valid (mask==True), keep its original index;
|
532
|
+
# if not, add self.speculative_num_draft_tokens so that it sorts after all valid entries.
|
533
|
+
keys = torch.where(
|
534
|
+
mask, col_indices, col_indices + self.speculative_num_draft_tokens
|
535
|
+
)
|
536
|
+
_, sort_order = torch.sort(keys, dim=1)
|
537
|
+
non_masked_page_table = (
|
538
|
+
forward_batch.req_to_token_pool.req_to_token[
|
539
|
+
forward_batch.req_pool_indices, :
|
540
|
+
]
|
541
|
+
.gather(1, cols)
|
542
|
+
.repeat_interleave(self.speculative_num_draft_tokens, dim=0)
|
543
|
+
) # (bsz, draft_num)
|
544
|
+
metadata_expand.page_table = non_masked_page_table.gather(1, sort_order)
|
545
|
+
metadata_expand.cache_seqlens_int32 = mask.sum(dim=1).to(torch.int32)
|
546
|
+
metadata_expand.cu_seqlens_k = torch.nn.functional.pad(
|
547
|
+
torch.cumsum(
|
548
|
+
metadata_expand.cache_seqlens_int32, dim=0, dtype=torch.int32
|
549
|
+
),
|
550
|
+
(1, 0),
|
551
|
+
)
|
552
|
+
metadata_expand.max_seq_len_k = (
|
553
|
+
metadata_expand.cache_seqlens_int32.max().item()
|
554
|
+
)
|
555
|
+
self.forward_metadata_spec_decode_expand = metadata_expand
|
398
556
|
elif forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed():
|
399
557
|
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
400
558
|
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
@@ -514,6 +672,11 @@ class FlashAttentionBackend(AttentionBackend):
|
|
514
672
|
and (hasattr(layer, "use_irope") and layer.use_irope)
|
515
673
|
)
|
516
674
|
|
675
|
+
# We do cascade attention for Target Verify with topk > 1
|
676
|
+
use_cascade_attn = (
|
677
|
+
forward_batch.forward_mode.is_target_verify() and self.topk > 1
|
678
|
+
)
|
679
|
+
|
517
680
|
# Get the appropriate page table based on whether we're using local attention
|
518
681
|
if use_local_attn:
|
519
682
|
local_metadata = metadata.local_attn_metadata
|
@@ -548,7 +711,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
548
711
|
cu_seqlens_k = metadata.encoder_cu_seqlens_k
|
549
712
|
window_size = (-1, -1)
|
550
713
|
|
551
|
-
|
714
|
+
result = flash_attn_with_kvcache(
|
552
715
|
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
553
716
|
k_cache=key_cache,
|
554
717
|
v_cache=value_cache,
|
@@ -558,13 +721,41 @@ class FlashAttentionBackend(AttentionBackend):
|
|
558
721
|
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
|
559
722
|
max_seqlen_q=max_seqlen_q,
|
560
723
|
softmax_scale=layer.scaling,
|
561
|
-
causal=causal,
|
724
|
+
causal=False if use_cascade_attn else causal,
|
562
725
|
window_size=window_size,
|
563
726
|
softcap=layer.logit_cap,
|
564
727
|
k_descale=k_descale,
|
565
728
|
v_descale=v_descale,
|
729
|
+
return_softmax_lse=use_cascade_attn,
|
566
730
|
)
|
567
|
-
|
731
|
+
|
732
|
+
if use_cascade_attn:
|
733
|
+
o, softmax_lse, *rest = result
|
734
|
+
o_expand, softmax_lse_expand, *rest_expand = flash_attn_with_kvcache(
|
735
|
+
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
736
|
+
k_cache=key_cache,
|
737
|
+
v_cache=value_cache,
|
738
|
+
page_table=self.forward_metadata_spec_decode_expand.page_table,
|
739
|
+
cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,
|
740
|
+
cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
|
741
|
+
cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
|
742
|
+
max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
|
743
|
+
softmax_scale=layer.scaling,
|
744
|
+
causal=False,
|
745
|
+
window_size=window_size,
|
746
|
+
softcap=layer.logit_cap,
|
747
|
+
k_descale=k_descale,
|
748
|
+
v_descale=v_descale,
|
749
|
+
return_softmax_lse=True,
|
750
|
+
)
|
751
|
+
o, _ = merge_state_v2_wrapper(
|
752
|
+
o,
|
753
|
+
softmax_lse.T.contiguous(),
|
754
|
+
o_expand,
|
755
|
+
softmax_lse_expand.T.contiguous(),
|
756
|
+
)
|
757
|
+
else:
|
758
|
+
o = result
|
568
759
|
else:
|
569
760
|
if (
|
570
761
|
not global_server_args_dict["disable_chunked_prefix_cache"]
|
@@ -627,7 +818,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|
627
818
|
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
628
819
|
q_nope = q_all[:, :, : layer.v_head_dim]
|
629
820
|
q_rope = q_all[:, :, layer.v_head_dim :]
|
630
|
-
|
821
|
+
|
822
|
+
result = flash_attn_with_kvcache(
|
631
823
|
q=q_rope,
|
632
824
|
k_cache=k_rope_cache,
|
633
825
|
v_cache=c_kv_cache,
|
@@ -638,13 +830,44 @@ class FlashAttentionBackend(AttentionBackend):
|
|
638
830
|
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
|
639
831
|
max_seqlen_q=max_seqlen_q,
|
640
832
|
softmax_scale=layer.scaling,
|
641
|
-
causal=
|
833
|
+
causal=False if use_cascade_attn else causal,
|
642
834
|
softcap=layer.logit_cap,
|
643
835
|
k_descale=k_descale,
|
644
836
|
v_descale=v_descale,
|
837
|
+
return_softmax_lse=use_cascade_attn,
|
645
838
|
)
|
839
|
+
if use_cascade_attn:
|
840
|
+
o, softmax_lse, *rest = result
|
841
|
+
o_expand, softmax_lse_expand, *rest_expand = (
|
842
|
+
flash_attn_with_kvcache(
|
843
|
+
q=q_rope,
|
844
|
+
k_cache=k_rope_cache,
|
845
|
+
v_cache=c_kv_cache,
|
846
|
+
qv=q_nope,
|
847
|
+
page_table=self.forward_metadata_spec_decode_expand.page_table,
|
848
|
+
cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,
|
849
|
+
cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
|
850
|
+
cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
|
851
|
+
max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
|
852
|
+
softmax_scale=layer.scaling,
|
853
|
+
causal=False,
|
854
|
+
window_size=window_size,
|
855
|
+
softcap=layer.logit_cap,
|
856
|
+
k_descale=k_descale,
|
857
|
+
v_descale=v_descale,
|
858
|
+
return_softmax_lse=True,
|
859
|
+
)
|
860
|
+
)
|
861
|
+
o, _ = merge_state_v2_wrapper(
|
862
|
+
o,
|
863
|
+
softmax_lse.T.contiguous(),
|
864
|
+
o_expand,
|
865
|
+
softmax_lse_expand.T.contiguous(),
|
866
|
+
)
|
867
|
+
else:
|
868
|
+
o = result
|
646
869
|
|
647
|
-
|
870
|
+
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
648
871
|
|
649
872
|
def forward_decode(
|
650
873
|
self,
|
@@ -681,6 +904,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|
681
904
|
use_local_attention = (
|
682
905
|
self.attention_chunk_size is not None and local_attn_metadata is not None
|
683
906
|
)
|
907
|
+
# We do cascade attention for Draft Decode with topk > 1
|
908
|
+
use_cascade_attn = self.topk > 1
|
684
909
|
|
685
910
|
# Calculate window size (can be moved to metadata if layer properties don't change)
|
686
911
|
# we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
|
@@ -752,23 +977,61 @@ class FlashAttentionBackend(AttentionBackend):
|
|
752
977
|
v_descale=v_descale,
|
753
978
|
)
|
754
979
|
else:
|
980
|
+
page_table = metadata.page_table
|
981
|
+
cache_seqlens = metadata.cache_seqlens_int32
|
982
|
+
cu_seqlens_k = metadata.cu_seqlens_k
|
983
|
+
max_seqlen_q = metadata.max_seq_len_q
|
984
|
+
q_reshaped = q.contiguous().view(
|
985
|
+
-1, layer.tp_q_head_num, layer.head_dim
|
986
|
+
)
|
987
|
+
|
755
988
|
# Default: single-token self-attention
|
756
|
-
|
757
|
-
q=
|
989
|
+
result = flash_attn_with_kvcache(
|
990
|
+
q=q_reshaped,
|
758
991
|
k_cache=key_cache,
|
759
992
|
v_cache=value_cache,
|
760
|
-
page_table=
|
761
|
-
cache_seqlens=
|
993
|
+
page_table=page_table,
|
994
|
+
cache_seqlens=cache_seqlens,
|
762
995
|
cu_seqlens_q=metadata.cu_seqlens_q,
|
763
|
-
cu_seqlens_k_new=
|
764
|
-
max_seqlen_q=
|
996
|
+
cu_seqlens_k_new=cu_seqlens_k,
|
997
|
+
max_seqlen_q=max_seqlen_q,
|
765
998
|
softmax_scale=layer.scaling,
|
766
|
-
causal=
|
999
|
+
causal=False if use_cascade_attn else causal,
|
767
1000
|
window_size=window_size,
|
768
1001
|
softcap=layer.logit_cap,
|
769
1002
|
k_descale=k_descale,
|
770
1003
|
v_descale=v_descale,
|
1004
|
+
return_softmax_lse=use_cascade_attn,
|
771
1005
|
)
|
1006
|
+
if use_cascade_attn:
|
1007
|
+
o, softmax_lse, *rest = result
|
1008
|
+
o_expand, softmax_lse_expand, *rest_expand = (
|
1009
|
+
flash_attn_with_kvcache(
|
1010
|
+
q=q_reshaped,
|
1011
|
+
k_cache=key_cache,
|
1012
|
+
v_cache=value_cache,
|
1013
|
+
page_table=self.forward_metadata_spec_decode_expand.page_table,
|
1014
|
+
cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,
|
1015
|
+
cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
|
1016
|
+
cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
|
1017
|
+
max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
|
1018
|
+
softmax_scale=layer.scaling,
|
1019
|
+
causal=False,
|
1020
|
+
window_size=window_size,
|
1021
|
+
softcap=layer.logit_cap,
|
1022
|
+
k_descale=k_descale,
|
1023
|
+
v_descale=v_descale,
|
1024
|
+
return_softmax_lse=True,
|
1025
|
+
)
|
1026
|
+
)
|
1027
|
+
o, _ = merge_state_v2(
|
1028
|
+
o,
|
1029
|
+
softmax_lse.T.contiguous(),
|
1030
|
+
o_expand,
|
1031
|
+
softmax_lse_expand.T.contiguous(),
|
1032
|
+
)
|
1033
|
+
else:
|
1034
|
+
o = result
|
772
1035
|
else:
|
773
1036
|
# Do absorbed multi-latent attention
|
774
1037
|
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
@@ -787,8 +1050,9 @@ class FlashAttentionBackend(AttentionBackend):
|
|
787
1050
|
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
788
1051
|
q_nope = q_all[:, :, : layer.v_head_dim]
|
789
1052
|
q_rope = q_all[:, :, layer.v_head_dim :]
|
1053
|
+
max_seqlen_q = metadata.max_seq_len_q
|
790
1054
|
|
791
|
-
|
1055
|
+
result = flash_attn_with_kvcache(
|
792
1056
|
q=q_rope,
|
793
1057
|
k_cache=k_rope_cache,
|
794
1058
|
v_cache=c_kv_cache,
|
@@ -797,13 +1061,43 @@ class FlashAttentionBackend(AttentionBackend):
|
|
797
1061
|
cache_seqlens=metadata.cache_seqlens_int32,
|
798
1062
|
cu_seqlens_q=metadata.cu_seqlens_q,
|
799
1063
|
cu_seqlens_k_new=metadata.cu_seqlens_k,
|
800
|
-
max_seqlen_q=
|
1064
|
+
max_seqlen_q=max_seqlen_q,
|
801
1065
|
softmax_scale=layer.scaling,
|
802
|
-
causal=
|
1066
|
+
causal=False if use_cascade_attn else causal,
|
803
1067
|
softcap=layer.logit_cap,
|
804
1068
|
k_descale=k_descale,
|
805
1069
|
v_descale=v_descale,
|
1070
|
+
return_softmax_lse=use_cascade_attn, # softmax_lse is needed for merge states
|
806
1071
|
)
|
1072
|
+
if use_cascade_attn:
|
1073
|
+
o, softmax_lse, *rest = result
|
1074
|
+
o_expand, softmax_lse_expand, *rest_expand = flash_attn_with_kvcache(
|
1075
|
+
q=q_rope,
|
1076
|
+
k_cache=k_rope_cache,
|
1077
|
+
v_cache=c_kv_cache,
|
1078
|
+
qv=q_nope,
|
1079
|
+
page_table=self.forward_metadata_spec_decode_expand.page_table,
|
1080
|
+
cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,
|
1081
|
+
cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
|
1082
|
+
cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
|
1083
|
+
max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
|
1084
|
+
softmax_scale=layer.scaling,
|
1085
|
+
causal=False,
|
1086
|
+
window_size=window_size,
|
1087
|
+
softcap=layer.logit_cap,
|
1088
|
+
k_descale=k_descale,
|
1089
|
+
v_descale=v_descale,
|
1090
|
+
return_softmax_lse=True,
|
1091
|
+
)
|
1092
|
+
o, _ = merge_state_v2(
|
1093
|
+
o,
|
1094
|
+
softmax_lse.T.contiguous(),
|
1095
|
+
o_expand,
|
1096
|
+
softmax_lse_expand.T.contiguous(),
|
1097
|
+
)
|
1098
|
+
else:
|
1099
|
+
o = result
|
1100
|
+
|
807
1101
|
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
808
1102
|
|
809
1103
|
def init_cuda_graph_state(self, max_bs: int):
|
@@ -815,6 +1109,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|
815
1109
|
This creates fixed-size tensors that will be reused during CUDA graph replay
|
816
1110
|
to avoid memory allocations.
|
817
1111
|
"""
|
1112
|
+
|
1113
|
+
# This is being used by normal decode and draft decode when topk == 1
|
818
1114
|
self.decode_cuda_graph_metadata = {
|
819
1115
|
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
|
820
1116
|
"cu_seqlens_q": torch.arange(
|
@@ -840,24 +1136,136 @@ class FlashAttentionBackend(AttentionBackend):
|
|
840
1136
|
),
|
841
1137
|
}
|
842
1138
|
|
843
|
-
|
844
|
-
|
845
|
-
|
846
|
-
|
847
|
-
|
848
|
-
|
849
|
-
|
850
|
-
|
851
|
-
|
852
|
-
|
853
|
-
|
854
|
-
|
855
|
-
|
856
|
-
|
857
|
-
|
858
|
-
|
859
|
-
|
860
|
-
|
1139
|
+
# This is used by draft decode's first half of metadata when topk > 1
|
1140
|
+
if self.topk > 1:
|
1141
|
+
self.draft_decode_metadata_topk_normal = {
|
1142
|
+
"cache_seqlens": torch.zeros(
|
1143
|
+
max_bs, dtype=torch.int32, device=self.device
|
1144
|
+
),
|
1145
|
+
"cu_seqlens_q": torch.arange(
|
1146
|
+
0,
|
1147
|
+
max_bs * self.topk + 1,
|
1148
|
+
step=self.topk,
|
1149
|
+
dtype=torch.int32,
|
1150
|
+
device=self.device,
|
1151
|
+
),
|
1152
|
+
"cu_seqlens_k": torch.zeros(
|
1153
|
+
max_bs + 1, dtype=torch.int32, device=self.device
|
1154
|
+
),
|
1155
|
+
"page_table": torch.zeros(
|
1156
|
+
max_bs,
|
1157
|
+
self.max_context_len,
|
1158
|
+
dtype=torch.int32,
|
1159
|
+
device=self.device,
|
1160
|
+
),
|
1161
|
+
}
|
1162
|
+
|
1163
|
+
# This is used by draft decode's second half of metadata when topk > 1
|
1164
|
+
decode_length = self.speculative_step_id + 1
|
1165
|
+
self.draft_decode_metadata_topk_expand = {
|
1166
|
+
"cache_seqlens": torch.full(
|
1167
|
+
(max_bs * self.topk,),
|
1168
|
+
decode_length,
|
1169
|
+
device=self.device,
|
1170
|
+
dtype=torch.int32,
|
1171
|
+
),
|
1172
|
+
"cu_seqlens_q": torch.arange(
|
1173
|
+
0,
|
1174
|
+
max_bs * self.topk + 1,
|
1175
|
+
dtype=torch.int32,
|
1176
|
+
device=self.device,
|
1177
|
+
),
|
1178
|
+
"cu_seqlens_k": torch.arange(
|
1179
|
+
0,
|
1180
|
+
max_bs * self.topk * decode_length + 1,
|
1181
|
+
step=decode_length,
|
1182
|
+
dtype=torch.int32,
|
1183
|
+
device=self.device,
|
1184
|
+
),
|
1185
|
+
"page_table": torch.zeros(
|
1186
|
+
max_bs * self.topk,
|
1187
|
+
decode_length,
|
1188
|
+
dtype=torch.int32,
|
1189
|
+
device=self.device,
|
1190
|
+
),
|
1191
|
+
}
|
1192
|
+
|
1193
|
+
if (
|
1194
|
+
self.speculative_num_draft_tokens is not None
|
1195
|
+
and self.speculative_num_draft_tokens > 0
|
1196
|
+
):
|
1197
|
+
self.target_verify_metadata = {
|
1198
|
+
"cache_seqlens": torch.zeros(
|
1199
|
+
max_bs, dtype=torch.int32, device=self.device
|
1200
|
+
),
|
1201
|
+
"cu_seqlens_q": torch.arange(
|
1202
|
+
0,
|
1203
|
+
max_bs * self.speculative_num_draft_tokens + 1,
|
1204
|
+
step=self.speculative_num_draft_tokens,
|
1205
|
+
dtype=torch.int32,
|
1206
|
+
device=self.device,
|
1207
|
+
),
|
1208
|
+
"cu_seqlens_k": torch.zeros(
|
1209
|
+
max_bs + 1, dtype=torch.int32, device=self.device
|
1210
|
+
),
|
1211
|
+
"page_table": torch.zeros(
|
1212
|
+
max_bs,
|
1213
|
+
(self.max_context_len + self.page_size - 1) // self.page_size,
|
1214
|
+
dtype=torch.int32,
|
1215
|
+
device=self.device,
|
1216
|
+
),
|
1217
|
+
"strided_indices": torch.arange(
|
1218
|
+
0, self.max_context_len, self.page_size, device=self.device
|
1219
|
+
),
|
1220
|
+
}
|
1221
|
+
|
1222
|
+
if self.topk > 1:
|
1223
|
+
self.target_verify_metadata_topk_normal = {
|
1224
|
+
"cache_seqlens": torch.zeros(
|
1225
|
+
max_bs, dtype=torch.int32, device=self.device
|
1226
|
+
),
|
1227
|
+
"cu_seqlens_q": torch.arange(
|
1228
|
+
0,
|
1229
|
+
max_bs * self.speculative_num_draft_tokens + 1,
|
1230
|
+
step=self.speculative_num_draft_tokens,
|
1231
|
+
dtype=torch.int32,
|
1232
|
+
device=self.device,
|
1233
|
+
),
|
1234
|
+
"cu_seqlens_k": torch.zeros(
|
1235
|
+
max_bs + 1, dtype=torch.int32, device=self.device
|
1236
|
+
),
|
1237
|
+
"page_table": torch.zeros(
|
1238
|
+
max_bs,
|
1239
|
+
self.max_context_len,
|
1240
|
+
dtype=torch.int32,
|
1241
|
+
device=self.device,
|
1242
|
+
),
|
1243
|
+
}
|
1244
|
+
|
1245
|
+
self.target_verify_metadata_topk_expand = {
|
1246
|
+
"cache_seqlens": torch.zeros(
|
1247
|
+
max_bs * self.speculative_num_draft_tokens,
|
1248
|
+
dtype=torch.int32,
|
1249
|
+
device=self.device,
|
1250
|
+
),
|
1251
|
+
"cu_seqlens_k": torch.zeros(
|
1252
|
+
max_bs * self.speculative_num_draft_tokens + 1,
|
1253
|
+
dtype=torch.int32,
|
1254
|
+
device=self.device,
|
1255
|
+
),
|
1256
|
+
"cu_seqlens_q": torch.arange(
|
1257
|
+
0,
|
1258
|
+
max_bs * self.speculative_num_draft_tokens + 1,
|
1259
|
+
dtype=torch.int32,
|
1260
|
+
device=self.device,
|
1261
|
+
),
|
1262
|
+
"page_table": torch.zeros(
|
1263
|
+
max_bs * self.speculative_num_draft_tokens,
|
1264
|
+
self.speculative_num_draft_tokens,
|
1265
|
+
dtype=torch.int32,
|
1266
|
+
device=self.device,
|
1267
|
+
),
|
1268
|
+
}
|
861
1269
|
|
862
1270
|
self.encoder_metadata = {
|
863
1271
|
"encoder_page_table": torch.zeros(
|
@@ -886,28 +1294,78 @@ class FlashAttentionBackend(AttentionBackend):
|
|
886
1294
|
):
|
887
1295
|
"""Initialize forward metadata for capturing CUDA graph."""
|
888
1296
|
metadata = FlashAttentionMetadata()
|
1297
|
+
|
1298
|
+
# metadata_expand is needed for Spec Decoding when top k > 1
|
1299
|
+
metadata_expand = FlashAttentionMetadata()
|
1300
|
+
|
889
1301
|
device = seq_lens.device
|
890
1302
|
if forward_mode.is_decode_or_idle():
|
891
1303
|
if spec_info is not None:
|
892
1304
|
# Draft Decode
|
893
|
-
|
894
|
-
|
895
|
-
|
896
|
-
|
897
|
-
|
898
|
-
|
899
|
-
|
900
|
-
|
901
|
-
|
902
|
-
|
903
|
-
|
904
|
-
|
905
|
-
|
906
|
-
|
907
|
-
|
908
|
-
|
909
|
-
|
910
|
-
|
1305
|
+
if self.topk <= 1:
|
1306
|
+
# When topk = 1, we use the normal decode metadata
|
1307
|
+
metadata.cache_seqlens_int32 = self.decode_cuda_graph_metadata[
|
1308
|
+
"cache_seqlens"
|
1309
|
+
][:bs]
|
1310
|
+
metadata.max_seq_len_k = seq_lens.max().item() + (
|
1311
|
+
self.speculative_step_id + 1
|
1312
|
+
)
|
1313
|
+
metadata.cu_seqlens_q = self.decode_cuda_graph_metadata[
|
1314
|
+
"cu_seqlens_q"
|
1315
|
+
][: bs + 1]
|
1316
|
+
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
1317
|
+
torch.cumsum(
|
1318
|
+
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
1319
|
+
),
|
1320
|
+
(1, 0),
|
1321
|
+
)
|
1322
|
+
metadata.page_table = self.decode_cuda_graph_metadata[
|
1323
|
+
"page_table_draft_decode"
|
1324
|
+
][req_pool_indices, :]
|
1325
|
+
self.decode_cuda_graph_metadata[bs] = metadata
|
1326
|
+
else:
|
1327
|
+
# When top k > 1, we need two specific draft decode metadata, and then merge states
|
1328
|
+
# 1. The first half of metadata for prefix tokens
|
1329
|
+
metadata.cache_seqlens_int32 = (
|
1330
|
+
self.draft_decode_metadata_topk_normal["cache_seqlens"][:bs]
|
1331
|
+
)
|
1332
|
+
metadata.max_seq_len_q = self.topk
|
1333
|
+
metadata.max_seq_len_k = seq_lens.max().item()
|
1334
|
+
metadata.cu_seqlens_q = self.draft_decode_metadata_topk_normal[
|
1335
|
+
"cu_seqlens_q"
|
1336
|
+
][: bs + 1]
|
1337
|
+
metadata.cu_seqlens_k = self.draft_decode_metadata_topk_normal[
|
1338
|
+
"cu_seqlens_k"
|
1339
|
+
][: bs + 1]
|
1340
|
+
metadata.page_table = self.draft_decode_metadata_topk_normal[
|
1341
|
+
"page_table"
|
1342
|
+
][req_pool_indices, :]
|
1343
|
+
|
1344
|
+
# 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
|
1345
|
+
metadata_expand.cache_seqlens_int32 = (
|
1346
|
+
self.draft_decode_metadata_topk_expand["cache_seqlens"][
|
1347
|
+
: bs * self.topk
|
1348
|
+
]
|
1349
|
+
)
|
1350
|
+
metadata_expand.max_seq_len_q = 1
|
1351
|
+
metadata_expand.max_seq_len_k = (
|
1352
|
+
self.speculative_step_id + 1
|
1353
|
+
) # , do this in replay
|
1354
|
+
metadata_expand.cu_seqlens_q = (
|
1355
|
+
self.draft_decode_metadata_topk_expand["cu_seqlens_q"][
|
1356
|
+
: bs * self.topk + 1
|
1357
|
+
]
|
1358
|
+
)
|
1359
|
+
metadata_expand.cu_seqlens_k = (
|
1360
|
+
self.draft_decode_metadata_topk_expand["cu_seqlens_k"][
|
1361
|
+
: bs * self.topk + 1
|
1362
|
+
]
|
1363
|
+
)
|
1364
|
+
metadata_expand.page_table = self.draft_decode_metadata_topk_expand[
|
1365
|
+
"page_table"
|
1366
|
+
][: bs * self.topk]
|
1367
|
+
self.draft_decode_metadata_topk_normal[bs] = metadata
|
1368
|
+
self.draft_decode_metadata_topk_expand[bs] = metadata_expand
|
911
1369
|
else:
|
912
1370
|
# Normal Decode
|
913
1371
|
# Get sequence information
|
@@ -927,37 +1385,77 @@ class FlashAttentionBackend(AttentionBackend):
|
|
927
1385
|
metadata.cu_seqlens_q = torch.arange(
|
928
1386
|
0, batch_size + 1, dtype=torch.int32, device=device
|
929
1387
|
)
|
930
|
-
|
1388
|
+
self.decode_cuda_graph_metadata[bs] = metadata
|
1389
|
+
|
931
1390
|
elif forward_mode.is_target_verify():
|
932
|
-
|
933
|
-
|
934
|
-
|
935
|
-
|
936
|
-
|
937
|
-
|
1391
|
+
if self.topk <= 1:
|
1392
|
+
metadata.cache_seqlens_int32 = self.target_verify_metadata[
|
1393
|
+
"cache_seqlens"
|
1394
|
+
][:bs]
|
1395
|
+
metadata.cache_seqlens_int32.copy_(
|
1396
|
+
(seq_lens + self.speculative_num_draft_tokens).to(torch.int32)
|
1397
|
+
)
|
938
1398
|
|
939
|
-
|
940
|
-
|
941
|
-
|
942
|
-
|
1399
|
+
metadata.max_seq_len_q = self.speculative_num_draft_tokens
|
1400
|
+
metadata.max_seq_len_k = (
|
1401
|
+
seq_lens.max().item() + self.speculative_num_draft_tokens
|
1402
|
+
)
|
943
1403
|
|
944
|
-
|
945
|
-
|
946
|
-
|
947
|
-
|
948
|
-
|
949
|
-
|
950
|
-
|
1404
|
+
metadata.cu_seqlens_q = torch.arange(
|
1405
|
+
0,
|
1406
|
+
bs * self.speculative_num_draft_tokens + 1,
|
1407
|
+
self.speculative_num_draft_tokens,
|
1408
|
+
dtype=torch.int32,
|
1409
|
+
device=device,
|
1410
|
+
)
|
951
1411
|
|
952
|
-
|
953
|
-
|
954
|
-
|
1412
|
+
metadata.cu_seqlens_k = self.target_verify_metadata["cu_seqlens_k"][
|
1413
|
+
: (bs + 1)
|
1414
|
+
]
|
955
1415
|
|
956
|
-
|
957
|
-
|
958
|
-
|
1416
|
+
metadata.page_table = self.target_verify_metadata["page_table"][
|
1417
|
+
req_pool_indices, :
|
1418
|
+
]
|
959
1419
|
|
960
|
-
|
1420
|
+
self.target_verify_metadata[bs] = metadata
|
1421
|
+
else:
|
1422
|
+
# When topk > 1, we need two specific target verify metadata, and then merge states
|
1423
|
+
# 1. The first half of metadata for prefix tokens
|
1424
|
+
metadata.cache_seqlens_int32 = self.target_verify_metadata_topk_normal[
|
1425
|
+
"cache_seqlens"
|
1426
|
+
][:bs]
|
1427
|
+
metadata.max_seq_len_q = self.speculative_num_draft_tokens
|
1428
|
+
# metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item(), do this in replay
|
1429
|
+
metadata.cu_seqlens_q = self.target_verify_metadata_topk_normal[
|
1430
|
+
"cu_seqlens_q"
|
1431
|
+
][: bs + 1]
|
1432
|
+
metadata.cu_seqlens_k = self.target_verify_metadata_topk_normal[
|
1433
|
+
"cu_seqlens_k"
|
1434
|
+
][: bs + 1]
|
1435
|
+
metadata.page_table = self.target_verify_metadata_topk_normal[
|
1436
|
+
"page_table"
|
1437
|
+
][req_pool_indices, :]
|
1438
|
+
|
1439
|
+
# 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
|
1440
|
+
metadata_expand.cache_seqlens_int32 = (
|
1441
|
+
self.target_verify_metadata_topk_expand["cache_seqlens"][
|
1442
|
+
: bs * self.speculative_num_draft_tokens
|
1443
|
+
]
|
1444
|
+
)
|
1445
|
+
metadata_expand.max_seq_len_q = 1
|
1446
|
+
metadata_expand.cu_seqlens_q = self.target_verify_metadata_topk_expand[
|
1447
|
+
"cu_seqlens_q"
|
1448
|
+
][: bs * self.speculative_num_draft_tokens + 1]
|
1449
|
+
metadata_expand.cu_seqlens_k = self.target_verify_metadata_topk_expand[
|
1450
|
+
"cu_seqlens_k"
|
1451
|
+
][: bs * self.speculative_num_draft_tokens + 1]
|
1452
|
+
|
1453
|
+
metadata_expand.page_table = self.target_verify_metadata_topk_expand[
|
1454
|
+
"page_table"
|
1455
|
+
][: bs * self.speculative_num_draft_tokens]
|
1456
|
+
|
1457
|
+
self.target_verify_metadata_topk_normal[bs] = metadata
|
1458
|
+
self.target_verify_metadata_topk_expand[bs] = metadata_expand
|
961
1459
|
|
962
1460
|
if encoder_lens is not None:
|
963
1461
|
encoder_bs = encoder_lens.numel()
|
@@ -973,6 +1471,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
973
1471
|
]
|
974
1472
|
|
975
1473
|
self.forward_metadata = metadata
|
1474
|
+
self.forward_metadata_spec_decode_expand = metadata_expand
|
976
1475
|
|
977
1476
|
def init_forward_metadata_replay_cuda_graph(
|
978
1477
|
self,
|
@@ -986,41 +1485,85 @@ class FlashAttentionBackend(AttentionBackend):
|
|
986
1485
|
seq_lens_cpu: Optional[torch.Tensor],
|
987
1486
|
out_cache_loc: torch.Tensor = None,
|
988
1487
|
):
|
989
|
-
|
1488
|
+
"""Initialize forward metadata for replaying CUDA graph."""
|
990
1489
|
seq_lens = seq_lens[:bs]
|
991
1490
|
seq_lens_cpu = seq_lens_cpu[:bs]
|
992
1491
|
req_pool_indices = req_pool_indices[:bs]
|
993
1492
|
device = seq_lens.device
|
1493
|
+
metadata = None
|
1494
|
+
metadata_expand = None
|
994
1495
|
|
995
1496
|
if forward_mode.is_decode_or_idle():
|
996
|
-
metadata = self.decode_cuda_graph_metadata[bs]
|
997
1497
|
|
998
1498
|
if spec_info is not None:
|
999
1499
|
# Draft Decode
|
1000
|
-
|
1001
|
-
|
1002
|
-
|
1500
|
+
if self.topk <= 1:
|
1501
|
+
metadata = self.decode_cuda_graph_metadata[bs]
|
1502
|
+
# When topk = 1, we use the normal decode metadata
|
1503
|
+
metadata.cache_seqlens_int32.copy_(
|
1504
|
+
(seq_lens + (self.speculative_step_id + 1)).to(torch.int32)
|
1505
|
+
)
|
1003
1506
|
|
1004
|
-
|
1005
|
-
|
1006
|
-
|
1007
|
-
|
1008
|
-
|
1009
|
-
|
1010
|
-
|
1011
|
-
|
1012
|
-
|
1507
|
+
metadata.max_seq_len_k = seq_lens_cpu.max().item() + (
|
1508
|
+
self.speculative_step_id + 1
|
1509
|
+
)
|
1510
|
+
metadata.cu_seqlens_k.copy_(
|
1511
|
+
torch.nn.functional.pad(
|
1512
|
+
torch.cumsum(
|
1513
|
+
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
1514
|
+
),
|
1515
|
+
(1, 0),
|
1516
|
+
)
|
1013
1517
|
)
|
1014
|
-
)
|
1015
1518
|
|
1016
|
-
|
1017
|
-
|
1018
|
-
|
1519
|
+
max_seq_pages = (
|
1520
|
+
metadata.max_seq_len_k + self.page_size - 1
|
1521
|
+
) // self.page_size
|
1522
|
+
page_indices = self.req_to_token[
|
1523
|
+
req_pool_indices[:, None],
|
1524
|
+
self.decode_cuda_graph_metadata["strided_indices"][
|
1525
|
+
:max_seq_pages
|
1526
|
+
],
|
1527
|
+
]
|
1528
|
+
|
1529
|
+
page_indices //= self.page_size
|
1530
|
+
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
|
1531
|
+
else:
|
1532
|
+
# When top k > 1, we need two specific draft decode metadata, and then merge states
|
1533
|
+
# 1. The first half of metadata for prefix tokens
|
1534
|
+
metadata = self.draft_decode_metadata_topk_normal[bs]
|
1535
|
+
metadata.cache_seqlens_int32.copy_(seq_lens.to(torch.int32))
|
1536
|
+
# metadata.max_seq_len_q = self.topk, already set in capture
|
1537
|
+
metadata.max_seq_len_k = seq_lens_cpu.max().item()
|
1538
|
+
# metadata.cu_seqlens_q already set in capture
|
1539
|
+
metadata.cu_seqlens_k.copy_(
|
1540
|
+
torch.nn.functional.pad(
|
1541
|
+
torch.cumsum(
|
1542
|
+
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
1543
|
+
),
|
1544
|
+
(1, 0),
|
1545
|
+
)
|
1546
|
+
)
|
1019
1547
|
|
1020
|
-
|
1548
|
+
page_table = self.req_to_token[
|
1549
|
+
req_pool_indices, : metadata.max_seq_len_k
|
1550
|
+
]
|
1551
|
+
|
1552
|
+
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
|
1021
1553
|
|
1554
|
+
# 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
|
1555
|
+
metadata_expand = self.draft_decode_metadata_topk_expand[bs]
|
1556
|
+
decode_length = self.speculative_step_id + 1
|
1557
|
+
cache_loc = out_cache_loc.view(
|
1558
|
+
self.speculative_num_steps, -1
|
1559
|
+
).T.contiguous()
|
1560
|
+
metadata_expand.page_table[: cache_loc.shape[0]].copy_(
|
1561
|
+
cache_loc[:, :decode_length].contiguous().to(torch.int32)
|
1562
|
+
)
|
1563
|
+
# TODO: we need to test this part for llama 4 eagle case
|
1022
1564
|
self._init_local_attn_metadata(metadata, device)
|
1023
1565
|
else:
|
1566
|
+
metadata = self.decode_cuda_graph_metadata[bs]
|
1024
1567
|
# Normal Decode
|
1025
1568
|
max_len = seq_lens_cpu.max().item()
|
1026
1569
|
metadata.max_seq_len_k = max_len
|
@@ -1045,24 +1588,117 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1045
1588
|
|
1046
1589
|
self._init_local_attn_metadata(metadata, device)
|
1047
1590
|
elif forward_mode.is_target_verify():
|
1048
|
-
|
1049
|
-
|
1050
|
-
|
1051
|
-
|
1591
|
+
if self.topk <= 1:
|
1592
|
+
metadata = self.target_verify_metadata[bs]
|
1593
|
+
metadata.cache_seqlens_int32.copy_(
|
1594
|
+
(seq_lens + self.speculative_num_draft_tokens).to(torch.int32)
|
1595
|
+
)
|
1052
1596
|
|
1053
|
-
|
1054
|
-
|
1055
|
-
|
1056
|
-
|
1057
|
-
|
1597
|
+
metadata.max_seq_len_k = (
|
1598
|
+
seq_lens_cpu.max().item() + self.speculative_num_draft_tokens
|
1599
|
+
)
|
1600
|
+
metadata.cu_seqlens_k.copy_(
|
1601
|
+
torch.nn.functional.pad(
|
1602
|
+
torch.cumsum(
|
1603
|
+
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
1604
|
+
),
|
1605
|
+
(1, 0),
|
1606
|
+
)
|
1607
|
+
)
|
1608
|
+
max_seq_pages = (
|
1609
|
+
metadata.max_seq_len_k + self.page_size - 1
|
1610
|
+
) // self.page_size
|
1611
|
+
page_indices = self.req_to_token[
|
1612
|
+
req_pool_indices[:, None],
|
1613
|
+
self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages],
|
1614
|
+
]
|
1615
|
+
page_indices //= self.page_size
|
1616
|
+
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
|
1617
|
+
else:
|
1618
|
+
# When topk > 1, we need two specific target verify metadata, and then merge states
|
1619
|
+
# 1. The first half of metadata for prefix tokens
|
1620
|
+
metadata = self.target_verify_metadata_topk_normal[bs]
|
1621
|
+
metadata.cache_seqlens_int32.copy_(seq_lens.to(torch.int32))
|
1622
|
+
# metadata.max_seq_len_q = self.speculative_num_draft_tokens, already set in capture
|
1623
|
+
metadata.max_seq_len_k = seq_lens_cpu.max().item()
|
1624
|
+
# metadata.cu_seqlens_q already set in capture
|
1625
|
+
metadata.cu_seqlens_k.copy_(
|
1626
|
+
torch.nn.functional.pad(
|
1627
|
+
torch.cumsum(
|
1628
|
+
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
1629
|
+
),
|
1630
|
+
(1, 0),
|
1631
|
+
)
|
1632
|
+
)
|
1633
|
+
page_table = self.req_to_token[
|
1634
|
+
req_pool_indices, : metadata.max_seq_len_k
|
1635
|
+
]
|
1636
|
+
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
|
1637
|
+
|
1638
|
+
# 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
|
1639
|
+
metadata_expand = self.target_verify_metadata_topk_expand[bs]
|
1640
|
+
# metadata_expand.max_seq_len_q = 1, already set in capture
|
1641
|
+
# metadata_expand.cu_seqlens_q already set in capture
|
1642
|
+
|
1643
|
+
offsets = torch.arange(
|
1644
|
+
self.speculative_num_draft_tokens, device=device
|
1645
|
+
).unsqueeze(
|
1646
|
+
0
|
1647
|
+
) # shape: (1, self.speculative_num_draft_tokens)
|
1648
|
+
cols = offsets.expand(seq_lens.numel(), -1) + seq_lens.unsqueeze(1)
|
1649
|
+
cum_len = torch.nn.functional.pad(
|
1058
1650
|
torch.cumsum(
|
1059
|
-
|
1651
|
+
(
|
1652
|
+
seq_lens + self.speculative_num_draft_tokens
|
1653
|
+
).repeat_interleave(self.speculative_num_draft_tokens),
|
1654
|
+
dim=0,
|
1060
1655
|
),
|
1061
1656
|
(1, 0),
|
1657
|
+
)[:-1]
|
1658
|
+
mask_extraction_indices = (
|
1659
|
+
cols.repeat_interleave(self.speculative_num_draft_tokens, dim=0)
|
1660
|
+
+ cum_len[:, None]
|
1661
|
+
).view(1, -1)
|
1662
|
+
# avoid extracting padded seq indices which will be out of boundary
|
1663
|
+
mask_extraction_indices[
|
1664
|
+
:, spec_info.positions.numel() * self.speculative_num_draft_tokens :
|
1665
|
+
].fill_(0)
|
1666
|
+
|
1667
|
+
mask = spec_info.custom_mask[mask_extraction_indices].view(
|
1668
|
+
-1, self.speculative_num_draft_tokens
|
1669
|
+
) # (bsz * draft_num, draft_num)
|
1670
|
+
col_indices = offsets.expand(
|
1671
|
+
mask.shape[0], self.speculative_num_draft_tokens
|
1672
|
+
)
|
1673
|
+
keys = torch.where(
|
1674
|
+
mask, col_indices, col_indices + self.speculative_num_draft_tokens
|
1675
|
+
)
|
1676
|
+
_, sort_order = torch.sort(keys, dim=1)
|
1677
|
+
|
1678
|
+
non_masked_page_table = (
|
1679
|
+
self.req_to_token[req_pool_indices, :]
|
1680
|
+
.gather(1, cols)
|
1681
|
+
.repeat_interleave(self.speculative_num_draft_tokens, dim=0)
|
1682
|
+
) # (bsz, draft_num)
|
1683
|
+
metadata_expand.page_table.copy_(
|
1684
|
+
non_masked_page_table.gather(1, sort_order)
|
1685
|
+
)
|
1686
|
+
metadata_expand.cache_seqlens_int32.copy_(
|
1687
|
+
mask.sum(dim=1).to(torch.int32)
|
1688
|
+
)
|
1689
|
+
metadata_expand.cu_seqlens_k.copy_(
|
1690
|
+
torch.nn.functional.pad(
|
1691
|
+
torch.cumsum(
|
1692
|
+
metadata_expand.cache_seqlens_int32,
|
1693
|
+
dim=0,
|
1694
|
+
dtype=torch.int32,
|
1695
|
+
),
|
1696
|
+
(1, 0),
|
1697
|
+
)
|
1698
|
+
)
|
1699
|
+
metadata_expand.max_seq_len_k = (
|
1700
|
+
metadata_expand.cache_seqlens_int32.max().item()
|
1062
1701
|
)
|
1063
|
-
)
|
1064
|
-
page_table = self.req_to_token[req_pool_indices, : metadata.max_seq_len_k]
|
1065
|
-
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
|
1066
1702
|
|
1067
1703
|
if encoder_lens is not None:
|
1068
1704
|
# Only support encoder size 1 for now
|
@@ -1089,6 +1725,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1089
1725
|
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
|
1090
1726
|
|
1091
1727
|
self.forward_metadata = metadata
|
1728
|
+
self.forward_metadata_spec_decode_expand = metadata_expand
|
1092
1729
|
|
1093
1730
|
def get_cuda_graph_seq_len_fill_value(self):
|
1094
1731
|
"""Get the fill value for sequence length in CUDA graph."""
|
@@ -1139,12 +1776,6 @@ class FlashAttentionMultiStepBackend:
|
|
1139
1776
|
self.model_runner = model_runner
|
1140
1777
|
self.topk = topk
|
1141
1778
|
self.speculative_num_steps = speculative_num_steps
|
1142
|
-
|
1143
|
-
# TODO: Support Topk > 1 for FlashAttentionBackend Spec Decoding
|
1144
|
-
assert (
|
1145
|
-
self.topk == 1
|
1146
|
-
), "speculative_eagle_topk must be 1 for FlashAttentionMultiStepBackend"
|
1147
|
-
|
1148
1779
|
self.attn_backends = []
|
1149
1780
|
for i in range(self.speculative_num_steps):
|
1150
1781
|
self.attn_backends.append(
|