sglang 0.4.5.post2__py3-none-any.whl → 0.4.5.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. sglang/bench_serving.py +3 -2
  2. sglang/compile_deep_gemm.py +136 -0
  3. sglang/lang/backend/openai.py +5 -1
  4. sglang/lang/backend/runtime_endpoint.py +5 -1
  5. sglang/srt/configs/model_config.py +4 -1
  6. sglang/srt/constrained/xgrammar_backend.py +1 -0
  7. sglang/srt/disaggregation/decode.py +43 -0
  8. sglang/srt/disaggregation/mini_lb.py +69 -8
  9. sglang/srt/disaggregation/mooncake/conn.py +1 -1
  10. sglang/srt/disaggregation/nixl/__init__.py +1 -0
  11. sglang/srt/disaggregation/nixl/conn.py +622 -0
  12. sglang/srt/disaggregation/prefill.py +100 -16
  13. sglang/srt/disaggregation/utils.py +17 -0
  14. sglang/srt/entrypoints/engine.py +4 -0
  15. sglang/srt/entrypoints/http_server.py +3 -7
  16. sglang/srt/function_call_parser.py +60 -0
  17. sglang/srt/layers/activation.py +2 -2
  18. sglang/srt/layers/attention/flashattention_backend.py +781 -150
  19. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +5 -5
  20. sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -5
  21. sglang/srt/layers/attention/triton_ops/prefill_attention.py +7 -3
  22. sglang/srt/layers/dp_attention.py +1 -1
  23. sglang/srt/layers/layernorm.py +19 -4
  24. sglang/srt/layers/moe/ep_moe/layer.py +2 -0
  25. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  26. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  27. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +153 -0
  28. sglang/srt/layers/quantization/deep_gemm.py +378 -0
  29. sglang/srt/layers/quantization/fp8_kernel.py +7 -38
  30. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  31. sglang/srt/layers/quantization/gptq.py +13 -7
  32. sglang/srt/layers/quantization/modelopt_quant.py +2 -2
  33. sglang/srt/layers/quantization/w8a8_int8.py +3 -3
  34. sglang/srt/layers/rotary_embedding.py +6 -6
  35. sglang/srt/layers/sampler.py +2 -2
  36. sglang/srt/managers/data_parallel_controller.py +7 -1
  37. sglang/srt/managers/io_struct.py +14 -3
  38. sglang/srt/managers/schedule_batch.py +13 -0
  39. sglang/srt/managers/scheduler.py +16 -6
  40. sglang/srt/managers/tokenizer_manager.py +115 -29
  41. sglang/srt/managers/tp_worker.py +1 -0
  42. sglang/srt/mem_cache/hiradix_cache.py +40 -32
  43. sglang/srt/mem_cache/memory_pool.py +31 -13
  44. sglang/srt/model_executor/cuda_graph_runner.py +13 -8
  45. sglang/srt/model_executor/model_runner.py +19 -4
  46. sglang/srt/models/deepseek_v2.py +9 -6
  47. sglang/srt/models/minicpm3.py +2 -2
  48. sglang/srt/models/minicpmo.py +17 -6
  49. sglang/srt/openai_api/adapter.py +71 -4
  50. sglang/srt/openai_api/protocol.py +6 -1
  51. sglang/srt/server_args.py +52 -40
  52. sglang/srt/speculative/build_eagle_tree.py +2 -2
  53. sglang/srt/speculative/eagle_utils.py +2 -2
  54. sglang/srt/speculative/eagle_worker.py +2 -7
  55. sglang/srt/utils.py +46 -5
  56. sglang/test/test_utils.py +3 -1
  57. sglang/version.py +1 -1
  58. {sglang-0.4.5.post2.dist-info → sglang-0.4.5.post3.dist-info}/METADATA +3 -3
  59. {sglang-0.4.5.post2.dist-info → sglang-0.4.5.post3.dist-info}/RECORD +62 -57
  60. {sglang-0.4.5.post2.dist-info → sglang-0.4.5.post3.dist-info}/WHEEL +0 -0
  61. {sglang-0.4.5.post2.dist-info → sglang-0.4.5.post3.dist-info}/licenses/LICENSE +0 -0
  62. {sglang-0.4.5.post2.dist-info → sglang-0.4.5.post3.dist-info}/top_level.txt +0 -0
@@ -16,6 +16,7 @@ if TYPE_CHECKING:
16
16
  from sglang.srt.layers.radix_attention import RadixAttention
17
17
  from sglang.srt.model_executor.model_runner import ModelRunner
18
18
 
19
+ from sgl_kernel import merge_state_v2
19
20
  from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
20
21
 
21
22
 
@@ -30,7 +31,7 @@ class FlashAttentionMetadata:
30
31
  # Sequence lengths for the forward batch
31
32
  cache_seqlens_int32: torch.Tensor = None
32
33
  # Maximum sequence length for query
33
- max_seq_len_q: int = 0
34
+ max_seq_len_q: int = 1
34
35
  # Maximum sequence length for key
35
36
  max_seq_len_k: int = 0
36
37
  # Cumulative sequence lengths for query
