sglang 0.4.5.post1__py3-none-any.whl → 0.4.5.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. sglang/__init__.py +2 -4
  2. sglang/bench_one_batch.py +2 -2
  3. sglang/bench_serving.py +3 -6
  4. sglang/compile_deep_gemm.py +136 -0
  5. sglang/lang/backend/anthropic.py +0 -4
  6. sglang/lang/backend/base_backend.py +1 -1
  7. sglang/lang/backend/openai.py +6 -2
  8. sglang/lang/backend/runtime_endpoint.py +5 -1
  9. sglang/lang/backend/vertexai.py +0 -1
  10. sglang/lang/compiler.py +1 -7
  11. sglang/lang/tracer.py +3 -7
  12. sglang/srt/_custom_ops.py +0 -2
  13. sglang/srt/configs/model_config.py +4 -1
  14. sglang/srt/constrained/outlines_jump_forward.py +14 -1
  15. sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
  16. sglang/srt/constrained/xgrammar_backend.py +27 -4
  17. sglang/srt/custom_op.py +0 -62
  18. sglang/srt/disaggregation/decode.py +105 -6
  19. sglang/srt/disaggregation/mini_lb.py +74 -9
  20. sglang/srt/disaggregation/mooncake/conn.py +33 -63
  21. sglang/srt/disaggregation/mooncake/transfer_engine.py +30 -61
  22. sglang/srt/disaggregation/nixl/__init__.py +1 -0
  23. sglang/srt/disaggregation/nixl/conn.py +622 -0
  24. sglang/srt/disaggregation/prefill.py +137 -17
  25. sglang/srt/disaggregation/utils.py +32 -0
  26. sglang/srt/entrypoints/engine.py +4 -0
  27. sglang/srt/entrypoints/http_server.py +3 -7
  28. sglang/srt/entrypoints/verl_engine.py +7 -5
  29. sglang/srt/function_call_parser.py +60 -0
  30. sglang/srt/layers/activation.py +6 -8
  31. sglang/srt/layers/attention/flashattention_backend.py +883 -209
  32. sglang/srt/layers/attention/flashinfer_backend.py +5 -2
  33. sglang/srt/layers/attention/torch_native_backend.py +6 -1
  34. sglang/srt/layers/attention/triton_backend.py +6 -0
  35. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +5 -5
  36. sglang/srt/layers/attention/triton_ops/extend_attention.py +18 -7
  37. sglang/srt/layers/attention/triton_ops/prefill_attention.py +7 -3
  38. sglang/srt/layers/dp_attention.py +1 -1
  39. sglang/srt/layers/layernorm.py +20 -5
  40. sglang/srt/layers/linear.py +17 -3
  41. sglang/srt/layers/moe/ep_moe/layer.py +17 -29
  42. sglang/srt/layers/moe/fused_moe_native.py +4 -0
  43. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +14 -19
  44. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  45. sglang/srt/layers/moe/topk.py +27 -30
  46. sglang/srt/layers/parameter.py +0 -2
  47. sglang/srt/layers/quantization/__init__.py +1 -0
  48. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  49. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +9 -2
  50. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +16 -44
  51. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  52. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +153 -0
  53. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
  54. sglang/srt/layers/quantization/deep_gemm.py +378 -0
  55. sglang/srt/layers/quantization/fp8.py +115 -132
  56. sglang/srt/layers/quantization/fp8_kernel.py +213 -88
  57. sglang/srt/layers/quantization/fp8_utils.py +189 -264
  58. sglang/srt/layers/quantization/gptq.py +13 -7
  59. sglang/srt/layers/quantization/modelopt_quant.py +2 -2
  60. sglang/srt/layers/quantization/moe_wna16.py +2 -0
  61. sglang/srt/layers/quantization/utils.py +5 -11
  62. sglang/srt/layers/quantization/w8a8_fp8.py +2 -0
  63. sglang/srt/layers/quantization/w8a8_int8.py +7 -7
  64. sglang/srt/layers/radix_attention.py +15 -0
  65. sglang/srt/layers/rotary_embedding.py +9 -8
  66. sglang/srt/layers/sampler.py +7 -12
  67. sglang/srt/lora/backend/base_backend.py +18 -2
  68. sglang/srt/lora/backend/flashinfer_backend.py +1 -1
  69. sglang/srt/lora/backend/triton_backend.py +1 -1
  70. sglang/srt/lora/layers.py +1 -1
  71. sglang/srt/lora/lora.py +1 -1
  72. sglang/srt/lora/lora_manager.py +1 -1
  73. sglang/srt/managers/data_parallel_controller.py +7 -1
  74. sglang/srt/managers/detokenizer_manager.py +0 -1
  75. sglang/srt/managers/io_struct.py +15 -3
  76. sglang/srt/managers/mm_utils.py +4 -3
  77. sglang/srt/managers/multimodal_processor.py +0 -2
  78. sglang/srt/managers/multimodal_processors/base_processor.py +3 -2
  79. sglang/srt/managers/schedule_batch.py +15 -4
  80. sglang/srt/managers/scheduler.py +28 -77
  81. sglang/srt/managers/tokenizer_manager.py +116 -29
  82. sglang/srt/managers/tp_worker.py +1 -0
  83. sglang/srt/mem_cache/hiradix_cache.py +41 -29
  84. sglang/srt/mem_cache/memory_pool.py +38 -15
  85. sglang/srt/model_executor/cuda_graph_runner.py +15 -10
  86. sglang/srt/model_executor/model_runner.py +39 -31
  87. sglang/srt/models/bert.py +398 -0
  88. sglang/srt/models/deepseek.py +1 -1
  89. sglang/srt/models/deepseek_nextn.py +74 -70
  90. sglang/srt/models/deepseek_v2.py +292 -348
  91. sglang/srt/models/llama.py +5 -5
  92. sglang/srt/models/minicpm3.py +31 -203
  93. sglang/srt/models/minicpmo.py +17 -6
  94. sglang/srt/models/qwen2.py +4 -1
  95. sglang/srt/models/qwen2_moe.py +14 -13
  96. sglang/srt/models/qwen3.py +335 -0
  97. sglang/srt/models/qwen3_moe.py +423 -0
  98. sglang/srt/openai_api/adapter.py +71 -4
  99. sglang/srt/openai_api/protocol.py +6 -1
  100. sglang/srt/reasoning_parser.py +0 -1
  101. sglang/srt/sampling/sampling_batch_info.py +2 -3
  102. sglang/srt/server_args.py +86 -72
  103. sglang/srt/speculative/build_eagle_tree.py +2 -2
  104. sglang/srt/speculative/eagle_utils.py +2 -2
  105. sglang/srt/speculative/eagle_worker.py +6 -14
  106. sglang/srt/utils.py +62 -6
  107. sglang/test/runners.py +5 -1
  108. sglang/test/test_block_fp8.py +167 -0
  109. sglang/test/test_custom_ops.py +1 -1
  110. sglang/test/test_utils.py +3 -1
  111. sglang/version.py +1 -1
  112. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/METADATA +5 -5
  113. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/RECORD +116 -110
  114. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/WHEEL +1 -1
  115. sglang/lang/__init__.py +0 -0
  116. sglang/srt/lora/backend/__init__.py +0 -25
  117. sglang/srt/server.py +0 -18
  118. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/licenses/LICENSE +0 -0
  119. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/top_level.txt +0 -0
@@ -16,6 +16,7 @@ if TYPE_CHECKING:
16
16
  from sglang.srt.layers.radix_attention import RadixAttention
17
17
  from sglang.srt.model_executor.model_runner import ModelRunner
18
18
 
19
+ from sgl_kernel import merge_state_v2
19
20
  from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
20
21
 
21
22
 
@@ -30,7 +31,7 @@ class FlashAttentionMetadata:
30
31
  # Sequence lengths for the forward batch
31
32
  cache_seqlens_int32: torch.Tensor = None
32
33
  # Maximum sequence length for query
33
- max_seq_len_q: int = 0
34
+ max_seq_len_q: int = 1
34
35
  # Maximum sequence length for key
35
36
  max_seq_len_k: int = 0
36
37
  # Cumulative sequence lengths for query
