sglang 0.4.5.post2__py3-none-any.whl → 0.4.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. sglang/bench_one_batch.py +19 -3
  2. sglang/bench_serving.py +8 -8
  3. sglang/compile_deep_gemm.py +177 -0
  4. sglang/lang/backend/openai.py +5 -1
  5. sglang/lang/backend/runtime_endpoint.py +5 -1
  6. sglang/srt/code_completion_parser.py +1 -1
  7. sglang/srt/configs/deepseekvl2.py +1 -1
  8. sglang/srt/configs/model_config.py +11 -2
  9. sglang/srt/constrained/llguidance_backend.py +78 -61
  10. sglang/srt/constrained/xgrammar_backend.py +1 -0
  11. sglang/srt/conversation.py +34 -1
  12. sglang/srt/disaggregation/decode.py +96 -5
  13. sglang/srt/disaggregation/mini_lb.py +113 -15
  14. sglang/srt/disaggregation/mooncake/conn.py +199 -32
  15. sglang/srt/disaggregation/nixl/__init__.py +1 -0
  16. sglang/srt/disaggregation/nixl/conn.py +622 -0
  17. sglang/srt/disaggregation/prefill.py +119 -20
  18. sglang/srt/disaggregation/utils.py +17 -0
  19. sglang/srt/entrypoints/engine.py +4 -0
  20. sglang/srt/entrypoints/http_server.py +11 -9
  21. sglang/srt/function_call_parser.py +132 -0
  22. sglang/srt/layers/activation.py +2 -2
  23. sglang/srt/layers/attention/base_attn_backend.py +3 -0
  24. sglang/srt/layers/attention/flashattention_backend.py +809 -160
  25. sglang/srt/layers/attention/flashmla_backend.py +8 -11
  26. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +5 -5
  27. sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -5
  28. sglang/srt/layers/attention/triton_ops/prefill_attention.py +7 -3
  29. sglang/srt/layers/attention/vision.py +2 -0
  30. sglang/srt/layers/dp_attention.py +1 -1
  31. sglang/srt/layers/layernorm.py +42 -5
  32. sglang/srt/layers/logits_processor.py +2 -2
  33. sglang/srt/layers/moe/ep_moe/layer.py +2 -0
  34. sglang/srt/layers/moe/fused_moe_native.py +2 -4
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +41 -41
  36. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  37. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +18 -15
  38. sglang/srt/layers/pooler.py +6 -0
  39. sglang/srt/layers/quantization/awq.py +5 -1
  40. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  41. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  42. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +153 -0
  43. sglang/srt/layers/quantization/deep_gemm.py +385 -0
  44. sglang/srt/layers/quantization/fp8_kernel.py +7 -38
  45. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  46. sglang/srt/layers/quantization/gptq.py +13 -7
  47. sglang/srt/layers/quantization/int8_kernel.py +32 -1
  48. sglang/srt/layers/quantization/modelopt_quant.py +2 -2
  49. sglang/srt/layers/quantization/w8a8_int8.py +3 -3
  50. sglang/srt/layers/radix_attention.py +13 -3
  51. sglang/srt/layers/rotary_embedding.py +176 -132
  52. sglang/srt/layers/sampler.py +2 -2
  53. sglang/srt/managers/data_parallel_controller.py +17 -4
  54. sglang/srt/managers/io_struct.py +21 -3
  55. sglang/srt/managers/mm_utils.py +85 -28
  56. sglang/srt/managers/multimodal_processors/base_processor.py +14 -1
  57. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +9 -2
  58. sglang/srt/managers/multimodal_processors/gemma3.py +2 -5
  59. sglang/srt/managers/multimodal_processors/janus_pro.py +2 -2
  60. sglang/srt/managers/multimodal_processors/minicpm.py +4 -3
  61. sglang/srt/managers/multimodal_processors/qwen_vl.py +38 -13
  62. sglang/srt/managers/schedule_batch.py +42 -12
  63. sglang/srt/managers/scheduler.py +47 -26
  64. sglang/srt/managers/tokenizer_manager.py +120 -30
  65. sglang/srt/managers/tp_worker.py +1 -0
  66. sglang/srt/mem_cache/hiradix_cache.py +40 -32
  67. sglang/srt/mem_cache/memory_pool.py +118 -13
  68. sglang/srt/model_executor/cuda_graph_runner.py +16 -10
  69. sglang/srt/model_executor/forward_batch_info.py +51 -95
  70. sglang/srt/model_executor/model_runner.py +29 -27
  71. sglang/srt/models/deepseek.py +12 -2
  72. sglang/srt/models/deepseek_nextn.py +101 -6
  73. sglang/srt/models/deepseek_v2.py +153 -76
  74. sglang/srt/models/deepseek_vl2.py +9 -4
  75. sglang/srt/models/gemma3_causal.py +1 -1
  76. sglang/srt/models/llama4.py +0 -1
  77. sglang/srt/models/minicpm3.py +2 -2
  78. sglang/srt/models/minicpmo.py +22 -7
  79. sglang/srt/models/mllama4.py +2 -2
  80. sglang/srt/models/qwen2_5_vl.py +3 -6
  81. sglang/srt/models/qwen2_vl.py +3 -7
  82. sglang/srt/models/roberta.py +178 -0
  83. sglang/srt/openai_api/adapter.py +87 -10
  84. sglang/srt/openai_api/protocol.py +6 -1
  85. sglang/srt/server_args.py +65 -60
  86. sglang/srt/speculative/build_eagle_tree.py +2 -2
  87. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  88. sglang/srt/speculative/eagle_utils.py +2 -2
  89. sglang/srt/speculative/eagle_worker.py +2 -7
  90. sglang/srt/torch_memory_saver_adapter.py +10 -1
  91. sglang/srt/utils.py +48 -6
  92. sglang/test/runners.py +6 -13
  93. sglang/test/test_utils.py +39 -19
  94. sglang/version.py +1 -1
  95. {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/METADATA +6 -7
  96. {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/RECORD +99 -92
  97. {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/WHEEL +1 -1
  98. {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/licenses/LICENSE +0 -0
  99. {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/top_level.txt +0 -0
@@ -16,6 +16,7 @@ if TYPE_CHECKING:
16
16
  from sglang.srt.layers.radix_attention import RadixAttention
17
17
  from sglang.srt.model_executor.model_runner import ModelRunner
18
18
 
19
+ from sgl_kernel import merge_state_v2
19
20
  from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
20
21
 
21
22
 
@@ -30,7 +31,7 @@ class FlashAttentionMetadata:
30
31
  # Sequence lengths for the forward batch
31
32
  cache_seqlens_int32: torch.Tensor = None
32
33
  # Maximum sequence length for query
33
- max_seq_len_q: int = 0
34
+ max_seq_len_q: int = 1
34
35
  # Maximum sequence length for key
35
36
  max_seq_len_k: int = 0
36
37
  # Cumulative sequence lengths for query
@@ -267,6 +268,12 @@ def cdiv(a: int, b: int) -> int:
267
268
  return -(a // -b)
268
269
 
269
270
 
271
+ # TODO(hebiao064): remove this once we have a better way to handle the merge_state_v2 torch.compile issue
272
+ @torch._dynamo.disable()
273
+ def merge_state_v2_wrapper(o, s_a, o_exp, s_b):
274
+ return merge_state_v2(o, s_a, o_exp, s_b)
275
+
276
+
270
277
  class FlashAttentionBackend(AttentionBackend):
271
278
  """FlashAttention backend implementation.
272
279
 
@@ -301,6 +308,8 @@ class FlashAttentionBackend(AttentionBackend):
301
308
  ), "Sliding window and cross attention are not supported together"
302
309
 
303
310
  self.forward_metadata: FlashAttentionMetadata = None
311
+ # extra metdata for handling speculative decoding topk > 1, extended draft decode and verify
312
+ self.forward_metadata_spec_decode_expand: FlashAttentionMetadata = None
304
313
  self.max_context_len = model_runner.model_config.context_len
305
314
  self.device = model_runner.device
306
315
  self.decode_cuda_graph_metadata = {}
@@ -311,8 +320,7 @@ class FlashAttentionBackend(AttentionBackend):
311
320
  self.page_size = model_runner.page_size
312
321
  self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA
313
322
  self.skip_prefill = skip_prefill
314
-
315
- self.topk = topk
323
+ self.topk = model_runner.server_args.speculative_eagle_topk or 0
316
324
  self.speculative_num_steps = speculative_num_steps
317
325
  self.speculative_num_draft_tokens = (
318
326
  model_runner.server_args.speculative_num_draft_tokens
@@ -336,14 +344,107 @@ class FlashAttentionBackend(AttentionBackend):
336
344
  if forward_batch.forward_mode.is_decode_or_idle():
337
345
  # Draft Decode
338
346
  if forward_batch.spec_info is not None:
347
+ if self.topk <= 1:
348
+ metadata.cache_seqlens_int32 = (
349
+ seqlens_in_batch + (self.speculative_step_id + 1)
350
+ ).to(torch.int32)
351
+ metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + (
352
+ self.speculative_step_id + 1
353
+ )
354
+ metadata.cu_seqlens_q = torch.arange(
355
+ 0, batch_size + 1, dtype=torch.int32, device=device
356
+ )
357
+ metadata.cu_seqlens_k = torch.nn.functional.pad(
358
+ torch.cumsum(
359
+ metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
360
+ ),
361
+ (1, 0),
362
+ )
363
+ metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
364
+ forward_batch.req_pool_indices, : metadata.max_seq_len_k
365
+ ]
366
+ else:
367
+ metadata.cache_seqlens_int32 = (seqlens_in_batch).to(torch.int32)
368
+ metadata.max_seq_len_q = self.topk
369
+ metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
370
+ metadata.cu_seqlens_q = torch.arange(
371
+ 0,
372
+ batch_size * self.topk + 1,
373
+ step=self.topk,
374
+ dtype=torch.int32,
375
+ device=device,
376
+ )
377
+ metadata.cu_seqlens_k = torch.nn.functional.pad(
378
+ torch.cumsum(
379
+ metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
380
+ ),
381
+ (1, 0),
382
+ )
383
+ metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
384
+ forward_batch.req_pool_indices, : metadata.max_seq_len_k
385
+ ]
386
+
387
+ metadata_expand = FlashAttentionMetadata()
388
+ decode_length = self.speculative_step_id + 1
389
+ metadata_expand.cache_seqlens_int32 = torch.full(
390
+ (seqlens_in_batch.numel() * self.topk,),
391
+ decode_length,
392
+ device=device,
393
+ dtype=torch.int32,
394
+ )
395
+ metadata_expand.max_seq_len_q = 1
396
+ metadata_expand.max_seq_len_k = self.speculative_step_id + 1
397
+ metadata_expand.cu_seqlens_q = torch.arange(
398
+ 0,
399
+ metadata_expand.cache_seqlens_int32.numel() + 1,
400
+ dtype=torch.int32,
401
+ device=device,
402
+ )
403
+ metadata_expand.cu_seqlens_k = torch.arange(
404
+ 0,
405
+ metadata_expand.cache_seqlens_int32.numel() * decode_length + 1,
406
+ step=decode_length,
407
+ dtype=torch.int32,
408
+ device=device,
409
+ )
410
+ cache_loc = forward_batch.out_cache_loc.view(
411
+ self.speculative_num_steps, -1
412
+ ).T.contiguous()
413
+ metadata_expand.page_table = (
414
+ cache_loc[:, :decode_length].contiguous().to(torch.int32)
415
+ )
416
+ self.forward_metadata_spec_decode_expand = metadata_expand
417
+ else:
418
+ # Normal Decode
419
+ metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
420
+ metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
421
+ metadata.cu_seqlens_q = torch.arange(
422
+ 0, batch_size + 1, dtype=torch.int32, device=device
423
+ )
424
+ metadata.cu_seqlens_k = torch.nn.functional.pad(
425
+ torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
426
+ )
427
+ metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
428
+ forward_batch.req_pool_indices, : metadata.max_seq_len_k
429
+ ]
430
+ # TODO: we need to test this part for llama 4 eagle case
431
+ self._init_local_attn_metadata(metadata, device)
432
+ elif forward_batch.forward_mode.is_target_verify():
433
+ if self.topk <= 1:
339
434
  metadata.cache_seqlens_int32 = (
340
- seqlens_in_batch + (self.speculative_step_id + 1)
435
+ forward_batch.seq_lens + self.speculative_num_draft_tokens
341
436
  ).to(torch.int32)
342
- metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + (
343
- self.speculative_step_id + 1
437
+ metadata.max_seq_len_q = self.speculative_num_draft_tokens
438
+ metadata.max_seq_len_k = (
439
+ forward_batch.seq_lens_cpu.max().item()
440
+ + self.speculative_num_draft_tokens
344
441
  )
345
442
  metadata.cu_seqlens_q = torch.arange(
346
- 0, batch_size + 1, dtype=torch.int32, device=device
443
+ 0,
444
+ batch_size * self.speculative_num_draft_tokens + 1,
445
+ self.speculative_num_draft_tokens,
446
+ dtype=torch.int32,
447
+ device=device,
347
448
  )
348
449
  metadata.cu_seqlens_k = torch.nn.functional.pad(
349
450
  torch.cumsum(
@@ -357,44 +458,101 @@ class FlashAttentionBackend(AttentionBackend):
357
458
 
358
459
  self._init_local_attn_metadata(metadata, device)
359
460
  else:
360
- # Normal Decode
361
- metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
461
+ metadata.cache_seqlens_int32 = forward_batch.seq_lens.to(torch.int32)
462
+ metadata.max_seq_len_q = self.speculative_num_draft_tokens
362
463
  metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
363
464
  metadata.cu_seqlens_q = torch.arange(
364
- 0, batch_size + 1, dtype=torch.int32, device=device
465
+ 0,
466
+ batch_size * self.speculative_num_draft_tokens + 1,
467
+ step=self.speculative_num_draft_tokens,
468
+ dtype=torch.int32,
469
+ device=device,
365
470
  )
366
471
  metadata.cu_seqlens_k = torch.nn.functional.pad(
367
- torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
472
+ torch.cumsum(
473
+ metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
474
+ ),
475
+ (1, 0),
368
476
  )
369
477
  metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
370
478
  forward_batch.req_pool_indices, : metadata.max_seq_len_k
371
479
  ]
372
480
 
373
- self._init_local_attn_metadata(metadata, device)
374
- elif forward_batch.forward_mode.is_target_verify():
375
- metadata.cache_seqlens_int32 = (
376
- forward_batch.seq_lens + self.speculative_num_draft_tokens
377
- ).to(torch.int32)
378
- metadata.max_seq_len_q = self.speculative_num_draft_tokens
379
- metadata.max_seq_len_k = (
380
- forward_batch.seq_lens_cpu.max().item()
381
- + self.speculative_num_draft_tokens
382
- )
383
- metadata.cu_seqlens_q = torch.arange(
384
- 0,
385
- batch_size * self.speculative_num_draft_tokens + 1,
386
- self.speculative_num_draft_tokens,
387
- dtype=torch.int32,
388
- device=device,
389
- )
390
- metadata.cu_seqlens_k = torch.nn.functional.pad(
391
- torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32),
392
- (1, 0),
393
- )
394
- metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
395
- forward_batch.req_pool_indices, : metadata.max_seq_len_k
396
- ]
481
+ metadata_expand = FlashAttentionMetadata()
397
482
 
483
+ metadata_expand.max_seq_len_q = 1
484
+ metadata_expand.cu_seqlens_q = torch.arange(
485
+ 0,
486
+ forward_batch.seq_lens.numel() * self.speculative_num_draft_tokens
487
+ + 1,
488
+ dtype=torch.int32,
489
+ device=device,
490
+ )
491
+
492
+ # create expand page table
493
+ offsets = torch.arange(
494
+ self.speculative_num_draft_tokens, device=device
495
+ ).unsqueeze(
496
+ 0
497
+ ) # shape: (1, self.speculative_num_draft_tokens)
498
+ cols = offsets.expand(
499
+ forward_batch.seq_lens.numel(), -1
500
+ ) + forward_batch.seq_lens.unsqueeze(1)
501
+ cum_len = torch.nn.functional.pad(
502
+ torch.cumsum(
503
+ (
504
+ forward_batch.seq_lens + self.speculative_num_draft_tokens
505
+ ).repeat_interleave(self.speculative_num_draft_tokens),
506
+ dim=0,
507
+ ),
508
+ (1, 0),
509
+ )[:-1]
510
+ mask_extraction_indices = (
511
+ cols.repeat_interleave(self.speculative_num_draft_tokens, dim=0)
512
+ + cum_len[:, None]
513
+ ).view(1, -1)
514
+ mask = forward_batch.spec_info.custom_mask[
515
+ mask_extraction_indices
516
+ ].view(
517
+ -1, self.speculative_num_draft_tokens
518
+ ) # (bsz * draft_num, draft_num)
519
+
520
+ # shift table indices to avoid padding
521
+ # non_masked_page_table [[8, 9, 10], mask (display with int format) [[1, 0, 0],
522
+ # [8, 9, 10], [1, 1, 0],
523
+ # [8, 9, 10]] [1, 0, 1]]
524
+ # if masked with padding [[8, 0, 0], our mask without padding [[8, 9, 10],
525
+ # [8, 9, 0], [8, 9, 10],
526
+ # [8, 0, 10]] [8, 10, 9]]
527
+ # note here cache_seqlens_int32 is [1, 2, 2] so extra page indices will be ignored in each row
528
+ col_indices = offsets.expand(
529
+ mask.shape[0], self.speculative_num_draft_tokens
530
+ )
531
+ # Build keys: if an entry is valid (mask==True), keep its original index;
532
+ # if not, add self.speculative_num_draft_tokens so that it sorts after all valid entries.
533
+ keys = torch.where(
534
+ mask, col_indices, col_indices + self.speculative_num_draft_tokens
535
+ )
536
+ _, sort_order = torch.sort(keys, dim=1)
537
+ non_masked_page_table = (
538
+ forward_batch.req_to_token_pool.req_to_token[
539
+ forward_batch.req_pool_indices, :
540
+ ]
541
+ .gather(1, cols)
542
+ .repeat_interleave(self.speculative_num_draft_tokens, dim=0)
543
+ ) # (bsz, draft_num)
544
+ metadata_expand.page_table = non_masked_page_table.gather(1, sort_order)
545
+ metadata_expand.cache_seqlens_int32 = mask.sum(dim=1).to(torch.int32)
546
+ metadata_expand.cu_seqlens_k = torch.nn.functional.pad(
547
+ torch.cumsum(
548
+ metadata_expand.cache_seqlens_int32, dim=0, dtype=torch.int32
549
+ ),
550
+ (1, 0),
551
+ )
552
+ metadata_expand.max_seq_len_k = (
553
+ metadata_expand.cache_seqlens_int32.max().item()
554
+ )
555
+ self.forward_metadata_spec_decode_expand = metadata_expand
398
556
  elif forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed():
399
557
  metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
400
558
  metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
@@ -465,6 +623,9 @@ class FlashAttentionBackend(AttentionBackend):
465
623
  layer: RadixAttention,
466
624
  forward_batch: ForwardBatch,
467
625
  save_kv_cache=True,
626
+ # For multi-head latent attention
627
+ q_rope: Optional[torch.Tensor] = None,
628
+ k_rope: Optional[torch.Tensor] = None,
468
629
  ):
469
630
  if k is not None:
470
631
  assert v is not None
@@ -479,11 +640,11 @@ class FlashAttentionBackend(AttentionBackend):
479
640
  layer, cache_loc, k, v, layer.k_scale, layer.v_scale
480
641
  )
481
642
  else:
482
- forward_batch.token_to_kv_pool.set_kv_buffer(
643
+ forward_batch.token_to_kv_pool.set_mla_kv_buffer(
483
644
  layer,
484
645
  cache_loc,
485
646
  k,
486
- v,
647
+ k_rope,
487
648
  )
488
649
 
489
650
  # Use precomputed metadata across all layers
@@ -514,6 +675,11 @@ class FlashAttentionBackend(AttentionBackend):
514
675
  and (hasattr(layer, "use_irope") and layer.use_irope)
515
676
  )
516
677
 
678
+ # We do cascade attention for Target Verify with topk > 1
679
+ use_cascade_attn = (
680
+ forward_batch.forward_mode.is_target_verify() and self.topk > 1
681
+ )
682
+
517
683
  # Get the appropriate page table based on whether we're using local attention
518
684
  if use_local_attn:
519
685
  local_metadata = metadata.local_attn_metadata
@@ -548,7 +714,7 @@ class FlashAttentionBackend(AttentionBackend):
548
714
  cu_seqlens_k = metadata.encoder_cu_seqlens_k
549
715
  window_size = (-1, -1)
550
716
 
551
- o = flash_attn_with_kvcache(
717
+ result = flash_attn_with_kvcache(
552
718
  q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
553
719
  k_cache=key_cache,
554
720
  v_cache=value_cache,
@@ -558,13 +724,41 @@ class FlashAttentionBackend(AttentionBackend):
558
724
  cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
559
725
  max_seqlen_q=max_seqlen_q,
560
726
  softmax_scale=layer.scaling,
561
- causal=causal,
727
+ causal=False if use_cascade_attn else causal,
562
728
  window_size=window_size,
563
729
  softcap=layer.logit_cap,
564
730
  k_descale=k_descale,
565
731
  v_descale=v_descale,
732
+ return_softmax_lse=use_cascade_attn,
566
733
  )
567
- return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
734
+
735
+ if use_cascade_attn:
736
+ o, softmax_lse, *rest = result
737
+ o_expand, softmax_lse_expand, *rest_expand = flash_attn_with_kvcache(
738
+ q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
739
+ k_cache=key_cache,
740
+ v_cache=value_cache,
741
+ page_table=self.forward_metadata_spec_decode_expand.page_table,
742
+ cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,
743
+ cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
744
+ cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
745
+ max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
746
+ softmax_scale=layer.scaling,
747
+ causal=False,
748
+ window_size=window_size,
749
+ softcap=layer.logit_cap,
750
+ k_descale=k_descale,
751
+ v_descale=v_descale,
752
+ return_softmax_lse=True,
753
+ )
754
+ o, _ = merge_state_v2_wrapper(
755
+ o,
756
+ softmax_lse.T.contiguous(),
757
+ o_expand,
758
+ softmax_lse_expand.T.contiguous(),
759
+ )
760
+ else:
761
+ o = result
568
762
  else:
569
763
  if (
570
764
  not global_server_args_dict["disable_chunked_prefix_cache"]
@@ -624,10 +818,17 @@ class FlashAttentionBackend(AttentionBackend):
624
818
  c_kv_cache = c_kv.view(
625
819
  -1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
626
820
  )
627
- q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
628
- q_nope = q_all[:, :, : layer.v_head_dim]
629
- q_rope = q_all[:, :, layer.v_head_dim :]
630
- o = flash_attn_with_kvcache(
821
+ if q_rope is not None:
822
+ q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
823
+ q_rope = q_rope.view(
824
+ -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
825
+ )
826
+ else:
827
+ q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
828
+ q_nope = q_all[:, :, : layer.v_head_dim]
829
+ q_rope = q_all[:, :, layer.v_head_dim :]
830
+
831
+ result = flash_attn_with_kvcache(
631
832
  q=q_rope,
632
833
  k_cache=k_rope_cache,
633
834
  v_cache=c_kv_cache,
@@ -638,13 +839,44 @@ class FlashAttentionBackend(AttentionBackend):
638
839
  cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
639
840
  max_seqlen_q=max_seqlen_q,
640
841
  softmax_scale=layer.scaling,
641
- causal=True,
842
+ causal=False if use_cascade_attn else causal,
642
843
  softcap=layer.logit_cap,
643
844
  k_descale=k_descale,
644
845
  v_descale=v_descale,
846
+ return_softmax_lse=use_cascade_attn,
645
847
  )
848
+ if use_cascade_attn:
849
+ o, softmax_lse, *rest = result
850
+ o_expand, softmax_lse_expand, *rest_expand = (
851
+ flash_attn_with_kvcache(
852
+ q=q_rope,
853
+ k_cache=k_rope_cache,
854
+ v_cache=c_kv_cache,
855
+ qv=q_nope,
856
+ page_table=self.forward_metadata_spec_decode_expand.page_table,
857
+ cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,
858
+ cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
859
+ cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
860
+ max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
861
+ softmax_scale=layer.scaling,
862
+ causal=False,
863
+ window_size=window_size,
864
+ softcap=layer.logit_cap,
865
+ k_descale=k_descale,
866
+ v_descale=v_descale,
867
+ return_softmax_lse=True,
868
+ )
869
+ )
870
+ o, _ = merge_state_v2_wrapper(
871
+ o,
872
+ softmax_lse.T.contiguous(),
873
+ o_expand,
874
+ softmax_lse_expand.T.contiguous(),
875
+ )
876
+ else:
877
+ o = result
646
878
 
647
- return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
879
+ return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
648
880
 
649
881
  def forward_decode(
650
882
  self,
@@ -654,6 +886,9 @@ class FlashAttentionBackend(AttentionBackend):
654
886
  layer: RadixAttention,
655
887
  forward_batch: ForwardBatch,
656
888
  save_kv_cache=True,
889
+ # For multi-head latent attention
890
+ q_rope: Optional[torch.Tensor] = None,
891
+ k_rope: Optional[torch.Tensor] = None,
657
892
  ) -> torch.Tensor:
658
893
  if k is not None:
659
894
  assert v is not None
@@ -668,11 +903,11 @@ class FlashAttentionBackend(AttentionBackend):
668
903
  layer, cache_loc, k, v, layer.k_scale, layer.v_scale
669
904
  )
670
905
  else:
671
- forward_batch.token_to_kv_pool.set_kv_buffer(
906
+ forward_batch.token_to_kv_pool.set_mla_kv_buffer(
672
907
  layer,
673
908
  cache_loc,
674
909
  k,
675
- v,
910
+ k_rope,
676
911
  )
677
912
 
678
913
  # Use precomputed metadata across all layers
@@ -681,6 +916,8 @@ class FlashAttentionBackend(AttentionBackend):
681
916
  use_local_attention = (
682
917
  self.attention_chunk_size is not None and local_attn_metadata is not None
683
918
  )
919
+ # We do cascade attention for Draft Decode with topk > 1
920
+ use_cascade_attn = self.topk > 1
684
921
 
685
922
  # Calculate window size (can be moved to metadata if layer properties don't change)
686
923
  # we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
@@ -752,23 +989,61 @@ class FlashAttentionBackend(AttentionBackend):
752
989
  v_descale=v_descale,
753
990
  )
754
991
  else:
992
+ page_table = metadata.page_table
993
+ cache_seqlens = metadata.cache_seqlens_int32
994
+ cu_seqlens_k = metadata.cu_seqlens_k
995
+ max_seqlen_q = metadata.max_seq_len_q
996
+ q_reshaped = q.contiguous().view(
997
+ -1, layer.tp_q_head_num, layer.head_dim
998
+ )
999
+
755
1000
  # Default: single-token self-attention
756
- o = flash_attn_with_kvcache(
757
- q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
1001
+ result = flash_attn_with_kvcache(
1002
+ q=q_reshaped,
758
1003
  k_cache=key_cache,
759
1004
  v_cache=value_cache,
760
- page_table=metadata.page_table,
761
- cache_seqlens=metadata.cache_seqlens_int32,
1005
+ page_table=page_table,
1006
+ cache_seqlens=cache_seqlens,
762
1007
  cu_seqlens_q=metadata.cu_seqlens_q,
763
- cu_seqlens_k_new=metadata.cu_seqlens_k,
764
- max_seqlen_q=1,
1008
+ cu_seqlens_k_new=cu_seqlens_k,
1009
+ max_seqlen_q=max_seqlen_q,
765
1010
  softmax_scale=layer.scaling,
766
- causal=True,
1011
+ causal=False if use_cascade_attn else causal,
767
1012
  window_size=window_size,
768
1013
  softcap=layer.logit_cap,
769
1014
  k_descale=k_descale,
770
1015
  v_descale=v_descale,
1016
+ return_softmax_lse=use_cascade_attn,
771
1017
  )
1018
+ if use_cascade_attn:
1019
+ o, softmax_lse, *rest = result
1020
+ o_expand, softmax_lse_expand, *rest_expand = (
1021
+ flash_attn_with_kvcache(
1022
+ q=q_reshaped,
1023
+ k_cache=key_cache,
1024
+ v_cache=value_cache,
1025
+ page_table=self.forward_metadata_spec_decode_expand.page_table,
1026
+ cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,
1027
+ cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
1028
+ cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
1029
+ max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
1030
+ softmax_scale=layer.scaling,
1031
+ causal=False,
1032
+ window_size=window_size,
1033
+ softcap=layer.logit_cap,
1034
+ k_descale=k_descale,
1035
+ v_descale=v_descale,
1036
+ return_softmax_lse=True,
1037
+ )
1038
+ )
1039
+ o, _ = merge_state_v2(
1040
+ o,
1041
+ softmax_lse.T.contiguous(),
1042
+ o_expand,
1043
+ softmax_lse_expand.T.contiguous(),
1044
+ )
1045
+ else:
1046
+ o = result
772
1047
  else:
773
1048
  # Do absorbed multi-latent attention
774
1049
  kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
@@ -784,11 +1059,18 @@ class FlashAttentionBackend(AttentionBackend):
784
1059
  -1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
785
1060
  )
786
1061
 
787
- q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
788
- q_nope = q_all[:, :, : layer.v_head_dim]
789
- q_rope = q_all[:, :, layer.v_head_dim :]
1062
+ if q_rope is not None:
1063
+ q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
1064
+ q_rope = q_rope.view(
1065
+ -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
1066
+ )
1067
+ else:
1068
+ q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
1069
+ q_nope = q_all[:, :, : layer.v_head_dim]
1070
+ q_rope = q_all[:, :, layer.v_head_dim :]
1071
+ max_seqlen_q = metadata.max_seq_len_q
790
1072
 
791
- o = flash_attn_with_kvcache(
1073
+ result = flash_attn_with_kvcache(
792
1074
  q=q_rope,
793
1075
  k_cache=k_rope_cache,
794
1076
  v_cache=c_kv_cache,
@@ -797,13 +1079,43 @@ class FlashAttentionBackend(AttentionBackend):
797
1079
  cache_seqlens=metadata.cache_seqlens_int32,
798
1080
  cu_seqlens_q=metadata.cu_seqlens_q,
799
1081
  cu_seqlens_k_new=metadata.cu_seqlens_k,
800
- max_seqlen_q=1,
1082
+ max_seqlen_q=max_seqlen_q,
801
1083
  softmax_scale=layer.scaling,
802
- causal=True,
1084
+ causal=False if use_cascade_attn else causal,
803
1085
  softcap=layer.logit_cap,
804
1086
  k_descale=k_descale,
805
1087
  v_descale=v_descale,
1088
+ return_softmax_lse=use_cascade_attn, # softmax_lse is needed for merge states
806
1089
  )
1090
+ if use_cascade_attn:
1091
+ o, softmax_lse, *rest = result
1092
+ o_expand, softmax_lse_expand, *rest_expand = flash_attn_with_kvcache(
1093
+ q=q_rope,
1094
+ k_cache=k_rope_cache,
1095
+ v_cache=c_kv_cache,
1096
+ qv=q_nope,
1097
+ page_table=self.forward_metadata_spec_decode_expand.page_table,
1098
+ cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,
1099
+ cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
1100
+ cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
1101
+ max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
1102
+ softmax_scale=layer.scaling,
1103
+ causal=False,
1104
+ window_size=window_size,
1105
+ softcap=layer.logit_cap,
1106
+ k_descale=k_descale,
1107
+ v_descale=v_descale,
1108
+ return_softmax_lse=True,
1109
+ )
1110
+ o, _ = merge_state_v2(
1111
+ o,
1112
+ softmax_lse.T.contiguous(),
1113
+ o_expand,
1114
+ softmax_lse_expand.T.contiguous(),
1115
+ )
1116
+ else:
1117
+ o = result
1118
+
807
1119
  return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
808
1120
 
809
1121
  def init_cuda_graph_state(self, max_bs: int):
@@ -815,6 +1127,8 @@ class FlashAttentionBackend(AttentionBackend):
815
1127
  This creates fixed-size tensors that will be reused during CUDA graph replay
816
1128
  to avoid memory allocations.
817
1129
  """
1130
+
1131
+ # This is being used by normal decode and draft decode when topk == 1
818
1132
  self.decode_cuda_graph_metadata = {
819
1133
  "cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
820
1134
  "cu_seqlens_q": torch.arange(
@@ -840,24 +1154,136 @@ class FlashAttentionBackend(AttentionBackend):
840
1154
  ),
841
1155
  }
842
1156
 
843
- self.target_verify_metadata = {
844
- "cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
845
- "cu_seqlens_q": torch.zeros(
846
- max_bs + 1, dtype=torch.int32, device=self.device
847
- ),
848
- "cu_seqlens_k": torch.zeros(
849
- max_bs + 1, dtype=torch.int32, device=self.device
850
- ),
851
- "page_table": torch.zeros(
852
- max_bs,
853
- (self.max_context_len + self.page_size - 1) // self.page_size,
854
- dtype=torch.int32,
855
- device=self.device,
856
- ),
857
- "strided_indices": torch.arange(
858
- 0, self.max_context_len, self.page_size, device=self.device
859
- ),
860
- }
1157
+ # This is used by draft decode's first half of metadata when topk > 1
1158
+ if self.topk > 1:
1159
+ self.draft_decode_metadata_topk_normal = {
1160
+ "cache_seqlens": torch.zeros(
1161
+ max_bs, dtype=torch.int32, device=self.device
1162
+ ),
1163
+ "cu_seqlens_q": torch.arange(
1164
+ 0,
1165
+ max_bs * self.topk + 1,
1166
+ step=self.topk,
1167
+ dtype=torch.int32,
1168
+ device=self.device,
1169
+ ),
1170
+ "cu_seqlens_k": torch.zeros(
1171
+ max_bs + 1, dtype=torch.int32, device=self.device
1172
+ ),
1173
+ "page_table": torch.zeros(
1174
+ max_bs,
1175
+ self.max_context_len,
1176
+ dtype=torch.int32,
1177
+ device=self.device,
1178
+ ),
1179
+ }
1180
+
1181
+ # This is used by draft decode's second half of metadata when topk > 1
1182
+ decode_length = self.speculative_step_id + 1
1183
+ self.draft_decode_metadata_topk_expand = {
1184
+ "cache_seqlens": torch.full(
1185
+ (max_bs * self.topk,),
1186
+ decode_length,
1187
+ device=self.device,
1188
+ dtype=torch.int32,
1189
+ ),
1190
+ "cu_seqlens_q": torch.arange(
1191
+ 0,
1192
+ max_bs * self.topk + 1,
1193
+ dtype=torch.int32,
1194
+ device=self.device,
1195
+ ),
1196
+ "cu_seqlens_k": torch.arange(
1197
+ 0,
1198
+ max_bs * self.topk * decode_length + 1,
1199
+ step=decode_length,
1200
+ dtype=torch.int32,
1201
+ device=self.device,
1202
+ ),
1203
+ "page_table": torch.zeros(
1204
+ max_bs * self.topk,
1205
+ decode_length,
1206
+ dtype=torch.int32,
1207
+ device=self.device,
1208
+ ),
1209
+ }
1210
+
1211
+ if (
1212
+ self.speculative_num_draft_tokens is not None
1213
+ and self.speculative_num_draft_tokens > 0
1214
+ ):
1215
+ self.target_verify_metadata = {
1216
+ "cache_seqlens": torch.zeros(
1217
+ max_bs, dtype=torch.int32, device=self.device
1218
+ ),
1219
+ "cu_seqlens_q": torch.arange(
1220
+ 0,
1221
+ max_bs * self.speculative_num_draft_tokens + 1,
1222
+ step=self.speculative_num_draft_tokens,
1223
+ dtype=torch.int32,
1224
+ device=self.device,
1225
+ ),
1226
+ "cu_seqlens_k": torch.zeros(
1227
+ max_bs + 1, dtype=torch.int32, device=self.device
1228
+ ),
1229
+ "page_table": torch.zeros(
1230
+ max_bs,
1231
+ (self.max_context_len + self.page_size - 1) // self.page_size,
1232
+ dtype=torch.int32,
1233
+ device=self.device,
1234
+ ),
1235
+ "strided_indices": torch.arange(
1236
+ 0, self.max_context_len, self.page_size, device=self.device
1237
+ ),
1238
+ }
1239
+
1240
+ if self.topk > 1:
1241
+ self.target_verify_metadata_topk_normal = {
1242
+ "cache_seqlens": torch.zeros(
1243
+ max_bs, dtype=torch.int32, device=self.device
1244
+ ),
1245
+ "cu_seqlens_q": torch.arange(
1246
+ 0,
1247
+ max_bs * self.speculative_num_draft_tokens + 1,
1248
+ step=self.speculative_num_draft_tokens,
1249
+ dtype=torch.int32,
1250
+ device=self.device,
1251
+ ),
1252
+ "cu_seqlens_k": torch.zeros(
1253
+ max_bs + 1, dtype=torch.int32, device=self.device
1254
+ ),
1255
+ "page_table": torch.zeros(
1256
+ max_bs,
1257
+ self.max_context_len,
1258
+ dtype=torch.int32,
1259
+ device=self.device,
1260
+ ),
1261
+ }
1262
+
1263
+ self.target_verify_metadata_topk_expand = {
1264
+ "cache_seqlens": torch.zeros(
1265
+ max_bs * self.speculative_num_draft_tokens,
1266
+ dtype=torch.int32,
1267
+ device=self.device,
1268
+ ),
1269
+ "cu_seqlens_k": torch.zeros(
1270
+ max_bs * self.speculative_num_draft_tokens + 1,
1271
+ dtype=torch.int32,
1272
+ device=self.device,
1273
+ ),
1274
+ "cu_seqlens_q": torch.arange(
1275
+ 0,
1276
+ max_bs * self.speculative_num_draft_tokens + 1,
1277
+ dtype=torch.int32,
1278
+ device=self.device,
1279
+ ),
1280
+ "page_table": torch.zeros(
1281
+ max_bs * self.speculative_num_draft_tokens,
1282
+ self.speculative_num_draft_tokens,
1283
+ dtype=torch.int32,
1284
+ device=self.device,
1285
+ ),
1286
+ }
861
1287
 
862
1288
  self.encoder_metadata = {
863
1289
  "encoder_page_table": torch.zeros(
@@ -886,28 +1312,78 @@ class FlashAttentionBackend(AttentionBackend):
886
1312
  ):
887
1313
  """Initialize forward metadata for capturing CUDA graph."""
888
1314
  metadata = FlashAttentionMetadata()
1315
+
1316
+ # metadata_expand is needed for Spec Decoding when top k > 1
1317
+ metadata_expand = FlashAttentionMetadata()
1318
+
889
1319
  device = seq_lens.device
890
1320
  if forward_mode.is_decode_or_idle():
891
1321
  if spec_info is not None:
892
1322
  # Draft Decode
893
- metadata.cache_seqlens_int32 = self.decode_cuda_graph_metadata[
894
- "cache_seqlens"
895
- ][:bs]
896
- metadata.max_seq_len_k = seq_lens.max().item() + (
897
- self.speculative_step_id + 1
898
- )
899
- metadata.cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"][
900
- : bs + 1
901
- ]
902
- metadata.cu_seqlens_k = torch.nn.functional.pad(
903
- torch.cumsum(
904
- metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
905
- ),
906
- (1, 0),
907
- )
908
- metadata.page_table = self.decode_cuda_graph_metadata[
909
- "page_table_draft_decode"
910
- ][req_pool_indices, :]
1323
+ if self.topk <= 1:
1324
+ # When topk = 1, we use the normal decode metadata
1325
+ metadata.cache_seqlens_int32 = self.decode_cuda_graph_metadata[
1326
+ "cache_seqlens"
1327
+ ][:bs]
1328
+ metadata.max_seq_len_k = seq_lens.max().item() + (
1329
+ self.speculative_step_id + 1
1330
+ )
1331
+ metadata.cu_seqlens_q = self.decode_cuda_graph_metadata[
1332
+ "cu_seqlens_q"
1333
+ ][: bs + 1]
1334
+ metadata.cu_seqlens_k = torch.nn.functional.pad(
1335
+ torch.cumsum(
1336
+ metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
1337
+ ),
1338
+ (1, 0),
1339
+ )
1340
+ metadata.page_table = self.decode_cuda_graph_metadata[
1341
+ "page_table_draft_decode"
1342
+ ][req_pool_indices, :]
1343
+ self.decode_cuda_graph_metadata[bs] = metadata
1344
+ else:
1345
+ # When top k > 1, we need two specific draft decode metadata, and then merge states
1346
+ # 1. The first half of metadata for prefix tokens
1347
+ metadata.cache_seqlens_int32 = (
1348
+ self.draft_decode_metadata_topk_normal["cache_seqlens"][:bs]
1349
+ )
1350
+ metadata.max_seq_len_q = self.topk
1351
+ metadata.max_seq_len_k = seq_lens.max().item()
1352
+ metadata.cu_seqlens_q = self.draft_decode_metadata_topk_normal[
1353
+ "cu_seqlens_q"
1354
+ ][: bs + 1]
1355
+ metadata.cu_seqlens_k = self.draft_decode_metadata_topk_normal[
1356
+ "cu_seqlens_k"
1357
+ ][: bs + 1]
1358
+ metadata.page_table = self.draft_decode_metadata_topk_normal[
1359
+ "page_table"
1360
+ ][req_pool_indices, :]
1361
+
1362
+ # 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
1363
+ metadata_expand.cache_seqlens_int32 = (
1364
+ self.draft_decode_metadata_topk_expand["cache_seqlens"][
1365
+ : bs * self.topk
1366
+ ]
1367
+ )
1368
+ metadata_expand.max_seq_len_q = 1
1369
+ metadata_expand.max_seq_len_k = (
1370
+ self.speculative_step_id + 1
1371
+ ) # , do this in replay
1372
+ metadata_expand.cu_seqlens_q = (
1373
+ self.draft_decode_metadata_topk_expand["cu_seqlens_q"][
1374
+ : bs * self.topk + 1
1375
+ ]
1376
+ )
1377
+ metadata_expand.cu_seqlens_k = (
1378
+ self.draft_decode_metadata_topk_expand["cu_seqlens_k"][
1379
+ : bs * self.topk + 1
1380
+ ]
1381
+ )
1382
+ metadata_expand.page_table = self.draft_decode_metadata_topk_expand[
1383
+ "page_table"
1384
+ ][: bs * self.topk]
1385
+ self.draft_decode_metadata_topk_normal[bs] = metadata
1386
+ self.draft_decode_metadata_topk_expand[bs] = metadata_expand
911
1387
  else:
912
1388
  # Normal Decode
913
1389
  # Get sequence information
@@ -927,37 +1403,77 @@ class FlashAttentionBackend(AttentionBackend):
927
1403
  metadata.cu_seqlens_q = torch.arange(
928
1404
  0, batch_size + 1, dtype=torch.int32, device=device
929
1405
  )
930
- self.decode_cuda_graph_metadata[bs] = metadata
1406
+ self.decode_cuda_graph_metadata[bs] = metadata
1407
+
931
1408
  elif forward_mode.is_target_verify():
932
- metadata.cache_seqlens_int32 = self.target_verify_metadata["cache_seqlens"][
933
- :bs
934
- ]
935
- metadata.cache_seqlens_int32.copy_(
936
- (seq_lens + self.speculative_num_draft_tokens).to(torch.int32)
937
- )
1409
+ if self.topk <= 1:
1410
+ metadata.cache_seqlens_int32 = self.target_verify_metadata[
1411
+ "cache_seqlens"
1412
+ ][:bs]
1413
+ metadata.cache_seqlens_int32.copy_(
1414
+ (seq_lens + self.speculative_num_draft_tokens).to(torch.int32)
1415
+ )
938
1416
 
939
- metadata.max_seq_len_q = self.speculative_num_draft_tokens
940
- metadata.max_seq_len_k = (
941
- seq_lens.max().item() + self.speculative_num_draft_tokens
942
- )
1417
+ metadata.max_seq_len_q = self.speculative_num_draft_tokens
1418
+ metadata.max_seq_len_k = (
1419
+ seq_lens.max().item() + self.speculative_num_draft_tokens
1420
+ )
943
1421
 
944
- metadata.cu_seqlens_q = torch.arange(
945
- 0,
946
- bs * self.speculative_num_draft_tokens + 1,
947
- self.speculative_num_draft_tokens,
948
- dtype=torch.int32,
949
- device=device,
950
- )
1422
+ metadata.cu_seqlens_q = torch.arange(
1423
+ 0,
1424
+ bs * self.speculative_num_draft_tokens + 1,
1425
+ self.speculative_num_draft_tokens,
1426
+ dtype=torch.int32,
1427
+ device=device,
1428
+ )
951
1429
 
952
- metadata.cu_seqlens_k = self.target_verify_metadata["cu_seqlens_k"][
953
- : (bs + 1)
954
- ]
1430
+ metadata.cu_seqlens_k = self.target_verify_metadata["cu_seqlens_k"][
1431
+ : (bs + 1)
1432
+ ]
955
1433
 
956
- metadata.page_table = self.target_verify_metadata["page_table"][
957
- req_pool_indices, :
958
- ]
1434
+ metadata.page_table = self.target_verify_metadata["page_table"][
1435
+ req_pool_indices, :
1436
+ ]
1437
+
1438
+ self.target_verify_metadata[bs] = metadata
1439
+ else:
1440
+ # When topk > 1, we need two specific target verify metadata, and then merge states
1441
+ # 1. The first half of metadata for prefix tokens
1442
+ metadata.cache_seqlens_int32 = self.target_verify_metadata_topk_normal[
1443
+ "cache_seqlens"
1444
+ ][:bs]
1445
+ metadata.max_seq_len_q = self.speculative_num_draft_tokens
1446
+ # metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item(), do this in replay
1447
+ metadata.cu_seqlens_q = self.target_verify_metadata_topk_normal[
1448
+ "cu_seqlens_q"
1449
+ ][: bs + 1]
1450
+ metadata.cu_seqlens_k = self.target_verify_metadata_topk_normal[
1451
+ "cu_seqlens_k"
1452
+ ][: bs + 1]
1453
+ metadata.page_table = self.target_verify_metadata_topk_normal[
1454
+ "page_table"
1455
+ ][req_pool_indices, :]
1456
+
1457
+ # 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
1458
+ metadata_expand.cache_seqlens_int32 = (
1459
+ self.target_verify_metadata_topk_expand["cache_seqlens"][
1460
+ : bs * self.speculative_num_draft_tokens
1461
+ ]
1462
+ )
1463
+ metadata_expand.max_seq_len_q = 1
1464
+ metadata_expand.cu_seqlens_q = self.target_verify_metadata_topk_expand[
1465
+ "cu_seqlens_q"
1466
+ ][: bs * self.speculative_num_draft_tokens + 1]
1467
+ metadata_expand.cu_seqlens_k = self.target_verify_metadata_topk_expand[
1468
+ "cu_seqlens_k"
1469
+ ][: bs * self.speculative_num_draft_tokens + 1]
1470
+
1471
+ metadata_expand.page_table = self.target_verify_metadata_topk_expand[
1472
+ "page_table"
1473
+ ][: bs * self.speculative_num_draft_tokens]
959
1474
 
960
- self.target_verify_metadata[bs] = metadata
1475
+ self.target_verify_metadata_topk_normal[bs] = metadata
1476
+ self.target_verify_metadata_topk_expand[bs] = metadata_expand
961
1477
 
962
1478
  if encoder_lens is not None:
963
1479
  encoder_bs = encoder_lens.numel()
@@ -973,6 +1489,7 @@ class FlashAttentionBackend(AttentionBackend):
973
1489
  ]
974
1490
 
975
1491
  self.forward_metadata = metadata
1492
+ self.forward_metadata_spec_decode_expand = metadata_expand
976
1493
 
977
1494
  def init_forward_metadata_replay_cuda_graph(
978
1495
  self,
@@ -986,41 +1503,85 @@ class FlashAttentionBackend(AttentionBackend):
986
1503
  seq_lens_cpu: Optional[torch.Tensor],
987
1504
  out_cache_loc: torch.Tensor = None,
988
1505
  ):
989
- # """Initialize forward metadata for replaying CUDA graph."""
1506
+ """Initialize forward metadata for replaying CUDA graph."""
990
1507
  seq_lens = seq_lens[:bs]
991
1508
  seq_lens_cpu = seq_lens_cpu[:bs]
992
1509
  req_pool_indices = req_pool_indices[:bs]
993
1510
  device = seq_lens.device
1511
+ metadata = None
1512
+ metadata_expand = None
994
1513
 
995
1514
  if forward_mode.is_decode_or_idle():
996
- metadata = self.decode_cuda_graph_metadata[bs]
997
1515
 
998
1516
  if spec_info is not None:
999
1517
  # Draft Decode
1000
- metadata.cache_seqlens_int32.copy_(
1001
- (seq_lens + (self.speculative_step_id + 1)).to(torch.int32)
1002
- )
1518
+ if self.topk <= 1:
1519
+ metadata = self.decode_cuda_graph_metadata[bs]
1520
+ # When topk = 1, we use the normal decode metadata
1521
+ metadata.cache_seqlens_int32.copy_(
1522
+ (seq_lens + (self.speculative_step_id + 1)).to(torch.int32)
1523
+ )
1003
1524
 
1004
- metadata.max_seq_len_k = seq_lens_cpu.max().item() + (
1005
- self.speculative_step_id + 1
1006
- )
1007
- metadata.cu_seqlens_k.copy_(
1008
- torch.nn.functional.pad(
1009
- torch.cumsum(
1010
- metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
1011
- ),
1012
- (1, 0),
1525
+ metadata.max_seq_len_k = seq_lens_cpu.max().item() + (
1526
+ self.speculative_step_id + 1
1527
+ )
1528
+ metadata.cu_seqlens_k.copy_(
1529
+ torch.nn.functional.pad(
1530
+ torch.cumsum(
1531
+ metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
1532
+ ),
1533
+ (1, 0),
1534
+ )
1013
1535
  )
1014
- )
1015
1536
 
1016
- page_table = self.req_to_token[
1017
- req_pool_indices, : metadata.max_seq_len_k
1018
- ]
1537
+ max_seq_pages = (
1538
+ metadata.max_seq_len_k + self.page_size - 1
1539
+ ) // self.page_size
1540
+ page_indices = self.req_to_token[
1541
+ req_pool_indices[:, None],
1542
+ self.decode_cuda_graph_metadata["strided_indices"][
1543
+ :max_seq_pages
1544
+ ],
1545
+ ]
1546
+
1547
+ page_indices //= self.page_size
1548
+ metadata.page_table[:, :max_seq_pages].copy_(page_indices)
1549
+ else:
1550
+ # When top k > 1, we need two specific draft decode metadata, and then merge states
1551
+ # 1. The first half of metadata for prefix tokens
1552
+ metadata = self.draft_decode_metadata_topk_normal[bs]
1553
+ metadata.cache_seqlens_int32.copy_(seq_lens.to(torch.int32))
1554
+ # metadata.max_seq_len_q = self.topk, already set in capture
1555
+ metadata.max_seq_len_k = seq_lens_cpu.max().item()
1556
+ # metadata.cu_seqlens_q already set in capture
1557
+ metadata.cu_seqlens_k.copy_(
1558
+ torch.nn.functional.pad(
1559
+ torch.cumsum(
1560
+ metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
1561
+ ),
1562
+ (1, 0),
1563
+ )
1564
+ )
1019
1565
 
1020
- metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
1566
+ page_table = self.req_to_token[
1567
+ req_pool_indices, : metadata.max_seq_len_k
1568
+ ]
1569
+
1570
+ metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
1021
1571
 
1572
+ # 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
1573
+ metadata_expand = self.draft_decode_metadata_topk_expand[bs]
1574
+ decode_length = self.speculative_step_id + 1
1575
+ cache_loc = out_cache_loc.view(
1576
+ self.speculative_num_steps, -1
1577
+ ).T.contiguous()
1578
+ metadata_expand.page_table[: cache_loc.shape[0]].copy_(
1579
+ cache_loc[:, :decode_length].contiguous().to(torch.int32)
1580
+ )
1581
+ # TODO: we need to test this part for llama 4 eagle case
1022
1582
  self._init_local_attn_metadata(metadata, device)
1023
1583
  else:
1584
+ metadata = self.decode_cuda_graph_metadata[bs]
1024
1585
  # Normal Decode
1025
1586
  max_len = seq_lens_cpu.max().item()
1026
1587
  metadata.max_seq_len_k = max_len
@@ -1045,24 +1606,117 @@ class FlashAttentionBackend(AttentionBackend):
1045
1606
 
1046
1607
  self._init_local_attn_metadata(metadata, device)
1047
1608
  elif forward_mode.is_target_verify():
1048
- metadata = self.target_verify_metadata[bs]
1049
- metadata.cache_seqlens_int32.copy_(
1050
- (seq_lens + self.speculative_num_draft_tokens).to(torch.int32)
1051
- )
1609
+ if self.topk <= 1:
1610
+ metadata = self.target_verify_metadata[bs]
1611
+ metadata.cache_seqlens_int32.copy_(
1612
+ (seq_lens + self.speculative_num_draft_tokens).to(torch.int32)
1613
+ )
1052
1614
 
1053
- metadata.max_seq_len_k = (
1054
- seq_lens_cpu.max().item() + self.speculative_num_draft_tokens
1055
- )
1056
- metadata.cu_seqlens_k.copy_(
1057
- torch.nn.functional.pad(
1615
+ metadata.max_seq_len_k = (
1616
+ seq_lens_cpu.max().item() + self.speculative_num_draft_tokens
1617
+ )
1618
+ metadata.cu_seqlens_k.copy_(
1619
+ torch.nn.functional.pad(
1620
+ torch.cumsum(
1621
+ metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
1622
+ ),
1623
+ (1, 0),
1624
+ )
1625
+ )
1626
+ max_seq_pages = (
1627
+ metadata.max_seq_len_k + self.page_size - 1
1628
+ ) // self.page_size
1629
+ page_indices = self.req_to_token[
1630
+ req_pool_indices[:, None],
1631
+ self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages],
1632
+ ]
1633
+ page_indices //= self.page_size
1634
+ metadata.page_table[:, :max_seq_pages].copy_(page_indices)
1635
+ else:
1636
+ # When topk > 1, we need two specific target verify metadata, and then merge states
1637
+ # 1. The first half of metadata for prefix tokens
1638
+ metadata = self.target_verify_metadata_topk_normal[bs]
1639
+ metadata.cache_seqlens_int32.copy_(seq_lens.to(torch.int32))
1640
+ # metadata.max_seq_len_q = self.speculative_num_draft_tokens, already set in capture
1641
+ metadata.max_seq_len_k = seq_lens_cpu.max().item()
1642
+ # metadata.cu_seqlens_q already set in capture
1643
+ metadata.cu_seqlens_k.copy_(
1644
+ torch.nn.functional.pad(
1645
+ torch.cumsum(
1646
+ metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
1647
+ ),
1648
+ (1, 0),
1649
+ )
1650
+ )
1651
+ page_table = self.req_to_token[
1652
+ req_pool_indices, : metadata.max_seq_len_k
1653
+ ]
1654
+ metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
1655
+
1656
+ # 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
1657
+ metadata_expand = self.target_verify_metadata_topk_expand[bs]
1658
+ # metadata_expand.max_seq_len_q = 1, already set in capture
1659
+ # metadata_expand.cu_seqlens_q already set in capture
1660
+
1661
+ offsets = torch.arange(
1662
+ self.speculative_num_draft_tokens, device=device
1663
+ ).unsqueeze(
1664
+ 0
1665
+ ) # shape: (1, self.speculative_num_draft_tokens)
1666
+ cols = offsets.expand(seq_lens.numel(), -1) + seq_lens.unsqueeze(1)
1667
+ cum_len = torch.nn.functional.pad(
1058
1668
  torch.cumsum(
1059
- metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
1669
+ (
1670
+ seq_lens + self.speculative_num_draft_tokens
1671
+ ).repeat_interleave(self.speculative_num_draft_tokens),
1672
+ dim=0,
1060
1673
  ),
1061
1674
  (1, 0),
1675
+ )[:-1]
1676
+ mask_extraction_indices = (
1677
+ cols.repeat_interleave(self.speculative_num_draft_tokens, dim=0)
1678
+ + cum_len[:, None]
1679
+ ).view(1, -1)
1680
+ # avoid extracting padded seq indices which will be out of boundary
1681
+ mask_extraction_indices[
1682
+ :, spec_info.positions.numel() * self.speculative_num_draft_tokens :
1683
+ ].fill_(0)
1684
+
1685
+ mask = spec_info.custom_mask[mask_extraction_indices].view(
1686
+ -1, self.speculative_num_draft_tokens
1687
+ ) # (bsz * draft_num, draft_num)
1688
+ col_indices = offsets.expand(
1689
+ mask.shape[0], self.speculative_num_draft_tokens
1690
+ )
1691
+ keys = torch.where(
1692
+ mask, col_indices, col_indices + self.speculative_num_draft_tokens
1693
+ )
1694
+ _, sort_order = torch.sort(keys, dim=1)
1695
+
1696
+ non_masked_page_table = (
1697
+ self.req_to_token[req_pool_indices, :]
1698
+ .gather(1, cols)
1699
+ .repeat_interleave(self.speculative_num_draft_tokens, dim=0)
1700
+ ) # (bsz, draft_num)
1701
+ metadata_expand.page_table.copy_(
1702
+ non_masked_page_table.gather(1, sort_order)
1703
+ )
1704
+ metadata_expand.cache_seqlens_int32.copy_(
1705
+ mask.sum(dim=1).to(torch.int32)
1706
+ )
1707
+ metadata_expand.cu_seqlens_k.copy_(
1708
+ torch.nn.functional.pad(
1709
+ torch.cumsum(
1710
+ metadata_expand.cache_seqlens_int32,
1711
+ dim=0,
1712
+ dtype=torch.int32,
1713
+ ),
1714
+ (1, 0),
1715
+ )
1716
+ )
1717
+ metadata_expand.max_seq_len_k = (
1718
+ metadata_expand.cache_seqlens_int32.max().item()
1062
1719
  )
1063
- )
1064
- page_table = self.req_to_token[req_pool_indices, : metadata.max_seq_len_k]
1065
- metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
1066
1720
 
1067
1721
  if encoder_lens is not None:
1068
1722
  # Only support encoder size 1 for now
@@ -1089,6 +1743,7 @@ class FlashAttentionBackend(AttentionBackend):
1089
1743
  metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
1090
1744
 
1091
1745
  self.forward_metadata = metadata
1746
+ self.forward_metadata_spec_decode_expand = metadata_expand
1092
1747
 
1093
1748
  def get_cuda_graph_seq_len_fill_value(self):
1094
1749
  """Get the fill value for sequence length in CUDA graph."""
@@ -1139,12 +1794,6 @@ class FlashAttentionMultiStepBackend:
1139
1794
  self.model_runner = model_runner
1140
1795
  self.topk = topk
1141
1796
  self.speculative_num_steps = speculative_num_steps
1142
-
1143
- # TODO: Support Topk > 1 for FlashAttentionBackend Spec Decoding
1144
- assert (
1145
- self.topk == 1
1146
- ), "speculative_eagle_topk must be 1 for FlashAttentionMultiStepBackend"
1147
-
1148
1797
  self.attn_backends = []
1149
1798
  for i in range(self.speculative_num_steps):
1150
1799
  self.attn_backends.append(