@@ -267,6 +268,12 @@ def cdiv(a: int, b: int) -> int:
267
268
  return -(a // -b)
268
269
 
269
270
 
271
+ # TODO(hebiao064): remove this once we have a better way to handle the merge_state_v2 torch.compile issue
272
+ @torch._dynamo.disable()
273
+ def merge_state_v2_wrapper(o, s_a, o_exp, s_b):
274
+ return merge_state_v2(o, s_a, o_exp, s_b)
275
+
276
+
270
277
  class FlashAttentionBackend(AttentionBackend):
271
278
  """FlashAttention backend implementation.
272
279
 
@@ -301,6 +308,8 @@ class FlashAttentionBackend(AttentionBackend):
301
308
  ), "Sliding window and cross attention are not supported together"
302
309
 
303
310
  self.forward_metadata: FlashAttentionMetadata = None
311
+ # extra metdata for handling speculative decoding topk > 1, extended draft decode and verify
312
+ self.forward_metadata_spec_decode_expand: FlashAttentionMetadata = None
304
313
  self.max_context_len = model_runner.model_config.context_len
305
314
  self.device = model_runner.device
306
315
  self.decode_cuda_graph_metadata = {}
@@ -311,8 +320,7 @@ class FlashAttentionBackend(AttentionBackend):
311
320
  self.page_size = model_runner.page_size
312
321
  self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA
313
322
  self.skip_prefill = skip_prefill
314
-
315
- self.topk = topk
323
+ self.topk = model_runner.server_args.speculative_eagle_topk or 0
316
324
  self.speculative_num_steps = speculative_num_steps
317
325
  self.speculative_num_draft_tokens = (
318
326
  model_runner.server_args.speculative_num_draft_tokens
@@ -336,14 +344,107 @@ class FlashAttentionBackend(AttentionBackend):
336
344
  if forward_batch.forward_mode.is_decode_or_idle():
337
345
  # Draft Decode
338
346
  if forward_batch.spec_info is not None:
347
+ if self.topk <= 1:
348
+ metadata.cache_seqlens_int32 = (
349
+ seqlens_in_batch + (self.speculative_step_id + 1)
350
+ ).to(torch.int32)
351
+ metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + (
352
+ self.speculative_step_id + 1
353
+ )
354
+ metadata.cu_seqlens_q = torch.arange(
355
+ 0, batch_size + 1, dtype=torch.int32, device=device
356
+ )
357
+ metadata.cu_seqlens_k = torch.nn.functional.pad(
358
+ torch.cumsum(
359
+ metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
360
+ ),
361
+ (1, 0),
362
+ )
363
+ metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
364
+ forward_batch.req_pool_indices, : metadata.max_seq_len_k
365
+ ]
366
+ else:
367
+ metadata.cache_seqlens_int32 = (seqlens_in_batch).to(torch.int32)
368
+ metadata.max_seq_len_q = self.topk
369
+ metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
370
+ metadata.cu_seqlens_q = torch.arange(
371
+ 0,
372
+ batch_size * self.topk + 1,
373
+ step=self.topk,
374
+ dtype=torch.int32,
375
+ device=device,
376
+ )
377
+ metadata.cu_seqlens_k = torch.nn.functional.pad(
378
+ torch.cumsum(
379
+ metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
380
+ ),
381
+ (1, 0),
382
+ )
383
+ metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
384
+ forward_batch.req_pool_indices, : metadata.max_seq_len_k
385
+ ]
386
+
387
+ metadata_expand = FlashAttentionMetadata()
388
+ decode_length = self.speculative_step_id + 1
389
+ metadata_expand.cache_seqlens_int32 = torch.full(
390
+ (seqlens_in_batch.numel() * self.topk,),
391
+ decode_length,
392
+ device=device,
393
+ dtype=torch.int32,
394
+ )
395
+ metadata_expand.max_seq_len_q = 1
396
+ metadata_expand.max_seq_len_k = self.speculative_step_id + 1
397
+ metadata_expand.cu_seqlens_q = torch.arange(
398
+ 0,
399
+ metadata_expand.cache_seqlens_int32.numel() + 1,
400
+ dtype=torch.int32,
401
+ device=device,
402
+ )
403
+ metadata_expand.cu_seqlens_k = torch.arange(
404
+ 0,
405
+ metadata_expand.cache_seqlens_int32.numel() * decode_length + 1,
406
+ step=decode_length,
407
+ dtype=torch.int32,
408
+ device=device,
409
+ )
410
+ cache_loc = forward_batch.out_cache_loc.view(
411
+ self.speculative_num_steps, -1
412
+ ).T.contiguous()
413
+ metadata_expand.page_table = (
414
+ cache_loc[:, :decode_length].contiguous().to(torch.int32)
415
+ )
416
+ self.forward_metadata_spec_decode_expand = metadata_expand
417
+ else:
418
+ # Normal Decode
419
+ metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
420
+ metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
421
+ metadata.cu_seqlens_q = torch.arange(
422
+ 0, batch_size + 1, dtype=torch.int32, device=device
423
+ )
424
+ metadata.cu_seqlens_k = torch.nn.functional.pad(
425
+ torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
426
+ )
427
+ metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
428
+ forward_batch.req_pool_indices, : metadata.max_seq_len_k
429
+ ]
430
+ # TODO: we need to test this part for llama 4 eagle case
431
+ self._init_local_attn_metadata(metadata, device)
432
+ elif forward_batch.forward_mode.is_target_verify():
433
+ if self.topk <= 1:
339
434
  metadata.cache_seqlens_int32 = (
340
- seqlens_in_batch + (self.speculative_step_id + 1)
435
+ forward_batch.seq_lens + self.speculative_num_draft_tokens
341
436
  ).to(torch.int32)
342
- metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + (
343
- self.speculative_step_id + 1
437
+ metadata.max_seq_len_q = self.speculative_num_draft_tokens
438
+ metadata.max_seq_len_k = (
439
+ forward_batch.seq_lens_cpu.max().item()
440
+ + self.speculative_num_draft_tokens
344
441
  )
345
442
  metadata.cu_seqlens_q = torch.arange(
346
- 0, batch_size + 1, dtype=torch.int32, device=device
443
+ 0,
444
+ batch_size * self.speculative_num_draft_tokens + 1,
445
+ self.speculative_num_draft_tokens,
446
+ dtype=torch.int32,
447
+ device=device,
347
448
  )
348
449
  metadata.cu_seqlens_k = torch.nn.functional.pad(
349
450
  torch.cumsum(
@@ -357,44 +458,101 @@ class FlashAttentionBackend(AttentionBackend):
357
458
 
358
459
  self._init_local_attn_metadata(metadata, device)
359
460
  else:
360
- # Normal Decode
361
- metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
461
+ metadata.cache_seqlens_int32 = forward_batch.seq_lens.to(torch.int32)
462
+ metadata.max_seq_len_q = self.speculative_num_draft_tokens
362
463
  metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
363
464
  metadata.cu_seqlens_q = torch.arange(
364
- 0, batch_size + 1, dtype=torch.int32, device=device
465
+ 0,
466
+ batch_size * self.speculative_num_draft_tokens + 1,
467
+ step=self.speculative_num_draft_tokens,
468
+ dtype=torch.int32,
469
+ device=device,
365
470
  )
366
471
  metadata.cu_seqlens_k = torch.nn.functional.pad(
367
- torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
472
+ torch.cumsum(
473
+ metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
474
+ ),
475
+ (1, 0),
368
476
  )
369
477
  metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
370
478
  forward_batch.req_pool_indices, : metadata.max_seq_len_k
371
479
  ]
372
480
 
373
- self._init_local_attn_metadata(metadata, device)
374
- elif forward_batch.forward_mode.is_target_verify():
375
- metadata.cache_seqlens_int32 = (
376
- forward_batch.seq_lens + self.speculative_num_draft_tokens
377
- ).to(torch.int32)
378
- metadata.max_seq_len_q = self.speculative_num_draft_tokens
379
- metadata.max_seq_len_k = (
380
- forward_batch.seq_lens_cpu.max().item()
381
- + self.speculative_num_draft_tokens
382
- )
383
- metadata.cu_seqlens_q = torch.arange(
384
- 0,
385
- batch_size * self.speculative_num_draft_tokens + 1,
386
- self.speculative_num_draft_tokens,
387
- dtype=torch.int32,
388
- device=device,
389
- )
390
- metadata.cu_seqlens_k = torch.nn.functional.pad(
391
- torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32),
392
- (1, 0),
393
- )
394
- metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
395
- forward_batch.req_pool_indices, : metadata.max_seq_len_k
396
- ]
481
+ metadata_expand = FlashAttentionMetadata()
397
482
 
483
+ metadata_expand.max_seq_len_q = 1
484
+ metadata_expand.cu_seqlens_q = torch.arange(
485
+ 0,
486
+ forward_batch.seq_lens.numel() * self.speculative_num_draft_tokens
487
+ + 1,
488
+ dtype=torch.int32,
489
+ device=device,
490
+ )
491
+
492
+ # create expand page table
493
+ offsets = torch.arange(
494
+ self.speculative_num_draft_tokens, device=device
495
+ ).unsqueeze(
496
+ 0
497
+ ) # shape: (1, self.speculative_num_draft_tokens)
498
+ cols = offsets.expand(
499
+ forward_batch.seq_lens.numel(), -1
500
+ ) + forward_batch.seq_lens.unsqueeze(1)
501
+ cum_len = torch.nn.functional.pad(
502
+ torch.cumsum(
503
+ (
504
+ forward_batch.seq_lens + self.speculative_num_draft_tokens
505
+ ).repeat_interleave(self.speculative_num_draft_tokens),
506
+ dim=0,
507
+ ),
508
+ (1, 0),
509
+ )[:-1]
510
+ mask_extraction_indices = (
511
+ cols.repeat_interleave(self.speculative_num_draft_tokens, dim=0)
512
+ + cum_len[:, None]
513
+ ).view(1, -1)
514
+ mask = forward_batch.spec_info.custom_mask[
515
+ mask_extraction_indices
516
+ ].view(
517
+ -1, self.speculative_num_draft_tokens
518
+ ) # (bsz * draft_num, draft_num)
519
+
520
+ # shift table indices to avoid padding
521
+ # non_masked_page_table [[8, 9, 10], mask (display with int format) [[1, 0, 0],
522
+ # [8, 9, 10], [1, 1, 0],
523
+ # [8, 9, 10]] [1, 0, 1]]
524
+ # if masked with padding [[8, 0, 0], our mask without padding [[8, 9, 10],
525
+ # [8, 9, 0], [8, 9, 10],
526
+ # [8, 0, 10]] [8, 10, 9]]
527
+ # note here cache_seqlens_int32 is [1, 2, 2] so extra page indices will be ignored in each row
528
+ col_indices = offsets.expand(
529
+ mask.shape[0], self.speculative_num_draft_tokens
530
+ )
531
+ # Build keys: if an entry is valid (mask==True), keep its original index;
532
+ # if not, add self.speculative_num_draft_tokens so that it sorts after all valid entries.
533
+ keys = torch.where(
534
+ mask, col_indices, col_indices + self.speculative_num_draft_tokens
535
+ )
536
+ _, sort_order = torch.sort(keys, dim=1)
537
+ non_masked_page_table = (
538
+ forward_batch.req_to_token_pool.req_to_token[
539
+ forward_batch.req_pool_indices, :
540
+ ]
541
+ .gather(1, cols)
542
+ .repeat_interleave(self.speculative_num_draft_tokens, dim=0)
543
+ ) # (bsz, draft_num)
544
+ metadata_expand.page_table = non_masked_page_table.gather(1, sort_order)
545
+ metadata_expand.cache_seqlens_int32 = mask.sum(dim=1).to(torch.int32)
546
+ metadata_expand.cu_seqlens_k = torch.nn.functional.pad(
547
+ torch.cumsum(
548
+ metadata_expand.cache_seqlens_int32, dim=0, dtype=torch.int32
549
+ ),
550
+ (1, 0),
551
+ )
552
+ metadata_expand.max_seq_len_k = (
553
+ metadata_expand.cache_seqlens_int32.max().item()
554
+ )
555
+ self.forward_metadata_spec_decode_expand = metadata_expand
398
556
  elif forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed():
399
557
  metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
400
558
  metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
@@ -514,6 +672,11 @@ class FlashAttentionBackend(AttentionBackend):
514
672
  and (hasattr(layer, "use_irope") and layer.use_irope)
515
673
  )
516
674
 
675
+ # We do cascade attention for Target Verify with topk > 1
676
+ use_cascade_attn = (
677
+ forward_batch.forward_mode.is_target_verify() and self.topk > 1
678
+ )
679
+
517
680
  # Get the appropriate page table based on whether we're using local attention
518
681
  if use_local_attn:
519
682
  local_metadata = metadata.local_attn_metadata
@@ -548,7 +711,7 @@ class FlashAttentionBackend(AttentionBackend):
548
711
  cu_seqlens_k = metadata.encoder_cu_seqlens_k
549
712
  window_size = (-1, -1)
550
713
 
551
- o = flash_attn_with_kvcache(
714
+ result = flash_attn_with_kvcache(
552
715
  q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
553
716
  k_cache=key_cache,
554
717
  v_cache=value_cache,
@@ -558,13 +721,41 @@ class FlashAttentionBackend(AttentionBackend):
558
721
  cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
559
722
  max_seqlen_q=max_seqlen_q,
560
723
  softmax_scale=layer.scaling,
561
- causal=causal,
724
+ causal=False if use_cascade_attn else causal,
562
725
  window_size=window_size,
563
726
  softcap=layer.logit_cap,
564
727
  k_descale=k_descale,
565
728
  v_descale=v_descale,
729
+ return_softmax_lse=use_cascade_attn,
566
730
  )
567
- return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
731
+
732
+ if use_cascade_attn:
733
+ o, softmax_lse, *rest = result
734
+ o_expand, softmax_lse_expand, *rest_expand = flash_attn_with_kvcache(
735
+ q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
736
+ k_cache=key_cache,
737
+ v_cache=value_cache,
738
+ page_table=self.forward_metadata_spec_decode_expand.page_table,
739
+ cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,
740
+ cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
741
+ cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
742
+ max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
743
+ softmax_scale=layer.scaling,
744
+ causal=False,
745
+ window_size=window_size,
746
+ softcap=layer.logit_cap,
747
+ k_descale=k_descale,
748
+ v_descale=v_descale,
749
+ return_softmax_lse=True,
750
+ )
751
+ o, _ = merge_state_v2_wrapper(
752
+ o,
753
+ softmax_lse.T.contiguous(),
754
+ o_expand,
755
+ softmax_lse_expand.T.contiguous(),
756
+ )
757
+ else:
758
+ o = result
568
759
  else:
569
760
  if (
570
761
  not global_server_args_dict["disable_chunked_prefix_cache"]
@@ -627,7 +818,8 @@ class FlashAttentionBackend(AttentionBackend):
627
818
  q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
628
819
  q_nope = q_all[:, :, : layer.v_head_dim]
629
820
  q_rope = q_all[:, :, layer.v_head_dim :]
630
- o = flash_attn_with_kvcache(
821
+
822
+ result = flash_attn_with_kvcache(
631
823
  q=q_rope,
632
824
  k_cache=k_rope_cache,
633
825
  v_cache=c_kv_cache,
@@ -638,13 +830,44 @@ class FlashAttentionBackend(AttentionBackend):
638
830
  cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
639
831
  max_seqlen_q=max_seqlen_q,
640
832
  softmax_scale=layer.scaling,
641
- causal=True,
833
+ causal=False if use_cascade_attn else causal,
642
834
  softcap=layer.logit_cap,
643
835
  k_descale=k_descale,
644
836
  v_descale=v_descale,
837
+ return_softmax_lse=use_cascade_attn,
645
838
  )
839
+ if use_cascade_attn:
840
+ o, softmax_lse, *rest = result
841
+ o_expand, softmax_lse_expand, *rest_expand = (
842
+ flash_attn_with_kvcache(
843
+ q=q_rope,
844
+ k_cache=k_rope_cache,
845
+ v_cache=c_kv_cache,
846
+ qv=q_nope,
847
+ page_table=self.forward_metadata_spec_decode_expand.page_table,
848
+ cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,
849
+ cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
850
+ cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
851
+ max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
852
+ softmax_scale=layer.scaling,
853
+ causal=False,
854
+ window_size=window_size,
855
+ softcap=layer.logit_cap,
856
+ k_descale=k_descale,
857
+ v_descale=v_descale,
858
+ return_softmax_lse=True,
859
+ )
860
+ )
861
+ o, _ = merge_state_v2_wrapper(
862
+ o,
863
+ softmax_lse.T.contiguous(),
864
+ o_expand,
865
+ softmax_lse_expand.T.contiguous(),
866
+ )
867
+ else:
868
+ o = result
646
869
 
647
- return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
870
+ return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
648
871
 
649
872
  def forward_decode(
650
873
  self,
@@ -681,6 +904,8 @@ class FlashAttentionBackend(AttentionBackend):
681
904
  use_local_attention = (
682
905
  self.attention_chunk_size is not None and local_attn_metadata is not None
683
906
  )
907
+ # We do cascade attention for Draft Decode with topk > 1
908
+ use_cascade_attn = self.topk > 1
684
909
 
685
910
  # Calculate window size (can be moved to metadata if layer properties don't change)
686
911
  # we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
@@ -752,23 +977,61 @@ class FlashAttentionBackend(AttentionBackend):
752
977
  v_descale=v_descale,
753
978
  )
754
979
  else:
980
+ page_table = metadata.page_table
981
+ cache_seqlens = metadata.cache_seqlens_int32
982
+ cu_seqlens_k = metadata.cu_seqlens_k
983
+ max_seqlen_q = metadata.max_seq_len_q
984
+ q_reshaped = q.contiguous().view(
985
+ -1, layer.tp_q_head_num, layer.head_dim
986
+ )
987
+
755
988
  # Default: single-token self-attention
756
- o = flash_attn_with_kvcache(
757
- q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
989
+ result = flash_attn_with_kvcache(
990
+ q=q_reshaped,
758
991
  k_cache=key_cache,
759
992
  v_cache=value_cache,
760
- page_table=metadata.page_table,
761
- cache_seqlens=metadata.cache_seqlens_int32,
993
+ page_table=page_table,
994
+ cache_seqlens=cache_seqlens,
762
995
  cu_seqlens_q=metadata.cu_seqlens_q,
763
- cu_seqlens_k_new=metadata.cu_seqlens_k,
764
- max_seqlen_q=1,
996
+ cu_seqlens_k_new=cu_seqlens_k,
997
+ max_seqlen_q=max_seqlen_q,
765
998
  softmax_scale=layer.scaling,
766
- causal=True,
999
+ causal=False if use_cascade_attn else causal,
767
1000
  window_size=window_size,
768
1001
  softcap=layer.logit_cap,
769
1002
  k_descale=k_descale,
770
1003
  v_descale=v_descale,
1004
+ return_softmax_lse=use_cascade_attn,
771
1005
  )
1006
+ if use_cascade_attn:
1007
+ o, softmax_lse, *rest = result
1008
+ o_expand, softmax_lse_expand, *rest_expand = (
1009
+ flash_attn_with_kvcache(
1010
+ q=q_reshaped,
1011
+ k_cache=key_cache,
1012
+ v_cache=value_cache,
1013
+ page_table=self.forward_metadata_spec_decode_expand.page_table,
1014
+ cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,
1015
+ cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
1016
+ cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
1017
+ max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
1018
+ softmax_scale=layer.scaling,
1019
+ causal=False,
1020
+ window_size=window_size,
1021
+ softcap=layer.logit_cap,
1022
+ k_descale=k_descale,
1023
+ v_descale=v_descale,
1024
+ return_softmax_lse=True,
1025
+ )
1026
+ )
1027
+ o, _ = merge_state_v2(
1028
+ o,
1029
+ softmax_lse.T.contiguous(),
1030
+ o_expand,
1031
+ softmax_lse_expand.T.contiguous(),
1032
+ )
1033
+ else:
1034
+ o = result
772
1035
  else:
773
1036
  # Do absorbed multi-latent attention
774
1037
  kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
@@ -787,8 +1050,9 @@ class FlashAttentionBackend(AttentionBackend):
787
1050
  q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
788
1051
  q_nope = q_all[:, :, : layer.v_head_dim]
789
1052
  q_rope = q_all[:, :, layer.v_head_dim :]
1053
+ max_seqlen_q = metadata.max_seq_len_q
790
1054
 
791
- o = flash_attn_with_kvcache(
1055
+ result = flash_attn_with_kvcache(
792
1056
  q=q_rope,
793
1057
  k_cache=k_rope_cache,
794
1058
  v_cache=c_kv_cache,
@@ -797,13 +1061,43 @@ class FlashAttentionBackend(AttentionBackend):
797
1061
  cache_seqlens=metadata.cache_seqlens_int32,
798
1062
  cu_seqlens_q=metadata.cu_seqlens_q,
799
1063
  cu_seqlens_k_new=metadata.cu_seqlens_k,
800
- max_seqlen_q=1,
1064
+ max_seqlen_q=max_seqlen_q,
801
1065
  softmax_scale=layer.scaling,
802
- causal=True,
1066
+ causal=False if use_cascade_attn else causal,
803
1067
  softcap=layer.logit_cap,
804
1068
  k_descale=k_descale,
805
1069
  v_descale=v_descale,
1070
+ return_softmax_lse=use_cascade_attn, # softmax_lse is needed for merge states
806
1071
  )
1072
+ if use_cascade_attn:
1073
+ o, softmax_lse, *rest = result
1074
+ o_expand, softmax_lse_expand, *rest_expand = flash_attn_with_kvcache(
1075
+ q=q_rope,
1076
+ k_cache=k_rope_cache,
1077
+ v_cache=c_kv_cache,
1078
+ qv=q_nope,
1079
+ page_table=self.forward_metadata_spec_decode_expand.page_table,
1080
+ cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,
1081
+ cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
1082
+ cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
1083
+ max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
1084
+ softmax_scale=layer.scaling,
1085
+ causal=False,
1086
+ window_size=window_size,
1087
+ softcap=layer.logit_cap,
1088
+ k_descale=k_descale,
1089
+ v_descale=v_descale,
1090
+ return_softmax_lse=True,
1091
+ )
1092
+ o, _ = merge_state_v2(
1093
+ o,
1094
+ softmax_lse.T.contiguous(),
1095
+ o_expand,
1096
+ softmax_lse_expand.T.contiguous(),
1097
+ )
1098
+ else:
1099
+ o = result
1100
+
807
1101
  return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
808
1102
 
809
1103
  def init_cuda_graph_state(self, max_bs: int):
@@ -815,6 +1109,8 @@ class FlashAttentionBackend(AttentionBackend):
815
1109
  This creates fixed-size tensors that will be reused during CUDA graph replay
816
1110
  to avoid memory allocations.
817
1111
  """
1112
+
1113
+ # This is being used by normal decode and draft decode when topk == 1
818
1114
  self.decode_cuda_graph_metadata = {
819
1115
  "cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
820
1116
  "cu_seqlens_q": torch.arange(
@@ -840,24 +1136,136 @@ class FlashAttentionBackend(AttentionBackend):
840
1136
  ),
841
1137
  }
842
1138
 
843
- self.target_verify_metadata = {
844
- "cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
845
- "cu_seqlens_q": torch.zeros(
846
- max_bs + 1, dtype=torch.int32, device=self.device
847
- ),
848
- "cu_seqlens_k": torch.zeros(
849
- max_bs + 1, dtype=torch.int32, device=self.device
850
- ),
851
- "page_table": torch.zeros(
852
- max_bs,
853
- (self.max_context_len + self.page_size - 1) // self.page_size,
854
- dtype=torch.int32,
855
- device=self.device,
856
- ),
857
- "strided_indices": torch.arange(
858
- 0, self.max_context_len, self.page_size, device=self.device
859
- ),
860
- }
1139
+ # This is used by draft decode's first half of metadata when topk > 1
1140
+ if self.topk > 1:
1141
+ self.draft_decode_metadata_topk_normal = {
1142
+ "cache_seqlens": torch.zeros(
1143
+ max_bs, dtype=torch.int32, device=self.device
1144
+ ),
1145
+ "cu_seqlens_q": torch.arange(
1146
+ 0,
1147
+ max_bs * self.topk + 1,
1148
+ step=self.topk,
1149
+ dtype=torch.int32,
1150
+ device=self.device,
1151
+ ),
1152
+ "cu_seqlens_k": torch.zeros(
1153
+ max_bs + 1, dtype=torch.int32, device=self.device
1154
+ ),
1155
+ "page_table": torch.zeros(
1156
+ max_bs,
1157
+ self.max_context_len,
1158
+ dtype=torch.int32,
1159
+ device=self.device,
1160
+ ),
1161
+ }
1162
+
1163
+ # This is used by draft decode's second half of metadata when topk > 1
1164
+ decode_length = self.speculative_step_id + 1
1165
+ self.draft_decode_metadata_topk_expand = {
1166
+ "cache_seqlens": torch.full(
1167
+ (max_bs * self.topk,),
1168
+ decode_length,
1169
+ device=self.device,
1170
+ dtype=torch.int32,
1171
+ ),
1172
+ "cu_seqlens_q": torch.arange(
1173
+ 0,
1174
+ max_bs * self.topk + 1,
1175
+ dtype=torch.int32,
1176
+ device=self.device,
1177
+ ),
1178
+ "cu_seqlens_k": torch.arange(
1179
+ 0,
1180
+ max_bs * self.topk * decode_length + 1,
1181
+ step=decode_length,
1182
+ dtype=torch.int32,
1183
+ device=self.device,
1184
+ ),
1185
+ "page_table": torch.zeros(
1186
+ max_bs * self.topk,
1187
+ decode_length,
1188
+ dtype=torch.int32,
1189
+ device=self.device,
1190
+ ),
1191
+ }
1192
+
1193
+ if (
1194
+ self.speculative_num_draft_tokens is not None
1195
+ and self.speculative_num_draft_tokens > 0
1196
+ ):
1197
+ self.target_verify_metadata = {
1198
+ "cache_seqlens": torch.zeros(
1199
+ max_bs, dtype=torch.int32, device=self.device
1200
+ ),
1201
+ "cu_seqlens_q": torch.arange(
1202
+ 0,
1203
+ max_bs * self.speculative_num_draft_tokens + 1,
1204
+ step=self.speculative_num_draft_tokens,
1205
+ dtype=torch.int32,
1206
+ device=self.device,
1207
+ ),
1208
+ "cu_seqlens_k": torch.zeros(
1209
+ max_bs + 1, dtype=torch.int32, device=self.device
1210
+ ),
1211
+ "page_table": torch.zeros(
1212
+ max_bs,
1213
+ (self.max_context_len + self.page_size - 1) // self.page_size,
1214
+ dtype=torch.int32,
1215
+ device=self.device,
1216
+ ),
1217
+ "strided_indices": torch.arange(
1218
+ 0, self.max_context_len, self.page_size, device=self.device
1219
+ ),
1220
+ }
1221
+
1222
+ if self.topk > 1:
1223
+ self.target_verify_metadata_topk_normal = {
1224
+ "cache_seqlens": torch.zeros(
1225
+ max_bs, dtype=torch.int32, device=self.device
1226
+ ),
1227
+ "cu_seqlens_q": torch.arange(
1228
+ 0,
1229
+ max_bs * self.speculative_num_draft_tokens + 1,
1230
+ step=self.speculative_num_draft_tokens,
1231
+ dtype=torch.int32,
1232
+ device=self.device,
1233
+ ),
1234
+ "cu_seqlens_k": torch.zeros(
1235
+ max_bs + 1, dtype=torch.int32, device=self.device
1236
+ ),
1237
+ "page_table": torch.zeros(
1238
+ max_bs,
1239
+ self.max_context_len,
1240
+ dtype=torch.int32,
1241
+ device=self.device,
1242
+ ),
1243
+ }
1244
+
1245
+ self.target_verify_metadata_topk_expand = {
1246
+ "cache_seqlens": torch.zeros(
1247
+ max_bs * self.speculative_num_draft_tokens,
1248
+ dtype=torch.int32,
1249
+ device=self.device,
1250
+ ),
1251
+ "cu_seqlens_k": torch.zeros(
1252
+ max_bs * self.speculative_num_draft_tokens + 1,
1253
+ dtype=torch.int32,
1254
+ device=self.device,
1255
+ ),
1256
+ "cu_seqlens_q": torch.arange(
1257
+ 0,
1258
+ max_bs * self.speculative_num_draft_tokens + 1,
1259
+ dtype=torch.int32,
1260
+ device=self.device,
1261
+ ),
1262
+ "page_table": torch.zeros(
1263
+ max_bs * self.speculative_num_draft_tokens,
1264
+ self.speculative_num_draft_tokens,
1265
+ dtype=torch.int32,
1266
+ device=self.device,
1267
+ ),
1268
+ }
861
1269
 
862
1270
  self.encoder_metadata = {
863
1271
  "encoder_page_table": torch.zeros(
@@ -886,28 +1294,78 @@ class FlashAttentionBackend(AttentionBackend):
886
1294
  ):
887
1295
  """Initialize forward metadata for capturing CUDA graph."""
888
1296
  metadata = FlashAttentionMetadata()
1297
+
1298
+ # metadata_expand is needed for Spec Decoding when top k > 1
1299
+ metadata_expand = FlashAttentionMetadata()
1300
+
889
1301
  device = seq_lens.device
890
1302
  if forward_mode.is_decode_or_idle():
891
1303
  if spec_info is not None:
892
1304
  # Draft Decode
893
- metadata.cache_seqlens_int32 = self.decode_cuda_graph_metadata[
894
- "cache_seqlens"
895
- ][:bs]
896
- metadata.max_seq_len_k = seq_lens.max().item() + (
897
- self.speculative_step_id + 1
898
- )
899
- metadata.cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"][
900
- : bs + 1
901
- ]
902
- metadata.cu_seqlens_k = torch.nn.functional.pad(
903
- torch.cumsum(
904
- metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
905
- ),
906
- (1, 0),
907
- )
908
- metadata.page_table = self.decode_cuda_graph_metadata[
909
- "page_table_draft_decode"
910
- ][req_pool_indices, :]
1305
+ if self.topk <= 1:
1306
+ # When topk = 1, we use the normal decode metadata
1307
+ metadata.cache_seqlens_int32 = self.decode_cuda_graph_metadata[
1308
+ "cache_seqlens"
1309
+ ][:bs]
1310
+ metadata.max_seq_len_k = seq_lens.max().item() + (
1311
+ self.speculative_step_id + 1
1312
+ )
1313
+ metadata.cu_seqlens_q = self.decode_cuda_graph_metadata[
1314
+ "cu_seqlens_q"
1315
+ ][: bs + 1]
1316
+ metadata.cu_seqlens_k = torch.nn.functional.pad(
1317
+ torch.cumsum(
1318
+ metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
1319
+ ),
1320
+ (1, 0),
1321
+ )
1322
+ metadata.page_table = self.decode_cuda_graph_metadata[
1323
+ "page_table_draft_decode"
1324
+ ][req_pool_indices, :]
1325
+ self.decode_cuda_graph_metadata[bs] = metadata
1326
+ else:
1327
+ # When top k > 1, we need two specific draft decode metadata, and then merge states
1328
+ # 1. The first half of metadata for prefix tokens
1329
+ metadata.cache_seqlens_int32 = (
1330
+ self.draft_decode_metadata_topk_normal["cache_seqlens"][:bs]
1331
+ )
1332
+ metadata.max_seq_len_q = self.topk
1333
+ metadata.max_seq_len_k = seq_lens.max().item()
1334
+ metadata.cu_seqlens_q = self.draft_decode_metadata_topk_normal[
1335
+ "cu_seqlens_q"
1336
+ ][: bs + 1]
1337
+ metadata.cu_seqlens_k = self.draft_decode_metadata_topk_normal[
1338
+ "cu_seqlens_k"
1339
+ ][: bs + 1]
1340
+ metadata.page_table = self.draft_decode_metadata_topk_normal[
1341
+ "page_table"
1342
+ ][req_pool_indices, :]
1343
+
1344
+ # 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
1345
+ metadata_expand.cache_seqlens_int32 = (
1346
+ self.draft_decode_metadata_topk_expand["cache_seqlens"][
1347
+ : bs * self.topk
1348
+ ]
1349
+ )
1350
+ metadata_expand.max_seq_len_q = 1
1351
+ metadata_expand.max_seq_len_k = (
1352
+ self.speculative_step_id + 1
1353
+ ) # , do this in replay
1354
+ metadata_expand.cu_seqlens_q = (
1355
+ self.draft_decode_metadata_topk_expand["cu_seqlens_q"][
1356
+ : bs * self.topk + 1
1357
+ ]
1358
+ )
1359
+ metadata_expand.cu_seqlens_k = (
1360
+ self.draft_decode_metadata_topk_expand["cu_seqlens_k"][
1361
+ : bs * self.topk + 1
1362
+ ]
1363
+ )
1364
+ metadata_expand.page_table = self.draft_decode_metadata_topk_expand[
1365
+ "page_table"
1366
+ ][: bs * self.topk]
1367
+ self.draft_decode_metadata_topk_normal[bs] = metadata
1368
+ self.draft_decode_metadata_topk_expand[bs] = metadata_expand
911
1369
  else:
912
1370
  # Normal Decode
913
1371
  # Get sequence information
@@ -927,37 +1385,77 @@ class FlashAttentionBackend(AttentionBackend):
927
1385
  metadata.cu_seqlens_q = torch.arange(
928
1386
  0, batch_size + 1, dtype=torch.int32, device=device
929
1387
  )
930
- self.decode_cuda_graph_metadata[bs] = metadata
1388
+ self.decode_cuda_graph_metadata[bs] = metadata
1389
+
931
1390
  elif forward_mode.is_target_verify():
932
- metadata.cache_seqlens_int32 = self.target_verify_metadata["cache_seqlens"][
933
- :bs
934
- ]
935
- metadata.cache_seqlens_int32.copy_(
936
- (seq_lens + self.speculative_num_draft_tokens).to(torch.int32)
937
- )
1391
+ if self.topk <= 1:
1392
+ metadata.cache_seqlens_int32 = self.target_verify_metadata[
1393
+ "cache_seqlens"
1394
+ ][:bs]
1395
+ metadata.cache_seqlens_int32.copy_(
1396
+ (seq_lens + self.speculative_num_draft_tokens).to(torch.int32)
1397
+ )
938
1398
 
939
- metadata.max_seq_len_q = self.speculative_num_draft_tokens
940
- metadata.max_seq_len_k = (
941
- seq_lens.max().item() + self.speculative_num_draft_tokens
942
- )
1399
+ metadata.max_seq_len_q = self.speculative_num_draft_tokens
1400
+ metadata.max_seq_len_k = (
1401
+ seq_lens.max().item() + self.speculative_num_draft_tokens
1402
+ )
943
1403
 
944
- metadata.cu_seqlens_q = torch.arange(
945
- 0,
946
- bs * self.speculative_num_draft_tokens + 1,
947
- self.speculative_num_draft_tokens,
948
- dtype=torch.int32,
949
- device=device,
950
- )
1404
+ metadata.cu_seqlens_q = torch.arange(
1405
+ 0,
1406
+ bs * self.speculative_num_draft_tokens + 1,
1407
+ self.speculative_num_draft_tokens,
1408
+ dtype=torch.int32,
1409
+ device=device,
1410
+ )
951
1411
 
952
- metadata.cu_seqlens_k = self.target_verify_metadata["cu_seqlens_k"][
953
- : (bs + 1)
954
- ]
1412
+ metadata.cu_seqlens_k = self.target_verify_metadata["cu_seqlens_k"][
1413
+ : (bs + 1)
1414
+ ]
955
1415
 
956
- metadata.page_table = self.target_verify_metadata["page_table"][
957
- req_pool_indices, :
958
- ]
1416
+ metadata.page_table = self.target_verify_metadata["page_table"][
1417
+ req_pool_indices, :
1418
+ ]
959
1419
 
960
- self.target_verify_metadata[bs] = metadata
1420
+ self.target_verify_metadata[bs] = metadata
1421
+ else:
1422
+ # When topk > 1, we need two specific target verify metadata, and then merge states
1423
+ # 1. The first half of metadata for prefix tokens
1424
+ metadata.cache_seqlens_int32 = self.target_verify_metadata_topk_normal[
1425
+ "cache_seqlens"
1426
+ ][:bs]
1427
+ metadata.max_seq_len_q = self.speculative_num_draft_tokens
1428
+ # metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item(), do this in replay
1429
+ metadata.cu_seqlens_q = self.target_verify_metadata_topk_normal[
1430
+ "cu_seqlens_q"
1431
+ ][: bs + 1]
1432
+ metadata.cu_seqlens_k = self.target_verify_metadata_topk_normal[
1433
+ "cu_seqlens_k"
1434
+ ][: bs + 1]
1435
+ metadata.page_table = self.target_verify_metadata_topk_normal[
1436
+ "page_table"
1437
+ ][req_pool_indices, :]
1438
+
1439
+ # 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
1440
+ metadata_expand.cache_seqlens_int32 = (
1441
+ self.target_verify_metadata_topk_expand["cache_seqlens"][
1442
+ : bs * self.speculative_num_draft_tokens
1443
+ ]
1444
+ )
1445
+ metadata_expand.max_seq_len_q = 1
1446
+ metadata_expand.cu_seqlens_q = self.target_verify_metadata_topk_expand[
1447
+ "cu_seqlens_q"
1448
+ ][: bs * self.speculative_num_draft_tokens + 1]
1449
+ metadata_expand.cu_seqlens_k = self.target_verify_metadata_topk_expand[
1450
+ "cu_seqlens_k"
1451
+ ][: bs * self.speculative_num_draft_tokens + 1]
1452
+
1453
+ metadata_expand.page_table = self.target_verify_metadata_topk_expand[
1454
+ "page_table"
1455
+ ][: bs * self.speculative_num_draft_tokens]
1456
+
1457
+ self.target_verify_metadata_topk_normal[bs] = metadata
1458
+ self.target_verify_metadata_topk_expand[bs] = metadata_expand
961
1459
 
962
1460
  if encoder_lens is not None:
963
1461
  encoder_bs = encoder_lens.numel()
@@ -973,6 +1471,7 @@ class FlashAttentionBackend(AttentionBackend):
973
1471
  ]