@@ -142,6 +143,16 @@ def make_local_attention_virtual_batches(
142
143
  seqlens_k_local: Key sequence lengths for local attention
143
144
  block_table_local: Block table for local attention
144
145
  """
146
+ # Adjust attention_chunk_size based on the actual sequence length
147
+ # to avoid index out of bounds errors
148
+ max_seq_len = seq_lens_np.max()
149
+ effective_chunk_size = min(attn_chunk_size, max_seq_len)
150
+ # Make sure effective_chunk_size is divisible by page_size
151
+ effective_chunk_size = (effective_chunk_size // page_size) * page_size
152
+ if effective_chunk_size < page_size:
153
+ effective_chunk_size = page_size
154
+ attn_chunk_size = effective_chunk_size
155
+
145
156
  q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1]
146
157
  actual_batch_size = seq_lens_np.shape[0]
147
158
 
@@ -257,6 +268,12 @@ def cdiv(a: int, b: int) -> int:
257
268
  return -(a // -b)
258
269
 
259
270
 
271
+ # TODO(hebiao064): remove this once we have a better way to handle the merge_state_v2 torch.compile issue
272
+ @torch._dynamo.disable()
273
+ def merge_state_v2_wrapper(o, s_a, o_exp, s_b):
274
+ return merge_state_v2(o, s_a, o_exp, s_b)
275
+
276
+
260
277
  class FlashAttentionBackend(AttentionBackend):
261
278
  """FlashAttention backend implementation.
262
279
 
@@ -291,6 +308,8 @@ class FlashAttentionBackend(AttentionBackend):
291
308
  ), "Sliding window and cross attention are not supported together"
292
309
 
293
310
  self.forward_metadata: FlashAttentionMetadata = None
311
+ # extra metdata for handling speculative decoding topk > 1, extended draft decode and verify
312
+ self.forward_metadata_spec_decode_expand: FlashAttentionMetadata = None
294
313
  self.max_context_len = model_runner.model_config.context_len
295
314
  self.device = model_runner.device
296
315
  self.decode_cuda_graph_metadata = {}
@@ -299,12 +318,9 @@ class FlashAttentionBackend(AttentionBackend):
299
318
  self.kv_cache_dtype = model_runner.kv_cache_dtype
300
319
  self.kv_cache_dtype_str = model_runner.server_args.kv_cache_dtype
301
320
  self.page_size = model_runner.page_size
302
- self.use_mla = (
303
- model_runner.model_config.attention_arch == AttentionArch.MLA
304
- ) and (not global_server_args_dict["disable_mla"])
321
+ self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA
305
322
  self.skip_prefill = skip_prefill
306
-
307
- self.topk = topk
323
+ self.topk = model_runner.server_args.speculative_eagle_topk or 0
308
324
  self.speculative_num_steps = speculative_num_steps
309
325
  self.speculative_num_draft_tokens = (
310
326
  model_runner.server_args.speculative_num_draft_tokens
@@ -328,14 +344,107 @@ class FlashAttentionBackend(AttentionBackend):
328
344
  if forward_batch.forward_mode.is_decode_or_idle():
329
345
  # Draft Decode
330
346
  if forward_batch.spec_info is not None:
347
+ if self.topk <= 1:
348
+ metadata.cache_seqlens_int32 = (
349
+ seqlens_in_batch + (self.speculative_step_id + 1)
350
+ ).to(torch.int32)
351
+ metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + (
352
+ self.speculative_step_id + 1
353
+ )
354
+ metadata.cu_seqlens_q = torch.arange(
355
+ 0, batch_size + 1, dtype=torch.int32, device=device
356
+ )
357
+ metadata.cu_seqlens_k = torch.nn.functional.pad(
358
+ torch.cumsum(
359
+ metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
360
+ ),
361
+ (1, 0),
362
+ )
363
+ metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
364
+ forward_batch.req_pool_indices, : metadata.max_seq_len_k
365
+ ]
366
+ else:
367
+ metadata.cache_seqlens_int32 = (seqlens_in_batch).to(torch.int32)
368
+ metadata.max_seq_len_q = self.topk
369
+ metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
370
+ metadata.cu_seqlens_q = torch.arange(
371
+ 0,
372
+ batch_size * self.topk + 1,
373
+ step=self.topk,
374
+ dtype=torch.int32,
375
+ device=device,
376
+ )
377
+ metadata.cu_seqlens_k = torch.nn.functional.pad(
378
+ torch.cumsum(
379
+ metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
380
+ ),
381
+ (1, 0),
382
+ )
383
+ metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
384
+ forward_batch.req_pool_indices, : metadata.max_seq_len_k
385
+ ]
386
+
387
+ metadata_expand = FlashAttentionMetadata()
388
+ decode_length = self.speculative_step_id + 1
389
+ metadata_expand.cache_seqlens_int32 = torch.full(
390
+ (seqlens_in_batch.numel() * self.topk,),
391
+ decode_length,
392
+ device=device,
393
+ dtype=torch.int32,
394
+ )
395
+ metadata_expand.max_seq_len_q = 1
396
+ metadata_expand.max_seq_len_k = self.speculative_step_id + 1
397
+ metadata_expand.cu_seqlens_q = torch.arange(
398
+ 0,
399
+ metadata_expand.cache_seqlens_int32.numel() + 1,
400
+ dtype=torch.int32,
401
+ device=device,
402
+ )
403
+ metadata_expand.cu_seqlens_k = torch.arange(
404
+ 0,
405
+ metadata_expand.cache_seqlens_int32.numel() * decode_length + 1,
406
+ step=decode_length,
407
+ dtype=torch.int32,
408
+ device=device,
409
+ )
410
+ cache_loc = forward_batch.out_cache_loc.view(
411
+ self.speculative_num_steps, -1
412
+ ).T.contiguous()
413
+ metadata_expand.page_table = (
414
+ cache_loc[:, :decode_length].contiguous().to(torch.int32)
415
+ )
416
+ self.forward_metadata_spec_decode_expand = metadata_expand
417
+ else:
418
+ # Normal Decode
419
+ metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
420
+ metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
421
+ metadata.cu_seqlens_q = torch.arange(
422
+ 0, batch_size + 1, dtype=torch.int32, device=device
423
+ )
424
+ metadata.cu_seqlens_k = torch.nn.functional.pad(
425
+ torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
426
+ )
427
+ metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
428
+ forward_batch.req_pool_indices, : metadata.max_seq_len_k
429
+ ]
430
+ # TODO: we need to test this part for llama 4 eagle case
431
+ self._init_local_attn_metadata(metadata, device)
432
+ elif forward_batch.forward_mode.is_target_verify():
433
+ if self.topk <= 1:
331
434
  metadata.cache_seqlens_int32 = (
332
- seqlens_in_batch + (self.speculative_step_id + 1)
435
+ forward_batch.seq_lens + self.speculative_num_draft_tokens
333
436
  ).to(torch.int32)
334
- metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + (
335
- self.speculative_step_id + 1
437
+ metadata.max_seq_len_q = self.speculative_num_draft_tokens
438
+ metadata.max_seq_len_k = (
439
+ forward_batch.seq_lens_cpu.max().item()
440
+ + self.speculative_num_draft_tokens
336
441
  )
337
442
  metadata.cu_seqlens_q = torch.arange(
338
- 0, batch_size + 1, dtype=torch.int32, device=device
443
+ 0,
444
+ batch_size * self.speculative_num_draft_tokens + 1,
445
+ self.speculative_num_draft_tokens,
446
+ dtype=torch.int32,
447
+ device=device,
339
448
  )
340
449
  metadata.cu_seqlens_k = torch.nn.functional.pad(
341
450
  torch.cumsum(
@@ -346,43 +455,104 @@ class FlashAttentionBackend(AttentionBackend):
346
455
  metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
347
456
  forward_batch.req_pool_indices, : metadata.max_seq_len_k
348
457
  ]
458
+
459
+ self._init_local_attn_metadata(metadata, device)
349
460
  else:
350
- # Normal Decode
351
- metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
461
+ metadata.cache_seqlens_int32 = forward_batch.seq_lens.to(torch.int32)
462
+ metadata.max_seq_len_q = self.speculative_num_draft_tokens
352
463
  metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
353
464
  metadata.cu_seqlens_q = torch.arange(
354
- 0, batch_size + 1, dtype=torch.int32, device=device
465
+ 0,
466
+ batch_size * self.speculative_num_draft_tokens + 1,
467
+ step=self.speculative_num_draft_tokens,
468
+ dtype=torch.int32,
469
+ device=device,
355
470
  )
356
471
  metadata.cu_seqlens_k = torch.nn.functional.pad(
357
- torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
472
+ torch.cumsum(
473
+ metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
474
+ ),
475
+ (1, 0),
358
476
  )
359
477
  metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
360
478
  forward_batch.req_pool_indices, : metadata.max_seq_len_k
361
479
  ]
362
- elif forward_batch.forward_mode.is_target_verify():
363
- metadata.cache_seqlens_int32 = (
364
- forward_batch.seq_lens + self.speculative_num_draft_tokens
365
- ).to(torch.int32)
366
- metadata.max_seq_len_q = self.speculative_num_draft_tokens
367
- metadata.max_seq_len_k = (
368
- forward_batch.seq_lens_cpu.max().item()
369
- + self.speculative_num_draft_tokens
370
- )
371
- metadata.cu_seqlens_q = torch.arange(
372
- 0,
373
- batch_size * self.speculative_num_draft_tokens + 1,
374
- self.speculative_num_draft_tokens,
375
- dtype=torch.int32,
376
- device=device,
377
- )
378
- metadata.cu_seqlens_k = torch.nn.functional.pad(
379
- torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32),
380
- (1, 0),
381
- )
382
- metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
383
- forward_batch.req_pool_indices, : metadata.max_seq_len_k
384
- ]
385
480
 