974
1472
 
975
1473
  self.forward_metadata = metadata
1474
+ self.forward_metadata_spec_decode_expand = metadata_expand
976
1475
 
977
1476
  def init_forward_metadata_replay_cuda_graph(
978
1477
  self,
@@ -986,41 +1485,85 @@ class FlashAttentionBackend(AttentionBackend):
986
1485
  seq_lens_cpu: Optional[torch.Tensor],
987
1486
  out_cache_loc: torch.Tensor = None,
988
1487
  ):
989
- # """Initialize forward metadata for replaying CUDA graph."""
1488
+ """Initialize forward metadata for replaying CUDA graph."""
990
1489
  seq_lens = seq_lens[:bs]
991
1490
  seq_lens_cpu = seq_lens_cpu[:bs]
992
1491
  req_pool_indices = req_pool_indices[:bs]
993
1492
  device = seq_lens.device
1493
+ metadata = None
1494
+ metadata_expand = None
994
1495
 
995
1496
  if forward_mode.is_decode_or_idle():
996
- metadata = self.decode_cuda_graph_metadata[bs]
997
1497
 
998
1498
  if spec_info is not None:
999
1499
  # Draft Decode
1000
- metadata.cache_seqlens_int32.copy_(
1001
- (seq_lens + (self.speculative_step_id + 1)).to(torch.int32)
1002
- )
1500
+ if self.topk <= 1:
1501
+ metadata = self.decode_cuda_graph_metadata[bs]
1502
+ # When topk = 1, we use the normal decode metadata
1503
+ metadata.cache_seqlens_int32.copy_(
1504
+ (seq_lens + (self.speculative_step_id + 1)).to(torch.int32)
1505
+ )
1003
1506
 
1004
- metadata.max_seq_len_k = seq_lens_cpu.max().item() + (
1005
- self.speculative_step_id + 1
1006
- )
1007
- metadata.cu_seqlens_k.copy_(
1008
- torch.nn.functional.pad(
1009
- torch.cumsum(
1010
- metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
1011
- ),
1012
- (1, 0),
1507
+ metadata.max_seq_len_k = seq_lens_cpu.max().item() + (
1508
+ self.speculative_step_id + 1
1509
+ )
1510
+ metadata.cu_seqlens_k.copy_(
1511
+ torch.nn.functional.pad(
1512
+ torch.cumsum(
1513
+ metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
1514
+ ),
1515
+ (1, 0),
1516
+ )
1013
1517
  )
1014
- )
1015
1518
 
1016
- page_table = self.req_to_token[
1017
- req_pool_indices, : metadata.max_seq_len_k
1018
- ]
1519
+ max_seq_pages = (
1520
+ metadata.max_seq_len_k + self.page_size - 1
1521
+ ) // self.page_size
1522
+ page_indices = self.req_to_token[
1523
+ req_pool_indices[:, None],
1524
+ self.decode_cuda_graph_metadata["strided_indices"][
1525
+ :max_seq_pages
1526
+ ],
1527
+ ]
1528
+
1529
+ page_indices //= self.page_size
1530
+ metadata.page_table[:, :max_seq_pages].copy_(page_indices)
1531
+ else:
1532
+ # When top k > 1, we need two specific draft decode metadata, and then merge states
1533
+ # 1. The first half of metadata for prefix tokens
1534
+ metadata = self.draft_decode_metadata_topk_normal[bs]
1535
+ metadata.cache_seqlens_int32.copy_(seq_lens.to(torch.int32))
1536
+ # metadata.max_seq_len_q = self.topk, already set in capture
1537
+ metadata.max_seq_len_k = seq_lens_cpu.max().item()
1538
+ # metadata.cu_seqlens_q already set in capture
1539
+ metadata.cu_seqlens_k.copy_(
1540
+ torch.nn.functional.pad(
1541
+ torch.cumsum(
1542
+ metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
1543
+ ),
1544
+ (1, 0),
1545
+ )
1546
+ )
1019
1547
 
1020
- metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
1548
+ page_table = self.req_to_token[
1549
+ req_pool_indices, : metadata.max_seq_len_k
1550
+ ]
1551
+
1552
+ metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
1021
1553
 
1554
+ # 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
1555
+ metadata_expand = self.draft_decode_metadata_topk_expand[bs]
1556
+ decode_length = self.speculative_step_id + 1
1557
+ cache_loc = out_cache_loc.view(
1558
+ self.speculative_num_steps, -1
1559
+ ).T.contiguous()
1560
+ metadata_expand.page_table[: cache_loc.shape[0]].copy_(
1561
+ cache_loc[:, :decode_length].contiguous().to(torch.int32)
1562
+ )
1563
+ # TODO: we need to test this part for llama 4 eagle case
1022
1564
  self._init_local_attn_metadata(metadata, device)
1023
1565
  else:
1566
+ metadata = self.decode_cuda_graph_metadata[bs]
1024
1567
  # Normal Decode
1025
1568
  max_len = seq_lens_cpu.max().item()
1026
1569
  metadata.max_seq_len_k = max_len
@@ -1045,24 +1588,117 @@ class FlashAttentionBackend(AttentionBackend):
1045
1588
 
1046
1589
  self._init_local_attn_metadata(metadata, device)
1047
1590
  elif forward_mode.is_target_verify():
1048
- metadata = self.target_verify_metadata[bs]
1049
- metadata.cache_seqlens_int32.copy_(
1050
- (seq_lens + self.speculative_num_draft_tokens).to(torch.int32)
1051
- )
1591
+ if self.topk <= 1:
1592
+ metadata = self.target_verify_metadata[bs]
1593
+ metadata.cache_seqlens_int32.copy_(
1594
+ (seq_lens + self.speculative_num_draft_tokens).to(torch.int32)
1595
+ )
1052
1596
 
1053
- metadata.max_seq_len_k = (
1054
- seq_lens_cpu.max().item() + self.speculative_num_draft_tokens
1055
- )
1056
- metadata.cu_seqlens_k.copy_(
1057
- torch.nn.functional.pad(
1597
+ metadata.max_seq_len_k = (
1598
+ seq_lens_cpu.max().item() + self.speculative_num_draft_tokens
1599
+ )
1600
+ metadata.cu_seqlens_k.copy_(
1601
+ torch.nn.functional.pad(
1602
+ torch.cumsum(
1603
+ metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
1604
+ ),
1605
+ (1, 0),
1606
+ )
1607
+ )
1608
+ max_seq_pages = (
1609
+ metadata.max_seq_len_k + self.page_size - 1
1610
+ ) // self.page_size
1611
+ page_indices = self.req_to_token[
1612
+ req_pool_indices[:, None],
1613
+ self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages],
1614
+ ]
1615
+ page_indices //= self.page_size
1616
+ metadata.page_table[:, :max_seq_pages].copy_(page_indices)
1617
+ else:
1618
+ # When topk > 1, we need two specific target verify metadata, and then merge states
1619
+ # 1. The first half of metadata for prefix tokens
1620
+ metadata = self.target_verify_metadata_topk_normal[bs]
1621
+ metadata.cache_seqlens_int32.copy_(seq_lens.to(torch.int32))
1622
+ # metadata.max_seq_len_q = self.speculative_num_draft_tokens, already set in capture
1623
+ metadata.max_seq_len_k = seq_lens_cpu.max().item()
1624
+ # metadata.cu_seqlens_q already set in capture
1625
+ metadata.cu_seqlens_k.copy_(
1626
+ torch.nn.functional.pad(
1627
+ torch.cumsum(
1628
+ metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
1629
+ ),
1630
+ (1, 0),
1631
+ )
1632
+ )
1633
+ page_table = self.req_to_token[
1634
+ req_pool_indices, : metadata.max_seq_len_k
1635
+ ]
1636
+ metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
1637
+
1638
+ # 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
1639
+ metadata_expand = self.target_verify_metadata_topk_expand[bs]
1640
+ # metadata_expand.max_seq_len_q = 1, already set in capture
1641
+ # metadata_expand.cu_seqlens_q already set in capture
1642
+
1643
+ offsets = torch.arange(
1644
+ self.speculative_num_draft_tokens, device=device
1645
+ ).unsqueeze(
1646
+ 0
1647
+ ) # shape: (1, self.speculative_num_draft_tokens)
1648
+ cols = offsets.expand(seq_lens.numel(), -1) + seq_lens.unsqueeze(1)
1649
+ cum_len = torch.nn.functional.pad(
1058
1650
  torch.cumsum(
1059
- metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
1651
+ (
1652
+ seq_lens + self.speculative_num_draft_tokens
1653
+ ).repeat_interleave(self.speculative_num_draft_tokens),
1654
+ dim=0,
1060
1655
  ),
1061
1656
  (1, 0),
1657
+ )[:-1]
1658
+ mask_extraction_indices = (
1659
+ cols.repeat_interleave(self.speculative_num_draft_tokens, dim=0)
1660
+ + cum_len[:, None]
1661
+ ).view(1, -1)
1662
+ # avoid extracting padded seq indices which will be out of boundary
1663
+ mask_extraction_indices[
1664
+ :, spec_info.positions.numel() * self.speculative_num_draft_tokens :
1665
+ ].fill_(0)
1666
+
1667
+ mask = spec_info.custom_mask[mask_extraction_indices].view(
1668
+ -1, self.speculative_num_draft_tokens
1669
+ ) # (bsz * draft_num, draft_num)
1670
+ col_indices = offsets.expand(
1671
+ mask.shape[0], self.speculative_num_draft_tokens
1672
+ )
1673
+ keys = torch.where(
1674
+ mask, col_indices, col_indices + self.speculative_num_draft_tokens
1675
+ )
1676
+ _, sort_order = torch.sort(keys, dim=1)
1677
+
1678
+ non_masked_page_table = (
1679
+ self.req_to_token[req_pool_indices, :]
1680
+ .gather(1, cols)
1681
+ .repeat_interleave(self.speculative_num_draft_tokens, dim=0)
1682
+ ) # (bsz, draft_num)
1683
+ metadata_expand.page_table.copy_(
1684
+ non_masked_page_table.gather(1, sort_order)
1685
+ )
1686
+ metadata_expand.cache_seqlens_int32.copy_(
1687
+ mask.sum(dim=1).to(torch.int32)
1688
+ )
1689
+ metadata_expand.cu_seqlens_k.copy_(
1690
+ torch.nn.functional.pad(
1691
+ torch.cumsum(
1692
+ metadata_expand.cache_seqlens_int32,
1693
+ dim=0,
1694
+ dtype=torch.int32,
1695
+ ),
1696
+ (1, 0),
1697
+ )
1698
+ )
1699
+ metadata_expand.max_seq_len_k = (
1700
+ metadata_expand.cache_seqlens_int32.max().item()
1062
1701
  )
1063
- )
1064
- page_table = self.req_to_token[req_pool_indices, : metadata.max_seq_len_k]
1065
- metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
1066
1702
 
1067
1703
  if encoder_lens is not None:
1068
1704
  # Only support encoder size 1 for now
@@ -1089,6 +1725,7 @@ class FlashAttentionBackend(AttentionBackend):
1089
1725
  metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
1090
1726
 
1091
1727
  self.forward_metadata = metadata
1728
+ self.forward_metadata_spec_decode_expand = metadata_expand
1092
1729
 
1093
1730
  def get_cuda_graph_seq_len_fill_value(self):
1094
1731
  """Get the fill value for sequence length in CUDA graph."""
@@ -1139,12 +1776,6 @@ class FlashAttentionMultiStepBackend:
1139
1776
  self.model_runner = model_runner
1140
1777
  self.topk = topk
1141
1778
  self.speculative_num_steps = speculative_num_steps
1142
-
1143
- # TODO: Support Topk > 1 for FlashAttentionBackend Spec Decoding
1144
- assert (
1145
- self.topk == 1
1146
- ), "speculative_eagle_topk must be 1 for FlashAttentionMultiStepBackend"
1147
-
1148
1779
  self.attn_backends = []
1149
1780
  for i in range(self.speculative_num_steps):
1150
1781
  self.attn_backends.append(