481
+ metadata_expand = FlashAttentionMetadata()
482
+
483
+ metadata_expand.max_seq_len_q = 1
484
+ metadata_expand.cu_seqlens_q = torch.arange(
485
+ 0,
486
+ forward_batch.seq_lens.numel() * self.speculative_num_draft_tokens
487
+ + 1,
488
+ dtype=torch.int32,
489
+ device=device,
490
+ )
491
+
492
+ # create expand page table
493
+ offsets = torch.arange(
494
+ self.speculative_num_draft_tokens, device=device
495
+ ).unsqueeze(
496
+ 0
497
+ ) # shape: (1, self.speculative_num_draft_tokens)
498
+ cols = offsets.expand(
499
+ forward_batch.seq_lens.numel(), -1
500
+ ) + forward_batch.seq_lens.unsqueeze(1)
501
+ cum_len = torch.nn.functional.pad(
502
+ torch.cumsum(
503
+ (
504
+ forward_batch.seq_lens + self.speculative_num_draft_tokens
505
+ ).repeat_interleave(self.speculative_num_draft_tokens),
506
+ dim=0,
507
+ ),
508
+ (1, 0),
509
+ )[:-1]
510
+ mask_extraction_indices = (
511
+ cols.repeat_interleave(self.speculative_num_draft_tokens, dim=0)
512
+ + cum_len[:, None]
513
+ ).view(1, -1)
514
+ mask = forward_batch.spec_info.custom_mask[
515
+ mask_extraction_indices
516
+ ].view(
517
+ -1, self.speculative_num_draft_tokens
518
+ ) # (bsz * draft_num, draft_num)
519
+
520
+ # shift table indices to avoid padding
521
+ # non_masked_page_table [[8, 9, 10], mask (display with int format) [[1, 0, 0],
522
+ # [8, 9, 10], [1, 1, 0],
523
+ # [8, 9, 10]] [1, 0, 1]]
524
+ # if masked with padding [[8, 0, 0], our mask without padding [[8, 9, 10],
525
+ # [8, 9, 0], [8, 9, 10],
526
+ # [8, 0, 10]] [8, 10, 9]]
527
+ # note here cache_seqlens_int32 is [1, 2, 2] so extra page indices will be ignored in each row
528
+ col_indices = offsets.expand(
529
+ mask.shape[0], self.speculative_num_draft_tokens
530
+ )
531
+ # Build keys: if an entry is valid (mask==True), keep its original index;
532
+ # if not, add self.speculative_num_draft_tokens so that it sorts after all valid entries.
533
+ keys = torch.where(
534
+ mask, col_indices, col_indices + self.speculative_num_draft_tokens
535
+ )
536
+ _, sort_order = torch.sort(keys, dim=1)
537
+ non_masked_page_table = (
538
+ forward_batch.req_to_token_pool.req_to_token[
539
+ forward_batch.req_pool_indices, :
540
+ ]
541
+ .gather(1, cols)
542
+ .repeat_interleave(self.speculative_num_draft_tokens, dim=0)
543
+ ) # (bsz, draft_num)
544
+ metadata_expand.page_table = non_masked_page_table.gather(1, sort_order)
545
+ metadata_expand.cache_seqlens_int32 = mask.sum(dim=1).to(torch.int32)
546
+ metadata_expand.cu_seqlens_k = torch.nn.functional.pad(
547
+ torch.cumsum(
548
+ metadata_expand.cache_seqlens_int32, dim=0, dtype=torch.int32
549
+ ),
550
+ (1, 0),
551
+ )
552
+ metadata_expand.max_seq_len_k = (
553
+ metadata_expand.cache_seqlens_int32.max().item()
554
+ )
555
+ self.forward_metadata_spec_decode_expand = metadata_expand
386
556
  elif forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed():
387
557
  metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
388
558
  metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
@@ -407,49 +577,8 @@ class FlashAttentionBackend(AttentionBackend):
407
577
  metadata.cu_seqlens_q = metadata.cu_seqlens_k
408
578
 
409
579
  # Setup local attention if enabled
410
- if (
411
- self.attention_chunk_size is not None
412
- and forward_batch.forward_mode == ForwardMode.EXTEND
413
- ):
414
- # Convert tensors to numpy for local attention processing
415
- cu_seqlens_q_np = metadata.cu_seqlens_q.cpu().numpy()
416
- seq_lens_np = metadata.cache_seqlens_int32.cpu().numpy()
417
-
418
- # Adjust attention_chunk_size based on the actual sequence length
419
- # to avoid index out of bounds errors
420
- max_seq_len = seq_lens_np.max()
421
- effective_chunk_size = min(self.attention_chunk_size, max_seq_len)
422
- # Make sure effective_chunk_size is divisible by page_size
423
- effective_chunk_size = (
424
- effective_chunk_size // self.page_size
425
- ) * self.page_size
426
- if effective_chunk_size < self.page_size:
427
- effective_chunk_size = self.page_size
428
-
429
- # Create local attention metadata
430
- (
431
- seqlens_q_local_np,
432
- cu_seqlens_q_local_np,
433
- seqlens_k_local_np,
434
- block_table_local,
435
- ) = make_local_attention_virtual_batches(
436
- effective_chunk_size,
437
- cu_seqlens_q_np,
438
- seq_lens_np,
439
- metadata.page_table,
440
- self.page_size,
441
- )
442
-
443
- local_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
444
- local_query_start_loc=torch.from_numpy(cu_seqlens_q_local_np).to(
445
- device
446
- ),
447
- local_seqused_k=torch.from_numpy(seqlens_k_local_np).to(device),
448
- local_block_table=block_table_local,
449
- local_max_query_len=seqlens_q_local_np.max(),
450
- local_max_seq_len=seqlens_k_local_np.max(),
451
- )
452
- metadata.local_attn_metadata = local_metadata
580
+ if forward_batch.forward_mode == ForwardMode.EXTEND:
581
+ self._init_local_attn_metadata(metadata, device)
453
582
 
454
583
  # Encoder metadata for cross attention
455
584
  if forward_batch.encoder_lens is not None:
@@ -543,6 +672,11 @@ class FlashAttentionBackend(AttentionBackend):
543
672
  and (hasattr(layer, "use_irope") and layer.use_irope)
544
673
  )
545
674
 
675
+ # We do cascade attention for Target Verify with topk > 1
676
+ use_cascade_attn = (
677
+ forward_batch.forward_mode.is_target_verify() and self.topk > 1
678
+ )
679
+
546
680
  # Get the appropriate page table based on whether we're using local attention
547
681
  if use_local_attn:
548
682
  local_metadata = metadata.local_attn_metadata
@@ -577,7 +711,7 @@ class FlashAttentionBackend(AttentionBackend):
577
711
  cu_seqlens_k = metadata.encoder_cu_seqlens_k
578
712
  window_size = (-1, -1)
579
713
 
580
- o = flash_attn_with_kvcache(
714
+ result = flash_attn_with_kvcache(
581
715
  q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
582
716
  k_cache=key_cache,
583
717
  v_cache=value_cache,
@@ -587,13 +721,41 @@ class FlashAttentionBackend(AttentionBackend):
587
721
  cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
588
722
  max_seqlen_q=max_seqlen_q,
589
723
  softmax_scale=layer.scaling,
590
- causal=causal,
724
+ causal=False if use_cascade_attn else causal,
591
725
  window_size=window_size,
592
726
  softcap=layer.logit_cap,
593
727
  k_descale=k_descale,
594
728
  v_descale=v_descale,
729
+ return_softmax_lse=use_cascade_attn,
595
730
  )
596
- return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
731
+
732
+ if use_cascade_attn:
733
+ o, softmax_lse, *rest = result
734
+ o_expand, softmax_lse_expand, *rest_expand = flash_attn_with_kvcache(
735
+ q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
736
+ k_cache=key_cache,
737
+ v_cache=value_cache,
738
+ page_table=self.forward_metadata_spec_decode_expand.page_table,
739
+ cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,
740
+ cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
741
+ cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
742
+ max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
743
+ softmax_scale=layer.scaling,
744
+ causal=False,
745
+ window_size=window_size,
746
+ softcap=layer.logit_cap,
747
+ k_descale=k_descale,
748
+ v_descale=v_descale,
749
+ return_softmax_lse=True,
750
+ )
751
+ o, _ = merge_state_v2_wrapper(
752
+ o,
753
+ softmax_lse.T.contiguous(),
754
+ o_expand,
755
+ softmax_lse_expand.T.contiguous(),
756
+ )
757
+ else:
758
+ o = result
597
759
  else:
598
760
  if (
599
761
  not global_server_args_dict["disable_chunked_prefix_cache"]
@@ -656,7 +818,8 @@ class FlashAttentionBackend(AttentionBackend):
656
818
  q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
657
819
  q_nope = q_all[:, :, : layer.v_head_dim]
658
820
  q_rope = q_all[:, :, layer.v_head_dim :]
659
- o = flash_attn_with_kvcache(
821
+
822
+ result = flash_attn_with_kvcache(
660
823
  q=q_rope,
661
824
  k_cache=k_rope_cache,
662
825
  v_cache=c_kv_cache,
@@ -667,13 +830,44 @@ class FlashAttentionBackend(AttentionBackend):
667
830
  cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
668
831
  max_seqlen_q=max_seqlen_q,
669
832
  softmax_scale=layer.scaling,
670
- causal=True,
833
+ causal=False if use_cascade_attn else causal,
671
834
  softcap=layer.logit_cap,
672
835
  k_descale=k_descale,
673
836
  v_descale=v_descale,
837
+ return_softmax_lse=use_cascade_attn,
674
838
  )
839
+ if use_cascade_attn:
840
+ o, softmax_lse, *rest = result
841
+ o_expand, softmax_lse_expand, *rest_expand = (
842
+ flash_attn_with_kvcache(
843
+ q=q_rope,
844
+ k_cache=k_rope_cache,
845
+ v_cache=c_kv_cache,
846
+ qv=q_nope,
847
+ page_table=self.forward_metadata_spec_decode_expand.page_table,
848
+ cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,
849
+ cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
850
+ cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
851
+ max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
852
+ softmax_scale=layer.scaling,
853
+ causal=False,
854
+ window_size=window_size,
855
+ softcap=layer.logit_cap,
856
+ k_descale=k_descale,
857
+ v_descale=v_descale,
858
+ return_softmax_lse=True,
859
+ )
860
+ )
861
+ o, _ = merge_state_v2_wrapper(
862
+ o,
863
+ softmax_lse.T.contiguous(),
864
+ o_expand,
865
+ softmax_lse_expand.T.contiguous(),
866
+ )
867
+ else:
868
+ o = result
675
869
 
676
- return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
870
+ return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
677
871
 
678
872
  def forward_decode(
679
873
  self,
@@ -706,6 +900,12 @@ class FlashAttentionBackend(AttentionBackend):
706
900
 
707
901
  # Use precomputed metadata across all layers
708
902
  metadata = self.forward_metadata
903
+ local_attn_metadata = getattr(metadata, "local_attn_metadata", None)
904
+ use_local_attention = (
905
+ self.attention_chunk_size is not None and local_attn_metadata is not None
906
+ )
907
+ # We do cascade attention for Draft Decode with topk > 1
908
+ use_cascade_attn = self.topk > 1
709
909
 
710
910
  # Calculate window size (can be moved to metadata if layer properties don't change)
711
911
  # we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
@@ -740,33 +940,98 @@ class FlashAttentionBackend(AttentionBackend):
740
940
  -1, self.page_size, layer.tp_v_head_num, layer.head_dim
741
941
  )
742
942
 
743
- q_reshaped = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
744
943
  if layer.is_cross_attention:
745
- page_table = metadata.encoder_page_table
746
- cache_seqlens = metadata.encoder_lens_int32
747
- cu_seqlens_k = metadata.encoder_cu_seqlens_k
748
- window_size = (-1, -1)
944
+ # Always use non-chunked logic for cross-attention
945
+ o = flash_attn_with_kvcache(
946
+ q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
947
+ k_cache=key_cache,
948
+ v_cache=value_cache,
949
+ page_table=metadata.encoder_page_table,
950
+ cache_seqlens=metadata.encoder_lens_int32,
951
+ cu_seqlens_q=metadata.cu_seqlens_q,
952
+ cu_seqlens_k_new=metadata.encoder_cu_seqlens_k,
953
+ max_seqlen_q=1,
954
+ softmax_scale=layer.scaling,
955
+ causal=False,
956
+ window_size=(-1, -1),
957
+ softcap=layer.logit_cap,
958
+ k_descale=k_descale,
959
+ v_descale=v_descale,
960
+ )
961
+ elif use_local_attention:
962
+ # Use chunked (local) attention batching for self-attention
963
+ o = flash_attn_with_kvcache(
964
+ q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
965
+ k_cache=key_cache,
966
+ v_cache=value_cache,
967
+ page_table=local_attn_metadata.local_block_table,
968
+ cache_seqlens=local_attn_metadata.local_seqused_k,
969
+ cu_seqlens_q=local_attn_metadata.local_query_start_loc,
970
+ cu_seqlens_k_new=metadata.cu_seqlens_k,
971
+ max_seqlen_q=local_attn_metadata.local_max_query_len,
972
+ softmax_scale=layer.scaling,
973
+ causal=True,
974
+ window_size=(-1, -1),
975
+ softcap=layer.logit_cap,
976
+ k_descale=k_descale,
977
+ v_descale=v_descale,
978
+ )
749
979
  else:
750
980
  page_table = metadata.page_table
751
981
  cache_seqlens = metadata.cache_seqlens_int32
752
982
  cu_seqlens_k = metadata.cu_seqlens_k
983
+ max_seqlen_q = metadata.max_seq_len_q
984
+ q_reshaped = q.contiguous().view(
985
+ -1, layer.tp_q_head_num, layer.head_dim
986
+ )
753
987
 
754
- o = flash_attn_with_kvcache(
755
- q=q_reshaped,
756
- k_cache=key_cache,
757
- v_cache=value_cache,
758
- page_table=page_table,
759
- cache_seqlens=cache_seqlens,
760
- cu_seqlens_q=metadata.cu_seqlens_q,
761
- cu_seqlens_k_new=cu_seqlens_k,
762
- max_seqlen_q=1,
763
- softmax_scale=layer.scaling,
764
- causal=causal,
765
- window_size=window_size,
766
- softcap=layer.logit_cap,
767
- k_descale=k_descale,
768
- v_descale=v_descale,
769
- )
988
+ # Default: single-token self-attention
989
+ result = flash_attn_with_kvcache(
990
+ q=q_reshaped,
991
+ k_cache=key_cache,
992
+ v_cache=value_cache,
993
+ page_table=page_table,
994
+ cache_seqlens=cache_seqlens,
995
+ cu_seqlens_q=metadata.cu_seqlens_q,
996
+ cu_seqlens_k_new=cu_seqlens_k,
997
+ max_seqlen_q=max_seqlen_q,
998
+ softmax_scale=layer.scaling,
999
+ causal=False if use_cascade_attn else causal,
1000
+ window_size=window_size,
1001
+ softcap=layer.logit_cap,
1002
+ k_descale=k_descale,
1003
+ v_descale=v_descale,
1004
+ return_softmax_lse=use_cascade_attn,
1005
+ )
1006
+ if use_cascade_attn:
1007
+ o, softmax_lse, *rest = result
1008
+ o_expand, softmax_lse_expand, *rest_expand = (
1009
+ flash_attn_with_kvcache(
1010
+ q=q_reshaped,
1011
+ k_cache=key_cache,
1012
+ v_cache=value_cache,
1013
+ page_table=self.forward_metadata_spec_decode_expand.page_table,
1014
+ cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,
1015
+ cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
1016
+ cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
1017
+ max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
1018
+ softmax_scale=layer.scaling,
1019
+ causal=False,
1020
+ window_size=window_size,
1021
+ softcap=layer.logit_cap,
1022
+ k_descale=k_descale,
1023
+ v_descale=v_descale,
1024
+ return_softmax_lse=True,
1025
+ )
1026
+ )
1027
+ o, _ = merge_state_v2(
1028
+ o,
1029
+ softmax_lse.T.contiguous(),
1030
+ o_expand,
1031
+ softmax_lse_expand.T.contiguous(),
1032
+ )
1033
+ else:
1034
+ o = result
770
1035
  else:
771
1036
  # Do absorbed multi-latent attention
772
1037
  kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
@@ -785,8 +1050,9 @@ class FlashAttentionBackend(AttentionBackend):
785
1050
  q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
786
1051
  q_nope = q_all[:, :, : layer.v_head_dim]
787
1052
  q_rope = q_all[:, :, layer.v_head_dim :]
1053
+ max_seqlen_q = metadata.max_seq_len_q
788
1054
 
789
- o = flash_attn_with_kvcache(
1055
+ result = flash_attn_with_kvcache(
790
1056
  q=q_rope,
791
1057
  k_cache=k_rope_cache,
792
1058
  v_cache=c_kv_cache,
@@ -795,13 +1061,43 @@ class FlashAttentionBackend(AttentionBackend):
795
1061
  cache_seqlens=metadata.cache_seqlens_int32,
796
1062
  cu_seqlens_q=metadata.cu_seqlens_q,
797
1063
  cu_seqlens_k_new=metadata.cu_seqlens_k,
798
- max_seqlen_q=1,
1064
+ max_seqlen_q=max_seqlen_q,
799
1065
  softmax_scale=layer.scaling,
800
- causal=True,
1066
+ causal=False if use_cascade_attn else causal,
801
1067
  softcap=layer.logit_cap,
802
1068
  k_descale=k_descale,
803
1069
  v_descale=v_descale,
1070
+ return_softmax_lse=use_cascade_attn, # softmax_lse is needed for merge states
804
1071
  )
1072
+ if use_cascade_attn:
1073
+ o, softmax_lse, *rest = result
1074
+ o_expand, softmax_lse_expand, *rest_expand = flash_attn_with_kvcache(
1075
+ q=q_rope,
1076
+ k_cache=k_rope_cache,
1077
+ v_cache=c_kv_cache,
1078
+ qv=q_nope,
1079
+ page_table=self.forward_metadata_spec_decode_expand.page_table,
1080
+ cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,
1081
+ cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
1082
+ cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
1083
+ max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
1084
+ softmax_scale=layer.scaling,
1085
+ causal=False,
1086
+ window_size=window_size,
1087
+ softcap=layer.logit_cap,
1088
+ k_descale=k_descale,
1089
+ v_descale=v_descale,
1090
+ return_softmax_lse=True,
1091
+ )
1092
+ o, _ = merge_state_v2(
1093
+ o,
1094
+ softmax_lse.T.contiguous(),
1095
+ o_expand,
1096
+ softmax_lse_expand.T.contiguous(),
1097
+ )
1098
+ else:
1099
+ o = result
1100
+
805
1101
  return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
806
1102
 
807
1103
  def init_cuda_graph_state(self, max_bs: int):
@@ -813,6 +1109,8 @@ class FlashAttentionBackend(AttentionBackend):
813
1109
  This creates fixed-size tensors that will be reused during CUDA graph replay
814
1110
  to avoid memory allocations.
815
1111
  """
1112
+
1113
+ # This is being used by normal decode and draft decode when topk == 1
816
1114
  self.decode_cuda_graph_metadata = {
817
1115
  "cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
818
1116
  "cu_seqlens_q": torch.arange(
@@ -838,24 +1136,136 @@ class FlashAttentionBackend(AttentionBackend):
838
1136
  ),
839
1137
  }
840
1138
 
841
- self.target_verify_metadata = {
842
- "cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
843
- "cu_seqlens_q": torch.zeros(
844
- max_bs + 1, dtype=torch.int32, device=self.device
845
- ),
846
- "cu_seqlens_k": torch.zeros(
847
- max_bs + 1, dtype=torch.int32, device=self.device
848
- ),
849
- "page_table": torch.zeros(
850
- max_bs,
851
- (self.max_context_len + self.page_size - 1) // self.page_size,
852
- dtype=torch.int32,
853
- device=self.device,
854
- ),
855
- "strided_indices": torch.arange(
856
- 0, self.max_context_len, self.page_size, device=self.device
857
- ),
858
- }
1139
+ # This is used by draft decode's first half of metadata when topk > 1
1140
+ if self.topk > 1:
1141
+ self.draft_decode_metadata_topk_normal = {
1142
+ "cache_seqlens": torch.zeros(
1143
+ max_bs, dtype=torch.int32, device=self.device
1144
+ ),
1145
+ "cu_seqlens_q": torch.arange(
1146
+ 0,
1147
+ max_bs * self.topk + 1,
1148
+ step=self.topk,
1149
+ dtype=torch.int32,
1150
+ device=self.device,
1151
+ ),
1152
+ "cu_seqlens_k": torch.zeros(
1153
+ max_bs + 1, dtype=torch.int32, device=self.device
1154
+ ),
1155
+ "page_table": torch.zeros(
1156
+ max_bs,
1157
+ self.max_context_len,
1158
+ dtype=torch.int32,
1159
+ device=self.device,
1160
+ ),
1161
+ }
1162
+
1163
+ # This is used by draft decode's second half of metadata when topk > 1
1164
+ decode_length = self.speculative_step_id + 1
1165
+ self.draft_decode_metadata_topk_expand = {
1166
+ "cache_seqlens": torch.full(
1167
+ (max_bs * self.topk,),
1168
+ decode_length,
1169
+ device=self.device,
1170
+ dtype=torch.int32,
1171
+ ),
1172
+ "cu_seqlens_q": torch.arange(
1173
+ 0,
1174
+ max_bs * self.topk + 1,
1175
+ dtype=torch.int32,
1176
+ device=self.device,
1177
+ ),
1178
+ "cu_seqlens_k": torch.arange(
1179
+ 0,
1180
+ max_bs * self.topk * decode_length + 1,
1181
+ step=decode_length,
1182
+ dtype=torch.int32,
1183
+ device=self.device,
1184
+ ),
1185
+ "page_table": torch.zeros(
1186
+ max_bs * self.topk,
1187
+ decode_length,
1188
+ dtype=torch.int32,
1189
+ device=self.device,
1190
+ ),
1191
+ }
1192
+
1193
+ if (
1194
+ self.speculative_num_draft_tokens is not None
1195
+ and self.speculative_num_draft_tokens > 0
1196
+ ):
1197
+ self.target_verify_metadata = {
1198
+ "cache_seqlens": torch.zeros(
1199
+ max_bs, dtype=torch.int32, device=self.device
1200
+ ),
1201
+ "cu_seqlens_q": torch.arange(
1202
+ 0,
1203
+ max_bs * self.speculative_num_draft_tokens + 1,
1204
+ step=self.speculative_num_draft_tokens,
1205
+ dtype=torch.int32,
1206
+ device=self.device,
1207
+ ),
1208
+ "cu_seqlens_k": torch.zeros(
1209
+ max_bs + 1, dtype=torch.int32, device=self.device
1210
+ ),
1211
+ "page_table": torch.zeros(
1212
+ max_bs,
1213
+ (self.max_context_len + self.page_size - 1) // self.page_size,
1214
+ dtype=torch.int32,
1215
+ device=self.device,
1216
+ ),
1217
+ "strided_indices": torch.arange(
1218
+ 0, self.max_context_len, self.page_size, device=self.device
1219
+ ),
1220
+ }
1221
+
1222
+ if self.topk > 1:
1223
+ self.target_verify_metadata_topk_normal = {
1224
+ "cache_seqlens": torch.zeros(
1225
+ max_bs, dtype=torch.int32, device=self.device
1226
+ ),
1227
+ "cu_seqlens_q": torch.arange(
1228
+ 0,
1229
+ max_bs * self.speculative_num_draft_tokens + 1,
1230
+ step=self.speculative_num_draft_tokens,
1231
+ dtype=torch.int32,
1232
+ device=self.device,
1233
+ ),
1234
+ "cu_seqlens_k": torch.zeros(
1235
+ max_bs + 1, dtype=torch.int32, device=self.device
1236
+ ),
1237
+ "page_table": torch.zeros(
1238
+ max_bs,
1239
+ self.max_context_len,
1240
+ dtype=torch.int32,
1241
+ device=self.device,
1242
+ ),
1243
+ }
1244
+
1245
+ self.target_verify_metadata_topk_expand = {
1246
+ "cache_seqlens": torch.zeros(
1247
+ max_bs * self.speculative_num_draft_tokens,
1248
+ dtype=torch.int32,
1249
+ device=self.device,
1250
+ ),
1251
+ "cu_seqlens_k": torch.zeros(
1252
+ max_bs * self.speculative_num_draft_tokens + 1,
1253
+ dtype=torch.int32,
1254
+ device=self.device,
1255
+ ),
1256
+ "cu_seqlens_q": torch.arange(
1257
+ 0,
1258
+ max_bs * self.speculative_num_draft_tokens + 1,
1259
+ dtype=torch.int32,
1260
+ device=self.device,
1261
+ ),
1262
+ "page_table": torch.zeros(
1263
+ max_bs * self.speculative_num_draft_tokens,
1264
+ self.speculative_num_draft_tokens,
1265
+ dtype=torch.int32,
1266
+ device=self.device,
1267
+ ),
1268
+ }
859
1269
 
860
1270
  self.encoder_metadata = {
861
1271
  "encoder_page_table": torch.zeros(
@@ -884,28 +1294,78 @@ class FlashAttentionBackend(AttentionBackend):
884
1294
  ):
885
1295
  """Initialize forward metadata for capturing CUDA graph."""
886
1296
  metadata = FlashAttentionMetadata()
1297
+
1298
+ # metadata_expand is needed for Spec Decoding when top k > 1
1299
+ metadata_expand = FlashAttentionMetadata()
1300
+
887
1301
  device = seq_lens.device
888
1302
  if forward_mode.is_decode_or_idle():
889
1303
  if spec_info is not None:
890
1304
  # Draft Decode
891
- metadata.cache_seqlens_int32 = self.decode_cuda_graph_metadata[
892
- "cache_seqlens"
893
- ][:bs]
894
- metadata.max_seq_len_k = seq_lens.max().item() + (
895
- self.speculative_step_id + 1
896
- )
897
- metadata.cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"][
898
- : bs + 1
899
- ]
900
- metadata.cu_seqlens_k = torch.nn.functional.pad(
901
- torch.cumsum(
902
- metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
903
- ),
904
- (1, 0),
905
- )
906
- metadata.page_table = self.decode_cuda_graph_metadata[
907
- "page_table_draft_decode"
908
- ][req_pool_indices, :]
1305
+ if self.topk <= 1:
1306
+ # When topk = 1, we use the normal decode metadata
1307
+ metadata.cache_seqlens_int32 = self.decode_cuda_graph_metadata[
1308
+ "cache_seqlens"
1309
+ ][:bs]
1310
+ metadata.max_seq_len_k = seq_lens.max().item() + (
1311
+ self.speculative_step_id + 1
1312
+ )
1313
+ metadata.cu_seqlens_q = self.decode_cuda_graph_metadata[
1314
+ "cu_seqlens_q"
1315
+ ][: bs + 1]
1316
+ metadata.cu_seqlens_k = torch.nn.functional.pad(
1317
+ torch.cumsum(
1318
+ metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
1319
+ ),
1320
+ (1, 0),
1321
+ )
1322
+ metadata.page_table = self.decode_cuda_graph_metadata[
1323
+ "page_table_draft_decode"
1324
+ ][req_pool_indices, :]
1325
+ self.decode_cuda_graph_metadata[bs] = metadata
1326
+ else:
1327
+ # When top k > 1, we need two specific draft decode metadata, and then merge states
1328
+ # 1. The first half of metadata for prefix tokens
1329
+ metadata.cache_seqlens_int32 = (
1330
+ self.draft_decode_metadata_topk_normal["cache_seqlens"][:bs]
1331
+ )
1332
+ metadata.max_seq_len_q = self.topk
1333
+ metadata.max_seq_len_k = seq_lens.max().item()
1334
+ metadata.cu_seqlens_q = self.draft_decode_metadata_topk_normal[
1335
+ "cu_seqlens_q"
1336
+ ][: bs + 1]
1337
+ metadata.cu_seqlens_k = self.draft_decode_metadata_topk_normal[
1338
+ "cu_seqlens_k"
1339
+ ][: bs + 1]
1340
+ metadata.page_table = self.draft_decode_metadata_topk_normal[
1341
+ "page_table"
1342
+ ][req_pool_indices, :]
1343
+
1344
+ # 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
1345
+ metadata_expand.cache_seqlens_int32 = (
1346
+ self.draft_decode_metadata_topk_expand["cache_seqlens"][
1347
+ : bs * self.topk
1348
+ ]
1349
+ )
1350
+ metadata_expand.max_seq_len_q = 1
1351
+ metadata_expand.max_seq_len_k = (
1352
+ self.speculative_step_id + 1
1353
+ ) # , do this in replay
1354
+ metadata_expand.cu_seqlens_q = (
1355
+ self.draft_decode_metadata_topk_expand["cu_seqlens_q"][
1356
+ : bs * self.topk + 1
1357
+ ]
1358
+ )
1359
+ metadata_expand.cu_seqlens_k = (
1360
+ self.draft_decode_metadata_topk_expand["cu_seqlens_k"][
1361
+ : bs * self.topk + 1
1362
+ ]
1363
+ )
1364
+ metadata_expand.page_table = self.draft_decode_metadata_topk_expand[
1365
+ "page_table"
1366
+ ][: bs * self.topk]
1367
+ self.draft_decode_metadata_topk_normal[bs] = metadata
1368
+ self.draft_decode_metadata_topk_expand[bs] = metadata_expand
909
1369
  else:
910
1370
  # Normal Decode
911
1371
  # Get sequence information
@@ -925,37 +1385,77 @@ class FlashAttentionBackend(AttentionBackend):
925
1385
  metadata.cu_seqlens_q = torch.arange(
926
1386
  0, batch_size + 1, dtype=torch.int32, device=device
927
1387
  )
928
- self.decode_cuda_graph_metadata[bs] = metadata
1388
+ self.decode_cuda_graph_metadata[bs] = metadata
1389
+
929
1390
  elif forward_mode.is_target_verify():
930
- metadata.cache_seqlens_int32 = self.target_verify_metadata["cache_seqlens"][
931
- :bs
932
- ]
933
- metadata.cache_seqlens_int32.copy_(
934
- (seq_lens + self.speculative_num_draft_tokens).to(torch.int32)
935
- )
1391
+ if self.topk <= 1:
1392
+ metadata.cache_seqlens_int32 = self.target_verify_metadata[
1393
+ "cache_seqlens"
1394
+ ][:bs]
1395
+ metadata.cache_seqlens_int32.copy_(
1396
+ (seq_lens + self.speculative_num_draft_tokens).to(torch.int32)
1397
+ )
936
1398
 
937
- metadata.max_seq_len_q = self.speculative_num_draft_tokens
938
- metadata.max_seq_len_k = (
939
- seq_lens.max().item() + self.speculative_num_draft_tokens
940
- )
1399
+ metadata.max_seq_len_q = self.speculative_num_draft_tokens
1400
+ metadata.max_seq_len_k = (
1401
+ seq_lens.max().item() + self.speculative_num_draft_tokens
1402
+ )
941
1403
 
942
- metadata.cu_seqlens_q = torch.arange(
943
- 0,
944
- bs * self.speculative_num_draft_tokens + 1,
945
- self.speculative_num_draft_tokens,
946
- dtype=torch.int32,
947
- device=device,
948
- )
1404
+ metadata.cu_seqlens_q = torch.arange(
1405
+ 0,
1406
+ bs * self.speculative_num_draft_tokens + 1,
1407
+ self.speculative_num_draft_tokens,
1408
+ dtype=torch.int32,
1409
+ device=device,
1410
+ )
949
1411
 
950
- metadata.cu_seqlens_k = self.target_verify_metadata["cu_seqlens_k"][
951
- : (bs + 1)
952
- ]
1412
+ metadata.cu_seqlens_k = self.target_verify_metadata["cu_seqlens_k"][
1413
+ : (bs + 1)
1414
+ ]
953
1415
 
954
- metadata.page_table = self.target_verify_metadata["page_table"][
955
- req_pool_indices, :
956
- ]
1416
+ metadata.page_table = self.target_verify_metadata["page_table"][
1417
+ req_pool_indices, :
1418
+ ]
1419
+
1420
+ self.target_verify_metadata[bs] = metadata
1421
+ else:
1422
+ # When topk > 1, we need two specific target verify metadata, and then merge states
1423
+ # 1. The first half of metadata for prefix tokens
1424
+ metadata.cache_seqlens_int32 = self.target_verify_metadata_topk_normal[
1425
+ "cache_seqlens"
1426
+ ][:bs]
1427
+ metadata.max_seq_len_q = self.speculative_num_draft_tokens
1428
+ # metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item(), do this in replay
1429
+ metadata.cu_seqlens_q = self.target_verify_metadata_topk_normal[
1430
+ "cu_seqlens_q"
1431
+ ][: bs + 1]
1432
+ metadata.cu_seqlens_k = self.target_verify_metadata_topk_normal[
1433
+ "cu_seqlens_k"
1434
+ ][: bs + 1]
1435
+ metadata.page_table = self.target_verify_metadata_topk_normal[
1436
+ "page_table"
1437
+ ][req_pool_indices, :]
1438
+
1439
+ # 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
1440
+ metadata_expand.cache_seqlens_int32 = (
1441
+ self.target_verify_metadata_topk_expand["cache_seqlens"][
1442
+ : bs * self.speculative_num_draft_tokens
1443
+ ]
1444
+ )
1445
+ metadata_expand.max_seq_len_q = 1
1446
+ metadata_expand.cu_seqlens_q = self.target_verify_metadata_topk_expand[
1447
+ "cu_seqlens_q"
1448
+ ][: bs * self.speculative_num_draft_tokens + 1]
1449
+ metadata_expand.cu_seqlens_k = self.target_verify_metadata_topk_expand[
1450
+ "cu_seqlens_k"
1451
+ ][: bs * self.speculative_num_draft_tokens + 1]
1452
+
1453
+ metadata_expand.page_table = self.target_verify_metadata_topk_expand[
1454
+ "page_table"
1455
+ ][: bs * self.speculative_num_draft_tokens]
957
1456
 
958
- self.target_verify_metadata[bs] = metadata
1457
+ self.target_verify_metadata_topk_normal[bs] = metadata
1458
+ self.target_verify_metadata_topk_expand[bs] = metadata_expand
959
1459
 
960
1460
  if encoder_lens is not None:
961
1461
  encoder_bs = encoder_lens.numel()
@@ -971,6 +1471,7 @@ class FlashAttentionBackend(AttentionBackend):
971
1471
  ]
972
1472
 
973
1473
  self.forward_metadata = metadata
1474
+ self.forward_metadata_spec_decode_expand = metadata_expand
974
1475
 
975
1476
  def init_forward_metadata_replay_cuda_graph(
976
1477
  self,
@@ -984,37 +1485,85 @@ class FlashAttentionBackend(AttentionBackend):
984
1485
  seq_lens_cpu: Optional[torch.Tensor],
985
1486
  out_cache_loc: torch.Tensor = None,
986
1487
  ):
987
- # """Initialize forward metadata for replaying CUDA graph."""
1488
+ """Initialize forward metadata for replaying CUDA graph."""
988
1489
  seq_lens = seq_lens[:bs]
989
1490
  seq_lens_cpu = seq_lens_cpu[:bs]
990
1491
  req_pool_indices = req_pool_indices[:bs]
1492
+ device = seq_lens.device
1493
+ metadata = None
1494
+ metadata_expand = None
1495
+
991
1496
  if forward_mode.is_decode_or_idle():
992
- metadata = self.decode_cuda_graph_metadata[bs]
993
1497
 
994
1498
  if spec_info is not None:
995
1499
  # Draft Decode
996
- metadata.cache_seqlens_int32.copy_(
997
- (seq_lens + (self.speculative_step_id + 1)).to(torch.int32)
998
- )
1500
+ if self.topk <= 1:
1501
+ metadata = self.decode_cuda_graph_metadata[bs]
1502
+ # When topk = 1, we use the normal decode metadata
1503
+ metadata.cache_seqlens_int32.copy_(
1504
+ (seq_lens + (self.speculative_step_id + 1)).to(torch.int32)
1505
+ )
999
1506
 
1000
- metadata.max_seq_len_k = seq_lens_cpu.max().item() + (
1001
- self.speculative_step_id + 1
1002
- )
1003
- metadata.cu_seqlens_k.copy_(
1004
- torch.nn.functional.pad(
1005
- torch.cumsum(
1006
- metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
1007
- ),
1008
- (1, 0),
1507
+ metadata.max_seq_len_k = seq_lens_cpu.max().item() + (
1508
+ self.speculative_step_id + 1
1509
+ )
1510
+ metadata.cu_seqlens_k.copy_(
1511
+ torch.nn.functional.pad(
1512
+ torch.cumsum(
1513
+ metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
1514
+ ),
1515
+ (1, 0),
1516
+ )
1009
1517
  )
1010
- )
1011
1518
 
1012
- page_table = self.req_to_token[
1013
- req_pool_indices, : metadata.max_seq_len_k
1014
- ]
1519
+ max_seq_pages = (
1520
+ metadata.max_seq_len_k + self.page_size - 1
1521
+ ) // self.page_size
1522
+ page_indices = self.req_to_token[
1523
+ req_pool_indices[:, None],
1524
+ self.decode_cuda_graph_metadata["strided_indices"][
1525
+ :max_seq_pages
1526
+ ],
1527
+ ]
1528
+
1529
+ page_indices //= self.page_size
1530
+ metadata.page_table[:, :max_seq_pages].copy_(page_indices)
1531
+ else:
1532
+ # When top k > 1, we need two specific draft decode metadata, and then merge states
1533
+ # 1. The first half of metadata for prefix tokens
1534
+ metadata = self.draft_decode_metadata_topk_normal[bs]
1535
+ metadata.cache_seqlens_int32.copy_(seq_lens.to(torch.int32))
1536
+ # metadata.max_seq_len_q = self.topk, already set in capture
1537
+ metadata.max_seq_len_k = seq_lens_cpu.max().item()
1538
+ # metadata.cu_seqlens_q already set in capture
1539
+ metadata.cu_seqlens_k.copy_(
1540
+ torch.nn.functional.pad(
1541
+ torch.cumsum(
1542
+ metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
1543
+ ),
1544
+ (1, 0),
1545
+ )
1546
+ )
1015
1547
 
1016
- metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
1548
+ page_table = self.req_to_token[
1549
+ req_pool_indices, : metadata.max_seq_len_k
1550
+ ]
1551
+
1552
+ metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
1553
+
1554
+ # 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
1555
+ metadata_expand = self.draft_decode_metadata_topk_expand[bs]
1556
+ decode_length = self.speculative_step_id + 1
1557
+ cache_loc = out_cache_loc.view(
1558
+ self.speculative_num_steps, -1
1559
+ ).T.contiguous()
1560
+ metadata_expand.page_table[: cache_loc.shape[0]].copy_(
1561
+ cache_loc[:, :decode_length].contiguous().to(torch.int32)
1562
+ )
1563
+ # TODO: we need to test this part for llama 4 eagle case
1564
+ self._init_local_attn_metadata(metadata, device)
1017
1565
  else:
1566
+ metadata = self.decode_cuda_graph_metadata[bs]
1018
1567
  # Normal Decode
1019
1568
  max_len = seq_lens_cpu.max().item()
1020
1569
  metadata.max_seq_len_k = max_len
@@ -1037,25 +1586,119 @@ class FlashAttentionBackend(AttentionBackend):
1037
1586
  metadata.page_table[:, :max_seq_pages].copy_(page_indices)
1038
1587
  metadata.page_table[:, max_seq_pages:].fill_(0)
1039
1588
 
1589
+ self._init_local_attn_metadata(metadata, device)
1040
1590
  elif forward_mode.is_target_verify():
1041
- metadata = self.target_verify_metadata[bs]
1042
- metadata.cache_seqlens_int32.copy_(
1043
- (seq_lens + self.speculative_num_draft_tokens).to(torch.int32)
1044
- )
1591
+ if self.topk <= 1:
1592
+ metadata = self.target_verify_metadata[bs]
1593
+ metadata.cache_seqlens_int32.copy_(
1594
+ (seq_lens + self.speculative_num_draft_tokens).to(torch.int32)
1595
+ )
1045
1596
 
1046
- metadata.max_seq_len_k = (
1047
- seq_lens_cpu.max().item() + self.speculative_num_draft_tokens
1048
- )
1049
- metadata.cu_seqlens_k.copy_(
1050
- torch.nn.functional.pad(
1597
+ metadata.max_seq_len_k = (
1598
+ seq_lens_cpu.max().item() + self.speculative_num_draft_tokens
1599
+ )
1600
+ metadata.cu_seqlens_k.copy_(
1601
+ torch.nn.functional.pad(
1602
+ torch.cumsum(
1603
+ metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
1604
+ ),
1605
+ (1, 0),
1606
+ )
1607
+ )
1608
+ max_seq_pages = (
1609
+ metadata.max_seq_len_k + self.page_size - 1
1610
+ ) // self.page_size
1611
+ page_indices = self.req_to_token[
1612
+ req_pool_indices[:, None],
1613
+ self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages],
1614
+ ]
1615
+ page_indices //= self.page_size
1616
+ metadata.page_table[:, :max_seq_pages].copy_(page_indices)
1617
+ else:
1618
+ # When topk > 1, we need two specific target verify metadata, and then merge states
1619
+ # 1. The first half of metadata for prefix tokens
1620
+ metadata = self.target_verify_metadata_topk_normal[bs]
1621
+ metadata.cache_seqlens_int32.copy_(seq_lens.to(torch.int32))
1622
+ # metadata.max_seq_len_q = self.speculative_num_draft_tokens, already set in capture
1623
+ metadata.max_seq_len_k = seq_lens_cpu.max().item()
1624
+ # metadata.cu_seqlens_q already set in capture
1625
+ metadata.cu_seqlens_k.copy_(
1626
+ torch.nn.functional.pad(
1627
+ torch.cumsum(
1628
+ metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
1629
+ ),
1630
+ (1, 0),
1631
+ )
1632
+ )
1633
+ page_table = self.req_to_token[
1634
+ req_pool_indices, : metadata.max_seq_len_k
1635
+ ]
1636
+ metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
1637
+
1638
+ # 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
1639
+ metadata_expand = self.target_verify_metadata_topk_expand[bs]
1640
+ # metadata_expand.max_seq_len_q = 1, already set in capture
1641
+ # metadata_expand.cu_seqlens_q already set in capture
1642
+
1643
+ offsets = torch.arange(
1644
+ self.speculative_num_draft_tokens, device=device
1645
+ ).unsqueeze(
1646
+ 0
1647
+ ) # shape: (1, self.speculative_num_draft_tokens)
1648
+ cols = offsets.expand(seq_lens.numel(), -1) + seq_lens.unsqueeze(1)
1649
+ cum_len = torch.nn.functional.pad(
1051
1650
  torch.cumsum(
1052
- metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
1651
+ (
1652
+ seq_lens + self.speculative_num_draft_tokens
1653
+ ).repeat_interleave(self.speculative_num_draft_tokens),
1654
+ dim=0,
1053
1655
  ),
1054
1656
  (1, 0),
1657
+ )[:-1]
1658
+ mask_extraction_indices = (
1659
+ cols.repeat_interleave(self.speculative_num_draft_tokens, dim=0)
1660
+ + cum_len[:, None]
1661
+ ).view(1, -1)
1662
+ # avoid extracting padded seq indices which will be out of boundary
1663
+ mask_extraction_indices[
1664
+ :, spec_info.positions.numel() * self.speculative_num_draft_tokens :
1665
+ ].fill_(0)
1666
+
1667
+ mask = spec_info.custom_mask[mask_extraction_indices].view(
1668
+ -1, self.speculative_num_draft_tokens
1669
+ ) # (bsz * draft_num, draft_num)
1670
+ col_indices = offsets.expand(
1671
+ mask.shape[0], self.speculative_num_draft_tokens
1672
+ )
1673
+ keys = torch.where(
1674
+ mask, col_indices, col_indices + self.speculative_num_draft_tokens
1675
+ )
1676
+ _, sort_order = torch.sort(keys, dim=1)
1677
+
1678
+ non_masked_page_table = (
1679
+ self.req_to_token[req_pool_indices, :]
1680
+ .gather(1, cols)
1681
+ .repeat_interleave(self.speculative_num_draft_tokens, dim=0)
1682
+ ) # (bsz, draft_num)
1683
+ metadata_expand.page_table.copy_(
1684
+ non_masked_page_table.gather(1, sort_order)
1685
+ )
1686
+ metadata_expand.cache_seqlens_int32.copy_(
1687
+ mask.sum(dim=1).to(torch.int32)
1688
+ )
1689
+ metadata_expand.cu_seqlens_k.copy_(
1690
+ torch.nn.functional.pad(
1691
+ torch.cumsum(
1692
+ metadata_expand.cache_seqlens_int32,
1693
+ dim=0,
1694
+ dtype=torch.int32,
1695
+ ),
1696
+ (1, 0),
1697
+ )
1698
+ )
1699
+ metadata_expand.max_seq_len_k = (
1700
+ metadata_expand.cache_seqlens_int32.max().item()
1055
1701
  )
1056
- )
1057
- page_table = self.req_to_token[req_pool_indices, : metadata.max_seq_len_k]
1058
- metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
1059
1702
 
1060
1703
  if encoder_lens is not None:
1061
1704
  # Only support encoder size 1 for now
@@ -1082,11 +1725,48 @@ class FlashAttentionBackend(AttentionBackend):
1082
1725
  metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
1083
1726
 
1084
1727
  self.forward_metadata = metadata
1728
+ self.forward_metadata_spec_decode_expand = metadata_expand
1085
1729
 
1086
1730
  def get_cuda_graph_seq_len_fill_value(self):
1087
1731
  """Get the fill value for sequence length in CUDA graph."""
1088
1732
  return 0
1089
1733
 
1734
+ def _init_local_attn_metadata(self, metadata: FlashAttentionMetadata, device):
1735
+ """Centralized utility to initialize local_attn_metadata if chunked attention is enabled."""
1736
+ if self.attention_chunk_size is None:
1737
+ metadata.local_attn_metadata = None
1738
+ return
1739
+
1740
+ cu_seqlens_q = metadata.cu_seqlens_q
1741
+ cache_seqlens_int32 = metadata.cache_seqlens_int32
1742
+ page_table = metadata.page_table
1743
+ if cu_seqlens_q is None or cache_seqlens_int32 is None or page_table is None:
1744
+ metadata.local_attn_metadata = None
1745
+ return
1746
+
1747
+ cu_seqlens_q_np = cu_seqlens_q.cpu().numpy()
1748
+ seq_lens_np = cache_seqlens_int32.cpu().numpy()
1749
+ (
1750
+ seqlens_q_local_np,
1751
+ cu_seqlens_q_local_np,
1752
+ seqlens_k_local_np,
1753
+ block_table_local,
1754
+ ) = make_local_attention_virtual_batches(
1755
+ self.attention_chunk_size,
1756
+ cu_seqlens_q_np,
1757
+ seq_lens_np,
1758
+ page_table,
1759
+ self.page_size,
1760
+ )
1761
+ local_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
1762
+ local_query_start_loc=torch.from_numpy(cu_seqlens_q_local_np).to(device),
1763
+ local_seqused_k=torch.from_numpy(seqlens_k_local_np).to(device),
1764
+ local_block_table=block_table_local.to(device),
1765
+ local_max_query_len=int(seqlens_q_local_np.max()),
1766
+ local_max_seq_len=int(seqlens_k_local_np.max()),
1767
+ )
1768
+ metadata.local_attn_metadata = local_metadata
1769
+
1090
1770
 
1091
1771
  class FlashAttentionMultiStepBackend:
1092
1772
 
@@ -1096,12 +1776,6 @@ class FlashAttentionMultiStepBackend:
1096
1776
  self.model_runner = model_runner
1097
1777
  self.topk = topk
1098
1778
  self.speculative_num_steps = speculative_num_steps
1099
-
1100
- # TODO: Support Topk > 1 for FlashAttentionBackend Spec Decoding
1101
- assert (
1102
- self.topk == 1
1103
- ), "speculative_eagle_topk must be 1 for FlashAttentionMultiStepBackend"
1104
-
1105
1779
  self.attn_backends = []
1106
1780
  for i in range(self.speculative_num_steps):
1107
1781
  self.attn_backends.append(